Config locks (#2654)

* Config locks

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Add locks for all_XXX

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Remove a word

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Add acquire_lock kwarg for value context manager

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Add towncrier entry

Signed-off-by: Toby <tobyharradine@gmail.com>

* Fix issues with `get_custom_lock` and `get_members_lock`

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
Toby Harradine 2019-07-24 06:50:07 +10:00 committed by Michael H
parent a8091332b8
commit af096bc1cc
4 changed files with 241 additions and 21 deletions

View File

@ -0,0 +1 @@
Added functions to acquire locks on Config groups and values. These locks are acquired by default when calling a value as a context manager. See :meth:`Value.get_lock` for details

View File

@ -1,8 +1,19 @@
import asyncio
import collections import collections
import logging import logging
import pickle import pickle
import weakref import weakref
from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar from typing import (
Any,
Union,
Tuple,
Dict,
Awaitable,
AsyncContextManager,
TypeVar,
MutableMapping,
Optional,
)
import discord import discord
@ -38,18 +49,26 @@ class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]): # pylint: disab
i.e. `dict`s or `list`s. This is because this class's ``raw_value`` i.e. `dict`s or `list`s. This is because this class's ``raw_value``
attribute must contain a reference to the object being modified within the attribute must contain a reference to the object being modified within the
context manager. context manager.
It should also be noted that the use of this context manager implies
the acquisition of the value's lock when the ``acquire_lock`` kwarg
to ``__init__`` is set to ``True``.
""" """
def __init__(self, value_obj, coro): def __init__(self, value_obj: "Value", coro: Awaitable[Any], *, acquire_lock: bool):
self.value_obj = value_obj self.value_obj = value_obj
self.coro = coro self.coro = coro
self.raw_value = None self.raw_value = None
self.__original_value = None self.__original_value = None
self.__acquire_lock = acquire_lock
self.__lock = self.value_obj.get_lock()
def __await__(self): def __await__(self):
return self.coro.__await__() return self.coro.__await__()
async def __aenter__(self): async def __aenter__(self):
if self.__acquire_lock is True:
await self.__lock.acquire()
self.raw_value = await self self.raw_value = await self
if not isinstance(self.raw_value, (list, dict)): if not isinstance(self.raw_value, (list, dict)):
raise TypeError( raise TypeError(
@ -61,12 +80,16 @@ class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]): # pylint: disab
return self.raw_value return self.raw_value
async def __aexit__(self, exc_type, exc, tb): async def __aexit__(self, exc_type, exc, tb):
if isinstance(self.raw_value, dict): try:
raw_value = _str_key_dict(self.raw_value) if isinstance(self.raw_value, dict):
else: raw_value = _str_key_dict(self.raw_value)
raw_value = self.raw_value else:
if raw_value != self.__original_value: raw_value = self.raw_value
await self.value_obj.set(self.raw_value) if raw_value != self.__original_value:
await self.value_obj.set(self.raw_value)
finally:
if self.__acquire_lock is True:
self.__lock.release()
class Value: class Value:
@ -74,9 +97,8 @@ class Value:
Attributes Attributes
---------- ----------
identifiers : Tuple[str] identifier_data : IdentifierData
This attribute provides all the keys necessary to get a specific data Information on identifiers for this value.
element from a json document.
default default
The default value for the data element that `identifiers` points at. The default value for the data element that `identifiers` points at.
driver : `redbot.core.drivers.red_base.BaseDriver` driver : `redbot.core.drivers.red_base.BaseDriver`
@ -84,10 +106,44 @@ class Value:
""" """
def __init__(self, identifier_data: IdentifierData, default_value, driver): def __init__(self, identifier_data: IdentifierData, default_value, driver, config: "Config"):
self.identifier_data = identifier_data self.identifier_data = identifier_data
self.default = default_value self.default = default_value
self.driver = driver self.driver = driver
self._config = config
def get_lock(self) -> asyncio.Lock:
"""Get a lock to create a critical region where this value is accessed.
When using this lock, make sure you either use it with the
``async with`` syntax, or if that's not feasible, ensure you
keep a reference to it from the acquisition to the release of
the lock. That is, if you can't use ``async with`` syntax, use
the lock like this::
lock = config.foo.get_lock()
await lock.acquire()
# Do stuff...
lock.release()
Do not use it like this::
await config.foo.get_lock().acquire()
# Do stuff...
config.foo.get_lock().release()
Doing it the latter way will likely cause an error, as the
acquired lock will be cleaned up by the garbage collector before
it is released, meaning the second call to ``get_lock()`` will
return a different lock to the first call.
Returns
-------
asyncio.Lock
A lock which is weakly cached for this value object.
"""
return self._config._lock_cache.setdefault(self.identifier_data, asyncio.Lock())
async def _get(self, default=...): async def _get(self, default=...):
try: try:
@ -96,7 +152,7 @@ class Value:
return default if default is not ... else self.default return default if default is not ... else self.default
return ret return ret
def __call__(self, default=...) -> _ValueCtxManager[Any]: def __call__(self, default=..., *, acquire_lock: bool = True) -> _ValueCtxManager[Any]:
"""Get the literal value of this data element. """Get the literal value of this data element.
Each `Value` object is created by the `Group.__getattr__` method. The Each `Value` object is created by the `Group.__getattr__` method. The
@ -106,7 +162,10 @@ class Value:
The return value of this method can also be used as an asynchronous The return value of this method can also be used as an asynchronous
context manager, i.e. with :code:`async with` syntax. This can only be context manager, i.e. with :code:`async with` syntax. This can only be
used on values which are mutable (namely lists and dicts), and will used on values which are mutable (namely lists and dicts), and will
set the value with its changes on exit of the context manager. set the value with its changes on exit of the context manager. It will
also acquire this value's lock to protect the critical region inside
this context manager's body, unless the ``acquire_lock`` keyword
argument is set to ``False``.
Example Example
------- -------
@ -129,7 +188,14 @@ class Value:
default : `object`, optional default : `object`, optional
This argument acts as an override for the registered default This argument acts as an override for the registered default
provided by `default`. This argument is ignored if its provided by `default`. This argument is ignored if its
value is :code:`None`. value is :code:`...`.
Other Parameters
----------------
acquire_lock : bool
Set to ``False`` to disable the acquisition of the value's
lock over the context manager body. Defaults to ``True``.
Has no effect when not used as a context manager.
Returns Returns
------- -------
@ -139,7 +205,7 @@ class Value:
with` syntax, on gets the value on entrance, and sets it on exit. with` syntax, on gets the value on entrance, and sets it on exit.
""" """
return _ValueCtxManager(self, self._get(default)) return _ValueCtxManager(self, self._get(default), acquire_lock=acquire_lock)
async def set(self, value): async def set(self, value):
"""Set the value of the data elements pointed to by `identifiers`. """Set the value of the data elements pointed to by `identifiers`.
@ -194,13 +260,14 @@ class Group(Value):
identifier_data: IdentifierData, identifier_data: IdentifierData,
defaults: dict, defaults: dict,
driver, driver,
config: "Config",
force_registration: bool = False, force_registration: bool = False,
): ):
self._defaults = defaults self._defaults = defaults
self.force_registration = force_registration self.force_registration = force_registration
self.driver = driver self.driver = driver
super().__init__(identifier_data, {}, self.driver) super().__init__(identifier_data, {}, self.driver, config)
@property @property
def defaults(self): def defaults(self):
@ -248,17 +315,24 @@ class Group(Value):
defaults=self._defaults[item], defaults=self._defaults[item],
driver=self.driver, driver=self.driver,
force_registration=self.force_registration, force_registration=self.force_registration,
config=self._config,
) )
elif is_value: elif is_value:
return Value( return Value(
identifier_data=new_identifiers, identifier_data=new_identifiers,
default_value=self._defaults[item], default_value=self._defaults[item],
driver=self.driver, driver=self.driver,
config=self._config,
) )
elif self.force_registration: elif self.force_registration:
raise AttributeError("'{}' is not a valid registered Group or value.".format(item)) raise AttributeError("'{}' is not a valid registered Group or value.".format(item))
else: else:
return Value(identifier_data=new_identifiers, default_value=None, driver=self.driver) return Value(
identifier_data=new_identifiers,
default_value=None,
driver=self.driver,
config=self._config,
)
async def clear_raw(self, *nested_path: Any): async def clear_raw(self, *nested_path: Any):
""" """
@ -411,7 +485,7 @@ class Group(Value):
return self.nested_update(raw, default) return self.nested_update(raw, default)
return raw return raw
def all(self) -> _ValueCtxManager[Dict[str, Any]]: def all(self, *, acquire_lock: bool = True) -> _ValueCtxManager[Dict[str, Any]]:
"""Get a dictionary representation of this group's data. """Get a dictionary representation of this group's data.
The return value of this method can also be used as an asynchronous The return value of this method can also be used as an asynchronous
@ -422,13 +496,19 @@ class Group(Value):
The return value of this method will include registered defaults for The return value of this method will include registered defaults for
values which have not yet been set. values which have not yet been set.
Other Parameters
----------------
acquire_lock : bool
Same as the ``acquire_lock`` keyword parameter in
`Value.__call__`.
Returns Returns
------- -------
dict dict
All of this Group's attributes, resolved as raw data values. All of this Group's attributes, resolved as raw data values.
""" """
return self() return self(acquire_lock=acquire_lock)
def nested_update( def nested_update(
self, current: collections.Mapping, defaults: Dict[str, Any] = ... self, current: collections.Mapping, defaults: Dict[str, Any] = ...
@ -555,6 +635,9 @@ class Config:
self._defaults = defaults or {} self._defaults = defaults or {}
self.custom_groups = {} self.custom_groups = {}
self._lock_cache: MutableMapping[
IdentifierData, asyncio.Lock
] = weakref.WeakValueDictionary()
@property @property
def defaults(self): def defaults(self):
@ -862,6 +945,7 @@ class Config:
defaults=defaults, defaults=defaults,
driver=self.driver, driver=self.driver,
force_registration=self.force_registration, force_registration=self.force_registration,
config=self,
) )
def guild(self, guild: discord.Guild) -> Group: def guild(self, guild: discord.Guild) -> Group:
@ -1139,7 +1223,7 @@ class Config:
identifier_data = IdentifierData( identifier_data = IdentifierData(
self.unique_identifier, "", (), (), self.custom_groups self.unique_identifier, "", (), (), self.custom_groups
) )
group = Group(identifier_data, defaults={}, driver=self.driver) group = Group(identifier_data, defaults={}, driver=self.driver, config=self)
else: else:
cat, *scopes = scopes cat, *scopes = scopes
group = self._get_base_group(cat, *scopes) group = self._get_base_group(cat, *scopes)
@ -1222,6 +1306,80 @@ class Config:
""" """
await self._clear_scope(str(group_identifier)) await self._clear_scope(str(group_identifier))
def get_guilds_lock(self) -> asyncio.Lock:
"""Get a lock for all guild data.
Returns
-------
asyncio.Lock
"""
return self.get_custom_lock(self.GUILD)
def get_channels_lock(self) -> asyncio.Lock:
"""Get a lock for all channel data.
Returns
-------
asyncio.Lock
"""
return self.get_custom_lock(self.CHANNEL)
def get_roles_lock(self) -> asyncio.Lock:
"""Get a lock for all role data.
Returns
-------
asyncio.Lock
"""
return self.get_custom_lock(self.ROLE)
def get_users_lock(self) -> asyncio.Lock:
"""Get a lock for all user data.
Returns
-------
asyncio.Lock
"""
return self.get_custom_lock(self.USER)
def get_members_lock(self, guild: Optional[discord.Guild] = None) -> asyncio.Lock:
"""Get a lock for all member data.
Parameters
----------
guild : Optional[discord.Guild]
The guild containing the members whose data you want to
lock. Omit to lock all data for all members in all guilds.
Returns
-------
asyncio.Lock
"""
if guild is None:
return self.get_custom_lock(self.GUILD)
else:
id_data = IdentifierData(
self.unique_identifier, self.MEMBER, (str(guild.id),), (), self.custom_groups
)
return self._lock_cache.setdefault(id_data, asyncio.Lock())
def get_custom_lock(self, group_identifier: str) -> asyncio.Lock:
"""Get a lock for all data in a custom scope.
Parameters
----------
group_identifier : str
The group identifier for the custom scope you want to lock.
Returns
-------
asyncio.Lock
"""
id_data = IdentifierData(
self.unique_identifier, group_identifier, (), (), self.custom_groups
)
return self._lock_cache.setdefault(id_data, asyncio.Lock())
def _str_key_dict(value: Dict[Any, _T]) -> Dict[str, _T]: def _str_key_dict(value: Dict[Any, _T]) -> Dict[str, _T]:
""" """

