diff --git a/.travis.yml b/.travis.yml index 5d94153b1..a6a8410fd 100644 --- a/.travis.yml +++ b/.travis.yml @@ -1,3 +1,4 @@ +dist: trusty language: python python: - "3.5.3" diff --git a/cogs/alias/alias.py b/cogs/alias/alias.py index 2179a543e..959ede213 100644 --- a/cogs/alias/alias.py +++ b/cogs/alias/alias.py @@ -81,24 +81,24 @@ class Alias: if global_: curr_aliases = self._aliases.entries() curr_aliases.append(alias.to_json()) - await self._aliases.set("entries", curr_aliases) + await self._aliases.entries.set(curr_aliases) else: curr_aliases = self._aliases.guild(ctx.guild).entries() curr_aliases.append(alias.to_json()) - await self._aliases.guild(ctx.guild).set("entries", curr_aliases) + await self._aliases.guild(ctx.guild).entries.set(curr_aliases) - await self._aliases.guild(ctx.guild).set("enabled", True) + await self._aliases.guild(ctx.guild).enabled.set(True) return alias async def delete_alias(self, ctx: commands.Context, alias_name: str, global_: bool=False) -> bool: if global_: aliases = self.unloaded_global_aliases() - setter_func = self._aliases.set + setter_func = self._aliases.entries.set else: aliases = self.unloaded_aliases(ctx.guild) - setter_func = self._aliases.guild(ctx.guild).set + setter_func = self._aliases.guild(ctx.guild).entries.set did_delete_alias = False @@ -110,7 +110,6 @@ class Alias: did_delete_alias = True await setter_func( - "entries", [a.to_json() for a in to_keep] ) @@ -355,8 +354,9 @@ class Alias: await ctx.send(box("\n".join(names), "diff")) async def on_message(self, message: discord.Message): - aliases = list(self.unloaded_aliases(message.guild)) + \ - list(self.unloaded_global_aliases()) + aliases = list(self.unloaded_global_aliases()) + if message.guild is not None: + aliases = aliases + list(self.unloaded_aliases(message.guild)) if len(aliases) == 0: return diff --git a/cogs/downloader/downloader.py b/cogs/downloader/downloader.py index 54bd0da0c..174273f6c 100644 --- a/cogs/downloader/downloader.py +++ b/cogs/downloader/downloader.py @@ -30,7 +30,7 @@ class Downloader: def __init__(self, bot: Red): self.bot = bot - self.conf = Config.get_conf(self, unique_identifier=998240343, + self.conf = Config.get_conf(self, identifier=998240343, force_registration=True) self.conf.register_global( @@ -73,7 +73,7 @@ class Downloader: if cog_json not in installed: installed.append(cog_json) - await self.conf.set("installed", installed) + await self.conf.installed.set(installed) async def _remove_from_installed(self, cog: Installable): """ @@ -86,7 +86,7 @@ class Downloader: if cog_json in installed: installed.remove(cog_json) - await self.conf.set("installed", installed) + await self.conf.installed.set(installed) async def _reinstall_cogs(self, cogs: Tuple[Installable]) -> Tuple[Installable]: """ diff --git a/cogs/downloader/repo_manager.py b/cogs/downloader/repo_manager.py index a174ea06c..ea2120f4c 100644 --- a/cogs/downloader/repo_manager.py +++ b/cogs/downloader/repo_manager.py @@ -526,4 +526,4 @@ class RepoManager: async def _save_repos(self): repo_json_info = {name: r.to_json() for name, r in self._repos.items()} - await self.downloader_config.set("repos", repo_json_info) + await self.downloader_config.repos.set(repo_json_info) diff --git a/core/bot.py b/core/bot.py index ef03ddb0f..01f4716d9 100644 --- a/core/bot.py +++ b/core/bot.py @@ -47,7 +47,7 @@ class Red(commands.Bot): kwargs["owner_id"] = cli_flags.owner if "owner_id" not in kwargs: - kwargs["owner_id"] = self.db.get("owner") + kwargs["owner_id"] = self.db.owner() self.counter = Counter() self.uptime = None @@ -89,7 +89,7 @@ class Red(commands.Bot): for package in self.extensions: if package.startswith("cogs."): loaded.append(package) - await self.db.set("packages", loaded) + await self.db.packages.set(loaded) class ExitCodes(Enum): diff --git a/core/cli.py b/core/cli.py index 36861e913..9cc145551 100644 --- a/core/cli.py +++ b/core/cli.py @@ -22,7 +22,7 @@ def interactive_config(red, token_set, prefix_set): print("That doesn't look like a valid token.") token = "" if token: - loop.run_until_complete(red.db.set("token", token)) + loop.run_until_complete(red.db.token.set(token)) if not prefix_set: prefix = "" @@ -39,7 +39,7 @@ def interactive_config(red, token_set, prefix_set): if not confirm("> "): prefix = "" if prefix: - loop.run_until_complete(red.db.set("prefix", [prefix])) + loop.run_until_complete(red.db.prefix.set([prefix])) ask_sentry(red) @@ -55,9 +55,9 @@ def ask_sentry(red: Red): " found issues in a timely manner. If you wish to opt in\n" " the process please type \"yes\":\n") if not confirm("> "): - loop.run_until_complete(red.db.set("enable_sentry", False)) + loop.run_until_complete(red.db.enable_sentry.set(False)) else: - loop.run_until_complete(red.db.set("enable_sentry", True)) + loop.run_until_complete(red.db.enable_sentry.set(True)) print("\nThank you for helping us with the development process!") diff --git a/core/config.py b/core/config.py index 3d636decb..137d7be30 100644 --- a/core/config.py +++ b/core/config.py @@ -1,521 +1,374 @@ -from pathlib import Path - -from core.drivers.red_json import JSON as JSONDriver -from core.drivers.red_mongo import Mongo import logging -from typing import Callable +from typing import Callable, Union, Tuple + +import discord +from copy import deepcopy + +from pathlib import Path + +from .drivers.red_json import JSON as JSONDriver log = logging.getLogger("red.config") -class BaseConfig: - def __init__(self, cog_name, unique_identifier, driver_spawn, force_registration=False, - hash_uuid=True, collection="GLOBAL", collection_uuid=None, - defaults={}): - self.cog_name = cog_name - if hash_uuid: - self.uuid = str(hash(unique_identifier)) - else: - self.uuid = unique_identifier - self.driver_spawn = driver_spawn - self._driver = None - self.collection = collection - self.collection_uuid = collection_uuid - self.force_registration = force_registration +class Value: + def __init__(self, identifiers: Tuple[str], default_value, spawner): + self._identifiers = identifiers + self.default = default_value + self.spawner = spawner + + @property + def identifiers(self): + return tuple(str(i) for i in self._identifiers) + + def __call__(self, default=None): + driver = self.spawner.get_driver() try: - self.driver.maybe_add_ident(self.uuid) - except AttributeError: - pass + ret = driver.get(self.identifiers) + except KeyError: + return default or self.default + return ret - self.driver_getmap = { - "GLOBAL": self.driver.get_global, - "GUILD": self.driver.get_guild, - "CHANNEL": self.driver.get_channel, - "ROLE": self.driver.get_role, - "USER": self.driver.get_user - } + async def set(self, value): + driver = self.spawner.get_driver() + await driver.set(self.identifiers, value) - self.driver_setmap = { - "GLOBAL": self.driver.set_global, - "GUILD": self.driver.set_guild, - "CHANNEL": self.driver.set_channel, - "ROLE": self.driver.set_role, - "USER": self.driver.set_user - } - self.curr_key = None +class Group(Value): + def __init__(self, identifiers: Tuple[str], + defaults: dict, + spawner, + force_registration: bool=False): + self.defaults = defaults + self.force_registration = force_registration + self.spawner = spawner - self.unsettable_keys = ("cog_name", "cog_identifier", "_id", - "guild_id", "channel_id", "role_id", - "user_id", "uuid") - self.invalid_keys = ( - "driver_spawn", - "_driver", "collection", - "collection_uuid", "force_registration" + super().__init__(identifiers, {}, self.spawner) + + # noinspection PyTypeChecker + def __getattr__(self, item: str) -> Union["Group", Value]: + """ + Takes in the next accessible item. If it's found to be a Group + we return another Group object. If it's found to be a Value + we return a Value object. If it is not found and + force_registration is True then we raise AttributeException, + otherwise return a Value object. + :param item: + :return: + """ + is_group = self.is_group(item) + is_value = not is_group and self.is_value(item) + new_identifiers = self.identifiers + (item, ) + if is_group: + return Group( + identifiers=new_identifiers, + defaults=self.defaults[item], + spawner=self.spawner, + force_registration=self.force_registration + ) + elif is_value: + return Value( + identifiers=new_identifiers, + default_value=self.defaults[item], + spawner=self.spawner + ) + elif self.force_registration: + raise AttributeError( + "'{}' is not a valid registered Group" + "or value.".format(item) + ) + else: + return Value( + identifiers=new_identifiers, + default_value=None, + spawner=self.spawner + ) + + @property + def _super_group(self) -> 'Group': + super_group = Group( + self.identifiers[:-1], + defaults={}, + spawner=self.spawner, + force_registration=self.force_registration ) + return super_group - self.defaults = defaults if defaults else { - "GLOBAL": {}, "GUILD": {}, "CHANNEL": {}, "ROLE": {}, - "MEMBER": {}, "USER": {}} + def is_group(self, item: str) -> bool: + """ + Determines if an attribute access is pointing at a registered group. + :param item: + :return: + """ + default = self.defaults.get(item) + return isinstance(default, dict) + + def is_value(self, item: str) -> bool: + """ + Determines if an attribute access is pointing at a registered value. + :param item: + :return: + """ + try: + default = self.defaults[item] + except KeyError: + return False + + return not isinstance(default, dict) + + def get_attr(self, item: str, default=None): + """ + You should avoid this function whenever possible. + :param item: + :param default: + :return: + """ + value = getattr(self, item) + return value(default=default) + + def all(self) -> dict: + """ + Gets all entries of the given kind. If this kind is member + then this method returns all members from the same + server. + :return: + """ + # noinspection PyTypeChecker + return self._super_group() + + async def set(self, value): + if not isinstance(value, dict): + raise ValueError( + "You may only set the value of a group to be a dict." + ) + await super().set(value) + + async def set_attr(self, item: str, value): + """ + You should avoid this function whenever possible. + :param item: + :param value: + :return: + """ + value_obj = getattr(self, item) + await value_obj.set(value) + + async def clear(self): + """ + Wipes out data for the given entry in this category + e.g. Guild/Role/User + :return: + """ + await self.set({}) + + async def clear_all(self): + """ + Removes all data from all entries. + :return: + """ + await self._super_group.set({}) + + +class MemberGroup(Group): + @property + def _super_group(self) -> Group: + new_identifiers = self.identifiers[:2] + group_obj = Group( + identifiers=new_identifiers, + defaults={}, + spawner=self.spawner + ) + return group_obj + + @property + def _guild_group(self) -> Group: + new_identifiers = self.identifiers[:3] + group_obj = Group( + identifiers=new_identifiers, + defaults={}, + spawner=self.spawner + ) + return group_obj + + def all_guilds(self) -> dict: + """ + Gets a dict of all guilds and members. + + REMEMBER: ID's are stored in these dicts as STRINGS. + :return: + """ + # noinspection PyTypeChecker + return self._super_group() + + def all(self) -> dict: + """ + Returns the dict of all members in the same guild. + :return: + """ + # noinspection PyTypeChecker + return self._guild_group() + +class Config: + GLOBAL = "GLOBAL" + GUILD = "GUILD" + CHANNEL = "TEXTCHANNEL" + ROLE = "ROLE" + USER = "USER" + MEMBER = "MEMBER" + + def __init__(self, cog_name: str, unique_identifier: str, + driver_spawn: Callable, + force_registration: bool=False, + defaults: dict=None): + self.cog_name = cog_name + self.unique_identifier = unique_identifier + + self.spawner = driver_spawn + self.force_registration = force_registration + self.defaults = defaults or {} @classmethod - def get_conf(cls, cog_instance: object, unique_identifier: int=0, - force_registration: bool=False): + def get_conf(cls, cog_instance, identifier: int, + force_registration=False): """ - Gets a config object that cog's can use to safely store data. The - backend to this is totally modular and can easily switch between - JSON and a DB. However, when changed, all data will likely be lost - unless cogs write some converters for their data. - - Positional Arguments: - cog_instance - The cog `self` object, can be passed in from your - cog's __init__ method. - - Keyword Arguments: - unique_identifier - a random integer or string that is used to - differentiate your cog from any other named the same. This way we - can safely store data for multiple cogs that are named the same. - - YOU SHOULD USE THIS. - - force_registration - A flag which will cause the Config object to - throw exceptions if you try to get/set data keys that you have - not pre-registered. I highly recommend you ENABLE this as it - will help reduce dumb typo errors. + Returns a Config instance based on a simplified set of initial + variables. + :param cog_instance: + :param identifier: Any random integer, used to keep your data + distinct from any other cog with the same name. + :param force_registration: Should config require registration + of data keys before allowing you to get/set values? + :return: """ - - url = None # TODO: get mongo url - port = None # TODO: get mongo port - - def spawn_mongo_driver(): - return Mongo(url, port) - - # TODO: Determine which backend users want, default to JSON - cog_name = cog_instance.__class__.__name__ + uuid = str(hash(identifier)) - driver_spawn = JSONDriver(cog_name) - - return cls(cog_name=cog_name, unique_identifier=unique_identifier, - driver_spawn=driver_spawn, force_registration=force_registration) + spawner = JSONDriver(cog_name) + return cls(cog_name=cog_name, unique_identifier=uuid, + force_registration=force_registration, + driver_spawn=spawner) @classmethod def get_core_conf(cls, force_registration: bool=False): core_data_path = Path.cwd() / 'core' / '.data' driver_spawn = JSONDriver("Core", data_path_override=core_data_path) return cls(cog_name="Core", driver_spawn=driver_spawn, - unique_identifier=0, + unique_identifier='0', force_registration=force_registration) - @property - def driver(self): - if self._driver is None: - try: - self._driver = self.driver_spawn() - except TypeError: - return self.driver_spawn + def __getattr__(self, item: str) -> Union[Group, Value]: + """ + This is used to generate Value or Group objects for global + values. + :param item: + :return: + """ + global_group = self._get_base_group(self.GLOBAL) + return getattr(global_group, item) - return self._driver - - def __getattr__(self, key): - """This should be used to return config key data as determined by - `self.collection` and `self.collection_uuid`.""" - raise NotImplemented - - def __setattr__(self, key, value): - if 'defaults' in self.__dict__: # Necessary to let the cog load - restricted = list(self.defaults[self.collection].keys()) + \ - list(self.unsettable_keys) - if key in restricted: - raise ValueError("Not allowed to dynamically set attributes of" - " unsettable_keys: {}".format(restricted)) + @staticmethod + def _get_defaults_dict(key: str, value) -> dict: + """ + Since we're allowing nested config stuff now, not storing the + defaults as a flat dict sounds like a good idea. May turn + out to be an awful one but we'll see. + :param key: + :param value: + :return: + """ + ret = {} + partial = ret + splitted = key.split('__') + for i, k in enumerate(splitted, start=1): + if not k.isidentifier(): + raise RuntimeError("'{}' is an invalid config key.".format(k)) + if i == len(splitted): + partial[k] = value else: - self.__dict__[key] = value - else: - self.__dict__[key] = value - - def clear(self): - """Clears all values in the current context ONLY.""" - raise NotImplemented - - def set(self, key, value): - """This should set config key with value `value` in the - corresponding collection as defined by `self.collection` and - `self.collection_uuid`.""" - raise NotImplemented - - def guild(self, guild): - """This should return a `BaseConfig` instance with the corresponding - `collection` and `collection_uuid`.""" - raise NotImplemented - - def channel(self, channel): - """This should return a `BaseConfig` instance with the corresponding - `collection` and `collection_uuid`.""" - raise NotImplemented - - def role(self, role): - """This should return a `BaseConfig` instance with the corresponding - `collection` and `collection_uuid`.""" - raise NotImplemented - - def member(self, member): - """This should return a `BaseConfig` instance with the corresponding - `collection` and `collection_uuid`.""" - raise NotImplemented - - def user(self, user): - """This should return a `BaseConfig` instance with the corresponding - `collection` and `collection_uuid`.""" - raise NotImplemented - - def register_global(self, **global_defaults): - """ - Registers a new dict of global defaults. This function should - be called EVERY TIME the cog loads (aka just do it in - __init__)! - - :param global_defaults: Each key should be the key you want to - access data by and the value is the default value of that - key. - :return: - """ - for k, v in global_defaults.items(): - try: - self._register_global(k, v) - except KeyError: - log.exception("Bad default global key.") - - def _register_global(self, key, default=None): - """Registers a global config key `key`""" - if key in self.unsettable_keys: - raise KeyError("Attempt to use restricted key: '{}'".format(key)) - elif not key.isidentifier(): - raise RuntimeError("Invalid key name, must be a valid python variable" - " name.") - self.defaults["GLOBAL"][key] = default - - def register_guild(self, **guild_defaults): - """ - Registers a new dict of guild defaults. This function should - be called EVERY TIME the cog loads (aka just do it in - __init__)! - - :param guild_defaults: Each key should be the key you want to - access data by and the value is the default value of that - key. - :return: - """ - for k, v in guild_defaults.items(): - try: - self._register_guild(k, v) - except KeyError: - log.exception("Bad default guild key.") - - def _register_guild(self, key, default=None): - """Registers a guild config key `key`""" - if key in self.unsettable_keys: - raise KeyError("Attempt to use restricted key: '{}'".format(key)) - elif not key.isidentifier(): - raise RuntimeError("Invalid key name, must be a valid python variable" - " name.") - self.defaults["GUILD"][key] = default - - def register_channel(self, **channel_defaults): - """ - Registers a new dict of channel defaults. This function should - be called EVERY TIME the cog loads (aka just do it in - __init__)! - - :param channel_defaults: Each key should be the key you want to - access data by and the value is the default value of that - key. - :return: - """ - for k, v in channel_defaults.items(): - try: - self._register_channel(k, v) - except KeyError: - log.exception("Bad default channel key.") - - def _register_channel(self, key, default=None): - """Registers a channel config key `key`""" - if key in self.unsettable_keys: - raise KeyError("Attempt to use restricted key: '{}'".format(key)) - elif not key.isidentifier(): - raise RuntimeError("Invalid key name, must be a valid python variable" - " name.") - self.defaults["CHANNEL"][key] = default - - def register_role(self, **role_defaults): - """ - Registers a new dict of role defaults. This function should - be called EVERY TIME the cog loads (aka just do it in - __init__)! - - :param role_defaults: Each key should be the key you want to - access data by and the value is the default value of that - key. - :return: - """ - for k, v in role_defaults.items(): - try: - self._register_role(k, v) - except KeyError: - log.exception("Bad default role key.") - - def _register_role(self, key, default=None): - """Registers a role config key `key`""" - if key in self.unsettable_keys: - raise KeyError("Attempt to use restricted key: '{}'".format(key)) - elif not key.isidentifier(): - raise RuntimeError("Invalid key name, must be a valid python variable" - " name.") - self.defaults["ROLE"][key] = default - - def register_member(self, **member_defaults): - """ - Registers a new dict of member defaults. This function should - be called EVERY TIME the cog loads (aka just do it in - __init__)! - - :param member_defaults: Each key should be the key you want to - access data by and the value is the default value of that - key. - :return: - """ - for k, v in member_defaults.items(): - try: - self._register_member(k, v) - except KeyError: - log.exception("Bad default member key.") - - def _register_member(self, key, default=None): - """Registers a member config key `key`""" - if key in self.unsettable_keys: - raise KeyError("Attempt to use restricted key: '{}'".format(key)) - elif not key.isidentifier(): - raise RuntimeError("Invalid key name, must be a valid python variable" - " name.") - self.defaults["MEMBER"][key] = default - - def register_user(self, **user_defaults): - """ - Registers a new dict of user defaults. This function should - be called EVERY TIME the cog loads (aka just do it in - __init__)! - - :param user_defaults: Each key should be the key you want to - access data by and the value is the default value of that - key. - :return: - """ - for k, v in user_defaults.items(): - try: - self._register_user(k, v) - except KeyError: - log.exception("Bad default user key.") - - def _register_user(self, key, default=None): - """Registers a user config key `key`""" - if key in self.unsettable_keys: - raise KeyError("Attempt to use restricted key: '{}'".format(key)) - elif not key.isidentifier(): - raise RuntimeError("Invalid key name, must be a valid python variable" - " name.") - self.defaults["USER"][key] = default - - -class Config(BaseConfig): - """ - Config object created by `Config.get_conf()` - - This configuration object is designed to make backend data - storage mechanisms pluggable. It also is designed to - help a cog developer make fewer mistakes (such as - typos) when dealing with cog data and to make those mistakes - apparent much faster in the design process. - - It also has the capability to safely store data between cogs - that share the same name. - - There are two main components to this config object. First, - you have the ability to get data on a level specific basis. - The seven levels available are: global, guild, channel, role, - member, user, and misc. - - The second main component is registering default values for - data in each of the levels. This functionality is OPTIONAL - and must be explicitly enabled when creating the Config object - using the kwarg `force_registration=True`. - - Basic Usage: - Creating a Config object: - Use the `Config.get_conf()` class method to create new - Config objects. - - See the `Config.get_conf()` documentation for more - information. - - Registering Default Values (optional): - You can register default values for data at all levels - EXCEPT misc. - - Simply pass in the key/value pairs as keyword arguments to - the respective function. - - e.g.: conf_obj.register_global(enabled=True) - conf_obj.register_guild(likes_red=True) - - Retrieving data by attributes: - Since I registered the "enabled" key in the previous example - at the global level I can now do: - - conf_obj.enabled() - - which will retrieve the current value of the "enabled" - key, making use of the default of "True". I can also do - the same for the guild key "likes_red": - - conf_obj.guild(guild_obj).likes_red() - - If I elected to not register default values, you can provide them - when you try to access the key: - - conf_obj.no_default(default=True) - - However if you do not provide a default and you do not register - defaults, accessing the attribute will return "None". - - Saving data: - This is accomplished by using the `set` function available at - every level. - - e.g.: conf_obj.set("enabled", False) - conf_obj.guild(guild_obj).set("likes_red", False) - - If `force_registration` was enabled when the config object - was created you will only be allowed to save keys that you - have registered. - - Misc data is special, use `conf.misc()` and `conf.set_misc(value)` - respectively. - """ - - def __getattr__(self, key) -> Callable: - """ - Until I've got a better way to do this I'm just gonna fake __call__ - - :param key: - :return: lambda function with kwarg - """ - return self._get_value_from_key(key) - - def _get_value_from_key(self, key) -> Callable: - try: - default = self.defaults[self.collection][key] - except KeyError as e: - if self.force_registration: - raise AttributeError("Key '{}' not registered!".format(key)) from e - default = None - - self.curr_key = key - - if self.collection != "MEMBER": - ret = lambda default=default: self.driver_getmap[self.collection]( - self.cog_name, self.uuid, self.collection_uuid, key, - default=default) - else: - mid, sid = self.collection_uuid - ret = lambda default=default: self.driver.get_member( - self.cog_name, self.uuid, mid, sid, key, - default=default) + partial[k] = {} + partial = partial[k] return ret - def get(self, key, default=None): + @staticmethod + def _update_defaults(to_add: dict, _partial: dict): """ - Included as an alternative to registering defaults. - - :param key: - :param default: - :return: + This tries to update the defaults dictionary with the nested + partial dict generated by _get_defaults_dict. This WILL + throw an error if you try to have both a value and a group + registered under the same name. + :param to_add: + :param _partial: + :return: """ + for k, v in to_add.items(): + val_is_dict = isinstance(v, dict) + if k in _partial: + existing_is_dict = isinstance(_partial[k], dict) + if val_is_dict != existing_is_dict: + # != is XOR + raise KeyError("You cannot register a Group and a Value under" + " the same name.") + if val_is_dict: + Config._update_defaults(v, _partial=_partial[k]) + else: + _partial[k] = v + else: + _partial[k] = v - if default is not None: - return self._get_value_from_key(key)(default) - else: - return self._get_value_from_key(key)() + def _register_default(self, key: str, **kwargs): + if key not in self.defaults: + self.defaults[key] = {} - async def set(self, key, value): - # Notice to future developers: - # This code was commented to allow users to set keys without having to register them. - # That being said, if they try to get keys without registering them - # things will blow up. I do highly recommend enforcing the key registration. + data = deepcopy(kwargs) - if key in self.unsettable_keys or key in self.invalid_keys: - raise KeyError("Restricted key name, please use another.") + for k, v in data.items(): + to_add = self._get_defaults_dict(k, v) + self._update_defaults(to_add, self.defaults[key]) - if self.force_registration and key not in self.defaults[self.collection]: - raise AttributeError("Key '{}' not registered!".format(key)) + def register_global(self, **kwargs): + self._register_default(self.GLOBAL, **kwargs) - if not key.isidentifier(): - raise RuntimeError("Invalid key name, must be a valid python variable" - " name.") + def register_guild(self, **kwargs): + self._register_default(self.GUILD, **kwargs) - if self.collection == "GLOBAL": - await self.driver.set_global(self.cog_name, self.uuid, key, value) - elif self.collection == "MEMBER": - mid, sid = self.collection_uuid - await self.driver.set_member(self.cog_name, self.uuid, mid, sid, - key, value) - elif self.collection in self.driver_setmap: - func = self.driver_setmap[self.collection] - await func(self.cog_name, self.uuid, self.collection_uuid, key, value) + def register_channel(self, **kwargs): + # We may need to add a voice channel category later + self._register_default(self.CHANNEL, **kwargs) - async def clear(self): - await self.driver_setmap[self.collection]( - self.cog_name, self.uuid, self.collection_uuid, None, None, - clear=True) + def register_role(self, **kwargs): + self._register_default(self.ROLE, **kwargs) - def guild(self, guild): - new = type(self)(self.cog_name, self.uuid, self.driver, - hash_uuid=False, defaults=self.defaults) - new.collection = "GUILD" - new.collection_uuid = guild.id - new._driver = None - return new + def register_user(self, **kwargs): + self._register_default(self.USER, **kwargs) - def channel(self, channel): - new = type(self)(self.cog_name, self.uuid, self.driver, - hash_uuid=False, defaults=self.defaults) - new.collection = "CHANNEL" - new.collection_uuid = channel.id - new._driver = None - return new + def register_member(self, **kwargs): + self._register_default(self.MEMBER, **kwargs) - def role(self, role): - new = type(self)(self.cog_name, self.uuid, self.driver, - hash_uuid=False, defaults=self.defaults) - new.collection = "ROLE" - new.collection_uuid = role.id - new._driver = None - return new + def _get_base_group(self, key: str, *identifiers: str, + group_class=Group) -> Group: + # noinspection PyTypeChecker + return group_class( + identifiers=(self.unique_identifier, key) + identifiers, + defaults=self.defaults.get(key, {}), + spawner=self.spawner, + force_registration=self.force_registration + ) - def member(self, member): - guild = member.guild - new = type(self)(self.cog_name, self.uuid, self.driver, - hash_uuid=False, defaults=self.defaults) - new.collection = "MEMBER" - new.collection_uuid = (member.id, guild.id) - new._driver = None - return new + def guild(self, guild: discord.Guild) -> Group: + return self._get_base_group(self.GUILD, guild.id) + + def channel(self, channel: discord.TextChannel) -> Group: + return self._get_base_group(self.CHANNEL, channel.id) + + def role(self, role: discord.Role) -> Group: + return self._get_base_group(self.ROLE, role.id) + + def user(self, user: discord.User) -> Group: + return self._get_base_group(self.USER, user.id) + + def member(self, member: discord.Member) -> MemberGroup: + return self._get_base_group(self.MEMBER, member.guild.id, member.id, + group_class=MemberGroup) - def user(self, user): - new = type(self)(self.cog_name, self.uuid, self.driver, - hash_uuid=False, defaults=self.defaults) - new.collection = "USER" - new.collection_uuid = user.id - new._driver = None - return new diff --git a/core/core_commands.py b/core/core_commands.py index 76c2bfdf8..3c99fab1f 100644 --- a/core/core_commands.py +++ b/core/core_commands.py @@ -98,7 +98,7 @@ class Core: @commands.guild_only() async def adminrole(self, ctx, *, role: discord.Role): """Sets the admin role for this server""" - await ctx.bot.db.guild(ctx.guild).set("admin_role", role.id) + await ctx.bot.db.guild(ctx.guild).admin_role.set(role.id) await ctx.send("The admin role for this server has been set.") @_set.command() @@ -106,7 +106,7 @@ class Core: @commands.guild_only() async def modrole(self, ctx, *, role: discord.Role): """Sets the mod role for this server""" - await ctx.bot.db.guild(ctx.guild).set("mod_role", role.id) + await ctx.bot.db.guild(ctx.guild).mod_role.set(role.id) await ctx.send("The mod role for this server has been set.") @_set.command() @@ -225,7 +225,7 @@ class Core: await ctx.bot.send_cmd_help(ctx) return prefixes = sorted(prefixes, reverse=True) - await ctx.bot.db.set("prefix", prefixes) + await ctx.bot.db.prefix.set(prefixes) await ctx.send("Prefix set.") @_set.command(aliases=["serverprefixes"]) @@ -234,11 +234,11 @@ class Core: async def serverprefix(self, ctx, *prefixes): """Sets Red's server prefix(es)""" if not prefixes: - await ctx.bot.db.guild(ctx.guild).set("prefix", []) + await ctx.bot.db.guild(ctx.guild).prefix.set([]) await ctx.send("Server prefixes have been reset.") return prefixes = sorted(prefixes, reverse=True) - await ctx.bot.db.guild(ctx.guild).set("prefix", prefixes) + await ctx.bot.db.guild(ctx.guild).prefix.set(prefixes) await ctx.send("Prefix set.") @_set.command() @@ -276,7 +276,7 @@ class Core: else: if message.content.strip() == token: self.owner.reset_cooldown(ctx) - await ctx.bot.db.set("owner", ctx.author.id) + await ctx.bot.db.owner.set(ctx.author.id) ctx.bot.owner_id = ctx.author.id await ctx.send("You have been set as owner.") else: diff --git a/core/drivers/red_base.py b/core/drivers/red_base.py index 1b58cc1dc..a0d838272 100644 --- a/core/drivers/red_base.py +++ b/core/drivers/red_base.py @@ -1,45 +1,12 @@ +from typing import Tuple + + class BaseDriver: - def get_global(self, cog_name, ident, collection_id, key, *, default=None): - raise NotImplementedError() + def get_driver(self): + raise NotImplementedError - def get_guild(self, cog_name, ident, guild_id, key, *, default=None): - raise NotImplementedError() + def get(self, identifiers: Tuple[str]): + raise NotImplementedError - def get_channel(self, cog_name, ident, channel_id, key, *, default=None): - raise NotImplementedError() - - def get_role(self, cog_name, ident, role_id, key, *, default=None): - raise NotImplementedError() - - def get_member(self, cog_name, ident, user_id, guild_id, key, *, - default=None): - raise NotImplementedError() - - def get_user(self, cog_name, ident, user_id, key, *, default=None): - raise NotImplementedError() - - def get_misc(self, cog_name, ident, *, default=None): - raise NotImplementedError() - - async def set_global(self, cog_name, ident, key, value, clear=False): - raise NotImplementedError() - - async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False): - raise NotImplementedError() - - async def set_channel(self, cog_name, ident, channel_id, key, value, - clear=False): - raise NotImplementedError() - - async def set_role(self, cog_name, ident, role_id, key, value, clear=False): - raise NotImplementedError() - - async def set_member(self, cog_name, ident, user_id, guild_id, key, value, - clear=False): - raise NotImplementedError() - - async def set_user(self, cog_name, ident, user_id, key, value, clear=False): - raise NotImplementedError() - - async def set_misc(self, cog_name, ident, value, clear=False): - raise NotImplementedError() + async def set(self, identifiers: Tuple[str], value): + raise NotImplementedError diff --git a/core/drivers/red_json.py b/core/drivers/red_json.py index 073d37a7c..7dfc64a5a 100644 --- a/core/drivers/red_json.py +++ b/core/drivers/red_json.py @@ -1,13 +1,15 @@ +from typing import Tuple + +from core.drivers.red_base import BaseDriver from core.json_io import JsonIO -import os -from .red_base import BaseDriver from pathlib import Path class JSON(BaseDriver): - def __init__(self, cog_name, *args, data_path_override: Path=None, - file_name_override: str="settings.json", **kwargs): + def __init__(self, cog_name, *, data_path_override: Path=None, + file_name_override: str="settings.json"): + super().__init__() self.cog_name = cog_name self.file_name = file_name_override if data_path_override: @@ -25,111 +27,23 @@ class JSON(BaseDriver): self.data = self.jsonIO._load_json() except FileNotFoundError: self.data = {} - - def maybe_add_ident(self, ident: str): - if ident in self.data: - return - - self.data[ident] = {} - for k in ("GLOBAL", "GUILD", "CHANNEL", "ROLE", "MEMBER", "USER"): - if k not in self.data[ident]: - self.data[ident][k] = {} - self.jsonIO._save_json(self.data) - def get_global(self, cog_name, ident, _, key, *, default=None): - return self.data[ident]["GLOBAL"].get(key, default) + def get_driver(self): + return self - def get_guild(self, cog_name, ident, guild_id, key, *, default=None): - guilddata = self.data[ident]["GUILD"].get(str(guild_id), {}) - return guilddata.get(key, default) + def get(self, identifiers: Tuple[str]): + partial = self.data + for i in identifiers: + partial = partial[i] + return partial - def get_channel(self, cog_name, ident, channel_id, key, *, default=None): - channeldata = self.data[ident]["CHANNEL"].get(str(channel_id), {}) - return channeldata.get(key, default) + async def set(self, identifiers, value): + partial = self.data + for i in identifiers[:-1]: + if i not in partial: + partial[i] = {} + partial = partial[i] - def get_role(self, cog_name, ident, role_id, key, *, default=None): - roledata = self.data[ident]["ROLE"].get(str(role_id), {}) - return roledata.get(key, default) - - def get_member(self, cog_name, ident, user_id, guild_id, key, *, - default=None): - userdata = self.data[ident]["MEMBER"].get(str(user_id), {}) - guilddata = userdata.get(str(guild_id), {}) - return guilddata.get(key, default) - - def get_user(self, cog_name, ident, user_id, key, *, default=None): - userdata = self.data[ident]["USER"].get(str(user_id), {}) - return userdata.get(key, default) - - async def set_global(self, cog_name, ident, key, value, clear=False): - if clear: - self.data[ident]["GLOBAL"] = {} - else: - self.data[ident]["GLOBAL"][key] = value - await self.jsonIO._threadsafe_save_json(self.data) - - async def set_guild(self, cog_name, ident, guild_id, key, value, clear=False): - guild_id = str(guild_id) - if clear: - self.data[ident]["GUILD"][guild_id] = {} - else: - try: - self.data[ident]["GUILD"][guild_id][key] = value - except KeyError: - self.data[ident]["GUILD"][guild_id] = {} - self.data[ident]["GUILD"][guild_id][key] = value - await self.jsonIO._threadsafe_save_json(self.data) - - async def set_channel(self, cog_name, ident, channel_id, key, value, clear=False): - channel_id = str(channel_id) - if clear: - self.data[ident]["CHANNEL"][channel_id] = {} - else: - try: - self.data[ident]["CHANNEL"][channel_id][key] = value - except KeyError: - self.data[ident]["CHANNEL"][channel_id] = {} - self.data[ident]["CHANNEL"][channel_id][key] = value - await self.jsonIO._threadsafe_save_json(self.data) - - async def set_role(self, cog_name, ident, role_id, key, value, clear=False): - role_id = str(role_id) - if clear: - self.data[ident]["ROLE"][role_id] = {} - else: - try: - self.data[ident]["ROLE"][role_id][key] = value - except KeyError: - self.data[ident]["ROLE"][role_id] = {} - self.data[ident]["ROLE"][role_id][key] = value - await self.jsonIO._threadsafe_save_json(self.data) - - async def set_member(self, cog_name, ident, user_id, guild_id, key, value, clear=False): - user_id = str(user_id) - guild_id = str(guild_id) - if clear: - self.data[ident]["MEMBER"][user_id] = {} - else: - try: - self.data[ident]["MEMBER"][user_id][guild_id][key] = value - except KeyError: - if user_id not in self.data[ident]["MEMBER"]: - self.data[ident]["MEMBER"][user_id] = {} - if guild_id not in self.data[ident]["MEMBER"][user_id]: - self.data[ident]["MEMBER"][user_id][guild_id] = {} - - self.data[ident]["MEMBER"][user_id][guild_id][key] = value - await self.jsonIO._threadsafe_save_json(self.data) - - async def set_user(self, cog_name, ident, user_id, key, value, clear=False): - user_id = str(user_id) - if clear: - self.data[ident]["USER"][user_id] = {} - else: - try: - self.data[ident]["USER"][user_id][key] = value - except KeyError: - self.data[ident]["USER"][user_id] = {} - self.data[ident]["USER"][user_id][key] = value + partial[identifiers[-1]] = value await self.jsonIO._threadsafe_save_json(self.data) diff --git a/main.py b/main.py index 8282667a6..ed41e82b8 100644 --- a/main.py +++ b/main.py @@ -113,7 +113,7 @@ if __name__ == '__main__': if db_token and not cli_flags.no_prompt: print("\nDo you want to reset the token? (y/n)") if confirm("> "): - loop.run_until_complete(red.db.set("token", "")) + loop.run_until_complete(red.db.token.set("")) print("Token has been reset.") except KeyboardInterrupt: log.info("Keyboard interrupt detected. Quitting...") diff --git a/tests/cogs/test_alias.py b/tests/cogs/test_alias.py index 9e39b6f51..f9c59b7a2 100644 --- a/tests/cogs/test_alias.py +++ b/tests/cogs/test_alias.py @@ -2,12 +2,11 @@ from cogs.alias import Alias import pytest -@pytest.fixture -def alias(monkeysession, config): - def get_mock_conf(*args, **kwargs): - return config +@pytest.fixture() +def alias(config): + import cogs.alias.alias - monkeysession.setattr("core.config.Config.get_conf", get_mock_conf) + cogs.alias.alias.Config.get_conf = lambda *args, **kwargs: config return Alias(None) @@ -25,9 +24,17 @@ def test_empty_global_aliases(alias): assert list(alias.unloaded_global_aliases()) == [] +async def create_test_guild_alias(alias, ctx): + await alias.add_alias(ctx, "test", "ping", global_=False) + + +async def create_test_global_alias(alias, ctx): + await alias.add_alias(ctx, "test", "ping", global_=True) + + @pytest.mark.asyncio async def test_add_guild_alias(alias, ctx): - await alias.add_alias(ctx, "test", "ping", global_=False) + await create_test_guild_alias(alias, ctx) is_alias, alias_obj = alias.is_alias(ctx.guild, "test") assert is_alias is True @@ -36,6 +43,7 @@ async def test_add_guild_alias(alias, ctx): @pytest.mark.asyncio async def test_delete_guild_alias(alias, ctx): + await create_test_guild_alias(alias, ctx) is_alias, _ = alias.is_alias(ctx.guild, "test") assert is_alias is True @@ -47,7 +55,7 @@ async def test_delete_guild_alias(alias, ctx): @pytest.mark.asyncio async def test_add_global_alias(alias, ctx): - await alias.add_alias(ctx, "test", "ping", global_=True) + await create_test_global_alias(alias, ctx) is_alias, alias_obj = alias.is_alias(ctx.guild, "test") assert is_alias is True @@ -56,6 +64,7 @@ async def test_add_global_alias(alias, ctx): @pytest.mark.asyncio async def test_delete_global_alias(alias, ctx): + await create_test_global_alias(alias, ctx) is_alias, alias_obj = alias.is_alias(ctx.guild, "test") assert is_alias is True assert alias_obj.global_ is True diff --git a/tests/conftest.py b/tests/conftest.py index 86d70afef..5ecc9b178 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -17,21 +17,27 @@ def monkeysession(request): mpatch.undo() -@pytest.fixture(scope="module") +@pytest.fixture() def json_driver(tmpdir_factory): + import uuid + rand = str(uuid.uuid4()) + path = Path(str(tmpdir_factory.mktemp(rand))) driver = red_json.JSON( "PyTest", - data_path_override=Path(str(tmpdir_factory.getbasetemp())) + data_path_override=path ) return driver @pytest.fixture() def config(json_driver): - return Config( + import uuid + conf = Config( cog_name="PyTest", - unique_identifier=0, + unique_identifier=str(uuid.uuid4()), driver_spawn=json_driver) + yield conf + conf.defaults = {} @pytest.fixture() @@ -39,19 +45,32 @@ def config_fr(json_driver): """ Mocked config object with force_register enabled. """ - return Config( + import uuid + conf = Config( cog_name="PyTest", - unique_identifier=0, + unique_identifier=str(uuid.uuid4()), driver_spawn=json_driver, force_registration=True ) + yield conf + conf.defaults = {} #region Dpy Mocks -@pytest.fixture(scope="module") -def empty_guild(): +@pytest.fixture() +def guild_factory(): mock_guild = namedtuple("Guild", "id members") - return mock_guild(random.randint(1, 999999999), []) + + class GuildFactory: + def get(self): + return mock_guild(random.randint(1, 999999999), []) + + return GuildFactory() + + +@pytest.fixture() +def empty_guild(guild_factory): + return guild_factory.get() @pytest.fixture(scope="module") @@ -66,16 +85,39 @@ def empty_role(): return mock_role(random.randint(1, 999999999)) -@pytest.fixture(scope="module") -def empty_member(empty_guild): +@pytest.fixture() +def member_factory(guild_factory): mock_member = namedtuple("Member", "id guild") - return mock_member(random.randint(1, 999999999), empty_guild) + + class MemberFactory: + def get(self): + return mock_member( + random.randint(1, 999999999), + guild_factory.get()) + + return MemberFactory() -@pytest.fixture(scope="module") -def empty_user(): +@pytest.fixture() +def empty_member(member_factory): + return member_factory.get() + + +@pytest.fixture() +def user_factory(): mock_user = namedtuple("User", "id") - return mock_user(random.randint(1, 999999999)) + + class UserFactory: + def get(self): + return mock_user( + random.randint(1, 999999999)) + + return UserFactory() + + +@pytest.fixture() +def empty_user(user_factory): + return user_factory.get() @pytest.fixture(scope="module") @@ -84,7 +126,7 @@ def empty_message(): return mock_msg("No content.") -@pytest.fixture +@pytest.fixture() def ctx(empty_member, empty_channel, red): mock_ctx = namedtuple("Context", "author guild channel message bot") return mock_ctx(empty_member, empty_member.guild, empty_channel, @@ -93,15 +135,14 @@ def ctx(empty_member, empty_channel, red): #region Red Mock -@pytest.fixture -def red(monkeysession, config_fr): +@pytest.fixture() +def red(config_fr): from core.cli import parse_cli_flags cli_flags = parse_cli_flags() description = "Red v3 - Alpha" - monkeysession.setattr("core.config.Config.get_core_conf", - lambda *args, **kwargs: config_fr) + Config.get_core_conf = (lambda *args, **kwargs: config_fr) red = Red(cli_flags, description=description, pm_help=None) diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 6c59fbd4c..8db6176c8 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -15,9 +15,9 @@ def test_config_register_global_badvalues(config): def test_config_register_guild(config, empty_guild): config.register_guild(enabled=False, some_list=[], some_dict={}) - assert config.defaults["GUILD"]["enabled"] is False - assert config.defaults["GUILD"]["some_list"] == [] - assert config.defaults["GUILD"]["some_dict"] == {} + assert config.defaults[config.GUILD]["enabled"] is False + assert config.defaults[config.GUILD]["some_list"] == [] + assert config.defaults[config.GUILD]["some_dict"] == {} assert config.guild(empty_guild).enabled() is False assert config.guild(empty_guild).some_list() == [] @@ -26,25 +26,25 @@ def test_config_register_guild(config, empty_guild): def test_config_register_channel(config, empty_channel): config.register_channel(enabled=False) - assert config.defaults["CHANNEL"]["enabled"] is False + assert config.defaults[config.CHANNEL]["enabled"] is False assert config.channel(empty_channel).enabled() is False def test_config_register_role(config, empty_role): config.register_role(enabled=False) - assert config.defaults["ROLE"]["enabled"] is False + assert config.defaults[config.ROLE]["enabled"] is False assert config.role(empty_role).enabled() is False def test_config_register_member(config, empty_member): config.register_member(some_number=-1) - assert config.defaults["MEMBER"]["some_number"] == -1 + assert config.defaults[config.MEMBER]["some_number"] == -1 assert config.member(empty_member).some_number() == -1 def test_config_register_user(config, empty_user): config.register_user(some_value=None) - assert config.defaults["USER"]["some_value"] is None + assert config.defaults[config.USER]["some_value"] is None assert config.user(empty_user).some_value() is None @@ -57,106 +57,233 @@ def test_config_force_register_global(config_fr): #endregion +# Test nested registration +def test_nested_registration(config): + config.register_global(foo__bar__baz=False) + assert config.foo.bar.baz() is False + + +def test_nested_registration_asdict(config): + defaults = {'bar': {'baz': False}} + config.register_global(foo=defaults) + + assert config.foo.bar.baz() is False + + +@pytest.mark.asyncio +async def test_nested_registration_and_changing(config): + defaults = {'bar': {'baz': False}} + config.register_global(foo=defaults) + + assert config.foo.bar.baz() is False + + with pytest.raises(ValueError): + await config.foo.set(True) + + +def test_doubleset_default(config): + config.register_global(foo=True) + config.register_global(foo=False) + + assert config.foo() is False + + +def test_nested_registration_multidict(config): + defaults = { + "foo": { + "bar": { + "baz": True + } + }, + "blah": True + } + config.register_global(**defaults) + + assert config.foo.bar.baz() is True + assert config.blah() is True + + +def test_nested_group_value_badreg(config): + config.register_global(foo=True) + with pytest.raises(KeyError): + config.register_global(foo__bar=False) + + +def test_nested_toplevel_reg(config): + defaults = {'bar': True, 'baz': False} + config.register_global(foo=defaults) + + assert config.foo.bar() is True + assert config.foo.baz() is False + + +def test_nested_overlapping(config): + config.register_global(foo__bar=True) + config.register_global(foo__baz=False) + + assert config.foo.bar() is True + assert config.foo.baz() is False + + +def test_nesting_nofr(config): + config.register_global(foo={}) + assert config.foo.bar() is None + assert config.foo() == {} + + #region Default Value Overrides def test_global_default_override(config): assert config.enabled(True) is True - assert config.get("enabled") is None - assert config.get("enabled", default=True) is True def test_global_default_nofr(config): assert config.nofr() is None assert config.nofr(True) is True - assert config.get("nofr") is None - assert config.get("nofr", default=True) is True def test_guild_default_override(config, empty_guild): assert config.guild(empty_guild).enabled(True) is True - assert config.guild(empty_guild).get("enabled") is None - assert config.guild(empty_guild).get("enabled", default=True) is True def test_channel_default_override(config, empty_channel): assert config.channel(empty_channel).enabled(True) is True - assert config.channel(empty_channel).get("enabled") is None - assert config.channel(empty_channel).get("enabled", default=True) is True def test_role_default_override(config, empty_role): assert config.role(empty_role).enabled(True) is True - assert config.role(empty_role).get("enabled") is None - assert config.role(empty_role).get("enabled", default=True) is True def test_member_default_override(config, empty_member): assert config.member(empty_member).enabled(True) is True - assert config.member(empty_member).get("enabled") is None - assert config.member(empty_member).get("enabled", default=True) is True def test_user_default_override(config, empty_user): assert config.user(empty_user).some_value(True) is True - assert config.user(empty_user).get("some_value") is None - assert config.user(empty_user).get("some_value", default=True) is True #endregion #region Setting Values @pytest.mark.asyncio async def test_set_global(config): - await config.set("enabled", True) + await config.enabled.set(True) assert config.enabled() is True -@pytest.mark.asyncio -async def test_set_global_badkey(config): - with pytest.raises(RuntimeError): - await config.set("this is a bad key", True) - - -@pytest.mark.asyncio -async def test_set_global_invalidkey(config): - with pytest.raises(KeyError): - await config.set("uuid", True) - - @pytest.mark.asyncio async def test_set_guild(config, empty_guild): - await config.guild(empty_guild).set("enabled", True) + await config.guild(empty_guild).enabled.set(True) assert config.guild(empty_guild).enabled() is True curr_list = config.guild(empty_guild).some_list([1, 2, 3]) assert curr_list == [1, 2, 3] curr_list.append(4) - await config.guild(empty_guild).set("some_list", curr_list) + await config.guild(empty_guild).some_list.set(curr_list) assert config.guild(empty_guild).some_list() == curr_list @pytest.mark.asyncio async def test_set_channel(config, empty_channel): - await config.channel(empty_channel).set("enabled", True) + await config.channel(empty_channel).enabled.set(True) assert config.channel(empty_channel).enabled() is True @pytest.mark.asyncio async def test_set_channel_no_register(config, empty_channel): - await config.channel(empty_channel).set("no_register", True) + await config.channel(empty_channel).no_register.set(True) assert config.channel(empty_channel).no_register() is True #endregion -# region Getting Values -def test_get_func_w_reg(config): - config.register_global( - thing=True - ) - assert config.get("thing") is True - assert config.get("thing", False) is False +# Dynamic attribute testing +@pytest.mark.asyncio +async def test_set_dynamic_attr(config): + await config.set_attr("foobar", True) + + assert config.foobar() is True -def test_get_func_wo_reg(config): - assert config.get("thing") is None - assert config.get("thing", True) is True -# endregion +def test_get_dynamic_attr(config): + assert config.get_attr("foobaz", True) is True + + +# Member Group testing +@pytest.mark.asyncio +async def test_membergroup_allguilds(config, empty_member): + await config.member(empty_member).foo.set(False) + + all_servers = config.member(empty_member).all_guilds() + assert str(empty_member.guild.id) in all_servers + + +@pytest.mark.asyncio +async def test_membergroup_allmembers(config, empty_member): + await config.member(empty_member).foo.set(False) + + all_members = config.member(empty_member).all() + assert str(empty_member.id) in all_members + + +# Clearing testing +@pytest.mark.asyncio +async def test_global_clear(config): + config.register_global(foo=True, bar=False) + + await config.foo.set(False) + await config.bar.set(True) + + assert config.foo() is False + assert config.bar() is True + + await config.clear() + + assert config.foo() is True + assert config.bar() is False + + +@pytest.mark.asyncio +async def test_member_clear(config, member_factory): + config.register_member(foo=True) + + m1 = member_factory.get() + await config.member(m1).foo.set(False) + assert config.member(m1).foo() is False + + m2 = member_factory.get() + await config.member(m2).foo.set(False) + assert config.member(m2).foo() is False + + assert m1.guild.id != m2.guild.id + + await config.member(m1).clear() + assert config.member(m1).foo() is True + assert config.member(m2).foo() is False + + +@pytest.mark.asyncio +async def test_member_clear_all(config, member_factory): + server_ids = [] + for _ in range(5): + member = member_factory.get() + await config.member(member).foo.set(True) + server_ids.append(member.guild.id) + + member = member_factory.get() + assert len(config.member(member).all_guilds()) == len(server_ids) + + await config.member(member).clear_all() + + assert len(config.member(member).all_guilds()) == 0 + + +# Get All testing +@pytest.mark.asyncio +async def test_user_get_all(config, user_factory): + for _ in range(5): + user = user_factory.get() + await config.user(user).foo.set(True) + + user = user_factory.get() + all_data = config.user(user).all() + + assert len(all_data) == 5