Merge V3/release/3.0.0 into V3/develop

This commit is contained in:
Toby Harradine 2018-10-12 08:59:14 +11:00 committed by GitHub
commit 5ed8be9998
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
13 changed files with 249 additions and 107 deletions

View File

@ -374,6 +374,21 @@ API Reference
inside the bot itself! Simply take a peek inside of the :code:`tests/core/test_config.py` file for examples of using inside the bot itself! Simply take a peek inside of the :code:`tests/core/test_config.py` file for examples of using
Config in all kinds of ways. Config in all kinds of ways.
.. important::
When getting, setting or clearing values in Config, all keys are casted to `str` for you. This
includes keys within a `dict` when one is being set, as well as keys in nested dictionaries
within that `dict`. For example::
>>> conf = Config.get_conf(self, identifier=999)
>>> conf.register_global(foo={})
>>> await conf.foo.set_raw(123, value=True)
>>> await conf.foo()
{'123': True}
>>> await conf.foo.set({123: True, 456: {789: False}}
>>> await conf.foo()
{'123': True, '456': {'789': False}}
.. automodule:: redbot.core.config .. automodule:: redbot.core.config
Config Config

View File

@ -34,14 +34,14 @@ async def download_lavalink(session):
async def maybe_download_lavalink(loop, cog): async def maybe_download_lavalink(loop, cog):
jar_exists = LAVALINK_JAR_FILE.exists() jar_exists = LAVALINK_JAR_FILE.exists()
current_build = redbot.core.VersionInfo.from_json(await cog.config.current_build()) current_build = redbot.core.VersionInfo.from_json(await cog.config.current_version())
if not jar_exists or current_build < redbot.core.version_info: if not jar_exists or current_build < redbot.core.version_info:
log.info("Downloading Lavalink.jar") log.info("Downloading Lavalink.jar")
LAVALINK_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True) LAVALINK_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
async with ClientSession(loop=loop) as session: async with ClientSession(loop=loop) as session:
await download_lavalink(session) await download_lavalink(session)
await cog.config.current_build.set(redbot.core.version_info.to_json()) await cog.config.current_version.set(redbot.core.version_info.to_json())
shutil.copyfile(str(BUNDLED_APP_YML_FILE), str(APP_YML_FILE)) shutil.copyfile(str(BUNDLED_APP_YML_FILE), str(APP_YML_FILE))

View File

@ -48,7 +48,7 @@ class Audio(commands.Cog):
"ws_port": "2332", "ws_port": "2332",
"password": "youshallnotpass", "password": "youshallnotpass",
"status": False, "status": False,
"current_build": redbot.core.VersionInfo.from_str("3.0.0a0").to_json(), "current_version": redbot.core.VersionInfo.from_str("3.0.0a0").to_json(),
"use_external_lavalink": False, "use_external_lavalink": False,
} }

View File

@ -169,7 +169,7 @@ class Cleanup(commands.Cog):
member = None member = None
try: try:
member = await commands.converter.MemberConverter().convert(ctx, user) member = await commands.MemberConverter().convert(ctx, user)
except commands.BadArgument: except commands.BadArgument:
try: try:
_id = int(user) _id = int(user)

View File

