PostgreSQL driver, tests against DB backends, and general drivers cleanup (#2723)

* PostgreSQL driver and general drivers cleanup

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Make tests pass

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Add black --target-version flag in make.bat

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Rewrite postgres driver

Most of the logic is now in PL/pgSQL.

This completely avoids the use of Python f-strings to format identifiers into queries. Although an SQL-injection attack would have been impossible anyway (only the owner would have ever had the ability to do that), using PostgreSQL's format() is more reliable for unusual identifiers. Performance-wise, I'm not sure whether this is an improvement, but I highly doubt that it's worse.

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Reformat

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Fix PostgresDriver.delete_all_data()

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Clean up PL/pgSQL code

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* More PL/pgSQL cleanup

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* PL/pgSQL function optimisations

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Ensure compatibility with PostgreSQL 10 and below

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* More/better docstrings for PG functions

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Fix typo in docstring

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Return correct value on toggle()

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Use composite type for PG function parameters

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Fix JSON driver's Config.clear_all()

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Correct description for Mongo tox recipe

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Fix linting errors

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Update dep specification after merging bumpdeps

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Add towncrier entries

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Update from merge

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Mention [postgres] extra in install docs

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Support more connection options and use better defaults

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Actually pass PG env vars in tox

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Replace event trigger with manual DELETE queries

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
Toby Harradine
2019-08-27 12:02:26 +10:00
committed by Michael H
parent 57fa29dd64
commit d1a46acc9a
34 changed files with 2282 additions and 843 deletions

View File

@@ -1,50 +1,110 @@
import enum
from typing import Optional, Type
from .red_base import IdentifierData
from .. import data_manager
from .base import IdentifierData, BaseDriver, ConfigCategory
from .json import JsonDriver
from .mongo import MongoDriver
from .postgres import PostgresDriver
__all__ = ["get_driver", "IdentifierData", "BackendType"]
__all__ = [
"get_driver",
"ConfigCategory",
"IdentifierData",
"BaseDriver",
"JsonDriver",
"MongoDriver",
"PostgresDriver",
"BackendType",
]
class BackendType(enum.Enum):
JSON = "JSON"
MONGO = "MongoDBV2"
MONGOV1 = "MongoDB"
POSTGRES = "Postgres"
def get_driver(type, *args, **kwargs):
_DRIVER_CLASSES = {
BackendType.JSON: JsonDriver,
BackendType.MONGO: MongoDriver,
BackendType.POSTGRES: PostgresDriver,
}
def get_driver_class(storage_type: Optional[BackendType] = None) -> Type[BaseDriver]:
"""Get the driver class for the given storage type.
Parameters
----------
storage_type : Optional[BackendType]
The backend you want a driver class for. Omit to try to obtain
the backend from data manager.
Returns
-------
Type[BaseDriver]
A subclass of `BaseDriver`.
Raises
------
ValueError
If there is no driver for the given storage type.
"""
Selectively import/load driver classes based on the selected type. This
is required so that dependencies can differ between installs (e.g. so that
you don't need to install a mongo dependency if you will just be running a
json data backend).
if storage_type is None:
storage_type = BackendType(data_manager.storage_type())
try:
return _DRIVER_CLASSES[storage_type]
except KeyError:
raise ValueError(f"No driver found for storage type {storage_type}") from None
.. note::
See the respective classes for information on what ``args`` and ``kwargs``
should be.
def get_driver(
cog_name: str, identifier: str, storage_type: Optional[BackendType] = None, **kwargs
):
"""Get a driver instance.
Parameters
----------
cog_name : str
The cog's name.
identifier : str
The cog's discriminator.
storage_type : Optional[BackendType]
The backend you want a driver for. Omit to try to obtain the
backend from data manager.
**kwargs
Driver-specific keyword arguments.
Returns
-------
BaseDriver
A driver instance.
Raises
------
RuntimeError
If the storage type is MongoV1 or invalid.
:param str type:
One of: json, mongo
:param args:
Dependent on driver type.
:param kwargs:
Dependent on driver type.
:return:
Subclass of :py:class:`.red_base.BaseDriver`.
"""
if type == "JSON":
from .red_json import JSON
if storage_type is None:
try:
storage_type = BackendType(data_manager.storage_type())
except RuntimeError:
storage_type = BackendType.JSON
return JSON(*args, **kwargs)
elif type == "MongoDBV2":
from .red_mongo import Mongo
return Mongo(*args, **kwargs)
elif type == "MongoDB":
raise RuntimeError(
"Please convert to JSON first to continue using the bot."
" This is a required conversion prior to using the new Mongo driver."
" This message will be updated with a link to the update docs once those"
" docs have been created."
)
raise RuntimeError("Invalid driver type: '{}'".format(type))
try:
driver_cls: Type[BaseDriver] = get_driver_class(storage_type)
except ValueError:
if storage_type == BackendType.MONGOV1:
raise RuntimeError(
"Please convert to JSON first to continue using the bot."
" This is a required conversion prior to using the new Mongo driver."
" This message will be updated with a link to the update docs once those"
" docs have been created."
) from None
else:
raise RuntimeError(f"Invalid driver type: '{storage_type}'") from None
return driver_cls(cog_name, identifier, **kwargs)

342
redbot/core/drivers/base.py Normal file
View File

@@ -0,0 +1,342 @@
import abc
import enum
from typing import Tuple, Dict, Any, Union, List, AsyncIterator, Type
__all__ = ["BaseDriver", "IdentifierData", "ConfigCategory"]
class ConfigCategory(str, enum.Enum):
GLOBAL = "GLOBAL"
GUILD = "GUILD"
CHANNEL = "TEXTCHANNEL"
ROLE = "ROLE"
USER = "USER"
MEMBER = "MEMBER"
@classmethod
def get_pkey_info(
cls, category: Union[str, "ConfigCategory"], custom_group_data: Dict[str, int]
) -> Tuple[int, bool]:
"""Get the full primary key length for the given category,
and whether or not the category is a custom category.
"""
try:
# noinspection PyArgumentList
category_obj = cls(category)
except ValueError:
return custom_group_data[category], True
else:
return _CATEGORY_PKEY_COUNTS[category_obj], False
_CATEGORY_PKEY_COUNTS = {
ConfigCategory.GLOBAL: 0,
ConfigCategory.GUILD: 1,
ConfigCategory.CHANNEL: 1,
ConfigCategory.ROLE: 1,
ConfigCategory.USER: 1,
ConfigCategory.MEMBER: 2,
}
class IdentifierData:
def __init__(
self,
cog_name: str,
uuid: str,
category: str,
primary_key: Tuple[str, ...],
identifiers: Tuple[str, ...],
primary_key_len: int,
is_custom: bool = False,
):
self._cog_name = cog_name
self._uuid = uuid
self._category = category
self._primary_key = primary_key
self._identifiers = identifiers
self.primary_key_len = primary_key_len
self._is_custom = is_custom
@property
def cog_name(self) -> str:
return self._cog_name
@property
def uuid(self) -> str:
return self._uuid
@property
def category(self) -> str:
return self._category
@property
def primary_key(self) -> Tuple[str, ...]:
return self._primary_key
@property
def identifiers(self) -> Tuple[str, ...]:
return self._identifiers
@property
def is_custom(self) -> bool:
return self._is_custom
def __repr__(self) -> str:
return (
f"<IdentifierData cog_name={self.cog_name} uuid={self.uuid} category={self.category} "
f"primary_key={self.primary_key} identifiers={self.identifiers}>"
)
def __eq__(self, other) -> bool:
if not isinstance(other, IdentifierData):
return False
return (
self.uuid == other.uuid
and self.category == other.category
and self.primary_key == other.primary_key
and self.identifiers == other.identifiers
)
def __hash__(self) -> int:
return hash((self.uuid, self.category, self.primary_key, self.identifiers))
def add_identifier(self, *identifier: str) -> "IdentifierData":
if not all(isinstance(i, str) for i in identifier):
raise ValueError("Identifiers must be strings.")
return IdentifierData(
self.cog_name,
self.uuid,
self.category,
self.primary_key,
self.identifiers + identifier,
self.primary_key_len,
is_custom=self.is_custom,
)
def to_tuple(self) -> Tuple[str, ...]:
return tuple(
filter(
None,
(self.cog_name, self.uuid, self.category, *self.primary_key, *self.identifiers),
)
)
class BaseDriver(abc.ABC):
def __init__(self, cog_name: str, identifier: str, **kwargs):
self.cog_name = cog_name
self.unique_cog_identifier = identifier
@classmethod
@abc.abstractmethod
async def initialize(cls, **storage_details) -> None:
"""
Initialize this driver.
Parameters
----------
**storage_details
The storage details required to initialize this driver.
Should be the same as :func:`data_manager.storage_details`
Raises
------
MissingExtraRequirements
If initializing the driver requires an extra which isn't
installed.
"""
raise NotImplementedError
@classmethod
@abc.abstractmethod
async def teardown(cls) -> None:
"""
Tear down this driver.
"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def get_config_details() -> Dict[str, Any]:
"""
Asks users for additional configuration information necessary
to use this config driver.
Returns
-------
Dict[str, Any]
Dictionary of configuration details.
"""
raise NotImplementedError
@abc.abstractmethod
async def get(self, identifier_data: IdentifierData) -> Any:
"""
Finds the value indicate by the given identifiers.
Parameters
----------
identifier_data
Returns
-------
Any
Stored value.
"""
raise NotImplementedError
@abc.abstractmethod
async def set(self, identifier_data: IdentifierData, value=None) -> None:
"""
Sets the value of the key indicated by the given identifiers.
Parameters
----------
identifier_data
value
Any JSON serializable python object.
"""
raise NotImplementedError
@abc.abstractmethod
async def clear(self, identifier_data: IdentifierData) -> None:
"""
Clears out the value specified by the given identifiers.
Equivalent to using ``del`` on a dict.
Parameters
----------
identifier_data
"""
raise NotImplementedError
@classmethod
@abc.abstractmethod
def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
"""Get info for cogs which have data stored on this backend.
Yields
------
Tuple[str, str]
Asynchronously yields (cog_name, cog_identifier) tuples.
"""
raise NotImplementedError
@classmethod
async def migrate_to(
cls,
new_driver_cls: Type["BaseDriver"],
all_custom_group_data: Dict[str, Dict[str, Dict[str, int]]],
) -> None:
"""Migrate data from this backend to another.
Both drivers must be initialized beforehand.
This will only move the data - no instance metadata is modified
as a result of this operation.
Parameters
----------
new_driver_cls
Subclass of `BaseDriver`.
all_custom_group_data : Dict[str, Dict[str, Dict[str, int]]]
Dict mapping cog names, to cog IDs, to custom groups, to
primary key lengths.
"""
# Backend-agnostic method of migrating from one driver to another.
async for cog_name, cog_id in cls.aiter_cogs():
this_driver = cls(cog_name, cog_id)
other_driver = new_driver_cls(cog_name, cog_id)
custom_group_data = all_custom_group_data.get(cog_name, {}).get(cog_id, {})
exported_data = await this_driver.export_data(custom_group_data)
await other_driver.import_data(exported_data, custom_group_data)
@classmethod
async def delete_all_data(cls, **kwargs) -> None:
"""Delete all data being stored by this driver.
The driver must be initialized before this operation.
The BaseDriver provides a generic method which may be overriden
by subclasses.
Parameters
----------
**kwargs
Driver-specific kwargs to change the way this method
operates.
"""
async for cog_name, cog_id in cls.aiter_cogs():
driver = cls(cog_name, cog_id)
await driver.clear(IdentifierData(cog_name, cog_id, "", (), (), 0))
@staticmethod
def _split_primary_key(
category: Union[ConfigCategory, str],
custom_group_data: Dict[str, int],
data: Dict[str, Any],
) -> List[Tuple[Tuple[str, ...], Dict[str, Any]]]:
pkey_len = ConfigCategory.get_pkey_info(category, custom_group_data)[0]
if pkey_len == 0:
return [((), data)]
def flatten(levels_remaining, currdata, parent_key=()):
items = []
for _k, _v in currdata.items():
new_key = parent_key + (_k,)
if levels_remaining > 1:
items.extend(flatten(levels_remaining - 1, _v, new_key).items())
else:
items.append((new_key, _v))
return dict(items)
ret = []
for k, v in flatten(pkey_len, data).items():
ret.append((k, v))
return ret
async def export_data(
self, custom_group_data: Dict[str, int]
) -> List[Tuple[str, Dict[str, Any]]]:
categories = [c.value for c in ConfigCategory]
categories.extend(custom_group_data.keys())
ret = []
for c in categories:
ident_data = IdentifierData(
self.cog_name,
self.unique_cog_identifier,
c,
(),
(),
*ConfigCategory.get_pkey_info(c, custom_group_data),
)
try:
data = await self.get(ident_data)
except KeyError:
continue
ret.append((c, data))
return ret
async def import_data(
self, cog_data: List[Tuple[str, Dict[str, Any]]], custom_group_data: Dict[str, int]
) -> None:
for category, all_data in cog_data:
splitted_pkey = self._split_primary_key(category, custom_group_data, all_data)
for pkey, data in splitted_pkey:
ident_data = IdentifierData(
self.cog_name,
self.unique_cog_identifier,
category,
pkey,
(),
*ConfigCategory.get_pkey_info(category, custom_group_data),
)
await self.set(ident_data, data)

View File

@@ -5,12 +5,13 @@ import os
import pickle
import weakref
from pathlib import Path
from typing import Any, Dict
from typing import Any, AsyncIterator, Dict, Optional, Tuple
from uuid import uuid4
from .red_base import BaseDriver, IdentifierData
from .. import data_manager, errors
from .base import BaseDriver, IdentifierData, ConfigCategory
__all__ = ["JSON"]
__all__ = ["JsonDriver"]
_shared_datastore = {}
@@ -35,9 +36,10 @@ def finalize_driver(cog_name):
_finalizers.remove(f)
class JSON(BaseDriver):
# noinspection PyProtectedMember
class JsonDriver(BaseDriver):
"""
Subclass of :py:class:`.red_base.BaseDriver`.
Subclass of :py:class:`.BaseDriver`.
.. py:attribute:: file_name
@@ -50,27 +52,26 @@ class JSON(BaseDriver):
def __init__(
self,
cog_name,
identifier,
cog_name: str,
identifier: str,
*,
data_path_override: Path = None,
file_name_override: str = "settings.json"
data_path_override: Optional[Path] = None,
file_name_override: str = "settings.json",
):
super().__init__(cog_name, identifier)
self.file_name = file_name_override
if data_path_override:
if data_path_override is not None:
self.data_path = data_path_override
elif cog_name == "Core" and identifier == "0":
self.data_path = data_manager.core_data_path()
else:
self.data_path = Path.cwd() / "cogs" / ".data" / self.cog_name
self.data_path = data_manager.cog_data_path(raw_name=cog_name)
self.data_path.mkdir(parents=True, exist_ok=True)
self.data_path = self.data_path / self.file_name
self._lock = asyncio.Lock()
self._load_data()
async def has_valid_connection(self) -> bool:
return True
@property
def data(self):
return _shared_datastore.get(self.cog_name)
@@ -79,6 +80,21 @@ class JSON(BaseDriver):
def data(self, value):
_shared_datastore[self.cog_name] = value
@classmethod
async def initialize(cls, **storage_details) -> None:
# No initializing to do
return
@classmethod
async def teardown(cls) -> None:
# No tearing down to do
return
@staticmethod
def get_config_details() -> Dict[str, Any]:
# No driver-specific configuration needed
return {}
def _load_data(self):
if self.cog_name not in _driver_counts:
_driver_counts[self.cog_name] = 0
@@ -111,30 +127,32 @@ class JSON(BaseDriver):
async def get(self, identifier_data: IdentifierData):
partial = self.data
full_identifiers = identifier_data.to_tuple()
full_identifiers = identifier_data.to_tuple()[1:]
for i in full_identifiers:
partial = partial[i]
return pickle.loads(pickle.dumps(partial, -1))
async def set(self, identifier_data: IdentifierData, value=None):
partial = self.data
full_identifiers = identifier_data.to_tuple()
full_identifiers = identifier_data.to_tuple()[1:]
# This is both our deepcopy() and our way of making sure this value is actually JSON
# serializable.
value_copy = json.loads(json.dumps(value))
async with self._lock:
for i in full_identifiers[:-1]:
if i not in partial:
partial[i] = {}
partial = partial[i]
partial[full_identifiers[-1]] = value_copy
try:
partial = partial.setdefault(i, {})
except AttributeError:
# Tried to set sub-field of non-object
raise errors.CannotSetSubfield
partial[full_identifiers[-1]] = value_copy
await self._save()
async def clear(self, identifier_data: IdentifierData):
partial = self.data
full_identifiers = identifier_data.to_tuple()
full_identifiers = identifier_data.to_tuple()[1:]
try:
for i in full_identifiers[:-1]:
partial = partial[i]
@@ -149,14 +167,32 @@ class JSON(BaseDriver):
else:
await self._save()
@classmethod
async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
yield "Core", "0"
for _dir in data_manager.cog_data_path().iterdir():
fpath = _dir / "settings.json"
if not fpath.exists():
continue
with fpath.open() as f:
try:
data = json.load(f)
except json.JSONDecodeError:
continue
if not isinstance(data, dict):
continue
for cog, inner in data.items():
if not isinstance(inner, dict):
continue
for cog_id in inner:
yield cog, cog_id
async def import_data(self, cog_data, custom_group_data):
def update_write_data(identifier_data: IdentifierData, _data):
partial = self.data
idents = identifier_data.to_tuple()
idents = identifier_data.to_tuple()[1:]
for ident in idents[:-1]:
if ident not in partial:
partial[ident] = {}
partial = partial[ident]
partial = partial.setdefault(ident, {})
partial[idents[-1]] = _data
async with self._lock:
@@ -164,12 +200,12 @@ class JSON(BaseDriver):
splitted_pkey = self._split_primary_key(category, custom_group_data, all_data)
for pkey, data in splitted_pkey:
ident_data = IdentifierData(
self.cog_name,
self.unique_cog_identifier,
category,
pkey,
(),
custom_group_data,
is_custom=category in custom_group_data,
*ConfigCategory.get_pkey_info(category, custom_group_data),
)
update_write_data(ident_data, data)
await self._save()
@@ -178,9 +214,6 @@ class JSON(BaseDriver):
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, _save_json, self.data_path, self.data)
def get_config_details(self):
return
def _save_json(path: Path, data: Dict[str, Any]) -> None:
"""

View File

@@ -0,0 +1,11 @@
import functools
import logging
import os
if os.getenv("RED_INSPECT_DRIVER_QUERIES"):
LOGGING_INVISIBLE = logging.DEBUG
else:
LOGGING_INVISIBLE = 0
log = logging.getLogger("red.driver")
log.invisible = functools.partial(log.log, LOGGING_INVISIBLE)

View File

@@ -2,77 +2,110 @@ import contextlib
import itertools
import re
from getpass import getpass
from typing import Match, Pattern, Tuple, Any, Dict, Iterator, List
from typing import Match, Pattern, Tuple, Optional, AsyncIterator, Any, Dict, Iterator, List
from urllib.parse import quote_plus
import motor.core
import motor.motor_asyncio
import pymongo.errors
try:
# pylint: disable=import-error
import pymongo.errors
import motor.core
import motor.motor_asyncio
except ModuleNotFoundError:
motor = None
pymongo = None
from .red_base import BaseDriver, IdentifierData
from .. import errors
from .base import BaseDriver, IdentifierData
__all__ = ["Mongo"]
__all__ = ["MongoDriver"]
_conn = None
def _initialize(**kwargs):
uri = kwargs.get("URI", "mongodb")
host = kwargs["HOST"]
port = kwargs["PORT"]
admin_user = kwargs["USERNAME"]
admin_pass = kwargs["PASSWORD"]
db_name = kwargs.get("DB_NAME", "default_db")
if port is 0:
ports = ""
else:
ports = ":{}".format(port)
if admin_user is not None and admin_pass is not None:
url = "{}://{}:{}@{}{}/{}".format(
uri, quote_plus(admin_user), quote_plus(admin_pass), host, ports, db_name
)
else:
url = "{}://{}{}/{}".format(uri, host, ports, db_name)
global _conn
_conn = motor.motor_asyncio.AsyncIOMotorClient(url, retryWrites=True)
class Mongo(BaseDriver):
class MongoDriver(BaseDriver):
"""
Subclass of :py:class:`.red_base.BaseDriver`.
Subclass of :py:class:`.BaseDriver`.
"""
def __init__(self, cog_name, identifier, **kwargs):
super().__init__(cog_name, identifier)
_conn: Optional["motor.motor_asyncio.AsyncIOMotorClient"] = None
if _conn is None:
_initialize(**kwargs)
@classmethod
async def initialize(cls, **storage_details) -> None:
if motor is None:
raise errors.MissingExtraRequirements(
"Red must be installed with the [mongo] extra to use the MongoDB driver"
)
uri = storage_details.get("URI", "mongodb")
host = storage_details["HOST"]
port = storage_details["PORT"]
user = storage_details["USERNAME"]
password = storage_details["PASSWORD"]
database = storage_details.get("DB_NAME", "default_db")
async def has_valid_connection(self) -> bool:
# Maybe fix this?
return True
if port is 0:
ports = ""
else:
ports = ":{}".format(port)
if user is not None and password is not None:
url = "{}://{}:{}@{}{}/{}".format(
uri, quote_plus(user), quote_plus(password), host, ports, database
)
else:
url = "{}://{}{}/{}".format(uri, host, ports, database)
cls._conn = motor.motor_asyncio.AsyncIOMotorClient(url, retryWrites=True)
@classmethod
async def teardown(cls) -> None:
if cls._conn is not None:
cls._conn.close()
@staticmethod
def get_config_details():
while True:
uri = input("Enter URI scheme (mongodb or mongodb+srv): ")
if uri is "":
uri = "mongodb"
if uri in ["mongodb", "mongodb+srv"]:
break
else:
print("Invalid URI scheme")
host = input("Enter host address: ")
if uri is "mongodb":
port = int(input("Enter host port: "))
else:
port = 0
admin_uname = input("Enter login username: ")
admin_password = getpass("Enter login password: ")
db_name = input("Enter mongodb database name: ")
if admin_uname == "":
admin_uname = admin_password = None
ret = {
"HOST": host,
"PORT": port,
"USERNAME": admin_uname,
"PASSWORD": admin_password,
"DB_NAME": db_name,
"URI": uri,
}
return ret
@property
def db(self) -> motor.core.Database:
def db(self) -> "motor.core.Database":
"""
Gets the mongo database for this cog's name.
.. warning::
Right now this will cause a new connection to be made every time the
database is accessed. We will want to create a connection pool down the
line to limit the number of connections.
:return:
PyMongo Database object.
"""
return _conn.get_database()
return self._conn.get_database()
def get_collection(self, category: str) -> motor.core.Collection:
def get_collection(self, category: str) -> "motor.core.Collection":
"""
Gets a specified collection within the PyMongo database for this cog.
@@ -85,12 +118,13 @@ class Mongo(BaseDriver):
"""
return self.db[self.cog_name][category]
def get_primary_key(self, identifier_data: IdentifierData) -> Tuple[str]:
@staticmethod
def get_primary_key(identifier_data: IdentifierData) -> Tuple[str, ...]:
# noinspection PyTypeChecker
return identifier_data.primary_key
async def rebuild_dataset(
self, identifier_data: IdentifierData, cursor: motor.motor_asyncio.AsyncIOMotorCursor
self, identifier_data: IdentifierData, cursor: "motor.motor_asyncio.AsyncIOMotorCursor"
):
ret = {}
async for doc in cursor:
@@ -141,16 +175,16 @@ class Mongo(BaseDriver):
async def set(self, identifier_data: IdentifierData, value=None):
uuid = self._escape_key(identifier_data.uuid)
primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data)))
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
if isinstance(value, dict):
if len(value) == 0:
await self.clear(identifier_data)
return
value = self._escape_dict_keys(value)
mongo_collection = self.get_collection(identifier_data.category)
pkey_len = self.get_pkey_len(identifier_data)
num_pkeys = len(primary_key)
if num_pkeys >= pkey_len:
if num_pkeys >= identifier_data.primary_key_len:
# We're setting at the document level or below.
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
if dot_identifiers:
@@ -158,11 +192,23 @@ class Mongo(BaseDriver):
else:
update_stmt = {"$set": value}
await mongo_collection.update_one(
{"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}},
update=update_stmt,
upsert=True,
)
try:
await mongo_collection.update_one(
{"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}},
update=update_stmt,
upsert=True,
)
except pymongo.errors.WriteError as exc:
if exc.args and exc.args[0].startswith("Cannot create field"):
# There's a bit of a failing edge case here...
# If we accidentally set the sub-field of an array, and the key happens to be a
# digit, it will successfully set the value in the array, and not raise an
# error. This is different to how other drivers would behave, and could lead to
# unexpected behaviour.
raise errors.CannotSetSubfield
else:
# Unhandled driver exception, should expose.
raise
else:
# We're setting above the document level.
@@ -171,15 +217,17 @@ class Mongo(BaseDriver):
# We'll do it in a transaction so we can roll-back in case something goes horribly
# wrong.
pkey_filter = self.generate_primary_key_filter(identifier_data)
async with await _conn.start_session() as session:
async with await self._conn.start_session() as session:
with contextlib.suppress(pymongo.errors.CollectionInvalid):
# Collections must already exist when inserting documents within a transaction
await _conn.get_database().create_collection(mongo_collection.full_name)
await self.db.create_collection(mongo_collection.full_name)
try:
async with session.start_transaction():
await mongo_collection.delete_many(pkey_filter, session=session)
await mongo_collection.insert_many(
self.generate_documents_to_insert(uuid, primary_key, value, pkey_len),
self.generate_documents_to_insert(
uuid, primary_key, value, identifier_data.primary_key_len
),
session=session,
)
except pymongo.errors.OperationFailure:
@@ -218,7 +266,7 @@ class Mongo(BaseDriver):
# What's left of `value` should be the new documents needing to be inserted.
to_insert = self.generate_documents_to_insert(
uuid, primary_key, value, pkey_len
uuid, primary_key, value, identifier_data.primary_key_len
)
requests = list(
itertools.chain(
@@ -289,6 +337,59 @@ class Mongo(BaseDriver):
for result in results:
await db[result["name"]].delete_many(pkey_filter)
@classmethod
async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
db = cls._conn.get_database()
for collection_name in await db.list_collection_names():
parts = collection_name.split(".")
if not len(parts) == 2:
continue
cog_name = parts[0]
for cog_id in await db[collection_name].distinct("_id.RED_uuid"):
yield cog_name, cog_id
@classmethod
async def delete_all_data(
cls, *, interactive: bool = False, drop_db: Optional[bool] = None, **kwargs
) -> None:
"""Delete all data being stored by this driver.
Parameters
----------
interactive : bool
Set to ``True`` to allow the method to ask the user for
input from the console, regarding the other unset parameters
for this method.
drop_db : Optional[bool]
Set to ``True`` to drop the entire database for the current
bot's instance. Otherwise, collections which appear to be
storing bot data will be dropped.
"""
if interactive is True and drop_db is None:
print(
"Please choose from one of the following options:\n"
" 1. Drop the entire MongoDB database for this instance, or\n"
" 2. Delete all of Red's data within this database, without dropping the database "
"itself."
)
options = ("1", "2")
while True:
resp = input("> ")
try:
drop_db = bool(options.index(resp))
except ValueError:
print("Please type a number corresponding to one of the options.")
else:
break
db = cls._conn.get_database()
if drop_db is True:
await cls._conn.drop_database(db)
else:
async with await cls._conn.start_session() as session:
async for cog_name, cog_id in cls.aiter_cogs():
await db.drop_collection(db[cog_name], session=session)
@staticmethod
def _escape_key(key: str) -> str:
return _SPECIAL_CHAR_PATTERN.sub(_replace_with_escaped, key)
@@ -344,40 +445,3 @@ _CHAR_ESCAPES = {
def _replace_with_unescaped(match: Match[str]) -> str:
return _CHAR_ESCAPES[match[0]]
def get_config_details():
uri = None
while True:
uri = input("Enter URI scheme (mongodb or mongodb+srv): ")
if uri is "":
uri = "mongodb"
if uri in ["mongodb", "mongodb+srv"]:
break
else:
print("Invalid URI scheme")
host = input("Enter host address: ")
if uri is "mongodb":
port = int(input("Enter host port: "))
else:
port = 0
admin_uname = input("Enter login username: ")
admin_password = getpass("Enter login password: ")
db_name = input("Enter mongodb database name: ")
if admin_uname == "":
admin_uname = admin_password = None
ret = {
"HOST": host,
"PORT": port,
"USERNAME": admin_uname,
"PASSWORD": admin_password,
"DB_NAME": db_name,
"URI": uri,
}
return ret

View File

@@ -0,0 +1,3 @@
from .postgres import PostgresDriver
__all__ = ["PostgresDriver"]

View File

@@ -0,0 +1,839 @@
/*
************************************************************
* PostgreSQL driver Data Definition Language (DDL) Script. *
************************************************************
*/
CREATE SCHEMA IF NOT EXISTS red_config;
CREATE SCHEMA IF NOT EXISTS red_utils;
DO $$
BEGIN
PERFORM 'red_config.identifier_data'::regtype;
EXCEPTION
WHEN UNDEFINED_OBJECT THEN
CREATE TYPE red_config.identifier_data AS (
cog_name text,
cog_id text,
category text,
pkeys text[],
identifiers text[],
pkey_len integer,
is_custom boolean
);
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Create the config schema and/or table if they do not exist yet.
*/
red_config.maybe_create_table(
id_data red_config.identifier_data
)
RETURNS void
LANGUAGE 'plpgsql'
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
schema_exists CONSTANT boolean := exists(
SELECT 1
FROM red_config.red_cogs t
WHERE t.cog_name = id_data.cog_name AND t.cog_id = id_data.cog_id);
table_exists CONSTANT boolean := schema_exists AND exists(
SELECT 1
FROM information_schema.tables
WHERE table_schema = schemaname AND table_name = id_data.category);
BEGIN
IF NOT schema_exists THEN
PERFORM red_config.create_schema(id_data.cog_name, id_data.cog_id);
END IF;
IF NOT table_exists THEN
PERFORM red_config.create_table(id_data);
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Create the config schema for the given cog.
*/
red_config.create_schema(new_cog_name text, new_cog_id text, OUT schemaname text)
RETURNS text
LANGUAGE 'plpgsql'
AS $$
BEGIN
schemaname := concat_ws('.', new_cog_name, new_cog_id);
EXECUTE format('CREATE SCHEMA IF NOT EXISTS %I', schemaname);
INSERT INTO red_config.red_cogs AS t VALUES(new_cog_name, new_cog_id, schemaname)
ON CONFLICT(cog_name, cog_id) DO UPDATE
SET
schemaname = excluded.schemaname;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Create the config table for the given category.
*/
red_config.create_table(id_data red_config.identifier_data)
RETURNS void
LANGUAGE 'plpgsql'
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
constraintname CONSTANT text := id_data.category||'_pkey';
pkey_columns CONSTANT text := red_utils.gen_pkey_columns(1, id_data.pkey_len);
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
pkey_column_definitions CONSTANT text := red_utils.gen_pkey_column_definitions(
1, id_data.pkey_len, pkey_type);
BEGIN
EXECUTE format(
$query$
CREATE TABLE IF NOT EXISTS %I.%I (
%s,
json_data jsonb DEFAULT '{}' NOT NULL,
CONSTRAINT %I PRIMARY KEY (%s)
)
$query$,
schemaname,
id_data.category,
pkey_column_definitions,
constraintname,
pkey_columns);
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Get config data.
*
* - When `pkeys` is a full primary key, all or part of a document
* will be returned.
* - When `pkeys` is not a full primary key, documents will be
* aggregated together into a single JSONB object, with primary keys
* as keys mapping to the documents.
*/
red_config.get(
id_data red_config.identifier_data,
OUT result jsonb
)
LANGUAGE 'plpgsql'
STABLE
PARALLEL SAFE
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
num_pkeys CONSTANT integer := coalesce(array_length(id_data.pkeys, 1), 0);
num_missing_pkeys CONSTANT integer := id_data.pkey_len - num_pkeys;
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
whereclause CONSTANT text := red_utils.gen_whereclause(num_pkeys, pkey_type);
missing_pkey_columns text;
BEGIN
IF num_missing_pkeys <= 0 THEN
-- No missing primary keys: we're getting all or part of a document.
EXECUTE format(
'SELECT json_data #> $2 FROM %I.%I WHERE %s',
schemaname,
id_data.category,
whereclause)
INTO result
USING id_data.pkeys, id_data.identifiers;
ELSIF num_missing_pkeys = 1 THEN
-- 1 missing primary key: we can use the built-in jsonb_object_agg() aggregate function.
EXECUTE format(
'SELECT jsonb_object_agg(%I::text, json_data) FROM %I.%I WHERE %s',
'primary_key_'||id_data.pkey_len,
schemaname,
id_data.category,
whereclause)
INTO result
USING id_data.pkeys;
ELSE
-- Multiple missing primary keys: we must use our custom red_utils.jsonb_object_agg2()
-- aggregate function.
missing_pkey_columns := red_utils.gen_pkey_columns_casted(num_pkeys + 1, id_data.pkey_len);
EXECUTE format(
'SELECT red_utils.jsonb_object_agg2(json_data, %s) FROM %I.%I WHERE %s',
missing_pkey_columns,
schemaname,
id_data.category,
whereclause)
INTO result
USING id_data.pkeys;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Set config data.
*
* - When `pkeys` is a full primary key, all or part of a document
* will be set.
* - When `pkeys` is not a full set, multiple documents will be
* replaced or removed - `new_value` must be a JSONB object mapping
* primary keys to the new documents.
*
* Raises `error_in_assignment` error when trying to set a sub-key
* of a non-document type.
*/
red_config.set(
id_data red_config.identifier_data,
new_value jsonb
)
RETURNS void
LANGUAGE 'plpgsql'
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
constraintname CONSTANT text := id_data.category||'_pkey';
num_pkeys CONSTANT integer := coalesce(array_length(id_data.pkeys, 1), 0);
num_missing_pkeys CONSTANT integer := id_data.pkey_len - num_pkeys;
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
pkey_placeholders CONSTANT text := red_utils.gen_pkey_placeholders(num_pkeys, pkey_type);
new_document jsonb;
pkey_column_definitions text;
whereclause text;
missing_pkey_columns text;
BEGIN
PERFORM red_config.maybe_create_table(id_data);
IF num_missing_pkeys = 0 THEN
-- Setting all or part of a document
new_document := red_utils.jsonb_set2('{}', new_value, VARIADIC id_data.identifiers);
EXECUTE format(
$query$
INSERT INTO %I.%I AS t VALUES (%s, $2)
ON CONFLICT ON CONSTRAINT %I DO UPDATE
SET
json_data = red_utils.jsonb_set2(t.json_data, $3, VARIADIC $4)
$query$,
schemaname,
id_data.category,
pkey_placeholders,
constraintname)
USING id_data.pkeys, new_document, new_value, id_data.identifiers;
ELSE
-- Setting multiple documents
whereclause := red_utils.gen_whereclause(num_pkeys, pkey_type);
missing_pkey_columns := red_utils.gen_pkey_columns_casted(
num_pkeys + 1, id_data.pkey_len, pkey_type);
pkey_column_definitions := red_utils.gen_pkey_column_definitions(num_pkeys + 1, id_data.pkey_len);
-- Delete all documents which we're setting first, since we don't know whether they'll be
-- replaced by the subsequent INSERT.
EXECUTE format('DELETE FROM %I.%I WHERE %s', schemaname, id_data.category, whereclause)
USING id_data.pkeys;
-- Insert all new documents
EXECUTE format(
$query$
INSERT INTO %I.%I AS t
SELECT %s, json_data
FROM red_utils.generate_rows_from_object($2, $3) AS f(%s, json_data jsonb)
ON CONFLICT ON CONSTRAINT %I DO UPDATE
SET
json_data = excluded.json_data
$query$,
schemaname,
id_data.category,
concat_ws(', ', pkey_placeholders, missing_pkey_columns),
pkey_column_definitions,
constraintname)
USING id_data.pkeys, new_value, num_missing_pkeys;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Clear config data.
*
* - When `identifiers` is not empty, this will clear a key from a
* document.
* - When `identifiers` is empty and `pkeys` is not empty, it will
* delete one or more documents.
* - When `pkeys` is empty, it will drop the whole table.
* - When `id_data.category` is NULL or an empty string, it will drop
* the whole schema.
*
* Has no effect when the document or key does not exist.
*/
red_config.clear(
id_data red_config.identifier_data
)
RETURNS void
LANGUAGE 'plpgsql'
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
num_pkeys CONSTANT integer := coalesce(array_length(id_data.pkeys, 1), 0);
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
whereclause text;
BEGIN
IF num_identifiers > 0 THEN
-- Popping a key from a document or nested document.
whereclause := red_utils.gen_whereclause(num_pkeys, pkey_type);
EXECUTE format(
$query$
UPDATE %I.%I AS t
SET
json_data = t.json_data #- $2
WHERE %s
$query$,
schemaname,
id_data.category,
whereclause)
USING id_data.pkeys, id_data.identifiers;
ELSIF num_pkeys > 0 THEN
-- Deleting one or many documents
whereclause := red_utils.gen_whereclause(num_pkeys, pkey_type);
EXECUTE format('DELETE FROM %I.%I WHERE %s', schemaname, id_data.category, whereclause)
USING id_data.pkeys;
ELSIF id_data.category IS NOT NULL AND id_data.category != '' THEN
-- Deleting an entire category
EXECUTE format('DROP TABLE %I.%I CASCADE', schemaname, id_data.category);
ELSE
-- Deleting an entire cog's data
EXECUTE format('DROP SCHEMA %I CASCADE', schemaname);
DELETE FROM red_config.red_cogs
WHERE cog_name = id_data.cog_name AND cog_id = id_data.cog_id;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Increment a number within a document.
*
* If the value doesn't already exist, it is inserted as
* `default_value + amount`.
*
* Raises 'wrong_object_type' error when trying to increment a
* non-numeric value.
*/
red_config.inc(
id_data red_config.identifier_data,
amount numeric,
default_value numeric,
OUT result numeric
)
LANGUAGE 'plpgsql'
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
whereclause CONSTANT text := red_utils.gen_whereclause(id_data.pkey_len, pkey_type);
new_document jsonb;
existing_document jsonb;
existing_value jsonb;
pkey_placeholders text;
BEGIN
IF num_identifiers = 0 THEN
-- Without identifiers, there's no chance we're actually incrementing a number
RAISE EXCEPTION 'Cannot increment document(s)'
USING ERRCODE = 'wrong_object_type';
END IF;
PERFORM red_config.maybe_create_table(id_data);
-- Look for the existing document
EXECUTE format(
'SELECT json_data FROM %I.%I WHERE %s',
schemaname,
id_data.category,
whereclause)
INTO existing_document USING id_data.pkeys;
IF existing_document IS NULL THEN
-- We need to insert a new document
result := default_value + amount;
new_document := red_utils.jsonb_set2('{}', result, VARIADIC id_data.identifiers);
pkey_placeholders := red_utils.gen_pkey_placeholders(id_data.pkey_len, pkey_type);
EXECUTE format(
'INSERT INTO %I.%I VALUES(%s, $2)',
schemaname,
id_data.category,
pkey_placeholders)
USING id_data.pkeys, new_document;
ELSE
-- We need to update the existing document
existing_value := existing_document #> id_data.identifiers;
IF existing_value IS NULL THEN
result := default_value + amount;
ELSIF jsonb_typeof(existing_value) = 'number' THEN
result := existing_value::text::numeric + amount;
ELSE
RAISE EXCEPTION 'Cannot increment non-numeric value %', existing_value
USING ERRCODE = 'wrong_object_type';
END IF;
new_document := red_utils.jsonb_set2(
existing_document, to_jsonb(result), id_data.identifiers);
EXECUTE format(
'UPDATE %I.%I SET json_data = $2 WHERE %s',
schemaname,
id_data.category,
whereclause)
USING id_data.pkeys, new_document;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Toggle a boolean within a document.
*
* If the value doesn't already exist, it is inserted as `NOT
* default_value`.
*
* Raises 'wrong_object_type' error when trying to toggle a
* non-boolean value.
*/
red_config.toggle(
id_data red_config.identifier_data,
default_value boolean,
OUT result boolean
)
LANGUAGE 'plpgsql'
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
whereclause CONSTANT text := red_utils.gen_whereclause(id_data.pkey_len, pkey_type);
new_document jsonb;
existing_document jsonb;
existing_value jsonb;
pkey_placeholders text;
BEGIN
IF num_identifiers = 0 THEN
-- Without identifiers, there's no chance we're actually toggling a boolean
RAISE EXCEPTION 'Cannot increment document(s)'
USING ERRCODE = 'wrong_object_type';
END IF;
PERFORM red_config.maybe_create_table(id_data);
-- Look for the existing document
EXECUTE format(
'SELECT json_data FROM %I.%I WHERE %s',
schemaname,
id_data.category,
whereclause)
INTO existing_document USING id_data.pkeys;
IF existing_document IS NULL THEN
-- We need to insert a new document
result := NOT default_value;
new_document := red_utils.jsonb_set2('{}', result, VARIADIC id_data.identifiers);
pkey_placeholders := red_utils.gen_pkey_placeholders(id_data.pkey_len, pkey_type);
EXECUTE format(
'INSERT INTO %I.%I VALUES(%s, $2)',
schemaname,
id_data.category,
pkey_placeholders)
USING id_data.pkeys, new_document;
ELSE
-- We need to update the existing document
existing_value := existing_document #> id_data.identifiers;
IF existing_value IS NULL THEN
result := NOT default_value;
ELSIF jsonb_typeof(existing_value) = 'boolean' THEN
result := NOT existing_value::text::boolean;
ELSE
RAISE EXCEPTION 'Cannot increment non-boolean value %', existing_value
USING ERRCODE = 'wrong_object_type';
END IF;
new_document := red_utils.jsonb_set2(
existing_document, to_jsonb(result), id_data.identifiers);
EXECUTE format(
'UPDATE %I.%I SET json_data = $2 WHERE %s',
schemaname,
id_data.category,
whereclause)
USING id_data.pkeys, new_document;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION
red_config.extend(
id_data red_config.identifier_data,
new_value text,
default_value text,
max_length integer DEFAULT NULL,
extend_left boolean DEFAULT FALSE,
OUT result jsonb
)
LANGUAGE 'plpgsql'
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
whereclause CONSTANT text := red_utils.gen_whereclause(id_data.pkey_len, pkey_type);
pop_idx CONSTANT integer := CASE extend_left WHEN TRUE THEN -1 ELSE 0 END;
new_document jsonb;
existing_document jsonb;
existing_value jsonb;
pkey_placeholders text;
idx integer;
BEGIN
IF num_identifiers = 0 THEN
-- Without identifiers, there's no chance we're actually appending to an array
RAISE EXCEPTION 'Cannot append to document(s)'
USING ERRCODE = 'wrong_object_type';
END IF;
PERFORM red_config.maybe_create_table(id_data);
-- Look for the existing document
EXECUTE format(
'SELECT json_data FROM %I.%I WHERE %s',
schemaname,
id_data.category,
whereclause)
INTO existing_document USING id_data.pkeys;
IF existing_document IS NULL THEN
result := default_value || new_value;
new_document := red_utils.jsonb_set2('{}'::jsonb, result, id_data.identifiers);
pkey_placeholders := red_utils.gen_pkey_placeholders(id_data.pkey_len, pkey_type);
EXECUTE format(
'INSERT INTO %I.%I VALUES(%s, $2)',
schemaname,
id_data.category,
pkey_placeholders)
USING id_data.pkeys, new_document;
ELSE
existing_value := existing_document #> id_data.identifiers;
IF existing_value IS NULL THEN
existing_value := default_value;
ELSIF jsonb_typeof(existing_value) != 'array' THEN
RAISE EXCEPTION 'Cannot append to non-array value %', existing_value
USING ERRCODE = 'wrong_object_type';
END IF;
CASE extend_left
WHEN TRUE THEN
result := new_value || existing_value;
ELSE
result := existing_value || new_value;
END CASE;
IF max_length IS NOT NULL THEN
FOR idx IN SELECT generate_series(1, jsonb_array_length(result) - max_length) LOOP
result := result - pop_idx;
END LOOP;
END IF;
new_document := red_utils.jsonb_set2(existing_document, result, id_data.identifiers);
EXECUTE format(
'UPDATE %I.%I SET json_data = $2 WHERE %s',
schemaname,
id_data.category,
whereclause)
USING id_data.pkeys, new_document;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Delete all schemas listed in the red_config.red_cogs table.
*/
red_config.delete_all_schemas()
RETURNS void
LANGUAGE 'plpgsql'
AS $$
DECLARE
cog_entry record;
BEGIN
FOR cog_entry IN SELECT * FROM red_config.red_cogs t LOOP
EXECUTE format('DROP SCHEMA %I CASCADE', cog_entry.schemaname);
END LOOP;
-- Clear out red_config.red_cogs table
DELETE FROM red_config.red_cogs WHERE TRUE;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Like `jsonb_set` but will insert new objects where one is missing
* along the path.
*
* Raises `error_in_assignment` error when trying to set a sub-key
* of a non-document type.
*/
red_utils.jsonb_set2(target jsonb, new_value jsonb, VARIADIC identifiers text[])
RETURNS jsonb
LANGUAGE 'plpgsql'
IMMUTABLE
PARALLEL SAFE
AS $$
DECLARE
num_identifiers CONSTANT integer := coalesce(array_length(identifiers, 1), 0);
cur_value_type text;
idx integer;
BEGIN
IF num_identifiers = 0 THEN
RETURN new_value;
END IF;
FOR idx IN SELECT generate_series(1, num_identifiers - 1) LOOP
cur_value_type := jsonb_typeof(target #> identifiers[:idx]);
IF cur_value_type IS NULL THEN
-- Parent key didn't exist in JSON before - insert new object
target := jsonb_set(target, identifiers[:idx], '{}'::jsonb);
ELSIF cur_value_type != 'object' THEN
-- We can't set the sub-field of a null, int, float, array etc.
RAISE EXCEPTION 'Cannot set sub-field of "%s"', cur_value_type
USING ERRCODE = 'error_in_assignment';
END IF;
END LOOP;
RETURN jsonb_set(target, identifiers, new_value);
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Return a set of rows to insert into a table, from a single JSONB
* object containing multiple documents.
*/
red_utils.generate_rows_from_object(object jsonb, num_missing_pkeys integer)
RETURNS setof record
LANGUAGE 'plpgsql'
IMMUTABLE
PARALLEL SAFE
AS $$
DECLARE
pair record;
column_definitions text;
BEGIN
IF num_missing_pkeys = 1 THEN
-- Base case: Simply return (key, value) pairs
RETURN QUERY
SELECT key AS key_1, value AS json_data
FROM jsonb_each(object);
ELSE
-- We need to return (key, key, ..., value) pairs: recurse into inner JSONB objects
column_definitions := red_utils.gen_pkey_column_definitions(2, num_missing_pkeys);
FOR pair IN SELECT * FROM jsonb_each(object) LOOP
RETURN QUERY
EXECUTE format(
$query$
SELECT $1 AS key_1, *
FROM red_utils.generate_rows_from_object($2, $3)
AS f(%s, json_data jsonb)
$query$,
column_definitions)
USING pair.key, pair.value, num_missing_pkeys - 1;
END LOOP;
END IF;
RETURN;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Get a comma-separated list of primary key placeholders.
*
* The placeholder will always be $1. Particularly useful for
* inserting values into a table from an array of primary keys.
*/
red_utils.gen_pkey_placeholders(num_pkeys integer, pkey_type text DEFAULT 'text')
RETURNS text
LANGUAGE 'sql'
IMMUTABLE
PARALLEL SAFE
AS $$
SELECT string_agg(t.item, ', ')
FROM (
SELECT format('$1[%s]::%s', idx, pkey_type) AS item
FROM generate_series(1, num_pkeys) idx) t
;
$$;
CREATE OR REPLACE FUNCTION
/*
* Generate a whereclause for the given number of primary keys.
*
* When there are no primary keys, this will simply return the the
* string 'TRUE'. When there are multiple, it will return multiple
* equality comparisons concatenated with 'AND'.
*/
red_utils.gen_whereclause(num_pkeys integer, pkey_type text)
RETURNS text
LANGUAGE 'sql'
IMMUTABLE
PARALLEL SAFE
AS $$
SELECT coalesce(string_agg(t.item, ' AND '), 'TRUE')
FROM (
SELECT format('%I = $1[%s]::%s', 'primary_key_'||idx, idx, pkey_type) AS item
FROM generate_series(1, num_pkeys) idx) t
;
$$;
CREATE OR REPLACE FUNCTION
/*
* Generate a comma-separated list of primary key column names.
*/
red_utils.gen_pkey_columns(start integer, stop integer)
RETURNS text
LANGUAGE 'sql'
IMMUTABLE
PARALLEL SAFE
AS $$
SELECT string_agg(t.item, ', ')
FROM (
SELECT quote_ident('primary_key_'||idx) AS item
FROM generate_series(start, stop) idx) t
;
$$;
CREATE OR REPLACE FUNCTION
/*
* Generate a comma-separated list of primary key column names casted
* to the given type.
*/
red_utils.gen_pkey_columns_casted(start integer, stop integer, pkey_type text DEFAULT 'text')
RETURNS text
LANGUAGE 'sql'
IMMUTABLE
PARALLEL SAFE
AS $$
SELECT string_agg(t.item, ', ')
FROM (
SELECT format('%I::%s', 'primary_key_'||idx, pkey_type) AS item
FROM generate_series(start, stop) idx) t
;
$$;
CREATE OR REPLACE FUNCTION
/*
* Generate a primary key column definition list.
*/
red_utils.gen_pkey_column_definitions(
start integer, stop integer, column_type text DEFAULT 'text'
)
RETURNS text
LANGUAGE 'sql'
IMMUTABLE
PARALLEL SAFE
AS $$
SELECT string_agg(t.item, ', ')
FROM (
SELECT format('%I %s', 'primary_key_'||idx, column_type) AS item
FROM generate_series(start, stop) idx) t
;
$$;
CREATE OR REPLACE FUNCTION
red_utils.get_pkey_type(is_custom boolean)
RETURNS TEXT
LANGUAGE 'sql'
IMMUTABLE
PARALLEL SAFE
AS $$
SELECT ('{bigint,text}'::text[])[is_custom::integer + 1];
$$;
DROP AGGREGATE IF EXISTS red_utils.jsonb_object_agg2(jsonb, VARIADIC text[]);
CREATE AGGREGATE
/*
* Like `jsonb_object_agg` but aggregates more than two columns into a
* single JSONB object.
*
* If possible, use `jsonb_object_agg` instead for performance
* reasons.
*/
red_utils.jsonb_object_agg2(json_data jsonb, VARIADIC primary_keys text[]) (
SFUNC = red_utils.jsonb_set2,
STYPE = jsonb,
INITCOND = '{}',
PARALLEL = SAFE
)
;
CREATE TABLE IF NOT EXISTS
/*
* Table to keep track of other cogs' schemas.
*/
red_config.red_cogs(
cog_name text,
cog_id text,
schemaname text NOT NULL,
PRIMARY KEY (cog_name, cog_id)
)
;

View File

@@ -0,0 +1,3 @@
SELECT red_config.delete_all_schemas();
DROP SCHEMA IF EXISTS red_config CASCADE;
DROP SCHEMA IF EXISTS red_utils CASCADE;

View File

@@ -0,0 +1,255 @@
import getpass
import json
import sys
from pathlib import Path
from typing import Optional, Any, AsyncIterator, Tuple, Union, Callable, List
try:
# pylint: disable=import-error
import asyncpg
except ModuleNotFoundError:
asyncpg = None
from ... import data_manager, errors
from ..base import BaseDriver, IdentifierData, ConfigCategory
from ..log import log
__all__ = ["PostgresDriver"]
_PKG_PATH = Path(__file__).parent
DDL_SCRIPT_PATH = _PKG_PATH / "ddl.sql"
DROP_DDL_SCRIPT_PATH = _PKG_PATH / "drop_ddl.sql"
def encode_identifier_data(
id_data: IdentifierData
) -> Tuple[str, str, str, List[str], List[str], int, bool]:
return (
id_data.cog_name,
id_data.uuid,
id_data.category,
["0"] if id_data.category == ConfigCategory.GLOBAL else list(id_data.primary_key),
list(id_data.identifiers),
1 if id_data.category == ConfigCategory.GLOBAL else id_data.primary_key_len,
id_data.is_custom,
)
class PostgresDriver(BaseDriver):
_pool: Optional["asyncpg.pool.Pool"] = None
@classmethod
async def initialize(cls, **storage_details) -> None:
if asyncpg is None:
raise errors.MissingExtraRequirements(
"Red must be installed with the [postgres] extra to use the PostgreSQL driver"
)
cls._pool = await asyncpg.create_pool(**storage_details)
with DDL_SCRIPT_PATH.open() as fs:
await cls._pool.execute(fs.read())
@classmethod
async def teardown(cls) -> None:
if cls._pool is not None:
await cls._pool.close()
@staticmethod
def get_config_details():
unixmsg = (
""
if sys.platform != "win32"
else (
" - Common directories for PostgreSQL Unix-domain sockets (/run/postgresql, "
"/var/run/postgresl, /var/pgsql_socket, /private/tmp, and /tmp),\n"
)
)
host = (
input(
f"Enter the PostgreSQL server's address.\n"
f"If left blank, Red will try the following, in order:\n"
f" - The PGHOST environment variable,\n{unixmsg}"
f" - localhost.\n"
f"> "
)
or None
)
print(
"Enter the PostgreSQL server port.\n"
"If left blank, this will default to either:\n"
" - The PGPORT environment variable,\n"
" - 5432."
)
while True:
port = input("> ") or None
if port is None:
break
try:
port = int(port)
except ValueError:
print("Port must be a number")
else:
break
user = (
input(
"Enter the PostgreSQL server username.\n"
"If left blank, this will default to either:\n"
" - The PGUSER environment variable,\n"
" - The OS name of the user running Red (ident/peer authentication).\n"
"> "
)
or None
)
passfile = r"%APPDATA%\postgresql\pgpass.conf" if sys.platform != "win32" else "~/.pgpass"
password = getpass.getpass(
f"Enter the PostgreSQL server password. The input will be hidden.\n"
f" NOTE: If using ident/peer authentication (no password), enter NONE.\n"
f"When NONE is entered, this will default to:\n"
f" - The PGPASSWORD environment variable,\n"
f" - Looking up the password in the {passfile} passfile,\n"
f" - No password.\n"
f"> "
)
if password == "NONE":
password = None
database = (
input(
"Enter the PostgreSQL database's name.\n"
"If left blank, this will default to either:\n"
" - The PGDATABASE environment variable,\n"
" - The OS name of the user running Red.\n"
"> "
)
or None
)
return {
"host": host,
"port": port,
"user": user,
"password": password,
"database": database,
}
async def get(self, identifier_data: IdentifierData):
try:
result = await self._execute(
"SELECT red_config.get($1)",
encode_identifier_data(identifier_data),
method=self._pool.fetchval,
)
except asyncpg.UndefinedTableError:
raise KeyError from None
if result is None:
# The result is None both when postgres yields no results, or when it yields a NULL row
# A 'null' JSON value would be returned as encoded JSON, i.e. the string 'null'
raise KeyError
return json.loads(result)
async def set(self, identifier_data: IdentifierData, value=None):
try:
await self._execute(
"SELECT red_config.set($1, $2::jsonb)",
encode_identifier_data(identifier_data),
json.dumps(value),
)
except asyncpg.ErrorInAssignmentError:
raise errors.CannotSetSubfield
async def clear(self, identifier_data: IdentifierData):
try:
await self._execute(
"SELECT red_config.clear($1)", encode_identifier_data(identifier_data)
)
except asyncpg.UndefinedTableError:
pass
async def inc(
self, identifier_data: IdentifierData, value: Union[int, float], default: Union[int, float]
) -> Union[int, float]:
try:
return await self._execute(
f"SELECT red_config.inc($1, $2, $3)",
encode_identifier_data(identifier_data),
value,
default,
method=self._pool.fetchval,
)
except asyncpg.WrongObjectTypeError as exc:
raise errors.StoredTypeError(*exc.args)
async def toggle(self, identifier_data: IdentifierData, default: bool) -> bool:
try:
return await self._execute(
"SELECT red_config.inc($1, $2)",
encode_identifier_data(identifier_data),
default,
method=self._pool.fetchval,
)
except asyncpg.WrongObjectTypeError as exc:
raise errors.StoredTypeError(*exc.args)
@classmethod
async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
query = "SELECT cog_name, cog_id FROM red_config.red_cogs"
log.invisible(query)
async with cls._pool.acquire() as conn, conn.transaction():
async for row in conn.cursor(query):
yield row["cog_name"], row["cog_id"]
@classmethod
async def delete_all_data(
cls, *, interactive: bool = False, drop_db: Optional[bool] = None, **kwargs
) -> None:
"""Delete all data being stored by this driver.
Parameters
----------
interactive : bool
Set to ``True`` to allow the method to ask the user for
input from the console, regarding the other unset parameters
for this method.
drop_db : Optional[bool]
Set to ``True`` to drop the entire database for the current
bot's instance. Otherwise, schemas within the database which
store bot data will be dropped, as well as functions,
aggregates, event triggers, and meta-tables.
"""
if interactive is True and drop_db is None:
print(
"Please choose from one of the following options:\n"
" 1. Drop the entire PostgreSQL database for this instance, or\n"
" 2. Delete all of Red's data within this database, without dropping the database "
"itself."
)
options = ("1", "2")
while True:
resp = input("> ")
try:
drop_db = bool(options.index(resp))
except ValueError:
print("Please type a number corresponding to one of the options.")
else:
break
if drop_db is True:
storage_details = data_manager.storage_details()
await cls._pool.execute(f"DROP DATABASE $1", storage_details["database"])
else:
with DROP_DDL_SCRIPT_PATH.open() as fs:
await cls._pool.execute(fs.read())
@classmethod
async def _execute(cls, query: str, *args, method: Optional[Callable] = None) -> Any:
if method is None:
method = cls._pool.execute
log.invisible("Query: %s", query)
if args:
log.invisible("Args: %s", args)
return await method(query, *args)

View File

@@ -1,233 +0,0 @@
import enum
from typing import Tuple
__all__ = ["BaseDriver", "IdentifierData"]
class ConfigCategory(enum.Enum):
GLOBAL = "GLOBAL"
GUILD = "GUILD"
CHANNEL = "TEXTCHANNEL"
ROLE = "ROLE"
USER = "USER"
MEMBER = "MEMBER"
class IdentifierData:
def __init__(
self,
uuid: str,
category: str,
primary_key: Tuple[str, ...],
identifiers: Tuple[str, ...],
custom_group_data: dict,
is_custom: bool = False,
):
self._uuid = uuid
self._category = category
self._primary_key = primary_key
self._identifiers = identifiers
self.custom_group_data = custom_group_data
self._is_custom = is_custom
@property
def uuid(self):
return self._uuid
@property
def category(self):
return self._category
@property
def primary_key(self):
return self._primary_key
@property
def identifiers(self):
return self._identifiers
@property
def is_custom(self):
return self._is_custom
def __repr__(self):
return (
f"<IdentifierData uuid={self.uuid} category={self.category} primary_key={self.primary_key}"
f" identifiers={self.identifiers}>"
)
def __eq__(self, other) -> bool:
if not isinstance(other, IdentifierData):
return False
return (
self.uuid == other.uuid
and self.category == other.category
and self.primary_key == other.primary_key
and self.identifiers == other.identifiers
)
def __hash__(self) -> int:
return hash((self.uuid, self.category, self.primary_key, self.identifiers))
def add_identifier(self, *identifier: str) -> "IdentifierData":
if not all(isinstance(i, str) for i in identifier):
raise ValueError("Identifiers must be strings.")
return IdentifierData(
self.uuid,
self.category,
self.primary_key,
self.identifiers + identifier,
self.custom_group_data,
is_custom=self.is_custom,
)
def to_tuple(self):
return tuple(
item
for item in (self.uuid, self.category, *self.primary_key, *self.identifiers)
if len(item) > 0
)
class BaseDriver:
def __init__(self, cog_name, identifier):
self.cog_name = cog_name
self.unique_cog_identifier = identifier
async def has_valid_connection(self) -> bool:
raise NotImplementedError
async def get(self, identifier_data: IdentifierData):
"""
Finds the value indicate by the given identifiers.
Parameters
----------
identifier_data
Returns
-------
Any
Stored value.
"""
raise NotImplementedError
def get_config_details(self):
"""
Asks users for additional configuration information necessary
to use this config driver.
Returns
-------
Dict of configuration details.
"""
raise NotImplementedError
async def set(self, identifier_data: IdentifierData, value=None):
"""
Sets the value of the key indicated by the given identifiers.
Parameters
----------
identifier_data
value
Any JSON serializable python object.
"""
raise NotImplementedError
async def clear(self, identifier_data: IdentifierData):
"""
Clears out the value specified by the given identifiers.
Equivalent to using ``del`` on a dict.
Parameters
----------
identifier_data
"""
raise NotImplementedError
def _get_levels(self, category, custom_group_data):
if category == ConfigCategory.GLOBAL.value:
return 0
elif category in (
ConfigCategory.USER.value,
ConfigCategory.GUILD.value,
ConfigCategory.CHANNEL.value,
ConfigCategory.ROLE.value,
):
return 1
elif category == ConfigCategory.MEMBER.value:
return 2
elif category in custom_group_data:
return custom_group_data[category]
else:
raise RuntimeError(f"Cannot convert due to group: {category}")
def _split_primary_key(self, category, custom_group_data, data):
levels = self._get_levels(category, custom_group_data)
if levels == 0:
return (((), data),)
def flatten(levels_remaining, currdata, parent_key=()):
items = []
for k, v in currdata.items():
new_key = parent_key + (k,)
if levels_remaining > 1:
items.extend(flatten(levels_remaining - 1, v, new_key).items())
else:
items.append((new_key, v))
return dict(items)
ret = []
for k, v in flatten(levels, data).items():
ret.append((k, v))
return tuple(ret)
async def export_data(self, custom_group_data):
categories = [c.value for c in ConfigCategory]
categories.extend(custom_group_data.keys())
ret = []
for c in categories:
ident_data = IdentifierData(
self.unique_cog_identifier,
c,
(),
(),
custom_group_data,
is_custom=c in custom_group_data,
)
try:
data = await self.get(ident_data)
except KeyError:
continue
ret.append((c, data))
return ret
async def import_data(self, cog_data, custom_group_data):
for category, all_data in cog_data:
splitted_pkey = self._split_primary_key(category, custom_group_data, all_data)
for pkey, data in splitted_pkey:
ident_data = IdentifierData(
self.unique_cog_identifier,
category,
pkey,
(),
custom_group_data,
is_custom=category in custom_group_data,
)
await self.set(ident_data, data)
@staticmethod
def get_pkey_len(identifier_data: IdentifierData) -> int:
cat = identifier_data.category
if cat == ConfigCategory.GLOBAL.value:
return 0
elif cat == ConfigCategory.MEMBER.value:
return 2
elif identifier_data.is_custom:
return identifier_data.custom_group_data[cat]
else:
return 1