[Core] Replaced JsonDB with Config (#770)

This commit is contained in:
Will 2017-05-27 22:28:59 -04:00 committed by Twentysix
parent a8745297dc
commit 3988fbbc09
23 changed files with 1298 additions and 380 deletions

12
.travis.yml Normal file
View File

@ -0,0 +1,12 @@
language: python
python:
- "3.5.3"
- "3.6.1"
install:
- pip install -r requirements.txt
script:
- python -m compileall ./cogs
- python -m pytest
cache: pip
notifications:
email: false

0
cogs/__init__.py Normal file
View File

View File

@ -0,0 +1 @@
from core.config import Config

View File

@ -1,6 +1,7 @@
from discord.ext import commands
from collections import Counter
from core.settings import CoreDB
from core import Config
from enum import Enum
import os
@ -8,17 +9,33 @@ import os
class Red(commands.Bot):
def __init__(self, cli_flags, **kwargs):
self._shutdown_mode = ExitCodes.CRITICAL
self.db = CoreDB("core/data/settings.json",
relative_path=False)
self.db = Config.get_core_conf(force_registration=True)
self.db.register_global(
token=None,
prefix=[],
packages=[],
coowners=[],
whitelist=[],
blacklist=[]
)
self.db.register_guild(
prefix=[],
whitelist=[],
blacklist=[],
admin_role=None,
mod_role=None
)
def prefix_manager(bot, message):
if not cli_flags.prefix:
global_prefix = self.db.get_global("prefix", [])
global_prefix = self.db.prefix()
else:
global_prefix = cli_flags.prefix
if message.guild is None:
return global_prefix
server_prefix = self.db.get(message.guild, "prefix", [])
server_prefix = self.db.guild(message.guild).prefix()
return server_prefix if server_prefix else global_prefix
if "command_prefix" not in kwargs:
@ -30,7 +47,7 @@ class Red(commands.Bot):
async def is_owner(self, user, allow_coowners=True):
if allow_coowners:
if user.id in self.db.get_global("coowners", []):
if user.id in self.db.coowners():
return True
return await super().is_owner(user)
@ -65,7 +82,7 @@ class Red(commands.Bot):
for package in self.extensions:
if package.startswith("cogs."):
loaded.append(package)
await self.db.set_global("packages", loaded)
await self.db.set("packages", loaded)
class ExitCodes(Enum):

View File

@ -1,3 +1,4 @@
import discord
from discord.ext import commands
@ -23,8 +24,12 @@ def mod_or_permissions(**perms):
if ctx.guild is None:
return has_perms_or_is_owner
author = ctx.author
mod_role = ctx.bot.db.get_mod_role(ctx.guild)
admin_role = ctx.bot.db.get_admin_role(ctx.guild)
mod_role_id = ctx.bot.db.guild(ctx.guild).mod_role()
admin_role_id = ctx.bot.db.guild(ctx.guild).admin_role()
mod_role = discord.utils.get(ctx.guild.roles, id=mod_role_id)
admin_role = discord.utils.get(ctx.guild.roles, id=admin_role_id)
is_staff = mod_role in author.roles or admin_role in author.roles
is_guild_owner = author == ctx.guild.owner
@ -40,7 +45,7 @@ def admin_or_permissions(**perms):
return has_perms_or_is_owner
author = ctx.author
is_guild_owner = author == ctx.guild.owner
admin_role = ctx.bot.db.get_admin_role(ctx.guild)
admin_role = ctx.bot.db.guild(ctx.guild).admin_role()
return admin_role in author.roles or has_perms_or_is_owner or is_guild_owner

View File

@ -1,3 +1,4 @@
import argparse
import asyncio
@ -19,7 +20,7 @@ def interactive_config(red, token_set, prefix_set):
print("That doesn't look like a valid token.")
token = ""
if token:
loop.run_until_complete(red.db.set_global("token", token))
loop.run_until_complete(red.db.set("token", token))
if not prefix_set:
prefix = ""
@ -36,6 +37,50 @@ def interactive_config(red, token_set, prefix_set):
if not confirm("> "):
prefix = ""
if prefix:
loop.run_until_complete(red.db.set_global("prefix", [prefix]))
loop.run_until_complete(red.db.set("prefix", [prefix]))
return token
def parse_cli_flags():
parser = argparse.ArgumentParser(description="Red - Discord Bot")
parser.add_argument("--owner", help="ID of the owner. Only who hosts "
"Red should be owner, this has "
"security implications")
parser.add_argument("--prefix", "-p", action="append",
help="Global prefix. Can be multiple")
parser.add_argument("--no-prompt",
action="store_true",
help="Disables console inputs. Features requiring "
"console interaction could be disabled as a "
"result")
parser.add_argument("--no-cogs",
action="store_true",
help="Starts Red with no cogs loaded, only core")
parser.add_argument("--self-bot",
action='store_true',
help="Specifies if Red should log in as selfbot")
parser.add_argument("--not-bot",
action='store_true',
help="Specifies if the token used belongs to a bot "
"account.")
parser.add_argument("--dry-run",
action="store_true",
help="Makes Red quit with code 0 just before the "
"login. This is useful for testing the boot "
"process.")
parser.add_argument("--debug",
action="store_true",
help="Sets the loggers level as debug")
parser.add_argument("--dev",
action="store_true",
help="Enables developer mode")
args = parser.parse_args()
if args.prefix:
args.prefix = sorted(args.prefix, reverse=True)
else:
args.prefix = []
return args

521
core/config.py Normal file
View File

@ -0,0 +1,521 @@
from pathlib import Path
from core.drivers.red_json import JSON as JSONDriver
from core.drivers.red_mongo import Mongo
import logging
from typing import Callable
log = logging.getLogger("red.config")
class BaseConfig:
def __init__(self, cog_name, unique_identifier, driver_spawn, force_registration=False,
hash_uuid=True, collection="GLOBAL", collection_uuid=None,
defaults={}):
self.cog_name = cog_name
if hash_uuid:
self.uuid = str(hash(unique_identifier))
else:
self.uuid = unique_identifier
self.driver_spawn = driver_spawn
self._driver = None
self.collection = collection
self.collection_uuid = collection_uuid
self.force_registration = force_registration
try:
self.driver.maybe_add_ident(self.uuid)
except AttributeError:
pass
self.driver_getmap = {
"GLOBAL": self.driver.get_global,
"GUILD": self.driver.get_guild,
"CHANNEL": self.driver.get_channel,
"ROLE": self.driver.get_role,
"USER": self.driver.get_user
}
self.driver_setmap = {
"GLOBAL": self.driver.set_global,
"GUILD": self.driver.set_guild,
"CHANNEL": self.driver.set_channel,
"ROLE": self.driver.set_role,
"USER": self.driver.set_user
}
self.curr_key = None
self.unsettable_keys = ("cog_name", "cog_identifier", "_id",
"guild_id", "channel_id", "role_id",
"user_id", "uuid")
self.invalid_keys = (
"driver_spawn",
"_driver", "collection",
"collection_uuid", "force_registration"
)
self.defaults = defaults if defaults else {
"GLOBAL": {}, "GUILD": {}, "CHANNEL": {}, "ROLE": {},
"MEMBER": {}, "USER": {}}
@classmethod
def get_conf(cls, cog_instance: object, unique_identifier: int=0,
force_registration: bool=False):
"""
Gets a config object that cog's can use to safely store data. The
backend to this is totally modular and can easily switch between
JSON and a DB. However, when changed, all data will likely be lost
unless cogs write some converters for their data.
Positional Arguments:
cog_instance - The cog `self` object, can be passed in from your
cog's __init__ method.
Keyword Arguments:
unique_identifier - a random integer or string that is used to
differentiate your cog from any other named the same. This way we
can safely store data for multiple cogs that are named the same.
YOU SHOULD USE THIS.
force_registration - A flag which will cause the Config object to
throw exceptions if you try to get/set data keys that you have
not pre-registered. I highly recommend you ENABLE this as it
will help reduce dumb typo errors.
"""
url = None # TODO: get mongo url
port = None # TODO: get mongo port
def spawn_mongo_driver():
return Mongo(url, port)
# TODO: Determine which backend users want, default to JSON
cog_name = cog_instance.__class__.__name__
driver_spawn = JSONDriver(cog_name)
return cls(cog_name=cog_name, unique_identifier=unique_identifier,
driver_spawn=driver_spawn, force_registration=force_registration)
@classmethod
def get_core_conf(cls, force_registration: bool=False):
core_data_path = Path.cwd() / 'core' / '.data'
driver_spawn = JSONDriver("Core", data_path_override=core_data_path)
return cls(cog_name="Core", driver_spawn=driver_spawn,
unique_identifier=0,
force_registration=force_registration)
@property
def driver(self):
if self._driver is None:
try:
self._driver = self.driver_spawn()
except TypeError:
return self.driver_spawn
return self._driver
def __getattr__(self, key):
"""This should be used to return config key data as determined by
`self.collection` and `self.collection_uuid`."""
raise NotImplemented
def __setattr__(self, key, value):
if 'defaults' in self.__dict__: # Necessary to let the cog load
restricted = list(self.defaults[self.collection].keys()) + \
list(self.unsettable_keys)
if key in restricted:
raise ValueError("Not allowed to dynamically set attributes of"
" unsettable_keys: {}".format(restricted))
else:
self.__dict__[key] = value
else:
self.__dict__[key] = value
def clear(self):
"""Clears all values in the current context ONLY."""
raise NotImplemented
def set(self, key, value):
"""This should set config key with value `value` in the
corresponding collection as defined by `self.collection` and
`self.collection_uuid`."""
raise NotImplemented
def guild(self, guild):
"""This should return a `BaseConfig` instance with the corresponding
`collection` and `collection_uuid`."""
raise NotImplemented
def channel(self, channel):
"""This should return a `BaseConfig` instance with the corresponding
`collection` and `collection_uuid`."""
raise NotImplemented
def role(self, role):
"""This should return a `BaseConfig` instance with the corresponding
`collection` and `collection_uuid`."""
raise NotImplemented
def member(self, member):
"""This should return a `BaseConfig` instance with the corresponding
`collection` and `collection_uuid`."""
raise NotImplemented
def user(self, user):
"""This should return a `BaseConfig` instance with the corresponding
`collection` and `collection_uuid`."""
raise NotImplemented
def register_global(self, **global_defaults):
"""
Registers a new dict of global defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param global_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in global_defaults.items():
try:
self._register_global(k, v)
except KeyError:
log.exception("Bad default global key.")
def _register_global(self, key, default=None):
"""Registers a global config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["GLOBAL"][key] = default
def register_guild(self, **guild_defaults):
"""
Registers a new dict of guild defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param guild_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in guild_defaults.items():
try:
self._register_guild(k, v)
except KeyError:
log.exception("Bad default guild key.")
def _register_guild(self, key, default=None):
"""Registers a guild config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["GUILD"][key] = default
def register_channel(self, **channel_defaults):
"""
Registers a new dict of channel defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param channel_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in channel_defaults.items():
try:
self._register_channel(k, v)
except KeyError:
log.exception("Bad default channel key.")
def _register_channel(self, key, default=None):
"""Registers a channel config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["CHANNEL"][key] = default
def register_role(self, **role_defaults):
"""
Registers a new dict of role defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param role_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in role_defaults.items():
try:
self._register_role(k, v)
except KeyError:
log.exception("Bad default role key.")
def _register_role(self, key, default=None):
"""Registers a role config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["ROLE"][key] = default
def register_member(self, **member_defaults):
"""
Registers a new dict of member defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param member_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in member_defaults.items():
try:
self._register_member(k, v)
except KeyError:
log.exception("Bad default member key.")
def _register_member(self, key, default=None):
"""Registers a member config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["MEMBER"][key] = default
def register_user(self, **user_defaults):
"""
Registers a new dict of user defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param user_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in user_defaults.items():
try:
self._register_user(k, v)
except KeyError:
log.exception("Bad default user key.")
def _register_user(self, key, default=None):
"""Registers a user config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["USER"][key] = default
class Config(BaseConfig):
"""
Config object created by `Config.get_conf()`
This configuration object is designed to make backend data
storage mechanisms pluggable. It also is designed to
help a cog developer make fewer mistakes (such as
typos) when dealing with cog data and to make those mistakes
apparent much faster in the design process.
It also has the capability to safely store data between cogs
that share the same name.
There are two main components to this config object. First,
you have the ability to get data on a level specific basis.
The seven levels available are: global, guild, channel, role,
member, user, and misc.
The second main component is registering default values for
data in each of the levels. This functionality is OPTIONAL
and must be explicitly enabled when creating the Config object
using the kwarg `force_registration=True`.
Basic Usage:
Creating a Config object:
Use the `Config.get_conf()` class method to create new
Config objects.
See the `Config.get_conf()` documentation for more
information.
Registering Default Values (optional):
You can register default values for data at all levels
EXCEPT misc.
Simply pass in the key/value pairs as keyword arguments to
the respective function.
e.g.: conf_obj.register_global(enabled=True)
conf_obj.register_guild(likes_red=True)
Retrieving data by attributes:
Since I registered the "enabled" key in the previous example
at the global level I can now do:
conf_obj.enabled()
which will retrieve the current value of the "enabled"
key, making use of the default of "True". I can also do
the same for the guild key "likes_red":
conf_obj.guild(guild_obj).likes_red()
If I elected to not register default values, you can provide them
when you try to access the key:
conf_obj.no_default(default=True)
However if you do not provide a default and you do not register
defaults, accessing the attribute will return "None".
Saving data:
This is accomplished by using the `set` function available at
every level.
e.g.: conf_obj.set("enabled", False)
conf_obj.guild(guild_obj).set("likes_red", False)
If `force_registration` was enabled when the config object
was created you will only be allowed to save keys that you
have registered.
Misc data is special, use `conf.misc()` and `conf.set_misc(value)`
respectively.
"""
def __getattr__(self, key) -> Callable:
"""
Until I've got a better way to do this I'm just gonna fake __call__
:param key:
:return: lambda function with kwarg
"""
return self._get_value_from_key(key)
def _get_value_from_key(self, key) -> Callable:
try:
default = self.defaults[self.collection][key]
except KeyError as e:
if self.force_registration:
raise AttributeError("Key '{}' not registered!".format(key)) from e
default = None
self.curr_key = key
if self.collection != "MEMBER":
ret = lambda default=default: self.driver_getmap[self.collection](
self.cog_name, self.uuid, self.collection_uuid, key,
default=default)
else:
mid, sid = self.collection_uuid
ret = lambda default=default: self.driver.get_member(
self.cog_name, self.uuid, mid, sid, key,
default=default)
return ret
def get(self, key, default=None):
"""
Included as an alternative to registering defaults.
:param key:
:param default:
:return:
"""
try:
return getattr(self, key)(default=default)
except AttributeError:
return
async def set(self, key, value):
# Notice to future developers:
# This code was commented to allow users to set keys without having to register them.
# That being said, if they try to get keys without registering them
# things will blow up. I do highly recommend enforcing the key registration.
if key in self.unsettable_keys or key in self.invalid_keys:
raise KeyError("Restricted key name, please use another.")
if self.force_registration and key not in self.defaults[self.collection]:
raise AttributeError("Key '{}' not registered!".format(key))
if not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
if self.collection == "GLOBAL":
await self.driver.set_global(self.cog_name, self.uuid, key, value)
elif self.collection == "MEMBER":
mid, sid = self.collection_uuid
await self.driver.set_member(self.cog_name, self.uuid, mid, sid,
key, value)
elif self.collection in self.driver_setmap:
func = self.driver_setmap[self.collection]
await func(self.cog_name, self.uuid, self.collection_uuid, key, value)
async def clear(self):
await self.driver_setmap[self.collection](
self.cog_name, self.uuid, self.collection_uuid, None, None,
clear=True)
def guild(self, guild):
new = type(self)(self.cog_name, self.uuid, self.driver,
hash_uuid=False, defaults=self.defaults)
new.collection = "GUILD"
new.collection_uuid = guild.id
new._driver = None
return new
def channel(self, channel):
new = type(self)(self.cog_name, self.uuid, self.driver,
hash_uuid=False, defaults=self.defaults)
new.collection = "CHANNEL"
new.collection_uuid = channel.id
new._driver = None
return new
def role(self, role):
new = type(self)(self.cog_name, self.uuid, self.driver,
hash_uuid=False, defaults=self.defaults)
new.collection = "ROLE"
new.collection_uuid = role.id
new._driver = None
return new
def member(self, member):
guild = member.guild
new = type(self)(self.cog_name, self.uuid, self.driver,
hash_uuid=False, defaults=self.defaults)
new.collection = "MEMBER"
new.collection_uuid = (member.id, guild.id)
new._driver = None
return new
def user(self, user):
new = type(self)(self.cog_name, self.uuid, self.driver,
hash_uuid=False, defaults=self.defaults)
new.collection = "USER"
new.collection_uuid = user.id
new._driver = None
return new

0
core/drivers/__init__.py Normal file
View File

45
core/drivers/red_base.py Normal file
View File

@ -0,0 +1,45 @@
class BaseDriver:
def get_global(self, cog_name, ident, collection_id, key, *, default=None):
raise NotImplementedError()
def get_guild(self, cog_name, ident, guild_id, key, *, default=None):
raise NotImplementedError()
def get_channel(self, cog_name, ident, channel_id, key, *, default=None):
raise NotImplementedError()
def get_role(self, cog_name, ident, role_id, key, *, default=None):
raise NotImplementedError()
def get_member(self, cog_name, ident, user_id, guild_id, key, *,
default=None):
raise NotImplementedError()
def get_user(self, cog_name, ident, user_id, key, *, default=None):
raise NotImplementedError()
def get_misc(self, cog_name, ident, *, default=None):
raise NotImplementedError()
async def set_global(self, cog_name, ident, key, value, clear=False):
raise NotImplementedError()
async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False):
raise NotImplementedError()
async def set_channel(self, cog_name, ident, channel_id, key, value,
clear=False):
raise NotImplementedError()
async def set_role(self, cog_name, ident, role_id, key, value, clear=False):
raise NotImplementedError()
async def set_member(self, cog_name, ident, user_id, guild_id, key, value,
clear=False):
raise NotImplementedError()
async def set_user(self, cog_name, ident, user_id, key, value, clear=False):
raise NotImplementedError()
async def set_misc(self, cog_name, ident, value, clear=False):
raise NotImplementedError()

