Various Config and Mongo Driver fixes (#2795)

- Fixes defaults being mixed into custom groups above the document level when doing `Group.all()`
- Fixes `Config.clear_all()` with Mongo driver
- Fixes `Group.set()` with Mongo driver on custom groups above the document level
- Fixes `IdentifierData.custom_group_data` being set to the wrong thing in `BaseDriver.import/export_data` (although this was an inconsequential bug)

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
Toby Harradine
2019-06-24 12:55:49 +10:00
committed by GitHub
parent 6ae3040aac
commit 71d0bd0d07
4 changed files with 160 additions and 39 deletions

View File

@@ -1,16 +1,14 @@
import logging
import collections
from copy import deepcopy
from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar, TYPE_CHECKING
from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar
import weakref
import discord
from .data_manager import cog_data_path, core_data_path
from .drivers import get_driver, IdentifierData, BackendType
if TYPE_CHECKING:
from .drivers.red_base import BaseDriver
from .drivers.red_base import BaseDriver
__all__ = ["Config", "get_latest_confs"]
@@ -545,7 +543,7 @@ class Config:
self,
cog_name: str,
unique_identifier: str,
driver: "BaseDriver",
driver: BaseDriver,
force_registration: bool = False,
defaults: dict = None,
):
@@ -852,9 +850,16 @@ class Config:
custom_group_data=self.custom_groups,
is_custom=is_custom,
)
pkey_len = BaseDriver.get_pkey_len(identifier_data)
if len(primary_keys) < pkey_len:
# Don't mix in defaults with groups higher than the document level
defaults = {}
else:
defaults = self.defaults.get(category, {})
return Group(
identifier_data=identifier_data,
defaults=self.defaults.get(category, {}),
defaults=defaults,
driver=self.driver,
force_registration=self.force_registration,
)
@@ -975,6 +980,7 @@ class Config:
"""
group = self._get_base_group(scope)
ret = {}
defaults = self.defaults.get(scope, {})
try:
dict_ = await self.driver.get(group.identifier_data)
@@ -982,7 +988,7 @@ class Config:
pass
else:
for k, v in dict_.items():
data = group.defaults
data = deepcopy(defaults)
data.update(v)
ret[int(k)] = data
@@ -1056,11 +1062,11 @@ class Config:
"""
return await self._all_from_scope(self.USER)
@staticmethod
def _all_members_from_guild(group: Group, guild_data: dict) -> dict:
def _all_members_from_guild(self, guild_data: dict) -> dict:
ret = {}
defaults = self.defaults.get(self.MEMBER, {})
for member_id, member_data in guild_data.items():
new_member_data = group.defaults
new_member_data = deepcopy(defaults)
new_member_data.update(member_data)
ret[int(member_id)] = new_member_data
return ret
@@ -1099,7 +1105,7 @@ class Config:
pass
else:
for guild_id, guild_data in dict_.items():
ret[int(guild_id)] = self._all_members_from_guild(group, guild_data)
ret[int(guild_id)] = self._all_members_from_guild(guild_data)
else:
group = self._get_base_group(self.MEMBER, str(guild.id))
try:
@@ -1107,7 +1113,7 @@ class Config:
except KeyError:
pass
else:
ret = self._all_members_from_guild(group, guild_data)
ret = self._all_members_from_guild(guild_data)
return ret
async def _clear_scope(self, *scopes: str):