mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-20 18:06:08 -05:00
[Config] Retrieve/save values with async context manager (#1131)
* [Config] Retrieve/save values with async context manager * Add a little docstring * Documentation * Implement async with syntax in existing modules
This commit is contained in:
@@ -328,10 +328,9 @@ class Admin:
|
||||
"""
|
||||
Add a role to the list of available selfroles.
|
||||
"""
|
||||
curr_selfroles = await self.conf.guild(ctx.guild).selfroles()
|
||||
if role.id not in curr_selfroles:
|
||||
curr_selfroles.append(role.id)
|
||||
await self.conf.guild(ctx.guild).selfroles.set(curr_selfroles)
|
||||
async with self.conf.guild(ctx.guild).selfroles() as curr_selfroles:
|
||||
if role.id not in curr_selfroles:
|
||||
curr_selfroles.append(role.id)
|
||||
|
||||
await ctx.send("The selfroles list has been successfully modified.")
|
||||
|
||||
@@ -341,9 +340,8 @@ class Admin:
|
||||
"""
|
||||
Removes a role from the list of available selfroles.
|
||||
"""
|
||||
curr_selfroles = await self.conf.guild(ctx.guild).selfroles()
|
||||
curr_selfroles.remove(role.id)
|
||||
await self.conf.guild(ctx.guild).selfroles.set(curr_selfroles)
|
||||
async with self.conf.guild(ctx.guild).selfroles() as curr_selfroles:
|
||||
curr_selfroles.remove(role.id)
|
||||
|
||||
await ctx.send("The selfroles list has been successfully modified.")
|
||||
|
||||
|
||||
@@ -82,41 +82,31 @@ class Alias:
|
||||
alias = AliasEntry(alias_name, command, ctx.author, global_=global_)
|
||||
|
||||
if global_:
|
||||
curr_aliases = await self._aliases.entries()
|
||||
curr_aliases.append(alias.to_json())
|
||||
await self._aliases.entries.set(curr_aliases)
|
||||
settings = self._aliases
|
||||
else:
|
||||
curr_aliases = await self._aliases.guild(ctx.guild).entries()
|
||||
settings = self._aliases.guild(ctx.guild)
|
||||
await settings.enabled.set(True)
|
||||
|
||||
async with settings.entries() as curr_aliases:
|
||||
curr_aliases.append(alias.to_json())
|
||||
await self._aliases.guild(ctx.guild).entries.set(curr_aliases)
|
||||
|
||||
await self._aliases.guild(ctx.guild).enabled.set(True)
|
||||
return alias
|
||||
|
||||
async def delete_alias(self, ctx: commands.Context, alias_name: str,
|
||||
global_: bool=False) -> bool:
|
||||
if global_:
|
||||
aliases = await self.unloaded_global_aliases()
|
||||
setter_func = self._aliases.entries.set
|
||||
settings = self._aliases
|
||||
else:
|
||||
aliases = await self.unloaded_aliases(ctx.guild)
|
||||
setter_func = self._aliases.guild(ctx.guild).entries.set
|
||||
settings = self._aliases.guild(ctx.guild)
|
||||
|
||||
did_delete_alias = False
|
||||
async with settings.entries() as aliases:
|
||||
for alias in aliases:
|
||||
alias_obj = AliasEntry.from_json(alias)
|
||||
if alias_obj.name == alias_name:
|
||||
aliases.remove(alias)
|
||||
return True
|
||||
|
||||
to_keep = []
|
||||
for alias in aliases:
|
||||
if alias.name != alias_name:
|
||||
to_keep.append(alias)
|
||||
else:
|
||||
did_delete_alias = True
|
||||
|
||||
await setter_func(
|
||||
[a.to_json() for a in to_keep]
|
||||
)
|
||||
|
||||
return did_delete_alias
|
||||
return False
|
||||
|
||||
async def get_prefix(self, message: discord.Message) -> str:
|
||||
"""
|
||||
|
||||
@@ -183,30 +183,24 @@ class Filter:
|
||||
await ctx.send(_("Count and time have been set."))
|
||||
|
||||
async def add_to_filter(self, server: discord.Guild, words: list) -> bool:
|
||||
added = 0
|
||||
cur_list = await self.settings.guild(server).filter()
|
||||
for w in words:
|
||||
if w.lower() not in cur_list and w != "":
|
||||
cur_list.append(w.lower())
|
||||
added += 1
|
||||
if added:
|
||||
await self.settings.guild(server).filter.set(cur_list)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
added = False
|
||||
async with self.settings.guild(server).filter() as cur_list:
|
||||
for w in words:
|
||||
if w.lower() not in cur_list and w:
|
||||
cur_list.append(w.lower())
|
||||
added = True
|
||||
|
||||
return added
|
||||
|
||||
async def remove_from_filter(self, server: discord.Guild, words: list) -> bool:
|
||||
removed = 0
|
||||
cur_list = await self.settings.guild(server).filter()
|
||||
for w in words:
|
||||
if w.lower() in cur_list:
|
||||
cur_list.remove(w.lower())
|
||||
removed += 1
|
||||
if removed:
|
||||
await self.settings.guild(server).filter.set(cur_list)
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
removed = False
|
||||
async with self.settings.guild(server).filter() as cur_list:
|
||||
for w in words:
|
||||
if w.lower() in cur_list:
|
||||
cur_list.remove(w.lower())
|
||||
removed = True
|
||||
|
||||
return removed
|
||||
|
||||
async def check_filter(self, message: discord.Message):
|
||||
server = message.guild
|
||||
|
||||
@@ -1347,18 +1347,21 @@ class Mod:
|
||||
|
||||
async def on_member_update(self, before, after):
|
||||
if before.name != after.name:
|
||||
name_list = await self.settings.user(before).past_names()
|
||||
if after.name not in name_list:
|
||||
names = deque(name_list, maxlen=20)
|
||||
names.append(after.name)
|
||||
await self.settings.user(before).past_names.set(list(names))
|
||||
async with self.settings.user(before).past_names() as name_list:
|
||||
if after.nick in name_list:
|
||||
# Ensure order is maintained without duplicates occuring
|
||||
name_list.remove(after.nick)
|
||||
name_list.append(after.nick)
|
||||
while len(name_list) > 20:
|
||||
name_list.pop(0)
|
||||
|
||||
if before.nick != after.nick and after.nick is not None:
|
||||
nick_list = await self.settings.member(before).past_nicks()
|
||||
nicks = deque(nick_list, maxlen=20)
|
||||
if after.nick not in nicks:
|
||||
nicks.append(after.nick)
|
||||
await self.settings.member(before).past_nicks.set(list(nicks))
|
||||
async with self.settings.member(before).past_nicks() as nick_list:
|
||||
if after.nick in nick_list:
|
||||
nick_list.remove(after.nick)
|
||||
nick_list.append(after.nick)
|
||||
while len(nick_list) > 20:
|
||||
nick_list.pop(0)
|
||||
|
||||
@staticmethod
|
||||
def are_overwrites_empty(overwrites):
|
||||
|
||||
@@ -142,15 +142,14 @@ class RedBase(BotBase, RpcMethodMixin):
|
||||
await self.db.packages.set(packages)
|
||||
|
||||
async def add_loaded_package(self, pkg_name: str):
|
||||
curr_pkgs = await self.db.packages()
|
||||
if pkg_name not in curr_pkgs:
|
||||
curr_pkgs.append(pkg_name)
|
||||
await self.save_packages_status(curr_pkgs)
|
||||
async with self.db.packages() as curr_pkgs:
|
||||
if pkg_name not in curr_pkgs:
|
||||
curr_pkgs.append(pkg_name)
|
||||
|
||||
async def remove_loaded_package(self, pkg_name: str):
|
||||
curr_pkgs = await self.db.packages()
|
||||
if pkg_name in curr_pkgs:
|
||||
await self.save_packages_status([p for p in curr_pkgs if p != pkg_name])
|
||||
async with self.db.packages() as curr_pkgs:
|
||||
while pkg_name in curr_pkgs:
|
||||
curr_pkgs.remove(pkg_name)
|
||||
|
||||
def load_extension(self, spec: ModuleSpec):
|
||||
name = spec.name.split('.')[-1]
|
||||
|
||||
@@ -10,6 +10,37 @@ from .drivers import get_driver
|
||||
log = logging.getLogger("red.config")
|
||||
|
||||
|
||||
class _ValueCtxManager:
|
||||
"""Context manager implementation of config values.
|
||||
|
||||
This class allows mutable config values to be both "get" and "set" from
|
||||
within an async context manager.
|
||||
|
||||
The context manager can only be used to get and set a mutable data type,
|
||||
i.e. `dict`s or `list`s. This is because this class's ``raw_value``
|
||||
attribute must contain a reference to the object being modified within the
|
||||
context manager.
|
||||
"""
|
||||
|
||||
def __init__(self, value_obj, coro):
|
||||
self.value_obj = value_obj
|
||||
self.coro = coro
|
||||
|
||||
def __await__(self):
|
||||
return self.coro.__await__()
|
||||
|
||||
async def __aenter__(self):
|
||||
self.raw_value = await self
|
||||
if not isinstance(self.raw_value, (list, dict)):
|
||||
raise TypeError("Type of retrieved value must be mutable (i.e. "
|
||||
"list or dict) in order to use a config value as "
|
||||
"a context manager.")
|
||||
return self.raw_value
|
||||
|
||||
async def __aexit__(self, *exc_info):
|
||||
await self.value_obj.set(self.raw_value)
|
||||
|
||||
|
||||
class Value:
|
||||
"""A singular "value" of data.
|
||||
|
||||
@@ -49,6 +80,11 @@ class Value:
|
||||
"real" data of the `Value` object is accessed by this method. It is a
|
||||
replacement for a :code:`get()` method.
|
||||
|
||||
The return value of this method can also be used as an asynchronous
|
||||
context manager, i.e. with :code:`async with` syntax. This can only be
|
||||
used on values which are mutable (namely lists and dicts), and will
|
||||
set the value with its changes on exit of the context manager.
|
||||
|
||||
Example
|
||||
-------
|
||||
::
|
||||
@@ -74,11 +110,13 @@ class Value:
|
||||
|
||||
Returns
|
||||
-------
|
||||
types.coroutine
|
||||
A coroutine object that must be awaited.
|
||||
`awaitable` mixed with `asynchronous context manager`
|
||||
A coroutine object mixed in with an async context manager. When
|
||||
awaited, this returns the raw data value. When used in :code:`async
|
||||
with` syntax, on gets the value on entrance, and sets it on exit.
|
||||
|
||||
"""
|
||||
return self._get(default)
|
||||
return _ValueCtxManager(self, self._get(default))
|
||||
|
||||
async def set(self, value):
|
||||
"""Set the value of the data elements pointed to by `identifiers`.
|
||||
|
||||
Reference in New Issue
Block a user