From dbed24aacaae7d9f3c393b5c00290bc120a16896 Mon Sep 17 00:00:00 2001 From: Toby Harradine Date: Sun, 26 Aug 2018 23:30:36 +1000 Subject: [PATCH] [Config] Group.__call__() has same behaviour as Group.all() (#2018) * Make calling groups useful This makes config.Group.__call__ effectively an alias for Group.all(), with the added bonus of becoming a context manager. get_raw has been updated as well to reflect the new behaviour of __call__. * Fix unintended side-effects of new behaviour * Add tests * Add test for get_raw mixing in defaults * Another cleanup for relying on old behaviour internally * Fix bank relying on old behaviour * Reformat --- redbot/core/bank.py | 13 +++---- redbot/core/config.py | 79 ++++++++++++++++++++++++++++----------- tests/core/test_config.py | 36 ++++++++++++++++++ 3 files changed, 99 insertions(+), 29 deletions(-) diff --git a/redbot/core/bank.py b/redbot/core/bank.py index 04362c2f3..2bedae9a7 100644 --- a/redbot/core/bank.py +++ b/redbot/core/bank.py @@ -400,19 +400,18 @@ async def get_account(member: Union[discord.Member, discord.User]) -> Account: """ if await is_global(): - acc_data = (await _conf.user(member)()).copy() - default = _DEFAULT_USER.copy() + all_accounts = await _conf.all_users() else: - acc_data = (await _conf.member(member)()).copy() - default = _DEFAULT_MEMBER.copy() + all_accounts = await _conf.all_members(member.guild) - if acc_data == {}: - acc_data = default - acc_data["name"] = member.display_name + if member.id not in all_accounts: + acc_data = {"name": member.display_name, "created_at": _DEFAULT_MEMBER["created_at"]} try: acc_data["balance"] = await get_default_balance(member.guild) except AttributeError: acc_data["balance"] = await get_default_balance() + else: + acc_data = all_accounts[member.id] acc_data["created_at"] = _decode_time(acc_data["created_at"]) return Account(**acc_data) diff --git a/redbot/core/config.py b/redbot/core/config.py index cf95ed640..c0b411666 100644 --- a/redbot/core/config.py +++ b/redbot/core/config.py @@ -1,7 +1,7 @@ import logging import collections from copy import deepcopy -from typing import Union, Tuple, TYPE_CHECKING +from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar, TYPE_CHECKING import discord @@ -13,8 +13,10 @@ if TYPE_CHECKING: log = logging.getLogger("red.config") +_T = TypeVar("_T") -class _ValueCtxManager: + +class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]): """Context manager implementation of config values. This class allows mutable config values to be both "get" and "set" from @@ -46,7 +48,7 @@ class _ValueCtxManager: ) return self.raw_value - async def __aexit__(self, *exc_info): + async def __aexit__(self, exc_type, exc, tb): if self.raw_value != self.__original_value: await self.value_obj.set(self.raw_value) @@ -76,14 +78,14 @@ class Value: def identifiers(self): return tuple(str(i) for i in self._identifiers) - async def _get(self, default): + async def _get(self, default=...): try: ret = await self.driver.get(*self.identifiers) except KeyError: - return default if default is not None else self.default + return default if default is not ... else self.default return ret - def __call__(self, default=None): + def __call__(self, default=...) -> _ValueCtxManager[Any]: """Get the literal value of this data element. Each `Value` object is created by the `Group.__getattr__` method. The @@ -187,6 +189,11 @@ class Group(Value): def defaults(self): return deepcopy(self._defaults) + async def _get(self, default: Dict[str, Any] = ...) -> Dict[str, Any]: + default = default if default is not ... else self.defaults + raw = await super()._get(default) + return self.nested_update(raw, default) + # noinspection PyTypeChecker def __getattr__(self, item: str) -> Union["Group", Value]: """Get an attribute of this group. @@ -306,6 +313,11 @@ class Group(Value): data = {"foo": {"bar": "baz"}} d = data["foo"]["bar"] + Note + ---- + If retreiving a sub-group, the return value of this method will + include registered defaults for values which have not yet been set. + Parameters ---------- nested_path : str @@ -339,15 +351,22 @@ class Group(Value): default = poss_default try: - return await self.driver.get(*self.identifiers, *path) + raw = await self.driver.get(*self.identifiers, *path) except KeyError: if default is not ...: return default raise + else: + if isinstance(default, dict): + return self.nested_update(raw, default) + return raw - async def all(self) -> dict: + def all(self) -> _ValueCtxManager[Dict[str, Any]]: """Get a dictionary representation of this group's data. + The return value of this method can also be used as an asynchronous + context manager, i.e. with :code:`async with` syntax. + Note ---- The return value of this method will include registered defaults for @@ -359,16 +378,18 @@ class Group(Value): All of this Group's attributes, resolved as raw data values. """ - return self.nested_update(await self()) + return self() - def nested_update(self, current, defaults=None): + def nested_update( + self, current: collections.Mapping, defaults: Dict[str, Any] = ... + ) -> Dict[str, Any]: """Robust updater for nested dictionaries If no defaults are passed, then the instance attribute 'defaults' will be used. """ - if not defaults: + if defaults is ...: defaults = self.defaults for key, value in current.items(): @@ -844,7 +865,7 @@ class Config: """ return self._get_base_group(group_identifier, *identifiers) - async def _all_from_scope(self, scope: str): + async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]: """Get a dict of all values from a particular scope of data. :code:`scope` must be one of the constants attributed to @@ -856,12 +877,18 @@ class Config: overwritten. """ group = self._get_base_group(scope) - dict_ = await group() ret = {} - for k, v in dict_.items(): - data = group.defaults - data.update(v) - ret[int(k)] = data + + try: + dict_ = await self.driver.get(*group.identifiers) + except KeyError: + pass + else: + for k, v in dict_.items(): + data = group.defaults + data.update(v) + ret[int(k)] = data + return ret async def all_guilds(self) -> dict: @@ -968,13 +995,21 @@ class Config: ret = {} if guild is None: group = self._get_base_group(self.MEMBER) - dict_ = await group() - for guild_id, guild_data in dict_.items(): - ret[int(guild_id)] = self._all_members_from_guild(group, guild_data) + try: + dict_ = await self.driver.get(*group.identifiers) + except KeyError: + pass + else: + for guild_id, guild_data in dict_.items(): + ret[int(guild_id)] = self._all_members_from_guild(group, guild_data) else: group = self._get_base_group(self.MEMBER, guild.id) - guild_data = await group() - ret = self._all_members_from_guild(group, guild_data) + try: + guild_data = await self.driver.get(*group.identifiers) + except KeyError: + pass + else: + ret = self._all_members_from_guild(group, guild_data) return ret async def _clear_scope(self, *scopes: str): diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 1a4541647..c333201e0 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -430,3 +430,39 @@ async def test_set_then_mutate(config): list1.append("foo") list1 = await config.list1() assert "foo" not in list1 + + +@pytest.mark.asyncio +async def test_call_group_fills_defaults(config): + config.register_global(subgroup={"foo": True}) + subgroup = await config.subgroup() + assert "foo" in subgroup + + +@pytest.mark.asyncio +async def test_group_call_ctxmgr_writes(config): + config.register_global(subgroup={"foo": True}) + async with config.subgroup() as subgroup: + subgroup["bar"] = False + + subgroup = await config.subgroup() + assert subgroup == {"foo": True, "bar": False} + + +@pytest.mark.asyncio +async def test_all_works_as_ctxmgr(config): + config.register_global(subgroup={"foo": True}) + async with config.subgroup.all() as subgroup: + subgroup["bar"] = False + + subgroup = await config.subgroup() + assert subgroup == {"foo": True, "bar": False} + + +@pytest.mark.asyncio +async def test_get_raw_mixes_defaults(config): + config.register_global(subgroup={"foo": True}) + await config.subgroup.set_raw("bar", value=False) + + subgroup = await config.get_raw("subgroup") + assert subgroup == {"foo": True, "bar": False}