135
core/drivers/red_json.py Normal file
View File

@ -0,0 +1,135 @@
from core.json_io import JsonIO
import os
from .red_base import BaseDriver
from pathlib import Path
class JSON(BaseDriver):
def __init__(self, cog_name, *args, data_path_override: Path=None,
file_name_override: str="settings.json", **kwargs):
self.cog_name = cog_name
self.file_name = file_name_override
if data_path_override:
self.data_path = data_path_override
else:
self.data_path = Path.cwd() / 'cogs' / '.data' / self.cog_name
self.data_path.mkdir(parents=True, exist_ok=True)
self.data_path = self.data_path / self.file_name
self.jsonIO = JsonIO(self.data_path)
try:
self.data = self.jsonIO._load_json()
except FileNotFoundError:
self.data = {}
def maybe_add_ident(self, ident: str):
if ident in self.data:
return
self.data[ident] = {}
for k in ("GLOBAL", "GUILD", "CHANNEL", "ROLE", "MEMBER", "USER"):
if k not in self.data[ident]:
self.data[ident][k] = {}
self.jsonIO._save_json(self.data)
def get_global(self, cog_name, ident, _, key, *, default=None):
return self.data[ident]["GLOBAL"].get(key, default)
def get_guild(self, cog_name, ident, guild_id, key, *, default=None):
guilddata = self.data[ident]["GUILD"].get(str(guild_id), {})
return guilddata.get(key, default)
def get_channel(self, cog_name, ident, channel_id, key, *, default=None):
channeldata = self.data[ident]["CHANNEL"].get(str(channel_id), {})
return channeldata.get(key, default)
def get_role(self, cog_name, ident, role_id, key, *, default=None):
roledata = self.data[ident]["ROLE"].get(str(role_id), {})
return roledata.get(key, default)
def get_member(self, cog_name, ident, user_id, guild_id, key, *,
default=None):
userdata = self.data[ident]["MEMBER"].get(str(user_id), {})
guilddata = userdata.get(str(guild_id), {})
return guilddata.get(key, default)
def get_user(self, cog_name, ident, user_id, key, *, default=None):
userdata = self.data[ident]["USER"].get(str(user_id), {})
return userdata.get(key, default)
async def set_global(self, cog_name, ident, key, value, clear=False):
if clear:
self.data[ident]["GLOBAL"] = {}
else:
self.data[ident]["GLOBAL"][key] = value
await self.jsonIO._threadsafe_save_json(self.data)
async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False):
guild_id = str(guild_id)
if clear:
self.data[ident]["GUILD"][guild_id] = {}
else:
try:
self.data[ident]["GUILD"][guild_id][key] = value
except KeyError:
self.data[ident]["GUILD"][guild_id] = {}
self.data[ident]["GUILD"][guild_id][key] = value
await self.jsonIO._threadsafe_save_json(self.data)
async def set_channel(self, cog_name, ident, channel_id, key, value, clear=False):
channel_id = str(channel_id)
if clear:
self.data[ident]["CHANNEL"][channel_id] = {}
else:
try:
self.data[ident]["CHANNEL"][channel_id][key] = value
except KeyError:
self.data[ident]["CHANNEL"][channel_id] = {}
self.data[ident]["CHANNEL"][channel_id][key] = value
await self.jsonIO._threadsafe_save_json(self.data)
async def set_role(self, cog_name, ident, role_id, key, value, clear=False):
role_id = str(role_id)
if clear:
self.data[ident]["ROLE"][role_id] = {}
else:
try:
self.data[ident]["ROLE"][role_id][key] = value
except KeyError:
self.data[ident]["ROLE"][role_id] = {}
self.data[ident]["ROLE"][role_id][key] = value
await self.jsonIO._threadsafe_save_json(self.data)
async def set_member(self, cog_name, ident, user_id, guild_id, key, value, clear=False):
user_id = str(user_id)
guild_id = str(guild_id)
if clear:
self.data[ident]["MEMBER"][user_id] = {}
else:
try:
self.data[ident]["MEMBER"][user_id][guild_id][key] = value
except KeyError:
if user_id not in self.data[ident]["MEMBER"]:
self.data[ident]["MEMBER"][user_id] = {}
if guild_id not in self.data[ident]["MEMBER"][user_id]:
self.data[ident]["MEMBER"][user_id][guild_id] = {}
self.data[ident]["MEMBER"][user_id][guild_id][key] = value
await self.jsonIO._threadsafe_save_json(self.data)
async def set_user(self, cog_name, ident, user_id, key, value, clear=False):
user_id = str(user_id)
if clear:
self.data[ident]["USER"][user_id] = {}
else:
try:
self.data[ident]["USER"][user_id][key] = value
except KeyError:
self.data[ident]["USER"][user_id] = {}
self.data[ident]["USER"][user_id][key] = value
await self.jsonIO._threadsafe_save_json(self.data)

