diff --git a/docs/framework_config.rst b/docs/framework_config.rst index 3837d9915..96b8d5768 100644 --- a/docs/framework_config.rst +++ b/docs/framework_config.rst @@ -374,6 +374,21 @@ API Reference inside the bot itself! Simply take a peek inside of the :code:`tests/core/test_config.py` file for examples of using Config in all kinds of ways. +.. important:: + + When getting, setting or clearing values in Config, all keys are casted to `str` for you. This + includes keys within a `dict` when one is being set, as well as keys in nested dictionaries + within that `dict`. For example:: + + >>> conf = Config.get_conf(self, identifier=999) + >>> conf.register_global(foo={}) + >>> await conf.foo.set_raw(123, value=True) + >>> await conf.foo() + {'123': True} + >>> await conf.foo.set({123: True, 456: {789: False}} + >>> await conf.foo() + {'123': True, '456': {'789': False}} + .. automodule:: redbot.core.config Config diff --git a/redbot/cogs/audio/__init__.py b/redbot/cogs/audio/__init__.py index 2f35ddd32..b13d2a1a3 100644 --- a/redbot/cogs/audio/__init__.py +++ b/redbot/cogs/audio/__init__.py @@ -34,14 +34,14 @@ async def download_lavalink(session): async def maybe_download_lavalink(loop, cog): jar_exists = LAVALINK_JAR_FILE.exists() - current_build = redbot.core.VersionInfo.from_json(await cog.config.current_build()) + current_build = redbot.core.VersionInfo.from_json(await cog.config.current_version()) if not jar_exists or current_build < redbot.core.version_info: log.info("Downloading Lavalink.jar") LAVALINK_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True) async with ClientSession(loop=loop) as session: await download_lavalink(session) - await cog.config.current_build.set(redbot.core.version_info.to_json()) + await cog.config.current_version.set(redbot.core.version_info.to_json()) shutil.copyfile(str(BUNDLED_APP_YML_FILE), str(APP_YML_FILE)) diff --git a/redbot/cogs/audio/audio.py b/redbot/cogs/audio/audio.py index 9984f841d..ec18d25d5 100644 --- a/redbot/cogs/audio/audio.py +++ b/redbot/cogs/audio/audio.py @@ -48,7 +48,7 @@ class Audio(commands.Cog): "ws_port": "2332", "password": "youshallnotpass", "status": False, - "current_build": redbot.core.VersionInfo.from_str("3.0.0a0").to_json(), + "current_version": redbot.core.VersionInfo.from_str("3.0.0a0").to_json(), "use_external_lavalink": False, } diff --git a/redbot/cogs/cleanup/cleanup.py b/redbot/cogs/cleanup/cleanup.py index b685af276..8c61dc172 100644 --- a/redbot/cogs/cleanup/cleanup.py +++ b/redbot/cogs/cleanup/cleanup.py @@ -169,7 +169,7 @@ class Cleanup(commands.Cog): member = None try: - member = await commands.converter.MemberConverter().convert(ctx, user) + member = await commands.MemberConverter().convert(ctx, user) except commands.BadArgument: try: _id = int(user) diff --git a/redbot/cogs/permissions/permissions.py b/redbot/cogs/permissions/permissions.py index f0bbd13d8..00854d6f1 100644 --- a/redbot/cogs/permissions/permissions.py +++ b/redbot/cogs/permissions/permissions.py @@ -542,7 +542,8 @@ class Permissions(commands.Cog): continue conf = self.config.custom(category) for cmd_name, cmd_rules in rules_dict.items(): - await conf.set_raw(cmd_name, guild_id, value=cmd_rules) + cmd_rules = {str(model_id): rule for model_id, rule in cmd_rules.items()} + await conf.set_raw(cmd_name, str(guild_id), value=cmd_rules) cmd_obj = getter(cmd_name) if cmd_obj is not None: self._load_rules_for(cmd_obj, {guild_id: cmd_rules}) @@ -651,14 +652,14 @@ class Permissions(commands.Cog): if category in old_rules: for name, rules in old_rules[category].items(): these_rules = new_rules.setdefault(name, {}) - guild_rules = these_rules.setdefault(guild_id, {}) + guild_rules = these_rules.setdefault(str(guild_id), {}) # Since allow rules would take precedence if the same model ID # sat in both the allow and deny list, we add the deny entries # first and let any conflicting allow entries overwrite. for model_id in rules.get("deny", []): - guild_rules[model_id] = False + guild_rules[str(model_id)] = False for model_id in rules.get("allow", []): - guild_rules[model_id] = True + guild_rules[str(model_id)] = True if "default" in rules: default = rules["default"] if default == "allow": @@ -689,7 +690,9 @@ class Permissions(commands.Cog): """ for guild_id, guild_dict in _int_key_map(rule_dict.items()): for model_id, rule in _int_key_map(guild_dict.items()): - if rule is True: + if model_id == "default": + cog_or_command.set_default_rule(rule, guild_id=guild_id) + elif rule is True: cog_or_command.allow_for(model_id, guild_id=guild_id) elif rule is False: cog_or_command.deny_to(model_id, guild_id=guild_id) @@ -724,9 +727,16 @@ class Permissions(commands.Cog): rules. """ for guild_id, guild_dict in _int_key_map(rule_dict.items()): - for model_id in map(int, guild_dict.keys()): - cog_or_command.clear_rule_for(model_id, guild_id) + for model_id in guild_dict.keys(): + if model_id == "default": + cog_or_command.set_default_rule(None, guild_id=guild_id) + else: + cog_or_command.clear_rule_for(int(model_id), guild_id=guild_id) -def _int_key_map(items_view: ItemsView[str, Any]) -> Iterator[Tuple[int, Any]]: - return map(lambda tup: (int(tup[0]), tup[1]), items_view) +def _int_key_map(items_view: ItemsView[str, Any]) -> Iterator[Tuple[Union[str, int], Any]]: + for k, v in items_view: + if k == "default": + yield k, v + else: + yield int(k), v diff --git a/redbot/cogs/trivia/session.py b/redbot/cogs/trivia/session.py index cc6e2e142..4cfcfd6a4 100644 --- a/redbot/cogs/trivia/session.py +++ b/redbot/cogs/trivia/session.py @@ -322,9 +322,9 @@ def _parse_answers(answers): for answer in answers: if isinstance(answer, bool): if answer is True: - ret.extend(["True", "Yes", _("Yes")]) + ret.extend(["True", "Yes", "On"]) else: - ret.extend(["False", "No", _("No")]) + ret.extend(["False", "No", "Off"]) else: ret.append(str(answer)) # Uniquify list diff --git a/redbot/cogs/warnings/helpers.py b/redbot/cogs/warnings/helpers.py index 39aae8739..05550ac91 100644 --- a/redbot/cogs/warnings/helpers.py +++ b/redbot/cogs/warnings/helpers.py @@ -19,9 +19,11 @@ async def warning_points_add_check( act = {} async with guild_settings.actions() as registered_actions: for a in registered_actions: + # Actions are sorted in decreasing order of points. + # The first action we find where the user is above the threshold will be the + # highest action we can take. if points >= a["points"]: act = a - else: break if act and act["exceed_command"] is not None: # some action needs to be taken await create_and_invoke_context(ctx, act["exceed_command"], user) diff --git a/redbot/cogs/warnings/warnings.py b/redbot/cogs/warnings/warnings.py index 1b545591b..21cd353b5 100644 --- a/redbot/cogs/warnings/warnings.py +++ b/redbot/cogs/warnings/warnings.py @@ -9,7 +9,7 @@ from redbot.cogs.warnings.helpers import ( get_command_for_dropping_points, warning_points_remove_check, ) -from redbot.core import Config, modlog, checks, commands +from redbot.core import Config, checks, commands from redbot.core.bot import Red from redbot.core.i18n import Translator, cog_i18n from redbot.core.utils.mod import is_admin_or_superior @@ -34,15 +34,14 @@ class Warnings(commands.Cog): self.config.register_guild(**self.default_guild) self.config.register_member(**self.default_member) self.bot = bot - loop = asyncio.get_event_loop() - loop.create_task(self.register_warningtype()) - @staticmethod - async def register_warningtype(): - try: - await modlog.register_casetype("warning", True, "\N{WARNING SIGN}", "Warning", None) - except RuntimeError: - pass + # We're not utilising modlog yet - no need to register a casetype + # @staticmethod + # async def register_warningtype(): + # try: + # await modlog.register_casetype("warning", True, "\N{WARNING SIGN}", "Warning", None) + # except RuntimeError: + # pass @commands.group() @commands.guild_only() diff --git a/redbot/core/config.py b/redbot/core/config.py index b9b7a0a76..70b37f227 100644 --- a/redbot/core/config.py +++ b/redbot/core/config.py @@ -39,17 +39,21 @@ class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]): async def __aenter__(self): self.raw_value = await self - self.__original_value = deepcopy(self.raw_value) if not isinstance(self.raw_value, (list, dict)): raise TypeError( "Type of retrieved value must be mutable (i.e. " "list or dict) in order to use a config value as " "a context manager." ) + self.__original_value = deepcopy(self.raw_value) return self.raw_value async def __aexit__(self, exc_type, exc, tb): - if self.raw_value != self.__original_value: + if isinstance(self.raw_value, dict): + raw_value = _str_key_dict(self.raw_value) + else: + raw_value = self.raw_value + if raw_value != self.__original_value: await self.value_obj.set(self.raw_value) @@ -58,7 +62,7 @@ class Value: Attributes ---------- - identifiers : `tuple` of `str` + identifiers : Tuple[str] This attribute provides all the keys necessary to get a specific data element from a json document. default @@ -69,15 +73,10 @@ class Value: """ def __init__(self, identifiers: Tuple[str], default_value, driver): - self._identifiers = identifiers + self.identifiers = identifiers self.default = default_value - self.driver = driver - @property - def identifiers(self): - return tuple(str(i) for i in self._identifiers) - async def _get(self, default=...): try: ret = await self.driver.get(*self.identifiers) @@ -149,6 +148,8 @@ class Value: The new literal value of this attribute. """ + if isinstance(value, dict): + value = _str_key_dict(value) await self.driver.set(*self.identifiers, value=value) async def clear(self): @@ -192,7 +193,10 @@ class Group(Value): async def _get(self, default: Dict[str, Any] = ...) -> Dict[str, Any]: default = default if default is not ... else self.defaults raw = await super()._get(default) - return self.nested_update(raw, default) + if isinstance(raw, dict): + return self.nested_update(raw, default) + else: + return raw # noinspection PyTypeChecker def __getattr__(self, item: str) -> Union["Group", Value]: @@ -238,7 +242,7 @@ class Group(Value): else: return Value(identifiers=new_identifiers, default_value=None, driver=self.driver) - async def clear_raw(self, *nested_path: str): + async def clear_raw(self, *nested_path: Any): """ Allows a developer to clear data as if it was stored in a standard Python dictionary. @@ -254,44 +258,44 @@ class Group(Value): Parameters ---------- - nested_path : str + nested_path : Any Multiple arguments that mirror the arguments passed in for nested - dict access. + dict access. These are casted to `str` for you. """ path = [str(p) for p in nested_path] await self.driver.clear(*self.identifiers, *path) - def is_group(self, item: str) -> bool: + def is_group(self, item: Any) -> bool: """A helper method for `__getattr__`. Most developers will have no need to use this. Parameters ---------- - item : str + item : Any See `__getattr__`. """ - default = self._defaults.get(item) + default = self._defaults.get(str(item)) return isinstance(default, dict) - def is_value(self, item: str) -> bool: + def is_value(self, item: Any) -> bool: """A helper method for `__getattr__`. Most developers will have no need to use this. Parameters ---------- - item : str + item : Any See `__getattr__`. """ try: - default = self._defaults[item] + default = self._defaults[str(item)] except KeyError: return False return not isinstance(default, dict) - def get_attr(self, item: str): + def get_attr(self, item: Union[int, str]): """Manually get an attribute of this Group. This is available to use as an alternative to using normal Python @@ -312,7 +316,8 @@ class Group(Value): Parameters ---------- item : str - The name of the data field in `Config`. + The name of the data field in `Config`. This is casted to + `str` for you. Returns ------- @@ -320,9 +325,11 @@ class Group(Value): The attribute which was requested. """ + if isinstance(item, int): + item = str(item) return self.__getattr__(item) - async def get_raw(self, *nested_path: str, default=...): + async def get_raw(self, *nested_path: Any, default=...): """ Allows a developer to access data as if it was stored in a standard Python dictionary. @@ -345,7 +352,7 @@ class Group(Value): ---------- nested_path : str Multiple arguments that mirror the arguments passed in for nested - dict access. + dict access. These are casted to `str` for you. default Default argument for the value attempting to be accessed. If the value does not exist the default will be returned. @@ -410,7 +417,6 @@ class Group(Value): If no defaults are passed, then the instance attribute 'defaults' will be used. - """ if defaults is ...: defaults = self.defaults @@ -428,7 +434,7 @@ class Group(Value): raise ValueError("You may only set the value of a group to be a dict.") await super().set(value) - async def set_raw(self, *nested_path: str, value): + async def set_raw(self, *nested_path: Any, value): """ Allows a developer to set data as if it was stored in a standard Python dictionary. @@ -444,13 +450,15 @@ class Group(Value): Parameters ---------- - nested_path : str + nested_path : Any Multiple arguments that mirror the arguments passed in for nested - dict access. + `dict` access. These are casted to `str` for you. value The value to store. """ path = [str(p) for p in nested_path] + if isinstance(value, dict): + value = _str_key_dict(value) await self.driver.set(*self.identifiers, *path, value=value) @@ -461,9 +469,11 @@ class Config: `get_core_conf` for Config used in the core package. .. important:: - Most config data should be accessed through its respective group method (e.g. :py:meth:`guild`) - however the process for accessing global data is a bit different. There is no :python:`global` method - because global data is accessed by normal attribute access:: + Most config data should be accessed through its respective + group method (e.g. :py:meth:`guild`) however the process for + accessing global data is a bit different. There is no + :python:`global` method because global data is accessed by + normal attribute access:: await conf.foo() @@ -548,7 +558,7 @@ class Config: A new Config object. """ - if cog_instance is None and not cog_name is None: + if cog_instance is None and cog_name is not None: cog_path_override = cog_data_path(raw_name=cog_name) else: cog_path_override = cog_data_path(cog_instance=cog_instance) @@ -635,11 +645,8 @@ class Config: def _get_defaults_dict(key: str, value) -> dict: """ Since we're allowing nested config stuff now, not storing the - _defaults as a flat dict sounds like a good idea. May turn - out to be an awful one but we'll see. - :param key: - :param value: - :return: + _defaults as a flat dict sounds like a good idea. May turn out + to be an awful one but we'll see. """ ret = {} partial = ret @@ -655,15 +662,12 @@ class Config: return ret @staticmethod - def _update_defaults(to_add: dict, _partial: dict): + def _update_defaults(to_add: Dict[str, Any], _partial: Dict[str, Any]): """ This tries to update the _defaults dictionary with the nested - partial dict generated by _get_defaults_dict. This WILL - throw an error if you try to have both a value and a group - registered under the same name. - :param to_add: - :param _partial: - :return: + partial dict generated by _get_defaults_dict. This WILL + throw an error if you try to have both a value and a group + registered under the same name. """ for k, v in to_add.items(): val_is_dict = isinstance(v, dict) @@ -679,7 +683,7 @@ class Config: else: _partial[k] = v - def _register_default(self, key: str, **kwargs): + def _register_default(self, key: str, **kwargs: Any): if key not in self._defaults: self._defaults[key] = {} @@ -720,8 +724,8 @@ class Config: **_defaults ) - You can do the same thing without a :python:`_defaults` dict by using double underscore as a variable - name separator:: + You can do the same thing without a :python:`_defaults` dict by + using double underscore as a variable name separator:: # This is equivalent to the previous example conf.register_global( @@ -802,7 +806,7 @@ class Config: The guild's Group object. """ - return self._get_base_group(self.GUILD, guild.id) + return self._get_base_group(self.GUILD, str(guild.id)) def channel(self, channel: discord.TextChannel) -> Group: """Returns a `Group` for the given channel. @@ -820,7 +824,7 @@ class Config: The channel's Group object. """ - return self._get_base_group(self.CHANNEL, channel.id) + return self._get_base_group(self.CHANNEL, str(channel.id)) def role(self, role: discord.Role) -> Group: """Returns a `Group` for the given role. @@ -836,7 +840,7 @@ class Config: The role's Group object. """ - return self._get_base_group(self.ROLE, role.id) + return self._get_base_group(self.ROLE, str(role.id)) def user(self, user: discord.abc.User) -> Group: """Returns a `Group` for the given user. @@ -852,7 +856,7 @@ class Config: The user's Group object. """ - return self._get_base_group(self.USER, user.id) + return self._get_base_group(self.USER, str(user.id)) def member(self, member: discord.Member) -> Group: """Returns a `Group` for the given member. @@ -866,8 +870,9 @@ class Config: ------- `Group ` The member's Group object. + """ - return self._get_base_group(self.MEMBER, member.guild.id, member.id) + return self._get_base_group(self.MEMBER, str(member.guild.id), str(member.id)) def custom(self, group_identifier: str, *identifiers: str): """Returns a `Group` for the given custom group. @@ -876,17 +881,17 @@ class Config: ---------- group_identifier : str Used to identify the custom group. - identifiers : str The attributes necessary to uniquely identify an entry in the - custom group. + custom group. These are casted to `str` for you. Returns ------- `Group ` The custom group's Group object. + """ - return self._get_base_group(group_identifier, *identifiers) + return self._get_base_group(str(group_identifier), *map(str, identifiers)) async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]: """Get a dict of all values from a particular scope of data. @@ -982,7 +987,8 @@ class Config: """ return await self._all_from_scope(self.USER) - def _all_members_from_guild(self, group: Group, guild_data: dict) -> dict: + @staticmethod + def _all_members_from_guild(group: Group, guild_data: dict) -> dict: ret = {} for member_id, member_data in guild_data.items(): new_member_data = group.defaults @@ -1026,7 +1032,7 @@ class Config: for guild_id, guild_data in dict_.items(): ret[int(guild_id)] = self._all_members_from_guild(group, guild_data) else: - group = self._get_base_group(self.MEMBER, guild.id) + group = self._get_base_group(self.MEMBER, str(guild.id)) try: guild_data = await self.driver.get(*group.identifiers) except KeyError: @@ -1054,7 +1060,8 @@ class Config: """ if not scopes: - group = Group(identifiers=[], defaults={}, driver=self.driver) + # noinspection PyTypeChecker + group = Group(identifiers=(), defaults={}, driver=self.driver) else: group = self._get_base_group(*scopes) await group.clear() @@ -1119,7 +1126,7 @@ class Config: """ if guild is not None: - await self._clear_scope(self.MEMBER, guild.id) + await self._clear_scope(self.MEMBER, str(guild.id)) return await self._clear_scope(self.MEMBER) @@ -1127,5 +1134,34 @@ class Config: """Clear all custom group data. This resets all custom group data to its registered defaults. + + Parameters + ---------- + group_identifier : str + The identifier for the custom group. This is casted to + `str` for you. """ - await self._clear_scope(group_identifier) + await self._clear_scope(str(group_identifier)) + + +def _str_key_dict(value: Dict[Any, _T]) -> Dict[str, _T]: + """ + Recursively casts all keys in the given `dict` to `str`. + + Parameters + ---------- + value : Dict[Any, Any] + The `dict` to cast keys to `str`. + + Returns + ------- + Dict[str, Any] + The `dict` with keys (and nested keys) casted to `str`. + + """ + ret = {} + for k, v in value.items(): + if isinstance(v, dict): + v = _str_key_dict(v) + ret[str(k)] = v + return ret diff --git a/redbot/core/core_commands.py b/redbot/core/core_commands.py index 3366b69b3..125aaaa67 100644 --- a/redbot/core/core_commands.py +++ b/redbot/core/core_commands.py @@ -1118,21 +1118,20 @@ class Core(commands.Cog, CoreLogic): if basic_config["STORAGE_TYPE"] == "MongoDB": from redbot.core.drivers.red_mongo import Mongo - m = Mongo("Core", **basic_config["STORAGE_DETAILS"]) + m = Mongo("Core", "0", **basic_config["STORAGE_DETAILS"]) db = m.db - collection_names = await db.collection_names(include_system_collections=False) + collection_names = await db.list_collection_names() for c_name in collection_names: if c_name == "Core": c_data_path = data_dir / basic_config["CORE_PATH_APPEND"] else: - c_data_path = data_dir / basic_config["COG_PATH_APPEND"] - output = {} + c_data_path = data_dir / basic_config["COG_PATH_APPEND"] / c_name docs = await db[c_name].find().to_list(None) for item in docs: item_id = str(item.pop("_id")) - output[item_id] = item - target = JSON(c_name, data_path_override=c_data_path) - await target.jsonIO._threadsafe_save_json(output) + output = item + target = JSON(c_name, item_id, data_path_override=c_data_path) + await target.jsonIO._threadsafe_save_json(output) backup_filename = "redv3-{}-{}.tar.gz".format( instance_name, ctx.message.created_at.strftime("%Y-%m-%d %H-%M-%S") ) diff --git a/redbot/core/drivers/red_mongo.py b/redbot/core/drivers/red_mongo.py index 0f9bc585c..6f8415bbd 100644 --- a/redbot/core/drivers/red_mongo.py +++ b/redbot/core/drivers/red_mongo.py @@ -1,7 +1,12 @@ -import motor.motor_asyncio -from .red_base import BaseDriver +import re +from typing import Match, Pattern from urllib.parse import quote_plus +import motor.core +import motor.motor_asyncio + +from .red_base import BaseDriver + __all__ = ["Mongo"] @@ -80,6 +85,7 @@ class Mongo(BaseDriver): async def get(self, *identifiers: str): mongo_collection = self.get_collection() + identifiers = (*map(self._escape_key, identifiers),) dot_identifiers = ".".join(identifiers) partial = await mongo_collection.find_one( @@ -91,10 +97,14 @@ class Mongo(BaseDriver): for i in identifiers: partial = partial[i] + if isinstance(partial, dict): + return self._unescape_dict_keys(partial) return partial async def set(self, *identifiers: str, value=None): - dot_identifiers = ".".join(identifiers) + dot_identifiers = ".".join(map(self._escape_key, identifiers)) + if isinstance(value, dict): + value = self._escape_dict_keys(value) mongo_collection = self.get_collection() @@ -105,7 +115,7 @@ class Mongo(BaseDriver): ) async def clear(self, *identifiers: str): - dot_identifiers = ".".join(identifiers) + dot_identifiers = ".".join(map(self._escape_key, identifiers)) mongo_collection = self.get_collection() if len(identifiers) > 0: @@ -115,6 +125,62 @@ class Mongo(BaseDriver): else: await mongo_collection.delete_one({"_id": self.unique_cog_identifier}) + @staticmethod + def _escape_key(key: str) -> str: + return _SPECIAL_CHAR_PATTERN.sub(_replace_with_escaped, key) + + @staticmethod + def _unescape_key(key: str) -> str: + return _CHAR_ESCAPE_PATTERN.sub(_replace_with_unescaped, key) + + @classmethod + def _escape_dict_keys(cls, data: dict) -> dict: + """Recursively escape all keys in a dict.""" + ret = {} + for key, value in data.items(): + key = cls._escape_key(key) + if isinstance(value, dict): + value = cls._escape_dict_keys(value) + ret[key] = value + return ret + + @classmethod + def _unescape_dict_keys(cls, data: dict) -> dict: + """Recursively unescape all keys in a dict.""" + ret = {} + for key, value in data.items(): + key = cls._unescape_key(key) + if isinstance(value, dict): + value = cls._unescape_dict_keys(value) + ret[key] = value + return ret + + +_SPECIAL_CHAR_PATTERN: Pattern[str] = re.compile(r"([.$]|\\U0000002E|\\U00000024)") +_SPECIAL_CHARS = { + ".": "\\U0000002E", + "$": "\\U00000024", + "\\U0000002E": "\\U&0000002E", + "\\U00000024": "\\U&00000024", +} + + +def _replace_with_escaped(match: Match[str]) -> str: + return _SPECIAL_CHARS[match[0]] + + +_CHAR_ESCAPE_PATTERN: Pattern[str] = re.compile(r"(\\U0000002E|\\U00000024)") +_CHAR_ESCAPES = { + "\\U0000002E": ".", + "\\U00000024": "$", + "\\U&0000002E": "\\U0000002E", + "\\U&00000024": "\\U00000024", +} + + +def _replace_with_unescaped(match: Match[str]) -> str: + return _CHAR_ESCAPES[match[0]] + def get_config_details(): uri = None diff --git a/tests/cogs/test_permissions.py b/tests/cogs/test_permissions.py index 679372732..84e41a49d 100644 --- a/tests/cogs/test_permissions.py +++ b/tests/cogs/test_permissions.py @@ -3,7 +3,7 @@ from redbot.cogs.permissions.permissions import Permissions, GLOBAL def test_schema_update(): old = { - GLOBAL: { + str(GLOBAL): { "owner_models": { "cogs": { "Admin": {"allow": [78631113035100160], "deny": [96733288462286848]}, @@ -19,7 +19,7 @@ def test_schema_update(): }, } }, - 43733288462286848: { + "43733288462286848": { "owner_models": { "cogs": { "Admin": { @@ -43,22 +43,22 @@ def test_schema_update(): assert new == ( { "Admin": { - GLOBAL: {78631113035100160: True, 96733288462286848: False}, - 43733288462286848: {24231113035100160: True, 35533288462286848: False}, + str(GLOBAL): {"78631113035100160": True, "96733288462286848": False}, + "43733288462286848": {"24231113035100160": True, "35533288462286848": False}, }, - "Audio": {GLOBAL: {133049272517001216: True, "default": False}}, - "General": {43733288462286848: {133049272517001216: True, "default": False}}, + "Audio": {str(GLOBAL): {"133049272517001216": True, "default": False}}, + "General": {"43733288462286848": {"133049272517001216": True, "default": False}}, }, { "cleanup bot": { - GLOBAL: {78631113035100160: True, "default": False}, - 43733288462286848: {17831113035100160: True, "default": True}, + str(GLOBAL): {"78631113035100160": True, "default": False}, + "43733288462286848": {"17831113035100160": True, "default": True}, }, - "ping": {GLOBAL: {96733288462286848: True, "default": True}}, + "ping": {str(GLOBAL): {"96733288462286848": True, "default": True}}, "set adminrole": { - 43733288462286848: { - 87733288462286848: True, - 95433288462286848: False, + "43733288462286848": { + "87733288462286848": True, + "95433288462286848": False, "default": True, } }, diff --git a/tests/core/test_config.py b/tests/core/test_config.py index abfafd8f2..b1c224fec 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -475,3 +475,18 @@ async def test_get_raw_mixes_defaults(config): subgroup = await config.get_raw("subgroup") assert subgroup == {"foo": True, "bar": False} + + +@pytest.mark.asyncio +async def test_cast_str_raw(config): + await config.set_raw(123, 456, value=True) + assert await config.get_raw(123, 456) is True + assert await config.get_raw("123", "456") is True + await config.clear_raw("123", 456) + + +@pytest.mark.asyncio +async def test_cast_str_nested(config): + config.register_global(foo={}) + await config.foo.set({123: True, 456: {789: False}}) + assert await config.foo() == {"123": True, "456": {"789": False}}