[Config] Retrieve/save values with async context manager (#1131)

* [Config] Retrieve/save values with async context manager

* Add a little docstring

* Documentation

* Implement async with syntax in existing modules
This commit is contained in:
Tobotimus 2017-12-04 14:07:34 +11:00 committed by palmtree5
parent 9b1018fa96
commit 9dbf56f942
8 changed files with 165 additions and 75 deletions

View File

@ -29,7 +29,7 @@ Basic Usage
@commands.command() @commands.command()
async def return_some_data(self, ctx): async def return_some_data(self, ctx):
await ctx.send(config.foo()) await ctx.send(await config.foo())
******** ********
Tutorial Tutorial
@ -117,11 +117,36 @@ Notice a few things in the above examples:
3. If you're getting the value, the syntax is:: 3. If you're getting the value, the syntax is::
self.config.<insert thing here, or nothing if global>.variable_name() self.config.<insert scope here, or nothing if global>.variable_name()
4. If setting, it's:: 4. If setting, it's::
self.config.<insert thing here, or nothing if global>.variable_name.set(new_value) self.config.<insert scope here, or nothing if global>.variable_name.set(new_value)
It is also possible to use :code:`async with` syntax to get and set config
values. When entering the statement, the config value is retreived, and on exit,
it is saved. This puts a safeguard on any code within the :code:`async with`
block such that if it breaks from the block in any way (whether it be from
:code:`return`, :code:`break`, :code:`continue` or an exception), the value will
still be saved.
.. important::
Only mutable config values can be used in the :code:`async with` statement
(namely lists or dicts), and they must be modified *in place* for their
changes to be saved.
Here is an example of the :code:`async with` syntax:
.. code-block:: python
@commands.command()
async def addblah(self, ctx, new_blah):
guild_group = self.config.guild(ctx.guild)
async with guild_group.blah() as blah:
blah.append(new_blah)
await ctx.send("The new blah value has been added!")
.. important:: .. important::

View File

@ -328,10 +328,9 @@ class Admin:
""" """
Add a role to the list of available selfroles. Add a role to the list of available selfroles.
""" """
curr_selfroles = await self.conf.guild(ctx.guild).selfroles() async with self.conf.guild(ctx.guild).selfroles() as curr_selfroles:
if role.id not in curr_selfroles: if role.id not in curr_selfroles:
curr_selfroles.append(role.id) curr_selfroles.append(role.id)
await self.conf.guild(ctx.guild).selfroles.set(curr_selfroles)
await ctx.send("The selfroles list has been successfully modified.") await ctx.send("The selfroles list has been successfully modified.")
@ -341,9 +340,8 @@ class Admin:
""" """
Removes a role from the list of available selfroles. Removes a role from the list of available selfroles.
""" """
curr_selfroles = await self.conf.guild(ctx.guild).selfroles() async with self.conf.guild(ctx.guild).selfroles() as curr_selfroles:
curr_selfroles.remove(role.id) curr_selfroles.remove(role.id)
await self.conf.guild(ctx.guild).selfroles.set(curr_selfroles)
await ctx.send("The selfroles list has been successfully modified.") await ctx.send("The selfroles list has been successfully modified.")

View File

@ -82,41 +82,31 @@ class Alias:
alias = AliasEntry(alias_name, command, ctx.author, global_=global_) alias = AliasEntry(alias_name, command, ctx.author, global_=global_)
if global_: if global_:
curr_aliases = await self._aliases.entries() settings = self._aliases
curr_aliases.append(alias.to_json())
await self._aliases.entries.set(curr_aliases)
else: else:
curr_aliases = await self._aliases.guild(ctx.guild).entries() settings = self._aliases.guild(ctx.guild)
await settings.enabled.set(True)
async with settings.entries() as curr_aliases:
curr_aliases.append(alias.to_json()) curr_aliases.append(alias.to_json())
await self._aliases.guild(ctx.guild).entries.set(curr_aliases)
await self._aliases.guild(ctx.guild).enabled.set(True)
return alias return alias
async def delete_alias(self, ctx: commands.Context, alias_name: str, async def delete_alias(self, ctx: commands.Context, alias_name: str,
global_: bool=False) -> bool: global_: bool=False) -> bool:
if global_: if global_:
aliases = await self.unloaded_global_aliases() settings = self._aliases
setter_func = self._aliases.entries.set
else: else:
aliases = await self.unloaded_aliases(ctx.guild) settings = self._aliases.guild(ctx.guild)
setter_func = self._aliases.guild(ctx.guild).entries.set
did_delete_alias = False async with settings.entries() as aliases:
to_keep = []
for alias in aliases: for alias in aliases:
if alias.name != alias_name: alias_obj = AliasEntry.from_json(alias)
to_keep.append(alias) if alias_obj.name == alias_name:
else: aliases.remove(alias)
did_delete_alias = True return True
await setter_func( return False
[a.to_json() for a in to_keep]
)
return did_delete_alias
async def get_prefix(self, message: discord.Message) -> str: async def get_prefix(self, message: discord.Message) -> str:
""" """

View File

@ -183,30 +183,24 @@ class Filter:
await ctx.send(_("Count and time have been set.")) await ctx.send(_("Count and time have been set."))
async def add_to_filter(self, server: discord.Guild, words: list) -> bool: async def add_to_filter(self, server: discord.Guild, words: list) -> bool:
added = 0 added = False
cur_list = await self.settings.guild(server).filter() async with self.settings.guild(server).filter() as cur_list:
for w in words: for w in words:
if w.lower() not in cur_list and w != "": if w.lower() not in cur_list and w:
cur_list.append(w.lower()) cur_list.append(w.lower())
added += 1 added = True
if added:
await self.settings.guild(server).filter.set(cur_list) return added
return True
else:
return False
async def remove_from_filter(self, server: discord.Guild, words: list) -> bool: async def remove_from_filter(self, server: discord.Guild, words: list) -> bool:
removed = 0 removed = False
cur_list = await self.settings.guild(server).filter() async with self.settings.guild(server).filter() as cur_list:
for w in words: for w in words:
if w.lower() in cur_list: if w.lower() in cur_list:
cur_list.remove(w.lower()) cur_list.remove(w.lower())
removed += 1 removed = True
if removed:
await self.settings.guild(server).filter.set(cur_list) return removed
return True
else:
return False
async def check_filter(self, message: discord.Message): async def check_filter(self, message: discord.Message):
server = message.guild server = message.guild

View File

@ -1347,18 +1347,21 @@ class Mod:
async def on_member_update(self, before, after): async def on_member_update(self, before, after):
if before.name != after.name: if before.name != after.name:
name_list = await self.settings.user(before).past_names() async with self.settings.user(before).past_names() as name_list:
if after.name not in name_list: if after.nick in name_list:
names = deque(name_list, maxlen=20) # Ensure order is maintained without duplicates occuring
names.append(after.name) name_list.remove(after.nick)
await self.settings.user(before).past_names.set(list(names)) name_list.append(after.nick)
while len(name_list) > 20:
name_list.pop(0)
if before.nick != after.nick and after.nick is not None: if before.nick != after.nick and after.nick is not None:
nick_list = await self.settings.member(before).past_nicks() async with self.settings.member(before).past_nicks() as nick_list:
nicks = deque(nick_list, maxlen=20) if after.nick in nick_list:
if after.nick not in nicks: nick_list.remove(after.nick)
nicks.append(after.nick) nick_list.append(after.nick)
await self.settings.member(before).past_nicks.set(list(nicks)) while len(nick_list) > 20:
nick_list.pop(0)
@staticmethod @staticmethod
def are_overwrites_empty(overwrites): def are_overwrites_empty(overwrites):

View File

@ -142,15 +142,14 @@ class RedBase(BotBase, RpcMethodMixin):
await self.db.packages.set(packages) await self.db.packages.set(packages)
async def add_loaded_package(self, pkg_name: str): async def add_loaded_package(self, pkg_name: str):
curr_pkgs = await self.db.packages() async with self.db.packages() as curr_pkgs:
if pkg_name not in curr_pkgs: if pkg_name not in curr_pkgs:
curr_pkgs.append(pkg_name) curr_pkgs.append(pkg_name)
await self.save_packages_status(curr_pkgs)
async def remove_loaded_package(self, pkg_name: str): async def remove_loaded_package(self, pkg_name: str):
curr_pkgs = await self.db.packages() async with self.db.packages() as curr_pkgs:
if pkg_name in curr_pkgs: while pkg_name in curr_pkgs:
await self.save_packages_status([p for p in curr_pkgs if p != pkg_name]) curr_pkgs.remove(pkg_name)
def load_extension(self, spec: ModuleSpec): def load_extension(self, spec: ModuleSpec):
name = spec.name.split('.')[-1] name = spec.name.split('.')[-1]

View File

@ -10,6 +10,37 @@ from .drivers import get_driver
log = logging.getLogger("red.config") log = logging.getLogger("red.config")
class _ValueCtxManager:
"""Context manager implementation of config values.
This class allows mutable config values to be both "get" and "set" from
within an async context manager.
The context manager can only be used to get and set a mutable data type,
i.e. `dict`s or `list`s. This is because this class's ``raw_value``
attribute must contain a reference to the object being modified within the
context manager.
"""
def __init__(self, value_obj, coro):
self.value_obj = value_obj
self.coro = coro
def __await__(self):
return self.coro.__await__()
async def __aenter__(self):
self.raw_value = await self
if not isinstance(self.raw_value, (list, dict)):
raise TypeError("Type of retrieved value must be mutable (i.e. "
"list or dict) in order to use a config value as "
"a context manager.")
return self.raw_value
async def __aexit__(self, *exc_info):
await self.value_obj.set(self.raw_value)
class Value: class Value:
"""A singular "value" of data. """A singular "value" of data.
@ -49,6 +80,11 @@ class Value:
"real" data of the `Value` object is accessed by this method. It is a "real" data of the `Value` object is accessed by this method. It is a
replacement for a :code:`get()` method. replacement for a :code:`get()` method.
The return value of this method can also be used as an asynchronous
context manager, i.e. with :code:`async with` syntax. This can only be
used on values which are mutable (namely lists and dicts), and will
set the value with its changes on exit of the context manager.
Example Example
------- -------
:: ::
@ -74,11 +110,13 @@ class Value:
Returns Returns
------- -------
types.coroutine `awaitable` mixed with `asynchronous context manager`
A coroutine object that must be awaited. A coroutine object mixed in with an async context manager. When
awaited, this returns the raw data value. When used in :code:`async
with` syntax, on gets the value on entrance, and sets it on exit.
""" """
return self._get(default) return _ValueCtxManager(self, self._get(default))
async def set(self, value): async def set(self, value):
"""Set the value of the data elements pointed to by `identifiers`. """Set the value of the data elements pointed to by `identifiers`.

View File

@ -333,3 +333,46 @@ async def test_user_getalldata(config, user_factory):
assert "bar" in all_data assert "bar" in all_data
assert config.user(user).defaults['foo'] is True assert config.user(user).defaults['foo'] is True
@pytest.mark.asyncio
async def test_value_ctxmgr(config):
config.register_global(foo_list=[])
async with config.foo_list() as foo_list:
foo_list.append('foo')
foo_list = await config.foo_list()
assert 'foo' in foo_list
@pytest.mark.asyncio
async def test_value_ctxmgr_saves(config):
config.register_global(bar_list=[])
try:
async with config.bar_list() as bar_list:
bar_list.append('bar')
raise RuntimeError()
except RuntimeError:
pass
bar_list = await config.bar_list()
assert 'bar' in bar_list
@pytest.mark.asyncio
async def test_value_ctxmgr_immutable(config):
config.register_global(foo=True)
try:
async with config.foo() as foo:
foo = False
except TypeError:
pass
else:
raise AssertionError
foo = await config.foo()
assert foo is True