@ -542,7 +542,8 @@ class Permissions(commands.Cog):
continue continue
conf = self.config.custom(category) conf = self.config.custom(category)
for cmd_name, cmd_rules in rules_dict.items(): for cmd_name, cmd_rules in rules_dict.items():
await conf.set_raw(cmd_name, guild_id, value=cmd_rules) cmd_rules = {str(model_id): rule for model_id, rule in cmd_rules.items()}
await conf.set_raw(cmd_name, str(guild_id), value=cmd_rules)
cmd_obj = getter(cmd_name) cmd_obj = getter(cmd_name)
if cmd_obj is not None: if cmd_obj is not None:
self._load_rules_for(cmd_obj, {guild_id: cmd_rules}) self._load_rules_for(cmd_obj, {guild_id: cmd_rules})
@ -651,14 +652,14 @@ class Permissions(commands.Cog):
if category in old_rules: if category in old_rules:
for name, rules in old_rules[category].items(): for name, rules in old_rules[category].items():
these_rules = new_rules.setdefault(name, {}) these_rules = new_rules.setdefault(name, {})
guild_rules = these_rules.setdefault(guild_id, {}) guild_rules = these_rules.setdefault(str(guild_id), {})
# Since allow rules would take precedence if the same model ID # Since allow rules would take precedence if the same model ID
# sat in both the allow and deny list, we add the deny entries # sat in both the allow and deny list, we add the deny entries
# first and let any conflicting allow entries overwrite. # first and let any conflicting allow entries overwrite.
for model_id in rules.get("deny", []): for model_id in rules.get("deny", []):
guild_rules[model_id] = False guild_rules[str(model_id)] = False
for model_id in rules.get("allow", []): for model_id in rules.get("allow", []):
guild_rules[model_id] = True guild_rules[str(model_id)] = True
if "default" in rules: if "default" in rules:
default = rules["default"] default = rules["default"]
if default == "allow": if default == "allow":
@ -689,7 +690,9 @@ class Permissions(commands.Cog):
""" """
for guild_id, guild_dict in _int_key_map(rule_dict.items()): for guild_id, guild_dict in _int_key_map(rule_dict.items()):
for model_id, rule in _int_key_map(guild_dict.items()): for model_id, rule in _int_key_map(guild_dict.items()):
if rule is True: if model_id == "default":
cog_or_command.set_default_rule(rule, guild_id=guild_id)
elif rule is True:
cog_or_command.allow_for(model_id, guild_id=guild_id) cog_or_command.allow_for(model_id, guild_id=guild_id)
elif rule is False: elif rule is False:
cog_or_command.deny_to(model_id, guild_id=guild_id) cog_or_command.deny_to(model_id, guild_id=guild_id)
@ -724,9 +727,16 @@ class Permissions(commands.Cog):
rules. rules.
""" """
for guild_id, guild_dict in _int_key_map(rule_dict.items()): for guild_id, guild_dict in _int_key_map(rule_dict.items()):
for model_id in map(int, guild_dict.keys()): for model_id in guild_dict.keys():
cog_or_command.clear_rule_for(model_id, guild_id) if model_id == "default":
cog_or_command.set_default_rule(None, guild_id=guild_id)
else:
cog_or_command.clear_rule_for(int(model_id), guild_id=guild_id)
def _int_key_map(items_view: ItemsView[str, Any]) -> Iterator[Tuple[int, Any]]: def _int_key_map(items_view: ItemsView[str, Any]) -> Iterator[Tuple[Union[str, int], Any]]:
return map(lambda tup: (int(tup[0]), tup[1]), items_view) for k, v in items_view:
if k == "default":
yield k, v
else:
yield int(k), v

View File

@ -322,9 +322,9 @@ def _parse_answers(answers):
for answer in answers: for answer in answers:
if isinstance(answer, bool): if isinstance(answer, bool):
if answer is True: if answer is True:
ret.extend(["True", "Yes", _("Yes")]) ret.extend(["True", "Yes", "On"])
else: else:
ret.extend(["False", "No", _("No")]) ret.extend(["False", "No", "Off"])
else: else:
ret.append(str(answer)) ret.append(str(answer))
# Uniquify list # Uniquify list

View File

@ -19,9 +19,11 @@ async def warning_points_add_check(
act = {} act = {}
async with guild_settings.actions() as registered_actions: async with guild_settings.actions() as registered_actions:
for a in registered_actions: for a in registered_actions:
# Actions are sorted in decreasing order of points.
# The first action we find where the user is above the threshold will be the
# highest action we can take.
if points >= a["points"]: if points >= a["points"]:
act = a act = a
else:
break break
if act and act["exceed_command"] is not None: # some action needs to be taken if act and act["exceed_command"] is not None: # some action needs to be taken
await create_and_invoke_context(ctx, act["exceed_command"], user) await create_and_invoke_context(ctx, act["exceed_command"], user)

View File

@ -9,7 +9,7 @@ from redbot.cogs.warnings.helpers import (
get_command_for_dropping_points, get_command_for_dropping_points,
warning_points_remove_check, warning_points_remove_check,
) )
from redbot.core import Config, modlog, checks, commands from redbot.core import Config, checks, commands
from redbot.core.bot import Red from redbot.core.bot import Red
from redbot.core.i18n import Translator, cog_i18n from redbot.core.i18n import Translator, cog_i18n
from redbot.core.utils.mod import is_admin_or_superior from redbot.core.utils.mod import is_admin_or_superior
@ -34,15 +34,14 @@ class Warnings(commands.Cog):
self.config.register_guild(**self.default_guild) self.config.register_guild(**self.default_guild)
self.config.register_member(**self.default_member) self.config.register_member(**self.default_member)
self.bot = bot self.bot = bot
loop = asyncio.get_event_loop()
loop.create_task(self.register_warningtype())
@staticmethod # We're not utilising modlog yet - no need to register a casetype
async def register_warningtype(): # @staticmethod
try: # async def register_warningtype():
await modlog.register_casetype("warning", True, "\N{WARNING SIGN}", "Warning", None) # try:
except RuntimeError: # await modlog.register_casetype("warning", True, "\N{WARNING SIGN}", "Warning", None)
pass # except RuntimeError:
# pass
@commands.group() @commands.group()
@commands.guild_only() @commands.guild_only()

