[Config] Retrieve/save values with async context manager (#1131)

* [Config] Retrieve/save values with async context manager

* Add a little docstring

* Documentation

* Implement async with syntax in existing modules
This commit is contained in:
Tobotimus 2017-12-04 14:07:34 +11:00 committed by palmtree5
parent 9b1018fa96
commit 9dbf56f942
8 changed files with 165 additions and 75 deletions

View File

@ -29,7 +29,7 @@ Basic Usage
@commands.command()
async def return_some_data(self, ctx):
await ctx.send(config.foo())
await ctx.send(await config.foo())
********
Tutorial
@ -117,11 +117,36 @@ Notice a few things in the above examples:
3. If you're getting the value, the syntax is::
self.config.<insert thing here, or nothing if global>.variable_name()
self.config.<insert scope here, or nothing if global>.variable_name()
4. If setting, it's::
self.config.<insert thing here, or nothing if global>.variable_name.set(new_value)
self.config.<insert scope here, or nothing if global>.variable_name.set(new_value)
It is also possible to use :code:`async with` syntax to get and set config
values. When entering the statement, the config value is retreived, and on exit,
it is saved. This puts a safeguard on any code within the :code:`async with`
block such that if it breaks from the block in any way (whether it be from
:code:`return`, :code:`break`, :code:`continue` or an exception), the value will
still be saved.
.. important::
Only mutable config values can be used in the :code:`async with` statement
(namely lists or dicts), and they must be modified *in place* for their
changes to be saved.
Here is an example of the :code:`async with` syntax:
.. code-block:: python
@commands.command()
async def addblah(self, ctx, new_blah):
guild_group = self.config.guild(ctx.guild)
async with guild_group.blah() as blah:
blah.append(new_blah)
await ctx.send("The new blah value has been added!")
.. important::

View File

@ -328,10 +328,9 @@ class Admin:
"""
Add a role to the list of available selfroles.
"""
curr_selfroles = await self.conf.guild(ctx.guild).selfroles()
if role.id not in curr_selfroles:
curr_selfroles.append(role.id)
await self.conf.guild(ctx.guild).selfroles.set(curr_selfroles)
async with self.conf.guild(ctx.guild).selfroles() as curr_selfroles:
if role.id not in curr_selfroles:
curr_selfroles.append(role.id)
await ctx.send("The selfroles list has been successfully modified.")
@ -341,9 +340,8 @@ class Admin:
"""
Removes a role from the list of available selfroles.
"""
curr_selfroles = await self.conf.guild(ctx.guild).selfroles()
curr_selfroles.remove(role.id)
await self.conf.guild(ctx.guild).selfroles.set(curr_selfroles)
async with self.conf.guild(ctx.guild).selfroles() as curr_selfroles:
curr_selfroles.remove(role.id)
await ctx.send("The selfroles list has been successfully modified.")

View File

@ -82,41 +82,31 @@ class Alias:
alias = AliasEntry(alias_name, command, ctx.author, global_=global_)
if global_:
curr_aliases = await self._aliases.entries()
curr_aliases.append(alias.to_json())
await self._aliases.entries.set(curr_aliases)
settings = self._aliases
else:
curr_aliases = await self._aliases.guild(ctx.guild).entries()
settings = self._aliases.guild(ctx.guild)
await settings.enabled.set(True)
async with settings.entries() as curr_aliases:
curr_aliases.append(alias.to_json())
await self._aliases.guild(ctx.guild).entries.set(curr_aliases)
await self._aliases.guild(ctx.guild).enabled.set(True)
return alias
async def delete_alias(self, ctx: commands.Context, alias_name: str,
global_: bool=False) -> bool:
if global_:
aliases = await self.unloaded_global_aliases()
setter_func = self._aliases.entries.set
settings = self._aliases
else:
aliases = await self.unloaded_aliases(ctx.guild)
setter_func = self._aliases.guild(ctx.guild).entries.set
settings = self._aliases.guild(ctx.guild)
did_delete_alias = False
async with settings.entries() as aliases:
for alias in aliases:
alias_obj = AliasEntry.from_json(alias)
if alias_obj.name == alias_name:
aliases.remove(alias)
return True
to_keep = []
for alias in aliases:
if alias.name != alias_name:
to_keep.append(alias)
else:
did_delete_alias = True
await setter_func(
[a.to_json() for a in to_keep]
)
return did_delete_alias
return False
async def get_prefix(self, message: discord.Message) -> str:
"""

View File

@ -183,30 +183,24 @@ class Filter:
await ctx.send(_("Count and time have been set."))
async def add_to_filter(self, server: discord.Guild, words: list) -> bool:
added = 0
cur_list = await self.settings.guild(server).filter()
for w in words:
if w.lower() not in cur_list and w != "":
cur_list.append(w.lower())
added += 1
if added:
await self.settings.guild(server).filter.set(cur_list)
return True
else:
return False
added = False
async with self.settings.guild(server).filter() as cur_list:
for w in words:
if w.lower() not in cur_list and w:
cur_list.append(w.lower())
added = True
return added
async def remove_from_filter(self, server: discord.Guild, words: list) -> bool:
removed = 0
cur_list = await self.settings.guild(server).filter()
for w in words:
if w.lower() in cur_list:
cur_list.remove(w.lower())
removed += 1
if removed:
await self.settings.guild(server).filter.set(cur_list)
return True
else:
return False
removed = False
async with self.settings.guild(server).filter() as cur_list:
for w in words:
if w.lower() in cur_list:
cur_list.remove(w.lower())
removed = True
return removed
async def check_filter(self, message: discord.Message):
server = message.guild

View File

@ -1347,18 +1347,21 @@ class Mod:
async def on_member_update(self, before, after):
if before.name != after.name:
name_list = await self.settings.user(before).past_names()
if after.name not in name_list:
names = deque(name_list, maxlen=20)
names.append(after.name)
await self.settings.user(before).past_names.set(list(names))
async with self.settings.user(before).past_names() as name_list:
if after.nick in name_list:
# Ensure order is maintained without duplicates occuring
name_list.remove(after.nick)
name_list.append(after.nick)
while len(name_list) > 20:
name_list.pop(0)
if before.nick != after.nick and after.nick is not None:
nick_list = await self.settings.member(before).past_nicks()
nicks = deque(nick_list, maxlen=20)
if after.nick not in nicks:
nicks.append(after.nick)
await self.settings.member(before).past_nicks.set(list(nicks))
async with self.settings.member(before).past_nicks() as nick_list:
if after.nick in nick_list:
nick_list.remove(after.nick)
nick_list.append(after.nick)
while len(nick_list) > 20:
nick_list.pop(0)
@staticmethod
def are_overwrites_empty(overwrites):

View File

@ -142,15 +142,14 @@ class RedBase(BotBase, RpcMethodMixin):
await self.db.packages.set(packages)
async def add_loaded_package(self, pkg_name: str):
curr_pkgs = await self.db.packages()
if pkg_name not in curr_pkgs:
curr_pkgs.append(pkg_name)
await self.save_packages_status(curr_pkgs)
async with self.db.packages() as curr_pkgs:
if pkg_name not in curr_pkgs:
curr_pkgs.append(pkg_name)
async def remove_loaded_package(self, pkg_name: str):
curr_pkgs = await self.db.packages()
if pkg_name in curr_pkgs:
await self.save_packages_status([p for p in curr_pkgs if p != pkg_name])
async with self.db.packages() as curr_pkgs:
while pkg_name in curr_pkgs:
curr_pkgs.remove(pkg_name)
def load_extension(self, spec: ModuleSpec):
name = spec.name.split('.')[-1]

View File

@ -10,6 +10,37 @@ from .drivers import get_driver
log = logging.getLogger("red.config")
class _ValueCtxManager:
"""Context manager implementation of config values.
This class allows mutable config values to be both "get" and "set" from
within an async context manager.
The context manager can only be used to get and set a mutable data type,
i.e. `dict`s or `list`s. This is because this class's ``raw_value``
attribute must contain a reference to the object being modified within the
context manager.
"""
def __init__(self, value_obj, coro):
self.value_obj = value_obj
self.coro = coro
def __await__(self):
return self.coro.__await__()
async def __aenter__(self):
self.raw_value = await self
if not isinstance(self.raw_value, (list, dict)):
raise TypeError("Type of retrieved value must be mutable (i.e. "
"list or dict) in order to use a config value as "
"a context manager.")
return self.raw_value
async def __aexit__(self, *exc_info):
await self.value_obj.set(self.raw_value)
class Value:
"""A singular "value" of data.
@ -49,6 +80,11 @@ class Value:
"real" data of the `Value` object is accessed by this method. It is a
replacement for a :code:`get()` method.
The return value of this method can also be used as an asynchronous
context manager, i.e. with :code:`async with` syntax. This can only be
used on values which are mutable (namely lists and dicts), and will
set the value with its changes on exit of the context manager.
Example
-------
::
@ -74,11 +110,13 @@ class Value:
Returns
-------
types.coroutine
A coroutine object that must be awaited.
`awaitable` mixed with `asynchronous context manager`
A coroutine object mixed in with an async context manager. When
awaited, this returns the raw data value. When used in :code:`async
with` syntax, on gets the value on entrance, and sets it on exit.
"""
return self._get(default)
return _ValueCtxManager(self, self._get(default))
async def set(self, value):
"""Set the value of the data elements pointed to by `identifiers`.

View File

@ -333,3 +333,46 @@ async def test_user_getalldata(config, user_factory):
assert "bar" in all_data
assert config.user(user).defaults['foo'] is True
@pytest.mark.asyncio
async def test_value_ctxmgr(config):
config.register_global(foo_list=[])
async with config.foo_list() as foo_list:
foo_list.append('foo')
foo_list = await config.foo_list()
assert 'foo' in foo_list
@pytest.mark.asyncio
async def test_value_ctxmgr_saves(config):
config.register_global(bar_list=[])
try:
async with config.bar_list() as bar_list:
bar_list.append('bar')
raise RuntimeError()
except RuntimeError:
pass
bar_list = await config.bar_list()
assert 'bar' in bar_list
@pytest.mark.asyncio
async def test_value_ctxmgr_immutable(config):
config.register_global(foo=True)
try:
async with config.foo() as foo:
foo = False
except TypeError:
pass
else:
raise AssertionError
foo = await config.foo()
assert foo is True