[Config] Group.__call__() has same behaviour as Group.all() (#2018)

* Make calling groups useful

This makes config.Group.__call__ effectively an alias for Group.all(),
with the added bonus of becoming a context manager.

get_raw has been updated as well to reflect the new behaviour of
__call__.

* Fix unintended side-effects of new behaviour

* Add tests

* Add test for get_raw mixing in defaults

* Another cleanup for relying on old behaviour internally

* Fix bank relying on old behaviour

* Reformat
This commit is contained in:
Toby Harradine 2018-08-26 23:30:36 +10:00 committed by GitHub
parent 48a7a21aca
commit dbed24aaca
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 99 additions and 29 deletions

View File

@ -400,19 +400,18 @@ async def get_account(member: Union[discord.Member, discord.User]) -> Account:
"""
if await is_global():
acc_data = (await _conf.user(member)()).copy()
default = _DEFAULT_USER.copy()
all_accounts = await _conf.all_users()
else:
acc_data = (await _conf.member(member)()).copy()
default = _DEFAULT_MEMBER.copy()
all_accounts = await _conf.all_members(member.guild)
if acc_data == {}:
acc_data = default
acc_data["name"] = member.display_name
if member.id not in all_accounts:
acc_data = {"name": member.display_name, "created_at": _DEFAULT_MEMBER["created_at"]}
try:
acc_data["balance"] = await get_default_balance(member.guild)
except AttributeError:
acc_data["balance"] = await get_default_balance()
else:
acc_data = all_accounts[member.id]
acc_data["created_at"] = _decode_time(acc_data["created_at"])
return Account(**acc_data)

View File

@ -1,7 +1,7 @@
import logging
import collections
from copy import deepcopy
from typing import Union, Tuple, TYPE_CHECKING
from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar, TYPE_CHECKING
import discord
@ -13,8 +13,10 @@ if TYPE_CHECKING:
log = logging.getLogger("red.config")
_T = TypeVar("_T")
class _ValueCtxManager:
class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]):
"""Context manager implementation of config values.
This class allows mutable config values to be both "get" and "set" from
@ -46,7 +48,7 @@ class _ValueCtxManager:
)
return self.raw_value
async def __aexit__(self, *exc_info):
async def __aexit__(self, exc_type, exc, tb):
if self.raw_value != self.__original_value:
await self.value_obj.set(self.raw_value)
@ -76,14 +78,14 @@ class Value:
def identifiers(self):
return tuple(str(i) for i in self._identifiers)
async def _get(self, default):
async def _get(self, default=...):
try:
ret = await self.driver.get(*self.identifiers)
except KeyError:
return default if default is not None else self.default
return default if default is not ... else self.default
return ret
def __call__(self, default=None):
def __call__(self, default=...) -> _ValueCtxManager[Any]:
"""Get the literal value of this data element.
Each `Value` object is created by the `Group.__getattr__` method. The
@ -187,6 +189,11 @@ class Group(Value):
def defaults(self):
return deepcopy(self._defaults)
async def _get(self, default: Dict[str, Any] = ...) -> Dict[str, Any]:
default = default if default is not ... else self.defaults
raw = await super()._get(default)
return self.nested_update(raw, default)
# noinspection PyTypeChecker
def __getattr__(self, item: str) -> Union["Group", Value]:
"""Get an attribute of this group.
@ -306,6 +313,11 @@ class Group(Value):
data = {"foo": {"bar": "baz"}}
d = data["foo"]["bar"]
Note
----
If retreiving a sub-group, the return value of this method will
include registered defaults for values which have not yet been set.
Parameters
----------
nested_path : str
@ -339,15 +351,22 @@ class Group(Value):
default = poss_default
try:
return await self.driver.get(*self.identifiers, *path)
raw = await self.driver.get(*self.identifiers, *path)
except KeyError:
if default is not ...:
return default
raise
else:
if isinstance(default, dict):
return self.nested_update(raw, default)
return raw
async def all(self) -> dict:
def all(self) -> _ValueCtxManager[Dict[str, Any]]:
"""Get a dictionary representation of this group's data.
The return value of this method can also be used as an asynchronous
context manager, i.e. with :code:`async with` syntax.
Note
----
The return value of this method will include registered defaults for
@ -359,16 +378,18 @@ class Group(Value):
All of this Group's attributes, resolved as raw data values.
"""
return self.nested_update(await self())
return self()
def nested_update(self, current, defaults=None):
def nested_update(
self, current: collections.Mapping, defaults: Dict[str, Any] = ...
) -> Dict[str, Any]:
"""Robust updater for nested dictionaries
If no defaults are passed, then the instance attribute 'defaults'
will be used.
"""
if not defaults:
if defaults is ...:
defaults = self.defaults
for key, value in current.items():
@ -844,7 +865,7 @@ class Config:
"""
return self._get_base_group(group_identifier, *identifiers)
async def _all_from_scope(self, scope: str):
async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]:
"""Get a dict of all values from a particular scope of data.
:code:`scope` must be one of the constants attributed to
@ -856,12 +877,18 @@ class Config:
overwritten.
"""
group = self._get_base_group(scope)
dict_ = await group()
ret = {}
try:
dict_ = await self.driver.get(*group.identifiers)
except KeyError:
pass
else:
for k, v in dict_.items():
data = group.defaults
data.update(v)
ret[int(k)] = data
return ret
async def all_guilds(self) -> dict:
@ -968,12 +995,20 @@ class Config:
ret = {}
if guild is None:
group = self._get_base_group(self.MEMBER)
dict_ = await group()
try:
dict_ = await self.driver.get(*group.identifiers)
except KeyError:
pass
else:
for guild_id, guild_data in dict_.items():
ret[int(guild_id)] = self._all_members_from_guild(group, guild_data)
else:
group = self._get_base_group(self.MEMBER, guild.id)
guild_data = await group()
try:
guild_data = await self.driver.get(*group.identifiers)
except KeyError:
pass
else:
ret = self._all_members_from_guild(group, guild_data)
return ret

View File

@ -430,3 +430,39 @@ async def test_set_then_mutate(config):
list1.append("foo")
list1 = await config.list1()
assert "foo" not in list1
@pytest.mark.asyncio
async def test_call_group_fills_defaults(config):
config.register_global(subgroup={"foo": True})
subgroup = await config.subgroup()
assert "foo" in subgroup
@pytest.mark.asyncio
async def test_group_call_ctxmgr_writes(config):
config.register_global(subgroup={"foo": True})
async with config.subgroup() as subgroup:
subgroup["bar"] = False
subgroup = await config.subgroup()
assert subgroup == {"foo": True, "bar": False}
@pytest.mark.asyncio
async def test_all_works_as_ctxmgr(config):
config.register_global(subgroup={"foo": True})
async with config.subgroup.all() as subgroup:
subgroup["bar"] = False
subgroup = await config.subgroup()
assert subgroup == {"foo": True, "bar": False}
@pytest.mark.asyncio
async def test_get_raw_mixes_defaults(config):
config.register_global(subgroup={"foo": True})
await config.subgroup.set_raw("bar", value=False)
subgroup = await config.get_raw("subgroup")
assert subgroup == {"foo": True, "bar": False}