View File

@ -39,17 +39,21 @@ class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]):
async def __aenter__(self): async def __aenter__(self):
self.raw_value = await self self.raw_value = await self
self.__original_value = deepcopy(self.raw_value)
if not isinstance(self.raw_value, (list, dict)): if not isinstance(self.raw_value, (list, dict)):
raise TypeError( raise TypeError(
"Type of retrieved value must be mutable (i.e. " "Type of retrieved value must be mutable (i.e. "
"list or dict) in order to use a config value as " "list or dict) in order to use a config value as "
"a context manager." "a context manager."
) )
self.__original_value = deepcopy(self.raw_value)
return self.raw_value return self.raw_value
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
if self.raw_value != self.__original_value: if isinstance(self.raw_value, dict):
raw_value = _str_key_dict(self.raw_value)
else:
raw_value = self.raw_value
if raw_value != self.__original_value:
await self.value_obj.set(self.raw_value) await self.value_obj.set(self.raw_value)
@ -58,7 +62,7 @@ class Value:
Attributes Attributes
---------- ----------
identifiers : `tuple` of `str` identifiers : Tuple[str]
This attribute provides all the keys necessary to get a specific data This attribute provides all the keys necessary to get a specific data
element from a json document. element from a json document.
default default
@ -69,15 +73,10 @@ class Value:
""" """
def __init__(self, identifiers: Tuple[str], default_value, driver): def __init__(self, identifiers: Tuple[str], default_value, driver):
self._identifiers = identifiers self.identifiers = identifiers
self.default = default_value self.default = default_value
self.driver = driver self.driver = driver
@property
def identifiers(self):
return tuple(str(i) for i in self._identifiers)
async def _get(self, default=...): async def _get(self, default=...):
try: try:
ret = await self.driver.get(*self.identifiers) ret = await self.driver.get(*self.identifiers)
@ -149,6 +148,8 @@ class Value:
The new literal value of this attribute. The new literal value of this attribute.
""" """
if isinstance(value, dict):
value = _str_key_dict(value)
await self.driver.set(*self.identifiers, value=value) await self.driver.set(*self.identifiers, value=value)
async def clear(self): async def clear(self):
@ -192,7 +193,10 @@ class Group(Value):
async def _get(self, default: Dict[str, Any] = ...) -> Dict[str, Any]: async def _get(self, default: Dict[str, Any] = ...) -> Dict[str, Any]:
default = default if default is not ... else self.defaults default = default if default is not ... else self.defaults
raw = await super()._get(default) raw = await super()._get(default)
return self.nested_update(raw, default) if isinstance(raw, dict):
return self.nested_update(raw, default)
else:
return raw
# noinspection PyTypeChecker # noinspection PyTypeChecker
def __getattr__(self, item: str) -> Union["Group", Value]: def __getattr__(self, item: str) -> Union["Group", Value]:
@ -238,7 +242,7 @@ class Group(Value):
else: else:
return Value(identifiers=new_identifiers, default_value=None, driver=self.driver) return Value(identifiers=new_identifiers, default_value=None, driver=self.driver)
async def clear_raw(self, *nested_path: str): async def clear_raw(self, *nested_path: Any):
""" """
Allows a developer to clear data as if it was stored in a standard Allows a developer to clear data as if it was stored in a standard
Python dictionary. Python dictionary.
@ -254,44 +258,44 @@ class Group(Value):
Parameters Parameters
---------- ----------
nested_path : str nested_path : Any
Multiple arguments that mirror the arguments passed in for nested Multiple arguments that mirror the arguments passed in for nested
dict access. dict access. These are casted to `str` for you.
""" """
path = [str(p) for p in nested_path] path = [str(p) for p in nested_path]
await self.driver.clear(*self.identifiers, *path) await self.driver.clear(*self.identifiers, *path)
def is_group(self, item: str) -> bool: def is_group(self, item: Any) -> bool:
"""A helper method for `__getattr__`. Most developers will have no need """A helper method for `__getattr__`. Most developers will have no need
to use this. to use this.
Parameters Parameters
---------- ----------
item : str item : Any
See `__getattr__`. See `__getattr__`.
""" """
default = self._defaults.get(item) default = self._defaults.get(str(item))
return isinstance(default, dict) return isinstance(default, dict)
def is_value(self, item: str) -> bool: def is_value(self, item: Any) -> bool:
"""A helper method for `__getattr__`. Most developers will have no need """A helper method for `__getattr__`. Most developers will have no need
to use this. to use this.
Parameters Parameters
---------- ----------
item : str item : Any
See `__getattr__`. See `__getattr__`.
""" """
try: try:
default = self._defaults[item] default = self._defaults[str(item)]
except KeyError: except KeyError:
return False return False
return not isinstance(default, dict) return not isinstance(default, dict)
def get_attr(self, item: str): def get_attr(self, item: Union[int, str]):
"""Manually get an attribute of this Group. """Manually get an attribute of this Group.
This is available to use as an alternative to using normal Python This is available to use as an alternative to using normal Python
@ -312,7 +316,8 @@ class Group(Value):
Parameters Parameters
---------- ----------
item : str item : str
The name of the data field in `Config`. The name of the data field in `Config`. This is casted to
`str` for you.
Returns Returns
------- -------
@ -320,9 +325,11 @@ class Group(Value):
The attribute which was requested. The attribute which was requested.
""" """
if isinstance(item, int):
item = str(item)
return self.__getattr__(item) return self.__getattr__(item)
async def get_raw(self, *nested_path: str, default=...): async def get_raw(self, *nested_path: Any, default=...):
""" """
Allows a developer to access data as if it was stored in a standard Allows a developer to access data as if it was stored in a standard
Python dictionary. Python dictionary.
@ -345,7 +352,7 @@ class Group(Value):
---------- ----------
nested_path : str nested_path : str
Multiple arguments that mirror the arguments passed in for nested Multiple arguments that mirror the arguments passed in for nested
dict access. dict access. These are casted to `str` for you.
default default
Default argument for the value attempting to be accessed. If the Default argument for the value attempting to be accessed. If the
value does not exist the default will be returned. value does not exist the default will be returned.
@ -410,7 +417,6 @@ class Group(Value):
If no defaults are passed, then the instance attribute 'defaults' If no defaults are passed, then the instance attribute 'defaults'
will be used. will be used.
""" """
if defaults is ...: if defaults is ...:
defaults = self.defaults defaults = self.defaults
@ -428,7 +434,7 @@ class Group(Value):
raise ValueError("You may only set the value of a group to be a dict.") raise ValueError("You may only set the value of a group to be a dict.")
await super().set(value) await super().set(value)
async def set_raw(self, *nested_path: str, value): async def set_raw(self, *nested_path: Any, value):
""" """
Allows a developer to set data as if it was stored in a standard Allows a developer to set data as if it was stored in a standard
Python dictionary. Python dictionary.
@ -444,13 +450,15 @@ class Group(Value):
Parameters Parameters
---------- ----------
nested_path : str nested_path : Any
Multiple arguments that mirror the arguments passed in for nested Multiple arguments that mirror the arguments passed in for nested
dict access. `dict` access. These are casted to `str` for you.
value value
The value to store. The value to store.
""" """
path = [str(p) for p in nested_path] path = [str(p) for p in nested_path]
if isinstance(value, dict):
value = _str_key_dict(value)
await self.driver.set(*self.identifiers, *path, value=value) await self.driver.set(*self.identifiers, *path, value=value)
@ -461,9 +469,11 @@ class Config:
`get_core_conf` for Config used in the core package. `get_core_conf` for Config used in the core package.
.. important:: .. important::
Most config data should be accessed through its respective group method (e.g. :py:meth:`guild`) Most config data should be accessed through its respective
however the process for accessing global data is a bit different. There is no :python:`global` method group method (e.g. :py:meth:`guild`) however the process for
because global data is accessed by normal attribute access:: accessing global data is a bit different. There is no
:python:`global` method because global data is accessed by
normal attribute access::
await conf.foo() await conf.foo()
@ -548,7 +558,7 @@ class Config:
A new Config object. A new Config object.
""" """
if cog_instance is None and not cog_name is None: if cog_instance is None and cog_name is not None:
cog_path_override = cog_data_path(raw_name=cog_name) cog_path_override = cog_data_path(raw_name=cog_name)
else: else:
cog_path_override = cog_data_path(cog_instance=cog_instance) cog_path_override = cog_data_path(cog_instance=cog_instance)
@ -635,11 +645,8 @@ class Config:
def _get_defaults_dict(key: str, value) -> dict: def _get_defaults_dict(key: str, value) -> dict:
""" """
Since we're allowing nested config stuff now, not storing the Since we're allowing nested config stuff now, not storing the
_defaults as a flat dict sounds like a good idea. May turn _defaults as a flat dict sounds like a good idea. May turn out
out to be an awful one but we'll see. to be an awful one but we'll see.
:param key:
:param value:
:return:
""" """
ret = {} ret = {}
partial = ret partial = ret
@ -655,15 +662,12 @@ class Config:
return ret return ret
@staticmethod @staticmethod
def _update_defaults(to_add: dict, _partial: dict): def _update_defaults(to_add: Dict[str, Any], _partial: Dict[str, Any]):
""" """
This tries to update the _defaults dictionary with the nested This tries to update the _defaults dictionary with the nested
partial dict generated by _get_defaults_dict. This WILL partial dict generated by _get_defaults_dict. This WILL
throw an error if you try to have both a value and a group throw an error if you try to have both a value and a group
registered under the same name. registered under the same name.
:param to_add:
:param _partial:
:return:
""" """
for k, v in to_add.items(): for k, v in to_add.items():
val_is_dict = isinstance(v, dict) val_is_dict = isinstance(v, dict)
@ -679,7 +683,7 @@ class Config:
else: else:
_partial[k] = v _partial[k] = v
def _register_default(self, key: str, **kwargs): def _register_default(self, key: str, **kwargs: Any):
if key not in self._defaults: if key not in self._defaults:
self._defaults[key] = {} self._defaults[key] = {}
@ -720,8 +724,8 @@ class Config:
**_defaults **_defaults
) )
You can do the same thing without a :python:`_defaults` dict by using double underscore as a variable You can do the same thing without a :python:`_defaults` dict by
name separator:: using double underscore as a variable name separator::
# This is equivalent to the previous example # This is equivalent to the previous example
conf.register_global( conf.register_global(
@ -802,7 +806,7 @@ class Config:
The guild's Group object. The guild's Group object.
""" """
return self._get_base_group(self.GUILD, guild.id) return self._get_base_group(self.GUILD, str(guild.id))
def channel(self, channel: discord.TextChannel) -> Group: def channel(self, channel: discord.TextChannel) -> Group:
"""Returns a `Group` for the given channel. """Returns a `Group` for the given channel.
@ -820,7 +824,7 @@ class Config:
The channel's Group object. The channel's Group object.
""" """
return self._get_base_group(self.CHANNEL, channel.id) return self._get_base_group(self.CHANNEL, str(channel.id))
def role(self, role: discord.Role) -> Group: def role(self, role: discord.Role) -> Group:
"""Returns a `Group` for the given role. """Returns a `Group` for the given role.
@ -836,7 +840,7 @@ class Config:
The role's Group object. The role's Group object.
""" """
return self._get_base_group(self.ROLE, role.id) return self._get_base_group(self.ROLE, str(role.id))
def user(self, user: discord.abc.User) -> Group: def user(self, user: discord.abc.User) -> Group:
"""Returns a `Group` for the given user. """Returns a `Group` for the given user.
@ -852,7 +856,7 @@ class Config:
The user's Group object. The user's Group object.
""" """
return self._get_base_group(self.USER, user.id) return self._get_base_group(self.USER, str(user.id))
def member(self, member: discord.Member) -> Group: def member(self, member: discord.Member) -> Group:
"""Returns a `Group` for the given member. """Returns a `Group` for the given member.
@ -866,8 +870,9 @@ class Config:
------- -------
`Group <redbot.core.config.Group>` `Group <redbot.core.config.Group>`
The member's Group object. The member's Group object.
""" """
return self._get_base_group(self.MEMBER, member.guild.id, member.id) return self._get_base_group(self.MEMBER, str(member.guild.id), str(member.id))
def custom(self, group_identifier: str, *identifiers: str): def custom(self, group_identifier: str, *identifiers: str):
"""Returns a `Group` for the given custom group. """Returns a `Group` for the given custom group.
@ -876,17 +881,17 @@ class Config:
---------- ----------
group_identifier : str group_identifier : str
Used to identify the custom group. Used to identify the custom group.
identifiers : str identifiers : str
The attributes necessary to uniquely identify an entry in the The attributes necessary to uniquely identify an entry in the
custom group. custom group. These are casted to `str` for you.
Returns Returns
------- -------
`Group <redbot.core.config.Group>` `Group <redbot.core.config.Group>`
The custom group's Group object. The custom group's Group object.
""" """
return self._get_base_group(group_identifier, *identifiers) return self._get_base_group(str(group_identifier), *map(str, identifiers))
async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]: async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]:
"""Get a dict of all values from a particular scope of data. """Get a dict of all values from a particular scope of data.
@ -982,7 +987,8 @@ class Config:
""" """
return await self._all_from_scope(self.USER) return await self._all_from_scope(self.USER)
def _all_members_from_guild(self, group: Group, guild_data: dict) -> dict: @staticmethod
def _all_members_from_guild(group: Group, guild_data: dict) -> dict:
ret = {} ret = {}
for member_id, member_data in guild_data.items(): for member_id, member_data in guild_data.items():
new_member_data = group.defaults new_member_data = group.defaults
@ -1026,7 +1032,7 @@ class Config:
for guild_id, guild_data in dict_.items(): for guild_id, guild_data in dict_.items():
ret[int(guild_id)] = self._all_members_from_guild(group, guild_data) ret[int(guild_id)] = self._all_members_from_guild(group, guild_data)
else: else:
group = self._get_base_group(self.MEMBER, guild.id) group = self._get_base_group(self.MEMBER, str(guild.id))
try: try:
guild_data = await self.driver.get(*group.identifiers) guild_data = await self.driver.get(*group.identifiers)
except KeyError: except KeyError:
@ -1054,7 +1060,8 @@ class Config:
""" """
if not scopes: if not scopes:
group = Group(identifiers=[], defaults={}, driver=self.driver) # noinspection PyTypeChecker
group = Group(identifiers=(), defaults={}, driver=self.driver)
else: else:
group = self._get_base_group(*scopes) group = self._get_base_group(*scopes)
await group.clear() await group.clear()
@ -1119,7 +1126,7 @@ class Config:
""" """
if guild is not None: if guild is not None:
await self._clear_scope(self.MEMBER, guild.id) await self._clear_scope(self.MEMBER, str(guild.id))
return return
await self._clear_scope(self.MEMBER) await self._clear_scope(self.MEMBER)
@ -1127,5 +1134,34 @@ class Config:
"""Clear all custom group data. """Clear all custom group data.
This resets all custom group data to its registered defaults. This resets all custom group data to its registered defaults.
Parameters
----------
group_identifier : str
The identifier for the custom group. This is casted to
`str` for you.
""" """
await self._clear_scope(group_identifier) await self._clear_scope(str(group_identifier))
def _str_key_dict(value: Dict[Any, _T]) -> Dict[str, _T]:
"""
Recursively casts all keys in the given `dict` to `str`.
Parameters
----------
value : Dict[Any, Any]
The `dict` to cast keys to `str`.
Returns
-------
Dict[str, Any]
The `dict` with keys (and nested keys) casted to `str`.
"""
ret = {}
for k, v in value.items():
if isinstance(v, dict):
v = _str_key_dict(v)
ret[str(k)] = v
return ret

