mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-22 18:57:59 -05:00
PostgreSQL driver, tests against DB backends, and general drivers cleanup (#2723)
* PostgreSQL driver and general drivers cleanup Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Make tests pass Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Add black --target-version flag in make.bat Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Rewrite postgres driver Most of the logic is now in PL/pgSQL. This completely avoids the use of Python f-strings to format identifiers into queries. Although an SQL-injection attack would have been impossible anyway (only the owner would have ever had the ability to do that), using PostgreSQL's format() is more reliable for unusual identifiers. Performance-wise, I'm not sure whether this is an improvement, but I highly doubt that it's worse. Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Reformat Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix PostgresDriver.delete_all_data() Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Clean up PL/pgSQL code Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * More PL/pgSQL cleanup Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * PL/pgSQL function optimisations Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Ensure compatibility with PostgreSQL 10 and below Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * More/better docstrings for PG functions Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix typo in docstring Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Return correct value on toggle() Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Use composite type for PG function parameters Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix JSON driver's Config.clear_all() Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Correct description for Mongo tox recipe Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix linting errors Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Update dep specification after merging bumpdeps Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Add towncrier entries Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Update from merge Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Mention [postgres] extra in install docs Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Support more connection options and use better defaults Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Actually pass PG env vars in tox Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Replace event trigger with manual DELETE queries Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
committed by
Michael H
parent
57fa29dd64
commit
d1a46acc9a
342
redbot/core/drivers/base.py
Normal file
342
redbot/core/drivers/base.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import abc
|
||||
import enum
|
||||
from typing import Tuple, Dict, Any, Union, List, AsyncIterator, Type
|
||||
|
||||
__all__ = ["BaseDriver", "IdentifierData", "ConfigCategory"]
|
||||
|
||||
|
||||
class ConfigCategory(str, enum.Enum):
|
||||
GLOBAL = "GLOBAL"
|
||||
GUILD = "GUILD"
|
||||
CHANNEL = "TEXTCHANNEL"
|
||||
ROLE = "ROLE"
|
||||
USER = "USER"
|
||||
MEMBER = "MEMBER"
|
||||
|
||||
@classmethod
|
||||
def get_pkey_info(
|
||||
cls, category: Union[str, "ConfigCategory"], custom_group_data: Dict[str, int]
|
||||
) -> Tuple[int, bool]:
|
||||
"""Get the full primary key length for the given category,
|
||||
and whether or not the category is a custom category.
|
||||
"""
|
||||
try:
|
||||
# noinspection PyArgumentList
|
||||
category_obj = cls(category)
|
||||
except ValueError:
|
||||
return custom_group_data[category], True
|
||||
else:
|
||||
return _CATEGORY_PKEY_COUNTS[category_obj], False
|
||||
|
||||
|
||||
_CATEGORY_PKEY_COUNTS = {
|
||||
ConfigCategory.GLOBAL: 0,
|
||||
ConfigCategory.GUILD: 1,
|
||||
ConfigCategory.CHANNEL: 1,
|
||||
ConfigCategory.ROLE: 1,
|
||||
ConfigCategory.USER: 1,
|
||||
ConfigCategory.MEMBER: 2,
|
||||
}
|
||||
|
||||
|
||||
class IdentifierData:
|
||||
def __init__(
|
||||
self,
|
||||
cog_name: str,
|
||||
uuid: str,
|
||||
category: str,
|
||||
primary_key: Tuple[str, ...],
|
||||
identifiers: Tuple[str, ...],
|
||||
primary_key_len: int,
|
||||
is_custom: bool = False,
|
||||
):
|
||||
self._cog_name = cog_name
|
||||
self._uuid = uuid
|
||||
self._category = category
|
||||
self._primary_key = primary_key
|
||||
self._identifiers = identifiers
|
||||
self.primary_key_len = primary_key_len
|
||||
self._is_custom = is_custom
|
||||
|
||||
@property
|
||||
def cog_name(self) -> str:
|
||||
return self._cog_name
|
||||
|
||||
@property
|
||||
def uuid(self) -> str:
|
||||
return self._uuid
|
||||
|
||||
@property
|
||||
def category(self) -> str:
|
||||
return self._category
|
||||
|
||||
@property
|
||||
def primary_key(self) -> Tuple[str, ...]:
|
||||
return self._primary_key
|
||||
|
||||
@property
|
||||
def identifiers(self) -> Tuple[str, ...]:
|
||||
return self._identifiers
|
||||
|
||||
@property
|
||||
def is_custom(self) -> bool:
|
||||
return self._is_custom
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<IdentifierData cog_name={self.cog_name} uuid={self.uuid} category={self.category} "
|
||||
f"primary_key={self.primary_key} identifiers={self.identifiers}>"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, IdentifierData):
|
||||
return False
|
||||
return (
|
||||
self.uuid == other.uuid
|
||||
and self.category == other.category
|
||||
and self.primary_key == other.primary_key
|
||||
and self.identifiers == other.identifiers
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.uuid, self.category, self.primary_key, self.identifiers))
|
||||
|
||||
def add_identifier(self, *identifier: str) -> "IdentifierData":
|
||||
if not all(isinstance(i, str) for i in identifier):
|
||||
raise ValueError("Identifiers must be strings.")
|
||||
|
||||
return IdentifierData(
|
||||
self.cog_name,
|
||||
self.uuid,
|
||||
self.category,
|
||||
self.primary_key,
|
||||
self.identifiers + identifier,
|
||||
self.primary_key_len,
|
||||
is_custom=self.is_custom,
|
||||
)
|
||||
|
||||
def to_tuple(self) -> Tuple[str, ...]:
|
||||
return tuple(
|
||||
filter(
|
||||
None,
|
||||
(self.cog_name, self.uuid, self.category, *self.primary_key, *self.identifiers),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class BaseDriver(abc.ABC):
|
||||
def __init__(self, cog_name: str, identifier: str, **kwargs):
|
||||
self.cog_name = cog_name
|
||||
self.unique_cog_identifier = identifier
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
async def initialize(cls, **storage_details) -> None:
|
||||
"""
|
||||
Initialize this driver.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
**storage_details
|
||||
The storage details required to initialize this driver.
|
||||
Should be the same as :func:`data_manager.storage_details`
|
||||
|
||||
Raises
|
||||
------
|
||||
MissingExtraRequirements
|
||||
If initializing the driver requires an extra which isn't
|
||||
installed.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
async def teardown(cls) -> None:
|
||||
"""
|
||||
Tear down this driver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def get_config_details() -> Dict[str, Any]:
|
||||
"""
|
||||
Asks users for additional configuration information necessary
|
||||
to use this config driver.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, Any]
|
||||
Dictionary of configuration details.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get(self, identifier_data: IdentifierData) -> Any:
|
||||
"""
|
||||
Finds the value indicate by the given identifiers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifier_data
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
Stored value.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def set(self, identifier_data: IdentifierData, value=None) -> None:
|
||||
"""
|
||||
Sets the value of the key indicated by the given identifiers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifier_data
|
||||
value
|
||||
Any JSON serializable python object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def clear(self, identifier_data: IdentifierData) -> None:
|
||||
"""
|
||||
Clears out the value specified by the given identifiers.
|
||||
|
||||
Equivalent to using ``del`` on a dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifier_data
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
|
||||
"""Get info for cogs which have data stored on this backend.
|
||||
|
||||
Yields
|
||||
------
|
||||
Tuple[str, str]
|
||||
Asynchronously yields (cog_name, cog_identifier) tuples.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
async def migrate_to(
|
||||
cls,
|
||||
new_driver_cls: Type["BaseDriver"],
|
||||
all_custom_group_data: Dict[str, Dict[str, Dict[str, int]]],
|
||||
) -> None:
|
||||
"""Migrate data from this backend to another.
|
||||
|
||||
Both drivers must be initialized beforehand.
|
||||
|
||||
This will only move the data - no instance metadata is modified
|
||||
as a result of this operation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_driver_cls
|
||||
Subclass of `BaseDriver`.
|
||||
all_custom_group_data : Dict[str, Dict[str, Dict[str, int]]]
|
||||
Dict mapping cog names, to cog IDs, to custom groups, to
|
||||
primary key lengths.
|
||||
|
||||
"""
|
||||
# Backend-agnostic method of migrating from one driver to another.
|
||||
async for cog_name, cog_id in cls.aiter_cogs():
|
||||
this_driver = cls(cog_name, cog_id)
|
||||
other_driver = new_driver_cls(cog_name, cog_id)
|
||||
custom_group_data = all_custom_group_data.get(cog_name, {}).get(cog_id, {})
|
||||
exported_data = await this_driver.export_data(custom_group_data)
|
||||
await other_driver.import_data(exported_data, custom_group_data)
|
||||
|
||||
@classmethod
|
||||
async def delete_all_data(cls, **kwargs) -> None:
|
||||
"""Delete all data being stored by this driver.
|
||||
|
||||
The driver must be initialized before this operation.
|
||||
|
||||
The BaseDriver provides a generic method which may be overriden
|
||||
by subclasses.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
**kwargs
|
||||
Driver-specific kwargs to change the way this method
|
||||
operates.
|
||||
|
||||
"""
|
||||
async for cog_name, cog_id in cls.aiter_cogs():
|
||||
driver = cls(cog_name, cog_id)
|
||||
await driver.clear(IdentifierData(cog_name, cog_id, "", (), (), 0))
|
||||
|
||||
@staticmethod
|
||||
def _split_primary_key(
|
||||
category: Union[ConfigCategory, str],
|
||||
custom_group_data: Dict[str, int],
|
||||
data: Dict[str, Any],
|
||||
) -> List[Tuple[Tuple[str, ...], Dict[str, Any]]]:
|
||||
pkey_len = ConfigCategory.get_pkey_info(category, custom_group_data)[0]
|
||||
if pkey_len == 0:
|
||||
return [((), data)]
|
||||
|
||||
def flatten(levels_remaining, currdata, parent_key=()):
|
||||
items = []
|
||||
for _k, _v in currdata.items():
|
||||
new_key = parent_key + (_k,)
|
||||
if levels_remaining > 1:
|
||||
items.extend(flatten(levels_remaining - 1, _v, new_key).items())
|
||||
else:
|
||||
items.append((new_key, _v))
|
||||
return dict(items)
|
||||
|
||||
ret = []
|
||||
for k, v in flatten(pkey_len, data).items():
|
||||
ret.append((k, v))
|
||||
return ret
|
||||
|
||||
async def export_data(
|
||||
self, custom_group_data: Dict[str, int]
|
||||
) -> List[Tuple[str, Dict[str, Any]]]:
|
||||
categories = [c.value for c in ConfigCategory]
|
||||
categories.extend(custom_group_data.keys())
|
||||
|
||||
ret = []
|
||||
for c in categories:
|
||||
ident_data = IdentifierData(
|
||||
self.cog_name,
|
||||
self.unique_cog_identifier,
|
||||
c,
|
||||
(),
|
||||
(),
|
||||
*ConfigCategory.get_pkey_info(c, custom_group_data),
|
||||
)
|
||||
try:
|
||||
data = await self.get(ident_data)
|
||||
except KeyError:
|
||||
continue
|
||||
ret.append((c, data))
|
||||
return ret
|
||||
|
||||
async def import_data(
|
||||
self, cog_data: List[Tuple[str, Dict[str, Any]]], custom_group_data: Dict[str, int]
|
||||
) -> None:
|
||||
for category, all_data in cog_data:
|
||||
splitted_pkey = self._split_primary_key(category, custom_group_data, all_data)
|
||||
for pkey, data in splitted_pkey:
|
||||
ident_data = IdentifierData(
|
||||
self.cog_name,
|
||||
self.unique_cog_identifier,
|
||||
category,
|
||||
pkey,
|
||||
(),
|
||||
*ConfigCategory.get_pkey_info(category, custom_group_data),
|
||||
)
|
||||
await self.set(ident_data, data)
|
||||
Reference in New Issue
Block a user