mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 11:18:54 -05:00
[Config] Cast keys to str on get/set/clear (#2217)
This is a step towards a more consistent front-end behaviour of Config, where errors are either circumvented or raised in the same way regardless of the driver being used. Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
parent
f85034eb27
commit
ce25011f0d
@ -374,6 +374,21 @@ API Reference
|
||||
inside the bot itself! Simply take a peek inside of the :code:`tests/core/test_config.py` file for examples of using
|
||||
Config in all kinds of ways.
|
||||
|
||||
.. important::
|
||||
|
||||
When getting, setting or clearing values in Config, all keys are casted to `str` for you. This
|
||||
includes keys within a `dict` when one is being set, as well as keys in nested dictionaries
|
||||
within that `dict`. For example::
|
||||
|
||||
>>> conf = Config.get_conf(self, identifier=999)
|
||||
>>> conf.register_global(foo={})
|
||||
>>> await conf.foo.set_raw(123, value=True)
|
||||
>>> await conf.foo()
|
||||
{'123': True}
|
||||
>>> await conf.foo.set({123: True, 456: {789: False}}
|
||||
>>> await conf.foo()
|
||||
{'123': True, '456': {'789': False}}
|
||||
|
||||
.. automodule:: redbot.core.config
|
||||
|
||||
Config
|
||||
|
||||
@ -39,17 +39,21 @@ class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]):
|
||||
|
||||
async def __aenter__(self):
|
||||
self.raw_value = await self
|
||||
self.__original_value = deepcopy(self.raw_value)
|
||||
if not isinstance(self.raw_value, (list, dict)):
|
||||
raise TypeError(
|
||||
"Type of retrieved value must be mutable (i.e. "
|
||||
"list or dict) in order to use a config value as "
|
||||
"a context manager."
|
||||
)
|
||||
self.__original_value = deepcopy(self.raw_value)
|
||||
return self.raw_value
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
if self.raw_value != self.__original_value:
|
||||
if isinstance(self.raw_value, dict):
|
||||
raw_value = _str_key_dict(self.raw_value)
|
||||
else:
|
||||
raw_value = self.raw_value
|
||||
if raw_value != self.__original_value:
|
||||
await self.value_obj.set(self.raw_value)
|
||||
|
||||
|
||||
@ -58,7 +62,7 @@ class Value:
|
||||
|
||||
Attributes
|
||||
----------
|
||||
identifiers : `tuple` of `str`
|
||||
identifiers : Tuple[str]
|
||||
This attribute provides all the keys necessary to get a specific data
|
||||
element from a json document.
|
||||
default
|
||||
@ -69,15 +73,10 @@ class Value:
|
||||
"""
|
||||
|
||||
def __init__(self, identifiers: Tuple[str], default_value, driver):
|
||||
self._identifiers = identifiers
|
||||
self.identifiers = identifiers
|
||||
self.default = default_value
|
||||
|
||||
self.driver = driver
|
||||
|
||||
@property
|
||||
def identifiers(self):
|
||||
return tuple(str(i) for i in self._identifiers)
|
||||
|
||||
async def _get(self, default=...):
|
||||
try:
|
||||
ret = await self.driver.get(*self.identifiers)
|
||||
@ -149,6 +148,8 @@ class Value:
|
||||
The new literal value of this attribute.
|
||||
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
value = _str_key_dict(value)
|
||||
await self.driver.set(*self.identifiers, value=value)
|
||||
|
||||
async def clear(self):
|
||||
@ -238,7 +239,7 @@ class Group(Value):
|
||||
else:
|
||||
return Value(identifiers=new_identifiers, default_value=None, driver=self.driver)
|
||||
|
||||
async def clear_raw(self, *nested_path: str):
|
||||
async def clear_raw(self, *nested_path: Any):
|
||||
"""
|
||||
Allows a developer to clear data as if it was stored in a standard
|
||||
Python dictionary.
|
||||
@ -254,44 +255,44 @@ class Group(Value):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nested_path : str
|
||||
nested_path : Any
|
||||
Multiple arguments that mirror the arguments passed in for nested
|
||||
dict access.
|
||||
dict access. These are casted to `str` for you.
|
||||
"""
|
||||
path = [str(p) for p in nested_path]
|
||||
await self.driver.clear(*self.identifiers, *path)
|
||||
|
||||
def is_group(self, item: str) -> bool:
|
||||
def is_group(self, item: Any) -> bool:
|
||||
"""A helper method for `__getattr__`. Most developers will have no need
|
||||
to use this.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
item : str
|
||||
item : Any
|
||||
See `__getattr__`.
|
||||
|
||||
"""
|
||||
default = self._defaults.get(item)
|
||||
default = self._defaults.get(str(item))
|
||||
return isinstance(default, dict)
|
||||
|
||||
def is_value(self, item: str) -> bool:
|
||||
def is_value(self, item: Any) -> bool:
|
||||
"""A helper method for `__getattr__`. Most developers will have no need
|
||||
to use this.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
item : str
|
||||
item : Any
|
||||
See `__getattr__`.
|
||||
|
||||
"""
|
||||
try:
|
||||
default = self._defaults[item]
|
||||
default = self._defaults[str(item)]
|
||||
except KeyError:
|
||||
return False
|
||||
|
||||
return not isinstance(default, dict)
|
||||
|
||||
def get_attr(self, item: str):
|
||||
def get_attr(self, item: Union[int, str]):
|
||||
"""Manually get an attribute of this Group.
|
||||
|
||||
This is available to use as an alternative to using normal Python
|
||||
@ -312,7 +313,8 @@ class Group(Value):
|
||||
Parameters
|
||||
----------
|
||||
item : str
|
||||
The name of the data field in `Config`.
|
||||
The name of the data field in `Config`. This is casted to
|
||||
`str` for you.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -320,9 +322,11 @@ class Group(Value):
|
||||
The attribute which was requested.
|
||||
|
||||
"""
|
||||
if isinstance(item, int):
|
||||
item = str(item)
|
||||
return self.__getattr__(item)
|
||||
|
||||
async def get_raw(self, *nested_path: str, default=...):
|
||||
async def get_raw(self, *nested_path: Any, default=...):
|
||||
"""
|
||||
Allows a developer to access data as if it was stored in a standard
|
||||
Python dictionary.
|
||||
@ -345,7 +349,7 @@ class Group(Value):
|
||||
----------
|
||||
nested_path : str
|
||||
Multiple arguments that mirror the arguments passed in for nested
|
||||
dict access.
|
||||
dict access. These are casted to `str` for you.
|
||||
default
|
||||
Default argument for the value attempting to be accessed. If the
|
||||
value does not exist the default will be returned.
|
||||
@ -410,7 +414,6 @@ class Group(Value):
|
||||
|
||||
If no defaults are passed, then the instance attribute 'defaults'
|
||||
will be used.
|
||||
|
||||
"""
|
||||
if defaults is ...:
|
||||
defaults = self.defaults
|
||||
@ -428,7 +431,7 @@ class Group(Value):
|
||||
raise ValueError("You may only set the value of a group to be a dict.")
|
||||
await super().set(value)
|
||||
|
||||
async def set_raw(self, *nested_path: str, value):
|
||||
async def set_raw(self, *nested_path: Any, value):
|
||||
"""
|
||||
Allows a developer to set data as if it was stored in a standard
|
||||
Python dictionary.
|
||||
@ -444,13 +447,15 @@ class Group(Value):
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nested_path : str
|
||||
nested_path : Any
|
||||
Multiple arguments that mirror the arguments passed in for nested
|
||||
dict access.
|
||||
`dict` access. These are casted to `str` for you.
|
||||
value
|
||||
The value to store.
|
||||
"""
|
||||
path = [str(p) for p in nested_path]
|
||||
if isinstance(value, dict):
|
||||
value = _str_key_dict(value)
|
||||
await self.driver.set(*self.identifiers, *path, value=value)
|
||||
|
||||
|
||||
@ -461,9 +466,11 @@ class Config:
|
||||
`get_core_conf` for Config used in the core package.
|
||||
|
||||
.. important::
|
||||
Most config data should be accessed through its respective group method (e.g. :py:meth:`guild`)
|
||||
however the process for accessing global data is a bit different. There is no :python:`global` method
|
||||
because global data is accessed by normal attribute access::
|
||||
Most config data should be accessed through its respective
|
||||
group method (e.g. :py:meth:`guild`) however the process for
|
||||
accessing global data is a bit different. There is no
|
||||
:python:`global` method because global data is accessed by
|
||||
normal attribute access::
|
||||
|
||||
await conf.foo()
|
||||
|
||||
@ -548,7 +555,7 @@ class Config:
|
||||
A new Config object.
|
||||
|
||||
"""
|
||||
if cog_instance is None and not cog_name is None:
|
||||
if cog_instance is None and cog_name is not None:
|
||||
cog_path_override = cog_data_path(raw_name=cog_name)
|
||||
else:
|
||||
cog_path_override = cog_data_path(cog_instance=cog_instance)
|
||||
@ -635,11 +642,8 @@ class Config:
|
||||
def _get_defaults_dict(key: str, value) -> dict:
|
||||
"""
|
||||
Since we're allowing nested config stuff now, not storing the
|
||||
_defaults as a flat dict sounds like a good idea. May turn
|
||||
out to be an awful one but we'll see.
|
||||
:param key:
|
||||
:param value:
|
||||
:return:
|
||||
_defaults as a flat dict sounds like a good idea. May turn out
|
||||
to be an awful one but we'll see.
|
||||
"""
|
||||
ret = {}
|
||||
partial = ret
|
||||
@ -655,15 +659,12 @@ class Config:
|
||||
return ret
|
||||
|
||||
@staticmethod
|
||||
def _update_defaults(to_add: dict, _partial: dict):
|
||||
def _update_defaults(to_add: Dict[str, Any], _partial: Dict[str, Any]):
|
||||
"""
|
||||
This tries to update the _defaults dictionary with the nested
|
||||
partial dict generated by _get_defaults_dict. This WILL
|
||||
throw an error if you try to have both a value and a group
|
||||
registered under the same name.
|
||||
:param to_add:
|
||||
:param _partial:
|
||||
:return:
|
||||
partial dict generated by _get_defaults_dict. This WILL
|
||||
throw an error if you try to have both a value and a group
|
||||
registered under the same name.
|
||||
"""
|
||||
for k, v in to_add.items():
|
||||
val_is_dict = isinstance(v, dict)
|
||||
@ -679,7 +680,7 @@ class Config:
|
||||
else:
|
||||
_partial[k] = v
|
||||
|
||||
def _register_default(self, key: str, **kwargs):
|
||||
def _register_default(self, key: str, **kwargs: Any):
|
||||
if key not in self._defaults:
|
||||
self._defaults[key] = {}
|
||||
|
||||
@ -720,8 +721,8 @@ class Config:
|
||||
**_defaults
|
||||
)
|
||||
|
||||
You can do the same thing without a :python:`_defaults` dict by using double underscore as a variable
|
||||
name separator::
|
||||
You can do the same thing without a :python:`_defaults` dict by
|
||||
using double underscore as a variable name separator::
|
||||
|
||||
# This is equivalent to the previous example
|
||||
conf.register_global(
|
||||
@ -802,7 +803,7 @@ class Config:
|
||||
The guild's Group object.
|
||||
|
||||
"""
|
||||
return self._get_base_group(self.GUILD, guild.id)
|
||||
return self._get_base_group(self.GUILD, str(guild.id))
|
||||
|
||||
def channel(self, channel: discord.TextChannel) -> Group:
|
||||
"""Returns a `Group` for the given channel.
|
||||
@ -820,7 +821,7 @@ class Config:
|
||||
The channel's Group object.
|
||||
|
||||
"""
|
||||
return self._get_base_group(self.CHANNEL, channel.id)
|
||||
return self._get_base_group(self.CHANNEL, str(channel.id))
|
||||
|
||||
def role(self, role: discord.Role) -> Group:
|
||||
"""Returns a `Group` for the given role.
|
||||
@ -836,7 +837,7 @@ class Config:
|
||||
The role's Group object.
|
||||
|
||||
"""
|
||||
return self._get_base_group(self.ROLE, role.id)
|
||||
return self._get_base_group(self.ROLE, str(role.id))
|
||||
|
||||
def user(self, user: discord.abc.User) -> Group:
|
||||
"""Returns a `Group` for the given user.
|
||||
@ -852,7 +853,7 @@ class Config:
|
||||
The user's Group object.
|
||||
|
||||
"""
|
||||
return self._get_base_group(self.USER, user.id)
|
||||
return self._get_base_group(self.USER, str(user.id))
|
||||
|
||||
def member(self, member: discord.Member) -> Group:
|
||||
"""Returns a `Group` for the given member.
|
||||
@ -866,8 +867,9 @@ class Config:
|
||||
-------
|
||||
`Group <redbot.core.config.Group>`
|
||||
The member's Group object.
|
||||
|
||||
"""
|
||||
return self._get_base_group(self.MEMBER, member.guild.id, member.id)
|
||||
return self._get_base_group(self.MEMBER, str(member.guild.id), str(member.id))
|
||||
|
||||
def custom(self, group_identifier: str, *identifiers: str):
|
||||
"""Returns a `Group` for the given custom group.
|
||||
@ -876,17 +878,17 @@ class Config:
|
||||
----------
|
||||
group_identifier : str
|
||||
Used to identify the custom group.
|
||||
|
||||
identifiers : str
|
||||
The attributes necessary to uniquely identify an entry in the
|
||||
custom group.
|
||||
custom group. These are casted to `str` for you.
|
||||
|
||||
Returns
|
||||
-------
|
||||
`Group <redbot.core.config.Group>`
|
||||
The custom group's Group object.
|
||||
|
||||
"""
|
||||
return self._get_base_group(group_identifier, *identifiers)
|
||||
return self._get_base_group(str(group_identifier), *map(str, identifiers))
|
||||
|
||||
async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]:
|
||||
"""Get a dict of all values from a particular scope of data.
|
||||
@ -982,7 +984,8 @@ class Config:
|
||||
"""
|
||||
return await self._all_from_scope(self.USER)
|
||||
|
||||
def _all_members_from_guild(self, group: Group, guild_data: dict) -> dict:
|
||||
@staticmethod
|
||||
def _all_members_from_guild(group: Group, guild_data: dict) -> dict:
|
||||
ret = {}
|
||||
for member_id, member_data in guild_data.items():
|
||||
new_member_data = group.defaults
|
||||
@ -1026,7 +1029,7 @@ class Config:
|
||||
for guild_id, guild_data in dict_.items():
|
||||
ret[int(guild_id)] = self._all_members_from_guild(group, guild_data)
|
||||
else:
|
||||
group = self._get_base_group(self.MEMBER, guild.id)
|
||||
group = self._get_base_group(self.MEMBER, str(guild.id))
|
||||
try:
|
||||
guild_data = await self.driver.get(*group.identifiers)
|
||||
except KeyError:
|
||||
@ -1054,7 +1057,8 @@ class Config:
|
||||
|
||||
"""
|
||||
if not scopes:
|
||||
group = Group(identifiers=[], defaults={}, driver=self.driver)
|
||||
# noinspection PyTypeChecker
|
||||
group = Group(identifiers=(), defaults={}, driver=self.driver)
|
||||
else:
|
||||
group = self._get_base_group(*scopes)
|
||||
await group.clear()
|
||||
@ -1119,7 +1123,7 @@ class Config:
|
||||
|
||||
"""
|
||||
if guild is not None:
|
||||
await self._clear_scope(self.MEMBER, guild.id)
|
||||
await self._clear_scope(self.MEMBER, str(guild.id))
|
||||
return
|
||||
await self._clear_scope(self.MEMBER)
|
||||
|
||||
@ -1127,5 +1131,34 @@ class Config:
|
||||
"""Clear all custom group data.
|
||||
|
||||
This resets all custom group data to its registered defaults.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
group_identifier : str
|
||||
The identifier for the custom group. This is casted to
|
||||
`str` for you.
|
||||
"""
|
||||
await self._clear_scope(group_identifier)
|
||||
await self._clear_scope(str(group_identifier))
|
||||
|
||||
|
||||
def _str_key_dict(value: Dict[Any, _T]) -> Dict[str, _T]:
|
||||
"""
|
||||
Recursively casts all keys in the given `dict` to `str`.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
value : Dict[Any, Any]
|
||||
The `dict` to cast keys to `str`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, Any]
|
||||
The `dict` with keys (and nested keys) casted to `str`.
|
||||
|
||||
"""
|
||||
ret = {}
|
||||
for k, v in value.items():
|
||||
if isinstance(v, dict):
|
||||
v = _str_key_dict(v)
|
||||
ret[str(k)] = v
|
||||
return ret
|
||||
|
||||
@ -475,3 +475,18 @@ async def test_get_raw_mixes_defaults(config):
|
||||
|
||||
subgroup = await config.get_raw("subgroup")
|
||||
assert subgroup == {"foo": True, "bar": False}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cast_str_raw(config):
|
||||
await config.set_raw(123, 456, value=True)
|
||||
assert await config.get_raw(123, 456) is True
|
||||
assert await config.get_raw("123", "456") is True
|
||||
await config.clear_raw("123", 456)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_cast_str_nested(config):
|
||||
config.register_global(foo={})
|
||||
await config.foo.set({123: True, 456: {789: False}})
|
||||
assert await config.foo() == {"123": True, "456": {"789": False}}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user