mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 03:08:55 -05:00
Config locks (#2654)
* Config locks Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Add locks for all_XXX Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Remove a word Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Add acquire_lock kwarg for value context manager Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Add towncrier entry Signed-off-by: Toby <tobyharradine@gmail.com> * Fix issues with `get_custom_lock` and `get_members_lock` Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
parent
a8091332b8
commit
af096bc1cc
1
changelog.d/2654.feature.rst
Normal file
1
changelog.d/2654.feature.rst
Normal file
@ -0,0 +1 @@
|
|||||||
|
Added functions to acquire locks on Config groups and values. These locks are acquired by default when calling a value as a context manager. See :meth:`Value.get_lock` for details
|
||||||
@ -1,8 +1,19 @@
|
|||||||
|
import asyncio
|
||||||
import collections
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import pickle
|
import pickle
|
||||||
import weakref
|
import weakref
|
||||||
from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar
|
from typing import (
|
||||||
|
Any,
|
||||||
|
Union,
|
||||||
|
Tuple,
|
||||||
|
Dict,
|
||||||
|
Awaitable,
|
||||||
|
AsyncContextManager,
|
||||||
|
TypeVar,
|
||||||
|
MutableMapping,
|
||||||
|
Optional,
|
||||||
|
)
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
@ -38,18 +49,26 @@ class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]): # pylint: disab
|
|||||||
i.e. `dict`s or `list`s. This is because this class's ``raw_value``
|
i.e. `dict`s or `list`s. This is because this class's ``raw_value``
|
||||||
attribute must contain a reference to the object being modified within the
|
attribute must contain a reference to the object being modified within the
|
||||||
context manager.
|
context manager.
|
||||||
|
|
||||||
|
It should also be noted that the use of this context manager implies
|
||||||
|
the acquisition of the value's lock when the ``acquire_lock`` kwarg
|
||||||
|
to ``__init__`` is set to ``True``.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, value_obj, coro):
|
def __init__(self, value_obj: "Value", coro: Awaitable[Any], *, acquire_lock: bool):
|
||||||
self.value_obj = value_obj
|
self.value_obj = value_obj
|
||||||
self.coro = coro
|
self.coro = coro
|
||||||
self.raw_value = None
|
self.raw_value = None
|
||||||
self.__original_value = None
|
self.__original_value = None
|
||||||
|
self.__acquire_lock = acquire_lock
|
||||||
|
self.__lock = self.value_obj.get_lock()
|
||||||
|
|
||||||
def __await__(self):
|
def __await__(self):
|
||||||
return self.coro.__await__()
|
return self.coro.__await__()
|
||||||
|
|
||||||
async def __aenter__(self):
|
async def __aenter__(self):
|
||||||
|
if self.__acquire_lock is True:
|
||||||
|
await self.__lock.acquire()
|
||||||
self.raw_value = await self
|
self.raw_value = await self
|
||||||
if not isinstance(self.raw_value, (list, dict)):
|
if not isinstance(self.raw_value, (list, dict)):
|
||||||
raise TypeError(
|
raise TypeError(
|
||||||
@ -61,12 +80,16 @@ class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]): # pylint: disab
|
|||||||
return self.raw_value
|
return self.raw_value
|
||||||
|
|
||||||
async def __aexit__(self, exc_type, exc, tb):
|
async def __aexit__(self, exc_type, exc, tb):
|
||||||
|
try:
|
||||||
if isinstance(self.raw_value, dict):
|
if isinstance(self.raw_value, dict):
|
||||||
raw_value = _str_key_dict(self.raw_value)
|
raw_value = _str_key_dict(self.raw_value)
|
||||||
else:
|
else:
|
||||||
raw_value = self.raw_value
|
raw_value = self.raw_value
|
||||||
if raw_value != self.__original_value:
|
if raw_value != self.__original_value:
|
||||||
await self.value_obj.set(self.raw_value)
|
await self.value_obj.set(self.raw_value)
|
||||||
|
finally:
|
||||||
|
if self.__acquire_lock is True:
|
||||||
|
self.__lock.release()
|
||||||
|
|
||||||
|
|
||||||
class Value:
|
class Value:
|
||||||
@ -74,9 +97,8 @@ class Value:
|
|||||||
|
|
||||||
Attributes
|
Attributes
|
||||||
----------
|
----------
|
||||||
identifiers : Tuple[str]
|
identifier_data : IdentifierData
|
||||||
This attribute provides all the keys necessary to get a specific data
|
Information on identifiers for this value.
|
||||||
element from a json document.
|
|
||||||
default
|
default
|
||||||
The default value for the data element that `identifiers` points at.
|
The default value for the data element that `identifiers` points at.
|
||||||
driver : `redbot.core.drivers.red_base.BaseDriver`
|
driver : `redbot.core.drivers.red_base.BaseDriver`
|
||||||
@ -84,10 +106,44 @@ class Value:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, identifier_data: IdentifierData, default_value, driver):
|
def __init__(self, identifier_data: IdentifierData, default_value, driver, config: "Config"):
|
||||||
self.identifier_data = identifier_data
|
self.identifier_data = identifier_data
|
||||||
self.default = default_value
|
self.default = default_value
|
||||||
self.driver = driver
|
self.driver = driver
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
def get_lock(self) -> asyncio.Lock:
|
||||||
|
"""Get a lock to create a critical region where this value is accessed.
|
||||||
|
|
||||||
|
When using this lock, make sure you either use it with the
|
||||||
|
``async with`` syntax, or if that's not feasible, ensure you
|
||||||
|
keep a reference to it from the acquisition to the release of
|
||||||
|
the lock. That is, if you can't use ``async with`` syntax, use
|
||||||
|
the lock like this::
|
||||||
|
|
||||||
|
lock = config.foo.get_lock()
|
||||||
|
await lock.acquire()
|
||||||
|
# Do stuff...
|
||||||
|
lock.release()
|
||||||
|
|
||||||
|
Do not use it like this::
|
||||||
|
|
||||||
|
await config.foo.get_lock().acquire()
|
||||||
|
# Do stuff...
|
||||||
|
config.foo.get_lock().release()
|
||||||
|
|
||||||
|
Doing it the latter way will likely cause an error, as the
|
||||||
|
acquired lock will be cleaned up by the garbage collector before
|
||||||
|
it is released, meaning the second call to ``get_lock()`` will
|
||||||
|
return a different lock to the first call.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
asyncio.Lock
|
||||||
|
A lock which is weakly cached for this value object.
|
||||||
|
|
||||||
|
"""
|
||||||
|
return self._config._lock_cache.setdefault(self.identifier_data, asyncio.Lock())
|
||||||
|
|
||||||
async def _get(self, default=...):
|
async def _get(self, default=...):
|
||||||
try:
|
try:
|
||||||
@ -96,7 +152,7 @@ class Value:
|
|||||||
return default if default is not ... else self.default
|
return default if default is not ... else self.default
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def __call__(self, default=...) -> _ValueCtxManager[Any]:
|
def __call__(self, default=..., *, acquire_lock: bool = True) -> _ValueCtxManager[Any]:
|
||||||
"""Get the literal value of this data element.
|
"""Get the literal value of this data element.
|
||||||
|
|
||||||
Each `Value` object is created by the `Group.__getattr__` method. The
|
Each `Value` object is created by the `Group.__getattr__` method. The
|
||||||
@ -106,7 +162,10 @@ class Value:
|
|||||||
The return value of this method can also be used as an asynchronous
|
The return value of this method can also be used as an asynchronous
|
||||||
context manager, i.e. with :code:`async with` syntax. This can only be
|
context manager, i.e. with :code:`async with` syntax. This can only be
|
||||||
used on values which are mutable (namely lists and dicts), and will
|
used on values which are mutable (namely lists and dicts), and will
|
||||||
set the value with its changes on exit of the context manager.
|
set the value with its changes on exit of the context manager. It will
|
||||||
|
also acquire this value's lock to protect the critical region inside
|
||||||
|
this context manager's body, unless the ``acquire_lock`` keyword
|
||||||
|
argument is set to ``False``.
|
||||||
|
|
||||||
Example
|
Example
|
||||||
-------
|
-------
|
||||||
@ -129,7 +188,14 @@ class Value:
|
|||||||
default : `object`, optional
|
default : `object`, optional
|
||||||
This argument acts as an override for the registered default
|
This argument acts as an override for the registered default
|
||||||
provided by `default`. This argument is ignored if its
|
provided by `default`. This argument is ignored if its
|
||||||
value is :code:`None`.
|
value is :code:`...`.
|
||||||
|
|
||||||
|
Other Parameters
|
||||||
|
----------------
|
||||||
|
acquire_lock : bool
|
||||||
|
Set to ``False`` to disable the acquisition of the value's
|
||||||
|
lock over the context manager body. Defaults to ``True``.
|
||||||
|
Has no effect when not used as a context manager.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -139,7 +205,7 @@ class Value:
|
|||||||
with` syntax, on gets the value on entrance, and sets it on exit.
|
with` syntax, on gets the value on entrance, and sets it on exit.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return _ValueCtxManager(self, self._get(default))
|
return _ValueCtxManager(self, self._get(default), acquire_lock=acquire_lock)
|
||||||
|
|
||||||
async def set(self, value):
|
async def set(self, value):
|
||||||
"""Set the value of the data elements pointed to by `identifiers`.
|
"""Set the value of the data elements pointed to by `identifiers`.
|
||||||
@ -194,13 +260,14 @@ class Group(Value):
|
|||||||
identifier_data: IdentifierData,
|
identifier_data: IdentifierData,
|
||||||
defaults: dict,
|
defaults: dict,
|
||||||
driver,
|
driver,
|
||||||
|
config: "Config",
|
||||||
force_registration: bool = False,
|
force_registration: bool = False,
|
||||||
):
|
):
|
||||||
self._defaults = defaults
|
self._defaults = defaults
|
||||||
self.force_registration = force_registration
|
self.force_registration = force_registration
|
||||||
self.driver = driver
|
self.driver = driver
|
||||||
|
|
||||||
super().__init__(identifier_data, {}, self.driver)
|
super().__init__(identifier_data, {}, self.driver, config)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def defaults(self):
|
def defaults(self):
|
||||||
@ -248,17 +315,24 @@ class Group(Value):
|
|||||||
defaults=self._defaults[item],
|
defaults=self._defaults[item],
|
||||||
driver=self.driver,
|
driver=self.driver,
|
||||||
force_registration=self.force_registration,
|
force_registration=self.force_registration,
|
||||||
|
config=self._config,
|
||||||
)
|
)
|
||||||
elif is_value:
|
elif is_value:
|
||||||
return Value(
|
return Value(
|
||||||
identifier_data=new_identifiers,
|
identifier_data=new_identifiers,
|
||||||
default_value=self._defaults[item],
|
default_value=self._defaults[item],
|
||||||
driver=self.driver,
|
driver=self.driver,
|
||||||
|
config=self._config,
|
||||||
)
|
)
|
||||||
elif self.force_registration:
|
elif self.force_registration:
|
||||||
raise AttributeError("'{}' is not a valid registered Group or value.".format(item))
|
raise AttributeError("'{}' is not a valid registered Group or value.".format(item))
|
||||||
else:
|
else:
|
||||||
return Value(identifier_data=new_identifiers, default_value=None, driver=self.driver)
|
return Value(
|
||||||
|
identifier_data=new_identifiers,
|
||||||
|
default_value=None,
|
||||||
|
driver=self.driver,
|
||||||
|
config=self._config,
|
||||||
|
)
|
||||||
|
|
||||||
async def clear_raw(self, *nested_path: Any):
|
async def clear_raw(self, *nested_path: Any):
|
||||||
"""
|
"""
|
||||||
@ -411,7 +485,7 @@ class Group(Value):
|
|||||||
return self.nested_update(raw, default)
|
return self.nested_update(raw, default)
|
||||||
return raw
|
return raw
|
||||||
|
|
||||||
def all(self) -> _ValueCtxManager[Dict[str, Any]]:
|
def all(self, *, acquire_lock: bool = True) -> _ValueCtxManager[Dict[str, Any]]:
|
||||||
"""Get a dictionary representation of this group's data.
|
"""Get a dictionary representation of this group's data.
|
||||||
|
|
||||||
The return value of this method can also be used as an asynchronous
|
The return value of this method can also be used as an asynchronous
|
||||||
@ -422,13 +496,19 @@ class Group(Value):
|
|||||||
The return value of this method will include registered defaults for
|
The return value of this method will include registered defaults for
|
||||||
values which have not yet been set.
|
values which have not yet been set.
|
||||||
|
|
||||||
|
Other Parameters
|
||||||
|
----------------
|
||||||
|
acquire_lock : bool
|
||||||
|
Same as the ``acquire_lock`` keyword parameter in
|
||||||
|
`Value.__call__`.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
dict
|
dict
|
||||||
All of this Group's attributes, resolved as raw data values.
|
All of this Group's attributes, resolved as raw data values.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
return self()
|
return self(acquire_lock=acquire_lock)
|
||||||
|
|
||||||
def nested_update(
|
def nested_update(
|
||||||
self, current: collections.Mapping, defaults: Dict[str, Any] = ...
|
self, current: collections.Mapping, defaults: Dict[str, Any] = ...
|
||||||
@ -555,6 +635,9 @@ class Config:
|
|||||||
self._defaults = defaults or {}
|
self._defaults = defaults or {}
|
||||||
|
|
||||||
self.custom_groups = {}
|
self.custom_groups = {}
|
||||||
|
self._lock_cache: MutableMapping[
|
||||||
|
IdentifierData, asyncio.Lock
|
||||||
|
] = weakref.WeakValueDictionary()
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def defaults(self):
|
def defaults(self):
|
||||||
@ -862,6 +945,7 @@ class Config:
|
|||||||
defaults=defaults,
|
defaults=defaults,
|
||||||
driver=self.driver,
|
driver=self.driver,
|
||||||
force_registration=self.force_registration,
|
force_registration=self.force_registration,
|
||||||
|
config=self,
|
||||||
)
|
)
|
||||||
|
|
||||||
def guild(self, guild: discord.Guild) -> Group:
|
def guild(self, guild: discord.Guild) -> Group:
|
||||||
@ -1139,7 +1223,7 @@ class Config:
|
|||||||
identifier_data = IdentifierData(
|
identifier_data = IdentifierData(
|
||||||
self.unique_identifier, "", (), (), self.custom_groups
|
self.unique_identifier, "", (), (), self.custom_groups
|
||||||
)
|
)
|
||||||
group = Group(identifier_data, defaults={}, driver=self.driver)
|
group = Group(identifier_data, defaults={}, driver=self.driver, config=self)
|
||||||
else:
|
else:
|
||||||
cat, *scopes = scopes
|
cat, *scopes = scopes
|
||||||
group = self._get_base_group(cat, *scopes)
|
group = self._get_base_group(cat, *scopes)
|
||||||
@ -1222,6 +1306,80 @@ class Config:
|
|||||||
"""
|
"""
|
||||||
await self._clear_scope(str(group_identifier))
|
await self._clear_scope(str(group_identifier))
|
||||||
|
|
||||||
|
def get_guilds_lock(self) -> asyncio.Lock:
|
||||||
|
"""Get a lock for all guild data.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
asyncio.Lock
|
||||||
|
"""
|
||||||
|
return self.get_custom_lock(self.GUILD)
|
||||||
|
|
||||||
|
def get_channels_lock(self) -> asyncio.Lock:
|
||||||
|
"""Get a lock for all channel data.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
asyncio.Lock
|
||||||
|
"""
|
||||||
|
return self.get_custom_lock(self.CHANNEL)
|
||||||
|
|
||||||
|
def get_roles_lock(self) -> asyncio.Lock:
|
||||||
|
"""Get a lock for all role data.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
asyncio.Lock
|
||||||
|
"""
|
||||||
|
return self.get_custom_lock(self.ROLE)
|
||||||
|
|
||||||
|
def get_users_lock(self) -> asyncio.Lock:
|
||||||
|
"""Get a lock for all user data.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
asyncio.Lock
|
||||||
|
"""
|
||||||
|
return self.get_custom_lock(self.USER)
|
||||||
|
|
||||||
|
def get_members_lock(self, guild: Optional[discord.Guild] = None) -> asyncio.Lock:
|
||||||
|
"""Get a lock for all member data.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
guild : Optional[discord.Guild]
|
||||||
|
The guild containing the members whose data you want to
|
||||||
|
lock. Omit to lock all data for all members in all guilds.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
asyncio.Lock
|
||||||
|
"""
|
||||||
|
if guild is None:
|
||||||
|
return self.get_custom_lock(self.GUILD)
|
||||||
|
else:
|
||||||
|
id_data = IdentifierData(
|
||||||
|
self.unique_identifier, self.MEMBER, (str(guild.id),), (), self.custom_groups
|
||||||
|
)
|
||||||
|
return self._lock_cache.setdefault(id_data, asyncio.Lock())
|
||||||
|
|
||||||
|
def get_custom_lock(self, group_identifier: str) -> asyncio.Lock:
|
||||||
|
"""Get a lock for all data in a custom scope.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
group_identifier : str
|
||||||
|
The group identifier for the custom scope you want to lock.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
asyncio.Lock
|
||||||
|
"""
|
||||||
|
id_data = IdentifierData(
|
||||||
|
self.unique_identifier, group_identifier, (), (), self.custom_groups
|
||||||
|
)
|
||||||
|
return self._lock_cache.setdefault(id_data, asyncio.Lock())
|
||||||
|
|
||||||
|
|
||||||
def _str_key_dict(value: Dict[Any, _T]) -> Dict[str, _T]:
|
def _str_key_dict(value: Dict[Any, _T]) -> Dict[str, _T]:
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -56,6 +56,19 @@ class IdentifierData:
|
|||||||
f" identifiers={self.identifiers}>"
|
f" identifiers={self.identifiers}>"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def __eq__(self, other) -> bool:
|
||||||
|
if not isinstance(other, IdentifierData):
|
||||||
|
return False
|
||||||
|
return (
|
||||||
|
self.uuid == other.uuid
|
||||||
|
and self.category == other.category
|
||||||
|
and self.primary_key == other.primary_key
|
||||||
|
and self.identifiers == other.identifiers
|
||||||
|
)
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
return hash((self.uuid, self.category, self.primary_key, self.identifiers))
|
||||||
|
|
||||||
def add_identifier(self, *identifier: str) -> "IdentifierData":
|
def add_identifier(self, *identifier: str) -> "IdentifierData":
|
||||||
if not all(isinstance(i, str) for i in identifier):
|
if not all(isinstance(i, str) for i in identifier):
|
||||||
raise ValueError("Identifiers must be strings.")
|
raise ValueError("Identifiers must be strings.")
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -506,3 +507,50 @@ def test_config_custom_doubleinit(config):
|
|||||||
config.init_custom("TEST", 3)
|
config.init_custom("TEST", 3)
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
config.init_custom("TEST", 2)
|
config.init_custom("TEST", 2)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_config_locks_cache(config, empty_guild):
|
||||||
|
lock1 = config.foo.get_lock()
|
||||||
|
assert lock1 is config.foo.get_lock()
|
||||||
|
lock2 = config.guild(empty_guild).foo.get_lock()
|
||||||
|
assert lock2 is config.guild(empty_guild).foo.get_lock()
|
||||||
|
assert lock1 is not lock2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_config_value_atomicity(config):
|
||||||
|
config.register_global(foo=[])
|
||||||
|
tasks = []
|
||||||
|
for _ in range(15):
|
||||||
|
|
||||||
|
async def func():
|
||||||
|
async with config.foo.get_lock():
|
||||||
|
foo = await config.foo()
|
||||||
|
foo.append(0)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
await config.foo.set(foo)
|
||||||
|
|
||||||
|
tasks.append(func())
|
||||||
|
|
||||||
|
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
|
||||||
|
|
||||||
|
assert len(await config.foo()) == 15
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_config_ctxmgr_atomicity(config):
|
||||||
|
config.register_global(foo=[])
|
||||||
|
tasks = []
|
||||||
|
for _ in range(15):
|
||||||
|
|
||||||
|
async def func():
|
||||||
|
async with config.foo() as foo:
|
||||||
|
foo.append(0)
|
||||||
|
await asyncio.sleep(0.1)
|
||||||
|
|
||||||
|
tasks.append(func())
|
||||||
|
|
||||||
|
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
|
||||||
|
|
||||||
|
assert len(await config.foo()) == 15
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user