211
core/drivers/red_mongo.py Normal file
View File

@ -0,0 +1,211 @@
import pymongo as m
from .red_base import BaseDriver
class RedMongoException(Exception):
"""Base Red Mongo Exception class"""
pass
class MultipleMatches(RedMongoException):
"""Raised when multiple documents match a single cog_name and
cog_identifier pair."""
pass
class MissingCollection(RedMongoException):
"""Raised when a collection is missing from the mongo db"""
pass
class Mongo(BaseDriver):
def __init__(self, host, port=27017, admin_user=None, admin_pass=None,
**kwargs):
self.conn = m.MongoClient(host=host, port=port, **kwargs)
self.admin_user = admin_user
self.admin_pass = admin_pass
self._db = self.conn.red
if self.admin_user is not None and self.admin_pass is not None:
self._db.authenticate(self.admin_user, self.admin_pass)
self._global = self._db.GLOBAL
self._guild = self._db.GUILD
self._channel = self._db.CHANNEL
self._role = self._db.ROLE
self._member = self._db.MEMBER
self._user = self._db.USER
def get_global(self, cog_name, cog_identifier, _, key, *, default=None):
doc = self._global.find(
{"cog_name": cog_name, "cog_identifier": cog_identifier},
projection=[key, ], batch_size=2)
if doc.count() == 2:
raise MultipleMatches("Too many matching documents at the GLOBAL"
" level: ({}, {})".format(cog_name,
cog_identifier))
elif doc.count() == 1:
return doc[0].get(key, default)
return default
def get_guild(self, cog_name, cog_identifier, guild_id, key, *,
default=None):
doc = self._guild.find(
{"cog_name": cog_name, "cog_identifier": cog_identifier,
"guild_id": guild_id},
projection=[key, ], batch_size=2)
if doc.count() == 2:
raise MultipleMatches("Too many matching documents at the GUILD"
" level: ({}, {}, {})".format(
cog_name, cog_identifier, guild_id))
elif doc.count() == 1:
return doc[0].get(key, default)
return default
def get_channel(self, cog_name, cog_identifier, channel_id, key, *,
default=None):
doc = self._channel.find(
{"cog_name": cog_name, "cog_identifier": cog_identifier,
"channel_id": channel_id},
projection=[key, ], batch_size=2)
if doc.count() == 2:
raise MultipleMatches("Too many matching documents at the CHANNEL"
" level: ({}, {}, {})".format(
cog_name, cog_identifier, channel_id))
elif doc.count() == 1:
return doc[0].get(key, default)
return default
def get_role(self, cog_name, cog_identifier, role_id, key, *,
default=None):
doc = self._role.find(
{"cog_name": cog_name, "cog_identifier": cog_identifier,
"role_id": role_id},
projection=[key, ], batch_size=2)
if doc.count() == 2:
raise MultipleMatches("Too many matching documents at the ROLE"
" level: ({}, {}, {})".format(
cog_name, cog_identifier, role_id))
elif doc.count() == 1:
return doc[0].get(key, default)
return default
def get_member(self, cog_name, cog_identifier, user_id, guild_id, key, *,
default=None):
doc = self._member.find(
{"cog_name": cog_name, "cog_identifier": cog_identifier,
"user_id": user_id, "guild_id": guild_id},
projection=[key, ], batch_size=2)
if doc.count() == 2:
raise MultipleMatches("Too many matching documents at the MEMBER"
" level: ({}, {}, mid {}, sid {})".format(
cog_name, cog_identifier, user_id,
guild_id))
elif doc.count() == 1:
return doc[0].get(key, default)
return default
def get_user(self, cog_name, cog_identifier, user_id, key, *,
default=None):
doc = self._user.find(
{"cog_name": cog_name, "cog_identifier": cog_identifier,
"user_id": user_id},
projection=[key, ], batch_size=2)
if doc.count() == 2:
raise MultipleMatches("Too many matching documents at the USER"
" level: ({}, {}, mid {})".format(
cog_name, cog_identifier, user_id))
elif doc.count() == 1:
return doc[0].get(key, default)
else:
return default
def set_global(self, cog_name, cog_identifier, key, value, clear=False):
filter = {"cog_name": cog_name, "cog_identifier": cog_identifier}
data = {"$set": {key: value}}
if self._global.count(filter) > 1:
raise MultipleMatches("Too many matching documents at the GLOBAL"
" level: ({}, {})".format(cog_name,
cog_identifier))
else:
if clear:
self._global.delete_one(filter)
else:
self._global.update_one(filter, data, upsert=True)
def set_guild(self, cog_name, cog_identifier, guild_id, key, value,
clear=False):
filter = {"cog_name": cog_name, "cog_identifier": cog_identifier,
"guild_id": guild_id}
data = {"$set": {key: value}}
if self._guild.count(filter) > 1:
raise MultipleMatches("Too many matching documents at the GUILD"
" level: ({}, {}, {})".format(
cog_name, cog_identifier, guild_id))
else:
if clear:
self._guild.delete_one(filter)
else:
self._guild.update_one(filter, data, upsert=True)
def set_channel(self, cog_name, cog_identifier, channel_id, key, value,
clear=False):
filter = {"cog_name": cog_name, "cog_identifier": cog_identifier,
"channel_id": channel_id}
data = {"$set": {key: value}}
if self._channel.count(filter) > 1:
raise MultipleMatches("Too many matching documents at the CHANNEL"
" level: ({}, {}, {})".format(
cog_name, cog_identifier, channel_id))
else:
if clear:
self._channel.delete_one(filter)
else:
self._channel.update_one(filter, data, upsert=True)
def set_role(self, cog_name, cog_identifier, role_id, key, value,
clear=False):
filter = {"cog_name": cog_name, "cog_identifier": cog_identifier,
"role_id": role_id}
data = {"$set": {key: value}}
if self._role.count(filter) > 1:
raise MultipleMatches("Too many matching documents at the ROLE"
" level: ({}, {}, {})".format(
cog_name, cog_identifier, role_id))
else:
if clear:
self._role.delete_one(filter)
else:
self._role.update_one(filter, data, upsert=True)
def set_member(self, cog_name, cog_identifier, user_id, guild_id, key,
value, clear=False):
filter = {"cog_name": cog_name, "cog_identifier": cog_identifier,
"guild_id": guild_id, "user_id": user_id}
data = {"$set": {key: value}}
if self._member.count(filter) > 1:
raise MultipleMatches("Too many matching documents at the MEMBER"
" level: ({}, {}, mid {}, sid {})".format(
cog_name, cog_identifier, user_id,
guild_id))
else:
if clear:
self._member.delete_one(filter)
else:
self._member.update_one(filter, data, upsert=True)
def set_user(self, cog_name, cog_identifier, user_id, key, value,
clear=False):
filter = {"cog_name": cog_name, "cog_identifier": cog_identifier,
"user_id": user_id}
data = {"$set": {key: value}}
if self._user.count(filter) > 1:
raise MultipleMatches("Too many matching documents at the USER"
" level: ({}, {}, mid {})".format(
cog_name, cog_identifier, user_id))
else:
if clear:
self._user.delete_one(filter)
else:
self._user.update_one(filter, data, upsert=True)

