[Config] Rewrite (#869)

This commit is contained in:
Will 2017-07-30 19:40:31 -04:00 committed by Twentysix
parent 5c2be25dfc
commit 99bfb2fc7a
14 changed files with 636 additions and 724 deletions

View File

@ -1,3 +1,4 @@
dist: trusty
language: python language: python
python: python:
- "3.5.3" - "3.5.3"

View File

@ -81,24 +81,24 @@ class Alias:
if global_: if global_:
curr_aliases = self._aliases.entries() curr_aliases = self._aliases.entries()
curr_aliases.append(alias.to_json()) curr_aliases.append(alias.to_json())
await self._aliases.set("entries", curr_aliases) await self._aliases.entries.set(curr_aliases)
else: else:
curr_aliases = self._aliases.guild(ctx.guild).entries() curr_aliases = self._aliases.guild(ctx.guild).entries()
curr_aliases.append(alias.to_json()) curr_aliases.append(alias.to_json())
await self._aliases.guild(ctx.guild).set("entries", curr_aliases) await self._aliases.guild(ctx.guild).entries.set(curr_aliases)
await self._aliases.guild(ctx.guild).set("enabled", True) await self._aliases.guild(ctx.guild).enabled.set(True)
return alias return alias
async def delete_alias(self, ctx: commands.Context, alias_name: str, async def delete_alias(self, ctx: commands.Context, alias_name: str,
global_: bool=False) -> bool: global_: bool=False) -> bool:
if global_: if global_:
aliases = self.unloaded_global_aliases() aliases = self.unloaded_global_aliases()
setter_func = self._aliases.set setter_func = self._aliases.entries.set
else: else:
aliases = self.unloaded_aliases(ctx.guild) aliases = self.unloaded_aliases(ctx.guild)
setter_func = self._aliases.guild(ctx.guild).set setter_func = self._aliases.guild(ctx.guild).entries.set
did_delete_alias = False did_delete_alias = False
@ -110,7 +110,6 @@ class Alias:
did_delete_alias = True did_delete_alias = True
await setter_func( await setter_func(
"entries",
[a.to_json() for a in to_keep] [a.to_json() for a in to_keep]
) )
@ -355,8 +354,9 @@ class Alias:
await ctx.send(box("\n".join(names), "diff")) await ctx.send(box("\n".join(names), "diff"))
async def on_message(self, message: discord.Message): async def on_message(self, message: discord.Message):
aliases = list(self.unloaded_aliases(message.guild)) + \ aliases = list(self.unloaded_global_aliases())
list(self.unloaded_global_aliases()) if message.guild is not None:
aliases = aliases + list(self.unloaded_aliases(message.guild))
if len(aliases) == 0: if len(aliases) == 0:
return return

View File

@ -30,7 +30,7 @@ class Downloader:
def __init__(self, bot: Red): def __init__(self, bot: Red):
self.bot = bot self.bot = bot
self.conf = Config.get_conf(self, unique_identifier=998240343, self.conf = Config.get_conf(self, identifier=998240343,
force_registration=True) force_registration=True)
self.conf.register_global( self.conf.register_global(
@ -73,7 +73,7 @@ class Downloader:
if cog_json not in installed: if cog_json not in installed:
installed.append(cog_json) installed.append(cog_json)
await self.conf.set("installed", installed) await self.conf.installed.set(installed)
async def _remove_from_installed(self, cog: Installable): async def _remove_from_installed(self, cog: Installable):
""" """
@ -86,7 +86,7 @@ class Downloader:
if cog_json in installed: if cog_json in installed:
installed.remove(cog_json) installed.remove(cog_json)
await self.conf.set("installed", installed) await self.conf.installed.set(installed)
async def _reinstall_cogs(self, cogs: Tuple[Installable]) -> Tuple[Installable]: async def _reinstall_cogs(self, cogs: Tuple[Installable]) -> Tuple[Installable]:
""" """

View File

@ -526,4 +526,4 @@ class RepoManager:
async def _save_repos(self): async def _save_repos(self):
repo_json_info = {name: r.to_json() for name, r in self._repos.items()} repo_json_info = {name: r.to_json() for name, r in self._repos.items()}
await self.downloader_config.set("repos", repo_json_info) await self.downloader_config.repos.set(repo_json_info)

View File

@ -47,7 +47,7 @@ class Red(commands.Bot):
kwargs["owner_id"] = cli_flags.owner kwargs["owner_id"] = cli_flags.owner
if "owner_id" not in kwargs: if "owner_id" not in kwargs:
kwargs["owner_id"] = self.db.get("owner") kwargs["owner_id"] = self.db.owner()
self.counter = Counter() self.counter = Counter()
self.uptime = None self.uptime = None
@ -89,7 +89,7 @@ class Red(commands.Bot):
for package in self.extensions: for package in self.extensions:
if package.startswith("cogs."): if package.startswith("cogs."):
loaded.append(package) loaded.append(package)
await self.db.set("packages", loaded) await self.db.packages.set(loaded)
class ExitCodes(Enum): class ExitCodes(Enum):

View File

@ -22,7 +22,7 @@ def interactive_config(red, token_set, prefix_set):
print("That doesn't look like a valid token.") print("That doesn't look like a valid token.")
token = "" token = ""
if token: if token:
loop.run_until_complete(red.db.set("token", token)) loop.run_until_complete(red.db.token.set(token))
if not prefix_set: if not prefix_set:
prefix = "" prefix = ""
@ -39,7 +39,7 @@ def interactive_config(red, token_set, prefix_set):
if not confirm("> "): if not confirm("> "):
prefix = "" prefix = ""
if prefix: if prefix:
loop.run_until_complete(red.db.set("prefix", [prefix])) loop.run_until_complete(red.db.prefix.set([prefix]))
ask_sentry(red) ask_sentry(red)
@ -55,9 +55,9 @@ def ask_sentry(red: Red):
" found issues in a timely manner. If you wish to opt in\n" " found issues in a timely manner. If you wish to opt in\n"
" the process please type \"yes\":\n") " the process please type \"yes\":\n")
if not confirm("> "): if not confirm("> "):
loop.run_until_complete(red.db.set("enable_sentry", False)) loop.run_until_complete(red.db.enable_sentry.set(False))
else: else:
loop.run_until_complete(red.db.set("enable_sentry", True)) loop.run_until_complete(red.db.enable_sentry.set(True))
print("\nThank you for helping us with the development process!") print("\nThank you for helping us with the development process!")

View File

@ -1,521 +1,374 @@
from pathlib import Path
from core.drivers.red_json import JSON as JSONDriver
from core.drivers.red_mongo import Mongo
import logging import logging
from typing import Callable from typing import Callable, Union, Tuple
import discord
from copy import deepcopy
from pathlib import Path
from .drivers.red_json import JSON as JSONDriver
log = logging.getLogger("red.config") log = logging.getLogger("red.config")
class BaseConfig:
def __init__(self, cog_name, unique_identifier, driver_spawn, force_registration=False,
hash_uuid=True, collection="GLOBAL", collection_uuid=None,
defaults={}):
self.cog_name = cog_name
if hash_uuid:
self.uuid = str(hash(unique_identifier))
else:
self.uuid = unique_identifier
self.driver_spawn = driver_spawn
self._driver = None
self.collection = collection
self.collection_uuid = collection_uuid
self.force_registration = force_registration class Value:
def __init__(self, identifiers: Tuple[str], default_value, spawner):
self._identifiers = identifiers
self.default = default_value
self.spawner = spawner
@property
def identifiers(self):
return tuple(str(i) for i in self._identifiers)
def __call__(self, default=None):
driver = self.spawner.get_driver()
try: try:
self.driver.maybe_add_ident(self.uuid) ret = driver.get(self.identifiers)
except AttributeError: except KeyError:
pass return default or self.default
return ret
self.driver_getmap = { async def set(self, value):
"GLOBAL": self.driver.get_global, driver = self.spawner.get_driver()
"GUILD": self.driver.get_guild, await driver.set(self.identifiers, value)
"CHANNEL": self.driver.get_channel,
"ROLE": self.driver.get_role,
"USER": self.driver.get_user
}
self.driver_setmap = {
"GLOBAL": self.driver.set_global,
"GUILD": self.driver.set_guild,
"CHANNEL": self.driver.set_channel,
"ROLE": self.driver.set_role,
"USER": self.driver.set_user
}
self.curr_key = None class Group(Value):
def __init__(self, identifiers: Tuple[str],
defaults: dict,
spawner,
force_registration: bool=False):
self.defaults = defaults
self.force_registration = force_registration
self.spawner = spawner
self.unsettable_keys = ("cog_name", "cog_identifier", "_id", super().__init__(identifiers, {}, self.spawner)
"guild_id", "channel_id", "role_id",
"user_id", "uuid") # noinspection PyTypeChecker
self.invalid_keys = ( def __getattr__(self, item: str) -> Union["Group", Value]:
"driver_spawn", """
"_driver", "collection", Takes in the next accessible item. If it's found to be a Group
"collection_uuid", "force_registration" we return another Group object. If it's found to be a Value
we return a Value object. If it is not found and
force_registration is True then we raise AttributeException,
otherwise return a Value object.
:param item:
:return:
"""
is_group = self.is_group(item)
is_value = not is_group and self.is_value(item)
new_identifiers = self.identifiers + (item, )
if is_group:
return Group(
identifiers=new_identifiers,
defaults=self.defaults[item],
spawner=self.spawner,
force_registration=self.force_registration
)
elif is_value:
return Value(
identifiers=new_identifiers,
default_value=self.defaults[item],
spawner=self.spawner
)
elif self.force_registration:
raise AttributeError(
"'{}' is not a valid registered Group"
"or value.".format(item)
)
else:
return Value(
identifiers=new_identifiers,
default_value=None,
spawner=self.spawner
)
@property
def _super_group(self) -> 'Group':
super_group = Group(
self.identifiers[:-1],
defaults={},
spawner=self.spawner,
force_registration=self.force_registration
) )
return super_group
self.defaults = defaults if defaults else { def is_group(self, item: str) -> bool:
"GLOBAL": {}, "GUILD": {}, "CHANNEL": {}, "ROLE": {}, """
"MEMBER": {}, "USER": {}} Determines if an attribute access is pointing at a registered group.
:param item:
:return:
"""
default = self.defaults.get(item)
return isinstance(default, dict)
def is_value(self, item: str) -> bool:
"""
Determines if an attribute access is pointing at a registered value.
:param item:
:return:
"""
try:
default = self.defaults[item]
except KeyError:
return False
return not isinstance(default, dict)
def get_attr(self, item: str, default=None):
"""
You should avoid this function whenever possible.
:param item:
:param default:
:return:
"""
value = getattr(self, item)
return value(default=default)
def all(self) -> dict:
"""
Gets all entries of the given kind. If this kind is member
then this method returns all members from the same
server.
:return:
"""
# noinspection PyTypeChecker
return self._super_group()
async def set(self, value):
if not isinstance(value, dict):
raise ValueError(
"You may only set the value of a group to be a dict."
)
await super().set(value)
async def set_attr(self, item: str, value):
"""
You should avoid this function whenever possible.
:param item:
:param value:
:return:
"""
value_obj = getattr(self, item)
await value_obj.set(value)
async def clear(self):
"""
Wipes out data for the given entry in this category
e.g. Guild/Role/User
:return:
"""
await self.set({})
async def clear_all(self):
"""
Removes all data from all entries.
:return:
"""
await self._super_group.set({})
class MemberGroup(Group):
@property
def _super_group(self) -> Group:
new_identifiers = self.identifiers[:2]
group_obj = Group(
identifiers=new_identifiers,
defaults={},
spawner=self.spawner
)
return group_obj
@property
def _guild_group(self) -> Group:
new_identifiers = self.identifiers[:3]
group_obj = Group(
identifiers=new_identifiers,
defaults={},
spawner=self.spawner
)
return group_obj
def all_guilds(self) -> dict:
"""
Gets a dict of all guilds and members.
REMEMBER: ID's are stored in these dicts as STRINGS.
:return:
"""
# noinspection PyTypeChecker
return self._super_group()
def all(self) -> dict:
"""
Returns the dict of all members in the same guild.
:return:
"""
# noinspection PyTypeChecker
return self._guild_group()
class Config:
GLOBAL = "GLOBAL"
GUILD = "GUILD"
CHANNEL = "TEXTCHANNEL"
ROLE = "ROLE"
USER = "USER"
MEMBER = "MEMBER"
def __init__(self, cog_name: str, unique_identifier: str,
driver_spawn: Callable,
force_registration: bool=False,
defaults: dict=None):
self.cog_name = cog_name
self.unique_identifier = unique_identifier
self.spawner = driver_spawn
self.force_registration = force_registration
self.defaults = defaults or {}
@classmethod @classmethod
def get_conf(cls, cog_instance: object, unique_identifier: int=0, def get_conf(cls, cog_instance, identifier: int,
force_registration: bool=False): force_registration=False):
""" """
Gets a config object that cog's can use to safely store data. The Returns a Config instance based on a simplified set of initial
backend to this is totally modular and can easily switch between variables.
JSON and a DB. However, when changed, all data will likely be lost :param cog_instance:
unless cogs write some converters for their data. :param identifier: Any random integer, used to keep your data
distinct from any other cog with the same name.
Positional Arguments: :param force_registration: Should config require registration
cog_instance - The cog `self` object, can be passed in from your of data keys before allowing you to get/set values?
cog's __init__ method. :return:
Keyword Arguments:
unique_identifier - a random integer or string that is used to
differentiate your cog from any other named the same. This way we
can safely store data for multiple cogs that are named the same.
YOU SHOULD USE THIS.
force_registration - A flag which will cause the Config object to
throw exceptions if you try to get/set data keys that you have
not pre-registered. I highly recommend you ENABLE this as it
will help reduce dumb typo errors.
""" """
url = None # TODO: get mongo url
port = None # TODO: get mongo port
def spawn_mongo_driver():
return Mongo(url, port)
# TODO: Determine which backend users want, default to JSON
cog_name = cog_instance.__class__.__name__ cog_name = cog_instance.__class__.__name__
uuid = str(hash(identifier))
driver_spawn = JSONDriver(cog_name) spawner = JSONDriver(cog_name)
return cls(cog_name=cog_name, unique_identifier=uuid,
return cls(cog_name=cog_name, unique_identifier=unique_identifier, force_registration=force_registration,
driver_spawn=driver_spawn, force_registration=force_registration) driver_spawn=spawner)
@classmethod @classmethod
def get_core_conf(cls, force_registration: bool=False): def get_core_conf(cls, force_registration: bool=False):
core_data_path = Path.cwd() / 'core' / '.data' core_data_path = Path.cwd() / 'core' / '.data'
driver_spawn = JSONDriver("Core", data_path_override=core_data_path) driver_spawn = JSONDriver("Core", data_path_override=core_data_path)
return cls(cog_name="Core", driver_spawn=driver_spawn, return cls(cog_name="Core", driver_spawn=driver_spawn,
unique_identifier=0, unique_identifier='0',
force_registration=force_registration) force_registration=force_registration)
@property def __getattr__(self, item: str) -> Union[Group, Value]:
def driver(self):
if self._driver is None:
try:
self._driver = self.driver_spawn()
except TypeError:
return self.driver_spawn
return self._driver
def __getattr__(self, key):
"""This should be used to return config key data as determined by
`self.collection` and `self.collection_uuid`."""
raise NotImplemented
def __setattr__(self, key, value):
if 'defaults' in self.__dict__: # Necessary to let the cog load
restricted = list(self.defaults[self.collection].keys()) + \
list(self.unsettable_keys)
if key in restricted:
raise ValueError("Not allowed to dynamically set attributes of"
" unsettable_keys: {}".format(restricted))
else:
self.__dict__[key] = value
else:
self.__dict__[key] = value
def clear(self):
"""Clears all values in the current context ONLY."""
raise NotImplemented
def set(self, key, value):
"""This should set config key with value `value` in the
corresponding collection as defined by `self.collection` and
`self.collection_uuid`."""
raise NotImplemented
def guild(self, guild):
"""This should return a `BaseConfig` instance with the corresponding
`collection` and `collection_uuid`."""
raise NotImplemented
def channel(self, channel):
"""This should return a `BaseConfig` instance with the corresponding
`collection` and `collection_uuid`."""
raise NotImplemented
def role(self, role):
"""This should return a `BaseConfig` instance with the corresponding
`collection` and `collection_uuid`."""
raise NotImplemented
def member(self, member):
"""This should return a `BaseConfig` instance with the corresponding
`collection` and `collection_uuid`."""
raise NotImplemented
def user(self, user):
"""This should return a `BaseConfig` instance with the corresponding
`collection` and `collection_uuid`."""
raise NotImplemented
def register_global(self, **global_defaults):
""" """
Registers a new dict of global defaults. This function should This is used to generate Value or Group objects for global
be called EVERY TIME the cog loads (aka just do it in values.
__init__)! :param item:
:param global_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return: :return:
""" """
for k, v in global_defaults.items(): global_group = self._get_base_group(self.GLOBAL)
try: return getattr(global_group, item)
self._register_global(k, v)
except KeyError:
log.exception("Bad default global key.")
def _register_global(self, key, default=None): @staticmethod
"""Registers a global config key `key`""" def _get_defaults_dict(key: str, value) -> dict:
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["GLOBAL"][key] = default
def register_guild(self, **guild_defaults):
""" """
Registers a new dict of guild defaults. This function should Since we're allowing nested config stuff now, not storing the
be called EVERY TIME the cog loads (aka just do it in defaults as a flat dict sounds like a good idea. May turn
__init__)! out to be an awful one but we'll see.
:param guild_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in guild_defaults.items():
try:
self._register_guild(k, v)
except KeyError:
log.exception("Bad default guild key.")
def _register_guild(self, key, default=None):
"""Registers a guild config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["GUILD"][key] = default
def register_channel(self, **channel_defaults):
"""
Registers a new dict of channel defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param channel_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in channel_defaults.items():
try:
self._register_channel(k, v)
except KeyError:
log.exception("Bad default channel key.")
def _register_channel(self, key, default=None):
"""Registers a channel config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["CHANNEL"][key] = default
def register_role(self, **role_defaults):
"""
Registers a new dict of role defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param role_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in role_defaults.items():
try:
self._register_role(k, v)
except KeyError:
log.exception("Bad default role key.")
def _register_role(self, key, default=None):
"""Registers a role config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["ROLE"][key] = default
def register_member(self, **member_defaults):
"""
Registers a new dict of member defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param member_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in member_defaults.items():
try:
self._register_member(k, v)
except KeyError:
log.exception("Bad default member key.")
def _register_member(self, key, default=None):
"""Registers a member config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["MEMBER"][key] = default
def register_user(self, **user_defaults):
"""
Registers a new dict of user defaults. This function should
be called EVERY TIME the cog loads (aka just do it in
__init__)!
:param user_defaults: Each key should be the key you want to
access data by and the value is the default value of that
key.
:return:
"""
for k, v in user_defaults.items():
try:
self._register_user(k, v)
except KeyError:
log.exception("Bad default user key.")
def _register_user(self, key, default=None):
"""Registers a user config key `key`"""
if key in self.unsettable_keys:
raise KeyError("Attempt to use restricted key: '{}'".format(key))
elif not key.isidentifier():
raise RuntimeError("Invalid key name, must be a valid python variable"
" name.")
self.defaults["USER"][key] = default
class Config(BaseConfig):
"""
Config object created by `Config.get_conf()`
This configuration object is designed to make backend data
storage mechanisms pluggable. It also is designed to
help a cog developer make fewer mistakes (such as
typos) when dealing with cog data and to make those mistakes
apparent much faster in the design process.
It also has the capability to safely store data between cogs
that share the same name.
There are two main components to this config object. First,
you have the ability to get data on a level specific basis.
The seven levels available are: global, guild, channel, role,
member, user, and misc.
The second main component is registering default values for
data in each of the levels. This functionality is OPTIONAL
and must be explicitly enabled when creating the Config object
using the kwarg `force_registration=True`.
Basic Usage:
Creating a Config object:
Use the `Config.get_conf()` class method to create new
Config objects.
See the `Config.get_conf()` documentation for more
information.
Registering Default Values (optional):
You can register default values for data at all levels
EXCEPT misc.
Simply pass in the key/value pairs as keyword arguments to
the respective function.
e.g.: conf_obj.register_global(enabled=True)
conf_obj.register_guild(likes_red=True)
Retrieving data by attributes:
Since I registered the "enabled" key in the previous example
at the global level I can now do:
conf_obj.enabled()
which will retrieve the current value of the "enabled"
key, making use of the default of "True". I can also do
the same for the guild key "likes_red":
conf_obj.guild(guild_obj).likes_red()
If I elected to not register default values, you can provide them
when you try to access the key:
conf_obj.no_default(default=True)
However if you do not provide a default and you do not register
defaults, accessing the attribute will return "None".
Saving data:
This is accomplished by using the `set` function available at
every level.
e.g.: conf_obj.set("enabled", False)
conf_obj.guild(guild_obj).set("likes_red", False)
If `force_registration` was enabled when the config object
was created you will only be allowed to save keys that you
have registered.
Misc data is special, use `conf.misc()` and `conf.set_misc(value)`
respectively.
"""
def __getattr__(self, key) -> Callable:
"""
Until I've got a better way to do this I'm just gonna fake __call__
:param key: :param key:
:return: lambda function with kwarg :param value:
:return:
""" """
return self._get_value_from_key(key) ret = {}
partial = ret
def _get_value_from_key(self, key) -> Callable: splitted = key.split('__')
try: for i, k in enumerate(splitted, start=1):
default = self.defaults[self.collection][key] if not k.isidentifier():
except KeyError as e: raise RuntimeError("'{}' is an invalid config key.".format(k))
if self.force_registration: if i == len(splitted):
raise AttributeError("Key '{}' not registered!".format(key)) from e partial[k] = value
default = None else:
partial[k] = {}
self.curr_key = key partial = partial[k]
if self.collection != "MEMBER":
ret = lambda default=default: self.driver_getmap[self.collection](
self.cog_name, self.uuid, self.collection_uuid, key,
default=default)
else:
mid, sid = self.collection_uuid
ret = lambda default=default: self.driver.get_member(
self.cog_name, self.uuid, mid, sid, key,
default=default)
return ret return ret
def get(self, key, default=None): @staticmethod
def _update_defaults(to_add: dict, _partial: dict):
""" """
Included as an alternative to registering defaults. This tries to update the defaults dictionary with the nested
partial dict generated by _get_defaults_dict. This WILL
:param key: throw an error if you try to have both a value and a group
:param default: registered under the same name.
:param to_add:
:param _partial:
:return: :return:
""" """
for k, v in to_add.items():
val_is_dict = isinstance(v, dict)
if k in _partial:
existing_is_dict = isinstance(_partial[k], dict)
if val_is_dict != existing_is_dict:
# != is XOR
raise KeyError("You cannot register a Group and a Value under"
" the same name.")
if val_is_dict:
Config._update_defaults(v, _partial=_partial[k])
else:
_partial[k] = v
else:
_partial[k] = v
if default is not None: def _register_default(self, key: str, **kwargs):
return self._get_value_from_key(key)(default) if key not in self.defaults:
else: self.defaults[key] = {}
return self._get_value_from_key(key)()
async def set(self, key, value): data = deepcopy(kwargs)
# Notice to future developers:
# This code was commented to allow users to set keys without having to register them.
# That being said, if they try to get keys without registering them
# things will blow up. I do highly recommend enforcing the key registration.
if key in self.unsettable_keys or key in self.invalid_keys: for k, v in data.items():
raise KeyError("Restricted key name, please use another.") to_add = self._get_defaults_dict(k, v)
self._update_defaults(to_add, self.defaults[key])
if self.force_registration and key not in self.defaults[self.collection]: def register_global(self, **kwargs):
raise AttributeError("Key '{}' not registered!".format(key)) self._register_default(self.GLOBAL, **kwargs)
if not key.isidentifier(): def register_guild(self, **kwargs):
raise RuntimeError("Invalid key name, must be a valid python variable" self._register_default(self.GUILD, **kwargs)
" name.")
if self.collection == "GLOBAL": def register_channel(self, **kwargs):
await self.driver.set_global(self.cog_name, self.uuid, key, value) # We may need to add a voice channel category later
elif self.collection == "MEMBER": self._register_default(self.CHANNEL, **kwargs)
mid, sid = self.collection_uuid
await self.driver.set_member(self.cog_name, self.uuid, mid, sid,
key, value)
elif self.collection in self.driver_setmap:
func = self.driver_setmap[self.collection]
await func(self.cog_name, self.uuid, self.collection_uuid, key, value)
async def clear(self): def register_role(self, **kwargs):
await self.driver_setmap[self.collection]( self._register_default(self.ROLE, **kwargs)
self.cog_name, self.uuid, self.collection_uuid, None, None,
clear=True)
def guild(self, guild): def register_user(self, **kwargs):
new = type(self)(self.cog_name, self.uuid, self.driver, self._register_default(self.USER, **kwargs)
hash_uuid=False, defaults=self.defaults)
new.collection = "GUILD"
new.collection_uuid = guild.id
new._driver = None
return new
def channel(self, channel): def register_member(self, **kwargs):
new = type(self)(self.cog_name, self.uuid, self.driver, self._register_default(self.MEMBER, **kwargs)
hash_uuid=False, defaults=self.defaults)
new.collection = "CHANNEL"
new.collection_uuid = channel.id
new._driver = None
return new
def role(self, role): def _get_base_group(self, key: str, *identifiers: str,
new = type(self)(self.cog_name, self.uuid, self.driver, group_class=Group) -> Group:
hash_uuid=False, defaults=self.defaults) # noinspection PyTypeChecker
new.collection = "ROLE" return group_class(
new.collection_uuid = role.id identifiers=(self.unique_identifier, key) + identifiers,
new._driver = None defaults=self.defaults.get(key, {}),
return new spawner=self.spawner,
force_registration=self.force_registration
)
def member(self, member): def guild(self, guild: discord.Guild) -> Group:
guild = member.guild return self._get_base_group(self.GUILD, guild.id)
new = type(self)(self.cog_name, self.uuid, self.driver,
hash_uuid=False, defaults=self.defaults) def channel(self, channel: discord.TextChannel) -> Group:
new.collection = "MEMBER" return self._get_base_group(self.CHANNEL, channel.id)
new.collection_uuid = (member.id, guild.id)
new._driver = None def role(self, role: discord.Role) -> Group:
return new return self._get_base_group(self.ROLE, role.id)
def user(self, user: discord.User) -> Group:
return self._get_base_group(self.USER, user.id)
def member(self, member: discord.Member) -> MemberGroup:
return self._get_base_group(self.MEMBER, member.guild.id, member.id,
group_class=MemberGroup)
def user(self, user):
new = type(self)(self.cog_name, self.uuid, self.driver,
hash_uuid=False, defaults=self.defaults)
new.collection = "USER"
new.collection_uuid = user.id
new._driver = None
return new

View File

@ -98,7 +98,7 @@ class Core:
@commands.guild_only() @commands.guild_only()
async def adminrole(self, ctx, *, role: discord.Role): async def adminrole(self, ctx, *, role: discord.Role):
"""Sets the admin role for this server""" """Sets the admin role for this server"""
await ctx.bot.db.guild(ctx.guild).set("admin_role", role.id) await ctx.bot.db.guild(ctx.guild).admin_role.set(role.id)
await ctx.send("The admin role for this server has been set.") await ctx.send("The admin role for this server has been set.")
@_set.command() @_set.command()
@ -106,7 +106,7 @@ class Core:
@commands.guild_only() @commands.guild_only()
async def modrole(self, ctx, *, role: discord.Role): async def modrole(self, ctx, *, role: discord.Role):
"""Sets the mod role for this server""" """Sets the mod role for this server"""
await ctx.bot.db.guild(ctx.guild).set("mod_role", role.id) await ctx.bot.db.guild(ctx.guild).mod_role.set(role.id)
await ctx.send("The mod role for this server has been set.") await ctx.send("The mod role for this server has been set.")
@_set.command() @_set.command()
@ -225,7 +225,7 @@ class Core:
await ctx.bot.send_cmd_help(ctx) await ctx.bot.send_cmd_help(ctx)
return return
prefixes = sorted(prefixes, reverse=True) prefixes = sorted(prefixes, reverse=True)
await ctx.bot.db.set("prefix", prefixes) await ctx.bot.db.prefix.set(prefixes)
await ctx.send("Prefix set.") await ctx.send("Prefix set.")
@_set.command(aliases=["serverprefixes"]) @_set.command(aliases=["serverprefixes"])
@ -234,11 +234,11 @@ class Core:
async def serverprefix(self, ctx, *prefixes): async def serverprefix(self, ctx, *prefixes):
"""Sets Red's server prefix(es)""" """Sets Red's server prefix(es)"""
if not prefixes: if not prefixes:
await ctx.bot.db.guild(ctx.guild).set("prefix", []) await ctx.bot.db.guild(ctx.guild).prefix.set([])
await ctx.send("Server prefixes have been reset.") await ctx.send("Server prefixes have been reset.")
return return
prefixes = sorted(prefixes, reverse=True) prefixes = sorted(prefixes, reverse=True)
await ctx.bot.db.guild(ctx.guild).set("prefix", prefixes) await ctx.bot.db.guild(ctx.guild).prefix.set(prefixes)
await ctx.send("Prefix set.") await ctx.send("Prefix set.")
@_set.command() @_set.command()
@ -276,7 +276,7 @@ class Core:
else: else:
if message.content.strip() == token: if message.content.strip() == token:
self.owner.reset_cooldown(ctx) self.owner.reset_cooldown(ctx)
await ctx.bot.db.set("owner", ctx.author.id) await ctx.bot.db.owner.set(ctx.author.id)
ctx.bot.owner_id = ctx.author.id ctx.bot.owner_id = ctx.author.id
await ctx.send("You have been set as owner.") await ctx.send("You have been set as owner.")
else: else:

View File

@ -1,45 +1,12 @@
from typing import Tuple
class BaseDriver: class BaseDriver:
def get_global(self, cog_name, ident, collection_id, key, *, default=None): def get_driver(self):
raise NotImplementedError() raise NotImplementedError
def get_guild(self, cog_name, ident, guild_id, key, *, default=None): def get(self, identifiers: Tuple[str]):
raise NotImplementedError() raise NotImplementedError
def get_channel(self, cog_name, ident, channel_id, key, *, default=None): async def set(self, identifiers: Tuple[str], value):
raise NotImplementedError() raise NotImplementedError
def get_role(self, cog_name, ident, role_id, key, *, default=None):
raise NotImplementedError()
def get_member(self, cog_name, ident, user_id, guild_id, key, *,
default=None):
raise NotImplementedError()
def get_user(self, cog_name, ident, user_id, key, *, default=None):
raise NotImplementedError()
def get_misc(self, cog_name, ident, *, default=None):
raise NotImplementedError()
async def set_global(self, cog_name, ident, key, value, clear=False):
raise NotImplementedError()
async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False):
raise NotImplementedError()
async def set_channel(self, cog_name, ident, channel_id, key, value,
clear=False):
raise NotImplementedError()
async def set_role(self, cog_name, ident, role_id, key, value, clear=False):
raise NotImplementedError()
async def set_member(self, cog_name, ident, user_id, guild_id, key, value,
clear=False):
raise NotImplementedError()
async def set_user(self, cog_name, ident, user_id, key, value, clear=False):
raise NotImplementedError()
async def set_misc(self, cog_name, ident, value, clear=False):
raise NotImplementedError()

View File

@ -1,13 +1,15 @@
from typing import Tuple
from core.drivers.red_base import BaseDriver
from core.json_io import JsonIO from core.json_io import JsonIO
import os
from .red_base import BaseDriver
from pathlib import Path from pathlib import Path
class JSON(BaseDriver): class JSON(BaseDriver):
def __init__(self, cog_name, *args, data_path_override: Path=None, def __init__(self, cog_name, *, data_path_override: Path=None,
file_name_override: str="settings.json", **kwargs): file_name_override: str="settings.json"):
super().__init__()
self.cog_name = cog_name self.cog_name = cog_name
self.file_name = file_name_override self.file_name = file_name_override
if data_path_override: if data_path_override:
@ -25,111 +27,23 @@ class JSON(BaseDriver):
self.data = self.jsonIO._load_json() self.data = self.jsonIO._load_json()
except FileNotFoundError: except FileNotFoundError:
self.data = {} self.data = {}
def maybe_add_ident(self, ident: str):
if ident in self.data:
return
self.data[ident] = {}
for k in ("GLOBAL", "GUILD", "CHANNEL", "ROLE", "MEMBER", "USER"):
if k not in self.data[ident]:
self.data[ident][k] = {}
self.jsonIO._save_json(self.data) self.jsonIO._save_json(self.data)
def get_global(self, cog_name, ident, _, key, *, default=None): def get_driver(self):
return self.data[ident]["GLOBAL"].get(key, default) return self
def get_guild(self, cog_name, ident, guild_id, key, *, default=None): def get(self, identifiers: Tuple[str]):
guilddata = self.data[ident]["GUILD"].get(str(guild_id), {}) partial = self.data
return guilddata.get(key, default) for i in identifiers:
partial = partial[i]
return partial
def get_channel(self, cog_name, ident, channel_id, key, *, default=None): async def set(self, identifiers, value):
channeldata = self.data[ident]["CHANNEL"].get(str(channel_id), {}) partial = self.data
return channeldata.get(key, default) for i in identifiers[:-1]:
if i not in partial:
partial[i] = {}
partial = partial[i]
def get_role(self, cog_name, ident, role_id, key, *, default=None): partial[identifiers[-1]] = value
roledata = self.data[ident]["ROLE"].get(str(role_id), {})
return roledata.get(key, default)
def get_member(self, cog_name, ident, user_id, guild_id, key, *,
default=None):
userdata = self.data[ident]["MEMBER"].get(str(user_id), {})
guilddata = userdata.get(str(guild_id), {})
return guilddata.get(key, default)
def get_user(self, cog_name, ident, user_id, key, *, default=None):
userdata = self.data[ident]["USER"].get(str(user_id), {})
return userdata.get(key, default)
async def set_global(self, cog_name, ident, key, value, clear=False):
if clear:
self.data[ident]["GLOBAL"] = {}
else:
self.data[ident]["GLOBAL"][key] = value
await self.jsonIO._threadsafe_save_json(self.data)
async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False):
guild_id = str(guild_id)
if clear:
self.data[ident]["GUILD"][guild_id] = {}
else:
try:
self.data[ident]["GUILD"][guild_id][key] = value
except KeyError:
self.data[ident]["GUILD"][guild_id] = {}
self.data[ident]["GUILD"][guild_id][key] = value
await self.jsonIO._threadsafe_save_json(self.data)
async def set_channel(self, cog_name, ident, channel_id, key, value, clear=False):
channel_id = str(channel_id)
if clear:
self.data[ident]["CHANNEL"][channel_id] = {}
else:
try:
self.data[ident]["CHANNEL"][channel_id][key] = value
except KeyError:
self.data[ident]["CHANNEL"][channel_id] = {}
self.data[ident]["CHANNEL"][channel_id][key] = value
await self.jsonIO._threadsafe_save_json(self.data)
async def set_role(self, cog_name, ident, role_id, key, value, clear=False):
role_id = str(role_id)
if clear:
self.data[ident]["ROLE"][role_id] = {}
else:
try:
self.data[ident]["ROLE"][role_id][key] = value
except KeyError:
self.data[ident]["ROLE"][role_id] = {}
self.data[ident]["ROLE"][role_id][key] = value
await self.jsonIO._threadsafe_save_json(self.data)
async def set_member(self, cog_name, ident, user_id, guild_id, key, value, clear=False):
user_id = str(user_id)
guild_id = str(guild_id)
if clear:
self.data[ident]["MEMBER"][user_id] = {}
else:
try:
self.data[ident]["MEMBER"][user_id][guild_id][key] = value
except KeyError:
if user_id not in self.data[ident]["MEMBER"]:
self.data[ident]["MEMBER"][user_id] = {}
if guild_id not in self.data[ident]["MEMBER"][user_id]:
self.data[ident]["MEMBER"][user_id][guild_id] = {}
self.data[ident]["MEMBER"][user_id][guild_id][key] = value
await self.jsonIO._threadsafe_save_json(self.data)
async def set_user(self, cog_name, ident, user_id, key, value, clear=False):
user_id = str(user_id)
if clear:
self.data[ident]["USER"][user_id] = {}
else:
try:
self.data[ident]["USER"][user_id][key] = value
except KeyError:
self.data[ident]["USER"][user_id] = {}
self.data[ident]["USER"][user_id][key] = value
await self.jsonIO._threadsafe_save_json(self.data) await self.jsonIO._threadsafe_save_json(self.data)

