mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 03:08:55 -05:00
Config locks (#2654)
* Config locks Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Add locks for all_XXX Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Remove a word Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Add acquire_lock kwarg for value context manager Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Add towncrier entry Signed-off-by: Toby <tobyharradine@gmail.com> * Fix issues with `get_custom_lock` and `get_members_lock` Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
parent
a8091332b8
commit
af096bc1cc
1
changelog.d/2654.feature.rst
Normal file
1
changelog.d/2654.feature.rst
Normal file
@ -0,0 +1 @@
|
||||
Added functions to acquire locks on Config groups and values. These locks are acquired by default when calling a value as a context manager. See :meth:`Value.get_lock` for details
|
||||
@ -1,8 +1,19 @@
|
||||
import asyncio
|
||||
import collections
|
||||
import logging
|
||||
import pickle
|
||||
import weakref
|
||||
from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar
|
||||
from typing import (
|
||||
Any,
|
||||
Union,
|
||||
Tuple,
|
||||
Dict,
|
||||
Awaitable,
|
||||
AsyncContextManager,
|
||||
TypeVar,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
)
|
||||
|
||||
import discord
|
||||
|
||||
@ -38,18 +49,26 @@ class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]): # pylint: disab
|
||||
i.e. `dict`s or `list`s. This is because this class's ``raw_value``
|
||||
attribute must contain a reference to the object being modified within the
|
||||
context manager.
|
||||
|
||||
It should also be noted that the use of this context manager implies
|
||||
the acquisition of the value's lock when the ``acquire_lock`` kwarg
|
||||
to ``__init__`` is set to ``True``.
|
||||
"""
|
||||
|
||||
def __init__(self, value_obj, coro):
|
||||
def __init__(self, value_obj: "Value", coro: Awaitable[Any], *, acquire_lock: bool):
|
||||
self.value_obj = value_obj
|
||||
self.coro = coro
|
||||
self.raw_value = None
|
||||
self.__original_value = None
|
||||
self.__acquire_lock = acquire_lock
|
||||
self.__lock = self.value_obj.get_lock()
|
||||
|
||||
def __await__(self):
|
||||
return self.coro.__await__()
|
||||
|
||||
async def __aenter__(self):
|
||||
if self.__acquire_lock is True:
|
||||
await self.__lock.acquire()
|
||||
self.raw_value = await self
|
||||
if not isinstance(self.raw_value, (list, dict)):
|
||||
raise TypeError(
|
||||
@ -61,12 +80,16 @@ class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]): # pylint: disab
|
||||
return self.raw_value
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
try:
|
||||
if isinstance(self.raw_value, dict):
|
||||
raw_value = _str_key_dict(self.raw_value)
|
||||
else:
|
||||
raw_value = self.raw_value
|
||||
if raw_value != self.__original_value:
|
||||
await self.value_obj.set(self.raw_value)
|
||||
finally:
|
||||
if self.__acquire_lock is True:
|
||||
self.__lock.release()
|
||||
|
||||
|
||||
class Value:
|
||||
@ -74,9 +97,8 @@ class Value:
|
||||
|
||||
Attributes
|
||||
----------
|
||||
identifiers : Tuple[str]
|
||||
This attribute provides all the keys necessary to get a specific data
|
||||
element from a json document.
|
||||
identifier_data : IdentifierData
|
||||
Information on identifiers for this value.
|
||||
default
|
||||
The default value for the data element that `identifiers` points at.
|
||||
driver : `redbot.core.drivers.red_base.BaseDriver`
|
||||
@ -84,10 +106,44 @@ class Value:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, identifier_data: IdentifierData, default_value, driver):
|
||||
def __init__(self, identifier_data: IdentifierData, default_value, driver, config: "Config"):
|
||||
self.identifier_data = identifier_data
|
||||
self.default = default_value
|
||||
self.driver = driver
|
||||
self._config = config
|
||||
|
||||
def get_lock(self) -> asyncio.Lock:
|
||||
"""Get a lock to create a critical region where this value is accessed.
|
||||
|
||||
When using this lock, make sure you either use it with the
|
||||
``async with`` syntax, or if that's not feasible, ensure you
|
||||
keep a reference to it from the acquisition to the release of
|
||||
the lock. That is, if you can't use ``async with`` syntax, use
|
||||
the lock like this::
|
||||
|
||||
lock = config.foo.get_lock()
|
||||
await lock.acquire()
|
||||
# Do stuff...
|
||||
lock.release()
|
||||
|
||||
Do not use it like this::
|
||||
|
||||
await config.foo.get_lock().acquire()
|
||||
# Do stuff...
|
||||
config.foo.get_lock().release()
|
||||
|
||||
Doing it the latter way will likely cause an error, as the
|
||||
acquired lock will be cleaned up by the garbage collector before
|
||||
it is released, meaning the second call to ``get_lock()`` will
|
||||
return a different lock to the first call.
|
||||
|
||||
Returns
|
||||
-------
|
||||
asyncio.Lock
|
||||
A lock which is weakly cached for this value object.
|
||||
|
||||
"""
|
||||
return self._config._lock_cache.setdefault(self.identifier_data, asyncio.Lock())
|
||||
|
||||
async def _get(self, default=...):
|
||||
try:
|
||||
@ -96,7 +152,7 @@ class Value:
|
||||
return default if default is not ... else self.default
|
||||
return ret
|
||||
|
||||
def __call__(self, default=...) -> _ValueCtxManager[Any]:
|
||||
def __call__(self, default=..., *, acquire_lock: bool = True) -> _ValueCtxManager[Any]:
|
||||
"""Get the literal value of this data element.
|
||||
|
||||
Each `Value` object is created by the `Group.__getattr__` method. The
|
||||
@ -106,7 +162,10 @@ class Value:
|
||||
The return value of this method can also be used as an asynchronous
|
||||
context manager, i.e. with :code:`async with` syntax. This can only be
|
||||
used on values which are mutable (namely lists and dicts), and will
|
||||
set the value with its changes on exit of the context manager.
|
||||
set the value with its changes on exit of the context manager. It will
|
||||
also acquire this value's lock to protect the critical region inside
|
||||
this context manager's body, unless the ``acquire_lock`` keyword
|
||||
argument is set to ``False``.
|
||||
|
||||
Example
|
||||
-------
|
||||
@ -129,7 +188,14 @@ class Value:
|
||||
default : `object`, optional
|
||||
This argument acts as an override for the registered default
|
||||
provided by `default`. This argument is ignored if its
|
||||
value is :code:`None`.
|
||||
value is :code:`...`.
|
||||
|
||||
Other Parameters
|
||||
----------------
|
||||
acquire_lock : bool
|
||||
Set to ``False`` to disable the acquisition of the value's
|
||||
lock over the context manager body. Defaults to ``True``.
|
||||
Has no effect when not used as a context manager.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -139,7 +205,7 @@ class Value:
|
||||
with` syntax, on gets the value on entrance, and sets it on exit.
|
||||
|
||||
"""
|
||||
return _ValueCtxManager(self, self._get(default))
|
||||
return _ValueCtxManager(self, self._get(default), acquire_lock=acquire_lock)
|
||||
|
||||
async def set(self, value):
|
||||
"""Set the value of the data elements pointed to by `identifiers`.
|
||||
@ -194,13 +260,14 @@ class Group(Value):
|
||||
identifier_data: IdentifierData,
|
||||
defaults: dict,
|
||||
driver,
|
||||
config: "Config",
|
||||
force_registration: bool = False,
|
||||
):
|
||||
self._defaults = defaults
|
||||
self.force_registration = force_registration
|
||||
self.driver = driver
|
||||
|
||||
super().__init__(identifier_data, {}, self.driver)
|
||||
super().__init__(identifier_data, {}, self.driver, config)
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
@ -248,17 +315,24 @@ class Group(Value):
|
||||
defaults=self._defaults[item],
|
||||
driver=self.driver,
|
||||
force_registration=self.force_registration,
|
||||
config=self._config,
|
||||
)
|
||||
elif is_value:
|
||||
return Value(
|
||||
identifier_data=new_identifiers,
|
||||
default_value=self._defaults[item],
|
||||
driver=self.driver,
|
||||
config=self._config,
|
||||
)
|
||||
elif self.force_registration:
|
||||
raise AttributeError("'{}' is not a valid registered Group or value.".format(item))
|
||||
else:
|
||||
return Value(identifier_data=new_identifiers, default_value=None, driver=self.driver)
|
||||
return Value(
|
||||
identifier_data=new_identifiers,
|
||||
default_value=None,
|
||||
driver=self.driver,
|
||||
config=self._config,
|
||||
)
|
||||
|
||||
async def clear_raw(self, *nested_path: Any):
|
||||
"""
|
||||
@ -411,7 +485,7 @@ class Group(Value):
|
||||
return self.nested_update(raw, default)
|
||||
return raw
|
||||
|
||||
def all(self) -> _ValueCtxManager[Dict[str, Any]]:
|
||||
def all(self, *, acquire_lock: bool = True) -> _ValueCtxManager[Dict[str, Any]]:
|
||||
"""Get a dictionary representation of this group's data.
|
||||
|
||||
The return value of this method can also be used as an asynchronous
|
||||
@ -422,13 +496,19 @@ class Group(Value):
|
||||
The return value of this method will include registered defaults for
|
||||
values which have not yet been set.
|
||||
|
||||
Other Parameters
|
||||
----------------
|
||||
acquire_lock : bool
|
||||
Same as the ``acquire_lock`` keyword parameter in
|
||||
`Value.__call__`.
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
All of this Group's attributes, resolved as raw data values.
|
||||
|
||||
"""
|
||||
return self()
|
||||
return self(acquire_lock=acquire_lock)
|
||||
|
||||
def nested_update(
|
||||
self, current: collections.Mapping, defaults: Dict[str, Any] = ...
|
||||
@ -555,6 +635,9 @@ class Config:
|
||||
self._defaults = defaults or {}
|
||||
|
||||
self.custom_groups = {}
|
||||
self._lock_cache: MutableMapping[
|
||||
IdentifierData, asyncio.Lock
|
||||
] = weakref.WeakValueDictionary()
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
@ -862,6 +945,7 @@ class Config:
|
||||
defaults=defaults,
|
||||
driver=self.driver,
|
||||
force_registration=self.force_registration,
|
||||
config=self,
|
||||
)
|
||||
|
||||
def guild(self, guild: discord.Guild) -> Group:
|
||||
@ -1139,7 +1223,7 @@ class Config:
|
||||
identifier_data = IdentifierData(
|
||||
self.unique_identifier, "", (), (), self.custom_groups
|
||||
)
|
||||
group = Group(identifier_data, defaults={}, driver=self.driver)
|
||||
group = Group(identifier_data, defaults={}, driver=self.driver, config=self)
|
||||
else:
|
||||
cat, *scopes = scopes
|
||||
group = self._get_base_group(cat, *scopes)
|
||||
@ -1222,6 +1306,80 @@ class Config:
|
||||
"""
|
||||
await self._clear_scope(str(group_identifier))
|
||||
|
||||
def get_guilds_lock(self) -> asyncio.Lock:
|
||||
"""Get a lock for all guild data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
asyncio.Lock
|
||||
"""
|
||||
return self.get_custom_lock(self.GUILD)
|
||||
|
||||
def get_channels_lock(self) -> asyncio.Lock:
|
||||
"""Get a lock for all channel data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
asyncio.Lock
|
||||
"""
|
||||
return self.get_custom_lock(self.CHANNEL)
|
||||
|
||||
def get_roles_lock(self) -> asyncio.Lock:
|
||||
"""Get a lock for all role data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
asyncio.Lock
|
||||
"""
|
||||
return self.get_custom_lock(self.ROLE)
|
||||
|
||||
def get_users_lock(self) -> asyncio.Lock:
|
||||
"""Get a lock for all user data.
|
||||
|
||||
Returns
|
||||
-------
|
||||
asyncio.Lock
|
||||
"""
|
||||
return self.get_custom_lock(self.USER)
|
||||
|
||||
def get_members_lock(self, guild: Optional[discord.Guild] = None) -> asyncio.Lock:
|
||||
"""Get a lock for all member data.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
guild : Optional[discord.Guild]
|
||||
The guild containing the members whose data you want to
|
||||
lock. Omit to lock all data for all members in all guilds.
|
||||
|
||||
Returns
|
||||
-------
|
||||
asyncio.Lock
|
||||
"""
|
||||
if guild is None:
|
||||
return self.get_custom_lock(self.GUILD)
|
||||
else:
|
||||
id_data = IdentifierData(
|
||||
self.unique_identifier, self.MEMBER, (str(guild.id),), (), self.custom_groups
|
||||
)
|
||||
return self._lock_cache.setdefault(id_data, asyncio.Lock())
|
||||
|
||||
def get_custom_lock(self, group_identifier: str) -> asyncio.Lock:
|
||||
"""Get a lock for all data in a custom scope.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
group_identifier : str
|
||||
The group identifier for the custom scope you want to lock.
|
||||
|
||||
Returns
|
||||
-------
|
||||
asyncio.Lock
|
||||
"""
|
||||
id_data = IdentifierData(
|
||||
self.unique_identifier, group_identifier, (), (), self.custom_groups
|
||||
)
|
||||
return self._lock_cache.setdefault(id_data, asyncio.Lock())
|
||||
|
||||
|
||||
def _str_key_dict(value: Dict[Any, _T]) -> Dict[str, _T]:
|
||||
"""
|
||||
|
||||
@ -56,6 +56,19 @@ class IdentifierData:
|
||||
f" identifiers={self.identifiers}>"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, IdentifierData):
|
||||
return False
|
||||
return (
|
||||
self.uuid == other.uuid
|
||||
and self.category == other.category
|
||||
and self.primary_key == other.primary_key
|
||||
and self.identifiers == other.identifiers
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.uuid, self.category, self.primary_key, self.identifiers))
|
||||
|
||||
def add_identifier(self, *identifier: str) -> "IdentifierData":
|
||||
if not all(isinstance(i, str) for i in identifier):
|
||||
raise ValueError("Identifiers must be strings.")
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
@ -506,3 +507,50 @@ def test_config_custom_doubleinit(config):
|
||||
config.init_custom("TEST", 3)
|
||||
with pytest.raises(ValueError):
|
||||
config.init_custom("TEST", 2)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_locks_cache(config, empty_guild):
|
||||
lock1 = config.foo.get_lock()
|
||||
assert lock1 is config.foo.get_lock()
|
||||
lock2 = config.guild(empty_guild).foo.get_lock()
|
||||
assert lock2 is config.guild(empty_guild).foo.get_lock()
|
||||
assert lock1 is not lock2
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_value_atomicity(config):
|
||||
config.register_global(foo=[])
|
||||
tasks = []
|
||||
for _ in range(15):
|
||||
|
||||
async def func():
|
||||
async with config.foo.get_lock():
|
||||
foo = await config.foo()
|
||||
foo.append(0)
|
||||
await asyncio.sleep(0.1)
|
||||
await config.foo.set(foo)
|
||||
|
||||
tasks.append(func())
|
||||
|
||||
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
|
||||
|
||||
assert len(await config.foo()) == 15
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_config_ctxmgr_atomicity(config):
|
||||
config.register_global(foo=[])
|
||||
tasks = []
|
||||
for _ in range(15):
|
||||
|
||||
async def func():
|
||||
async with config.foo() as foo:
|
||||
foo.append(0)
|
||||
await asyncio.sleep(0.1)
|
||||
|
||||
tasks.append(func())
|
||||
|
||||
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
|
||||
|
||||
assert len(await config.foo()) == 15
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user