mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 03:08:55 -05:00
[Config] Retrieve/save values with async context manager (#1131)
* [Config] Retrieve/save values with async context manager * Add a little docstring * Documentation * Implement async with syntax in existing modules
This commit is contained in:
parent
9b1018fa96
commit
9dbf56f942
@ -29,7 +29,7 @@ Basic Usage
|
||||
|
||||
@commands.command()
|
||||
async def return_some_data(self, ctx):
|
||||
await ctx.send(config.foo())
|
||||
await ctx.send(await config.foo())
|
||||
|
||||
********
|
||||
Tutorial
|
||||
@ -117,11 +117,36 @@ Notice a few things in the above examples:
|
||||
|
||||
3. If you're getting the value, the syntax is::
|
||||
|
||||
self.config.<insert thing here, or nothing if global>.variable_name()
|
||||
self.config.<insert scope here, or nothing if global>.variable_name()
|
||||
|
||||
4. If setting, it's::
|
||||
|
||||
self.config.<insert thing here, or nothing if global>.variable_name.set(new_value)
|
||||
self.config.<insert scope here, or nothing if global>.variable_name.set(new_value)
|
||||
|
||||
It is also possible to use :code:`async with` syntax to get and set config
|
||||
values. When entering the statement, the config value is retreived, and on exit,
|
||||
it is saved. This puts a safeguard on any code within the :code:`async with`
|
||||
block such that if it breaks from the block in any way (whether it be from
|
||||
:code:`return`, :code:`break`, :code:`continue` or an exception), the value will
|
||||
still be saved.
|
||||
|
||||
.. important::
|
||||
|
||||
Only mutable config values can be used in the :code:`async with` statement
|
||||
(namely lists or dicts), and they must be modified *in place* for their
|
||||
changes to be saved.
|
||||
|
||||
Here is an example of the :code:`async with` syntax:
|
||||
|
||||
.. code-block:: python
|
||||
|
||||
@commands.command()
|
||||
async def addblah(self, ctx, new_blah):
|
||||
guild_group = self.config.guild(ctx.guild)
|
||||
async with guild_group.blah() as blah:
|
||||
blah.append(new_blah)
|
||||
await ctx.send("The new blah value has been added!")
|
||||
|
||||
|
||||
.. important::
|
||||
|
||||
|
||||
@ -328,10 +328,9 @@ class Admin:
|
||||
"""
|
||||
Add a role to the list of available selfroles.
|
||||
"""
|
||||
curr_selfroles = await self.conf.guild(ctx.guild).selfroles()
|
||||
if role.id not in curr_selfroles:
|
||||
curr_selfroles.append(role.id)
|
||||
await self.conf.guild(ctx.guild).selfroles.set(curr_selfroles)
|
||||
async with self.conf.guild(ctx.guild).selfroles() as curr_selfroles:
|
||||
if role.id not in curr_selfroles:
|
||||
curr_selfroles.append(role.id)
|
||||
|
||||
await ctx.send("The selfroles list has been successfully modified.")
|
||||
|
||||
@ -341,9 +340,8 @@ class Admin:
|
||||
"""
|
||||
Removes a role from the list of available selfroles.
|
||||
"""
|
||||
curr_selfroles = await self.conf.guild(ctx.guild).selfroles()
|
||||
curr_selfroles.remove(role.id)
|
||||
await self.conf.guild(ctx.guild).selfroles.set(curr_selfroles)
|
||||
async with self.conf.guild(ctx.guild).selfroles() as curr_selfroles:
|
||||
curr_selfroles.remove(role.id)
|
||||
|
||||
await ctx.send("The selfroles list has been successfully modified.")
|
||||
|
||||
|
||||
@ -82,41 +82,31 @@ class Alias:
|
||||
alias = AliasEntry(alias_name, command, ctx.author, global_=global_)
|
||||
|
||||
if global_:
|
||||
curr_aliases = await self._aliases.entries()
|
||||
curr_aliases.append(alias.to_json())
|
||||
await self._aliases.entries.set(curr_aliases)
|
||||
settings = self._aliases
|
||||
else:
|
||||
curr_aliases = await self._aliases.guild(ctx.guild).entries()
|
||||
settings = self._aliases.guild(ctx.guild)
|
||||
await settings.enabled.set(True)
|
||||
|
||||
async with settings.entries() as curr_aliases:
|
||||
curr_aliases.append(alias.to_json())
|
||||
await self._aliases.guild(ctx.guild).entries.set(curr_aliases)
|
||||
|
||||
await self._aliases.guild(ctx.guild).enabled.set(True)
|
||||
return alias
|
||||
|
||||
async def delete_alias(self, ctx: commands.Context, alias_name: str,
|
||||
global_: bool=False) -> bool:
|
||||
if global_:
|
||||
aliases = await self.unloaded_global_aliases()
|
||||
setter_func = self._aliases.entries.set
|
||||
settings = self._aliases
|
||||
else:
|
||||
aliases = await self.unloaded_aliases(ctx.guild)
|
||||
setter_func = self._aliases.guild(ctx.guild).entries.set
|
||||
settings = self._aliases.guild(ctx.guild)
|
||||
|
||||
did_delete_alias = False
|
||||
async with settings.entries() as aliases:
|
||||
for alias in aliases:
|
||||
alias_obj = AliasEntry.from_json(alias)
|
||||
if alias_obj.name == alias_name:
|
||||
aliases.remove(alias)
|
||||
return True
|
||||
|
||||
to_keep = []
|
||||
for alias in aliases:
|
||||
if alias.name != alias_name:
|
||||
to_keep.append(alias)
|
||||
else:
|
||||
did_delete_alias = True
|
||||
|
||||
await setter_func(
|
||||
[a.to_json() for a in to_keep]
|
||||
)
|
||||
|
||||
return did_delete_alias
|
||||
return False
|
||||
|
||||
async def get_prefix(self, message: discord.Message) -> str:
|
||||
"""
|
||||
|
||||
@ -183,30 +183,24 @@ class Filter:
|
||||
await ctx.send(_("Count and time have been set."))
|
||||
|
||||
async def add_to_filter(self, server: discord.Guild, words: list) -> bool:
|
||||
added = 0
|
||||
cur_list = await self.settings.guild(server).filter()
|
||||
for w in words:
|
||||
if w.lower() not in cur_list and w != "":
|
||||
cur_list.append(w.lower())
|
||||
added += 1
|
||||
if added:
|
||||
await self.settings.guild(server).filter.set(cur_list)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
added = False
|
||||
async with self.settings.guild(server).filter() as cur_list:
|
||||
for w in words:
|
||||
if w.lower() not in cur_list and w:
|
||||
cur_list.append(w.lower())
|
||||
added = True
|
||||
|
||||
return added
|
||||
|
||||
async def remove_from_filter(self, server: discord.Guild, words: list) -> bool:
|
||||
removed = 0
|
||||
cur_list = await self.settings.guild(server).filter()
|
||||
for w in words:
|
||||
if w.lower() in cur_list:
|
||||
cur_list.remove(w.lower())
|
||||
removed += 1
|
||||
if removed:
|
||||
await self.settings.guild(server).filter.set(cur_list)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
removed = False
|
||||
async with self.settings.guild(server).filter() as cur_list:
|
||||
for w in words:
|
||||
if w.lower() in cur_list:
|
||||
cur_list.remove(w.lower())
|
||||
removed = True
|
||||
|
||||
return removed
|
||||
|
||||
async def check_filter(self, message: discord.Message):
|
||||
server = message.guild
|
||||
|
||||
@ -1347,18 +1347,21 @@ class Mod:
|
||||
|
||||
async def on_member_update(self, before, after):
|
||||
if before.name != after.name:
|
||||
name_list = await self.settings.user(before).past_names()
|
||||
if after.name not in name_list:
|
||||
names = deque(name_list, maxlen=20)
|
||||
names.append(after.name)
|
||||
await self.settings.user(before).past_names.set(list(names))
|
||||
async with self.settings.user(before).past_names() as name_list:
|
||||
if after.nick in name_list:
|
||||
# Ensure order is maintained without duplicates occuring
|
||||
name_list.remove(after.nick)
|
||||
name_list.append(after.nick)
|
||||
while len(name_list) > 20:
|
||||
name_list.pop(0)
|
||||
|
||||
if before.nick != after.nick and after.nick is not None:
|
||||
nick_list = await self.settings.member(before).past_nicks()
|
||||
nicks = deque(nick_list, maxlen=20)
|
||||
if after.nick not in nicks:
|
||||
nicks.append(after.nick)
|
||||
await self.settings.member(before).past_nicks.set(list(nicks))
|
||||
async with self.settings.member(before).past_nicks() as nick_list:
|
||||
if after.nick in nick_list:
|
||||
nick_list.remove(after.nick)
|
||||
nick_list.append(after.nick)
|
||||
while len(nick_list) > 20:
|
||||
nick_list.pop(0)
|
||||
|
||||
@staticmethod
|
||||
def are_overwrites_empty(overwrites):
|
||||
|
||||
@ -142,15 +142,14 @@ class RedBase(BotBase, RpcMethodMixin):
|
||||
await self.db.packages.set(packages)
|
||||
|
||||
async def add_loaded_package(self, pkg_name: str):
|
||||
curr_pkgs = await self.db.packages()
|
||||
if pkg_name not in curr_pkgs:
|
||||
curr_pkgs.append(pkg_name)
|
||||
await self.save_packages_status(curr_pkgs)
|
||||
async with self.db.packages() as curr_pkgs:
|
||||
if pkg_name not in curr_pkgs:
|
||||
curr_pkgs.append(pkg_name)
|
||||
|
||||
async def remove_loaded_package(self, pkg_name: str):
|
||||
curr_pkgs = await self.db.packages()
|
||||
if pkg_name in curr_pkgs:
|
||||
await self.save_packages_status([p for p in curr_pkgs if p != pkg_name])
|
||||
async with self.db.packages() as curr_pkgs:
|
||||
while pkg_name in curr_pkgs:
|
||||
curr_pkgs.remove(pkg_name)
|
||||
|
||||
def load_extension(self, spec: ModuleSpec):
|
||||
name = spec.name.split('.')[-1]
|
||||
|
||||
@ -10,6 +10,37 @@ from .drivers import get_driver
|
||||
log = logging.getLogger("red.config")
|
||||
|
||||
|
||||
class _ValueCtxManager:
|
||||
"""Context manager implementation of config values.
|
||||
|
||||
This class allows mutable config values to be both "get" and "set" from
|
||||
within an async context manager.
|
||||
|
||||
The context manager can only be used to get and set a mutable data type,
|
||||
i.e. `dict`s or `list`s. This is because this class's ``raw_value``
|
||||
attribute must contain a reference to the object being modified within the
|
||||
context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, value_obj, coro):
|
||||
self.value_obj = value_obj
|
||||
self.coro = coro
|
||||
|
||||
def __await__(self):
|
||||
return self.coro.__await__()
|
||||
|
||||
async def __aenter__(self):
|
||||
self.raw_value = await self
|
||||
if not isinstance(self.raw_value, (list, dict)):
|
||||
raise TypeError("Type of retrieved value must be mutable (i.e. "
|
||||
"list or dict) in order to use a config value as "
|
||||
"a context manager.")
|
||||
return self.raw_value
|
||||
|
||||
async def __aexit__(self, *exc_info):
|
||||
await self.value_obj.set(self.raw_value)
|
||||
|
||||
|
||||
class Value:
|
||||
"""A singular "value" of data.
|
||||
|
||||
@ -49,6 +80,11 @@ class Value:
|
||||
"real" data of the `Value` object is accessed by this method. It is a
|
||||
replacement for a :code:`get()` method.
|
||||
|
||||
The return value of this method can also be used as an asynchronous
|
||||
context manager, i.e. with :code:`async with` syntax. This can only be
|
||||
used on values which are mutable (namely lists and dicts), and will
|
||||
set the value with its changes on exit of the context manager.
|
||||
|
||||
Example
|
||||
-------
|
||||
::
|
||||
@ -74,11 +110,13 @@ class Value:
|
||||
|
||||
Returns
|
||||
-------
|
||||
types.coroutine
|
||||
A coroutine object that must be awaited.
|
||||
`awaitable` mixed with `asynchronous context manager`
|
||||
A coroutine object mixed in with an async context manager. When
|
||||
awaited, this returns the raw data value. When used in :code:`async
|
||||
with` syntax, on gets the value on entrance, and sets it on exit.
|
||||
|
||||
"""
|
||||
return self._get(default)
|
||||
return _ValueCtxManager(self, self._get(default))
|
||||
|
||||
async def set(self, value):
|
||||
"""Set the value of the data elements pointed to by `identifiers`.
|
||||
|
||||
@ -333,3 +333,46 @@ async def test_user_getalldata(config, user_factory):
|
||||
assert "bar" in all_data
|
||||
|
||||
assert config.user(user).defaults['foo'] is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_value_ctxmgr(config):
|
||||
config.register_global(foo_list=[])
|
||||
|
||||
async with config.foo_list() as foo_list:
|
||||
foo_list.append('foo')
|
||||
|
||||
foo_list = await config.foo_list()
|
||||
|
||||
assert 'foo' in foo_list
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_value_ctxmgr_saves(config):
|
||||
config.register_global(bar_list=[])
|
||||
|
||||
try:
|
||||
async with config.bar_list() as bar_list:
|
||||
bar_list.append('bar')
|
||||
raise RuntimeError()
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
bar_list = await config.bar_list()
|
||||
|
||||
assert 'bar' in bar_list
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_value_ctxmgr_immutable(config):
|
||||
config.register_global(foo=True)
|
||||
|
||||
try:
|
||||
async with config.foo() as foo:
|
||||
foo = False
|
||||
except TypeError:
|
||||
pass
|
||||
else:
|
||||
raise AssertionError
|
||||
|
||||
foo = await config.foo()
|
||||
assert foo is True
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user