View File

@ -32,7 +32,7 @@ def init_events(bot, cli_flags):
if cli_flags.no_cogs is False:
print("Loading packages...")
failed = []
packages = bot.db.get_global("packages", [])
packages = bot.db.packages()
for package in packages:
try:
@ -69,7 +69,7 @@ def init_events(bot, cli_flags):
print("\nInvite URL: {}\n".format(invite_url))
@bot.event
async def on_command_error(error, ctx):
async def on_command_error(ctx, error):
if isinstance(error, commands.MissingRequiredArgument):
await bot.send_cmd_help(ctx)
elif isinstance(error, commands.BadArgument):

View File

@ -1,3 +1,6 @@
from discord.ext import commands
def init_global_checks(bot):
@bot.check
@ -5,20 +8,19 @@ def init_global_checks(bot):
if await bot.is_owner(ctx.author):
return True
if bot.db.get_global("whitelist", []):
return ctx.author.id in bot.db.get_global("whitelist", [])
if bot.db.whitelist():
return ctx.author.id in bot.db.whitelist()
return ctx.author.id not in bot.db.get_global("blacklist", [])
return ctx.author.id not in bot.db.blacklist()
@bot.check
async def local_perms(ctx):
async def local_perms(ctx: commands.Context):
if await bot.is_owner(ctx.author):
return True
elif ctx.message.guild is None:
return True
guild_perms = bot.db.get_all(ctx.guild, {})
local_blacklist = guild_perms.get("blacklist", [])
local_whitelist = guild_perms.get("whitelist", [])
local_blacklist = bot.db.guild(ctx.guild).blacklist()
local_whitelist = bot.db.guild(ctx.guild).whitelist()
if local_whitelist:
return ctx.author.id in local_whitelist

