From ce25011f0d3ea46e32f31df1dc270df9733b12b7 Mon Sep 17 00:00:00 2001 From: Toby Harradine Date: Thu, 11 Oct 2018 11:18:57 +1100 Subject: [PATCH] [Config] Cast keys to str on get/set/clear (#2217) This is a step towards a more consistent front-end behaviour of Config, where errors are either circumvented or raised in the same way regardless of the driver being used. Signed-off-by: Toby Harradine --- docs/framework_config.rst | 15 ++++ redbot/core/config.py | 149 +++++++++++++++++++++++--------------- tests/core/test_config.py | 15 ++++ 3 files changed, 121 insertions(+), 58 deletions(-) diff --git a/docs/framework_config.rst b/docs/framework_config.rst index 3837d9915..96b8d5768 100644 --- a/docs/framework_config.rst +++ b/docs/framework_config.rst @@ -374,6 +374,21 @@ API Reference inside the bot itself! Simply take a peek inside of the :code:`tests/core/test_config.py` file for examples of using Config in all kinds of ways. +.. important:: + + When getting, setting or clearing values in Config, all keys are casted to `str` for you. This + includes keys within a `dict` when one is being set, as well as keys in nested dictionaries + within that `dict`. For example:: + + >>> conf = Config.get_conf(self, identifier=999) + >>> conf.register_global(foo={}) + >>> await conf.foo.set_raw(123, value=True) + >>> await conf.foo() + {'123': True} + >>> await conf.foo.set({123: True, 456: {789: False}} + >>> await conf.foo() + {'123': True, '456': {'789': False}} + .. automodule:: redbot.core.config Config diff --git a/redbot/core/config.py b/redbot/core/config.py index b9b7a0a76..ec6e27215 100644 --- a/redbot/core/config.py +++ b/redbot/core/config.py @@ -39,17 +39,21 @@ class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]): async def __aenter__(self): self.raw_value = await self - self.__original_value = deepcopy(self.raw_value) if not isinstance(self.raw_value, (list, dict)): raise TypeError( "Type of retrieved value must be mutable (i.e. " "list or dict) in order to use a config value as " "a context manager." ) + self.__original_value = deepcopy(self.raw_value) return self.raw_value async def __aexit__(self, exc_type, exc, tb): - if self.raw_value != self.__original_value: + if isinstance(self.raw_value, dict): + raw_value = _str_key_dict(self.raw_value) + else: + raw_value = self.raw_value + if raw_value != self.__original_value: await self.value_obj.set(self.raw_value) @@ -58,7 +62,7 @@ class Value: Attributes ---------- - identifiers : `tuple` of `str` + identifiers : Tuple[str] This attribute provides all the keys necessary to get a specific data element from a json document. default @@ -69,15 +73,10 @@ class Value: """ def __init__(self, identifiers: Tuple[str], default_value, driver): - self._identifiers = identifiers + self.identifiers = identifiers self.default = default_value - self.driver = driver - @property - def identifiers(self): - return tuple(str(i) for i in self._identifiers) - async def _get(self, default=...): try: ret = await self.driver.get(*self.identifiers) @@ -149,6 +148,8 @@ class Value: The new literal value of this attribute. """ + if isinstance(value, dict): + value = _str_key_dict(value) await self.driver.set(*self.identifiers, value=value) async def clear(self): @@ -238,7 +239,7 @@ class Group(Value): else: return Value(identifiers=new_identifiers, default_value=None, driver=self.driver) - async def clear_raw(self, *nested_path: str): + async def clear_raw(self, *nested_path: Any): """ Allows a developer to clear data as if it was stored in a standard Python dictionary. @@ -254,44 +255,44 @@ class Group(Value): Parameters ---------- - nested_path : str + nested_path : Any Multiple arguments that mirror the arguments passed in for nested - dict access. + dict access. These are casted to `str` for you. """ path = [str(p) for p in nested_path] await self.driver.clear(*self.identifiers, *path) - def is_group(self, item: str) -> bool: + def is_group(self, item: Any) -> bool: """A helper method for `__getattr__`. Most developers will have no need to use this. Parameters ---------- - item : str + item : Any See `__getattr__`. """ - default = self._defaults.get(item) + default = self._defaults.get(str(item)) return isinstance(default, dict) - def is_value(self, item: str) -> bool: + def is_value(self, item: Any) -> bool: """A helper method for `__getattr__`. Most developers will have no need to use this. Parameters ---------- - item : str + item : Any See `__getattr__`. """ try: - default = self._defaults[item] + default = self._defaults[str(item)] except KeyError: return False return not isinstance(default, dict) - def get_attr(self, item: str): + def get_attr(self, item: Union[int, str]): """Manually get an attribute of this Group. This is available to use as an alternative to using normal Python @@ -312,7 +313,8 @@ class Group(Value): Parameters ---------- item : str - The name of the data field in `Config`. + The name of the data field in `Config`. This is casted to + `str` for you. Returns ------- @@ -320,9 +322,11 @@ class Group(Value): The attribute which was requested. """ + if isinstance(item, int): + item = str(item) return self.__getattr__(item) - async def get_raw(self, *nested_path: str, default=...): + async def get_raw(self, *nested_path: Any, default=...): """ Allows a developer to access data as if it was stored in a standard Python dictionary. @@ -345,7 +349,7 @@ class Group(Value): ---------- nested_path : str Multiple arguments that mirror the arguments passed in for nested - dict access. + dict access. These are casted to `str` for you. default Default argument for the value attempting to be accessed. If the value does not exist the default will be returned. @@ -410,7 +414,6 @@ class Group(Value): If no defaults are passed, then the instance attribute 'defaults' will be used. - """ if defaults is ...: defaults = self.defaults @@ -428,7 +431,7 @@ class Group(Value): raise ValueError("You may only set the value of a group to be a dict.") await super().set(value) - async def set_raw(self, *nested_path: str, value): + async def set_raw(self, *nested_path: Any, value): """ Allows a developer to set data as if it was stored in a standard Python dictionary. @@ -444,13 +447,15 @@ class Group(Value): Parameters ---------- - nested_path : str + nested_path : Any Multiple arguments that mirror the arguments passed in for nested - dict access. + `dict` access. These are casted to `str` for you. value The value to store. """ path = [str(p) for p in nested_path] + if isinstance(value, dict): + value = _str_key_dict(value) await self.driver.set(*self.identifiers, *path, value=value) @@ -461,9 +466,11 @@ class Config: `get_core_conf` for Config used in the core package. .. important:: - Most config data should be accessed through its respective group method (e.g. :py:meth:`guild`) - however the process for accessing global data is a bit different. There is no :python:`global` method - because global data is accessed by normal attribute access:: + Most config data should be accessed through its respective + group method (e.g. :py:meth:`guild`) however the process for + accessing global data is a bit different. There is no + :python:`global` method because global data is accessed by + normal attribute access:: await conf.foo() @@ -548,7 +555,7 @@ class Config: A new Config object. """ - if cog_instance is None and not cog_name is None: + if cog_instance is None and cog_name is not None: cog_path_override = cog_data_path(raw_name=cog_name) else: cog_path_override = cog_data_path(cog_instance=cog_instance) @@ -635,11 +642,8 @@ class Config: def _get_defaults_dict(key: str, value) -> dict: """ Since we're allowing nested config stuff now, not storing the - _defaults as a flat dict sounds like a good idea. May turn - out to be an awful one but we'll see. - :param key: - :param value: - :return: + _defaults as a flat dict sounds like a good idea. May turn out + to be an awful one but we'll see. """ ret = {} partial = ret @@ -655,15 +659,12 @@ class Config: return ret @staticmethod - def _update_defaults(to_add: dict, _partial: dict): + def _update_defaults(to_add: Dict[str, Any], _partial: Dict[str, Any]): """ This tries to update the _defaults dictionary with the nested - partial dict generated by _get_defaults_dict. This WILL - throw an error if you try to have both a value and a group - registered under the same name. - :param to_add: - :param _partial: - :return: + partial dict generated by _get_defaults_dict. This WILL + throw an error if you try to have both a value and a group + registered under the same name. """ for k, v in to_add.items(): val_is_dict = isinstance(v, dict) @@ -679,7 +680,7 @@ class Config: else: _partial[k] = v - def _register_default(self, key: str, **kwargs): + def _register_default(self, key: str, **kwargs: Any): if key not in self._defaults: self._defaults[key] = {} @@ -720,8 +721,8 @@ class Config: **_defaults ) - You can do the same thing without a :python:`_defaults` dict by using double underscore as a variable - name separator:: + You can do the same thing without a :python:`_defaults` dict by + using double underscore as a variable name separator:: # This is equivalent to the previous example conf.register_global( @@ -802,7 +803,7 @@ class Config: The guild's Group object. """ - return self._get_base_group(self.GUILD, guild.id) + return self._get_base_group(self.GUILD, str(guild.id)) def channel(self, channel: discord.TextChannel) -> Group: """Returns a `Group` for the given channel. @@ -820,7 +821,7 @@ class Config: The channel's Group object. """ - return self._get_base_group(self.CHANNEL, channel.id) + return self._get_base_group(self.CHANNEL, str(channel.id)) def role(self, role: discord.Role) -> Group: """Returns a `Group` for the given role. @@ -836,7 +837,7 @@ class Config: The role's Group object. """ - return self._get_base_group(self.ROLE, role.id) + return self._get_base_group(self.ROLE, str(role.id)) def user(self, user: discord.abc.User) -> Group: """Returns a `Group` for the given user. @@ -852,7 +853,7 @@ class Config: The user's Group object. """ - return self._get_base_group(self.USER, user.id) + return self._get_base_group(self.USER, str(user.id)) def member(self, member: discord.Member) -> Group: """Returns a `Group` for the given member. @@ -866,8 +867,9 @@ class Config: ------- `Group ` The member's Group object. + """ - return self._get_base_group(self.MEMBER, member.guild.id, member.id) + return self._get_base_group(self.MEMBER, str(member.guild.id), str(member.id)) def custom(self, group_identifier: str, *identifiers: str): """Returns a `Group` for the given custom group. @@ -876,17 +878,17 @@ class Config: ---------- group_identifier : str Used to identify the custom group. - identifiers : str The attributes necessary to uniquely identify an entry in the - custom group. + custom group. These are casted to `str` for you. Returns ------- `Group ` The custom group's Group object. + """ - return self._get_base_group(group_identifier, *identifiers) + return self._get_base_group(str(group_identifier), *map(str, identifiers)) async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]: """Get a dict of all values from a particular scope of data. @@ -982,7 +984,8 @@ class Config: """ return await self._all_from_scope(self.USER) - def _all_members_from_guild(self, group: Group, guild_data: dict) -> dict: + @staticmethod + def _all_members_from_guild(group: Group, guild_data: dict) -> dict: ret = {} for member_id, member_data in guild_data.items(): new_member_data = group.defaults @@ -1026,7 +1029,7 @@ class Config: for guild_id, guild_data in dict_.items(): ret[int(guild_id)] = self._all_members_from_guild(group, guild_data) else: - group = self._get_base_group(self.MEMBER, guild.id) + group = self._get_base_group(self.MEMBER, str(guild.id)) try: guild_data = await self.driver.get(*group.identifiers) except KeyError: @@ -1054,7 +1057,8 @@ class Config: """ if not scopes: - group = Group(identifiers=[], defaults={}, driver=self.driver) + # noinspection PyTypeChecker + group = Group(identifiers=(), defaults={}, driver=self.driver) else: group = self._get_base_group(*scopes) await group.clear() @@ -1119,7 +1123,7 @@ class Config: """ if guild is not None: - await self._clear_scope(self.MEMBER, guild.id) + await self._clear_scope(self.MEMBER, str(guild.id)) return await self._clear_scope(self.MEMBER) @@ -1127,5 +1131,34 @@ class Config: """Clear all custom group data. This resets all custom group data to its registered defaults. + + Parameters + ---------- + group_identifier : str + The identifier for the custom group. This is casted to + `str` for you. """ - await self._clear_scope(group_identifier) + await self._clear_scope(str(group_identifier)) + + +def _str_key_dict(value: Dict[Any, _T]) -> Dict[str, _T]: + """ + Recursively casts all keys in the given `dict` to `str`. + + Parameters + ---------- + value : Dict[Any, Any] + The `dict` to cast keys to `str`. + + Returns + ------- + Dict[str, Any] + The `dict` with keys (and nested keys) casted to `str`. + + """ + ret = {} + for k, v in value.items(): + if isinstance(v, dict): + v = _str_key_dict(v) + ret[str(k)] = v + return ret diff --git a/tests/core/test_config.py b/tests/core/test_config.py index abfafd8f2..b1c224fec 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -475,3 +475,18 @@ async def test_get_raw_mixes_defaults(config): subgroup = await config.get_raw("subgroup") assert subgroup == {"foo": True, "bar": False} + + +@pytest.mark.asyncio +async def test_cast_str_raw(config): + await config.set_raw(123, 456, value=True) + assert await config.get_raw(123, 456) is True + assert await config.get_raw("123", "456") is True + await config.clear_raw("123", 456) + + +@pytest.mark.asyncio +async def test_cast_str_nested(config): + config.register_global(foo={}) + await config.foo.set({123: True, 456: {789: False}}) + assert await config.foo() == {"123": True, "456": {"789": False}}