mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 03:08:55 -05:00
[Config] Group.__call__() has same behaviour as Group.all() (#2018)
* Make calling groups useful This makes config.Group.__call__ effectively an alias for Group.all(), with the added bonus of becoming a context manager. get_raw has been updated as well to reflect the new behaviour of __call__. * Fix unintended side-effects of new behaviour * Add tests * Add test for get_raw mixing in defaults * Another cleanup for relying on old behaviour internally * Fix bank relying on old behaviour * Reformat
This commit is contained in:
parent
48a7a21aca
commit
dbed24aaca
@ -400,19 +400,18 @@ async def get_account(member: Union[discord.Member, discord.User]) -> Account:
|
||||
|
||||
"""
|
||||
if await is_global():
|
||||
acc_data = (await _conf.user(member)()).copy()
|
||||
default = _DEFAULT_USER.copy()
|
||||
all_accounts = await _conf.all_users()
|
||||
else:
|
||||
acc_data = (await _conf.member(member)()).copy()
|
||||
default = _DEFAULT_MEMBER.copy()
|
||||
all_accounts = await _conf.all_members(member.guild)
|
||||
|
||||
if acc_data == {}:
|
||||
acc_data = default
|
||||
acc_data["name"] = member.display_name
|
||||
if member.id not in all_accounts:
|
||||
acc_data = {"name": member.display_name, "created_at": _DEFAULT_MEMBER["created_at"]}
|
||||
try:
|
||||
acc_data["balance"] = await get_default_balance(member.guild)
|
||||
except AttributeError:
|
||||
acc_data["balance"] = await get_default_balance()
|
||||
else:
|
||||
acc_data = all_accounts[member.id]
|
||||
|
||||
acc_data["created_at"] = _decode_time(acc_data["created_at"])
|
||||
return Account(**acc_data)
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import collections
|
||||
from copy import deepcopy
|
||||
from typing import Union, Tuple, TYPE_CHECKING
|
||||
from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar, TYPE_CHECKING
|
||||
|
||||
import discord
|
||||
|
||||
@ -13,8 +13,10 @@ if TYPE_CHECKING:
|
||||
|
||||
log = logging.getLogger("red.config")
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
class _ValueCtxManager:
|
||||
|
||||
class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]):
|
||||
"""Context manager implementation of config values.
|
||||
|
||||
This class allows mutable config values to be both "get" and "set" from
|
||||
@ -46,7 +48,7 @@ class _ValueCtxManager:
|
||||
)
|
||||
return self.raw_value
|
||||
|
||||
async def __aexit__(self, *exc_info):
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
if self.raw_value != self.__original_value:
|
||||
await self.value_obj.set(self.raw_value)
|
||||
|
||||
@ -76,14 +78,14 @@ class Value:
|
||||
def identifiers(self):
|
||||
return tuple(str(i) for i in self._identifiers)
|
||||
|
||||
async def _get(self, default):
|
||||
async def _get(self, default=...):
|
||||
try:
|
||||
ret = await self.driver.get(*self.identifiers)
|
||||
except KeyError:
|
||||
return default if default is not None else self.default
|
||||
return default if default is not ... else self.default
|
||||
return ret
|
||||
|
||||
def __call__(self, default=None):
|
||||
def __call__(self, default=...) -> _ValueCtxManager[Any]:
|
||||
"""Get the literal value of this data element.
|
||||
|
||||
Each `Value` object is created by the `Group.__getattr__` method. The
|
||||
@ -187,6 +189,11 @@ class Group(Value):
|
||||
def defaults(self):
|
||||
return deepcopy(self._defaults)
|
||||
|
||||
async def _get(self, default: Dict[str, Any] = ...) -> Dict[str, Any]:
|
||||
default = default if default is not ... else self.defaults
|
||||
raw = await super()._get(default)
|
||||
return self.nested_update(raw, default)
|
||||
|
||||
# noinspection PyTypeChecker
|
||||
def __getattr__(self, item: str) -> Union["Group", Value]:
|
||||
"""Get an attribute of this group.
|
||||
@ -306,6 +313,11 @@ class Group(Value):
|
||||
data = {"foo": {"bar": "baz"}}
|
||||
d = data["foo"]["bar"]
|
||||
|
||||
Note
|
||||
----
|
||||
If retreiving a sub-group, the return value of this method will
|
||||
include registered defaults for values which have not yet been set.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
nested_path : str
|
||||
@ -339,15 +351,22 @@ class Group(Value):
|
||||
default = poss_default
|
||||
|
||||
try:
|
||||
return await self.driver.get(*self.identifiers, *path)
|
||||
raw = await self.driver.get(*self.identifiers, *path)
|
||||
except KeyError:
|
||||
if default is not ...:
|
||||
return default
|
||||
raise
|
||||
else:
|
||||
if isinstance(default, dict):
|
||||
return self.nested_update(raw, default)
|
||||
return raw
|
||||
|
||||
async def all(self) -> dict:
|
||||
def all(self) -> _ValueCtxManager[Dict[str, Any]]:
|
||||
"""Get a dictionary representation of this group's data.
|
||||
|
||||
The return value of this method can also be used as an asynchronous
|
||||
context manager, i.e. with :code:`async with` syntax.
|
||||
|
||||
Note
|
||||
----
|
||||
The return value of this method will include registered defaults for
|
||||
@ -359,16 +378,18 @@ class Group(Value):
|
||||
All of this Group's attributes, resolved as raw data values.
|
||||
|
||||
"""
|
||||
return self.nested_update(await self())
|
||||
return self()
|
||||
|
||||
def nested_update(self, current, defaults=None):
|
||||
def nested_update(
|
||||
self, current: collections.Mapping, defaults: Dict[str, Any] = ...
|
||||
) -> Dict[str, Any]:
|
||||
"""Robust updater for nested dictionaries
|
||||
|
||||
If no defaults are passed, then the instance attribute 'defaults'
|
||||
will be used.
|
||||
|
||||
"""
|
||||
if not defaults:
|
||||
if defaults is ...:
|
||||
defaults = self.defaults
|
||||
|
||||
for key, value in current.items():
|
||||
@ -844,7 +865,7 @@ class Config:
|
||||
"""
|
||||
return self._get_base_group(group_identifier, *identifiers)
|
||||
|
||||
async def _all_from_scope(self, scope: str):
|
||||
async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]:
|
||||
"""Get a dict of all values from a particular scope of data.
|
||||
|
||||
:code:`scope` must be one of the constants attributed to
|
||||
@ -856,12 +877,18 @@ class Config:
|
||||
overwritten.
|
||||
"""
|
||||
group = self._get_base_group(scope)
|
||||
dict_ = await group()
|
||||
ret = {}
|
||||
for k, v in dict_.items():
|
||||
data = group.defaults
|
||||
data.update(v)
|
||||
ret[int(k)] = data
|
||||
|
||||
try:
|
||||
dict_ = await self.driver.get(*group.identifiers)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
for k, v in dict_.items():
|
||||
data = group.defaults
|
||||
data.update(v)
|
||||
ret[int(k)] = data
|
||||
|
||||
return ret
|
||||
|
||||
async def all_guilds(self) -> dict:
|
||||
@ -968,13 +995,21 @@ class Config:
|
||||
ret = {}
|
||||
if guild is None:
|
||||
group = self._get_base_group(self.MEMBER)
|
||||
dict_ = await group()
|
||||
for guild_id, guild_data in dict_.items():
|
||||
ret[int(guild_id)] = self._all_members_from_guild(group, guild_data)
|
||||
try:
|
||||
dict_ = await self.driver.get(*group.identifiers)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
for guild_id, guild_data in dict_.items():
|
||||
ret[int(guild_id)] = self._all_members_from_guild(group, guild_data)
|
||||
else:
|
||||
group = self._get_base_group(self.MEMBER, guild.id)
|
||||
guild_data = await group()
|
||||
ret = self._all_members_from_guild(group, guild_data)
|
||||
try:
|
||||
guild_data = await self.driver.get(*group.identifiers)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
ret = self._all_members_from_guild(group, guild_data)
|
||||
return ret
|
||||
|
||||
async def _clear_scope(self, *scopes: str):
|
||||
|
||||
@ -430,3 +430,39 @@ async def test_set_then_mutate(config):
|
||||
list1.append("foo")
|
||||
list1 = await config.list1()
|
||||
assert "foo" not in list1
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_call_group_fills_defaults(config):
|
||||
config.register_global(subgroup={"foo": True})
|
||||
subgroup = await config.subgroup()
|
||||
assert "foo" in subgroup
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_group_call_ctxmgr_writes(config):
|
||||
config.register_global(subgroup={"foo": True})
|
||||
async with config.subgroup() as subgroup:
|
||||
subgroup["bar"] = False
|
||||
|
||||
subgroup = await config.subgroup()
|
||||
assert subgroup == {"foo": True, "bar": False}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_works_as_ctxmgr(config):
|
||||
config.register_global(subgroup={"foo": True})
|
||||
async with config.subgroup.all() as subgroup:
|
||||
subgroup["bar"] = False
|
||||
|
||||
subgroup = await config.subgroup()
|
||||
assert subgroup == {"foo": True, "bar": False}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_raw_mixes_defaults(config):
|
||||
config.register_global(subgroup={"foo": True})
|
||||
await config.subgroup.set_raw("bar", value=False)
|
||||
|
||||
subgroup = await config.get_raw("subgroup")
|
||||
assert subgroup == {"foo": True, "bar": False}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user