View File

@ -1118,21 +1118,20 @@ class Core(commands.Cog, CoreLogic):
if basic_config["STORAGE_TYPE"] == "MongoDB": if basic_config["STORAGE_TYPE"] == "MongoDB":
from redbot.core.drivers.red_mongo import Mongo from redbot.core.drivers.red_mongo import Mongo
m = Mongo("Core", **basic_config["STORAGE_DETAILS"]) m = Mongo("Core", "0", **basic_config["STORAGE_DETAILS"])
db = m.db db = m.db
collection_names = await db.collection_names(include_system_collections=False) collection_names = await db.list_collection_names()
for c_name in collection_names: for c_name in collection_names:
if c_name == "Core": if c_name == "Core":
c_data_path = data_dir / basic_config["CORE_PATH_APPEND"] c_data_path = data_dir / basic_config["CORE_PATH_APPEND"]
else: else:
c_data_path = data_dir / basic_config["COG_PATH_APPEND"] c_data_path = data_dir / basic_config["COG_PATH_APPEND"] / c_name
output = {}
docs = await db[c_name].find().to_list(None) docs = await db[c_name].find().to_list(None)
for item in docs: for item in docs:
item_id = str(item.pop("_id")) item_id = str(item.pop("_id"))
output[item_id] = item output = item
target = JSON(c_name, data_path_override=c_data_path) target = JSON(c_name, item_id, data_path_override=c_data_path)
await target.jsonIO._threadsafe_save_json(output) await target.jsonIO._threadsafe_save_json(output)
backup_filename = "redv3-{}-{}.tar.gz".format( backup_filename = "redv3-{}-{}.tar.gz".format(
instance_name, ctx.message.created_at.strftime("%Y-%m-%d %H-%M-%S") instance_name, ctx.message.created_at.strftime("%Y-%m-%d %H-%M-%S")
) )

