[Config] Group.__call__() has same behaviour as Group.all() (#2018)

* Make calling groups useful

This makes config.Group.__call__ effectively an alias for Group.all(),
with the added bonus of becoming a context manager.

get_raw has been updated as well to reflect the new behaviour of
__call__.

* Fix unintended side-effects of new behaviour

* Add tests

* Add test for get_raw mixing in defaults

* Another cleanup for relying on old behaviour internally

* Fix bank relying on old behaviour

* Reformat
This commit is contained in:
Toby Harradine 2018-08-26 23:30:36 +10:00 committed by GitHub
parent 48a7a21aca
commit dbed24aaca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 29 deletions

View File

@ -400,19 +400,18 @@ async def get_account(member: Union[discord.Member, discord.User]) -> Account:
""" """
if await is_global(): if await is_global():
acc_data = (await _conf.user(member)()).copy() all_accounts = await _conf.all_users()
default = _DEFAULT_USER.copy()
else: else:
acc_data = (await _conf.member(member)()).copy() all_accounts = await _conf.all_members(member.guild)
default = _DEFAULT_MEMBER.copy()
if acc_data == {}: if member.id not in all_accounts:
acc_data = default acc_data = {"name": member.display_name, "created_at": _DEFAULT_MEMBER["created_at"]}
acc_data["name"] = member.display_name
try: try:
acc_data["balance"] = await get_default_balance(member.guild) acc_data["balance"] = await get_default_balance(member.guild)
except AttributeError: except AttributeError:
acc_data["balance"] = await get_default_balance() acc_data["balance"] = await get_default_balance()
else:
acc_data = all_accounts[member.id]
acc_data["created_at"] = _decode_time(acc_data["created_at"]) acc_data["created_at"] = _decode_time(acc_data["created_at"])
return Account(**acc_data) return Account(**acc_data)

View File

@ -1,7 +1,7 @@
import logging import logging
import collections import collections
from copy import deepcopy from copy import deepcopy
from typing import Union, Tuple, TYPE_CHECKING from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar, TYPE_CHECKING
import discord import discord
@ -13,8 +13,10 @@ if TYPE_CHECKING:
log = logging.getLogger("red.config") log = logging.getLogger("red.config")
_T = TypeVar("_T")
class _ValueCtxManager:
class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]):
"""Context manager implementation of config values. """Context manager implementation of config values.
This class allows mutable config values to be both "get" and "set" from This class allows mutable config values to be both "get" and "set" from
@ -46,7 +48,7 @@ class _ValueCtxManager:
) )
return self.raw_value return self.raw_value
async def __aexit__(self, *exc_info): async def __aexit__(self, exc_type, exc, tb):
if self.raw_value != self.__original_value: if self.raw_value != self.__original_value:
await self.value_obj.set(self.raw_value) await self.value_obj.set(self.raw_value)
@ -76,14 +78,14 @@ class Value:
def identifiers(self): def identifiers(self):
return tuple(str(i) for i in self._identifiers) return tuple(str(i) for i in self._identifiers)
async def _get(self, default): async def _get(self, default=...):
try: try:
ret = await self.driver.get(*self.identifiers) ret = await self.driver.get(*self.identifiers)
except KeyError: except KeyError:
return default if default is not None else self.default return default if default is not ... else self.default
return ret return ret
def __call__(self, default=None): def __call__(self, default=...) -> _ValueCtxManager[Any]:
"""Get the literal value of this data element. """Get the literal value of this data element.
Each `Value` object is created by the `Group.__getattr__` method. The Each `Value` object is created by the `Group.__getattr__` method. The
@ -187,6 +189,11 @@ class Group(Value):
def defaults(self): def defaults(self):
return deepcopy(self._defaults) return deepcopy(self._defaults)
async def _get(self, default: Dict[str, Any] = ...) -> Dict[str, Any]:
default = default if default is not ... else self.defaults
raw = await super()._get(default)
return self.nested_update(raw, default)
# noinspection PyTypeChecker # noinspection PyTypeChecker
def __getattr__(self, item: str) -> Union["Group", Value]: def __getattr__(self, item: str) -> Union["Group", Value]:
"""Get an attribute of this group. """Get an attribute of this group.
@ -306,6 +313,11 @@ class Group(Value):
data = {"foo": {"bar": "baz"}} data = {"foo": {"bar": "baz"}}
d = data["foo"]["bar"] d = data["foo"]["bar"]
Note
----
If retreiving a sub-group, the return value of this method will
include registered defaults for values which have not yet been set.
Parameters Parameters
---------- ----------
nested_path : str nested_path : str
@ -339,15 +351,22 @@ class Group(Value):
default = poss_default default = poss_default
try: try:
return await self.driver.get(*self.identifiers, *path) raw = await self.driver.get(*self.identifiers, *path)
except KeyError: except KeyError:
if default is not ...: if default is not ...:
return default return default
raise raise
else:
if isinstance(default, dict):
return self.nested_update(raw, default)
return raw
async def all(self) -> dict: def all(self) -> _ValueCtxManager[Dict[str, Any]]:
"""Get a dictionary representation of this group's data. """Get a dictionary representation of this group's data.
The return value of this method can also be used as an asynchronous
context manager, i.e. with :code:`async with` syntax.
Note Note
---- ----
The return value of this method will include registered defaults for The return value of this method will include registered defaults for
@ -359,16 +378,18 @@ class Group(Value):
All of this Group's attributes, resolved as raw data values. All of this Group's attributes, resolved as raw data values.
""" """
return self.nested_update(await self()) return self()
def nested_update(self, current, defaults=None): def nested_update(
self, current: collections.Mapping, defaults: Dict[str, Any] = ...
) -> Dict[str, Any]:
"""Robust updater for nested dictionaries """Robust updater for nested dictionaries
If no defaults are passed, then the instance attribute 'defaults' If no defaults are passed, then the instance attribute 'defaults'
will be used. will be used.
""" """
if not defaults: if defaults is ...:
defaults = self.defaults defaults = self.defaults
for key, value in current.items(): for key, value in current.items():
@ -844,7 +865,7 @@ class Config:
""" """
return self._get_base_group(group_identifier, *identifiers) return self._get_base_group(group_identifier, *identifiers)
async def _all_from_scope(self, scope: str): async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]:
"""Get a dict of all values from a particular scope of data. """Get a dict of all values from a particular scope of data.
:code:`scope` must be one of the constants attributed to :code:`scope` must be one of the constants attributed to
@ -856,12 +877,18 @@ class Config:
overwritten. overwritten.
""" """
group = self._get_base_group(scope) group = self._get_base_group(scope)
dict_ = await group()
ret = {} ret = {}
for k, v in dict_.items():
data = group.defaults try:
data.update(v) dict_ = await self.driver.get(*group.identifiers)
ret[int(k)] = data except KeyError:
pass
else:
for k, v in dict_.items():
data = group.defaults
data.update(v)
ret[int(k)] = data
return ret return ret
async def all_guilds(self) -> dict: async def all_guilds(self) -> dict:
@ -968,13 +995,21 @@ class Config:
ret = {} ret = {}
if guild is None: if guild is None:
group = self._get_base_group(self.MEMBER) group = self._get_base_group(self.MEMBER)
dict_ = await group() try:
for guild_id, guild_data in dict_.items(): dict_ = await self.driver.get(*group.identifiers)
ret[int(guild_id)] = self._all_members_from_guild(group, guild_data) except KeyError:
pass
else:
for guild_id, guild_data in dict_.items():
ret[int(guild_id)] = self._all_members_from_guild(group, guild_data)
else: else:
group = self._get_base_group(self.MEMBER, guild.id) group = self._get_base_group(self.MEMBER, guild.id)
guild_data = await group() try:
ret = self._all_members_from_guild(group, guild_data) guild_data = await self.driver.get(*group.identifiers)
except KeyError:
pass
else:
ret = self._all_members_from_guild(group, guild_data)
return ret return ret
async def _clear_scope(self, *scopes: str): async def _clear_scope(self, *scopes: str):

View File

@ -430,3 +430,39 @@ async def test_set_then_mutate(config):
list1.append("foo") list1.append("foo")
list1 = await config.list1() list1 = await config.list1()
assert "foo" not in list1 assert "foo" not in list1
@pytest.mark.asyncio
async def test_call_group_fills_defaults(config):
config.register_global(subgroup={"foo": True})
subgroup = await config.subgroup()
assert "foo" in subgroup
@pytest.mark.asyncio
async def test_group_call_ctxmgr_writes(config):
config.register_global(subgroup={"foo": True})
async with config.subgroup() as subgroup:
subgroup["bar"] = False
subgroup = await config.subgroup()
assert subgroup == {"foo": True, "bar": False}
@pytest.mark.asyncio
async def test_all_works_as_ctxmgr(config):
config.register_global(subgroup={"foo": True})
async with config.subgroup.all() as subgroup:
subgroup["bar"] = False
subgroup = await config.subgroup()
assert subgroup == {"foo": True, "bar": False}
@pytest.mark.asyncio
async def test_get_raw_mixes_defaults(config):
config.register_global(subgroup={"foo": True})
await config.subgroup.set_raw("bar", value=False)
subgroup = await config.get_raw("subgroup")
assert subgroup == {"foo": True, "bar": False}