diff --git a/docs/framework_config.rst b/docs/framework_config.rst index b5388ffdf..fd392dd26 100644 --- a/docs/framework_config.rst +++ b/docs/framework_config.rst @@ -29,7 +29,7 @@ Basic Usage @commands.command() async def return_some_data(self, ctx): - await ctx.send(config.foo()) + await ctx.send(await config.foo()) ******** Tutorial @@ -117,11 +117,36 @@ Notice a few things in the above examples: 3. If you're getting the value, the syntax is:: - self.config..variable_name() + self.config..variable_name() 4. If setting, it's:: - self.config..variable_name.set(new_value) + self.config..variable_name.set(new_value) + +It is also possible to use :code:`async with` syntax to get and set config +values. When entering the statement, the config value is retreived, and on exit, +it is saved. This puts a safeguard on any code within the :code:`async with` +block such that if it breaks from the block in any way (whether it be from +:code:`return`, :code:`break`, :code:`continue` or an exception), the value will +still be saved. + +.. important:: + + Only mutable config values can be used in the :code:`async with` statement + (namely lists or dicts), and they must be modified *in place* for their + changes to be saved. + +Here is an example of the :code:`async with` syntax: + +.. code-block:: python + + @commands.command() + async def addblah(self, ctx, new_blah): + guild_group = self.config.guild(ctx.guild) + async with guild_group.blah() as blah: + blah.append(new_blah) + await ctx.send("The new blah value has been added!") + .. important:: diff --git a/redbot/cogs/admin/admin.py b/redbot/cogs/admin/admin.py index df516d2e6..424162151 100644 --- a/redbot/cogs/admin/admin.py +++ b/redbot/cogs/admin/admin.py @@ -328,10 +328,9 @@ class Admin: """ Add a role to the list of available selfroles. """ - curr_selfroles = await self.conf.guild(ctx.guild).selfroles() - if role.id not in curr_selfroles: - curr_selfroles.append(role.id) - await self.conf.guild(ctx.guild).selfroles.set(curr_selfroles) + async with self.conf.guild(ctx.guild).selfroles() as curr_selfroles: + if role.id not in curr_selfroles: + curr_selfroles.append(role.id) await ctx.send("The selfroles list has been successfully modified.") @@ -341,9 +340,8 @@ class Admin: """ Removes a role from the list of available selfroles. """ - curr_selfroles = await self.conf.guild(ctx.guild).selfroles() - curr_selfroles.remove(role.id) - await self.conf.guild(ctx.guild).selfroles.set(curr_selfroles) + async with self.conf.guild(ctx.guild).selfroles() as curr_selfroles: + curr_selfroles.remove(role.id) await ctx.send("The selfroles list has been successfully modified.") diff --git a/redbot/cogs/alias/alias.py b/redbot/cogs/alias/alias.py index 7907767a8..7fd25ec12 100644 --- a/redbot/cogs/alias/alias.py +++ b/redbot/cogs/alias/alias.py @@ -82,41 +82,31 @@ class Alias: alias = AliasEntry(alias_name, command, ctx.author, global_=global_) if global_: - curr_aliases = await self._aliases.entries() - curr_aliases.append(alias.to_json()) - await self._aliases.entries.set(curr_aliases) + settings = self._aliases else: - curr_aliases = await self._aliases.guild(ctx.guild).entries() + settings = self._aliases.guild(ctx.guild) + await settings.enabled.set(True) + async with settings.entries() as curr_aliases: curr_aliases.append(alias.to_json()) - await self._aliases.guild(ctx.guild).entries.set(curr_aliases) - await self._aliases.guild(ctx.guild).enabled.set(True) return alias async def delete_alias(self, ctx: commands.Context, alias_name: str, global_: bool=False) -> bool: if global_: - aliases = await self.unloaded_global_aliases() - setter_func = self._aliases.entries.set + settings = self._aliases else: - aliases = await self.unloaded_aliases(ctx.guild) - setter_func = self._aliases.guild(ctx.guild).entries.set + settings = self._aliases.guild(ctx.guild) - did_delete_alias = False + async with settings.entries() as aliases: + for alias in aliases: + alias_obj = AliasEntry.from_json(alias) + if alias_obj.name == alias_name: + aliases.remove(alias) + return True - to_keep = [] - for alias in aliases: - if alias.name != alias_name: - to_keep.append(alias) - else: - did_delete_alias = True - - await setter_func( - [a.to_json() for a in to_keep] - ) - - return did_delete_alias + return False async def get_prefix(self, message: discord.Message) -> str: """ diff --git a/redbot/cogs/filter/filter.py b/redbot/cogs/filter/filter.py index 4c9c21075..dc5dbc1e6 100644 --- a/redbot/cogs/filter/filter.py +++ b/redbot/cogs/filter/filter.py @@ -183,30 +183,24 @@ class Filter: await ctx.send(_("Count and time have been set.")) async def add_to_filter(self, server: discord.Guild, words: list) -> bool: - added = 0 - cur_list = await self.settings.guild(server).filter() - for w in words: - if w.lower() not in cur_list and w != "": - cur_list.append(w.lower()) - added += 1 - if added: - await self.settings.guild(server).filter.set(cur_list) - return True - else: - return False + added = False + async with self.settings.guild(server).filter() as cur_list: + for w in words: + if w.lower() not in cur_list and w: + cur_list.append(w.lower()) + added = True + + return added async def remove_from_filter(self, server: discord.Guild, words: list) -> bool: - removed = 0 - cur_list = await self.settings.guild(server).filter() - for w in words: - if w.lower() in cur_list: - cur_list.remove(w.lower()) - removed += 1 - if removed: - await self.settings.guild(server).filter.set(cur_list) - return True - else: - return False + removed = False + async with self.settings.guild(server).filter() as cur_list: + for w in words: + if w.lower() in cur_list: + cur_list.remove(w.lower()) + removed = True + + return removed async def check_filter(self, message: discord.Message): server = message.guild diff --git a/redbot/cogs/mod/mod.py b/redbot/cogs/mod/mod.py index 23790748c..21b333c2c 100644 --- a/redbot/cogs/mod/mod.py +++ b/redbot/cogs/mod/mod.py @@ -1347,18 +1347,21 @@ class Mod: async def on_member_update(self, before, after): if before.name != after.name: - name_list = await self.settings.user(before).past_names() - if after.name not in name_list: - names = deque(name_list, maxlen=20) - names.append(after.name) - await self.settings.user(before).past_names.set(list(names)) + async with self.settings.user(before).past_names() as name_list: + if after.nick in name_list: + # Ensure order is maintained without duplicates occuring + name_list.remove(after.nick) + name_list.append(after.nick) + while len(name_list) > 20: + name_list.pop(0) if before.nick != after.nick and after.nick is not None: - nick_list = await self.settings.member(before).past_nicks() - nicks = deque(nick_list, maxlen=20) - if after.nick not in nicks: - nicks.append(after.nick) - await self.settings.member(before).past_nicks.set(list(nicks)) + async with self.settings.member(before).past_nicks() as nick_list: + if after.nick in nick_list: + nick_list.remove(after.nick) + nick_list.append(after.nick) + while len(nick_list) > 20: + nick_list.pop(0) @staticmethod def are_overwrites_empty(overwrites): diff --git a/redbot/core/bot.py b/redbot/core/bot.py index 5c67d227a..6c5eb773d 100644 --- a/redbot/core/bot.py +++ b/redbot/core/bot.py @@ -142,15 +142,14 @@ class RedBase(BotBase, RpcMethodMixin): await self.db.packages.set(packages) async def add_loaded_package(self, pkg_name: str): - curr_pkgs = await self.db.packages() - if pkg_name not in curr_pkgs: - curr_pkgs.append(pkg_name) - await self.save_packages_status(curr_pkgs) + async with self.db.packages() as curr_pkgs: + if pkg_name not in curr_pkgs: + curr_pkgs.append(pkg_name) async def remove_loaded_package(self, pkg_name: str): - curr_pkgs = await self.db.packages() - if pkg_name in curr_pkgs: - await self.save_packages_status([p for p in curr_pkgs if p != pkg_name]) + async with self.db.packages() as curr_pkgs: + while pkg_name in curr_pkgs: + curr_pkgs.remove(pkg_name) def load_extension(self, spec: ModuleSpec): name = spec.name.split('.')[-1] diff --git a/redbot/core/config.py b/redbot/core/config.py index a53fb2185..f6fdcb082 100644 --- a/redbot/core/config.py +++ b/redbot/core/config.py @@ -10,6 +10,37 @@ from .drivers import get_driver log = logging.getLogger("red.config") +class _ValueCtxManager: + """Context manager implementation of config values. + + This class allows mutable config values to be both "get" and "set" from + within an async context manager. + + The context manager can only be used to get and set a mutable data type, + i.e. `dict`s or `list`s. This is because this class's ``raw_value`` + attribute must contain a reference to the object being modified within the + context manager. + """ + + def __init__(self, value_obj, coro): + self.value_obj = value_obj + self.coro = coro + + def __await__(self): + return self.coro.__await__() + + async def __aenter__(self): + self.raw_value = await self + if not isinstance(self.raw_value, (list, dict)): + raise TypeError("Type of retrieved value must be mutable (i.e. " + "list or dict) in order to use a config value as " + "a context manager.") + return self.raw_value + + async def __aexit__(self, *exc_info): + await self.value_obj.set(self.raw_value) + + class Value: """A singular "value" of data. @@ -49,6 +80,11 @@ class Value: "real" data of the `Value` object is accessed by this method. It is a replacement for a :code:`get()` method. + The return value of this method can also be used as an asynchronous + context manager, i.e. with :code:`async with` syntax. This can only be + used on values which are mutable (namely lists and dicts), and will + set the value with its changes on exit of the context manager. + Example ------- :: @@ -74,11 +110,13 @@ class Value: Returns ------- - types.coroutine - A coroutine object that must be awaited. + `awaitable` mixed with `asynchronous context manager` + A coroutine object mixed in with an async context manager. When + awaited, this returns the raw data value. When used in :code:`async + with` syntax, on gets the value on entrance, and sets it on exit. """ - return self._get(default) + return _ValueCtxManager(self, self._get(default)) async def set(self, value): """Set the value of the data elements pointed to by `identifiers`. diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 184691cf0..3081ee0ce 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -333,3 +333,46 @@ async def test_user_getalldata(config, user_factory): assert "bar" in all_data assert config.user(user).defaults['foo'] is True + +@pytest.mark.asyncio +async def test_value_ctxmgr(config): + config.register_global(foo_list=[]) + + async with config.foo_list() as foo_list: + foo_list.append('foo') + + foo_list = await config.foo_list() + + assert 'foo' in foo_list + + +@pytest.mark.asyncio +async def test_value_ctxmgr_saves(config): + config.register_global(bar_list=[]) + + try: + async with config.bar_list() as bar_list: + bar_list.append('bar') + raise RuntimeError() + except RuntimeError: + pass + + bar_list = await config.bar_list() + + assert 'bar' in bar_list + + +@pytest.mark.asyncio +async def test_value_ctxmgr_immutable(config): + config.register_global(foo=True) + + try: + async with config.foo() as foo: + foo = False + except TypeError: + pass + else: + raise AssertionError + + foo = await config.foo() + assert foo is True