View File

@ -1,7 +1,12 @@
import motor.motor_asyncio import re
from .red_base import BaseDriver from typing import Match, Pattern
from urllib.parse import quote_plus from urllib.parse import quote_plus
import motor.core
import motor.motor_asyncio
from .red_base import BaseDriver
__all__ = ["Mongo"] __all__ = ["Mongo"]
@ -80,6 +85,7 @@ class Mongo(BaseDriver):
async def get(self, *identifiers: str): async def get(self, *identifiers: str):
mongo_collection = self.get_collection() mongo_collection = self.get_collection()
identifiers = (*map(self._escape_key, identifiers),)
dot_identifiers = ".".join(identifiers) dot_identifiers = ".".join(identifiers)
partial = await mongo_collection.find_one( partial = await mongo_collection.find_one(
@ -91,10 +97,14 @@ class Mongo(BaseDriver):
for i in identifiers: for i in identifiers:
partial = partial[i] partial = partial[i]
if isinstance(partial, dict):
return self._unescape_dict_keys(partial)
return partial return partial
async def set(self, *identifiers: str, value=None): async def set(self, *identifiers: str, value=None):
dot_identifiers = ".".join(identifiers) dot_identifiers = ".".join(map(self._escape_key, identifiers))
if isinstance(value, dict):
value = self._escape_dict_keys(value)
mongo_collection = self.get_collection() mongo_collection = self.get_collection()
@ -105,7 +115,7 @@ class Mongo(BaseDriver):
) )
async def clear(self, *identifiers: str): async def clear(self, *identifiers: str):
dot_identifiers = ".".join(identifiers) dot_identifiers = ".".join(map(self._escape_key, identifiers))
mongo_collection = self.get_collection() mongo_collection = self.get_collection()
if len(identifiers) > 0: if len(identifiers) > 0:
@ -115,6 +125,62 @@ class Mongo(BaseDriver):
else: else:
await mongo_collection.delete_one({"_id": self.unique_cog_identifier}) await mongo_collection.delete_one({"_id": self.unique_cog_identifier})
@staticmethod
def _escape_key(key: str) -> str:
return _SPECIAL_CHAR_PATTERN.sub(_replace_with_escaped, key)
@staticmethod
def _unescape_key(key: str) -> str:
return _CHAR_ESCAPE_PATTERN.sub(_replace_with_unescaped, key)
@classmethod
def _escape_dict_keys(cls, data: dict) -> dict:
"""Recursively escape all keys in a dict."""
ret = {}
for key, value in data.items():
key = cls._escape_key(key)
if isinstance(value, dict):
value = cls._escape_dict_keys(value)
ret[key] = value
return ret
@classmethod
def _unescape_dict_keys(cls, data: dict) -> dict:
"""Recursively unescape all keys in a dict."""
ret = {}
for key, value in data.items():
key = cls._unescape_key(key)
if isinstance(value, dict):
value = cls._unescape_dict_keys(value)
ret[key] = value
return ret
_SPECIAL_CHAR_PATTERN: Pattern[str] = re.compile(r"([.$]|\\U0000002E|\\U00000024)")
_SPECIAL_CHARS = {
".": "\\U0000002E",
"$": "\\U00000024",
"\\U0000002E": "\\U&0000002E",
"\\U00000024": "\\U&00000024",
}
def _replace_with_escaped(match: Match[str]) -> str:
return _SPECIAL_CHARS[match[0]]
_CHAR_ESCAPE_PATTERN: Pattern[str] = re.compile(r"(\\U0000002E|\\U00000024)")
_CHAR_ESCAPES = {
"\\U0000002E": ".",
"\\U00000024": "$",
"\\U&0000002E": "\\U0000002E",
"\\U&00000024": "\\U00000024",
}
def _replace_with_unescaped(match: Match[str]) -> str:
return _CHAR_ESCAPES[match[0]]
def get_config_details(): def get_config_details():
uri = None uri = None