View File

@ -7,7 +7,7 @@ from uuid import uuid4
# This is basically our old DataIO, except that it's now threadsafe
# and just a base for much more elaborate classes
from pathlib import Path
log = logging.getLogger("red")
@ -17,25 +17,33 @@ MINIFIED = {"sort_keys": True, "separators": (',', ':')}
class JsonIO:
"""Basic functions for atomic saving / loading of json files"""
_lock = asyncio.Lock()
def __init__(self, path: Path=Path.cwd()):
"""
:param path: Full path to file.
"""
self._lock = asyncio.Lock()
self.path = path
def _save_json(self, path, data, settings=PRETTY):
log.debug("Saving file {}".format(path))
filename, _ = os.path.splitext(path)
# noinspection PyUnresolvedReferences
def _save_json(self, data, settings=PRETTY):
log.debug("Saving file {}".format(self.path))
filename = self.path.stem
tmp_file = "{}-{}.tmp".format(filename, uuid4().fields[0])
with open(tmp_file, encoding="utf-8", mode="w") as f:
tmp_path = self.path.parent / tmp_file
with tmp_path.open(encoding="utf-8", mode="w") as f:
json.dump(data, f, **settings)
os.replace(tmp_file, path)
tmp_path.replace(self.path)
async def _threadsafe_save_json(self, path, data, settings=PRETTY):
async def _threadsafe_save_json(self, data, settings=PRETTY):
loop = asyncio.get_event_loop()
func = functools.partial(self._save_json, path, data, settings)
func = functools.partial(self._save_json, data, settings)
with await self._lock:
await loop.run_in_executor(None, func)
def _load_json(self, path):
log.debug("Reading file {}".format(path))
with open(path, encoding='utf-8', mode="r") as f:
# noinspection PyUnresolvedReferences
def _load_json(self):
log.debug("Reading file {}".format(self.path))
with self.path.open(encoding='utf-8', mode="r") as f:
data = json.load(f)
return data