View File

@ -56,6 +56,19 @@ class IdentifierData:
f" identifiers={self.identifiers}>" f" identifiers={self.identifiers}>"
) )
def __eq__(self, other) -> bool:
if not isinstance(other, IdentifierData):
return False
return (
self.uuid == other.uuid
and self.category == other.category
and self.primary_key == other.primary_key
and self.identifiers == other.identifiers
)
def __hash__(self) -> int:
return hash((self.uuid, self.category, self.primary_key, self.identifiers))
def add_identifier(self, *identifier: str) -> "IdentifierData": def add_identifier(self, *identifier: str) -> "IdentifierData":
if not all(isinstance(i, str) for i in identifier): if not all(isinstance(i, str) for i in identifier):
raise ValueError("Identifiers must be strings.") raise ValueError("Identifiers must be strings.")

View File

@ -1,3 +1,4 @@
import asyncio
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -506,3 +507,50 @@ def test_config_custom_doubleinit(config):
config.init_custom("TEST", 3) config.init_custom("TEST", 3)
with pytest.raises(ValueError): with pytest.raises(ValueError):
config.init_custom("TEST", 2) config.init_custom("TEST", 2)
@pytest.mark.asyncio
async def test_config_locks_cache(config, empty_guild):
lock1 = config.foo.get_lock()
assert lock1 is config.foo.get_lock()
lock2 = config.guild(empty_guild).foo.get_lock()
assert lock2 is config.guild(empty_guild).foo.get_lock()
assert lock1 is not lock2
@pytest.mark.asyncio
async def test_config_value_atomicity(config):
config.register_global(foo=[])
tasks = []
for _ in range(15):
async def func():
async with config.foo.get_lock():
foo = await config.foo()
foo.append(0)
await asyncio.sleep(0.1)
await config.foo.set(foo)
tasks.append(func())
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
assert len(await config.foo()) == 15
@pytest.mark.asyncio
async def test_config_ctxmgr_atomicity(config):
config.register_global(foo=[])
tasks = []
for _ in range(15):
async def func():
async with config.foo() as foo:
foo.append(0)
await asyncio.sleep(0.1)
tasks.append(func())
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
assert len(await config.foo()) == 15