View File

@ -3,7 +3,7 @@ from redbot.cogs.permissions.permissions import Permissions, GLOBAL
def test_schema_update(): def test_schema_update():
old = { old = {
GLOBAL: { str(GLOBAL): {
"owner_models": { "owner_models": {
"cogs": { "cogs": {
"Admin": {"allow": [78631113035100160], "deny": [96733288462286848]}, "Admin": {"allow": [78631113035100160], "deny": [96733288462286848]},
@ -19,7 +19,7 @@ def test_schema_update():
}, },
} }
}, },
43733288462286848: { "43733288462286848": {
"owner_models": { "owner_models": {
"cogs": { "cogs": {
"Admin": { "Admin": {
@ -43,22 +43,22 @@ def test_schema_update():
assert new == ( assert new == (
{ {
"Admin": { "Admin": {
GLOBAL: {78631113035100160: True, 96733288462286848: False}, str(GLOBAL): {"78631113035100160": True, "96733288462286848": False},
43733288462286848: {24231113035100160: True, 35533288462286848: False}, "43733288462286848": {"24231113035100160": True, "35533288462286848": False},
}, },
"Audio": {GLOBAL: {133049272517001216: True, "default": False}}, "Audio": {str(GLOBAL): {"133049272517001216": True, "default": False}},
"General": {43733288462286848: {133049272517001216: True, "default": False}}, "General": {"43733288462286848": {"133049272517001216": True, "default": False}},
}, },
{ {
"cleanup bot": { "cleanup bot": {
GLOBAL: {78631113035100160: True, "default": False}, str(GLOBAL): {"78631113035100160": True, "default": False},
43733288462286848: {17831113035100160: True, "default": True}, "43733288462286848": {"17831113035100160": True, "default": True},
}, },
"ping": {GLOBAL: {96733288462286848: True, "default": True}}, "ping": {str(GLOBAL): {"96733288462286848": True, "default": True}},
"set adminrole": { "set adminrole": {
43733288462286848: { "43733288462286848": {
87733288462286848: True, "87733288462286848": True,
95433288462286848: False, "95433288462286848": False,
"default": True, "default": True,
} }
}, },

View File

@ -475,3 +475,18 @@ async def test_get_raw_mixes_defaults(config):
subgroup = await config.get_raw("subgroup") subgroup = await config.get_raw("subgroup")
assert subgroup == {"foo": True, "bar": False} assert subgroup == {"foo": True, "bar": False}
@pytest.mark.asyncio
async def test_cast_str_raw(config):
await config.set_raw(123, 456, value=True)
assert await config.get_raw(123, 456) is True
assert await config.get_raw("123", "456") is True
await config.clear_raw("123", 456)
@pytest.mark.asyncio
async def test_cast_str_nested(config):
config.register_global(foo={})
await config.foo.set({123: True, 456: {789: False}})
assert await config.foo() == {"123": True, "456": {"789": False}}