mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-20 18:06:08 -05:00
PostgreSQL driver, tests against DB backends, and general drivers cleanup (#2723)
* PostgreSQL driver and general drivers cleanup Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Make tests pass Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Add black --target-version flag in make.bat Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Rewrite postgres driver Most of the logic is now in PL/pgSQL. This completely avoids the use of Python f-strings to format identifiers into queries. Although an SQL-injection attack would have been impossible anyway (only the owner would have ever had the ability to do that), using PostgreSQL's format() is more reliable for unusual identifiers. Performance-wise, I'm not sure whether this is an improvement, but I highly doubt that it's worse. Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Reformat Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix PostgresDriver.delete_all_data() Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Clean up PL/pgSQL code Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * More PL/pgSQL cleanup Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * PL/pgSQL function optimisations Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Ensure compatibility with PostgreSQL 10 and below Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * More/better docstrings for PG functions Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix typo in docstring Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Return correct value on toggle() Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Use composite type for PG function parameters Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix JSON driver's Config.clear_all() Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Correct description for Mongo tox recipe Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix linting errors Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Update dep specification after merging bumpdeps Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Add towncrier entries Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Update from merge Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Mention [postgres] extra in install docs Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Support more connection options and use better defaults Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Actually pass PG env vars in tox Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Replace event trigger with manual DELETE queries Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
committed by
Michael H
parent
57fa29dd64
commit
d1a46acc9a
@@ -1,50 +1,110 @@
|
||||
import enum
|
||||
from typing import Optional, Type
|
||||
|
||||
from .red_base import IdentifierData
|
||||
from .. import data_manager
|
||||
from .base import IdentifierData, BaseDriver, ConfigCategory
|
||||
from .json import JsonDriver
|
||||
from .mongo import MongoDriver
|
||||
from .postgres import PostgresDriver
|
||||
|
||||
__all__ = ["get_driver", "IdentifierData", "BackendType"]
|
||||
__all__ = [
|
||||
"get_driver",
|
||||
"ConfigCategory",
|
||||
"IdentifierData",
|
||||
"BaseDriver",
|
||||
"JsonDriver",
|
||||
"MongoDriver",
|
||||
"PostgresDriver",
|
||||
"BackendType",
|
||||
]
|
||||
|
||||
|
||||
class BackendType(enum.Enum):
|
||||
JSON = "JSON"
|
||||
MONGO = "MongoDBV2"
|
||||
MONGOV1 = "MongoDB"
|
||||
POSTGRES = "Postgres"
|
||||
|
||||
|
||||
def get_driver(type, *args, **kwargs):
|
||||
_DRIVER_CLASSES = {
|
||||
BackendType.JSON: JsonDriver,
|
||||
BackendType.MONGO: MongoDriver,
|
||||
BackendType.POSTGRES: PostgresDriver,
|
||||
}
|
||||
|
||||
|
||||
def get_driver_class(storage_type: Optional[BackendType] = None) -> Type[BaseDriver]:
|
||||
"""Get the driver class for the given storage type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
storage_type : Optional[BackendType]
|
||||
The backend you want a driver class for. Omit to try to obtain
|
||||
the backend from data manager.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Type[BaseDriver]
|
||||
A subclass of `BaseDriver`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If there is no driver for the given storage type.
|
||||
|
||||
"""
|
||||
Selectively import/load driver classes based on the selected type. This
|
||||
is required so that dependencies can differ between installs (e.g. so that
|
||||
you don't need to install a mongo dependency if you will just be running a
|
||||
json data backend).
|
||||
if storage_type is None:
|
||||
storage_type = BackendType(data_manager.storage_type())
|
||||
try:
|
||||
return _DRIVER_CLASSES[storage_type]
|
||||
except KeyError:
|
||||
raise ValueError(f"No driver found for storage type {storage_type}") from None
|
||||
|
||||
.. note::
|
||||
|
||||
See the respective classes for information on what ``args`` and ``kwargs``
|
||||
should be.
|
||||
def get_driver(
|
||||
cog_name: str, identifier: str, storage_type: Optional[BackendType] = None, **kwargs
|
||||
):
|
||||
"""Get a driver instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cog_name : str
|
||||
The cog's name.
|
||||
identifier : str
|
||||
The cog's discriminator.
|
||||
storage_type : Optional[BackendType]
|
||||
The backend you want a driver for. Omit to try to obtain the
|
||||
backend from data manager.
|
||||
**kwargs
|
||||
Driver-specific keyword arguments.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BaseDriver
|
||||
A driver instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If the storage type is MongoV1 or invalid.
|
||||
|
||||
:param str type:
|
||||
One of: json, mongo
|
||||
:param args:
|
||||
Dependent on driver type.
|
||||
:param kwargs:
|
||||
Dependent on driver type.
|
||||
:return:
|
||||
Subclass of :py:class:`.red_base.BaseDriver`.
|
||||
"""
|
||||
if type == "JSON":
|
||||
from .red_json import JSON
|
||||
if storage_type is None:
|
||||
try:
|
||||
storage_type = BackendType(data_manager.storage_type())
|
||||
except RuntimeError:
|
||||
storage_type = BackendType.JSON
|
||||
|
||||
return JSON(*args, **kwargs)
|
||||
elif type == "MongoDBV2":
|
||||
from .red_mongo import Mongo
|
||||
|
||||
return Mongo(*args, **kwargs)
|
||||
elif type == "MongoDB":
|
||||
raise RuntimeError(
|
||||
"Please convert to JSON first to continue using the bot."
|
||||
" This is a required conversion prior to using the new Mongo driver."
|
||||
" This message will be updated with a link to the update docs once those"
|
||||
" docs have been created."
|
||||
)
|
||||
raise RuntimeError("Invalid driver type: '{}'".format(type))
|
||||
try:
|
||||
driver_cls: Type[BaseDriver] = get_driver_class(storage_type)
|
||||
except ValueError:
|
||||
if storage_type == BackendType.MONGOV1:
|
||||
raise RuntimeError(
|
||||
"Please convert to JSON first to continue using the bot."
|
||||
" This is a required conversion prior to using the new Mongo driver."
|
||||
" This message will be updated with a link to the update docs once those"
|
||||
" docs have been created."
|
||||
) from None
|
||||
else:
|
||||
raise RuntimeError(f"Invalid driver type: '{storage_type}'") from None
|
||||
return driver_cls(cog_name, identifier, **kwargs)
|
||||
|
||||
342
redbot/core/drivers/base.py
Normal file
342
redbot/core/drivers/base.py
Normal file
@@ -0,0 +1,342 @@
|
||||
import abc
|
||||
import enum
|
||||
from typing import Tuple, Dict, Any, Union, List, AsyncIterator, Type
|
||||
|
||||
__all__ = ["BaseDriver", "IdentifierData", "ConfigCategory"]
|
||||
|
||||
|
||||
class ConfigCategory(str, enum.Enum):
|
||||
GLOBAL = "GLOBAL"
|
||||
GUILD = "GUILD"
|
||||
CHANNEL = "TEXTCHANNEL"
|
||||
ROLE = "ROLE"
|
||||
USER = "USER"
|
||||
MEMBER = "MEMBER"
|
||||
|
||||
@classmethod
|
||||
def get_pkey_info(
|
||||
cls, category: Union[str, "ConfigCategory"], custom_group_data: Dict[str, int]
|
||||
) -> Tuple[int, bool]:
|
||||
"""Get the full primary key length for the given category,
|
||||
and whether or not the category is a custom category.
|
||||
"""
|
||||
try:
|
||||
# noinspection PyArgumentList
|
||||
category_obj = cls(category)
|
||||
except ValueError:
|
||||
return custom_group_data[category], True
|
||||
else:
|
||||
return _CATEGORY_PKEY_COUNTS[category_obj], False
|
||||
|
||||
|
||||
_CATEGORY_PKEY_COUNTS = {
|
||||
ConfigCategory.GLOBAL: 0,
|
||||
ConfigCategory.GUILD: 1,
|
||||
ConfigCategory.CHANNEL: 1,
|
||||
ConfigCategory.ROLE: 1,
|
||||
ConfigCategory.USER: 1,
|
||||
ConfigCategory.MEMBER: 2,
|
||||
}
|
||||
|
||||
|
||||
class IdentifierData:
|
||||
def __init__(
|
||||
self,
|
||||
cog_name: str,
|
||||
uuid: str,
|
||||
category: str,
|
||||
primary_key: Tuple[str, ...],
|
||||
identifiers: Tuple[str, ...],
|
||||
primary_key_len: int,
|
||||
is_custom: bool = False,
|
||||
):
|
||||
self._cog_name = cog_name
|
||||
self._uuid = uuid
|
||||
self._category = category
|
||||
self._primary_key = primary_key
|
||||
self._identifiers = identifiers
|
||||
self.primary_key_len = primary_key_len
|
||||
self._is_custom = is_custom
|
||||
|
||||
@property
|
||||
def cog_name(self) -> str:
|
||||
return self._cog_name
|
||||
|
||||
@property
|
||||
def uuid(self) -> str:
|
||||
return self._uuid
|
||||
|
||||
@property
|
||||
def category(self) -> str:
|
||||
return self._category
|
||||
|
||||
@property
|
||||
def primary_key(self) -> Tuple[str, ...]:
|
||||
return self._primary_key
|
||||
|
||||
@property
|
||||
def identifiers(self) -> Tuple[str, ...]:
|
||||
return self._identifiers
|
||||
|
||||
@property
|
||||
def is_custom(self) -> bool:
|
||||
return self._is_custom
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<IdentifierData cog_name={self.cog_name} uuid={self.uuid} category={self.category} "
|
||||
f"primary_key={self.primary_key} identifiers={self.identifiers}>"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, IdentifierData):
|
||||
return False
|
||||
return (
|
||||
self.uuid == other.uuid
|
||||
and self.category == other.category
|
||||
and self.primary_key == other.primary_key
|
||||
and self.identifiers == other.identifiers
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.uuid, self.category, self.primary_key, self.identifiers))
|
||||
|
||||
def add_identifier(self, *identifier: str) -> "IdentifierData":
|
||||
if not all(isinstance(i, str) for i in identifier):
|
||||
raise ValueError("Identifiers must be strings.")
|
||||
|
||||
return IdentifierData(
|
||||
self.cog_name,
|
||||
self.uuid,
|
||||
self.category,
|
||||
self.primary_key,
|
||||
self.identifiers + identifier,
|
||||
self.primary_key_len,
|
||||
is_custom=self.is_custom,
|
||||
)
|
||||
|
||||
def to_tuple(self) -> Tuple[str, ...]:
|
||||
return tuple(
|
||||
filter(
|
||||
None,
|
||||
(self.cog_name, self.uuid, self.category, *self.primary_key, *self.identifiers),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class BaseDriver(abc.ABC):
|
||||
def __init__(self, cog_name: str, identifier: str, **kwargs):
|
||||
self.cog_name = cog_name
|
||||
self.unique_cog_identifier = identifier
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
async def initialize(cls, **storage_details) -> None:
|
||||
"""
|
||||
Initialize this driver.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
**storage_details
|
||||
The storage details required to initialize this driver.
|
||||
Should be the same as :func:`data_manager.storage_details`
|
||||
|
||||
Raises
|
||||
------
|
||||
MissingExtraRequirements
|
||||
If initializing the driver requires an extra which isn't
|
||||
installed.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
async def teardown(cls) -> None:
|
||||
"""
|
||||
Tear down this driver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def get_config_details() -> Dict[str, Any]:
|
||||
"""
|
||||
Asks users for additional configuration information necessary
|
||||
to use this config driver.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, Any]
|
||||
Dictionary of configuration details.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get(self, identifier_data: IdentifierData) -> Any:
|
||||
"""
|
||||
Finds the value indicate by the given identifiers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifier_data
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
Stored value.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def set(self, identifier_data: IdentifierData, value=None) -> None:
|
||||
"""
|
||||
Sets the value of the key indicated by the given identifiers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifier_data
|
||||
value
|
||||
Any JSON serializable python object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def clear(self, identifier_data: IdentifierData) -> None:
|
||||
"""
|
||||
Clears out the value specified by the given identifiers.
|
||||
|
||||
Equivalent to using ``del`` on a dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifier_data
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
|
||||
"""Get info for cogs which have data stored on this backend.
|
||||
|
||||
Yields
|
||||
------
|
||||
Tuple[str, str]
|
||||
Asynchronously yields (cog_name, cog_identifier) tuples.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
async def migrate_to(
|
||||
cls,
|
||||
new_driver_cls: Type["BaseDriver"],
|
||||
all_custom_group_data: Dict[str, Dict[str, Dict[str, int]]],
|
||||
) -> None:
|
||||
"""Migrate data from this backend to another.
|
||||
|
||||
Both drivers must be initialized beforehand.
|
||||
|
||||
This will only move the data - no instance metadata is modified
|
||||
as a result of this operation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_driver_cls
|
||||
Subclass of `BaseDriver`.
|
||||
all_custom_group_data : Dict[str, Dict[str, Dict[str, int]]]
|
||||
Dict mapping cog names, to cog IDs, to custom groups, to
|
||||
primary key lengths.
|
||||
|
||||
"""
|
||||
# Backend-agnostic method of migrating from one driver to another.
|
||||
async for cog_name, cog_id in cls.aiter_cogs():
|
||||
this_driver = cls(cog_name, cog_id)
|
||||
other_driver = new_driver_cls(cog_name, cog_id)
|
||||
custom_group_data = all_custom_group_data.get(cog_name, {}).get(cog_id, {})
|
||||
exported_data = await this_driver.export_data(custom_group_data)
|
||||
await other_driver.import_data(exported_data, custom_group_data)
|
||||
|
||||
@classmethod
|
||||
async def delete_all_data(cls, **kwargs) -> None:
|
||||
"""Delete all data being stored by this driver.
|
||||
|
||||
The driver must be initialized before this operation.
|
||||
|
||||
The BaseDriver provides a generic method which may be overriden
|
||||
by subclasses.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
**kwargs
|
||||
Driver-specific kwargs to change the way this method
|
||||
operates.
|
||||
|
||||
"""
|
||||
async for cog_name, cog_id in cls.aiter_cogs():
|
||||
driver = cls(cog_name, cog_id)
|
||||
await driver.clear(IdentifierData(cog_name, cog_id, "", (), (), 0))
|
||||
|
||||
@staticmethod
|
||||
def _split_primary_key(
|
||||
category: Union[ConfigCategory, str],
|
||||
custom_group_data: Dict[str, int],
|
||||
data: Dict[str, Any],
|
||||
) -> List[Tuple[Tuple[str, ...], Dict[str, Any]]]:
|
||||
pkey_len = ConfigCategory.get_pkey_info(category, custom_group_data)[0]
|
||||
if pkey_len == 0:
|
||||
return [((), data)]
|
||||
|
||||
def flatten(levels_remaining, currdata, parent_key=()):
|
||||
items = []
|
||||
for _k, _v in currdata.items():
|
||||
new_key = parent_key + (_k,)
|
||||
if levels_remaining > 1:
|
||||
items.extend(flatten(levels_remaining - 1, _v, new_key).items())
|
||||
else:
|
||||
items.append((new_key, _v))
|
||||
return dict(items)
|
||||
|
||||
ret = []
|
||||
for k, v in flatten(pkey_len, data).items():
|
||||
ret.append((k, v))
|
||||
return ret
|
||||
|
||||
async def export_data(
|
||||
self, custom_group_data: Dict[str, int]
|
||||
) -> List[Tuple[str, Dict[str, Any]]]:
|
||||
categories = [c.value for c in ConfigCategory]
|
||||
categories.extend(custom_group_data.keys())
|
||||
|
||||
ret = []
|
||||
for c in categories:
|
||||
ident_data = IdentifierData(
|
||||
self.cog_name,
|
||||
self.unique_cog_identifier,
|
||||
c,
|
||||
(),
|
||||
(),
|
||||
*ConfigCategory.get_pkey_info(c, custom_group_data),
|
||||
)
|
||||
try:
|
||||
data = await self.get(ident_data)
|
||||
except KeyError:
|
||||
continue
|
||||
ret.append((c, data))
|
||||
return ret
|
||||
|
||||
async def import_data(
|
||||
self, cog_data: List[Tuple[str, Dict[str, Any]]], custom_group_data: Dict[str, int]
|
||||
) -> None:
|
||||
for category, all_data in cog_data:
|
||||
splitted_pkey = self._split_primary_key(category, custom_group_data, all_data)
|
||||
for pkey, data in splitted_pkey:
|
||||
ident_data = IdentifierData(
|
||||
self.cog_name,
|
||||
self.unique_cog_identifier,
|
||||
category,
|
||||
pkey,
|
||||
(),
|
||||
*ConfigCategory.get_pkey_info(category, custom_group_data),
|
||||
)
|
||||
await self.set(ident_data, data)
|
||||
@@ -5,12 +5,13 @@ import os
|
||||
import pickle
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from .red_base import BaseDriver, IdentifierData
|
||||
from .. import data_manager, errors
|
||||
from .base import BaseDriver, IdentifierData, ConfigCategory
|
||||
|
||||
__all__ = ["JSON"]
|
||||
__all__ = ["JsonDriver"]
|
||||
|
||||
|
||||
_shared_datastore = {}
|
||||
@@ -35,9 +36,10 @@ def finalize_driver(cog_name):
|
||||
_finalizers.remove(f)
|
||||
|
||||
|
||||
class JSON(BaseDriver):
|
||||
# noinspection PyProtectedMember
|
||||
class JsonDriver(BaseDriver):
|
||||
"""
|
||||
Subclass of :py:class:`.red_base.BaseDriver`.
|
||||
Subclass of :py:class:`.BaseDriver`.
|
||||
|
||||
.. py:attribute:: file_name
|
||||
|
||||
@@ -50,27 +52,26 @@ class JSON(BaseDriver):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cog_name,
|
||||
identifier,
|
||||
cog_name: str,
|
||||
identifier: str,
|
||||
*,
|
||||
data_path_override: Path = None,
|
||||
file_name_override: str = "settings.json"
|
||||
data_path_override: Optional[Path] = None,
|
||||
file_name_override: str = "settings.json",
|
||||
):
|
||||
super().__init__(cog_name, identifier)
|
||||
self.file_name = file_name_override
|
||||
if data_path_override:
|
||||
if data_path_override is not None:
|
||||
self.data_path = data_path_override
|
||||
elif cog_name == "Core" and identifier == "0":
|
||||
self.data_path = data_manager.core_data_path()
|
||||
else:
|
||||
self.data_path = Path.cwd() / "cogs" / ".data" / self.cog_name
|
||||
self.data_path = data_manager.cog_data_path(raw_name=cog_name)
|
||||
self.data_path.mkdir(parents=True, exist_ok=True)
|
||||
self.data_path = self.data_path / self.file_name
|
||||
|
||||
self._lock = asyncio.Lock()
|
||||
self._load_data()
|
||||
|
||||
async def has_valid_connection(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return _shared_datastore.get(self.cog_name)
|
||||
@@ -79,6 +80,21 @@ class JSON(BaseDriver):
|
||||
def data(self, value):
|
||||
_shared_datastore[self.cog_name] = value
|
||||
|
||||
@classmethod
|
||||
async def initialize(cls, **storage_details) -> None:
|
||||
# No initializing to do
|
||||
return
|
||||
|
||||
@classmethod
|
||||
async def teardown(cls) -> None:
|
||||
# No tearing down to do
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def get_config_details() -> Dict[str, Any]:
|
||||
# No driver-specific configuration needed
|
||||
return {}
|
||||
|
||||
def _load_data(self):
|
||||
if self.cog_name not in _driver_counts:
|
||||
_driver_counts[self.cog_name] = 0
|
||||
@@ -111,30 +127,32 @@ class JSON(BaseDriver):
|
||||
|
||||
async def get(self, identifier_data: IdentifierData):
|
||||
partial = self.data
|
||||
full_identifiers = identifier_data.to_tuple()
|
||||
full_identifiers = identifier_data.to_tuple()[1:]
|
||||
for i in full_identifiers:
|
||||
partial = partial[i]
|
||||
return pickle.loads(pickle.dumps(partial, -1))
|
||||
|
||||
async def set(self, identifier_data: IdentifierData, value=None):
|
||||
partial = self.data
|
||||
full_identifiers = identifier_data.to_tuple()
|
||||
full_identifiers = identifier_data.to_tuple()[1:]
|
||||
# This is both our deepcopy() and our way of making sure this value is actually JSON
|
||||
# serializable.
|
||||
value_copy = json.loads(json.dumps(value))
|
||||
|
||||
async with self._lock:
|
||||
for i in full_identifiers[:-1]:
|
||||
if i not in partial:
|
||||
partial[i] = {}
|
||||
partial = partial[i]
|
||||
partial[full_identifiers[-1]] = value_copy
|
||||
try:
|
||||
partial = partial.setdefault(i, {})
|
||||
except AttributeError:
|
||||
# Tried to set sub-field of non-object
|
||||
raise errors.CannotSetSubfield
|
||||
|
||||
partial[full_identifiers[-1]] = value_copy
|
||||
await self._save()
|
||||
|
||||
async def clear(self, identifier_data: IdentifierData):
|
||||
partial = self.data
|
||||
full_identifiers = identifier_data.to_tuple()
|
||||
full_identifiers = identifier_data.to_tuple()[1:]
|
||||
try:
|
||||
for i in full_identifiers[:-1]:
|
||||
partial = partial[i]
|
||||
@@ -149,14 +167,32 @@ class JSON(BaseDriver):
|
||||
else:
|
||||
await self._save()
|
||||
|
||||
@classmethod
|
||||
async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
|
||||
yield "Core", "0"
|
||||
for _dir in data_manager.cog_data_path().iterdir():
|
||||
fpath = _dir / "settings.json"
|
||||
if not fpath.exists():
|
||||
continue
|
||||
with fpath.open() as f:
|
||||
try:
|
||||
data = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
for cog, inner in data.items():
|
||||
if not isinstance(inner, dict):
|
||||
continue
|
||||
for cog_id in inner:
|
||||
yield cog, cog_id
|
||||
|
||||
async def import_data(self, cog_data, custom_group_data):
|
||||
def update_write_data(identifier_data: IdentifierData, _data):
|
||||
partial = self.data
|
||||
idents = identifier_data.to_tuple()
|
||||
idents = identifier_data.to_tuple()[1:]
|
||||
for ident in idents[:-1]:
|
||||
if ident not in partial:
|
||||
partial[ident] = {}
|
||||
partial = partial[ident]
|
||||
partial = partial.setdefault(ident, {})
|
||||
partial[idents[-1]] = _data
|
||||
|
||||
async with self._lock:
|
||||
@@ -164,12 +200,12 @@ class JSON(BaseDriver):
|
||||
splitted_pkey = self._split_primary_key(category, custom_group_data, all_data)
|
||||
for pkey, data in splitted_pkey:
|
||||
ident_data = IdentifierData(
|
||||
self.cog_name,
|
||||
self.unique_cog_identifier,
|
||||
category,
|
||||
pkey,
|
||||
(),
|
||||
custom_group_data,
|
||||
is_custom=category in custom_group_data,
|
||||
*ConfigCategory.get_pkey_info(category, custom_group_data),
|
||||
)
|
||||
update_write_data(ident_data, data)
|
||||
await self._save()
|
||||
@@ -178,9 +214,6 @@ class JSON(BaseDriver):
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, _save_json, self.data_path, self.data)
|
||||
|
||||
def get_config_details(self):
|
||||
return
|
||||
|
||||
|
||||
def _save_json(path: Path, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
11
redbot/core/drivers/log.py
Normal file
11
redbot/core/drivers/log.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
|
||||
if os.getenv("RED_INSPECT_DRIVER_QUERIES"):
|
||||
LOGGING_INVISIBLE = logging.DEBUG
|
||||
else:
|
||||
LOGGING_INVISIBLE = 0
|
||||
|
||||
log = logging.getLogger("red.driver")
|
||||
log.invisible = functools.partial(log.log, LOGGING_INVISIBLE)
|
||||
@@ -2,77 +2,110 @@ import contextlib
|
||||
import itertools
|
||||
import re
|
||||
from getpass import getpass
|
||||
from typing import Match, Pattern, Tuple, Any, Dict, Iterator, List
|
||||
from typing import Match, Pattern, Tuple, Optional, AsyncIterator, Any, Dict, Iterator, List
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import motor.core
|
||||
import motor.motor_asyncio
|
||||
import pymongo.errors
|
||||
try:
|
||||
# pylint: disable=import-error
|
||||
import pymongo.errors
|
||||
import motor.core
|
||||
import motor.motor_asyncio
|
||||
except ModuleNotFoundError:
|
||||
motor = None
|
||||
pymongo = None
|
||||
|
||||
from .red_base import BaseDriver, IdentifierData
|
||||
from .. import errors
|
||||
from .base import BaseDriver, IdentifierData
|
||||
|
||||
__all__ = ["Mongo"]
|
||||
__all__ = ["MongoDriver"]
|
||||
|
||||
|
||||
_conn = None
|
||||
|
||||
|
||||
def _initialize(**kwargs):
|
||||
uri = kwargs.get("URI", "mongodb")
|
||||
host = kwargs["HOST"]
|
||||
port = kwargs["PORT"]
|
||||
admin_user = kwargs["USERNAME"]
|
||||
admin_pass = kwargs["PASSWORD"]
|
||||
db_name = kwargs.get("DB_NAME", "default_db")
|
||||
|
||||
if port is 0:
|
||||
ports = ""
|
||||
else:
|
||||
ports = ":{}".format(port)
|
||||
|
||||
if admin_user is not None and admin_pass is not None:
|
||||
url = "{}://{}:{}@{}{}/{}".format(
|
||||
uri, quote_plus(admin_user), quote_plus(admin_pass), host, ports, db_name
|
||||
)
|
||||
else:
|
||||
url = "{}://{}{}/{}".format(uri, host, ports, db_name)
|
||||
|
||||
global _conn
|
||||
_conn = motor.motor_asyncio.AsyncIOMotorClient(url, retryWrites=True)
|
||||
|
||||
|
||||
class Mongo(BaseDriver):
|
||||
class MongoDriver(BaseDriver):
|
||||
"""
|
||||
Subclass of :py:class:`.red_base.BaseDriver`.
|
||||
Subclass of :py:class:`.BaseDriver`.
|
||||
"""
|
||||
|
||||
def __init__(self, cog_name, identifier, **kwargs):
|
||||
super().__init__(cog_name, identifier)
|
||||
_conn: Optional["motor.motor_asyncio.AsyncIOMotorClient"] = None
|
||||
|
||||
if _conn is None:
|
||||
_initialize(**kwargs)
|
||||
@classmethod
|
||||
async def initialize(cls, **storage_details) -> None:
|
||||
if motor is None:
|
||||
raise errors.MissingExtraRequirements(
|
||||
"Red must be installed with the [mongo] extra to use the MongoDB driver"
|
||||
)
|
||||
uri = storage_details.get("URI", "mongodb")
|
||||
host = storage_details["HOST"]
|
||||
port = storage_details["PORT"]
|
||||
user = storage_details["USERNAME"]
|
||||
password = storage_details["PASSWORD"]
|
||||
database = storage_details.get("DB_NAME", "default_db")
|
||||
|
||||
async def has_valid_connection(self) -> bool:
|
||||
# Maybe fix this?
|
||||
return True
|
||||
if port is 0:
|
||||
ports = ""
|
||||
else:
|
||||
ports = ":{}".format(port)
|
||||
|
||||
if user is not None and password is not None:
|
||||
url = "{}://{}:{}@{}{}/{}".format(
|
||||
uri, quote_plus(user), quote_plus(password), host, ports, database
|
||||
)
|
||||
else:
|
||||
url = "{}://{}{}/{}".format(uri, host, ports, database)
|
||||
|
||||
cls._conn = motor.motor_asyncio.AsyncIOMotorClient(url, retryWrites=True)
|
||||
|
||||
@classmethod
|
||||
async def teardown(cls) -> None:
|
||||
if cls._conn is not None:
|
||||
cls._conn.close()
|
||||
|
||||
@staticmethod
|
||||
def get_config_details():
|
||||
while True:
|
||||
uri = input("Enter URI scheme (mongodb or mongodb+srv): ")
|
||||
if uri is "":
|
||||
uri = "mongodb"
|
||||
|
||||
if uri in ["mongodb", "mongodb+srv"]:
|
||||
break
|
||||
else:
|
||||
print("Invalid URI scheme")
|
||||
|
||||
host = input("Enter host address: ")
|
||||
if uri is "mongodb":
|
||||
port = int(input("Enter host port: "))
|
||||
else:
|
||||
port = 0
|
||||
|
||||
admin_uname = input("Enter login username: ")
|
||||
admin_password = getpass("Enter login password: ")
|
||||
|
||||
db_name = input("Enter mongodb database name: ")
|
||||
|
||||
if admin_uname == "":
|
||||
admin_uname = admin_password = None
|
||||
|
||||
ret = {
|
||||
"HOST": host,
|
||||
"PORT": port,
|
||||
"USERNAME": admin_uname,
|
||||
"PASSWORD": admin_password,
|
||||
"DB_NAME": db_name,
|
||||
"URI": uri,
|
||||
}
|
||||
return ret
|
||||
|
||||
@property
|
||||
def db(self) -> motor.core.Database:
|
||||
def db(self) -> "motor.core.Database":
|
||||
"""
|
||||
Gets the mongo database for this cog's name.
|
||||
|
||||
.. warning::
|
||||
|
||||
Right now this will cause a new connection to be made every time the
|
||||
database is accessed. We will want to create a connection pool down the
|
||||
line to limit the number of connections.
|
||||
|
||||
:return:
|
||||
PyMongo Database object.
|
||||
"""
|
||||
return _conn.get_database()
|
||||
return self._conn.get_database()
|
||||
|
||||
def get_collection(self, category: str) -> motor.core.Collection:
|
||||
def get_collection(self, category: str) -> "motor.core.Collection":
|
||||
"""
|
||||
Gets a specified collection within the PyMongo database for this cog.
|
||||
|
||||
@@ -85,12 +118,13 @@ class Mongo(BaseDriver):
|
||||
"""
|
||||
return self.db[self.cog_name][category]
|
||||
|
||||
def get_primary_key(self, identifier_data: IdentifierData) -> Tuple[str]:
|
||||
@staticmethod
|
||||
def get_primary_key(identifier_data: IdentifierData) -> Tuple[str, ...]:
|
||||
# noinspection PyTypeChecker
|
||||
return identifier_data.primary_key
|
||||
|
||||
async def rebuild_dataset(
|
||||
self, identifier_data: IdentifierData, cursor: motor.motor_asyncio.AsyncIOMotorCursor
|
||||
self, identifier_data: IdentifierData, cursor: "motor.motor_asyncio.AsyncIOMotorCursor"
|
||||
):
|
||||
ret = {}
|
||||
async for doc in cursor:
|
||||
@@ -141,16 +175,16 @@ class Mongo(BaseDriver):
|
||||
async def set(self, identifier_data: IdentifierData, value=None):
|
||||
uuid = self._escape_key(identifier_data.uuid)
|
||||
primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data)))
|
||||
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
|
||||
if isinstance(value, dict):
|
||||
if len(value) == 0:
|
||||
await self.clear(identifier_data)
|
||||
return
|
||||
value = self._escape_dict_keys(value)
|
||||
mongo_collection = self.get_collection(identifier_data.category)
|
||||
pkey_len = self.get_pkey_len(identifier_data)
|
||||
num_pkeys = len(primary_key)
|
||||
|
||||
if num_pkeys >= pkey_len:
|
||||
if num_pkeys >= identifier_data.primary_key_len:
|
||||
# We're setting at the document level or below.
|
||||
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
|
||||
if dot_identifiers:
|
||||
@@ -158,11 +192,23 @@ class Mongo(BaseDriver):
|
||||
else:
|
||||
update_stmt = {"$set": value}
|
||||
|
||||
await mongo_collection.update_one(
|
||||
{"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}},
|
||||
update=update_stmt,
|
||||
upsert=True,
|
||||
)
|
||||
try:
|
||||
await mongo_collection.update_one(
|
||||
{"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}},
|
||||
update=update_stmt,
|
||||
upsert=True,
|
||||
)
|
||||
except pymongo.errors.WriteError as exc:
|
||||
if exc.args and exc.args[0].startswith("Cannot create field"):
|
||||
# There's a bit of a failing edge case here...
|
||||
# If we accidentally set the sub-field of an array, and the key happens to be a
|
||||
# digit, it will successfully set the value in the array, and not raise an
|
||||
# error. This is different to how other drivers would behave, and could lead to
|
||||
# unexpected behaviour.
|
||||
raise errors.CannotSetSubfield
|
||||
else:
|
||||
# Unhandled driver exception, should expose.
|
||||
raise
|
||||
|
||||
else:
|
||||
# We're setting above the document level.
|
||||
@@ -171,15 +217,17 @@ class Mongo(BaseDriver):
|
||||
# We'll do it in a transaction so we can roll-back in case something goes horribly
|
||||
# wrong.
|
||||
pkey_filter = self.generate_primary_key_filter(identifier_data)
|
||||
async with await _conn.start_session() as session:
|
||||
async with await self._conn.start_session() as session:
|
||||
with contextlib.suppress(pymongo.errors.CollectionInvalid):
|
||||
# Collections must already exist when inserting documents within a transaction
|
||||
await _conn.get_database().create_collection(mongo_collection.full_name)
|
||||
await self.db.create_collection(mongo_collection.full_name)
|
||||
try:
|
||||
async with session.start_transaction():
|
||||
await mongo_collection.delete_many(pkey_filter, session=session)
|
||||
await mongo_collection.insert_many(
|
||||
self.generate_documents_to_insert(uuid, primary_key, value, pkey_len),
|
||||
self.generate_documents_to_insert(
|
||||
uuid, primary_key, value, identifier_data.primary_key_len
|
||||
),
|
||||
session=session,
|
||||
)
|
||||
except pymongo.errors.OperationFailure:
|
||||
@@ -218,7 +266,7 @@ class Mongo(BaseDriver):
|
||||
|
||||
# What's left of `value` should be the new documents needing to be inserted.
|
||||
to_insert = self.generate_documents_to_insert(
|
||||
uuid, primary_key, value, pkey_len
|
||||
uuid, primary_key, value, identifier_data.primary_key_len
|
||||
)
|
||||
requests = list(
|
||||
itertools.chain(
|
||||
@@ -289,6 +337,59 @@ class Mongo(BaseDriver):
|
||||
for result in results:
|
||||
await db[result["name"]].delete_many(pkey_filter)
|
||||
|
||||
@classmethod
|
||||
async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
|
||||
db = cls._conn.get_database()
|
||||
for collection_name in await db.list_collection_names():
|
||||
parts = collection_name.split(".")
|
||||
if not len(parts) == 2:
|
||||
continue
|
||||
cog_name = parts[0]
|
||||
for cog_id in await db[collection_name].distinct("_id.RED_uuid"):
|
||||
yield cog_name, cog_id
|
||||
|
||||
@classmethod
|
||||
async def delete_all_data(
|
||||
cls, *, interactive: bool = False, drop_db: Optional[bool] = None, **kwargs
|
||||
) -> None:
|
||||
"""Delete all data being stored by this driver.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
interactive : bool
|
||||
Set to ``True`` to allow the method to ask the user for
|
||||
input from the console, regarding the other unset parameters
|
||||
for this method.
|
||||
drop_db : Optional[bool]
|
||||
Set to ``True`` to drop the entire database for the current
|
||||
bot's instance. Otherwise, collections which appear to be
|
||||
storing bot data will be dropped.
|
||||
|
||||
"""
|
||||
if interactive is True and drop_db is None:
|
||||
print(
|
||||
"Please choose from one of the following options:\n"
|
||||
" 1. Drop the entire MongoDB database for this instance, or\n"
|
||||
" 2. Delete all of Red's data within this database, without dropping the database "
|
||||
"itself."
|
||||
)
|
||||
options = ("1", "2")
|
||||
while True:
|
||||
resp = input("> ")
|
||||
try:
|
||||
drop_db = bool(options.index(resp))
|
||||
except ValueError:
|
||||
print("Please type a number corresponding to one of the options.")
|
||||
else:
|
||||
break
|
||||
db = cls._conn.get_database()
|
||||
if drop_db is True:
|
||||
await cls._conn.drop_database(db)
|
||||
else:
|
||||
async with await cls._conn.start_session() as session:
|
||||
async for cog_name, cog_id in cls.aiter_cogs():
|
||||
await db.drop_collection(db[cog_name], session=session)
|
||||
|
||||
@staticmethod
|
||||
def _escape_key(key: str) -> str:
|
||||
return _SPECIAL_CHAR_PATTERN.sub(_replace_with_escaped, key)
|
||||
@@ -344,40 +445,3 @@ _CHAR_ESCAPES = {
|
||||
|
||||
def _replace_with_unescaped(match: Match[str]) -> str:
|
||||
return _CHAR_ESCAPES[match[0]]
|
||||
|
||||
|
||||
def get_config_details():
|
||||
uri = None
|
||||
while True:
|
||||
uri = input("Enter URI scheme (mongodb or mongodb+srv): ")
|
||||
if uri is "":
|
||||
uri = "mongodb"
|
||||
|
||||
if uri in ["mongodb", "mongodb+srv"]:
|
||||
break
|
||||
else:
|
||||
print("Invalid URI scheme")
|
||||
|
||||
host = input("Enter host address: ")
|
||||
if uri is "mongodb":
|
||||
port = int(input("Enter host port: "))
|
||||
else:
|
||||
port = 0
|
||||
|
||||
admin_uname = input("Enter login username: ")
|
||||
admin_password = getpass("Enter login password: ")
|
||||
|
||||
db_name = input("Enter mongodb database name: ")
|
||||
|
||||
if admin_uname == "":
|
||||
admin_uname = admin_password = None
|
||||
|
||||
ret = {
|
||||
"HOST": host,
|
||||
"PORT": port,
|
||||
"USERNAME": admin_uname,
|
||||
"PASSWORD": admin_password,
|
||||
"DB_NAME": db_name,
|
||||
"URI": uri,
|
||||
}
|
||||
return ret
|
||||
3
redbot/core/drivers/postgres/__init__.py
Normal file
3
redbot/core/drivers/postgres/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from .postgres import PostgresDriver
|
||||
|
||||
__all__ = ["PostgresDriver"]
|
||||
839
redbot/core/drivers/postgres/ddl.sql
Normal file
839
redbot/core/drivers/postgres/ddl.sql
Normal file
@@ -0,0 +1,839 @@
|
||||
/*
|
||||
************************************************************
|
||||
* PostgreSQL driver Data Definition Language (DDL) Script. *
|
||||
************************************************************
|
||||
*/
|
||||
|
||||
CREATE SCHEMA IF NOT EXISTS red_config;
|
||||
CREATE SCHEMA IF NOT EXISTS red_utils;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
PERFORM 'red_config.identifier_data'::regtype;
|
||||
EXCEPTION
|
||||
WHEN UNDEFINED_OBJECT THEN
|
||||
CREATE TYPE red_config.identifier_data AS (
|
||||
cog_name text,
|
||||
cog_id text,
|
||||
category text,
|
||||
pkeys text[],
|
||||
identifiers text[],
|
||||
pkey_len integer,
|
||||
is_custom boolean
|
||||
);
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Create the config schema and/or table if they do not exist yet.
|
||||
*/
|
||||
red_config.maybe_create_table(
|
||||
id_data red_config.identifier_data
|
||||
)
|
||||
RETURNS void
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
schema_exists CONSTANT boolean := exists(
|
||||
SELECT 1
|
||||
FROM red_config.red_cogs t
|
||||
WHERE t.cog_name = id_data.cog_name AND t.cog_id = id_data.cog_id);
|
||||
table_exists CONSTANT boolean := schema_exists AND exists(
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = schemaname AND table_name = id_data.category);
|
||||
|
||||
BEGIN
|
||||
IF NOT schema_exists THEN
|
||||
PERFORM red_config.create_schema(id_data.cog_name, id_data.cog_id);
|
||||
END IF;
|
||||
IF NOT table_exists THEN
|
||||
PERFORM red_config.create_table(id_data);
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Create the config schema for the given cog.
|
||||
*/
|
||||
red_config.create_schema(new_cog_name text, new_cog_id text, OUT schemaname text)
|
||||
RETURNS text
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
BEGIN
|
||||
schemaname := concat_ws('.', new_cog_name, new_cog_id);
|
||||
|
||||
EXECUTE format('CREATE SCHEMA IF NOT EXISTS %I', schemaname);
|
||||
|
||||
INSERT INTO red_config.red_cogs AS t VALUES(new_cog_name, new_cog_id, schemaname)
|
||||
ON CONFLICT(cog_name, cog_id) DO UPDATE
|
||||
SET
|
||||
schemaname = excluded.schemaname;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Create the config table for the given category.
|
||||
*/
|
||||
red_config.create_table(id_data red_config.identifier_data)
|
||||
RETURNS void
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
constraintname CONSTANT text := id_data.category||'_pkey';
|
||||
pkey_columns CONSTANT text := red_utils.gen_pkey_columns(1, id_data.pkey_len);
|
||||
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
|
||||
pkey_column_definitions CONSTANT text := red_utils.gen_pkey_column_definitions(
|
||||
1, id_data.pkey_len, pkey_type);
|
||||
|
||||
BEGIN
|
||||
EXECUTE format(
|
||||
$query$
|
||||
CREATE TABLE IF NOT EXISTS %I.%I (
|
||||
%s,
|
||||
json_data jsonb DEFAULT '{}' NOT NULL,
|
||||
CONSTRAINT %I PRIMARY KEY (%s)
|
||||
)
|
||||
$query$,
|
||||
schemaname,
|
||||
id_data.category,
|
||||
pkey_column_definitions,
|
||||
constraintname,
|
||||
pkey_columns);
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Get config data.
|
||||
*
|
||||
* - When `pkeys` is a full primary key, all or part of a document
|
||||
* will be returned.
|
||||
* - When `pkeys` is not a full primary key, documents will be
|
||||
* aggregated together into a single JSONB object, with primary keys
|
||||
* as keys mapping to the documents.
|
||||
*/
|
||||
red_config.get(
|
||||
id_data red_config.identifier_data,
|
||||
OUT result jsonb
|
||||
)
|
||||
LANGUAGE 'plpgsql'
|
||||
STABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
num_pkeys CONSTANT integer := coalesce(array_length(id_data.pkeys, 1), 0);
|
||||
num_missing_pkeys CONSTANT integer := id_data.pkey_len - num_pkeys;
|
||||
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
|
||||
whereclause CONSTANT text := red_utils.gen_whereclause(num_pkeys, pkey_type);
|
||||
|
||||
missing_pkey_columns text;
|
||||
|
||||
BEGIN
|
||||
IF num_missing_pkeys <= 0 THEN
|
||||
-- No missing primary keys: we're getting all or part of a document.
|
||||
EXECUTE format(
|
||||
'SELECT json_data #> $2 FROM %I.%I WHERE %s',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
INTO result
|
||||
USING id_data.pkeys, id_data.identifiers;
|
||||
|
||||
ELSIF num_missing_pkeys = 1 THEN
|
||||
-- 1 missing primary key: we can use the built-in jsonb_object_agg() aggregate function.
|
||||
EXECUTE format(
|
||||
'SELECT jsonb_object_agg(%I::text, json_data) FROM %I.%I WHERE %s',
|
||||
'primary_key_'||id_data.pkey_len,
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
INTO result
|
||||
USING id_data.pkeys;
|
||||
ELSE
|
||||
-- Multiple missing primary keys: we must use our custom red_utils.jsonb_object_agg2()
|
||||
-- aggregate function.
|
||||
missing_pkey_columns := red_utils.gen_pkey_columns_casted(num_pkeys + 1, id_data.pkey_len);
|
||||
|
||||
EXECUTE format(
|
||||
'SELECT red_utils.jsonb_object_agg2(json_data, %s) FROM %I.%I WHERE %s',
|
||||
missing_pkey_columns,
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
INTO result
|
||||
USING id_data.pkeys;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Set config data.
|
||||
*
|
||||
* - When `pkeys` is a full primary key, all or part of a document
|
||||
* will be set.
|
||||
* - When `pkeys` is not a full set, multiple documents will be
|
||||
* replaced or removed - `new_value` must be a JSONB object mapping
|
||||
* primary keys to the new documents.
|
||||
*
|
||||
* Raises `error_in_assignment` error when trying to set a sub-key
|
||||
* of a non-document type.
|
||||
*/
|
||||
red_config.set(
|
||||
id_data red_config.identifier_data,
|
||||
new_value jsonb
|
||||
)
|
||||
RETURNS void
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
constraintname CONSTANT text := id_data.category||'_pkey';
|
||||
num_pkeys CONSTANT integer := coalesce(array_length(id_data.pkeys, 1), 0);
|
||||
num_missing_pkeys CONSTANT integer := id_data.pkey_len - num_pkeys;
|
||||
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
|
||||
pkey_placeholders CONSTANT text := red_utils.gen_pkey_placeholders(num_pkeys, pkey_type);
|
||||
|
||||
new_document jsonb;
|
||||
pkey_column_definitions text;
|
||||
whereclause text;
|
||||
missing_pkey_columns text;
|
||||
|
||||
BEGIN
|
||||
PERFORM red_config.maybe_create_table(id_data);
|
||||
|
||||
IF num_missing_pkeys = 0 THEN
|
||||
-- Setting all or part of a document
|
||||
new_document := red_utils.jsonb_set2('{}', new_value, VARIADIC id_data.identifiers);
|
||||
|
||||
EXECUTE format(
|
||||
$query$
|
||||
INSERT INTO %I.%I AS t VALUES (%s, $2)
|
||||
ON CONFLICT ON CONSTRAINT %I DO UPDATE
|
||||
SET
|
||||
json_data = red_utils.jsonb_set2(t.json_data, $3, VARIADIC $4)
|
||||
$query$,
|
||||
schemaname,
|
||||
id_data.category,
|
||||
pkey_placeholders,
|
||||
constraintname)
|
||||
USING id_data.pkeys, new_document, new_value, id_data.identifiers;
|
||||
|
||||
ELSE
|
||||
-- Setting multiple documents
|
||||
whereclause := red_utils.gen_whereclause(num_pkeys, pkey_type);
|
||||
missing_pkey_columns := red_utils.gen_pkey_columns_casted(
|
||||
num_pkeys + 1, id_data.pkey_len, pkey_type);
|
||||
pkey_column_definitions := red_utils.gen_pkey_column_definitions(num_pkeys + 1, id_data.pkey_len);
|
||||
|
||||
-- Delete all documents which we're setting first, since we don't know whether they'll be
|
||||
-- replaced by the subsequent INSERT.
|
||||
EXECUTE format('DELETE FROM %I.%I WHERE %s', schemaname, id_data.category, whereclause)
|
||||
USING id_data.pkeys;
|
||||
|
||||
-- Insert all new documents
|
||||
EXECUTE format(
|
||||
$query$
|
||||
INSERT INTO %I.%I AS t
|
||||
SELECT %s, json_data
|
||||
FROM red_utils.generate_rows_from_object($2, $3) AS f(%s, json_data jsonb)
|
||||
ON CONFLICT ON CONSTRAINT %I DO UPDATE
|
||||
SET
|
||||
json_data = excluded.json_data
|
||||
$query$,
|
||||
schemaname,
|
||||
id_data.category,
|
||||
concat_ws(', ', pkey_placeholders, missing_pkey_columns),
|
||||
pkey_column_definitions,
|
||||
constraintname)
|
||||
USING id_data.pkeys, new_value, num_missing_pkeys;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Clear config data.
|
||||
*
|
||||
* - When `identifiers` is not empty, this will clear a key from a
|
||||
* document.
|
||||
* - When `identifiers` is empty and `pkeys` is not empty, it will
|
||||
* delete one or more documents.
|
||||
* - When `pkeys` is empty, it will drop the whole table.
|
||||
* - When `id_data.category` is NULL or an empty string, it will drop
|
||||
* the whole schema.
|
||||
*
|
||||
* Has no effect when the document or key does not exist.
|
||||
*/
|
||||
red_config.clear(
|
||||
id_data red_config.identifier_data
|
||||
)
|
||||
RETURNS void
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
num_pkeys CONSTANT integer := coalesce(array_length(id_data.pkeys, 1), 0);
|
||||
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
|
||||
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
|
||||
|
||||
whereclause text;
|
||||
|
||||
BEGIN
|
||||
IF num_identifiers > 0 THEN
|
||||
-- Popping a key from a document or nested document.
|
||||
whereclause := red_utils.gen_whereclause(num_pkeys, pkey_type);
|
||||
|
||||
EXECUTE format(
|
||||
$query$
|
||||
UPDATE %I.%I AS t
|
||||
SET
|
||||
json_data = t.json_data #- $2
|
||||
WHERE %s
|
||||
$query$,
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
USING id_data.pkeys, id_data.identifiers;
|
||||
|
||||
ELSIF num_pkeys > 0 THEN
|
||||
-- Deleting one or many documents
|
||||
whereclause := red_utils.gen_whereclause(num_pkeys, pkey_type);
|
||||
|
||||
EXECUTE format('DELETE FROM %I.%I WHERE %s', schemaname, id_data.category, whereclause)
|
||||
USING id_data.pkeys;
|
||||
|
||||
ELSIF id_data.category IS NOT NULL AND id_data.category != '' THEN
|
||||
-- Deleting an entire category
|
||||
EXECUTE format('DROP TABLE %I.%I CASCADE', schemaname, id_data.category);
|
||||
|
||||
ELSE
|
||||
-- Deleting an entire cog's data
|
||||
EXECUTE format('DROP SCHEMA %I CASCADE', schemaname);
|
||||
|
||||
DELETE FROM red_config.red_cogs
|
||||
WHERE cog_name = id_data.cog_name AND cog_id = id_data.cog_id;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Increment a number within a document.
|
||||
*
|
||||
* If the value doesn't already exist, it is inserted as
|
||||
* `default_value + amount`.
|
||||
*
|
||||
* Raises 'wrong_object_type' error when trying to increment a
|
||||
* non-numeric value.
|
||||
*/
|
||||
red_config.inc(
|
||||
id_data red_config.identifier_data,
|
||||
amount numeric,
|
||||
default_value numeric,
|
||||
OUT result numeric
|
||||
)
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
|
||||
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
|
||||
whereclause CONSTANT text := red_utils.gen_whereclause(id_data.pkey_len, pkey_type);
|
||||
|
||||
new_document jsonb;
|
||||
existing_document jsonb;
|
||||
existing_value jsonb;
|
||||
pkey_placeholders text;
|
||||
|
||||
BEGIN
|
||||
IF num_identifiers = 0 THEN
|
||||
-- Without identifiers, there's no chance we're actually incrementing a number
|
||||
RAISE EXCEPTION 'Cannot increment document(s)'
|
||||
USING ERRCODE = 'wrong_object_type';
|
||||
END IF;
|
||||
|
||||
PERFORM red_config.maybe_create_table(id_data);
|
||||
|
||||
-- Look for the existing document
|
||||
EXECUTE format(
|
||||
'SELECT json_data FROM %I.%I WHERE %s',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
INTO existing_document USING id_data.pkeys;
|
||||
|
||||
IF existing_document IS NULL THEN
|
||||
-- We need to insert a new document
|
||||
result := default_value + amount;
|
||||
new_document := red_utils.jsonb_set2('{}', result, VARIADIC id_data.identifiers);
|
||||
pkey_placeholders := red_utils.gen_pkey_placeholders(id_data.pkey_len, pkey_type);
|
||||
|
||||
EXECUTE format(
|
||||
'INSERT INTO %I.%I VALUES(%s, $2)',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
pkey_placeholders)
|
||||
USING id_data.pkeys, new_document;
|
||||
|
||||
ELSE
|
||||
-- We need to update the existing document
|
||||
existing_value := existing_document #> id_data.identifiers;
|
||||
|
||||
IF existing_value IS NULL THEN
|
||||
result := default_value + amount;
|
||||
|
||||
ELSIF jsonb_typeof(existing_value) = 'number' THEN
|
||||
result := existing_value::text::numeric + amount;
|
||||
|
||||
ELSE
|
||||
RAISE EXCEPTION 'Cannot increment non-numeric value %', existing_value
|
||||
USING ERRCODE = 'wrong_object_type';
|
||||
END IF;
|
||||
|
||||
new_document := red_utils.jsonb_set2(
|
||||
existing_document, to_jsonb(result), id_data.identifiers);
|
||||
|
||||
EXECUTE format(
|
||||
'UPDATE %I.%I SET json_data = $2 WHERE %s',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
USING id_data.pkeys, new_document;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Toggle a boolean within a document.
|
||||
*
|
||||
* If the value doesn't already exist, it is inserted as `NOT
|
||||
* default_value`.
|
||||
*
|
||||
* Raises 'wrong_object_type' error when trying to toggle a
|
||||
* non-boolean value.
|
||||
*/
|
||||
red_config.toggle(
|
||||
id_data red_config.identifier_data,
|
||||
default_value boolean,
|
||||
OUT result boolean
|
||||
)
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
|
||||
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
|
||||
whereclause CONSTANT text := red_utils.gen_whereclause(id_data.pkey_len, pkey_type);
|
||||
|
||||
new_document jsonb;
|
||||
existing_document jsonb;
|
||||
existing_value jsonb;
|
||||
pkey_placeholders text;
|
||||
|
||||
BEGIN
|
||||
IF num_identifiers = 0 THEN
|
||||
-- Without identifiers, there's no chance we're actually toggling a boolean
|
||||
RAISE EXCEPTION 'Cannot increment document(s)'
|
||||
USING ERRCODE = 'wrong_object_type';
|
||||
END IF;
|
||||
|
||||
PERFORM red_config.maybe_create_table(id_data);
|
||||
|
||||
-- Look for the existing document
|
||||
EXECUTE format(
|
||||
'SELECT json_data FROM %I.%I WHERE %s',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
INTO existing_document USING id_data.pkeys;
|
||||
|
||||
IF existing_document IS NULL THEN
|
||||
-- We need to insert a new document
|
||||
result := NOT default_value;
|
||||
new_document := red_utils.jsonb_set2('{}', result, VARIADIC id_data.identifiers);
|
||||
pkey_placeholders := red_utils.gen_pkey_placeholders(id_data.pkey_len, pkey_type);
|
||||
|
||||
EXECUTE format(
|
||||
'INSERT INTO %I.%I VALUES(%s, $2)',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
pkey_placeholders)
|
||||
USING id_data.pkeys, new_document;
|
||||
|
||||
ELSE
|
||||
-- We need to update the existing document
|
||||
existing_value := existing_document #> id_data.identifiers;
|
||||
|
||||
IF existing_value IS NULL THEN
|
||||
result := NOT default_value;
|
||||
|
||||
ELSIF jsonb_typeof(existing_value) = 'boolean' THEN
|
||||
result := NOT existing_value::text::boolean;
|
||||
|
||||
ELSE
|
||||
RAISE EXCEPTION 'Cannot increment non-boolean value %', existing_value
|
||||
USING ERRCODE = 'wrong_object_type';
|
||||
END IF;
|
||||
|
||||
new_document := red_utils.jsonb_set2(
|
||||
existing_document, to_jsonb(result), id_data.identifiers);
|
||||
|
||||
EXECUTE format(
|
||||
'UPDATE %I.%I SET json_data = $2 WHERE %s',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
USING id_data.pkeys, new_document;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
red_config.extend(
|
||||
id_data red_config.identifier_data,
|
||||
new_value text,
|
||||
default_value text,
|
||||
max_length integer DEFAULT NULL,
|
||||
extend_left boolean DEFAULT FALSE,
|
||||
OUT result jsonb
|
||||
)
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
|
||||
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
|
||||
whereclause CONSTANT text := red_utils.gen_whereclause(id_data.pkey_len, pkey_type);
|
||||
pop_idx CONSTANT integer := CASE extend_left WHEN TRUE THEN -1 ELSE 0 END;
|
||||
|
||||
new_document jsonb;
|
||||
existing_document jsonb;
|
||||
existing_value jsonb;
|
||||
pkey_placeholders text;
|
||||
idx integer;
|
||||
BEGIN
|
||||
IF num_identifiers = 0 THEN
|
||||
-- Without identifiers, there's no chance we're actually appending to an array
|
||||
RAISE EXCEPTION 'Cannot append to document(s)'
|
||||
USING ERRCODE = 'wrong_object_type';
|
||||
END IF;
|
||||
|
||||
PERFORM red_config.maybe_create_table(id_data);
|
||||
|
||||
-- Look for the existing document
|
||||
EXECUTE format(
|
||||
'SELECT json_data FROM %I.%I WHERE %s',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
INTO existing_document USING id_data.pkeys;
|
||||
|
||||
IF existing_document IS NULL THEN
|
||||
result := default_value || new_value;
|
||||
new_document := red_utils.jsonb_set2('{}'::jsonb, result, id_data.identifiers);
|
||||
pkey_placeholders := red_utils.gen_pkey_placeholders(id_data.pkey_len, pkey_type);
|
||||
|
||||
EXECUTE format(
|
||||
'INSERT INTO %I.%I VALUES(%s, $2)',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
pkey_placeholders)
|
||||
USING id_data.pkeys, new_document;
|
||||
|
||||
ELSE
|
||||
existing_value := existing_document #> id_data.identifiers;
|
||||
|
||||
IF existing_value IS NULL THEN
|
||||
existing_value := default_value;
|
||||
|
||||
ELSIF jsonb_typeof(existing_value) != 'array' THEN
|
||||
RAISE EXCEPTION 'Cannot append to non-array value %', existing_value
|
||||
USING ERRCODE = 'wrong_object_type';
|
||||
END IF;
|
||||
|
||||
CASE extend_left
|
||||
WHEN TRUE THEN
|
||||
result := new_value || existing_value;
|
||||
ELSE
|
||||
result := existing_value || new_value;
|
||||
END CASE;
|
||||
|
||||
IF max_length IS NOT NULL THEN
|
||||
FOR idx IN SELECT generate_series(1, jsonb_array_length(result) - max_length) LOOP
|
||||
result := result - pop_idx;
|
||||
END LOOP;
|
||||
END IF;
|
||||
|
||||
new_document := red_utils.jsonb_set2(existing_document, result, id_data.identifiers);
|
||||
|
||||
EXECUTE format(
|
||||
'UPDATE %I.%I SET json_data = $2 WHERE %s',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
USING id_data.pkeys, new_document;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Delete all schemas listed in the red_config.red_cogs table.
|
||||
*/
|
||||
red_config.delete_all_schemas()
|
||||
RETURNS void
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
cog_entry record;
|
||||
BEGIN
|
||||
FOR cog_entry IN SELECT * FROM red_config.red_cogs t LOOP
|
||||
EXECUTE format('DROP SCHEMA %I CASCADE', cog_entry.schemaname);
|
||||
END LOOP;
|
||||
-- Clear out red_config.red_cogs table
|
||||
DELETE FROM red_config.red_cogs WHERE TRUE;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Like `jsonb_set` but will insert new objects where one is missing
|
||||
* along the path.
|
||||
*
|
||||
* Raises `error_in_assignment` error when trying to set a sub-key
|
||||
* of a non-document type.
|
||||
*/
|
||||
red_utils.jsonb_set2(target jsonb, new_value jsonb, VARIADIC identifiers text[])
|
||||
RETURNS jsonb
|
||||
LANGUAGE 'plpgsql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
DECLARE
|
||||
num_identifiers CONSTANT integer := coalesce(array_length(identifiers, 1), 0);
|
||||
|
||||
cur_value_type text;
|
||||
idx integer;
|
||||
|
||||
BEGIN
|
||||
IF num_identifiers = 0 THEN
|
||||
RETURN new_value;
|
||||
END IF;
|
||||
|
||||
FOR idx IN SELECT generate_series(1, num_identifiers - 1) LOOP
|
||||
cur_value_type := jsonb_typeof(target #> identifiers[:idx]);
|
||||
IF cur_value_type IS NULL THEN
|
||||
-- Parent key didn't exist in JSON before - insert new object
|
||||
target := jsonb_set(target, identifiers[:idx], '{}'::jsonb);
|
||||
|
||||
ELSIF cur_value_type != 'object' THEN
|
||||
-- We can't set the sub-field of a null, int, float, array etc.
|
||||
RAISE EXCEPTION 'Cannot set sub-field of "%s"', cur_value_type
|
||||
USING ERRCODE = 'error_in_assignment';
|
||||
END IF;
|
||||
END LOOP;
|
||||
|
||||
RETURN jsonb_set(target, identifiers, new_value);
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Return a set of rows to insert into a table, from a single JSONB
|
||||
* object containing multiple documents.
|
||||
*/
|
||||
red_utils.generate_rows_from_object(object jsonb, num_missing_pkeys integer)
|
||||
RETURNS setof record
|
||||
LANGUAGE 'plpgsql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
DECLARE
|
||||
pair record;
|
||||
column_definitions text;
|
||||
BEGIN
|
||||
IF num_missing_pkeys = 1 THEN
|
||||
-- Base case: Simply return (key, value) pairs
|
||||
RETURN QUERY
|
||||
SELECT key AS key_1, value AS json_data
|
||||
FROM jsonb_each(object);
|
||||
ELSE
|
||||
-- We need to return (key, key, ..., value) pairs: recurse into inner JSONB objects
|
||||
column_definitions := red_utils.gen_pkey_column_definitions(2, num_missing_pkeys);
|
||||
|
||||
FOR pair IN SELECT * FROM jsonb_each(object) LOOP
|
||||
RETURN QUERY
|
||||
EXECUTE format(
|
||||
$query$
|
||||
SELECT $1 AS key_1, *
|
||||
FROM red_utils.generate_rows_from_object($2, $3)
|
||||
AS f(%s, json_data jsonb)
|
||||
$query$,
|
||||
column_definitions)
|
||||
USING pair.key, pair.value, num_missing_pkeys - 1;
|
||||
END LOOP;
|
||||
END IF;
|
||||
RETURN;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Get a comma-separated list of primary key placeholders.
|
||||
*
|
||||
* The placeholder will always be $1. Particularly useful for
|
||||
* inserting values into a table from an array of primary keys.
|
||||
*/
|
||||
red_utils.gen_pkey_placeholders(num_pkeys integer, pkey_type text DEFAULT 'text')
|
||||
RETURNS text
|
||||
LANGUAGE 'sql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
SELECT string_agg(t.item, ', ')
|
||||
FROM (
|
||||
SELECT format('$1[%s]::%s', idx, pkey_type) AS item
|
||||
FROM generate_series(1, num_pkeys) idx) t
|
||||
;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Generate a whereclause for the given number of primary keys.
|
||||
*
|
||||
* When there are no primary keys, this will simply return the the
|
||||
* string 'TRUE'. When there are multiple, it will return multiple
|
||||
* equality comparisons concatenated with 'AND'.
|
||||
*/
|
||||
red_utils.gen_whereclause(num_pkeys integer, pkey_type text)
|
||||
RETURNS text
|
||||
LANGUAGE 'sql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
SELECT coalesce(string_agg(t.item, ' AND '), 'TRUE')
|
||||
FROM (
|
||||
SELECT format('%I = $1[%s]::%s', 'primary_key_'||idx, idx, pkey_type) AS item
|
||||
FROM generate_series(1, num_pkeys) idx) t
|
||||
;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Generate a comma-separated list of primary key column names.
|
||||
*/
|
||||
red_utils.gen_pkey_columns(start integer, stop integer)
|
||||
RETURNS text
|
||||
LANGUAGE 'sql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
SELECT string_agg(t.item, ', ')
|
||||
FROM (
|
||||
SELECT quote_ident('primary_key_'||idx) AS item
|
||||
FROM generate_series(start, stop) idx) t
|
||||
;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Generate a comma-separated list of primary key column names casted
|
||||
* to the given type.
|
||||
*/
|
||||
red_utils.gen_pkey_columns_casted(start integer, stop integer, pkey_type text DEFAULT 'text')
|
||||
RETURNS text
|
||||
LANGUAGE 'sql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
SELECT string_agg(t.item, ', ')
|
||||
FROM (
|
||||
SELECT format('%I::%s', 'primary_key_'||idx, pkey_type) AS item
|
||||
FROM generate_series(start, stop) idx) t
|
||||
;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Generate a primary key column definition list.
|
||||
*/
|
||||
red_utils.gen_pkey_column_definitions(
|
||||
start integer, stop integer, column_type text DEFAULT 'text'
|
||||
)
|
||||
RETURNS text
|
||||
LANGUAGE 'sql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
SELECT string_agg(t.item, ', ')
|
||||
FROM (
|
||||
SELECT format('%I %s', 'primary_key_'||idx, column_type) AS item
|
||||
FROM generate_series(start, stop) idx) t
|
||||
;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
red_utils.get_pkey_type(is_custom boolean)
|
||||
RETURNS TEXT
|
||||
LANGUAGE 'sql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
SELECT ('{bigint,text}'::text[])[is_custom::integer + 1];
|
||||
$$;
|
||||
|
||||
|
||||
DROP AGGREGATE IF EXISTS red_utils.jsonb_object_agg2(jsonb, VARIADIC text[]);
|
||||
CREATE AGGREGATE
|
||||
/*
|
||||
* Like `jsonb_object_agg` but aggregates more than two columns into a
|
||||
* single JSONB object.
|
||||
*
|
||||
* If possible, use `jsonb_object_agg` instead for performance
|
||||
* reasons.
|
||||
*/
|
||||
red_utils.jsonb_object_agg2(json_data jsonb, VARIADIC primary_keys text[]) (
|
||||
SFUNC = red_utils.jsonb_set2,
|
||||
STYPE = jsonb,
|
||||
INITCOND = '{}',
|
||||
PARALLEL = SAFE
|
||||
)
|
||||
;
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS
|
||||
/*
|
||||
* Table to keep track of other cogs' schemas.
|
||||
*/
|
||||
red_config.red_cogs(
|
||||
cog_name text,
|
||||
cog_id text,
|
||||
schemaname text NOT NULL,
|
||||
PRIMARY KEY (cog_name, cog_id)
|
||||
)
|
||||
;
|
||||
3
redbot/core/drivers/postgres/drop_ddl.sql
Normal file
3
redbot/core/drivers/postgres/drop_ddl.sql
Normal file
@@ -0,0 +1,3 @@
|
||||
SELECT red_config.delete_all_schemas();
|
||||
DROP SCHEMA IF EXISTS red_config CASCADE;
|
||||
DROP SCHEMA IF EXISTS red_utils CASCADE;
|
||||
255
redbot/core/drivers/postgres/postgres.py
Normal file
255
redbot/core/drivers/postgres/postgres.py
Normal file
@@ -0,0 +1,255 @@
|
||||
import getpass
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Any, AsyncIterator, Tuple, Union, Callable, List
|
||||
|
||||
try:
|
||||
# pylint: disable=import-error
|
||||
import asyncpg
|
||||
except ModuleNotFoundError:
|
||||
asyncpg = None
|
||||
|
||||
from ... import data_manager, errors
|
||||
from ..base import BaseDriver, IdentifierData, ConfigCategory
|
||||
from ..log import log
|
||||
|
||||
__all__ = ["PostgresDriver"]
|
||||
|
||||
_PKG_PATH = Path(__file__).parent
|
||||
DDL_SCRIPT_PATH = _PKG_PATH / "ddl.sql"
|
||||
DROP_DDL_SCRIPT_PATH = _PKG_PATH / "drop_ddl.sql"
|
||||
|
||||
|
||||
def encode_identifier_data(
|
||||
id_data: IdentifierData
|
||||
) -> Tuple[str, str, str, List[str], List[str], int, bool]:
|
||||
return (
|
||||
id_data.cog_name,
|
||||
id_data.uuid,
|
||||
id_data.category,
|
||||
["0"] if id_data.category == ConfigCategory.GLOBAL else list(id_data.primary_key),
|
||||
list(id_data.identifiers),
|
||||
1 if id_data.category == ConfigCategory.GLOBAL else id_data.primary_key_len,
|
||||
id_data.is_custom,
|
||||
)
|
||||
|
||||
|
||||
class PostgresDriver(BaseDriver):
|
||||
|
||||
_pool: Optional["asyncpg.pool.Pool"] = None
|
||||
|
||||
@classmethod
|
||||
async def initialize(cls, **storage_details) -> None:
|
||||
if asyncpg is None:
|
||||
raise errors.MissingExtraRequirements(
|
||||
"Red must be installed with the [postgres] extra to use the PostgreSQL driver"
|
||||
)
|
||||
cls._pool = await asyncpg.create_pool(**storage_details)
|
||||
with DDL_SCRIPT_PATH.open() as fs:
|
||||
await cls._pool.execute(fs.read())
|
||||
|
||||
@classmethod
|
||||
async def teardown(cls) -> None:
|
||||
if cls._pool is not None:
|
||||
await cls._pool.close()
|
||||
|
||||
@staticmethod
|
||||
def get_config_details():
|
||||
unixmsg = (
|
||||
""
|
||||
if sys.platform != "win32"
|
||||
else (
|
||||
" - Common directories for PostgreSQL Unix-domain sockets (/run/postgresql, "
|
||||
"/var/run/postgresl, /var/pgsql_socket, /private/tmp, and /tmp),\n"
|
||||
)
|
||||
)
|
||||
host = (
|
||||
input(
|
||||
f"Enter the PostgreSQL server's address.\n"
|
||||
f"If left blank, Red will try the following, in order:\n"
|
||||
f" - The PGHOST environment variable,\n{unixmsg}"
|
||||
f" - localhost.\n"
|
||||
f"> "
|
||||
)
|
||||
or None
|
||||
)
|
||||
|
||||
print(
|
||||
"Enter the PostgreSQL server port.\n"
|
||||
"If left blank, this will default to either:\n"
|
||||
" - The PGPORT environment variable,\n"
|
||||
" - 5432."
|
||||
)
|
||||
while True:
|
||||
port = input("> ") or None
|
||||
if port is None:
|
||||
break
|
||||
|
||||
try:
|
||||
port = int(port)
|
||||
except ValueError:
|
||||
print("Port must be a number")
|
||||
else:
|
||||
break
|
||||
|
||||
user = (
|
||||
input(
|
||||
"Enter the PostgreSQL server username.\n"
|
||||
"If left blank, this will default to either:\n"
|
||||
" - The PGUSER environment variable,\n"
|
||||
" - The OS name of the user running Red (ident/peer authentication).\n"
|
||||
"> "
|
||||
)
|
||||
or None
|
||||
)
|
||||
|
||||
passfile = r"%APPDATA%\postgresql\pgpass.conf" if sys.platform != "win32" else "~/.pgpass"
|
||||
password = getpass.getpass(
|
||||
f"Enter the PostgreSQL server password. The input will be hidden.\n"
|
||||
f" NOTE: If using ident/peer authentication (no password), enter NONE.\n"
|
||||
f"When NONE is entered, this will default to:\n"
|
||||
f" - The PGPASSWORD environment variable,\n"
|
||||
f" - Looking up the password in the {passfile} passfile,\n"
|
||||
f" - No password.\n"
|
||||
f"> "
|
||||
)
|
||||
if password == "NONE":
|
||||
password = None
|
||||
|
||||
database = (
|
||||
input(
|
||||
"Enter the PostgreSQL database's name.\n"
|
||||
"If left blank, this will default to either:\n"
|
||||
" - The PGDATABASE environment variable,\n"
|
||||
" - The OS name of the user running Red.\n"
|
||||
"> "
|
||||
)
|
||||
or None
|
||||
)
|
||||
|
||||
return {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"user": user,
|
||||
"password": password,
|
||||
"database": database,
|
||||
}
|
||||
|
||||
async def get(self, identifier_data: IdentifierData):
|
||||
try:
|
||||
result = await self._execute(
|
||||
"SELECT red_config.get($1)",
|
||||
encode_identifier_data(identifier_data),
|
||||
method=self._pool.fetchval,
|
||||
)
|
||||
except asyncpg.UndefinedTableError:
|
||||
raise KeyError from None
|
||||
|
||||
if result is None:
|
||||
# The result is None both when postgres yields no results, or when it yields a NULL row
|
||||
# A 'null' JSON value would be returned as encoded JSON, i.e. the string 'null'
|
||||
raise KeyError
|
||||
return json.loads(result)
|
||||
|
||||
async def set(self, identifier_data: IdentifierData, value=None):
|
||||
try:
|
||||
await self._execute(
|
||||
"SELECT red_config.set($1, $2::jsonb)",
|
||||
encode_identifier_data(identifier_data),
|
||||
json.dumps(value),
|
||||
)
|
||||
except asyncpg.ErrorInAssignmentError:
|
||||
raise errors.CannotSetSubfield
|
||||
|
||||
async def clear(self, identifier_data: IdentifierData):
|
||||
try:
|
||||
await self._execute(
|
||||
"SELECT red_config.clear($1)", encode_identifier_data(identifier_data)
|
||||
)
|
||||
except asyncpg.UndefinedTableError:
|
||||
pass
|
||||
|
||||
async def inc(
|
||||
self, identifier_data: IdentifierData, value: Union[int, float], default: Union[int, float]
|
||||
) -> Union[int, float]:
|
||||
try:
|
||||
return await self._execute(
|
||||
f"SELECT red_config.inc($1, $2, $3)",
|
||||
encode_identifier_data(identifier_data),
|
||||
value,
|
||||
default,
|
||||
method=self._pool.fetchval,
|
||||
)
|
||||
except asyncpg.WrongObjectTypeError as exc:
|
||||
raise errors.StoredTypeError(*exc.args)
|
||||
|
||||
async def toggle(self, identifier_data: IdentifierData, default: bool) -> bool:
|
||||
try:
|
||||
return await self._execute(
|
||||
"SELECT red_config.inc($1, $2)",
|
||||
encode_identifier_data(identifier_data),
|
||||
default,
|
||||
method=self._pool.fetchval,
|
||||
)
|
||||
except asyncpg.WrongObjectTypeError as exc:
|
||||
raise errors.StoredTypeError(*exc.args)
|
||||
|
||||
@classmethod
|
||||
async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
|
||||
query = "SELECT cog_name, cog_id FROM red_config.red_cogs"
|
||||
log.invisible(query)
|
||||
async with cls._pool.acquire() as conn, conn.transaction():
|
||||
async for row in conn.cursor(query):
|
||||
yield row["cog_name"], row["cog_id"]
|
||||
|
||||
@classmethod
|
||||
async def delete_all_data(
|
||||
cls, *, interactive: bool = False, drop_db: Optional[bool] = None, **kwargs
|
||||
) -> None:
|
||||
"""Delete all data being stored by this driver.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
interactive : bool
|
||||
Set to ``True`` to allow the method to ask the user for
|
||||
input from the console, regarding the other unset parameters
|
||||
for this method.
|
||||
drop_db : Optional[bool]
|
||||
Set to ``True`` to drop the entire database for the current
|
||||
bot's instance. Otherwise, schemas within the database which
|
||||
store bot data will be dropped, as well as functions,
|
||||
aggregates, event triggers, and meta-tables.
|
||||
|
||||
"""
|
||||
if interactive is True and drop_db is None:
|
||||
print(
|
||||
"Please choose from one of the following options:\n"
|
||||
" 1. Drop the entire PostgreSQL database for this instance, or\n"
|
||||
" 2. Delete all of Red's data within this database, without dropping the database "
|
||||
"itself."
|
||||
)
|
||||
options = ("1", "2")
|
||||
while True:
|
||||
resp = input("> ")
|
||||
try:
|
||||
drop_db = bool(options.index(resp))
|
||||
except ValueError:
|
||||
print("Please type a number corresponding to one of the options.")
|
||||
else:
|
||||
break
|
||||
if drop_db is True:
|
||||
storage_details = data_manager.storage_details()
|
||||
await cls._pool.execute(f"DROP DATABASE $1", storage_details["database"])
|
||||
else:
|
||||
with DROP_DDL_SCRIPT_PATH.open() as fs:
|
||||
await cls._pool.execute(fs.read())
|
||||
|
||||
@classmethod
|
||||
async def _execute(cls, query: str, *args, method: Optional[Callable] = None) -> Any:
|
||||
if method is None:
|
||||
method = cls._pool.execute
|
||||
log.invisible("Query: %s", query)
|
||||
if args:
|
||||
log.invisible("Args: %s", args)
|
||||
return await method(query, *args)
|
||||
@@ -1,233 +0,0 @@
|
||||
import enum
|
||||
from typing import Tuple
|
||||
|
||||
__all__ = ["BaseDriver", "IdentifierData"]
|
||||
|
||||
|
||||
class ConfigCategory(enum.Enum):
|
||||
GLOBAL = "GLOBAL"
|
||||
GUILD = "GUILD"
|
||||
CHANNEL = "TEXTCHANNEL"
|
||||
ROLE = "ROLE"
|
||||
USER = "USER"
|
||||
MEMBER = "MEMBER"
|
||||
|
||||
|
||||
class IdentifierData:
|
||||
def __init__(
|
||||
self,
|
||||
uuid: str,
|
||||
category: str,
|
||||
primary_key: Tuple[str, ...],
|
||||
identifiers: Tuple[str, ...],
|
||||
custom_group_data: dict,
|
||||
is_custom: bool = False,
|
||||
):
|
||||
self._uuid = uuid
|
||||
self._category = category
|
||||
self._primary_key = primary_key
|
||||
self._identifiers = identifiers
|
||||
self.custom_group_data = custom_group_data
|
||||
self._is_custom = is_custom
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
return self._uuid
|
||||
|
||||
@property
|
||||
def category(self):
|
||||
return self._category
|
||||
|
||||
@property
|
||||
def primary_key(self):
|
||||
return self._primary_key
|
||||
|
||||
@property
|
||||
def identifiers(self):
|
||||
return self._identifiers
|
||||
|
||||
@property
|
||||
def is_custom(self):
|
||||
return self._is_custom
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<IdentifierData uuid={self.uuid} category={self.category} primary_key={self.primary_key}"
|
||||
f" identifiers={self.identifiers}>"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, IdentifierData):
|
||||
return False
|
||||
return (
|
||||
self.uuid == other.uuid
|
||||
and self.category == other.category
|
||||
and self.primary_key == other.primary_key
|
||||
and self.identifiers == other.identifiers
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.uuid, self.category, self.primary_key, self.identifiers))
|
||||
|
||||
def add_identifier(self, *identifier: str) -> "IdentifierData":
|
||||
if not all(isinstance(i, str) for i in identifier):
|
||||
raise ValueError("Identifiers must be strings.")
|
||||
|
||||
return IdentifierData(
|
||||
self.uuid,
|
||||
self.category,
|
||||
self.primary_key,
|
||||
self.identifiers + identifier,
|
||||
self.custom_group_data,
|
||||
is_custom=self.is_custom,
|
||||
)
|
||||
|
||||
def to_tuple(self):
|
||||
return tuple(
|
||||
item
|
||||
for item in (self.uuid, self.category, *self.primary_key, *self.identifiers)
|
||||
if len(item) > 0
|
||||
)
|
||||
|
||||
|
||||
class BaseDriver:
|
||||
def __init__(self, cog_name, identifier):
|
||||
self.cog_name = cog_name
|
||||
self.unique_cog_identifier = identifier
|
||||
|
||||
async def has_valid_connection(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get(self, identifier_data: IdentifierData):
|
||||
"""
|
||||
Finds the value indicate by the given identifiers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifier_data
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
Stored value.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_config_details(self):
|
||||
"""
|
||||
Asks users for additional configuration information necessary
|
||||
to use this config driver.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict of configuration details.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def set(self, identifier_data: IdentifierData, value=None):
|
||||
"""
|
||||
Sets the value of the key indicated by the given identifiers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifier_data
|
||||
value
|
||||
Any JSON serializable python object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def clear(self, identifier_data: IdentifierData):
|
||||
"""
|
||||
Clears out the value specified by the given identifiers.
|
||||
|
||||
Equivalent to using ``del`` on a dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifier_data
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_levels(self, category, custom_group_data):
|
||||
if category == ConfigCategory.GLOBAL.value:
|
||||
return 0
|
||||
elif category in (
|
||||
ConfigCategory.USER.value,
|
||||
ConfigCategory.GUILD.value,
|
||||
ConfigCategory.CHANNEL.value,
|
||||
ConfigCategory.ROLE.value,
|
||||
):
|
||||
return 1
|
||||
elif category == ConfigCategory.MEMBER.value:
|
||||
return 2
|
||||
elif category in custom_group_data:
|
||||
return custom_group_data[category]
|
||||
else:
|
||||
raise RuntimeError(f"Cannot convert due to group: {category}")
|
||||
|
||||
def _split_primary_key(self, category, custom_group_data, data):
|
||||
levels = self._get_levels(category, custom_group_data)
|
||||
if levels == 0:
|
||||
return (((), data),)
|
||||
|
||||
def flatten(levels_remaining, currdata, parent_key=()):
|
||||
items = []
|
||||
for k, v in currdata.items():
|
||||
new_key = parent_key + (k,)
|
||||
if levels_remaining > 1:
|
||||
items.extend(flatten(levels_remaining - 1, v, new_key).items())
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
|
||||
ret = []
|
||||
for k, v in flatten(levels, data).items():
|
||||
ret.append((k, v))
|
||||
return tuple(ret)
|
||||
|
||||
async def export_data(self, custom_group_data):
|
||||
categories = [c.value for c in ConfigCategory]
|
||||
categories.extend(custom_group_data.keys())
|
||||
|
||||
ret = []
|
||||
for c in categories:
|
||||
ident_data = IdentifierData(
|
||||
self.unique_cog_identifier,
|
||||
c,
|
||||
(),
|
||||
(),
|
||||
custom_group_data,
|
||||
is_custom=c in custom_group_data,
|
||||
)
|
||||
try:
|
||||
data = await self.get(ident_data)
|
||||
except KeyError:
|
||||
continue
|
||||
ret.append((c, data))
|
||||
return ret
|
||||
|
||||
async def import_data(self, cog_data, custom_group_data):
|
||||
for category, all_data in cog_data:
|
||||
splitted_pkey = self._split_primary_key(category, custom_group_data, all_data)
|
||||
for pkey, data in splitted_pkey:
|
||||
ident_data = IdentifierData(
|
||||
self.unique_cog_identifier,
|
||||
category,
|
||||
pkey,
|
||||
(),
|
||||
custom_group_data,
|
||||
is_custom=category in custom_group_data,
|
||||
)
|
||||
await self.set(ident_data, data)
|
||||
|
||||
@staticmethod
|
||||
def get_pkey_len(identifier_data: IdentifierData) -> int:
|
||||
cat = identifier_data.category
|
||||
if cat == ConfigCategory.GLOBAL.value:
|
||||
return 0
|
||||
elif cat == ConfigCategory.MEMBER.value:
|
||||
return 2
|
||||
elif identifier_data.is_custom:
|
||||
return identifier_data.custom_group_data[cat]
|
||||
else:
|
||||
return 1
|
||||
Reference in New Issue
Block a user