[Config] Cast keys to str on get/set/clear (#2217)

This is a step towards a more consistent front-end behaviour of Config, where errors are either circumvented or raised in the same way regardless of the driver being used.

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
Toby Harradine 2018-10-11 11:18:57 +11:00 committed by GitHub
parent f85034eb27
commit ce25011f0d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 121 additions and 58 deletions

View File

@ -374,6 +374,21 @@ API Reference
inside the bot itself! Simply take a peek inside of the :code:`tests/core/test_config.py` file for examples of using inside the bot itself! Simply take a peek inside of the :code:`tests/core/test_config.py` file for examples of using
Config in all kinds of ways. Config in all kinds of ways.
.. important::
When getting, setting or clearing values in Config, all keys are casted to `str` for you. This
includes keys within a `dict` when one is being set, as well as keys in nested dictionaries
within that `dict`. For example::
>>> conf = Config.get_conf(self, identifier=999)
>>> conf.register_global(foo={})
>>> await conf.foo.set_raw(123, value=True)
>>> await conf.foo()
{'123': True}
>>> await conf.foo.set({123: True, 456: {789: False}}
>>> await conf.foo()
{'123': True, '456': {'789': False}}
.. automodule:: redbot.core.config .. automodule:: redbot.core.config
Config Config

View File

@ -39,17 +39,21 @@ class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]):
async def __aenter__(self): async def __aenter__(self):
self.raw_value = await self self.raw_value = await self
self.__original_value = deepcopy(self.raw_value)
if not isinstance(self.raw_value, (list, dict)): if not isinstance(self.raw_value, (list, dict)):
raise TypeError( raise TypeError(
"Type of retrieved value must be mutable (i.e. " "Type of retrieved value must be mutable (i.e. "
"list or dict) in order to use a config value as " "list or dict) in order to use a config value as "
"a context manager." "a context manager."
) )
self.__original_value = deepcopy(self.raw_value)
return self.raw_value return self.raw_value
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
if self.raw_value != self.__original_value: if isinstance(self.raw_value, dict):
raw_value = _str_key_dict(self.raw_value)
else:
raw_value = self.raw_value
if raw_value != self.__original_value:
await self.value_obj.set(self.raw_value) await self.value_obj.set(self.raw_value)
@ -58,7 +62,7 @@ class Value:
Attributes Attributes
---------- ----------
identifiers : `tuple` of `str` identifiers : Tuple[str]
This attribute provides all the keys necessary to get a specific data This attribute provides all the keys necessary to get a specific data
element from a json document. element from a json document.
default default
@ -69,15 +73,10 @@ class Value:
""" """
def __init__(self, identifiers: Tuple[str], default_value, driver): def __init__(self, identifiers: Tuple[str], default_value, driver):
self._identifiers = identifiers self.identifiers = identifiers
self.default = default_value self.default = default_value
self.driver = driver self.driver = driver
@property
def identifiers(self):
return tuple(str(i) for i in self._identifiers)
async def _get(self, default=...): async def _get(self, default=...):
try: try:
ret = await self.driver.get(*self.identifiers) ret = await self.driver.get(*self.identifiers)
@ -149,6 +148,8 @@ class Value:
The new literal value of this attribute. The new literal value of this attribute.
""" """
if isinstance(value, dict):
value = _str_key_dict(value)
await self.driver.set(*self.identifiers, value=value) await self.driver.set(*self.identifiers, value=value)
async def clear(self): async def clear(self):
@ -238,7 +239,7 @@ class Group(Value):
else: else:
return Value(identifiers=new_identifiers, default_value=None, driver=self.driver) return Value(identifiers=new_identifiers, default_value=None, driver=self.driver)
async def clear_raw(self, *nested_path: str): async def clear_raw(self, *nested_path: Any):
""" """
Allows a developer to clear data as if it was stored in a standard Allows a developer to clear data as if it was stored in a standard
Python dictionary. Python dictionary.
@ -254,44 +255,44 @@ class Group(Value):
Parameters Parameters
---------- ----------
nested_path : str nested_path : Any
Multiple arguments that mirror the arguments passed in for nested Multiple arguments that mirror the arguments passed in for nested
dict access. dict access. These are casted to `str` for you.
""" """
path = [str(p) for p in nested_path] path = [str(p) for p in nested_path]
await self.driver.clear(*self.identifiers, *path) await self.driver.clear(*self.identifiers, *path)
def is_group(self, item: str) -> bool: def is_group(self, item: Any) -> bool:
"""A helper method for `__getattr__`. Most developers will have no need """A helper method for `__getattr__`. Most developers will have no need
to use this. to use this.
Parameters Parameters
---------- ----------
item : str item : Any
See `__getattr__`. See `__getattr__`.
""" """
default = self._defaults.get(item) default = self._defaults.get(str(item))
return isinstance(default, dict) return isinstance(default, dict)
def is_value(self, item: str) -> bool: def is_value(self, item: Any) -> bool:
"""A helper method for `__getattr__`. Most developers will have no need """A helper method for `__getattr__`. Most developers will have no need
to use this. to use this.
Parameters Parameters
---------- ----------
item : str item : Any
See `__getattr__`. See `__getattr__`.
""" """
try: try:
default = self._defaults[item] default = self._defaults[str(item)]
except KeyError: except KeyError:
return False return False
return not isinstance(default, dict) return not isinstance(default, dict)
def get_attr(self, item: str): def get_attr(self, item: Union[int, str]):
"""Manually get an attribute of this Group. """Manually get an attribute of this Group.
This is available to use as an alternative to using normal Python This is available to use as an alternative to using normal Python
@ -312,7 +313,8 @@ class Group(Value):
Parameters Parameters
---------- ----------
item : str item : str
The name of the data field in `Config`. The name of the data field in `Config`. This is casted to
`str` for you.
Returns Returns
------- -------
@ -320,9 +322,11 @@ class Group(Value):
The attribute which was requested. The attribute which was requested.
""" """
if isinstance(item, int):
item = str(item)
return self.__getattr__(item) return self.__getattr__(item)
async def get_raw(self, *nested_path: str, default=...): async def get_raw(self, *nested_path: Any, default=...):
""" """
Allows a developer to access data as if it was stored in a standard Allows a developer to access data as if it was stored in a standard
Python dictionary. Python dictionary.
@ -345,7 +349,7 @@ class Group(Value):
---------- ----------
nested_path : str nested_path : str
Multiple arguments that mirror the arguments passed in for nested Multiple arguments that mirror the arguments passed in for nested
dict access. dict access. These are casted to `str` for you.
default default
Default argument for the value attempting to be accessed. If the Default argument for the value attempting to be accessed. If the
value does not exist the default will be returned. value does not exist the default will be returned.
@ -410,7 +414,6 @@ class Group(Value):
If no defaults are passed, then the instance attribute 'defaults' If no defaults are passed, then the instance attribute 'defaults'
will be used. will be used.
""" """
if defaults is ...: if defaults is ...:
defaults = self.defaults defaults = self.defaults
@ -428,7 +431,7 @@ class Group(Value):
raise ValueError("You may only set the value of a group to be a dict.") raise ValueError("You may only set the value of a group to be a dict.")
await super().set(value) await super().set(value)
async def set_raw(self, *nested_path: str, value): async def set_raw(self, *nested_path: Any, value):
""" """
Allows a developer to set data as if it was stored in a standard Allows a developer to set data as if it was stored in a standard
Python dictionary. Python dictionary.
@ -444,13 +447,15 @@ class Group(Value):
Parameters Parameters
---------- ----------
nested_path : str nested_path : Any
Multiple arguments that mirror the arguments passed in for nested Multiple arguments that mirror the arguments passed in for nested
dict access. `dict` access. These are casted to `str` for you.
value value
The value to store. The value to store.
""" """
path = [str(p) for p in nested_path] path = [str(p) for p in nested_path]
if isinstance(value, dict):
value = _str_key_dict(value)
await self.driver.set(*self.identifiers, *path, value=value) await self.driver.set(*self.identifiers, *path, value=value)
@ -461,9 +466,11 @@ class Config:
`get_core_conf` for Config used in the core package. `get_core_conf` for Config used in the core package.
.. important:: .. important::
Most config data should be accessed through its respective group method (e.g. :py:meth:`guild`) Most config data should be accessed through its respective
however the process for accessing global data is a bit different. There is no :python:`global` method group method (e.g. :py:meth:`guild`) however the process for
because global data is accessed by normal attribute access:: accessing global data is a bit different. There is no
:python:`global` method because global data is accessed by
normal attribute access::
await conf.foo() await conf.foo()
@ -548,7 +555,7 @@ class Config:
A new Config object. A new Config object.
""" """
if cog_instance is None and not cog_name is None: if cog_instance is None and cog_name is not None:
cog_path_override = cog_data_path(raw_name=cog_name) cog_path_override = cog_data_path(raw_name=cog_name)
else: else:
cog_path_override = cog_data_path(cog_instance=cog_instance) cog_path_override = cog_data_path(cog_instance=cog_instance)
@ -635,11 +642,8 @@ class Config:
def _get_defaults_dict(key: str, value) -> dict: def _get_defaults_dict(key: str, value) -> dict:
""" """
Since we're allowing nested config stuff now, not storing the Since we're allowing nested config stuff now, not storing the
_defaults as a flat dict sounds like a good idea. May turn _defaults as a flat dict sounds like a good idea. May turn out
out to be an awful one but we'll see. to be an awful one but we'll see.
:param key:
:param value:
:return:
""" """
ret = {} ret = {}
partial = ret partial = ret
@ -655,15 +659,12 @@ class Config:
return ret return ret
@staticmethod @staticmethod
def _update_defaults(to_add: dict, _partial: dict): def _update_defaults(to_add: Dict[str, Any], _partial: Dict[str, Any]):
""" """
This tries to update the _defaults dictionary with the nested This tries to update the _defaults dictionary with the nested
partial dict generated by _get_defaults_dict. This WILL partial dict generated by _get_defaults_dict. This WILL
throw an error if you try to have both a value and a group throw an error if you try to have both a value and a group
registered under the same name. registered under the same name.
:param to_add:
:param _partial:
:return:
""" """
for k, v in to_add.items(): for k, v in to_add.items():
val_is_dict = isinstance(v, dict) val_is_dict = isinstance(v, dict)
@ -679,7 +680,7 @@ class Config:
else: else:
_partial[k] = v _partial[k] = v
def _register_default(self, key: str, **kwargs): def _register_default(self, key: str, **kwargs: Any):
if key not in self._defaults: if key not in self._defaults:
self._defaults[key] = {} self._defaults[key] = {}
@ -720,8 +721,8 @@ class Config:
**_defaults **_defaults
) )
You can do the same thing without a :python:`_defaults` dict by using double underscore as a variable You can do the same thing without a :python:`_defaults` dict by
name separator:: using double underscore as a variable name separator::
# This is equivalent to the previous example # This is equivalent to the previous example
conf.register_global( conf.register_global(
@ -802,7 +803,7 @@ class Config:
The guild's Group object. The guild's Group object.
""" """
return self._get_base_group(self.GUILD, guild.id) return self._get_base_group(self.GUILD, str(guild.id))
def channel(self, channel: discord.TextChannel) -> Group: def channel(self, channel: discord.TextChannel) -> Group:
"""Returns a `Group` for the given channel. """Returns a `Group` for the given channel.
@ -820,7 +821,7 @@ class Config:
The channel's Group object. The channel's Group object.
""" """
return self._get_base_group(self.CHANNEL, channel.id) return self._get_base_group(self.CHANNEL, str(channel.id))
def role(self, role: discord.Role) -> Group: def role(self, role: discord.Role) -> Group:
"""Returns a `Group` for the given role. """Returns a `Group` for the given role.
@ -836,7 +837,7 @@ class Config:
The role's Group object. The role's Group object.
""" """
return self._get_base_group(self.ROLE, role.id) return self._get_base_group(self.ROLE, str(role.id))
def user(self, user: discord.abc.User) -> Group: def user(self, user: discord.abc.User) -> Group:
"""Returns a `Group` for the given user. """Returns a `Group` for the given user.
@ -852,7 +853,7 @@ class Config:
The user's Group object. The user's Group object.
""" """
return self._get_base_group(self.USER, user.id) return self._get_base_group(self.USER, str(user.id))
def member(self, member: discord.Member) -> Group: def member(self, member: discord.Member) -> Group:
"""Returns a `Group` for the given member. """Returns a `Group` for the given member.
@ -866,8 +867,9 @@ class Config:
------- -------
`Group <redbot.core.config.Group>` `Group <redbot.core.config.Group>`
The member's Group object. The member's Group object.
""" """
return self._get_base_group(self.MEMBER, member.guild.id, member.id) return self._get_base_group(self.MEMBER, str(member.guild.id), str(member.id))
def custom(self, group_identifier: str, *identifiers: str): def custom(self, group_identifier: str, *identifiers: str):
"""Returns a `Group` for the given custom group. """Returns a `Group` for the given custom group.
@ -876,17 +878,17 @@ class Config:
---------- ----------
group_identifier : str group_identifier : str
Used to identify the custom group. Used to identify the custom group.
identifiers : str identifiers : str
The attributes necessary to uniquely identify an entry in the The attributes necessary to uniquely identify an entry in the
custom group. custom group. These are casted to `str` for you.
Returns Returns
------- -------
`Group <redbot.core.config.Group>` `Group <redbot.core.config.Group>`
The custom group's Group object. The custom group's Group object.
""" """
return self._get_base_group(group_identifier, *identifiers) return self._get_base_group(str(group_identifier), *map(str, identifiers))
async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]: async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]:
"""Get a dict of all values from a particular scope of data. """Get a dict of all values from a particular scope of data.
@ -982,7 +984,8 @@ class Config:
""" """
return await self._all_from_scope(self.USER) return await self._all_from_scope(self.USER)
def _all_members_from_guild(self, group: Group, guild_data: dict) -> dict: @staticmethod
def _all_members_from_guild(group: Group, guild_data: dict) -> dict:
ret = {} ret = {}
for member_id, member_data in guild_data.items(): for member_id, member_data in guild_data.items():
new_member_data = group.defaults new_member_data = group.defaults
@ -1026,7 +1029,7 @@ class Config:
for guild_id, guild_data in dict_.items(): for guild_id, guild_data in dict_.items():
ret[int(guild_id)] = self._all_members_from_guild(group, guild_data) ret[int(guild_id)] = self._all_members_from_guild(group, guild_data)
else: else:
group = self._get_base_group(self.MEMBER, guild.id) group = self._get_base_group(self.MEMBER, str(guild.id))
try: try:
guild_data = await self.driver.get(*group.identifiers) guild_data = await self.driver.get(*group.identifiers)
except KeyError: except KeyError:
@ -1054,7 +1057,8 @@ class Config:
""" """
if not scopes: if not scopes:
group = Group(identifiers=[], defaults={}, driver=self.driver) # noinspection PyTypeChecker
group = Group(identifiers=(), defaults={}, driver=self.driver)
else: else:
group = self._get_base_group(*scopes) group = self._get_base_group(*scopes)
await group.clear() await group.clear()
@ -1119,7 +1123,7 @@ class Config:
""" """
if guild is not None: if guild is not None:
await self._clear_scope(self.MEMBER, guild.id) await self._clear_scope(self.MEMBER, str(guild.id))
return return
await self._clear_scope(self.MEMBER) await self._clear_scope(self.MEMBER)
@ -1127,5 +1131,34 @@ class Config:
"""Clear all custom group data. """Clear all custom group data.
This resets all custom group data to its registered defaults. This resets all custom group data to its registered defaults.
Parameters
----------
group_identifier : str
The identifier for the custom group. This is casted to
`str` for you.
""" """
await self._clear_scope(group_identifier) await self._clear_scope(str(group_identifier))
def _str_key_dict(value: Dict[Any, _T]) -> Dict[str, _T]:
"""
Recursively casts all keys in the given `dict` to `str`.
Parameters
----------
value : Dict[Any, Any]
The `dict` to cast keys to `str`.
Returns
-------
Dict[str, Any]
The `dict` with keys (and nested keys) casted to `str`.
"""
ret = {}
for k, v in value.items():
if isinstance(v, dict):
v = _str_key_dict(v)
ret[str(k)] = v
return ret

View File

@ -475,3 +475,18 @@ async def test_get_raw_mixes_defaults(config):
subgroup = await config.get_raw("subgroup") subgroup = await config.get_raw("subgroup")
assert subgroup == {"foo": True, "bar": False} assert subgroup == {"foo": True, "bar": False}
@pytest.mark.asyncio
async def test_cast_str_raw(config):
await config.set_raw(123, 456, value=True)
assert await config.get_raw(123, 456) is True
assert await config.get_raw("123", "456") is True
await config.clear_raw("123", 456)
@pytest.mark.asyncio
async def test_cast_str_nested(config):
config.register_global(foo={})
await config.foo.set({123: True, 456: {789: False}})
assert await config.foo() == {"123": True, "456": {"789": False}}