View File

@ -113,7 +113,7 @@ if __name__ == '__main__':
if db_token and not cli_flags.no_prompt: if db_token and not cli_flags.no_prompt:
print("\nDo you want to reset the token? (y/n)") print("\nDo you want to reset the token? (y/n)")
if confirm("> "): if confirm("> "):
loop.run_until_complete(red.db.set("token", "")) loop.run_until_complete(red.db.token.set(""))
print("Token has been reset.") print("Token has been reset.")
except KeyboardInterrupt: except KeyboardInterrupt:
log.info("Keyboard interrupt detected. Quitting...") log.info("Keyboard interrupt detected. Quitting...")

View File

@ -2,12 +2,11 @@ from cogs.alias import Alias
import pytest import pytest
@pytest.fixture @pytest.fixture()
def alias(monkeysession, config): def alias(config):
def get_mock_conf(*args, **kwargs): import cogs.alias.alias
return config
monkeysession.setattr("core.config.Config.get_conf", get_mock_conf) cogs.alias.alias.Config.get_conf = lambda *args, **kwargs: config
return Alias(None) return Alias(None)
@ -25,9 +24,17 @@ def test_empty_global_aliases(alias):
assert list(alias.unloaded_global_aliases()) == [] assert list(alias.unloaded_global_aliases()) == []
async def create_test_guild_alias(alias, ctx):
await alias.add_alias(ctx, "test", "ping", global_=False)
async def create_test_global_alias(alias, ctx):
await alias.add_alias(ctx, "test", "ping", global_=True)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_guild_alias(alias, ctx): async def test_add_guild_alias(alias, ctx):
await alias.add_alias(ctx, "test", "ping", global_=False) await create_test_guild_alias(alias, ctx)
is_alias, alias_obj = alias.is_alias(ctx.guild, "test") is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
assert is_alias is True assert is_alias is True
@ -36,6 +43,7 @@ async def test_add_guild_alias(alias, ctx):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_guild_alias(alias, ctx): async def test_delete_guild_alias(alias, ctx):
await create_test_guild_alias(alias, ctx)
is_alias, _ = alias.is_alias(ctx.guild, "test") is_alias, _ = alias.is_alias(ctx.guild, "test")
assert is_alias is True assert is_alias is True
@ -47,7 +55,7 @@ async def test_delete_guild_alias(alias, ctx):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_global_alias(alias, ctx): async def test_add_global_alias(alias, ctx):
await alias.add_alias(ctx, "test", "ping", global_=True) await create_test_global_alias(alias, ctx)
is_alias, alias_obj = alias.is_alias(ctx.guild, "test") is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
assert is_alias is True assert is_alias is True
@ -56,6 +64,7 @@ async def test_add_global_alias(alias, ctx):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_global_alias(alias, ctx): async def test_delete_global_alias(alias, ctx):
await create_test_global_alias(alias, ctx)
is_alias, alias_obj = alias.is_alias(ctx.guild, "test") is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
assert is_alias is True assert is_alias is True
assert alias_obj.global_ is True assert alias_obj.global_ is True