View File

@ -1,124 +0,0 @@
from core.utils.helpers import JsonGuildDB
import discord
import argparse
class CoreDB(JsonGuildDB):
"""
The central DB used by Red to store a variety
of settings, both global and guild specific
"""
def get_admin_role(self, guild):
"""Returns the guild's admin role
Returns None if not set or if the role
couldn't be retrieved"""
_id = self.get_all(guild, {}).get("admin_role", None)
return discord.utils.get(guild.roles, id=_id)
def get_mod_role(self, guild):
"""Returns the guild's mod role
Returns None if not set or if the role
couldn't be retrieved"""
_id = self.get_all(guild, {}).get("mod_role", None)
return discord.utils.get(guild.roles, id=_id)
async def set_admin_role(self, role):
"""Sets the admin role for the guild"""
if not isinstance(role, discord.Role):
raise TypeError("A valid Discord role must be passed.")
await self.set(role.guild, "admin_role", role.id)
async def set_mod_role(self, role):
"""Sets the mod role for the guild"""
if not isinstance(role, discord.Role):
raise TypeError("A valid Discord role must be passed.")
await self.set(role.guild, "mod_role", role.id)
def get_global_whitelist(self):
"""Returns the global whitelist"""
return self.get_global("whitelist", [])
def get_global_blacklist(self):
"""Returns the global whitelist"""
return self.get_global("blacklist", [])
async def set_global_whitelist(self, whitelist):
"""Sets the global whitelist"""
if not isinstance(list, whitelist):
raise TypeError("A list of IDs must be passed.")
await self.set_global("whitelist", whitelist)
async def set_global_blacklist(self, blacklist):
"""Sets the global blacklist"""
if not isinstance(list, blacklist):
raise TypeError("A list of IDs must be passed.")
await self.set_global("blacklist", blacklist)
def get_guild_whitelist(self, guild):
"""Returns the guild's whitelist"""
return self.get(guild, "whitelist", [])
def get_guild_blacklist(self, guild):
"""Returns the guild's blacklist"""
return self.get(guild, "blacklist", [])
async def set_guild_whitelist(self, guild, whitelist):
"""Sets the guild's whitelist"""
if not isinstance(guild, discord.Guild) or not isinstance(whitelist, list):
raise TypeError("A valid Discord guild and a list of IDs "
"must be passed.")
await self.set(guild, "whitelist", whitelist)
async def set_guild_blacklist(self, guild, blacklist):
"""Sets the guild's blacklist"""
if not isinstance(guild, discord.Guild) or not isinstance(blacklist, list):
raise TypeError("A valid Discord guild and a list of IDs "
"must be passed.")
await self.set(guild, "blacklist", blacklist)
def parse_cli_flags():
parser = argparse.ArgumentParser(description="Red - Discord Bot")
parser.add_argument("--owner", help="ID of the owner. Only who hosts "
"Red should be owner, this has "
"security implications")
parser.add_argument("--prefix", "-p", action="append",
help="Global prefix. Can be multiple")
parser.add_argument("--no-prompt",
action="store_true",
help="Disables console inputs. Features requiring "
"console interaction could be disabled as a "
"result")
parser.add_argument("--no-cogs",
action="store_true",
help="Starts Red with no cogs loaded, only core")
parser.add_argument("--self-bot",
action='store_true',
help="Specifies if Red should log in as selfbot")
parser.add_argument("--not-bot",
action='store_true',
help="Specifies if the token used belongs to a bot "
"account.")
parser.add_argument("--dry-run",
action="store_true",
help="Makes Red quit with code 0 just before the "
"login. This is useful for testing the boot "
"process.")
parser.add_argument("--debug",
action="store_true",
help="Sets the loggers level as debug")
parser.add_argument("--dev",
action="store_true",
help="Enables developer mode")
args = parser.parse_args()
if args.prefix:
args.prefix = sorted(args.prefix, reverse=True)
else:
args.prefix = []
return args

View File

