mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 03:08:55 -05:00
[Config] Group.__call__() has same behaviour as Group.all() (#2018)
* Make calling groups useful This makes config.Group.__call__ effectively an alias for Group.all(), with the added bonus of becoming a context manager. get_raw has been updated as well to reflect the new behaviour of __call__. * Fix unintended side-effects of new behaviour * Add tests * Add test for get_raw mixing in defaults * Another cleanup for relying on old behaviour internally * Fix bank relying on old behaviour * Reformat
This commit is contained in:
parent
48a7a21aca
commit
dbed24aaca
@ -400,19 +400,18 @@ async def get_account(member: Union[discord.Member, discord.User]) -> Account:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if await is_global():
|
if await is_global():
|
||||||
acc_data = (await _conf.user(member)()).copy()
|
all_accounts = await _conf.all_users()
|
||||||
default = _DEFAULT_USER.copy()
|
|
||||||
else:
|
else:
|
||||||
acc_data = (await _conf.member(member)()).copy()
|
all_accounts = await _conf.all_members(member.guild)
|
||||||
default = _DEFAULT_MEMBER.copy()
|
|
||||||
|
|
||||||
if acc_data == {}:
|
if member.id not in all_accounts:
|
||||||
acc_data = default
|
acc_data = {"name": member.display_name, "created_at": _DEFAULT_MEMBER["created_at"]}
|
||||||
acc_data["name"] = member.display_name
|
|
||||||
try:
|
try:
|
||||||
acc_data["balance"] = await get_default_balance(member.guild)
|
acc_data["balance"] = await get_default_balance(member.guild)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
acc_data["balance"] = await get_default_balance()
|
acc_data["balance"] = await get_default_balance()
|
||||||
|
else:
|
||||||
|
acc_data = all_accounts[member.id]
|
||||||
|
|
||||||
acc_data["created_at"] = _decode_time(acc_data["created_at"])
|
acc_data["created_at"] = _decode_time(acc_data["created_at"])
|
||||||
return Account(**acc_data)
|
return Account(**acc_data)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
import logging
|
import logging
|
||||||
import collections
|
import collections
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Union, Tuple, TYPE_CHECKING
|
from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar, TYPE_CHECKING
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
@ -13,8 +13,10 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
log = logging.getLogger("red.config")
|
log = logging.getLogger("red.config")
|
||||||
|
|
||||||
|
_T = TypeVar("_T")
|
||||||
|
|
||||||
class _ValueCtxManager:
|
|
||||||
|
class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]):
|
||||||
"""Context manager implementation of config values.
|
"""Context manager implementation of config values.
|
||||||
|
|
||||||
This class allows mutable config values to be both "get" and "set" from
|
This class allows mutable config values to be both "get" and "set" from
|
||||||
@ -46,7 +48,7 @@ class _ValueCtxManager:
|
|||||||
)
|
)
|
||||||
return self.raw_value
|
return self.raw_value
|
||||||
|
|
||||||
async def __aexit__(self, *exc_info):
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
if self.raw_value != self.__original_value:
|
if self.raw_value != self.__original_value:
|
||||||
await self.value_obj.set(self.raw_value)
|
await self.value_obj.set(self.raw_value)
|
||||||
|
|
||||||
@ -76,14 +78,14 @@ class Value:
|
|||||||
def identifiers(self):
|
def identifiers(self):
|
||||||
return tuple(str(i) for i in self._identifiers)
|
return tuple(str(i) for i in self._identifiers)
|
||||||
|
|
||||||
async def _get(self, default):
|
async def _get(self, default=...):
|
||||||
try:
|
try:
|
||||||
ret = await self.driver.get(*self.identifiers)
|
ret = await self.driver.get(*self.identifiers)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return default if default is not None else self.default
|
return default if default is not ... else self.default
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def __call__(self, default=None):
|
def __call__(self, default=...) -> _ValueCtxManager[Any]:
|
||||||
"""Get the literal value of this data element.
|
"""Get the literal value of this data element.
|
||||||
|
|
||||||
Each `Value` object is created by the `Group.__getattr__` method. The
|
Each `Value` object is created by the `Group.__getattr__` method. The
|
||||||
@ -187,6 +189,11 @@ class Group(Value):
|
|||||||
def defaults(self):
|
def defaults(self):
|
||||||
return deepcopy(self._defaults)
|
return deepcopy(self._defaults)
|
||||||
|
|
||||||
|
async def _get(self, default: Dict[str, Any] = ...) -> Dict[str, Any]:
|
||||||
|
default = default if default is not ... else self.defaults
|
||||||
|
raw = await super()._get(default)
|
||||||
|
return self.nested_update(raw, default)
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
def __getattr__(self, item: str) -> Union["Group", Value]:
|
def __getattr__(self, item: str) -> Union["Group", Value]:
|
||||||
"""Get an attribute of this group.
|
"""Get an attribute of this group.
|
||||||
@ -306,6 +313,11 @@ class Group(Value):
|
|||||||
data = {"foo": {"bar": "baz"}}
|
data = {"foo": {"bar": "baz"}}
|
||||||
d = data["foo"]["bar"]
|
d = data["foo"]["bar"]
|
||||||
|
|
||||||
|
Note
|
||||||
|
----
|
||||||
|
If retreiving a sub-group, the return value of this method will
|
||||||
|
include registered defaults for values which have not yet been set.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
nested_path : str
|
nested_path : str
|
||||||
@ -339,15 +351,22 @@ class Group(Value):
|
|||||||
default = poss_default
|
default = poss_default
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self.driver.get(*self.identifiers, *path)
|
raw = await self.driver.get(*self.identifiers, *path)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if default is not ...:
|
if default is not ...:
|
||||||
return default
|
return default
|
||||||
raise
|
raise
|
||||||
|
else:
|
||||||
|
if isinstance(default, dict):
|
||||||
|
return self.nested_update(raw, default)
|
||||||
|
return raw
|
||||||
|
|
||||||
async def all(self) -> dict:
|
def all(self) -> _ValueCtxManager[Dict[str, Any]]:
|
||||||
"""Get a dictionary representation of this group's data.
|
"""Get a dictionary representation of this group's data.
|
||||||
|
|
||||||
|
The return value of this method can also be used as an asynchronous
|
||||||
|
context manager, i.e. with :code:`async with` syntax.
|
||||||
|
|
||||||
Note
|
Note
|
||||||
----
|
----
|
||||||
The return value of this method will include registered defaults for
|
The return value of this method will include registered defaults for
|
||||||
@ -359,16 +378,18 @@ class Group(Value):
|
|||||||
All of this Group's attributes, resolved as raw data values.
|
All of this Group's attributes, resolved as raw data values.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return self.nested_update(await self())
|
return self()
|
||||||
|
|
||||||
def nested_update(self, current, defaults=None):
|
def nested_update(
|
||||||
|
self, current: collections.Mapping, defaults: Dict[str, Any] = ...
|
||||||
|
) -> Dict[str, Any]:
|
||||||
"""Robust updater for nested dictionaries
|
"""Robust updater for nested dictionaries
|
||||||
|
|
||||||
If no defaults are passed, then the instance attribute 'defaults'
|
If no defaults are passed, then the instance attribute 'defaults'
|
||||||
will be used.
|
will be used.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not defaults:
|
if defaults is ...:
|
||||||
defaults = self.defaults
|
defaults = self.defaults
|
||||||
|
|
||||||
for key, value in current.items():
|
for key, value in current.items():
|
||||||
@ -844,7 +865,7 @@ class Config:
|
|||||||
"""
|
"""
|
||||||
return self._get_base_group(group_identifier, *identifiers)
|
return self._get_base_group(group_identifier, *identifiers)
|
||||||
|
|
||||||
async def _all_from_scope(self, scope: str):
|
async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]:
|
||||||
"""Get a dict of all values from a particular scope of data.
|
"""Get a dict of all values from a particular scope of data.
|
||||||
|
|
||||||
:code:`scope` must be one of the constants attributed to
|
:code:`scope` must be one of the constants attributed to
|
||||||
@ -856,12 +877,18 @@ class Config:
|
|||||||
overwritten.
|
overwritten.
|
||||||
"""
|
"""
|
||||||
group = self._get_base_group(scope)
|
group = self._get_base_group(scope)
|
||||||
dict_ = await group()
|
|
||||||
ret = {}
|
ret = {}
|
||||||
for k, v in dict_.items():
|
|
||||||
data = group.defaults
|
try:
|
||||||
data.update(v)
|
dict_ = await self.driver.get(*group.identifiers)
|
||||||
ret[int(k)] = data
|
except KeyError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
for k, v in dict_.items():
|
||||||
|
data = group.defaults
|
||||||
|
data.update(v)
|
||||||
|
ret[int(k)] = data
|
||||||
|
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def all_guilds(self) -> dict:
|
async def all_guilds(self) -> dict:
|
||||||
@ -968,13 +995,21 @@ class Config:
|
|||||||
ret = {}
|
ret = {}
|
||||||
if guild is None:
|
if guild is None:
|
||||||
group = self._get_base_group(self.MEMBER)
|
group = self._get_base_group(self.MEMBER)
|
||||||
dict_ = await group()
|
try:
|
||||||
for guild_id, guild_data in dict_.items():
|
dict_ = await self.driver.get(*group.identifiers)
|
||||||
ret[int(guild_id)] = self._all_members_from_guild(group, guild_data)
|
except KeyError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
for guild_id, guild_data in dict_.items():
|
||||||
|
ret[int(guild_id)] = self._all_members_from_guild(group, guild_data)
|
||||||
else:
|
else:
|
||||||
group = self._get_base_group(self.MEMBER, guild.id)
|
group = self._get_base_group(self.MEMBER, guild.id)
|
||||||
guild_data = await group()
|
try:
|
||||||
ret = self._all_members_from_guild(group, guild_data)
|
guild_data = await self.driver.get(*group.identifiers)
|
||||||
|
except KeyError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
ret = self._all_members_from_guild(group, guild_data)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def _clear_scope(self, *scopes: str):
|
async def _clear_scope(self, *scopes: str):
|
||||||
|
|||||||
@ -430,3 +430,39 @@ async def test_set_then_mutate(config):
|
|||||||
list1.append("foo")
|
list1.append("foo")
|
||||||
list1 = await config.list1()
|
list1 = await config.list1()
|
||||||
assert "foo" not in list1
|
assert "foo" not in list1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_call_group_fills_defaults(config):
|
||||||
|
config.register_global(subgroup={"foo": True})
|
||||||
|
subgroup = await config.subgroup()
|
||||||
|
assert "foo" in subgroup
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_group_call_ctxmgr_writes(config):
|
||||||
|
config.register_global(subgroup={"foo": True})
|
||||||
|
async with config.subgroup() as subgroup:
|
||||||
|
subgroup["bar"] = False
|
||||||
|
|
||||||
|
subgroup = await config.subgroup()
|
||||||
|
assert subgroup == {"foo": True, "bar": False}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_all_works_as_ctxmgr(config):
|
||||||
|
config.register_global(subgroup={"foo": True})
|
||||||
|
async with config.subgroup.all() as subgroup:
|
||||||
|
subgroup["bar"] = False
|
||||||
|
|
||||||
|
subgroup = await config.subgroup()
|
||||||
|
assert subgroup == {"foo": True, "bar": False}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_raw_mixes_defaults(config):
|
||||||
|
config.register_global(subgroup={"foo": True})
|
||||||
|
await config.subgroup.set_raw("bar", value=False)
|
||||||
|
|
||||||
|
subgroup = await config.get_raw("subgroup")
|
||||||
|
assert subgroup == {"foo": True, "bar": False}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user