View File

@ -17,21 +17,27 @@ def monkeysession(request):
mpatch.undo() mpatch.undo()
@pytest.fixture(scope="module") @pytest.fixture()
def json_driver(tmpdir_factory): def json_driver(tmpdir_factory):
import uuid
rand = str(uuid.uuid4())
path = Path(str(tmpdir_factory.mktemp(rand)))
driver = red_json.JSON( driver = red_json.JSON(
"PyTest", "PyTest",
data_path_override=Path(str(tmpdir_factory.getbasetemp())) data_path_override=path
) )
return driver return driver
@pytest.fixture() @pytest.fixture()
def config(json_driver): def config(json_driver):
return Config( import uuid
conf = Config(
cog_name="PyTest", cog_name="PyTest",
unique_identifier=0, unique_identifier=str(uuid.uuid4()),
driver_spawn=json_driver) driver_spawn=json_driver)
yield conf
conf.defaults = {}
@pytest.fixture() @pytest.fixture()
@ -39,19 +45,32 @@ def config_fr(json_driver):
""" """
Mocked config object with force_register enabled. Mocked config object with force_register enabled.
""" """
return Config( import uuid
conf = Config(
cog_name="PyTest", cog_name="PyTest",
unique_identifier=0, unique_identifier=str(uuid.uuid4()),
driver_spawn=json_driver, driver_spawn=json_driver,
force_registration=True force_registration=True
) )
yield conf
conf.defaults = {}
#region Dpy Mocks #region Dpy Mocks
@pytest.fixture(scope="module") @pytest.fixture()
def empty_guild(): def guild_factory():
mock_guild = namedtuple("Guild", "id members") mock_guild = namedtuple("Guild", "id members")
return mock_guild(random.randint(1, 999999999), [])
class GuildFactory:
def get(self):
return mock_guild(random.randint(1, 999999999), [])
return GuildFactory()
@pytest.fixture()
def empty_guild(guild_factory):
return guild_factory.get()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -66,16 +85,39 @@ def empty_role():
return mock_role(random.randint(1, 999999999)) return mock_role(random.randint(1, 999999999))
@pytest.fixture(scope="module") @pytest.fixture()
def empty_member(empty_guild): def member_factory(guild_factory):
mock_member = namedtuple("Member", "id guild") mock_member = namedtuple("Member", "id guild")
return mock_member(random.randint(1, 999999999), empty_guild)
class MemberFactory:
def get(self):
return mock_member(
random.randint(1, 999999999),
guild_factory.get())
return MemberFactory()
@pytest.fixture(scope="module") @pytest.fixture()
def empty_user(): def empty_member(member_factory):
return member_factory.get()
@pytest.fixture()
def user_factory():
mock_user = namedtuple("User", "id") mock_user = namedtuple("User", "id")
return mock_user(random.randint(1, 999999999))
class UserFactory:
def get(self):
return mock_user(
random.randint(1, 999999999))
return UserFactory()
@pytest.fixture()
def empty_user(user_factory):
return user_factory.get()
@pytest.fixture(scope="module") @pytest.fixture(scope="module")
@ -84,7 +126,7 @@ def empty_message():
return mock_msg("No content.") return mock_msg("No content.")
@pytest.fixture @pytest.fixture()
def ctx(empty_member, empty_channel, red): def ctx(empty_member, empty_channel, red):
mock_ctx = namedtuple("Context", "author guild channel message bot") mock_ctx = namedtuple("Context", "author guild channel message bot")
return mock_ctx(empty_member, empty_member.guild, empty_channel, return mock_ctx(empty_member, empty_member.guild, empty_channel,
@ -93,15 +135,14 @@ def ctx(empty_member, empty_channel, red):
#region Red Mock #region Red Mock
@pytest.fixture @pytest.fixture()
def red(monkeysession, config_fr): def red(config_fr):
from core.cli import parse_cli_flags from core.cli import parse_cli_flags
cli_flags = parse_cli_flags() cli_flags = parse_cli_flags()
description = "Red v3 - Alpha" description = "Red v3 - Alpha"
monkeysession.setattr("core.config.Config.get_core_conf", Config.get_core_conf = (lambda *args, **kwargs: config_fr)
lambda *args, **kwargs: config_fr)
red = Red(cli_flags, description=description, pm_help=None) red = Red(cli_flags, description=description, pm_help=None)

View File

@ -15,9 +15,9 @@ def test_config_register_global_badvalues(config):
def test_config_register_guild(config, empty_guild): def test_config_register_guild(config, empty_guild):
config.register_guild(enabled=False, some_list=[], some_dict={}) config.register_guild(enabled=False, some_list=[], some_dict={})
assert config.defaults["GUILD"]["enabled"] is False assert config.defaults[config.GUILD]["enabled"] is False
assert config.defaults["GUILD"]["some_list"] == [] assert config.defaults[config.GUILD]["some_list"] == []
assert config.defaults["GUILD"]["some_dict"] == {} assert config.defaults[config.GUILD]["some_dict"] == {}
assert config.guild(empty_guild).enabled() is False assert config.guild(empty_guild).enabled() is False
assert config.guild(empty_guild).some_list() == [] assert config.guild(empty_guild).some_list() == []
@ -26,25 +26,25 @@ def test_config_register_guild(config, empty_guild):
def test_config_register_channel(config, empty_channel): def test_config_register_channel(config, empty_channel):
config.register_channel(enabled=False) config.register_channel(enabled=False)
assert config.defaults["CHANNEL"]["enabled"] is False assert config.defaults[config.CHANNEL]["enabled"] is False
assert config.channel(empty_channel).enabled() is False assert config.channel(empty_channel).enabled() is False
def test_config_register_role(config, empty_role): def test_config_register_role(config, empty_role):
config.register_role(enabled=False) config.register_role(enabled=False)
assert config.defaults["ROLE"]["enabled"] is False assert config.defaults[config.ROLE]["enabled"] is False
assert config.role(empty_role).enabled() is False assert config.role(empty_role).enabled() is False
def test_config_register_member(config, empty_member): def test_config_register_member(config, empty_member):
config.register_member(some_number=-1) config.register_member(some_number=-1)
assert config.defaults["MEMBER"]["some_number"] == -1 assert config.defaults[config.MEMBER]["some_number"] == -1
assert config.member(empty_member).some_number() == -1 assert config.member(empty_member).some_number() == -1
def test_config_register_user(config, empty_user): def test_config_register_user(config, empty_user):
config.register_user(some_value=None) config.register_user(some_value=None)
assert config.defaults["USER"]["some_value"] is None assert config.defaults[config.USER]["some_value"] is None
assert config.user(empty_user).some_value() is None assert config.user(empty_user).some_value() is None
@ -57,106 +57,233 @@ def test_config_force_register_global(config_fr):
#endregion #endregion
# Test nested registration
def test_nested_registration(config):
config.register_global(foo__bar__baz=False)
assert config.foo.bar.baz() is False
def test_nested_registration_asdict(config):
defaults = {'bar': {'baz': False}}
config.register_global(foo=defaults)
assert config.foo.bar.baz() is False
@pytest.mark.asyncio
async def test_nested_registration_and_changing(config):
defaults = {'bar': {'baz': False}}
config.register_global(foo=defaults)
assert config.foo.bar.baz() is False
with pytest.raises(ValueError):
await config.foo.set(True)
def test_doubleset_default(config):
config.register_global(foo=True)
config.register_global(foo=False)
assert config.foo() is False
def test_nested_registration_multidict(config):
defaults = {
"foo": {
"bar": {
"baz": True
}
},
"blah": True
}
config.register_global(**defaults)
assert config.foo.bar.baz() is True
assert config.blah() is True
def test_nested_group_value_badreg(config):
config.register_global(foo=True)
with pytest.raises(KeyError):
config.register_global(foo__bar=False)
def test_nested_toplevel_reg(config):
defaults = {'bar': True, 'baz': False}
config.register_global(foo=defaults)
assert config.foo.bar() is True
assert config.foo.baz() is False
def test_nested_overlapping(config):
config.register_global(foo__bar=True)
config.register_global(foo__baz=False)
assert config.foo.bar() is True
assert config.foo.baz() is False
def test_nesting_nofr(config):
config.register_global(foo={})
assert config.foo.bar() is None
assert config.foo() == {}
#region Default Value Overrides #region Default Value Overrides
def test_global_default_override(config): def test_global_default_override(config):
assert config.enabled(True) is True assert config.enabled(True) is True
assert config.get("enabled") is None
assert config.get("enabled", default=True) is True
def test_global_default_nofr(config): def test_global_default_nofr(config):
assert config.nofr() is None assert config.nofr() is None
assert config.nofr(True) is True assert config.nofr(True) is True
assert config.get("nofr") is None
assert config.get("nofr", default=True) is True
def test_guild_default_override(config, empty_guild): def test_guild_default_override(config, empty_guild):
assert config.guild(empty_guild).enabled(True) is True assert config.guild(empty_guild).enabled(True) is True
assert config.guild(empty_guild).get("enabled") is None
assert config.guild(empty_guild).get("enabled", default=True) is True
def test_channel_default_override(config, empty_channel): def test_channel_default_override(config, empty_channel):
assert config.channel(empty_channel).enabled(True) is True assert config.channel(empty_channel).enabled(True) is True
assert config.channel(empty_channel).get("enabled") is None
assert config.channel(empty_channel).get("enabled", default=True) is True
def test_role_default_override(config, empty_role): def test_role_default_override(config, empty_role):
assert config.role(empty_role).enabled(True) is True assert config.role(empty_role).enabled(True) is True
assert config.role(empty_role).get("enabled") is None
assert config.role(empty_role).get("enabled", default=True) is True
def test_member_default_override(config, empty_member): def test_member_default_override(config, empty_member):
assert config.member(empty_member).enabled(True) is True assert config.member(empty_member).enabled(True) is True
assert config.member(empty_member).get("enabled") is None
assert config.member(empty_member).get("enabled", default=True) is True
def test_user_default_override(config, empty_user): def test_user_default_override(config, empty_user):
assert config.user(empty_user).some_value(True) is True assert config.user(empty_user).some_value(True) is True
assert config.user(empty_user).get("some_value") is None
assert config.user(empty_user).get("some_value", default=True) is True
#endregion #endregion
#region Setting Values #region Setting Values
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_global(config): async def test_set_global(config):
await config.set("enabled", True) await config.enabled.set(True)
assert config.enabled() is True assert config.enabled() is True
@pytest.mark.asyncio
async def test_set_global_badkey(config):
with pytest.raises(RuntimeError):
await config.set("this is a bad key", True)
@pytest.mark.asyncio
async def test_set_global_invalidkey(config):
with pytest.raises(KeyError):
await config.set("uuid", True)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_guild(config, empty_guild): async def test_set_guild(config, empty_guild):
await config.guild(empty_guild).set("enabled", True) await config.guild(empty_guild).enabled.set(True)
assert config.guild(empty_guild).enabled() is True assert config.guild(empty_guild).enabled() is True
curr_list = config.guild(empty_guild).some_list([1, 2, 3]) curr_list = config.guild(empty_guild).some_list([1, 2, 3])
assert curr_list == [1, 2, 3] assert curr_list == [1, 2, 3]
curr_list.append(4) curr_list.append(4)
await config.guild(empty_guild).set("some_list", curr_list) await config.guild(empty_guild).some_list.set(curr_list)
assert config.guild(empty_guild).some_list() == curr_list assert config.guild(empty_guild).some_list() == curr_list
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_channel(config, empty_channel): async def test_set_channel(config, empty_channel):
await config.channel(empty_channel).set("enabled", True) await config.channel(empty_channel).enabled.set(True)
assert config.channel(empty_channel).enabled() is True assert config.channel(empty_channel).enabled() is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_channel_no_register(config, empty_channel): async def test_set_channel_no_register(config, empty_channel):
await config.channel(empty_channel).set("no_register", True) await config.channel(empty_channel).no_register.set(True)
assert config.channel(empty_channel).no_register() is True assert config.channel(empty_channel).no_register() is True
#endregion #endregion
# region Getting Values # Dynamic attribute testing
def test_get_func_w_reg(config): @pytest.mark.asyncio
config.register_global( async def test_set_dynamic_attr(config):
thing=True await config.set_attr("foobar", True)
)
assert config.get("thing") is True assert config.foobar() is True
assert config.get("thing", False) is False
def test_get_func_wo_reg(config): def test_get_dynamic_attr(config):
assert config.get("thing") is None assert config.get_attr("foobaz", True) is True
assert config.get("thing", True) is True
# endregion
# Member Group testing
@pytest.mark.asyncio
async def test_membergroup_allguilds(config, empty_member):
await config.member(empty_member).foo.set(False)
all_servers = config.member(empty_member).all_guilds()
assert str(empty_member.guild.id) in all_servers
@pytest.mark.asyncio
async def test_membergroup_allmembers(config, empty_member):
await config.member(empty_member).foo.set(False)
all_members = config.member(empty_member).all()
assert str(empty_member.id) in all_members
# Clearing testing
@pytest.mark.asyncio
async def test_global_clear(config):
config.register_global(foo=True, bar=False)
await config.foo.set(False)
await config.bar.set(True)
assert config.foo() is False
assert config.bar() is True
await config.clear()
assert config.foo() is True
assert config.bar() is False
@pytest.mark.asyncio
async def test_member_clear(config, member_factory):
config.register_member(foo=True)
m1 = member_factory.get()
await config.member(m1).foo.set(False)
assert config.member(m1).foo() is False
m2 = member_factory.get()
await config.member(m2).foo.set(False)
assert config.member(m2).foo() is False
assert m1.guild.id != m2.guild.id
await config.member(m1).clear()
assert config.member(m1).foo() is True
assert config.member(m2).foo() is False
@pytest.mark.asyncio
async def test_member_clear_all(config, member_factory):
server_ids = []
for _ in range(5):
member = member_factory.get()
await config.member(member).foo.set(True)
server_ids.append(member.guild.id)
member = member_factory.get()
assert len(config.member(member).all_guilds()) == len(server_ids)
await config.member(member).clear_all()
assert len(config.member(member).all_guilds()) == 0
# Get All testing
@pytest.mark.asyncio
async def test_user_get_all(config, user_factory):
for _ in range(5):
user = user_factory.get()
await config.user(user).foo.set(True)
user = user_factory.get()
all_data = config.user(user).all()
assert len(all_data) == 5