[Config] Retrieve/save values with async context manager (#1131)

* [Config] Retrieve/save values with async context manager

* Add a little docstring

* Documentation

* Implement async with syntax in existing modules
This commit is contained in:
Tobotimus
2017-12-04 14:07:34 +11:00
committed by palmtree5
parent 9b1018fa96
commit 9dbf56f942
8 changed files with 165 additions and 75 deletions

View File

@@ -328,10 +328,9 @@ class Admin:
"""
Add a role to the list of available selfroles.
"""
curr_selfroles = await self.conf.guild(ctx.guild).selfroles()
if role.id not in curr_selfroles:
curr_selfroles.append(role.id)
await self.conf.guild(ctx.guild).selfroles.set(curr_selfroles)
async with self.conf.guild(ctx.guild).selfroles() as curr_selfroles:
if role.id not in curr_selfroles:
curr_selfroles.append(role.id)
await ctx.send("The selfroles list has been successfully modified.")
@@ -341,9 +340,8 @@ class Admin:
"""
Removes a role from the list of available selfroles.
"""
curr_selfroles = await self.conf.guild(ctx.guild).selfroles()
curr_selfroles.remove(role.id)
await self.conf.guild(ctx.guild).selfroles.set(curr_selfroles)
async with self.conf.guild(ctx.guild).selfroles() as curr_selfroles:
curr_selfroles.remove(role.id)
await ctx.send("The selfroles list has been successfully modified.")

View File

@@ -82,41 +82,31 @@ class Alias:
alias = AliasEntry(alias_name, command, ctx.author, global_=global_)
if global_:
curr_aliases = await self._aliases.entries()
curr_aliases.append(alias.to_json())
await self._aliases.entries.set(curr_aliases)
settings = self._aliases
else:
curr_aliases = await self._aliases.guild(ctx.guild).entries()
settings = self._aliases.guild(ctx.guild)
await settings.enabled.set(True)
async with settings.entries() as curr_aliases:
curr_aliases.append(alias.to_json())
await self._aliases.guild(ctx.guild).entries.set(curr_aliases)
await self._aliases.guild(ctx.guild).enabled.set(True)
return alias
async def delete_alias(self, ctx: commands.Context, alias_name: str,
global_: bool=False) -> bool:
if global_:
aliases = await self.unloaded_global_aliases()
setter_func = self._aliases.entries.set
settings = self._aliases
else:
aliases = await self.unloaded_aliases(ctx.guild)
setter_func = self._aliases.guild(ctx.guild).entries.set
settings = self._aliases.guild(ctx.guild)
did_delete_alias = False
async with settings.entries() as aliases:
for alias in aliases:
alias_obj = AliasEntry.from_json(alias)
if alias_obj.name == alias_name:
aliases.remove(alias)
return True
to_keep = []
for alias in aliases:
if alias.name != alias_name:
to_keep.append(alias)
else:
did_delete_alias = True
await setter_func(
[a.to_json() for a in to_keep]
)
return did_delete_alias
return False
async def get_prefix(self, message: discord.Message) -> str:
"""

View File

@@ -183,30 +183,24 @@ class Filter:
await ctx.send(_("Count and time have been set."))
async def add_to_filter(self, server: discord.Guild, words: list) -> bool:
added = 0
cur_list = await self.settings.guild(server).filter()
for w in words:
if w.lower() not in cur_list and w != "":
cur_list.append(w.lower())
added += 1
if added:
await self.settings.guild(server).filter.set(cur_list)
return True
else:
return False
added = False
async with self.settings.guild(server).filter() as cur_list:
for w in words:
if w.lower() not in cur_list and w:
cur_list.append(w.lower())
added = True
return added
async def remove_from_filter(self, server: discord.Guild, words: list) -> bool:
removed = 0
cur_list = await self.settings.guild(server).filter()
for w in words:
if w.lower() in cur_list:
cur_list.remove(w.lower())
removed += 1
if removed:
await self.settings.guild(server).filter.set(cur_list)
return True
else:
return False
removed = False
async with self.settings.guild(server).filter() as cur_list:
for w in words:
if w.lower() in cur_list:
cur_list.remove(w.lower())
removed = True
return removed
async def check_filter(self, message: discord.Message):
server = message.guild

View File

@@ -1347,18 +1347,21 @@ class Mod:
async def on_member_update(self, before, after):
if before.name != after.name:
name_list = await self.settings.user(before).past_names()
if after.name not in name_list:
names = deque(name_list, maxlen=20)
names.append(after.name)
await self.settings.user(before).past_names.set(list(names))
async with self.settings.user(before).past_names() as name_list:
if after.nick in name_list:
# Ensure order is maintained without duplicates occuring
name_list.remove(after.nick)
name_list.append(after.nick)
while len(name_list) > 20:
name_list.pop(0)
if before.nick != after.nick and after.nick is not None:
nick_list = await self.settings.member(before).past_nicks()
nicks = deque(nick_list, maxlen=20)
if after.nick not in nicks:
nicks.append(after.nick)
await self.settings.member(before).past_nicks.set(list(nicks))
async with self.settings.member(before).past_nicks() as nick_list:
if after.nick in nick_list:
nick_list.remove(after.nick)
nick_list.append(after.nick)
while len(nick_list) > 20:
nick_list.pop(0)
@staticmethod
def are_overwrites_empty(overwrites):

View File

@@ -142,15 +142,14 @@ class RedBase(BotBase, RpcMethodMixin):
await self.db.packages.set(packages)
async def add_loaded_package(self, pkg_name: str):
curr_pkgs = await self.db.packages()
if pkg_name not in curr_pkgs:
curr_pkgs.append(pkg_name)
await self.save_packages_status(curr_pkgs)
async with self.db.packages() as curr_pkgs:
if pkg_name not in curr_pkgs:
curr_pkgs.append(pkg_name)
async def remove_loaded_package(self, pkg_name: str):
curr_pkgs = await self.db.packages()
if pkg_name in curr_pkgs:
await self.save_packages_status([p for p in curr_pkgs if p != pkg_name])
async with self.db.packages() as curr_pkgs:
while pkg_name in curr_pkgs:
curr_pkgs.remove(pkg_name)
def load_extension(self, spec: ModuleSpec):
name = spec.name.split('.')[-1]

View File

@@ -10,6 +10,37 @@ from .drivers import get_driver
log = logging.getLogger("red.config")
class _ValueCtxManager:
"""Context manager implementation of config values.
This class allows mutable config values to be both "get" and "set" from
within an async context manager.
The context manager can only be used to get and set a mutable data type,
i.e. `dict`s or `list`s. This is because this class's ``raw_value``
attribute must contain a reference to the object being modified within the
context manager.
"""
def __init__(self, value_obj, coro):
self.value_obj = value_obj
self.coro = coro
def __await__(self):
return self.coro.__await__()
async def __aenter__(self):
self.raw_value = await self
if not isinstance(self.raw_value, (list, dict)):
raise TypeError("Type of retrieved value must be mutable (i.e. "
"list or dict) in order to use a config value as "
"a context manager.")
return self.raw_value
async def __aexit__(self, *exc_info):
await self.value_obj.set(self.raw_value)
class Value:
"""A singular "value" of data.
@@ -49,6 +80,11 @@ class Value:
"real" data of the `Value` object is accessed by this method. It is a
replacement for a :code:`get()` method.
The return value of this method can also be used as an asynchronous
context manager, i.e. with :code:`async with` syntax. This can only be
used on values which are mutable (namely lists and dicts), and will
set the value with its changes on exit of the context manager.
Example
-------
::
@@ -74,11 +110,13 @@ class Value:
Returns
-------
types.coroutine
A coroutine object that must be awaited.
`awaitable` mixed with `asynchronous context manager`
A coroutine object mixed in with an async context manager. When
awaited, this returns the raw data value. When used in :code:`async
with` syntax, on gets the value on entrance, and sets it on exit.
"""
return self._get(default)
return _ValueCtxManager(self, self._get(default))
async def set(self, value):
"""Set the value of the data elements pointed to by `identifiers`.