diff --git a/changelog.d/2654.feature.rst b/changelog.d/2654.feature.rst new file mode 100644 index 000000000..eac3ab395 --- /dev/null +++ b/changelog.d/2654.feature.rst @@ -0,0 +1 @@ +Added functions to acquire locks on Config groups and values. These locks are acquired by default when calling a value as a context manager. See :meth:`Value.get_lock` for details diff --git a/redbot/core/config.py b/redbot/core/config.py index 00d9305a1..e81b20c0d 100644 --- a/redbot/core/config.py +++ b/redbot/core/config.py @@ -1,8 +1,19 @@ +import asyncio import collections import logging import pickle import weakref -from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar +from typing import ( + Any, + Union, + Tuple, + Dict, + Awaitable, + AsyncContextManager, + TypeVar, + MutableMapping, + Optional, +) import discord @@ -38,18 +49,26 @@ class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]): # pylint: disab i.e. `dict`s or `list`s. This is because this class's ``raw_value`` attribute must contain a reference to the object being modified within the context manager. + + It should also be noted that the use of this context manager implies + the acquisition of the value's lock when the ``acquire_lock`` kwarg + to ``__init__`` is set to ``True``. """ - def __init__(self, value_obj, coro): + def __init__(self, value_obj: "Value", coro: Awaitable[Any], *, acquire_lock: bool): self.value_obj = value_obj self.coro = coro self.raw_value = None self.__original_value = None + self.__acquire_lock = acquire_lock + self.__lock = self.value_obj.get_lock() def __await__(self): return self.coro.__await__() async def __aenter__(self): + if self.__acquire_lock is True: + await self.__lock.acquire() self.raw_value = await self if not isinstance(self.raw_value, (list, dict)): raise TypeError( @@ -61,12 +80,16 @@ class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]): # pylint: disab return self.raw_value async def __aexit__(self, exc_type, exc, tb): - if isinstance(self.raw_value, dict): - raw_value = _str_key_dict(self.raw_value) - else: - raw_value = self.raw_value - if raw_value != self.__original_value: - await self.value_obj.set(self.raw_value) + try: + if isinstance(self.raw_value, dict): + raw_value = _str_key_dict(self.raw_value) + else: + raw_value = self.raw_value + if raw_value != self.__original_value: + await self.value_obj.set(self.raw_value) + finally: + if self.__acquire_lock is True: + self.__lock.release() class Value: @@ -74,9 +97,8 @@ class Value: Attributes ---------- - identifiers : Tuple[str] - This attribute provides all the keys necessary to get a specific data - element from a json document. + identifier_data : IdentifierData + Information on identifiers for this value. default The default value for the data element that `identifiers` points at. driver : `redbot.core.drivers.red_base.BaseDriver` @@ -84,10 +106,44 @@ class Value: """ - def __init__(self, identifier_data: IdentifierData, default_value, driver): + def __init__(self, identifier_data: IdentifierData, default_value, driver, config: "Config"): self.identifier_data = identifier_data self.default = default_value self.driver = driver + self._config = config + + def get_lock(self) -> asyncio.Lock: + """Get a lock to create a critical region where this value is accessed. + + When using this lock, make sure you either use it with the + ``async with`` syntax, or if that's not feasible, ensure you + keep a reference to it from the acquisition to the release of + the lock. That is, if you can't use ``async with`` syntax, use + the lock like this:: + + lock = config.foo.get_lock() + await lock.acquire() + # Do stuff... + lock.release() + + Do not use it like this:: + + await config.foo.get_lock().acquire() + # Do stuff... + config.foo.get_lock().release() + + Doing it the latter way will likely cause an error, as the + acquired lock will be cleaned up by the garbage collector before + it is released, meaning the second call to ``get_lock()`` will + return a different lock to the first call. + + Returns + ------- + asyncio.Lock + A lock which is weakly cached for this value object. + + """ + return self._config._lock_cache.setdefault(self.identifier_data, asyncio.Lock()) async def _get(self, default=...): try: @@ -96,7 +152,7 @@ class Value: return default if default is not ... else self.default return ret - def __call__(self, default=...) -> _ValueCtxManager[Any]: + def __call__(self, default=..., *, acquire_lock: bool = True) -> _ValueCtxManager[Any]: """Get the literal value of this data element. Each `Value` object is created by the `Group.__getattr__` method. The @@ -106,7 +162,10 @@ class Value: The return value of this method can also be used as an asynchronous context manager, i.e. with :code:`async with` syntax. This can only be used on values which are mutable (namely lists and dicts), and will - set the value with its changes on exit of the context manager. + set the value with its changes on exit of the context manager. It will + also acquire this value's lock to protect the critical region inside + this context manager's body, unless the ``acquire_lock`` keyword + argument is set to ``False``. Example ------- @@ -129,7 +188,14 @@ class Value: default : `object`, optional This argument acts as an override for the registered default provided by `default`. This argument is ignored if its - value is :code:`None`. + value is :code:`...`. + + Other Parameters + ---------------- + acquire_lock : bool + Set to ``False`` to disable the acquisition of the value's + lock over the context manager body. Defaults to ``True``. + Has no effect when not used as a context manager. Returns ------- @@ -139,7 +205,7 @@ class Value: with` syntax, on gets the value on entrance, and sets it on exit. """ - return _ValueCtxManager(self, self._get(default)) + return _ValueCtxManager(self, self._get(default), acquire_lock=acquire_lock) async def set(self, value): """Set the value of the data elements pointed to by `identifiers`. @@ -194,13 +260,14 @@ class Group(Value): identifier_data: IdentifierData, defaults: dict, driver, + config: "Config", force_registration: bool = False, ): self._defaults = defaults self.force_registration = force_registration self.driver = driver - super().__init__(identifier_data, {}, self.driver) + super().__init__(identifier_data, {}, self.driver, config) @property def defaults(self): @@ -248,17 +315,24 @@ class Group(Value): defaults=self._defaults[item], driver=self.driver, force_registration=self.force_registration, + config=self._config, ) elif is_value: return Value( identifier_data=new_identifiers, default_value=self._defaults[item], driver=self.driver, + config=self._config, ) elif self.force_registration: raise AttributeError("'{}' is not a valid registered Group or value.".format(item)) else: - return Value(identifier_data=new_identifiers, default_value=None, driver=self.driver) + return Value( + identifier_data=new_identifiers, + default_value=None, + driver=self.driver, + config=self._config, + ) async def clear_raw(self, *nested_path: Any): """ @@ -411,7 +485,7 @@ class Group(Value): return self.nested_update(raw, default) return raw - def all(self) -> _ValueCtxManager[Dict[str, Any]]: + def all(self, *, acquire_lock: bool = True) -> _ValueCtxManager[Dict[str, Any]]: """Get a dictionary representation of this group's data. The return value of this method can also be used as an asynchronous @@ -422,13 +496,19 @@ class Group(Value): The return value of this method will include registered defaults for values which have not yet been set. + Other Parameters + ---------------- + acquire_lock : bool + Same as the ``acquire_lock`` keyword parameter in + `Value.__call__`. + Returns ------- dict All of this Group's attributes, resolved as raw data values. """ - return self() + return self(acquire_lock=acquire_lock) def nested_update( self, current: collections.Mapping, defaults: Dict[str, Any] = ... @@ -555,6 +635,9 @@ class Config: self._defaults = defaults or {} self.custom_groups = {} + self._lock_cache: MutableMapping[ + IdentifierData, asyncio.Lock + ] = weakref.WeakValueDictionary() @property def defaults(self): @@ -862,6 +945,7 @@ class Config: defaults=defaults, driver=self.driver, force_registration=self.force_registration, + config=self, ) def guild(self, guild: discord.Guild) -> Group: @@ -1139,7 +1223,7 @@ class Config: identifier_data = IdentifierData( self.unique_identifier, "", (), (), self.custom_groups ) - group = Group(identifier_data, defaults={}, driver=self.driver) + group = Group(identifier_data, defaults={}, driver=self.driver, config=self) else: cat, *scopes = scopes group = self._get_base_group(cat, *scopes) @@ -1222,6 +1306,80 @@ class Config: """ await self._clear_scope(str(group_identifier)) + def get_guilds_lock(self) -> asyncio.Lock: + """Get a lock for all guild data. + + Returns + ------- + asyncio.Lock + """ + return self.get_custom_lock(self.GUILD) + + def get_channels_lock(self) -> asyncio.Lock: + """Get a lock for all channel data. + + Returns + ------- + asyncio.Lock + """ + return self.get_custom_lock(self.CHANNEL) + + def get_roles_lock(self) -> asyncio.Lock: + """Get a lock for all role data. + + Returns + ------- + asyncio.Lock + """ + return self.get_custom_lock(self.ROLE) + + def get_users_lock(self) -> asyncio.Lock: + """Get a lock for all user data. + + Returns + ------- + asyncio.Lock + """ + return self.get_custom_lock(self.USER) + + def get_members_lock(self, guild: Optional[discord.Guild] = None) -> asyncio.Lock: + """Get a lock for all member data. + + Parameters + ---------- + guild : Optional[discord.Guild] + The guild containing the members whose data you want to + lock. Omit to lock all data for all members in all guilds. + + Returns + ------- + asyncio.Lock + """ + if guild is None: + return self.get_custom_lock(self.GUILD) + else: + id_data = IdentifierData( + self.unique_identifier, self.MEMBER, (str(guild.id),), (), self.custom_groups + ) + return self._lock_cache.setdefault(id_data, asyncio.Lock()) + + def get_custom_lock(self, group_identifier: str) -> asyncio.Lock: + """Get a lock for all data in a custom scope. + + Parameters + ---------- + group_identifier : str + The group identifier for the custom scope you want to lock. + + Returns + ------- + asyncio.Lock + """ + id_data = IdentifierData( + self.unique_identifier, group_identifier, (), (), self.custom_groups + ) + return self._lock_cache.setdefault(id_data, asyncio.Lock()) + def _str_key_dict(value: Dict[Any, _T]) -> Dict[str, _T]: """ diff --git a/redbot/core/drivers/red_base.py b/redbot/core/drivers/red_base.py index 2e5d41c6f..de28b1b9a 100644 --- a/redbot/core/drivers/red_base.py +++ b/redbot/core/drivers/red_base.py @@ -56,6 +56,19 @@ class IdentifierData: f" identifiers={self.identifiers}>" ) + def __eq__(self, other) -> bool: + if not isinstance(other, IdentifierData): + return False + return ( + self.uuid == other.uuid + and self.category == other.category + and self.primary_key == other.primary_key + and self.identifiers == other.identifiers + ) + + def __hash__(self) -> int: + return hash((self.uuid, self.category, self.primary_key, self.identifiers)) + def add_identifier(self, *identifier: str) -> "IdentifierData": if not all(isinstance(i, str) for i in identifier): raise ValueError("Identifiers must be strings.") diff --git a/tests/core/test_config.py b/tests/core/test_config.py index a1e74dd2f..5cd9e0e90 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -1,3 +1,4 @@ +import asyncio from unittest.mock import patch import pytest @@ -506,3 +507,50 @@ def test_config_custom_doubleinit(config): config.init_custom("TEST", 3) with pytest.raises(ValueError): config.init_custom("TEST", 2) + + +@pytest.mark.asyncio +async def test_config_locks_cache(config, empty_guild): + lock1 = config.foo.get_lock() + assert lock1 is config.foo.get_lock() + lock2 = config.guild(empty_guild).foo.get_lock() + assert lock2 is config.guild(empty_guild).foo.get_lock() + assert lock1 is not lock2 + + +@pytest.mark.asyncio +async def test_config_value_atomicity(config): + config.register_global(foo=[]) + tasks = [] + for _ in range(15): + + async def func(): + async with config.foo.get_lock(): + foo = await config.foo() + foo.append(0) + await asyncio.sleep(0.1) + await config.foo.set(foo) + + tasks.append(func()) + + await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) + + assert len(await config.foo()) == 15 + + +@pytest.mark.asyncio +async def test_config_ctxmgr_atomicity(config): + config.register_global(foo=[]) + tasks = [] + for _ in range(15): + + async def func(): + async with config.foo() as foo: + foo.append(0) + await asyncio.sleep(0.1) + + tasks.append(func()) + + await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) + + assert len(await config.foo()) == 15