@ -1,216 +0,0 @@
import os
import discord
import asyncio
import functools
import inspect
from collections import defaultdict
from core.json_io import JsonIO
GLOBAL_KEY = '__global__'
SENTINEL = object()
class JsonDB(JsonIO):
"""
A DB-like helper class to streamline the saving of json files
Parameters:
file_path: str
The path of the json file you want to create / access
create_dirs: bool=True
If True, it will create any missing directory leading to
the file you want to create
relative_path: bool=True
The file_path you specified is relative to the path from which
you're instantiating this object from
i.e. If you're in a package's folder and your file_path is
'data/settings.json', these files will be created inside
the package's folder and not Red's root folder
default_value: Optional=None
Same behaviour as a defaultdict
"""
_caller = ""
def __init__(self, file_path, **kwargs):
local = kwargs.pop("relative_path", True)
if local and not self._caller:
self._caller = self._get_caller_path()
create_dirs = kwargs.pop("create_dirs", True)
default_value = kwargs.pop("default_value", SENTINEL)
self.autosave = kwargs.pop("autosave", False)
self.path = os.path.join(self._caller, file_path)
file_exists = os.path.isfile(self.path)
if create_dirs and not file_exists:
path, _ = os.path.split(self.path)
if path:
try:
os.makedirs(path)
except FileExistsError:
pass
if file_exists:
# Might be worth looking into threadsafe ways for very large files
self._data = self._load_json(self.path)
else:
self._data = {}
self._blocking_save()
if default_value is not SENTINEL:
def _get_default():
return default_value
self._data = defaultdict(_get_default, self._data)
self._loop = asyncio.get_event_loop()
self._task = functools.partial(self._threadsafe_save_json, self._data)
async def set(self, key, value):
"""Sets a DB's entry"""
self._data[key] = value
await self.save()
def get(self, key, default=None):
"""Returns a DB's entry"""
return self._data.get(key, default)
async def remove(self, key):
"""Removes a DB's entry"""
del self._data[key]
await self.save()
async def pop(self, key, default=None):
"""Removes and returns a DB's entry"""
value = self._data.pop(key, default)
await self.save()
return value
async def wipe(self):
"""Wipes DB"""
self._data = {}
await self.save()
def all(self):
"""Returns all DB's data"""
return self._data
def _blocking_save(self):
"""Using this should be avoided. Let's stick to threadsafe saves"""
self._save_json(self.path, self._data)
async def save(self):
"""Threadsafe save to file"""
await self._threadsafe_save_json(self.path, self._data)
def _get_caller_path(self):
frame = inspect.stack()[2]
module = inspect.getmodule(frame[0])
abspath = os.path.abspath(module.__file__)
return os.path.dirname(abspath)
def __contains__(self, key):
return key in self._data
def __getitem__(self, key):
return self._data[key]
def __len__(self):
return len(self._data)
def __repr__(self):
return "<{} {}>".format(self.__class__.__name__, self._data)
class JsonGuildDB(JsonDB):
"""
A DB-like helper class to streamline the saving of json files
This is a variant of JsonDB that allows for guild specific data
Global data is still allowed with dedicated methods
Same parameters as JsonDB
"""
def __init__(self, *args, **kwargs):
local = kwargs.get("relative_path", True)
if local and not self._caller:
self._caller = self._get_caller_path()
super().__init__(*args, **kwargs)
async def set(self, guild, key, value):
"""Sets a guild's entry"""
if not isinstance(guild, discord.Guild):
raise TypeError('Can only set guild data')
if str(guild.id) not in self._data:
self._data[str(guild.id)] = {}
self._data[str(guild.id)][key] = value
await self.save()
def get(self, guild, key, default=None):
"""Returns a guild's entry"""
if not isinstance(guild, discord.Guild):
raise TypeError('Can only get guild data')
if str(guild.id) not in self._data:
return default
return self._data[str(guild.id)].get(key, default)
async def remove(self, guild, key):
"""Removes a guild's entry"""
if not isinstance(guild, discord.Guild):
raise TypeError('Can only remove guild data')
if str(guild.id) not in self._data:
raise KeyError('Guild data is not present')
del self._data[str(guild.id)][key]
await self.save()
async def pop(self, guild, key, default=None):
"""Removes and returns a guild's entry"""
if not isinstance(guild, discord.Guild):
raise TypeError('Can only remove guild data')
value = self._data.get(str(guild.id), {}).pop(key, default)
await self.save()
return value
def get_all(self, guild, default=None):
"""Returns all entries of a guild"""
if not isinstance(guild, discord.Guild):
raise TypeError('Can only get guild data')
return self._data.get(str(guild.id), default)
async def remove_all(self, guild):
"""Removes all entries of a guild"""
if not isinstance(guild, discord.Guild):
raise TypeError('Can only remove guilds')
await super().remove(str(guild.id))
async def set_global(self, key, value):
"""Sets a global value"""
if GLOBAL_KEY not in self._data:
self._data[GLOBAL_KEY] = {}
self._data[GLOBAL_KEY][key] = value
await self.save()
def get_global(self, key, default=None):
"""Gets a global value"""
if GLOBAL_KEY not in self._data:
self._data[GLOBAL_KEY] = {}
return self._data[GLOBAL_KEY].get(key, default)
async def remove_global(self, key):
"""Removes a global value"""
if GLOBAL_KEY not in self._data:
self._data[GLOBAL_KEY] = {}
del self._data[GLOBAL_KEY][key]
await self.save()
async def pop_global(self, key, default=None):
"""Removes and returns a global value"""
if GLOBAL_KEY not in self._data:
self._data[GLOBAL_KEY] = {}
value = self._data[GLOBAL_KEY].pop(key, default)
await self.save()
return value

11
main.py
View File

@ -1,8 +1,7 @@
from core.bot import Red, ExitCodes
from core.global_checks import init_global_checks
from core.events import init_events
from core.settings import parse_cli_flags
from core.cli import interactive_config, confirm
from core.cli import interactive_config, confirm, parse_cli_flags
from core.core_commands import Core
from core.dev_commands import Dev
import asyncio
@ -65,8 +64,8 @@ if __name__ == '__main__':
if cli_flags.dev:
red.add_cog(Dev())
token = os.environ.get("RED_TOKEN", red.db.get_global("token", None))
prefix = cli_flags.prefix or red.db.get_global("prefix", [])
token = os.environ.get("RED_TOKEN", red.db.token())
prefix = cli_flags.prefix or red.db.prefix()
if token is None or not prefix:
if cli_flags.no_prompt is False:
@ -89,11 +88,11 @@ if __name__ == '__main__':
"a user account, remember that the --not-bot flag "
"must be used. For self-bot functionalities instead, "
"--self-bot")
db_token = red.db.get_global("token")
db_token = red.db.token()
if db_token and not cli_flags.no_prompt:
print("\nDo you want to reset the token? (y/n)")
if confirm("> "):
loop.run_until_complete(red.db.remove_global("token"))
loop.run_until_complete(red.db.set("token", ""))
print("Token has been reset.")
except KeyboardInterrupt:
log.info("Keyboard interrupt detected. Quitting...")

View File

@ -1,2 +1,5 @@
git+https://github.com/Rapptz/discord.py@rewrite#egg=discord.py[voice]
youtube_dl
pytest
git+https://github.com/pytest-dev/pytest-asyncio
pymongo

0
tests/__init__.py Normal file
View File

101
tests/conftest.py Normal file
View File

