mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 03:08:55 -05:00
[Config] Retrieve/save values with async context manager (#1131)
* [Config] Retrieve/save values with async context manager * Add a little docstring * Documentation * Implement async with syntax in existing modules
This commit is contained in:
parent
9b1018fa96
commit
9dbf56f942
@ -29,7 +29,7 @@ Basic Usage
|
|||||||
|
|
||||||
@commands.command()
|
@commands.command()
|
||||||
async def return_some_data(self, ctx):
|
async def return_some_data(self, ctx):
|
||||||
await ctx.send(config.foo())
|
await ctx.send(await config.foo())
|
||||||
|
|
||||||
********
|
********
|
||||||
Tutorial
|
Tutorial
|
||||||
@ -117,11 +117,36 @@ Notice a few things in the above examples:
|
|||||||
|
|
||||||
3. If you're getting the value, the syntax is::
|
3. If you're getting the value, the syntax is::
|
||||||
|
|
||||||
self.config.<insert thing here, or nothing if global>.variable_name()
|
self.config.<insert scope here, or nothing if global>.variable_name()
|
||||||
|
|
||||||
4. If setting, it's::
|
4. If setting, it's::
|
||||||
|
|
||||||
self.config.<insert thing here, or nothing if global>.variable_name.set(new_value)
|
self.config.<insert scope here, or nothing if global>.variable_name.set(new_value)
|
||||||
|
|
||||||
|
It is also possible to use :code:`async with` syntax to get and set config
|
||||||
|
values. When entering the statement, the config value is retreived, and on exit,
|
||||||
|
it is saved. This puts a safeguard on any code within the :code:`async with`
|
||||||
|
block such that if it breaks from the block in any way (whether it be from
|
||||||
|
:code:`return`, :code:`break`, :code:`continue` or an exception), the value will
|
||||||
|
still be saved.
|
||||||
|
|
||||||
|
.. important::
|
||||||
|
|
||||||
|
Only mutable config values can be used in the :code:`async with` statement
|
||||||
|
(namely lists or dicts), and they must be modified *in place* for their
|
||||||
|
changes to be saved.
|
||||||
|
|
||||||
|
Here is an example of the :code:`async with` syntax:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
@commands.command()
|
||||||
|
async def addblah(self, ctx, new_blah):
|
||||||
|
guild_group = self.config.guild(ctx.guild)
|
||||||
|
async with guild_group.blah() as blah:
|
||||||
|
blah.append(new_blah)
|
||||||
|
await ctx.send("The new blah value has been added!")
|
||||||
|
|
||||||
|
|
||||||
.. important::
|
.. important::
|
||||||
|
|
||||||
|
|||||||
@ -328,10 +328,9 @@ class Admin:
|
|||||||
"""
|
"""
|
||||||
Add a role to the list of available selfroles.
|
Add a role to the list of available selfroles.
|
||||||
"""
|
"""
|
||||||
curr_selfroles = await self.conf.guild(ctx.guild).selfroles()
|
async with self.conf.guild(ctx.guild).selfroles() as curr_selfroles:
|
||||||
if role.id not in curr_selfroles:
|
if role.id not in curr_selfroles:
|
||||||
curr_selfroles.append(role.id)
|
curr_selfroles.append(role.id)
|
||||||
await self.conf.guild(ctx.guild).selfroles.set(curr_selfroles)
|
|
||||||
|
|
||||||
await ctx.send("The selfroles list has been successfully modified.")
|
await ctx.send("The selfroles list has been successfully modified.")
|
||||||
|
|
||||||
@ -341,9 +340,8 @@ class Admin:
|
|||||||
"""
|
"""
|
||||||
Removes a role from the list of available selfroles.
|
Removes a role from the list of available selfroles.
|
||||||
"""
|
"""
|
||||||
curr_selfroles = await self.conf.guild(ctx.guild).selfroles()
|
async with self.conf.guild(ctx.guild).selfroles() as curr_selfroles:
|
||||||
curr_selfroles.remove(role.id)
|
curr_selfroles.remove(role.id)
|
||||||
await self.conf.guild(ctx.guild).selfroles.set(curr_selfroles)
|
|
||||||
|
|
||||||
await ctx.send("The selfroles list has been successfully modified.")
|
await ctx.send("The selfroles list has been successfully modified.")
|
||||||
|
|
||||||
|
|||||||
@ -82,41 +82,31 @@ class Alias:
|
|||||||
alias = AliasEntry(alias_name, command, ctx.author, global_=global_)
|
alias = AliasEntry(alias_name, command, ctx.author, global_=global_)
|
||||||
|
|
||||||
if global_:
|
if global_:
|
||||||
curr_aliases = await self._aliases.entries()
|
settings = self._aliases
|
||||||
curr_aliases.append(alias.to_json())
|
|
||||||
await self._aliases.entries.set(curr_aliases)
|
|
||||||
else:
|
else:
|
||||||
curr_aliases = await self._aliases.guild(ctx.guild).entries()
|
settings = self._aliases.guild(ctx.guild)
|
||||||
|
await settings.enabled.set(True)
|
||||||
|
|
||||||
|
async with settings.entries() as curr_aliases:
|
||||||
curr_aliases.append(alias.to_json())
|
curr_aliases.append(alias.to_json())
|
||||||
await self._aliases.guild(ctx.guild).entries.set(curr_aliases)
|
|
||||||
|
|
||||||
await self._aliases.guild(ctx.guild).enabled.set(True)
|
|
||||||
return alias
|
return alias
|
||||||
|
|
||||||
async def delete_alias(self, ctx: commands.Context, alias_name: str,
|
async def delete_alias(self, ctx: commands.Context, alias_name: str,
|
||||||
global_: bool=False) -> bool:
|
global_: bool=False) -> bool:
|
||||||
if global_:
|
if global_:
|
||||||
aliases = await self.unloaded_global_aliases()
|
settings = self._aliases
|
||||||
setter_func = self._aliases.entries.set
|
|
||||||
else:
|
else:
|
||||||
aliases = await self.unloaded_aliases(ctx.guild)
|
settings = self._aliases.guild(ctx.guild)
|
||||||
setter_func = self._aliases.guild(ctx.guild).entries.set
|
|
||||||
|
|
||||||
did_delete_alias = False
|
async with settings.entries() as aliases:
|
||||||
|
for alias in aliases:
|
||||||
|
alias_obj = AliasEntry.from_json(alias)
|
||||||
|
if alias_obj.name == alias_name:
|
||||||
|
aliases.remove(alias)
|
||||||
|
return True
|
||||||
|
|
||||||
to_keep = []
|
return False
|
||||||
for alias in aliases:
|
|
||||||
if alias.name != alias_name:
|
|
||||||
to_keep.append(alias)
|
|
||||||
else:
|
|
||||||
did_delete_alias = True
|
|
||||||
|
|
||||||
await setter_func(
|
|
||||||
[a.to_json() for a in to_keep]
|
|
||||||
)
|
|
||||||
|
|
||||||
return did_delete_alias
|
|
||||||
|
|
||||||
async def get_prefix(self, message: discord.Message) -> str:
|
async def get_prefix(self, message: discord.Message) -> str:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -183,30 +183,24 @@ class Filter:
|
|||||||
await ctx.send(_("Count and time have been set."))
|
await ctx.send(_("Count and time have been set."))
|
||||||
|
|
||||||
async def add_to_filter(self, server: discord.Guild, words: list) -> bool:
|
async def add_to_filter(self, server: discord.Guild, words: list) -> bool:
|
||||||
added = 0
|
added = False
|
||||||
cur_list = await self.settings.guild(server).filter()
|
async with self.settings.guild(server).filter() as cur_list:
|
||||||
for w in words:
|
for w in words:
|
||||||
if w.lower() not in cur_list and w != "":
|
if w.lower() not in cur_list and w:
|
||||||
cur_list.append(w.lower())
|
cur_list.append(w.lower())
|
||||||
added += 1
|
added = True
|
||||||
if added:
|
|
||||||
await self.settings.guild(server).filter.set(cur_list)
|
return added
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def remove_from_filter(self, server: discord.Guild, words: list) -> bool:
|
async def remove_from_filter(self, server: discord.Guild, words: list) -> bool:
|
||||||
removed = 0
|
removed = False
|
||||||
cur_list = await self.settings.guild(server).filter()
|
async with self.settings.guild(server).filter() as cur_list:
|
||||||
for w in words:
|
for w in words:
|
||||||
if w.lower() in cur_list:
|
if w.lower() in cur_list:
|
||||||
cur_list.remove(w.lower())
|
cur_list.remove(w.lower())
|
||||||
removed += 1
|
removed = True
|
||||||
if removed:
|
|
||||||
await self.settings.guild(server).filter.set(cur_list)
|
return removed
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
async def check_filter(self, message: discord.Message):
|
async def check_filter(self, message: discord.Message):
|
||||||
server = message.guild
|
server = message.guild
|
||||||
|
|||||||
@ -1347,18 +1347,21 @@ class Mod:
|
|||||||
|
|
||||||
async def on_member_update(self, before, after):
|
async def on_member_update(self, before, after):
|
||||||
if before.name != after.name:
|
if before.name != after.name:
|
||||||
name_list = await self.settings.user(before).past_names()
|
async with self.settings.user(before).past_names() as name_list:
|
||||||
if after.name not in name_list:
|
if after.nick in name_list:
|
||||||
names = deque(name_list, maxlen=20)
|
# Ensure order is maintained without duplicates occuring
|
||||||
names.append(after.name)
|
name_list.remove(after.nick)
|
||||||
await self.settings.user(before).past_names.set(list(names))
|
name_list.append(after.nick)
|
||||||
|
while len(name_list) > 20:
|
||||||
|
name_list.pop(0)
|
||||||
|
|
||||||
if before.nick != after.nick and after.nick is not None:
|
if before.nick != after.nick and after.nick is not None:
|
||||||
nick_list = await self.settings.member(before).past_nicks()
|
async with self.settings.member(before).past_nicks() as nick_list:
|
||||||
nicks = deque(nick_list, maxlen=20)
|
if after.nick in nick_list:
|
||||||
if after.nick not in nicks:
|
nick_list.remove(after.nick)
|
||||||
nicks.append(after.nick)
|
nick_list.append(after.nick)
|
||||||
await self.settings.member(before).past_nicks.set(list(nicks))
|
while len(nick_list) > 20:
|
||||||
|
nick_list.pop(0)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def are_overwrites_empty(overwrites):
|
def are_overwrites_empty(overwrites):
|
||||||
|
|||||||
@ -142,15 +142,14 @@ class RedBase(BotBase, RpcMethodMixin):
|
|||||||
await self.db.packages.set(packages)
|
await self.db.packages.set(packages)
|
||||||
|
|
||||||
async def add_loaded_package(self, pkg_name: str):
|
async def add_loaded_package(self, pkg_name: str):
|
||||||
curr_pkgs = await self.db.packages()
|
async with self.db.packages() as curr_pkgs:
|
||||||
if pkg_name not in curr_pkgs:
|
if pkg_name not in curr_pkgs:
|
||||||
curr_pkgs.append(pkg_name)
|
curr_pkgs.append(pkg_name)
|
||||||
await self.save_packages_status(curr_pkgs)
|
|
||||||
|
|
||||||
async def remove_loaded_package(self, pkg_name: str):
|
async def remove_loaded_package(self, pkg_name: str):
|
||||||
curr_pkgs = await self.db.packages()
|
async with self.db.packages() as curr_pkgs:
|
||||||
if pkg_name in curr_pkgs:
|
while pkg_name in curr_pkgs:
|
||||||
await self.save_packages_status([p for p in curr_pkgs if p != pkg_name])
|
curr_pkgs.remove(pkg_name)
|
||||||
|
|
||||||
def load_extension(self, spec: ModuleSpec):
|
def load_extension(self, spec: ModuleSpec):
|
||||||
name = spec.name.split('.')[-1]
|
name = spec.name.split('.')[-1]
|
||||||
|
|||||||
@ -10,6 +10,37 @@ from .drivers import get_driver
|
|||||||
log = logging.getLogger("red.config")
|
log = logging.getLogger("red.config")
|
||||||
|
|
||||||
|
|
||||||
|
class _ValueCtxManager:
|
||||||
|
"""Context manager implementation of config values.
|
||||||
|
|
||||||
|
This class allows mutable config values to be both "get" and "set" from
|
||||||
|
within an async context manager.
|
||||||
|
|
||||||
|
The context manager can only be used to get and set a mutable data type,
|
||||||
|
i.e. `dict`s or `list`s. This is because this class's ``raw_value``
|
||||||
|
attribute must contain a reference to the object being modified within the
|
||||||
|
context manager.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, value_obj, coro):
|
||||||
|
self.value_obj = value_obj
|
||||||
|
self.coro = coro
|
||||||
|
|
||||||
|
def __await__(self):
|
||||||
|
return self.coro.__await__()
|
||||||
|
|
||||||
|
async def __aenter__(self):
|
||||||
|
self.raw_value = await self
|
||||||
|
if not isinstance(self.raw_value, (list, dict)):
|
||||||
|
raise TypeError("Type of retrieved value must be mutable (i.e. "
|
||||||
|
"list or dict) in order to use a config value as "
|
||||||
|
"a context manager.")
|
||||||
|
return self.raw_value
|
||||||
|
|
||||||
|
async def __aexit__(self, *exc_info):
|
||||||
|
await self.value_obj.set(self.raw_value)
|
||||||
|
|
||||||
|
|
||||||
class Value:
|
class Value:
|
||||||
"""A singular "value" of data.
|
"""A singular "value" of data.
|
||||||
|
|
||||||
@ -49,6 +80,11 @@ class Value:
|
|||||||
"real" data of the `Value` object is accessed by this method. It is a
|
"real" data of the `Value` object is accessed by this method. It is a
|
||||||
replacement for a :code:`get()` method.
|
replacement for a :code:`get()` method.
|
||||||
|
|
||||||
|
The return value of this method can also be used as an asynchronous
|
||||||
|
context manager, i.e. with :code:`async with` syntax. This can only be
|
||||||
|
used on values which are mutable (namely lists and dicts), and will
|
||||||
|
set the value with its changes on exit of the context manager.
|
||||||
|
|
||||||
Example
|
Example
|
||||||
-------
|
-------
|
||||||
::
|
::
|
||||||
@ -74,11 +110,13 @@ class Value:
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
types.coroutine
|
`awaitable` mixed with `asynchronous context manager`
|
||||||
A coroutine object that must be awaited.
|
A coroutine object mixed in with an async context manager. When
|
||||||
|
awaited, this returns the raw data value. When used in :code:`async
|
||||||
|
with` syntax, on gets the value on entrance, and sets it on exit.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return self._get(default)
|
return _ValueCtxManager(self, self._get(default))
|
||||||
|
|
||||||
async def set(self, value):
|
async def set(self, value):
|
||||||
"""Set the value of the data elements pointed to by `identifiers`.
|
"""Set the value of the data elements pointed to by `identifiers`.
|
||||||
|
|||||||
@ -333,3 +333,46 @@ async def test_user_getalldata(config, user_factory):
|
|||||||
assert "bar" in all_data
|
assert "bar" in all_data
|
||||||
|
|
||||||
assert config.user(user).defaults['foo'] is True
|
assert config.user(user).defaults['foo'] is True
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_value_ctxmgr(config):
|
||||||
|
config.register_global(foo_list=[])
|
||||||
|
|
||||||
|
async with config.foo_list() as foo_list:
|
||||||
|
foo_list.append('foo')
|
||||||
|
|
||||||
|
foo_list = await config.foo_list()
|
||||||
|
|
||||||
|
assert 'foo' in foo_list
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_value_ctxmgr_saves(config):
|
||||||
|
config.register_global(bar_list=[])
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with config.bar_list() as bar_list:
|
||||||
|
bar_list.append('bar')
|
||||||
|
raise RuntimeError()
|
||||||
|
except RuntimeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
bar_list = await config.bar_list()
|
||||||
|
|
||||||
|
assert 'bar' in bar_list
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_value_ctxmgr_immutable(config):
|
||||||
|
config.register_global(foo=True)
|
||||||
|
|
||||||
|
try:
|
||||||
|
async with config.foo() as foo:
|
||||||
|
foo = False
|
||||||
|
except TypeError:
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
raise AssertionError
|
||||||
|
|
||||||
|
foo = await config.foo()
|
||||||
|
assert foo is True
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user