@ -0,0 +1,101 @@
from collections import namedtuple
from pathlib import Path
import pytest
import random
from core.bot import Red
from core.drivers import red_json
from core import Config
@pytest.fixture(scope="module")
def json_driver(tmpdir_factory):
driver = red_json.JSON(
"PyTest",
data_path_override=Path(str(tmpdir_factory.getbasetemp()))
)
return driver
@pytest.fixture(scope="module")
def config(json_driver):
return Config(
cog_name="PyTest",
unique_identifier=0,
driver_spawn=json_driver)
@pytest.fixture(scope="module")
def config_fr(json_driver):
"""
Mocked config object with force_register enabled.
"""
return Config(
cog_name="PyTest",
unique_identifier=0,
driver_spawn=json_driver,
force_registration=True
)
#region Dpy Mocks
@pytest.fixture(scope="module")
def empty_guild():
mock_guild = namedtuple("Guild", "id members")
return mock_guild(random.randint(1, 999999999), [])
@pytest.fixture(scope="module")
def empty_channel():
mock_channel = namedtuple("Channel", "id")
return mock_channel(random.randint(1, 999999999))
@pytest.fixture(scope="module")
def empty_role():
mock_role = namedtuple("Role", "id")
return mock_role(random.randint(1, 999999999))
@pytest.fixture(scope="module")
def empty_member(empty_guild):
mock_member = namedtuple("Member", "id guild")
return mock_member(random.randint(1, 999999999), empty_guild)
@pytest.fixture(scope="module")
def empty_user():
mock_user = namedtuple("User", "id")
return mock_user(random.randint(1, 999999999))
@pytest.fixture(scope="module")
def empty_message():
mock_msg = namedtuple("Message", "content")
return mock_msg("No content.")
@pytest.fixture(scope="module")
def ctx(empty_member, empty_channel, red):
mock_ctx = namedtuple("Context", "author guild channel message bot")
return mock_ctx(empty_member, empty_member.guild, empty_channel,
empty_message, red)
#endregion
#region Red Mock
@pytest.fixture
def red(monkeypatch, config_fr, event_loop):
from core.cli import parse_cli_flags
cli_flags = parse_cli_flags()
description = "Red v3 - Alpha"
monkeypatch.setattr("core.config.Config.get_core_conf",
lambda *args, **kwargs: config_fr)
red = Red(cli_flags, description=description, pm_help=None,
loop=event_loop)
return red
#endregion

0
tests/core/__init__.py Normal file
View File

147
tests/core/test_config.py Normal file
View File

@ -0,0 +1,147 @@
import pytest
#region Register Tests
def test_config_register_global(config):
config.register_global(enabled=False)
assert config.defaults["GLOBAL"]["enabled"] is False
assert config.enabled() is False
def test_config_register_global_badvalues(config):
with pytest.raises(RuntimeError):
config.register_global(**{"invalid var name": True})
def test_config_register_guild(config, empty_guild):
config.register_guild(enabled=False, some_list=[], some_dict={})
assert config.defaults["GUILD"]["enabled"] is False
assert config.defaults["GUILD"]["some_list"] == []
assert config.defaults["GUILD"]["some_dict"] == {}
assert config.guild(empty_guild).enabled() is False
assert config.guild(empty_guild).some_list() == []
assert config.guild(empty_guild).some_dict() == {}
def test_config_register_channel(config, empty_channel):
config.register_channel(enabled=False)
assert config.defaults["CHANNEL"]["enabled"] is False
assert config.channel(empty_channel).enabled() is False
def test_config_register_role(config, empty_role):
config.register_role(enabled=False)
assert config.defaults["ROLE"]["enabled"] is False
assert config.role(empty_role).enabled() is False
def test_config_register_member(config, empty_member):
config.register_member(some_number=-1)
assert config.defaults["MEMBER"]["some_number"] == -1
assert config.member(empty_member).some_number() == -1
def test_config_register_user(config, empty_user):
config.register_user(some_value=None)
assert config.defaults["USER"]["some_value"] is None
assert config.user(empty_user).some_value() is None
def test_config_force_register_global(config_fr):
with pytest.raises(AttributeError):
config_fr.enabled()
config_fr.register_global(enabled=True)
assert config_fr.enabled() is True
#endregion
#region Default Value Overrides
def test_global_default_override(config):
assert config.enabled(True) is True
assert config.get("enabled") is None
assert config.get("enabled", default=True) is True
def test_global_default_nofr(config):
assert config.nofr() is None
assert config.nofr(True) is True
assert config.get("nofr") is None
assert config.get("nofr", default=True) is True
def test_guild_default_override(config, empty_guild):
assert config.guild(empty_guild).enabled(True) is True
assert config.guild(empty_guild).get("enabled") is None
assert config.guild(empty_guild).get("enabled", default=True) is True
def test_channel_default_override(config, empty_channel):
assert config.channel(empty_channel).enabled(True) is True
assert config.channel(empty_channel).get("enabled") is None
assert config.channel(empty_channel).get("enabled", default=True) is True
def test_role_default_override(config, empty_role):
assert config.role(empty_role).enabled(True) is True
assert config.role(empty_role).get("enabled") is None
assert config.role(empty_role).get("enabled", default=True) is True
def test_member_default_override(config, empty_member):
assert config.member(empty_member).enabled(True) is True
assert config.member(empty_member).get("enabled") is None
assert config.member(empty_member).get("enabled", default=True) is True
def test_user_default_override(config, empty_user):
assert config.user(empty_user).some_value(True) is True
assert config.user(empty_user).get("some_value") is None
assert config.user(empty_user).get("some_value", default=True) is True
#endregion
#region Setting Values
@pytest.mark.asyncio
async def test_set_global(config):
await config.set("enabled", True)
assert config.enabled() is True
@pytest.mark.asyncio
async def test_set_global_badkey(config):
with pytest.raises(RuntimeError):
await config.set("this is a bad key", True)
@pytest.mark.asyncio
async def test_set_global_invalidkey(config):
with pytest.raises(KeyError):
await config.set("uuid", True)
@pytest.mark.asyncio
async def test_set_guild(config, empty_guild):
await config.guild(empty_guild).set("enabled", True)
assert config.guild(empty_guild).enabled() is True
curr_list = config.guild(empty_guild).some_list([1, 2, 3])
assert curr_list == [1, 2, 3]
curr_list.append(4)
await config.guild(empty_guild).set("some_list", curr_list)
assert config.guild(empty_guild).some_list() == curr_list
@pytest.mark.asyncio
async def test_set_channel(config, empty_channel):
await config.channel(empty_channel).set("enabled", True)
assert config.channel(empty_channel).enabled() is True
@pytest.mark.asyncio
async def test_set_channel_no_register(config, empty_channel):
await config.channel(empty_channel).set("no_register", True)
assert config.channel(empty_channel).no_register() is True
#endregion

View File

@ -0,0 +1,6 @@
import pytest
@pytest.mark.asyncio
async def test_can_init_bot(red):
assert red is not None