PostgreSQL driver, tests against DB backends, and general drivers cleanup (#2723)

* PostgreSQL driver and general drivers cleanup

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Make tests pass

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Add black --target-version flag in make.bat

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Rewrite postgres driver

Most of the logic is now in PL/pgSQL.

This completely avoids the use of Python f-strings to format identifiers into queries. Although an SQL-injection attack would have been impossible anyway (only the owner would have ever had the ability to do that), using PostgreSQL's format() is more reliable for unusual identifiers. Performance-wise, I'm not sure whether this is an improvement, but I highly doubt that it's worse.

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Reformat

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Fix PostgresDriver.delete_all_data()

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Clean up PL/pgSQL code

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* More PL/pgSQL cleanup

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* PL/pgSQL function optimisations

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Ensure compatibility with PostgreSQL 10 and below

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* More/better docstrings for PG functions

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Fix typo in docstring

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Return correct value on toggle()

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Use composite type for PG function parameters

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Fix JSON driver's Config.clear_all()

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Correct description for Mongo tox recipe

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Fix linting errors

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Update dep specification after merging bumpdeps

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Add towncrier entries

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Update from merge

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Mention [postgres] extra in install docs

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Support more connection options and use better defaults

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Actually pass PG env vars in tox

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Replace event trigger with manual DELETE queries

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
Toby Harradine 2019-08-27 12:02:26 +10:00 committed by Michael H
parent 57fa29dd64
commit d1a46acc9a
34 changed files with 2282 additions and 843 deletions

View File

@ -5,14 +5,10 @@ notifications:
email: false
python:
- 3.7.2
- 3.7.3
env:
global:
- PIPENV_IGNORE_VIRTUALENVS=1
matrix:
- TOXENV=py
- TOXENV=docs
- TOXENV=style
install:
- pip install --upgrade pip tox
@ -22,6 +18,19 @@ script:
jobs:
include:
- env: TOXENV=py
- env: TOXENV=docs
- env: TOXENV=style
- env: TOXENV=postgres
services: postgresql
addons:
postgresql: "10"
before_script:
- psql -c 'create database red_db;' -U postgres
- env: TOXENV=mongo
services: mongodb
before_script:
- mongo red_db --eval 'db.createUser({user:"red",pwd:"red",roles:["readWrite"]});'
# These jobs only occur on tag creation if the prior ones succeed
- stage: PyPi Deployment
if: tag IS present

View File

@ -1,8 +1,8 @@
# Python Code Style
reformat:
black -l 99 `git ls-files "*.py"`
black -l 99 --target-version py37 `git ls-files "*.py"`
stylecheck:
black --check -l 99 `git ls-files "*.py"`
black --check -l 99 --target-version py37 `git ls-files "*.py"`
# Translations
gettext:

View File

@ -0,0 +1 @@
Added a config driver for PostgreSQL

32
changelog.d/2723.misc.rst Normal file
View File

@ -0,0 +1,32 @@
Changes to the ``redbot.core.drivers`` package:
- The modules inside the ``redbot.core.drivers`` package no longer have the ``red_`` prefix in front of their names.
- All driver classes are now directly accessible as attributes to the ``redbot.core.drivers`` package.
- :func:`get_driver`'s signature has been changed.
- :func:`get_driver` can now use data manager to infer the backend type if it is not supplied as an argument.
- :func:`get_driver_class` has been added.
Changes to the :class:`BaseDriver` and :class:`JsonDriver` classes class:
- :meth:`BaseDriver.get_config_details` is an now abstract staticmethod.
- :meth:`BaseDriver.initialize` and :meth:`BaseDriver.teardown` are two new abstract coroutine classmethods.
- :meth:`BaseDriver.delete_all_data` is a new concrete (but overrideable) coroutine instance method.
- :meth:`BaseDriver.aiter_cogs` is a new abstract asynchronous iterator method.
- :meth:`BaseDriver.migrate_to` is a new concrete coroutine classmethod.
- :class:`JsonDriver` no longer requires the data path when constructed and will infer the data path from data manager.
Changes to the :class:`IdentifierData` class and :class:`ConfigCategory` enum:
- ``IdentifierData.custom_group_data`` has been replaced by :attr:`IdentifierData.primary_key_len`.
- :meth:`ConfigCategory.get_pkey_info` is a new classmethod.
Changes to the migration and backup system:
- All code in the ``redbot.setup`` script, excluding that regarding MongoV1, is now virtually backend-agnostic.
- All code in the ``[p]backup`` is now backend-agnostic.
- :func:`redbot.core.config.migrate` is a new coroutine function.
- All a new driver needs to do now to be compatible with migrations and backups is to implement the :class:`BaseDriver` ABC.
Enhancements to unit tests:
- New tox recipes have been added for testing against Mongo and Postgres backends. See the ``tox.ini`` file for clues on how to run them.

View File

@ -421,15 +421,15 @@ Driver Reference
Base Driver
^^^^^^^^^^^
.. autoclass:: redbot.core.drivers.red_base.BaseDriver
.. autoclass:: redbot.core.drivers.BaseDriver
:members:
JSON Driver
^^^^^^^^^^^
.. autoclass:: redbot.core.drivers.red_json.JSON
.. autoclass:: redbot.core.drivers.JsonDriver
:members:
Mongo Driver
^^^^^^^^^^^^
.. autoclass:: redbot.core.drivers.red_mongo.Mongo
.. autoclass:: redbot.core.drivers.MongoDriver
:members:

View File

@ -200,6 +200,12 @@ Or, to install with MongoDB support:
python3.7 -m pip install -U Red-DiscordBot[mongo]
Or, to install with PostgreSQL support:
.. code-block:: none
python3.7 -m pip install -U Red-DiscordBot[postgres]
.. note::
To install the development version, replace ``Red-DiscordBot`` in the above commands with the

View File

@ -65,7 +65,7 @@ Installing Red
If you're not inside an activated virtual environment, include the ``--user`` flag with all
``pip`` commands.
* No MongoDB support:
* Normal installation:
.. code-block:: none
@ -77,6 +77,12 @@ Installing Red
python -m pip install -U Red-DiscordBot[mongo]
* With PostgreSQL support:
.. code-block:: none
python3.7 -m pip install -U Red-DiscordBot[postgres]
.. note::
To install the development version, replace ``Red-DiscordBot`` in the above commands with the

View File

@ -14,11 +14,11 @@ for /F "tokens=* USEBACKQ" %%A in (`git ls-files "*.py"`) do (
goto %1
:reformat
black -l 99 !PYFILES!
black -l 99 --target-version py37 !PYFILES!
exit /B %ERRORLEVEL%
:stylecheck
black -l 99 --check !PYFILES!
black -l 99 --check --target-version py37 !PYFILES!
exit /B %ERRORLEVEL%
:newenv

View File

@ -33,7 +33,7 @@ from redbot.core.events import init_events
from redbot.core.cli import interactive_config, confirm, parse_cli_flags
from redbot.core.core_commands import Core
from redbot.core.dev_commands import Dev
from redbot.core import __version__, modlog, bank, data_manager
from redbot.core import __version__, modlog, bank, data_manager, drivers
from signal import SIGTERM
@ -99,7 +99,11 @@ def main():
)
cli_flags.instance_name = "temporary_red"
data_manager.create_temp_config()
loop = asyncio.get_event_loop()
data_manager.load_basic_configuration(cli_flags.instance_name)
driver_cls = drivers.get_driver_class()
loop.run_until_complete(driver_cls.initialize(**data_manager.storage_details()))
redbot.logging.init_logging(
level=cli_flags.logging_level, location=data_manager.core_data_path() / "logs"
)
@ -111,7 +115,6 @@ def main():
red = Red(
cli_flags=cli_flags, description=description, dm_help=None, fetch_offline_members=True
)
loop = asyncio.get_event_loop()
loop.run_until_complete(red.maybe_update_config())
init_global_checks(red)
init_events(red, cli_flags)

View File

@ -251,7 +251,7 @@ class Warnings(commands.Cog):
user: discord.Member,
points: Optional[int] = 1,
*,
reason: str
reason: str,
):
"""Warn the user for the specified reason.

View File

@ -1,7 +1,7 @@
import asyncio
import inspect
import os
import logging
import os
from collections import Counter
from enum import Enum
from importlib.machinery import ModuleSpec
@ -9,10 +9,9 @@ from pathlib import Path
from typing import Optional, Union, List
import discord
import sys
from discord.ext.commands import when_mentioned_or
from . import Config, i18n, commands, errors
from . import Config, i18n, commands, errors, drivers
from .cog_manager import CogManager
from .rpc import RPCMixin
@ -592,8 +591,8 @@ class Red(RedBase, discord.AutoShardedClient):
async def logout(self):
"""Logs out of Discord and closes all connections."""
await super().logout()
await drivers.get_driver_class().teardown()
async def shutdown(self, *, restart: bool = False):
"""Gracefully quit Red.

View File

@ -5,23 +5,22 @@ import pickle
import weakref
from typing import (
Any,
Union,
Tuple,
Dict,
Awaitable,
AsyncContextManager,
TypeVar,
Awaitable,
Dict,
MutableMapping,
Optional,
Tuple,
Type,
TypeVar,
Union,
)
import discord
from .data_manager import cog_data_path, core_data_path
from .drivers import get_driver, IdentifierData, BackendType
from .drivers.red_base import BaseDriver
from .drivers import IdentifierData, get_driver, ConfigCategory, BaseDriver
__all__ = ["Config", "get_latest_confs"]
__all__ = ["Config", "get_latest_confs", "migrate"]
log = logging.getLogger("red.config")
@ -101,7 +100,7 @@ class Value:
Information on identifiers for this value.
default
The default value for the data element that `identifiers` points at.
driver : `redbot.core.drivers.red_base.BaseDriver`
driver : `redbot.core.drivers.BaseDriver`
A reference to `Config.driver`.
"""
@ -250,7 +249,7 @@ class Group(Value):
All registered default values for this Group.
force_registration : `bool`
Same as `Config.force_registration`.
driver : `redbot.core.drivers.red_base.BaseDriver`
driver : `redbot.core.drivers.BaseDriver`
A reference to `Config.driver`.
"""
@ -586,7 +585,7 @@ class Config:
Unique identifier provided to differentiate cog data when name
conflicts occur.
driver
An instance of a driver that implements `redbot.core.drivers.red_base.BaseDriver`.
An instance of a driver that implements `redbot.core.drivers.BaseDriver`.
force_registration : `bool`
Determines if Config should throw an error if a cog attempts to access
an attribute which has not been previously registered.
@ -634,7 +633,7 @@ class Config:
self.force_registration = force_registration
self._defaults = defaults or {}
self.custom_groups = {}
self.custom_groups: Dict[str, int] = {}
self._lock_cache: MutableMapping[
IdentifierData, asyncio.Lock
] = weakref.WeakValueDictionary()
@ -643,10 +642,6 @@ class Config:
def defaults(self):
return pickle.loads(pickle.dumps(self._defaults, -1))
@staticmethod
def _create_uuid(identifier: int):
return str(identifier)
@classmethod
def get_conf(cls, cog_instance, identifier: int, force_registration=False, cog_name=None):
"""Get a Config instance for your cog.
@ -681,25 +676,12 @@ class Config:
A new Config object.
"""
if cog_instance is None and cog_name is not None:
cog_path_override = cog_data_path(raw_name=cog_name)
else:
cog_path_override = cog_data_path(cog_instance=cog_instance)
uuid = str(identifier)
if cog_name is None:
cog_name = type(cog_instance).__name__
cog_name = cog_path_override.stem
# uuid = str(hash(identifier))
uuid = cls._create_uuid(identifier)
# We have to import this here otherwise we have a circular dependency
from .data_manager import basic_config
driver_name = basic_config.get("STORAGE_TYPE", "JSON")
driver_details = basic_config.get("STORAGE_DETAILS", {})
driver = get_driver(
driver_name, cog_name, uuid, data_path_override=cog_path_override, **driver_details
)
if driver_name == BackendType.JSON.value:
driver = get_driver(cog_name, uuid)
if hasattr(driver, "migrate_identifier"):
driver.migrate_identifier(identifier)
conf = cls(
@ -712,7 +694,7 @@ class Config:
@classmethod
def get_core_conf(cls, force_registration: bool = False):
"""Get a Config instance for a core module.
"""Get a Config instance for the core bot.
All core modules that require a config instance should use this
classmethod instead of `get_conf`.
@ -723,24 +705,9 @@ class Config:
See `force_registration`.
"""
core_path = core_data_path()
# We have to import this here otherwise we have a circular dependency
from .data_manager import basic_config
driver_name = basic_config.get("STORAGE_TYPE", "JSON")
driver_details = basic_config.get("STORAGE_DETAILS", {})
driver = get_driver(
driver_name, "Core", "0", data_path_override=core_path, **driver_details
return cls.get_conf(
None, cog_name="Core", identifier=0, force_registration=force_registration
)
conf = cls(
cog_name="Core",
driver=driver,
unique_identifier="0",
force_registration=force_registration,
)
return conf
def __getattr__(self, item: str) -> Union[Group, Value]:
"""Same as `group.__getattr__` except for global data.
@ -916,26 +883,18 @@ class Config:
self.custom_groups[group_identifier] = identifier_count
def _get_base_group(self, category: str, *primary_keys: str) -> Group:
is_custom = category not in (
self.GLOBAL,
self.GUILD,
self.USER,
self.MEMBER,
self.ROLE,
self.CHANNEL,
)
# noinspection PyTypeChecker
pkey_len, is_custom = ConfigCategory.get_pkey_info(category, self.custom_groups)
identifier_data = IdentifierData(
cog_name=self.cog_name,
uuid=self.unique_identifier,
category=category,
primary_key=primary_keys,
identifiers=(),
custom_group_data=self.custom_groups,
primary_key_len=pkey_len,
is_custom=is_custom,
)
pkey_len = BaseDriver.get_pkey_len(identifier_data)
if len(primary_keys) < pkey_len:
if len(primary_keys) < identifier_data.primary_key_len:
# Don't mix in defaults with groups higher than the document level
defaults = {}
else:
@ -1220,9 +1179,7 @@ class Config:
"""
if not scopes:
# noinspection PyTypeChecker
identifier_data = IdentifierData(
self.unique_identifier, "", (), (), self.custom_groups
)
identifier_data = IdentifierData(self.cog_name, self.unique_identifier, "", (), (), 0)
group = Group(identifier_data, defaults={}, driver=self.driver, config=self)
else:
cat, *scopes = scopes
@ -1359,7 +1316,12 @@ class Config:
return self.get_custom_lock(self.GUILD)
else:
id_data = IdentifierData(
self.unique_identifier, self.MEMBER, (str(guild.id),), (), self.custom_groups
self.cog_name,
self.unique_identifier,
category=self.MEMBER,
primary_key=(str(guild.id),),
identifiers=(),
primary_key_len=2,
)
return self._lock_cache.setdefault(id_data, asyncio.Lock())
@ -1375,10 +1337,33 @@ class Config:
-------
asyncio.Lock
"""
id_data = IdentifierData(
self.unique_identifier, group_identifier, (), (), self.custom_groups
)
return self._lock_cache.setdefault(id_data, asyncio.Lock())
try:
pkey_len, is_custom = ConfigCategory.get_pkey_info(
group_identifier, self.custom_groups
)
except KeyError:
raise ValueError(f"Custom group not initialized: {group_identifier}") from None
else:
id_data = IdentifierData(
self.cog_name,
self.unique_identifier,
category=group_identifier,
primary_key=(),
identifiers=(),
primary_key_len=pkey_len,
is_custom=is_custom,
)
return self._lock_cache.setdefault(id_data, asyncio.Lock())
async def migrate(cur_driver_cls: Type[BaseDriver], new_driver_cls: Type[BaseDriver]) -> None:
"""Migrate from one driver type to another."""
# Get custom group data
core_conf = Config.get_core_conf()
core_conf.init_custom("CUSTOM_GROUPS", 2)
all_custom_group_data = await core_conf.custom("CUSTOM_GROUPS").all()
await cur_driver_cls.migrate_to(new_driver_cls, all_custom_group_data)
def _str_key_dict(value: Dict[Any, _T]) -> Dict[str, _T]:

View File

@ -3,15 +3,12 @@ import contextlib
import datetime
import importlib
import itertools
import json
import logging
import os
import pathlib
import sys
import platform
import getpass
import pip
import tarfile
import traceback
from collections import namedtuple
from pathlib import Path
@ -23,15 +20,18 @@ import aiohttp
import discord
import pkg_resources
from redbot.core import (
from . import (
__version__,
version_info as red_version_info,
VersionInfo,
checks,
commands,
drivers,
errors,
i18n,
config,
)
from .utils import create_backup
from .utils.predicates import MessagePredicate
from .utils.chat_formatting import humanize_timedelta, pagify, box, inline, humanize_list
from .commands.requires import PrivilegeLevel
@ -1307,105 +1307,71 @@ class Core(commands.Cog, CoreLogic):
@commands.command()
@checks.is_owner()
async def backup(self, ctx: commands.Context, *, backup_path: str = None):
"""Creates a backup of all data for the instance."""
if backup_path:
path = pathlib.Path(backup_path)
if not (path.exists() and path.is_dir()):
return await ctx.send(
_("That path doesn't seem to exist. Please provide a valid path.")
)
from redbot.core.data_manager import basic_config, instance_name
from redbot.core.drivers.red_json import JSON
async def backup(self, ctx: commands.Context, *, backup_dir: str = None):
"""Creates a backup of all data for the instance.
data_dir = Path(basic_config["DATA_PATH"])
if basic_config["STORAGE_TYPE"] == "MongoDB":
from redbot.core.drivers.red_mongo import Mongo
m = Mongo("Core", "0", **basic_config["STORAGE_DETAILS"])
db = m.db
collection_names = await db.list_collection_names()
for c_name in collection_names:
if c_name == "Core":
c_data_path = data_dir / basic_config["CORE_PATH_APPEND"]
else:
c_data_path = data_dir / basic_config["COG_PATH_APPEND"] / c_name
docs = await db[c_name].find().to_list(None)
for item in docs:
item_id = str(item.pop("_id"))
target = JSON(c_name, item_id, data_path_override=c_data_path)
target.data = item
await target._save()
backup_filename = "redv3-{}-{}.tar.gz".format(
instance_name, ctx.message.created_at.strftime("%Y-%m-%d %H-%M-%S")
)
if data_dir.exists():
if not backup_path:
backup_pth = data_dir.home()
else:
backup_pth = Path(backup_path)
backup_file = backup_pth / backup_filename
to_backup = []
exclusions = [
"__pycache__",
"Lavalink.jar",
os.path.join("Downloader", "lib"),
os.path.join("CogManager", "cogs"),
os.path.join("RepoManager", "repos"),
]
downloader_cog = ctx.bot.get_cog("Downloader")
if downloader_cog and hasattr(downloader_cog, "_repo_manager"):
repo_output = []
repo_mgr = downloader_cog._repo_manager
for repo in repo_mgr._repos.values():
repo_output.append({"url": repo.url, "name": repo.name, "branch": repo.branch})
repo_filename = data_dir / "cogs" / "RepoManager" / "repos.json"
with open(str(repo_filename), "w") as f:
f.write(json.dumps(repo_output, indent=4))
instance_data = {instance_name: basic_config}
instance_file = data_dir / "instance.json"
with open(str(instance_file), "w") as instance_out:
instance_out.write(json.dumps(instance_data, indent=4))
for f in data_dir.glob("**/*"):
if not any(ex in str(f) for ex in exclusions):
to_backup.append(f)
with tarfile.open(str(backup_file), "w:gz") as tar:
for f in to_backup:
tar.add(str(f), recursive=False)
print(str(backup_file))
await ctx.send(
_("A backup has been made of this instance. It is at {}.").format(backup_file)
)
if backup_file.stat().st_size > 8_000_000:
await ctx.send(_("This backup is too large to send via DM."))
return
await ctx.send(_("Would you like to receive a copy via DM? (y/n)"))
pred = MessagePredicate.yes_or_no(ctx)
try:
await ctx.bot.wait_for("message", check=pred, timeout=60)
except asyncio.TimeoutError:
await ctx.send(_("Response timed out."))
else:
if pred.result is True:
await ctx.send(_("OK, it's on its way!"))
try:
async with ctx.author.typing():
await ctx.author.send(
_("Here's a copy of the backup"),
file=discord.File(str(backup_file)),
)
except discord.Forbidden:
await ctx.send(
_("I don't seem to be able to DM you. Do you have closed DMs?")
)
except discord.HTTPException:
await ctx.send(_("I could not send the backup file."))
else:
await ctx.send(_("OK then."))
You may provide a path to a directory for the backup archive to
be placed in. If the directory does not exist, the bot will
attempt to create it.
"""
if backup_dir is None:
dest = Path.home()
else:
await ctx.send(_("That directory doesn't seem to exist..."))
dest = Path(backup_dir)
driver_cls = drivers.get_driver_class()
if driver_cls != drivers.JsonDriver:
await ctx.send(_("Converting data to JSON for backup..."))
async with ctx.typing():
await config.migrate(driver_cls, drivers.JsonDriver)
log.info("Creating backup for this instance...")
try:
backup_fpath = await create_backup(dest)
except OSError as exc:
await ctx.send(
_(
"Creating the backup archive failed! Please check your console or logs for "
"details."
)
)
log.exception("Failed to create backup archive", exc_info=exc)
return
if backup_fpath is None:
await ctx.send(_("Your datapath appears to be empty."))
return
log.info("Backup archive created successfully at '%s'", backup_fpath)
await ctx.send(
_("A backup has been made of this instance. It is located at `{path}`.").format(
path=backup_fpath
)
)
if backup_fpath.stat().st_size > 8_000_000:
await ctx.send(_("This backup is too large to send via DM."))
return
await ctx.send(_("Would you like to receive a copy via DM? (y/n)"))
pred = MessagePredicate.yes_or_no(ctx)
try:
await ctx.bot.wait_for("message", check=pred, timeout=60)
except asyncio.TimeoutError:
await ctx.send(_("Response timed out."))
else:
if pred.result is True:
await ctx.send(_("OK, it's on its way!"))
try:
async with ctx.author.typing():
await ctx.author.send(
_("Here's a copy of the backup"), file=discord.File(str(backup_fpath))
)
except discord.Forbidden:
await ctx.send(_("I don't seem to be able to DM you. Do you have closed DMs?"))
except discord.HTTPException:
await ctx.send(_("I could not send the backup file."))
else:
await ctx.send(_("OK then."))
@commands.command()
@commands.cooldown(1, 60, commands.BucketType.user)

View File

@ -224,7 +224,4 @@ def storage_details() -> dict:
-------
dict
"""
try:
return basic_config["STORAGE_DETAILS"]
except KeyError as e:
raise RuntimeError("Bot basic config has not been loaded yet.") from e
return basic_config.get("STORAGE_DETAILS", {})

View File

@ -1,50 +1,110 @@
import enum
from typing import Optional, Type
from .red_base import IdentifierData
from .. import data_manager
from .base import IdentifierData, BaseDriver, ConfigCategory
from .json import JsonDriver
from .mongo import MongoDriver
from .postgres import PostgresDriver
__all__ = ["get_driver", "IdentifierData", "BackendType"]
__all__ = [
"get_driver",
"ConfigCategory",
"IdentifierData",
"BaseDriver",
"JsonDriver",
"MongoDriver",
"PostgresDriver",
"BackendType",
]
class BackendType(enum.Enum):
JSON = "JSON"
MONGO = "MongoDBV2"
MONGOV1 = "MongoDB"
POSTGRES = "Postgres"
def get_driver(type, *args, **kwargs):
_DRIVER_CLASSES = {
BackendType.JSON: JsonDriver,
BackendType.MONGO: MongoDriver,
BackendType.POSTGRES: PostgresDriver,
}
def get_driver_class(storage_type: Optional[BackendType] = None) -> Type[BaseDriver]:
"""Get the driver class for the given storage type.
Parameters
----------
storage_type : Optional[BackendType]
The backend you want a driver class for. Omit to try to obtain
the backend from data manager.
Returns
-------
Type[BaseDriver]
A subclass of `BaseDriver`.
Raises
------
ValueError
If there is no driver for the given storage type.
"""
Selectively import/load driver classes based on the selected type. This
is required so that dependencies can differ between installs (e.g. so that
you don't need to install a mongo dependency if you will just be running a
json data backend).
if storage_type is None:
storage_type = BackendType(data_manager.storage_type())
try:
return _DRIVER_CLASSES[storage_type]
except KeyError:
raise ValueError(f"No driver found for storage type {storage_type}") from None
.. note::
See the respective classes for information on what ``args`` and ``kwargs``
should be.
def get_driver(
cog_name: str, identifier: str, storage_type: Optional[BackendType] = None, **kwargs
):
"""Get a driver instance.
Parameters
----------
cog_name : str
The cog's name.
identifier : str
The cog's discriminator.
storage_type : Optional[BackendType]
The backend you want a driver for. Omit to try to obtain the
backend from data manager.
**kwargs
Driver-specific keyword arguments.
Returns
-------
BaseDriver
A driver instance.
Raises
------
RuntimeError
If the storage type is MongoV1 or invalid.
:param str type:
One of: json, mongo
:param args:
Dependent on driver type.
:param kwargs:
Dependent on driver type.
:return:
Subclass of :py:class:`.red_base.BaseDriver`.
"""
if type == "JSON":
from .red_json import JSON
if storage_type is None:
try:
storage_type = BackendType(data_manager.storage_type())
except RuntimeError:
storage_type = BackendType.JSON
return JSON(*args, **kwargs)
elif type == "MongoDBV2":
from .red_mongo import Mongo
return Mongo(*args, **kwargs)
elif type == "MongoDB":
raise RuntimeError(
"Please convert to JSON first to continue using the bot."
" This is a required conversion prior to using the new Mongo driver."
" This message will be updated with a link to the update docs once those"
" docs have been created."
)
raise RuntimeError("Invalid driver type: '{}'".format(type))
try:
driver_cls: Type[BaseDriver] = get_driver_class(storage_type)
except ValueError:
if storage_type == BackendType.MONGOV1:
raise RuntimeError(
"Please convert to JSON first to continue using the bot."
" This is a required conversion prior to using the new Mongo driver."
" This message will be updated with a link to the update docs once those"
" docs have been created."
) from None
else:
raise RuntimeError(f"Invalid driver type: '{storage_type}'") from None
return driver_cls(cog_name, identifier, **kwargs)

342
redbot/core/drivers/base.py Normal file
View File

@ -0,0 +1,342 @@
import abc
import enum
from typing import Tuple, Dict, Any, Union, List, AsyncIterator, Type
__all__ = ["BaseDriver", "IdentifierData", "ConfigCategory"]
class ConfigCategory(str, enum.Enum):
GLOBAL = "GLOBAL"
GUILD = "GUILD"
CHANNEL = "TEXTCHANNEL"
ROLE = "ROLE"
USER = "USER"
MEMBER = "MEMBER"
@classmethod
def get_pkey_info(
cls, category: Union[str, "ConfigCategory"], custom_group_data: Dict[str, int]
) -> Tuple[int, bool]:
"""Get the full primary key length for the given category,
and whether or not the category is a custom category.
"""
try:
# noinspection PyArgumentList
category_obj = cls(category)
except ValueError:
return custom_group_data[category], True
else:
return _CATEGORY_PKEY_COUNTS[category_obj], False
_CATEGORY_PKEY_COUNTS = {
ConfigCategory.GLOBAL: 0,
ConfigCategory.GUILD: 1,
ConfigCategory.CHANNEL: 1,
ConfigCategory.ROLE: 1,
ConfigCategory.USER: 1,
ConfigCategory.MEMBER: 2,
}
class IdentifierData:
def __init__(
self,
cog_name: str,
uuid: str,
category: str,
primary_key: Tuple[str, ...],
identifiers: Tuple[str, ...],
primary_key_len: int,
is_custom: bool = False,
):
self._cog_name = cog_name
self._uuid = uuid
self._category = category
self._primary_key = primary_key
self._identifiers = identifiers
self.primary_key_len = primary_key_len
self._is_custom = is_custom
@property
def cog_name(self) -> str:
return self._cog_name
@property
def uuid(self) -> str:
return self._uuid
@property
def category(self) -> str:
return self._category
@property
def primary_key(self) -> Tuple[str, ...]:
return self._primary_key
@property
def identifiers(self) -> Tuple[str, ...]:
return self._identifiers
@property
def is_custom(self) -> bool:
return self._is_custom
def __repr__(self) -> str:
return (
f"<IdentifierData cog_name={self.cog_name} uuid={self.uuid} category={self.category} "
f"primary_key={self.primary_key} identifiers={self.identifiers}>"
)
def __eq__(self, other) -> bool:
if not isinstance(other, IdentifierData):
return False
return (
self.uuid == other.uuid
and self.category == other.category
and self.primary_key == other.primary_key
and self.identifiers == other.identifiers
)
def __hash__(self) -> int:
return hash((self.uuid, self.category, self.primary_key, self.identifiers))
def add_identifier(self, *identifier: str) -> "IdentifierData":
if not all(isinstance(i, str) for i in identifier):
raise ValueError("Identifiers must be strings.")
return IdentifierData(
self.cog_name,
self.uuid,
self.category,
self.primary_key,
self.identifiers + identifier,
self.primary_key_len,
is_custom=self.is_custom,
)
def to_tuple(self) -> Tuple[str, ...]:
return tuple(
filter(
None,
(self.cog_name, self.uuid, self.category, *self.primary_key, *self.identifiers),
)
)
class BaseDriver(abc.ABC):
def __init__(self, cog_name: str, identifier: str, **kwargs):
self.cog_name = cog_name
self.unique_cog_identifier = identifier
@classmethod
@abc.abstractmethod
async def initialize(cls, **storage_details) -> None:
"""
Initialize this driver.
Parameters
----------
**storage_details
The storage details required to initialize this driver.
Should be the same as :func:`data_manager.storage_details`
Raises
------
MissingExtraRequirements
If initializing the driver requires an extra which isn't
installed.
"""
raise NotImplementedError
@classmethod
@abc.abstractmethod
async def teardown(cls) -> None:
"""
Tear down this driver.
"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def get_config_details() -> Dict[str, Any]:
"""
Asks users for additional configuration information necessary
to use this config driver.
Returns
-------
Dict[str, Any]
Dictionary of configuration details.
"""
raise NotImplementedError
@abc.abstractmethod
async def get(self, identifier_data: IdentifierData) -> Any:
"""
Finds the value indicate by the given identifiers.
Parameters
----------
identifier_data
Returns
-------
Any
Stored value.
"""
raise NotImplementedError
@abc.abstractmethod
async def set(self, identifier_data: IdentifierData, value=None) -> None:
"""
Sets the value of the key indicated by the given identifiers.
Parameters
----------
identifier_data
value
Any JSON serializable python object.
"""
raise NotImplementedError
@abc.abstractmethod
async def clear(self, identifier_data: IdentifierData) -> None:
"""
Clears out the value specified by the given identifiers.
Equivalent to using ``del`` on a dict.
Parameters
----------
identifier_data
"""
raise NotImplementedError
@classmethod
@abc.abstractmethod
def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
"""Get info for cogs which have data stored on this backend.
Yields
------
Tuple[str, str]
Asynchronously yields (cog_name, cog_identifier) tuples.
"""
raise NotImplementedError
@classmethod
async def migrate_to(
cls,
new_driver_cls: Type["BaseDriver"],
all_custom_group_data: Dict[str, Dict[str, Dict[str, int]]],
) -> None:
"""Migrate data from this backend to another.
Both drivers must be initialized beforehand.
This will only move the data - no instance metadata is modified
as a result of this operation.
Parameters
----------
new_driver_cls
Subclass of `BaseDriver`.
all_custom_group_data : Dict[str, Dict[str, Dict[str, int]]]
Dict mapping cog names, to cog IDs, to custom groups, to
primary key lengths.
"""
# Backend-agnostic method of migrating from one driver to another.
async for cog_name, cog_id in cls.aiter_cogs():
this_driver = cls(cog_name, cog_id)
other_driver = new_driver_cls(cog_name, cog_id)
custom_group_data = all_custom_group_data.get(cog_name, {}).get(cog_id, {})
exported_data = await this_driver.export_data(custom_group_data)
await other_driver.import_data(exported_data, custom_group_data)
@classmethod
async def delete_all_data(cls, **kwargs) -> None:
"""Delete all data being stored by this driver.
The driver must be initialized before this operation.
The BaseDriver provides a generic method which may be overriden
by subclasses.
Parameters
----------
**kwargs
Driver-specific kwargs to change the way this method
operates.
"""
async for cog_name, cog_id in cls.aiter_cogs():
driver = cls(cog_name, cog_id)
await driver.clear(IdentifierData(cog_name, cog_id, "", (), (), 0))
@staticmethod
def _split_primary_key(
category: Union[ConfigCategory, str],
custom_group_data: Dict[str, int],
data: Dict[str, Any],
) -> List[Tuple[Tuple[str, ...], Dict[str, Any]]]:
pkey_len = ConfigCategory.get_pkey_info(category, custom_group_data)[0]
if pkey_len == 0:
return [((), data)]
def flatten(levels_remaining, currdata, parent_key=()):
items = []
for _k, _v in currdata.items():
new_key = parent_key + (_k,)
if levels_remaining > 1:
items.extend(flatten(levels_remaining - 1, _v, new_key).items())
else:
items.append((new_key, _v))
return dict(items)
ret = []
for k, v in flatten(pkey_len, data).items():
ret.append((k, v))
return ret
async def export_data(
self, custom_group_data: Dict[str, int]
) -> List[Tuple[str, Dict[str, Any]]]:
categories = [c.value for c in ConfigCategory]
categories.extend(custom_group_data.keys())
ret = []
for c in categories:
ident_data = IdentifierData(
self.cog_name,
self.unique_cog_identifier,
c,
(),
(),
*ConfigCategory.get_pkey_info(c, custom_group_data),
)
try:
data = await self.get(ident_data)
except KeyError:
continue
ret.append((c, data))
return ret
async def import_data(
self, cog_data: List[Tuple[str, Dict[str, Any]]], custom_group_data: Dict[str, int]
) -> None:
for category, all_data in cog_data:
splitted_pkey = self._split_primary_key(category, custom_group_data, all_data)
for pkey, data in splitted_pkey:
ident_data = IdentifierData(
self.cog_name,
self.unique_cog_identifier,
category,
pkey,
(),
*ConfigCategory.get_pkey_info(category, custom_group_data),
)
await self.set(ident_data, data)

View File

@ -5,12 +5,13 @@ import os
import pickle
import weakref
from pathlib import Path
from typing import Any, Dict
from typing import Any, AsyncIterator, Dict, Optional, Tuple
from uuid import uuid4
from .red_base import BaseDriver, IdentifierData
from .. import data_manager, errors
from .base import BaseDriver, IdentifierData, ConfigCategory
__all__ = ["JSON"]
__all__ = ["JsonDriver"]
_shared_datastore = {}
@ -35,9 +36,10 @@ def finalize_driver(cog_name):
_finalizers.remove(f)
class JSON(BaseDriver):
# noinspection PyProtectedMember
class JsonDriver(BaseDriver):
"""
Subclass of :py:class:`.red_base.BaseDriver`.
Subclass of :py:class:`.BaseDriver`.
.. py:attribute:: file_name
@ -50,27 +52,26 @@ class JSON(BaseDriver):
def __init__(
self,
cog_name,
identifier,
cog_name: str,
identifier: str,
*,
data_path_override: Path = None,
file_name_override: str = "settings.json"
data_path_override: Optional[Path] = None,
file_name_override: str = "settings.json",
):
super().__init__(cog_name, identifier)
self.file_name = file_name_override
if data_path_override:
if data_path_override is not None:
self.data_path = data_path_override
elif cog_name == "Core" and identifier == "0":
self.data_path = data_manager.core_data_path()
else:
self.data_path = Path.cwd() / "cogs" / ".data" / self.cog_name
self.data_path = data_manager.cog_data_path(raw_name=cog_name)
self.data_path.mkdir(parents=True, exist_ok=True)
self.data_path = self.data_path / self.file_name
self._lock = asyncio.Lock()
self._load_data()
async def has_valid_connection(self) -> bool:
return True
@property
def data(self):
return _shared_datastore.get(self.cog_name)
@ -79,6 +80,21 @@ class JSON(BaseDriver):
def data(self, value):
_shared_datastore[self.cog_name] = value
@classmethod
async def initialize(cls, **storage_details) -> None:
# No initializing to do
return
@classmethod
async def teardown(cls) -> None:
# No tearing down to do
return
@staticmethod
def get_config_details() -> Dict[str, Any]:
# No driver-specific configuration needed
return {}
def _load_data(self):
if self.cog_name not in _driver_counts:
_driver_counts[self.cog_name] = 0
@ -111,30 +127,32 @@ class JSON(BaseDriver):
async def get(self, identifier_data: IdentifierData):
partial = self.data
full_identifiers = identifier_data.to_tuple()
full_identifiers = identifier_data.to_tuple()[1:]
for i in full_identifiers:
partial = partial[i]
return pickle.loads(pickle.dumps(partial, -1))
async def set(self, identifier_data: IdentifierData, value=None):
partial = self.data
full_identifiers = identifier_data.to_tuple()
full_identifiers = identifier_data.to_tuple()[1:]
# This is both our deepcopy() and our way of making sure this value is actually JSON
# serializable.
value_copy = json.loads(json.dumps(value))
async with self._lock:
for i in full_identifiers[:-1]:
if i not in partial:
partial[i] = {}
partial = partial[i]
partial[full_identifiers[-1]] = value_copy
try:
partial = partial.setdefault(i, {})
except AttributeError:
# Tried to set sub-field of non-object
raise errors.CannotSetSubfield
partial[full_identifiers[-1]] = value_copy
await self._save()
async def clear(self, identifier_data: IdentifierData):
partial = self.data
full_identifiers = identifier_data.to_tuple()
full_identifiers = identifier_data.to_tuple()[1:]
try:
for i in full_identifiers[:-1]:
partial = partial[i]
@ -149,14 +167,32 @@ class JSON(BaseDriver):
else:
await self._save()
@classmethod
async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
yield "Core", "0"
for _dir in data_manager.cog_data_path().iterdir():
fpath = _dir / "settings.json"
if not fpath.exists():
continue
with fpath.open() as f:
try:
data = json.load(f)
except json.JSONDecodeError:
continue
if not isinstance(data, dict):
continue
for cog, inner in data.items():
if not isinstance(inner, dict):
continue
for cog_id in inner:
yield cog, cog_id
async def import_data(self, cog_data, custom_group_data):
def update_write_data(identifier_data: IdentifierData, _data):
partial = self.data
idents = identifier_data.to_tuple()
idents = identifier_data.to_tuple()[1:]
for ident in idents[:-1]:
if ident not in partial:
partial[ident] = {}
partial = partial[ident]
partial = partial.setdefault(ident, {})
partial[idents[-1]] = _data
async with self._lock:
@ -164,12 +200,12 @@ class JSON(BaseDriver):
splitted_pkey = self._split_primary_key(category, custom_group_data, all_data)
for pkey, data in splitted_pkey:
ident_data = IdentifierData(
self.cog_name,
self.unique_cog_identifier,
category,
pkey,
(),
custom_group_data,
is_custom=category in custom_group_data,
*ConfigCategory.get_pkey_info(category, custom_group_data),
)
update_write_data(ident_data, data)
await self._save()
@ -178,9 +214,6 @@ class JSON(BaseDriver):
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, _save_json, self.data_path, self.data)
def get_config_details(self):
return
def _save_json(path: Path, data: Dict[str, Any]) -> None:
"""

View File

@ -0,0 +1,11 @@
import functools
import logging
import os
if os.getenv("RED_INSPECT_DRIVER_QUERIES"):
LOGGING_INVISIBLE = logging.DEBUG
else:
LOGGING_INVISIBLE = 0
log = logging.getLogger("red.driver")
log.invisible = functools.partial(log.log, LOGGING_INVISIBLE)

View File

@ -2,77 +2,110 @@ import contextlib
import itertools
import re
from getpass import getpass
from typing import Match, Pattern, Tuple, Any, Dict, Iterator, List
from typing import Match, Pattern, Tuple, Optional, AsyncIterator, Any, Dict, Iterator, List
from urllib.parse import quote_plus
import motor.core
import motor.motor_asyncio
import pymongo.errors
try:
# pylint: disable=import-error
import pymongo.errors
import motor.core
import motor.motor_asyncio
except ModuleNotFoundError:
motor = None
pymongo = None
from .red_base import BaseDriver, IdentifierData
from .. import errors
from .base import BaseDriver, IdentifierData
__all__ = ["Mongo"]
__all__ = ["MongoDriver"]
_conn = None
def _initialize(**kwargs):
uri = kwargs.get("URI", "mongodb")
host = kwargs["HOST"]
port = kwargs["PORT"]
admin_user = kwargs["USERNAME"]
admin_pass = kwargs["PASSWORD"]
db_name = kwargs.get("DB_NAME", "default_db")
if port is 0:
ports = ""
else:
ports = ":{}".format(port)
if admin_user is not None and admin_pass is not None:
url = "{}://{}:{}@{}{}/{}".format(
uri, quote_plus(admin_user), quote_plus(admin_pass), host, ports, db_name
)
else:
url = "{}://{}{}/{}".format(uri, host, ports, db_name)
global _conn
_conn = motor.motor_asyncio.AsyncIOMotorClient(url, retryWrites=True)
class Mongo(BaseDriver):
class MongoDriver(BaseDriver):
"""
Subclass of :py:class:`.red_base.BaseDriver`.
Subclass of :py:class:`.BaseDriver`.
"""
def __init__(self, cog_name, identifier, **kwargs):
super().__init__(cog_name, identifier)
_conn: Optional["motor.motor_asyncio.AsyncIOMotorClient"] = None
if _conn is None:
_initialize(**kwargs)
@classmethod
async def initialize(cls, **storage_details) -> None:
if motor is None:
raise errors.MissingExtraRequirements(
"Red must be installed with the [mongo] extra to use the MongoDB driver"
)
uri = storage_details.get("URI", "mongodb")
host = storage_details["HOST"]
port = storage_details["PORT"]
user = storage_details["USERNAME"]
password = storage_details["PASSWORD"]
database = storage_details.get("DB_NAME", "default_db")
async def has_valid_connection(self) -> bool:
# Maybe fix this?
return True
if port is 0:
ports = ""
else:
ports = ":{}".format(port)
if user is not None and password is not None:
url = "{}://{}:{}@{}{}/{}".format(
uri, quote_plus(user), quote_plus(password), host, ports, database
)
else:
url = "{}://{}{}/{}".format(uri, host, ports, database)
cls._conn = motor.motor_asyncio.AsyncIOMotorClient(url, retryWrites=True)
@classmethod
async def teardown(cls) -> None:
if cls._conn is not None:
cls._conn.close()
@staticmethod
def get_config_details():
while True:
uri = input("Enter URI scheme (mongodb or mongodb+srv): ")
if uri is "":
uri = "mongodb"
if uri in ["mongodb", "mongodb+srv"]:
break
else:
print("Invalid URI scheme")
host = input("Enter host address: ")
if uri is "mongodb":
port = int(input("Enter host port: "))
else:
port = 0
admin_uname = input("Enter login username: ")
admin_password = getpass("Enter login password: ")
db_name = input("Enter mongodb database name: ")
if admin_uname == "":
admin_uname = admin_password = None
ret = {
"HOST": host,
"PORT": port,
"USERNAME": admin_uname,
"PASSWORD": admin_password,
"DB_NAME": db_name,
"URI": uri,
}
return ret
@property
def db(self) -> motor.core.Database:
def db(self) -> "motor.core.Database":
"""
Gets the mongo database for this cog's name.
.. warning::
Right now this will cause a new connection to be made every time the
database is accessed. We will want to create a connection pool down the
line to limit the number of connections.
:return:
PyMongo Database object.
"""
return _conn.get_database()
return self._conn.get_database()
def get_collection(self, category: str) -> motor.core.Collection:
def get_collection(self, category: str) -> "motor.core.Collection":
"""
Gets a specified collection within the PyMongo database for this cog.
@ -85,12 +118,13 @@ class Mongo(BaseDriver):
"""
return self.db[self.cog_name][category]
def get_primary_key(self, identifier_data: IdentifierData) -> Tuple[str]:
@staticmethod
def get_primary_key(identifier_data: IdentifierData) -> Tuple[str, ...]:
# noinspection PyTypeChecker
return identifier_data.primary_key
async def rebuild_dataset(
self, identifier_data: IdentifierData, cursor: motor.motor_asyncio.AsyncIOMotorCursor
self, identifier_data: IdentifierData, cursor: "motor.motor_asyncio.AsyncIOMotorCursor"
):
ret = {}
async for doc in cursor:
@ -141,16 +175,16 @@ class Mongo(BaseDriver):
async def set(self, identifier_data: IdentifierData, value=None):
uuid = self._escape_key(identifier_data.uuid)
primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data)))
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
if isinstance(value, dict):
if len(value) == 0:
await self.clear(identifier_data)
return
value = self._escape_dict_keys(value)
mongo_collection = self.get_collection(identifier_data.category)
pkey_len = self.get_pkey_len(identifier_data)
num_pkeys = len(primary_key)
if num_pkeys >= pkey_len:
if num_pkeys >= identifier_data.primary_key_len:
# We're setting at the document level or below.
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
if dot_identifiers:
@ -158,11 +192,23 @@ class Mongo(BaseDriver):
else:
update_stmt = {"$set": value}
await mongo_collection.update_one(
{"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}},
update=update_stmt,
upsert=True,
)
try:
await mongo_collection.update_one(
{"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}},
update=update_stmt,
upsert=True,
)
except pymongo.errors.WriteError as exc:
if exc.args and exc.args[0].startswith("Cannot create field"):
# There's a bit of a failing edge case here...
# If we accidentally set the sub-field of an array, and the key happens to be a
# digit, it will successfully set the value in the array, and not raise an
# error. This is different to how other drivers would behave, and could lead to
# unexpected behaviour.
raise errors.CannotSetSubfield
else:
# Unhandled driver exception, should expose.
raise
else:
# We're setting above the document level.
@ -171,15 +217,17 @@ class Mongo(BaseDriver):
# We'll do it in a transaction so we can roll-back in case something goes horribly
# wrong.
pkey_filter = self.generate_primary_key_filter(identifier_data)
async with await _conn.start_session() as session:
async with await self._conn.start_session() as session:
with contextlib.suppress(pymongo.errors.CollectionInvalid):
# Collections must already exist when inserting documents within a transaction
await _conn.get_database().create_collection(mongo_collection.full_name)
await self.db.create_collection(mongo_collection.full_name)
try:
async with session.start_transaction():
await mongo_collection.delete_many(pkey_filter, session=session)
await mongo_collection.insert_many(
self.generate_documents_to_insert(uuid, primary_key, value, pkey_len),
self.generate_documents_to_insert(
uuid, primary_key, value, identifier_data.primary_key_len
),
session=session,
)
except pymongo.errors.OperationFailure:
@ -218,7 +266,7 @@ class Mongo(BaseDriver):
# What's left of `value` should be the new documents needing to be inserted.
to_insert = self.generate_documents_to_insert(
uuid, primary_key, value, pkey_len
uuid, primary_key, value, identifier_data.primary_key_len
)
requests = list(
itertools.chain(
@ -289,6 +337,59 @@ class Mongo(BaseDriver):
for result in results:
await db[result["name"]].delete_many(pkey_filter)
@classmethod
async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
db = cls._conn.get_database()
for collection_name in await db.list_collection_names():
parts = collection_name.split(".")
if not len(parts) == 2:
continue
cog_name = parts[0]
for cog_id in await db[collection_name].distinct("_id.RED_uuid"):
yield cog_name, cog_id
@classmethod
async def delete_all_data(
cls, *, interactive: bool = False, drop_db: Optional[bool] = None, **kwargs
) -> None:
"""Delete all data being stored by this driver.
Parameters
----------
interactive : bool
Set to ``True`` to allow the method to ask the user for
input from the console, regarding the other unset parameters
for this method.
drop_db : Optional[bool]
Set to ``True`` to drop the entire database for the current
bot's instance. Otherwise, collections which appear to be
storing bot data will be dropped.
"""
if interactive is True and drop_db is None:
print(
"Please choose from one of the following options:\n"
" 1. Drop the entire MongoDB database for this instance, or\n"
" 2. Delete all of Red's data within this database, without dropping the database "
"itself."
)
options = ("1", "2")
while True:
resp = input("> ")
try:
drop_db = bool(options.index(resp))
except ValueError:
print("Please type a number corresponding to one of the options.")
else:
break
db = cls._conn.get_database()
if drop_db is True:
await cls._conn.drop_database(db)
else:
async with await cls._conn.start_session() as session:
async for cog_name, cog_id in cls.aiter_cogs():
await db.drop_collection(db[cog_name], session=session)
@staticmethod
def _escape_key(key: str) -> str:
return _SPECIAL_CHAR_PATTERN.sub(_replace_with_escaped, key)
@ -344,40 +445,3 @@ _CHAR_ESCAPES = {
def _replace_with_unescaped(match: Match[str]) -> str:
return _CHAR_ESCAPES[match[0]]
def get_config_details():
uri = None
while True:
uri = input("Enter URI scheme (mongodb or mongodb+srv): ")
if uri is "":
uri = "mongodb"
if uri in ["mongodb", "mongodb+srv"]:
break
else:
print("Invalid URI scheme")
host = input("Enter host address: ")
if uri is "mongodb":
port = int(input("Enter host port: "))
else:
port = 0
admin_uname = input("Enter login username: ")
admin_password = getpass("Enter login password: ")
db_name = input("Enter mongodb database name: ")
if admin_uname == "":
admin_uname = admin_password = None
ret = {
"HOST": host,
"PORT": port,
"USERNAME": admin_uname,
"PASSWORD": admin_password,
"DB_NAME": db_name,
"URI": uri,
}
return ret

View File

@ -0,0 +1,3 @@
from .postgres import PostgresDriver
__all__ = ["PostgresDriver"]

View File

@ -0,0 +1,839 @@
/*
************************************************************
* PostgreSQL driver Data Definition Language (DDL) Script. *
************************************************************
*/
CREATE SCHEMA IF NOT EXISTS red_config;
CREATE SCHEMA IF NOT EXISTS red_utils;
DO $$
BEGIN
PERFORM 'red_config.identifier_data'::regtype;
EXCEPTION
WHEN UNDEFINED_OBJECT THEN
CREATE TYPE red_config.identifier_data AS (
cog_name text,
cog_id text,
category text,
pkeys text[],
identifiers text[],
pkey_len integer,
is_custom boolean
);
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Create the config schema and/or table if they do not exist yet.
*/
red_config.maybe_create_table(
id_data red_config.identifier_data
)
RETURNS void
LANGUAGE 'plpgsql'
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
schema_exists CONSTANT boolean := exists(
SELECT 1
FROM red_config.red_cogs t
WHERE t.cog_name = id_data.cog_name AND t.cog_id = id_data.cog_id);
table_exists CONSTANT boolean := schema_exists AND exists(
SELECT 1
FROM information_schema.tables
WHERE table_schema = schemaname AND table_name = id_data.category);
BEGIN
IF NOT schema_exists THEN
PERFORM red_config.create_schema(id_data.cog_name, id_data.cog_id);
END IF;
IF NOT table_exists THEN
PERFORM red_config.create_table(id_data);
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Create the config schema for the given cog.
*/
red_config.create_schema(new_cog_name text, new_cog_id text, OUT schemaname text)
RETURNS text
LANGUAGE 'plpgsql'
AS $$
BEGIN
schemaname := concat_ws('.', new_cog_name, new_cog_id);
EXECUTE format('CREATE SCHEMA IF NOT EXISTS %I', schemaname);
INSERT INTO red_config.red_cogs AS t VALUES(new_cog_name, new_cog_id, schemaname)
ON CONFLICT(cog_name, cog_id) DO UPDATE
SET
schemaname = excluded.schemaname;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Create the config table for the given category.
*/
red_config.create_table(id_data red_config.identifier_data)
RETURNS void
LANGUAGE 'plpgsql'
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
constraintname CONSTANT text := id_data.category||'_pkey';
pkey_columns CONSTANT text := red_utils.gen_pkey_columns(1, id_data.pkey_len);
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
pkey_column_definitions CONSTANT text := red_utils.gen_pkey_column_definitions(
1, id_data.pkey_len, pkey_type);
BEGIN
EXECUTE format(
$query$
CREATE TABLE IF NOT EXISTS %I.%I (
%s,
json_data jsonb DEFAULT '{}' NOT NULL,
CONSTRAINT %I PRIMARY KEY (%s)
)
$query$,
schemaname,
id_data.category,
pkey_column_definitions,
constraintname,
pkey_columns);
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Get config data.
*
* - When `pkeys` is a full primary key, all or part of a document
* will be returned.
* - When `pkeys` is not a full primary key, documents will be
* aggregated together into a single JSONB object, with primary keys
* as keys mapping to the documents.
*/
red_config.get(
id_data red_config.identifier_data,
OUT result jsonb
)
LANGUAGE 'plpgsql'
STABLE
PARALLEL SAFE
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
num_pkeys CONSTANT integer := coalesce(array_length(id_data.pkeys, 1), 0);
num_missing_pkeys CONSTANT integer := id_data.pkey_len - num_pkeys;
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
whereclause CONSTANT text := red_utils.gen_whereclause(num_pkeys, pkey_type);
missing_pkey_columns text;
BEGIN
IF num_missing_pkeys <= 0 THEN
-- No missing primary keys: we're getting all or part of a document.
EXECUTE format(
'SELECT json_data #> $2 FROM %I.%I WHERE %s',
schemaname,
id_data.category,
whereclause)
INTO result
USING id_data.pkeys, id_data.identifiers;
ELSIF num_missing_pkeys = 1 THEN
-- 1 missing primary key: we can use the built-in jsonb_object_agg() aggregate function.
EXECUTE format(
'SELECT jsonb_object_agg(%I::text, json_data) FROM %I.%I WHERE %s',
'primary_key_'||id_data.pkey_len,
schemaname,
id_data.category,
whereclause)
INTO result
USING id_data.pkeys;
ELSE
-- Multiple missing primary keys: we must use our custom red_utils.jsonb_object_agg2()
-- aggregate function.
missing_pkey_columns := red_utils.gen_pkey_columns_casted(num_pkeys + 1, id_data.pkey_len);
EXECUTE format(
'SELECT red_utils.jsonb_object_agg2(json_data, %s) FROM %I.%I WHERE %s',
missing_pkey_columns,
schemaname,
id_data.category,
whereclause)
INTO result
USING id_data.pkeys;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Set config data.
*
* - When `pkeys` is a full primary key, all or part of a document
* will be set.
* - When `pkeys` is not a full set, multiple documents will be
* replaced or removed - `new_value` must be a JSONB object mapping
* primary keys to the new documents.
*
* Raises `error_in_assignment` error when trying to set a sub-key
* of a non-document type.
*/
red_config.set(
id_data red_config.identifier_data,
new_value jsonb
)
RETURNS void
LANGUAGE 'plpgsql'
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
constraintname CONSTANT text := id_data.category||'_pkey';
num_pkeys CONSTANT integer := coalesce(array_length(id_data.pkeys, 1), 0);
num_missing_pkeys CONSTANT integer := id_data.pkey_len - num_pkeys;
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
pkey_placeholders CONSTANT text := red_utils.gen_pkey_placeholders(num_pkeys, pkey_type);
new_document jsonb;
pkey_column_definitions text;
whereclause text;
missing_pkey_columns text;
BEGIN
PERFORM red_config.maybe_create_table(id_data);
IF num_missing_pkeys = 0 THEN
-- Setting all or part of a document
new_document := red_utils.jsonb_set2('{}', new_value, VARIADIC id_data.identifiers);
EXECUTE format(
$query$
INSERT INTO %I.%I AS t VALUES (%s, $2)
ON CONFLICT ON CONSTRAINT %I DO UPDATE
SET
json_data = red_utils.jsonb_set2(t.json_data, $3, VARIADIC $4)
$query$,
schemaname,
id_data.category,
pkey_placeholders,
constraintname)
USING id_data.pkeys, new_document, new_value, id_data.identifiers;
ELSE
-- Setting multiple documents
whereclause := red_utils.gen_whereclause(num_pkeys, pkey_type);
missing_pkey_columns := red_utils.gen_pkey_columns_casted(
num_pkeys + 1, id_data.pkey_len, pkey_type);
pkey_column_definitions := red_utils.gen_pkey_column_definitions(num_pkeys + 1, id_data.pkey_len);
-- Delete all documents which we're setting first, since we don't know whether they'll be
-- replaced by the subsequent INSERT.
EXECUTE format('DELETE FROM %I.%I WHERE %s', schemaname, id_data.category, whereclause)
USING id_data.pkeys;
-- Insert all new documents
EXECUTE format(
$query$
INSERT INTO %I.%I AS t
SELECT %s, json_data
FROM red_utils.generate_rows_from_object($2, $3) AS f(%s, json_data jsonb)
ON CONFLICT ON CONSTRAINT %I DO UPDATE
SET
json_data = excluded.json_data
$query$,
schemaname,
id_data.category,
concat_ws(', ', pkey_placeholders, missing_pkey_columns),
pkey_column_definitions,
constraintname)
USING id_data.pkeys, new_value, num_missing_pkeys;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Clear config data.
*
* - When `identifiers` is not empty, this will clear a key from a
* document.
* - When `identifiers` is empty and `pkeys` is not empty, it will
* delete one or more documents.
* - When `pkeys` is empty, it will drop the whole table.
* - When `id_data.category` is NULL or an empty string, it will drop
* the whole schema.
*
* Has no effect when the document or key does not exist.
*/
red_config.clear(
id_data red_config.identifier_data
)
RETURNS void
LANGUAGE 'plpgsql'
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
num_pkeys CONSTANT integer := coalesce(array_length(id_data.pkeys, 1), 0);
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
whereclause text;
BEGIN
IF num_identifiers > 0 THEN
-- Popping a key from a document or nested document.
whereclause := red_utils.gen_whereclause(num_pkeys, pkey_type);
EXECUTE format(
$query$
UPDATE %I.%I AS t
SET
json_data = t.json_data #- $2
WHERE %s
$query$,
schemaname,
id_data.category,
whereclause)
USING id_data.pkeys, id_data.identifiers;
ELSIF num_pkeys > 0 THEN
-- Deleting one or many documents
whereclause := red_utils.gen_whereclause(num_pkeys, pkey_type);
EXECUTE format('DELETE FROM %I.%I WHERE %s', schemaname, id_data.category, whereclause)
USING id_data.pkeys;
ELSIF id_data.category IS NOT NULL AND id_data.category != '' THEN
-- Deleting an entire category
EXECUTE format('DROP TABLE %I.%I CASCADE', schemaname, id_data.category);
ELSE
-- Deleting an entire cog's data
EXECUTE format('DROP SCHEMA %I CASCADE', schemaname);
DELETE FROM red_config.red_cogs
WHERE cog_name = id_data.cog_name AND cog_id = id_data.cog_id;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Increment a number within a document.
*
* If the value doesn't already exist, it is inserted as
* `default_value + amount`.
*
* Raises 'wrong_object_type' error when trying to increment a
* non-numeric value.
*/
red_config.inc(
id_data red_config.identifier_data,
amount numeric,
default_value numeric,
OUT result numeric
)
LANGUAGE 'plpgsql'
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
whereclause CONSTANT text := red_utils.gen_whereclause(id_data.pkey_len, pkey_type);
new_document jsonb;
existing_document jsonb;
existing_value jsonb;
pkey_placeholders text;
BEGIN
IF num_identifiers = 0 THEN
-- Without identifiers, there's no chance we're actually incrementing a number
RAISE EXCEPTION 'Cannot increment document(s)'
USING ERRCODE = 'wrong_object_type';
END IF;
PERFORM red_config.maybe_create_table(id_data);
-- Look for the existing document
EXECUTE format(
'SELECT json_data FROM %I.%I WHERE %s',
schemaname,
id_data.category,
whereclause)
INTO existing_document USING id_data.pkeys;
IF existing_document IS NULL THEN
-- We need to insert a new document
result := default_value + amount;
new_document := red_utils.jsonb_set2('{}', result, VARIADIC id_data.identifiers);
pkey_placeholders := red_utils.gen_pkey_placeholders(id_data.pkey_len, pkey_type);
EXECUTE format(
'INSERT INTO %I.%I VALUES(%s, $2)',
schemaname,
id_data.category,
pkey_placeholders)
USING id_data.pkeys, new_document;
ELSE
-- We need to update the existing document
existing_value := existing_document #> id_data.identifiers;
IF existing_value IS NULL THEN
result := default_value + amount;
ELSIF jsonb_typeof(existing_value) = 'number' THEN
result := existing_value::text::numeric + amount;
ELSE
RAISE EXCEPTION 'Cannot increment non-numeric value %', existing_value
USING ERRCODE = 'wrong_object_type';
END IF;
new_document := red_utils.jsonb_set2(
existing_document, to_jsonb(result), id_data.identifiers);
EXECUTE format(
'UPDATE %I.%I SET json_data = $2 WHERE %s',
schemaname,
id_data.category,
whereclause)
USING id_data.pkeys, new_document;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Toggle a boolean within a document.
*
* If the value doesn't already exist, it is inserted as `NOT
* default_value`.
*
* Raises 'wrong_object_type' error when trying to toggle a
* non-boolean value.
*/
red_config.toggle(
id_data red_config.identifier_data,
default_value boolean,
OUT result boolean
)
LANGUAGE 'plpgsql'
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
whereclause CONSTANT text := red_utils.gen_whereclause(id_data.pkey_len, pkey_type);
new_document jsonb;
existing_document jsonb;
existing_value jsonb;
pkey_placeholders text;
BEGIN
IF num_identifiers = 0 THEN
-- Without identifiers, there's no chance we're actually toggling a boolean
RAISE EXCEPTION 'Cannot increment document(s)'
USING ERRCODE = 'wrong_object_type';
END IF;
PERFORM red_config.maybe_create_table(id_data);
-- Look for the existing document
EXECUTE format(
'SELECT json_data FROM %I.%I WHERE %s',
schemaname,
id_data.category,
whereclause)
INTO existing_document USING id_data.pkeys;
IF existing_document IS NULL THEN
-- We need to insert a new document
result := NOT default_value;
new_document := red_utils.jsonb_set2('{}', result, VARIADIC id_data.identifiers);
pkey_placeholders := red_utils.gen_pkey_placeholders(id_data.pkey_len, pkey_type);
EXECUTE format(
'INSERT INTO %I.%I VALUES(%s, $2)',
schemaname,
id_data.category,
pkey_placeholders)
USING id_data.pkeys, new_document;
ELSE
-- We need to update the existing document
existing_value := existing_document #> id_data.identifiers;
IF existing_value IS NULL THEN
result := NOT default_value;
ELSIF jsonb_typeof(existing_value) = 'boolean' THEN
result := NOT existing_value::text::boolean;
ELSE
RAISE EXCEPTION 'Cannot increment non-boolean value %', existing_value
USING ERRCODE = 'wrong_object_type';
END IF;
new_document := red_utils.jsonb_set2(
existing_document, to_jsonb(result), id_data.identifiers);
EXECUTE format(
'UPDATE %I.%I SET json_data = $2 WHERE %s',
schemaname,
id_data.category,
whereclause)
USING id_data.pkeys, new_document;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION
red_config.extend(
id_data red_config.identifier_data,
new_value text,
default_value text,
max_length integer DEFAULT NULL,
extend_left boolean DEFAULT FALSE,
OUT result jsonb
)
LANGUAGE 'plpgsql'
AS $$
DECLARE
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
whereclause CONSTANT text := red_utils.gen_whereclause(id_data.pkey_len, pkey_type);
pop_idx CONSTANT integer := CASE extend_left WHEN TRUE THEN -1 ELSE 0 END;
new_document jsonb;
existing_document jsonb;
existing_value jsonb;
pkey_placeholders text;
idx integer;
BEGIN
IF num_identifiers = 0 THEN
-- Without identifiers, there's no chance we're actually appending to an array
RAISE EXCEPTION 'Cannot append to document(s)'
USING ERRCODE = 'wrong_object_type';
END IF;
PERFORM red_config.maybe_create_table(id_data);
-- Look for the existing document
EXECUTE format(
'SELECT json_data FROM %I.%I WHERE %s',
schemaname,
id_data.category,
whereclause)
INTO existing_document USING id_data.pkeys;
IF existing_document IS NULL THEN
result := default_value || new_value;
new_document := red_utils.jsonb_set2('{}'::jsonb, result, id_data.identifiers);
pkey_placeholders := red_utils.gen_pkey_placeholders(id_data.pkey_len, pkey_type);
EXECUTE format(
'INSERT INTO %I.%I VALUES(%s, $2)',
schemaname,
id_data.category,
pkey_placeholders)
USING id_data.pkeys, new_document;
ELSE
existing_value := existing_document #> id_data.identifiers;
IF existing_value IS NULL THEN
existing_value := default_value;
ELSIF jsonb_typeof(existing_value) != 'array' THEN
RAISE EXCEPTION 'Cannot append to non-array value %', existing_value
USING ERRCODE = 'wrong_object_type';
END IF;
CASE extend_left
WHEN TRUE THEN
result := new_value || existing_value;
ELSE
result := existing_value || new_value;
END CASE;
IF max_length IS NOT NULL THEN
FOR idx IN SELECT generate_series(1, jsonb_array_length(result) - max_length) LOOP
result := result - pop_idx;
END LOOP;
END IF;
new_document := red_utils.jsonb_set2(existing_document, result, id_data.identifiers);
EXECUTE format(
'UPDATE %I.%I SET json_data = $2 WHERE %s',
schemaname,
id_data.category,
whereclause)
USING id_data.pkeys, new_document;
END IF;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Delete all schemas listed in the red_config.red_cogs table.
*/
red_config.delete_all_schemas()
RETURNS void
LANGUAGE 'plpgsql'
AS $$
DECLARE
cog_entry record;
BEGIN
FOR cog_entry IN SELECT * FROM red_config.red_cogs t LOOP
EXECUTE format('DROP SCHEMA %I CASCADE', cog_entry.schemaname);
END LOOP;
-- Clear out red_config.red_cogs table
DELETE FROM red_config.red_cogs WHERE TRUE;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Like `jsonb_set` but will insert new objects where one is missing
* along the path.
*
* Raises `error_in_assignment` error when trying to set a sub-key
* of a non-document type.
*/
red_utils.jsonb_set2(target jsonb, new_value jsonb, VARIADIC identifiers text[])
RETURNS jsonb
LANGUAGE 'plpgsql'
IMMUTABLE
PARALLEL SAFE
AS $$
DECLARE
num_identifiers CONSTANT integer := coalesce(array_length(identifiers, 1), 0);
cur_value_type text;
idx integer;
BEGIN
IF num_identifiers = 0 THEN
RETURN new_value;
END IF;
FOR idx IN SELECT generate_series(1, num_identifiers - 1) LOOP
cur_value_type := jsonb_typeof(target #> identifiers[:idx]);
IF cur_value_type IS NULL THEN
-- Parent key didn't exist in JSON before - insert new object
target := jsonb_set(target, identifiers[:idx], '{}'::jsonb);
ELSIF cur_value_type != 'object' THEN
-- We can't set the sub-field of a null, int, float, array etc.
RAISE EXCEPTION 'Cannot set sub-field of "%s"', cur_value_type
USING ERRCODE = 'error_in_assignment';
END IF;
END LOOP;
RETURN jsonb_set(target, identifiers, new_value);
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Return a set of rows to insert into a table, from a single JSONB
* object containing multiple documents.
*/
red_utils.generate_rows_from_object(object jsonb, num_missing_pkeys integer)
RETURNS setof record
LANGUAGE 'plpgsql'
IMMUTABLE
PARALLEL SAFE
AS $$
DECLARE
pair record;
column_definitions text;
BEGIN
IF num_missing_pkeys = 1 THEN
-- Base case: Simply return (key, value) pairs
RETURN QUERY
SELECT key AS key_1, value AS json_data
FROM jsonb_each(object);
ELSE
-- We need to return (key, key, ..., value) pairs: recurse into inner JSONB objects
column_definitions := red_utils.gen_pkey_column_definitions(2, num_missing_pkeys);
FOR pair IN SELECT * FROM jsonb_each(object) LOOP
RETURN QUERY
EXECUTE format(
$query$
SELECT $1 AS key_1, *
FROM red_utils.generate_rows_from_object($2, $3)
AS f(%s, json_data jsonb)
$query$,
column_definitions)
USING pair.key, pair.value, num_missing_pkeys - 1;
END LOOP;
END IF;
RETURN;
END;
$$;
CREATE OR REPLACE FUNCTION
/*
* Get a comma-separated list of primary key placeholders.
*
* The placeholder will always be $1. Particularly useful for
* inserting values into a table from an array of primary keys.
*/
red_utils.gen_pkey_placeholders(num_pkeys integer, pkey_type text DEFAULT 'text')
RETURNS text
LANGUAGE 'sql'
IMMUTABLE
PARALLEL SAFE
AS $$
SELECT string_agg(t.item, ', ')
FROM (
SELECT format('$1[%s]::%s', idx, pkey_type) AS item
FROM generate_series(1, num_pkeys) idx) t
;
$$;
CREATE OR REPLACE FUNCTION
/*
* Generate a whereclause for the given number of primary keys.
*
* When there are no primary keys, this will simply return the the
* string 'TRUE'. When there are multiple, it will return multiple
* equality comparisons concatenated with 'AND'.
*/
red_utils.gen_whereclause(num_pkeys integer, pkey_type text)
RETURNS text
LANGUAGE 'sql'
IMMUTABLE
PARALLEL SAFE
AS $$
SELECT coalesce(string_agg(t.item, ' AND '), 'TRUE')
FROM (
SELECT format('%I = $1[%s]::%s', 'primary_key_'||idx, idx, pkey_type) AS item
FROM generate_series(1, num_pkeys) idx) t
;
$$;
CREATE OR REPLACE FUNCTION
/*
* Generate a comma-separated list of primary key column names.
*/
red_utils.gen_pkey_columns(start integer, stop integer)
RETURNS text
LANGUAGE 'sql'
IMMUTABLE
PARALLEL SAFE
AS $$
SELECT string_agg(t.item, ', ')
FROM (
SELECT quote_ident('primary_key_'||idx) AS item
FROM generate_series(start, stop) idx) t
;
$$;
CREATE OR REPLACE FUNCTION
/*
* Generate a comma-separated list of primary key column names casted
* to the given type.
*/
red_utils.gen_pkey_columns_casted(start integer, stop integer, pkey_type text DEFAULT 'text')
RETURNS text
LANGUAGE 'sql'
IMMUTABLE
PARALLEL SAFE
AS $$
SELECT string_agg(t.item, ', ')
FROM (
SELECT format('%I::%s', 'primary_key_'||idx, pkey_type) AS item
FROM generate_series(start, stop) idx) t
;
$$;
CREATE OR REPLACE FUNCTION
/*
* Generate a primary key column definition list.
*/
red_utils.gen_pkey_column_definitions(
start integer, stop integer, column_type text DEFAULT 'text'
)
RETURNS text
LANGUAGE 'sql'
IMMUTABLE
PARALLEL SAFE
AS $$
SELECT string_agg(t.item, ', ')
FROM (
SELECT format('%I %s', 'primary_key_'||idx, column_type) AS item
FROM generate_series(start, stop) idx) t
;
$$;
CREATE OR REPLACE FUNCTION
red_utils.get_pkey_type(is_custom boolean)
RETURNS TEXT
LANGUAGE 'sql'
IMMUTABLE
PARALLEL SAFE
AS $$
SELECT ('{bigint,text}'::text[])[is_custom::integer + 1];
$$;
DROP AGGREGATE IF EXISTS red_utils.jsonb_object_agg2(jsonb, VARIADIC text[]);
CREATE AGGREGATE
/*
* Like `jsonb_object_agg` but aggregates more than two columns into a
* single JSONB object.
*
* If possible, use `jsonb_object_agg` instead for performance
* reasons.
*/
red_utils.jsonb_object_agg2(json_data jsonb, VARIADIC primary_keys text[]) (
SFUNC = red_utils.jsonb_set2,
STYPE = jsonb,
INITCOND = '{}',
PARALLEL = SAFE
)
;
CREATE TABLE IF NOT EXISTS
/*
* Table to keep track of other cogs' schemas.
*/
red_config.red_cogs(
cog_name text,
cog_id text,
schemaname text NOT NULL,
PRIMARY KEY (cog_name, cog_id)
)
;

View File

@ -0,0 +1,3 @@
SELECT red_config.delete_all_schemas();
DROP SCHEMA IF EXISTS red_config CASCADE;
DROP SCHEMA IF EXISTS red_utils CASCADE;

View File

@ -0,0 +1,255 @@
import getpass
import json
import sys
from pathlib import Path
from typing import Optional, Any, AsyncIterator, Tuple, Union, Callable, List
try:
# pylint: disable=import-error
import asyncpg
except ModuleNotFoundError:
asyncpg = None
from ... import data_manager, errors
from ..base import BaseDriver, IdentifierData, ConfigCategory
from ..log import log
__all__ = ["PostgresDriver"]
_PKG_PATH = Path(__file__).parent
DDL_SCRIPT_PATH = _PKG_PATH / "ddl.sql"
DROP_DDL_SCRIPT_PATH = _PKG_PATH / "drop_ddl.sql"
def encode_identifier_data(
id_data: IdentifierData
) -> Tuple[str, str, str, List[str], List[str], int, bool]:
return (
id_data.cog_name,
id_data.uuid,
id_data.category,
["0"] if id_data.category == ConfigCategory.GLOBAL else list(id_data.primary_key),
list(id_data.identifiers),
1 if id_data.category == ConfigCategory.GLOBAL else id_data.primary_key_len,
id_data.is_custom,
)
class PostgresDriver(BaseDriver):
_pool: Optional["asyncpg.pool.Pool"] = None
@classmethod
async def initialize(cls, **storage_details) -> None:
if asyncpg is None:
raise errors.MissingExtraRequirements(
"Red must be installed with the [postgres] extra to use the PostgreSQL driver"
)
cls._pool = await asyncpg.create_pool(**storage_details)
with DDL_SCRIPT_PATH.open() as fs:
await cls._pool.execute(fs.read())
@classmethod
async def teardown(cls) -> None:
if cls._pool is not None:
await cls._pool.close()
@staticmethod
def get_config_details():
unixmsg = (
""
if sys.platform != "win32"
else (
" - Common directories for PostgreSQL Unix-domain sockets (/run/postgresql, "
"/var/run/postgresl, /var/pgsql_socket, /private/tmp, and /tmp),\n"
)
)
host = (
input(
f"Enter the PostgreSQL server's address.\n"
f"If left blank, Red will try the following, in order:\n"
f" - The PGHOST environment variable,\n{unixmsg}"
f" - localhost.\n"
f"> "
)
or None
)
print(
"Enter the PostgreSQL server port.\n"
"If left blank, this will default to either:\n"
" - The PGPORT environment variable,\n"
" - 5432."
)
while True:
port = input("> ") or None
if port is None:
break
try:
port = int(port)
except ValueError:
print("Port must be a number")
else:
break
user = (
input(
"Enter the PostgreSQL server username.\n"
"If left blank, this will default to either:\n"
" - The PGUSER environment variable,\n"
" - The OS name of the user running Red (ident/peer authentication).\n"
"> "
)
or None
)
passfile = r"%APPDATA%\postgresql\pgpass.conf" if sys.platform != "win32" else "~/.pgpass"
password = getpass.getpass(
f"Enter the PostgreSQL server password. The input will be hidden.\n"
f" NOTE: If using ident/peer authentication (no password), enter NONE.\n"
f"When NONE is entered, this will default to:\n"
f" - The PGPASSWORD environment variable,\n"
f" - Looking up the password in the {passfile} passfile,\n"
f" - No password.\n"
f"> "
)
if password == "NONE":
password = None
database = (
input(
"Enter the PostgreSQL database's name.\n"
"If left blank, this will default to either:\n"
" - The PGDATABASE environment variable,\n"
" - The OS name of the user running Red.\n"
"> "
)
or None
)
return {
"host": host,
"port": port,
"user": user,
"password": password,
"database": database,
}
async def get(self, identifier_data: IdentifierData):
try:
result = await self._execute(
"SELECT red_config.get($1)",
encode_identifier_data(identifier_data),
method=self._pool.fetchval,
)
except asyncpg.UndefinedTableError:
raise KeyError from None
if result is None:
# The result is None both when postgres yields no results, or when it yields a NULL row
# A 'null' JSON value would be returned as encoded JSON, i.e. the string 'null'
raise KeyError
return json.loads(result)
async def set(self, identifier_data: IdentifierData, value=None):
try:
await self._execute(
"SELECT red_config.set($1, $2::jsonb)",
encode_identifier_data(identifier_data),
json.dumps(value),
)
except asyncpg.ErrorInAssignmentError:
raise errors.CannotSetSubfield
async def clear(self, identifier_data: IdentifierData):
try:
await self._execute(
"SELECT red_config.clear($1)", encode_identifier_data(identifier_data)
)
except asyncpg.UndefinedTableError:
pass
async def inc(
self, identifier_data: IdentifierData, value: Union[int, float], default: Union[int, float]
) -> Union[int, float]:
try:
return await self._execute(
f"SELECT red_config.inc($1, $2, $3)",
encode_identifier_data(identifier_data),
value,
default,
method=self._pool.fetchval,
)
except asyncpg.WrongObjectTypeError as exc:
raise errors.StoredTypeError(*exc.args)
async def toggle(self, identifier_data: IdentifierData, default: bool) -> bool:
try:
return await self._execute(
"SELECT red_config.inc($1, $2)",
encode_identifier_data(identifier_data),
default,
method=self._pool.fetchval,
)
except asyncpg.WrongObjectTypeError as exc:
raise errors.StoredTypeError(*exc.args)
@classmethod
async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
query = "SELECT cog_name, cog_id FROM red_config.red_cogs"
log.invisible(query)
async with cls._pool.acquire() as conn, conn.transaction():
async for row in conn.cursor(query):
yield row["cog_name"], row["cog_id"]
@classmethod
async def delete_all_data(
cls, *, interactive: bool = False, drop_db: Optional[bool] = None, **kwargs
) -> None:
"""Delete all data being stored by this driver.
Parameters
----------
interactive : bool
Set to ``True`` to allow the method to ask the user for
input from the console, regarding the other unset parameters
for this method.
drop_db : Optional[bool]
Set to ``True`` to drop the entire database for the current
bot's instance. Otherwise, schemas within the database which
store bot data will be dropped, as well as functions,
aggregates, event triggers, and meta-tables.
"""
if interactive is True and drop_db is None:
print(
"Please choose from one of the following options:\n"
" 1. Drop the entire PostgreSQL database for this instance, or\n"
" 2. Delete all of Red's data within this database, without dropping the database "
"itself."
)
options = ("1", "2")
while True:
resp = input("> ")
try:
drop_db = bool(options.index(resp))
except ValueError:
print("Please type a number corresponding to one of the options.")
else:
break
if drop_db is True:
storage_details = data_manager.storage_details()
await cls._pool.execute(f"DROP DATABASE $1", storage_details["database"])
else:
with DROP_DDL_SCRIPT_PATH.open() as fs:
await cls._pool.execute(fs.read())
@classmethod
async def _execute(cls, query: str, *args, method: Optional[Callable] = None) -> Any:
if method is None:
method = cls._pool.execute
log.invisible("Query: %s", query)
if args:
log.invisible("Args: %s", args)
return await method(query, *args)

View File

@ -1,233 +0,0 @@
import enum
from typing import Tuple
__all__ = ["BaseDriver", "IdentifierData"]
class ConfigCategory(enum.Enum):
GLOBAL = "GLOBAL"
GUILD = "GUILD"
CHANNEL = "TEXTCHANNEL"
ROLE = "ROLE"
USER = "USER"
MEMBER = "MEMBER"
class IdentifierData:
def __init__(
self,
uuid: str,
category: str,
primary_key: Tuple[str, ...],
identifiers: Tuple[str, ...],
custom_group_data: dict,
is_custom: bool = False,
):
self._uuid = uuid
self._category = category
self._primary_key = primary_key
self._identifiers = identifiers
self.custom_group_data = custom_group_data
self._is_custom = is_custom
@property
def uuid(self):
return self._uuid
@property
def category(self):
return self._category
@property
def primary_key(self):
return self._primary_key
@property
def identifiers(self):
return self._identifiers
@property
def is_custom(self):
return self._is_custom
def __repr__(self):
return (
f"<IdentifierData uuid={self.uuid} category={self.category} primary_key={self.primary_key}"
f" identifiers={self.identifiers}>"
)
def __eq__(self, other) -> bool:
if not isinstance(other, IdentifierData):
return False
return (
self.uuid == other.uuid
and self.category == other.category
and self.primary_key == other.primary_key
and self.identifiers == other.identifiers
)
def __hash__(self) -> int:
return hash((self.uuid, self.category, self.primary_key, self.identifiers))
def add_identifier(self, *identifier: str) -> "IdentifierData":
if not all(isinstance(i, str) for i in identifier):
raise ValueError("Identifiers must be strings.")
return IdentifierData(
self.uuid,
self.category,
self.primary_key,
self.identifiers + identifier,
self.custom_group_data,
is_custom=self.is_custom,
)
def to_tuple(self):
return tuple(
item
for item in (self.uuid, self.category, *self.primary_key, *self.identifiers)
if len(item) > 0
)
class BaseDriver:
def __init__(self, cog_name, identifier):
self.cog_name = cog_name
self.unique_cog_identifier = identifier
async def has_valid_connection(self) -> bool:
raise NotImplementedError
async def get(self, identifier_data: IdentifierData):
"""
Finds the value indicate by the given identifiers.
Parameters
----------
identifier_data
Returns
-------
Any
Stored value.
"""
raise NotImplementedError
def get_config_details(self):
"""
Asks users for additional configuration information necessary
to use this config driver.
Returns
-------
Dict of configuration details.
"""
raise NotImplementedError
async def set(self, identifier_data: IdentifierData, value=None):
"""
Sets the value of the key indicated by the given identifiers.
Parameters
----------
identifier_data
value
Any JSON serializable python object.
"""
raise NotImplementedError
async def clear(self, identifier_data: IdentifierData):
"""
Clears out the value specified by the given identifiers.
Equivalent to using ``del`` on a dict.
Parameters
----------
identifier_data
"""
raise NotImplementedError
def _get_levels(self, category, custom_group_data):
if category == ConfigCategory.GLOBAL.value:
return 0
elif category in (
ConfigCategory.USER.value,
ConfigCategory.GUILD.value,
ConfigCategory.CHANNEL.value,
ConfigCategory.ROLE.value,
):
return 1
elif category == ConfigCategory.MEMBER.value:
return 2
elif category in custom_group_data:
return custom_group_data[category]
else:
raise RuntimeError(f"Cannot convert due to group: {category}")
def _split_primary_key(self, category, custom_group_data, data):
levels = self._get_levels(category, custom_group_data)
if levels == 0:
return (((), data),)
def flatten(levels_remaining, currdata, parent_key=()):
items = []
for k, v in currdata.items():
new_key = parent_key + (k,)
if levels_remaining > 1:
items.extend(flatten(levels_remaining - 1, v, new_key).items())
else:
items.append((new_key, v))
return dict(items)
ret = []
for k, v in flatten(levels, data).items():
ret.append((k, v))
return tuple(ret)
async def export_data(self, custom_group_data):
categories = [c.value for c in ConfigCategory]
categories.extend(custom_group_data.keys())
ret = []
for c in categories:
ident_data = IdentifierData(
self.unique_cog_identifier,
c,
(),
(),
custom_group_data,
is_custom=c in custom_group_data,
)
try:
data = await self.get(ident_data)
except KeyError:
continue
ret.append((c, data))
return ret
async def import_data(self, cog_data, custom_group_data):
for category, all_data in cog_data:
splitted_pkey = self._split_primary_key(category, custom_group_data, all_data)
for pkey, data in splitted_pkey:
ident_data = IdentifierData(
self.unique_cog_identifier,
category,
pkey,
(),
custom_group_data,
is_custom=category in custom_group_data,
)
await self.set(ident_data, data)
@staticmethod
def get_pkey_len(identifier_data: IdentifierData) -> int:
cat = identifier_data.category
if cat == ConfigCategory.GLOBAL.value:
return 0
elif cat == ConfigCategory.MEMBER.value:
return 2
elif identifier_data.is_custom:
return identifier_data.custom_group_data[cat]
else:
return 1

View File

@ -1,5 +1,4 @@
import importlib.machinery
from typing import Optional
import discord
@ -49,3 +48,37 @@ class BalanceTooHigh(BankError, OverflowError):
return _("{user}'s balance cannot rise above {max:,} {currency}.").format(
user=self.user, max=self.max_balance, currency=self.currency_name
)
class MissingExtraRequirements(RedError):
"""Raised when an extra requirement is missing but required."""
class ConfigError(RedError):
"""Error in a Config operation."""
class StoredTypeError(ConfigError, TypeError):
"""A TypeError pertaining to stored Config data.
This error may arise when, for example, trying to increment a value
which is not a number, or trying to toggle a value which is not a
boolean.
"""
class CannotSetSubfield(StoredTypeError):
"""Tried to set sub-field of an invalid data structure.
This would occur in the following example::
>>> import asyncio
>>> from redbot.core import Config
>>> config = Config.get_conf(None, 1234, cog_name="Example")
>>> async def example():
... await config.foo.set(True)
... await config.set_raw("foo", "bar", False) # Should raise here
...
>>> asyncio.run(example())
"""

View File

@ -1,7 +1,9 @@
import asyncio
import json
import logging
import os
import shutil
import tarfile
from asyncio import AbstractEventLoop, as_completed, Semaphore
from asyncio.futures import isfuture
from itertools import chain
@ -24,8 +26,10 @@ from typing import (
)
import discord
from datetime import datetime
from fuzzywuzzy import fuzz, process
from .. import commands, data_manager
from .chat_formatting import box
if TYPE_CHECKING:
@ -37,6 +41,7 @@ __all__ = [
"fuzzy_command_search",
"format_fuzzy_results",
"deduplicate_iterables",
"create_backup",
]
_T = TypeVar("_T")
@ -397,3 +402,45 @@ def bounded_gather(
tasks = (_sem_wrapper(semaphore, task) for task in coros_or_futures)
return asyncio.gather(*tasks, loop=loop, return_exceptions=return_exceptions)
async def create_backup(dest: Path = Path.home()) -> Optional[Path]:
data_path = Path(data_manager.core_data_path().parent)
if not data_path.exists():
return
dest.mkdir(parents=True, exist_ok=True)
timestr = datetime.utcnow().isoformat(timespec="minutes")
backup_fpath = dest / f"redv3_{data_manager.instance_name}_{timestr}.tar.gz"
to_backup = []
exclusions = [
"__pycache__",
"Lavalink.jar",
os.path.join("Downloader", "lib"),
os.path.join("CogManager", "cogs"),
os.path.join("RepoManager", "repos"),
]
# Avoiding circular imports
from ...cogs.downloader.repo_manager import RepoManager
repo_mgr = RepoManager()
await repo_mgr.initialize()
repo_output = []
for _, repo in repo_mgr._repos:
repo_output.append({"url": repo.url, "name": repo.name, "branch": repo.branch})
repos_file = data_path / "cogs" / "RepoManager" / "repos.json"
with repos_file.open("w") as fs:
json.dump(repo_output, fs, indent=4)
instance_file = data_path / "instance.json"
with instance_file.open("w") as fs:
json.dump({data_manager.instance_name: data_manager.basic_config}, fs, indent=4)
for f in data_path.glob("**/*"):
if not any(ex in str(f) for ex in exclusions) and f.is_file():
to_backup.append(f)
with tarfile.open(str(backup_fpath), "w:gz") as tar:
for f in to_backup:
tar.add(str(f), arcname=f.relative_to(data_path), recursive=False)
return backup_fpath

View File

@ -89,7 +89,7 @@ class Tunnel(metaclass=TunnelMeta):
destination: discord.abc.Messageable,
content: str = None,
embed=None,
files: Optional[List[discord.File]] = None
files: Optional[List[discord.File]] = None,
) -> List[discord.Message]:
"""
This does the actual sending, use this instead of a full tunnel

View File

@ -7,15 +7,13 @@ import pytest
from _pytest.monkeypatch import MonkeyPatch
from redbot.core import Config
from redbot.core.bot import Red
from redbot.core import config as config_module
from redbot.core.drivers import red_json
from redbot.core import config as config_module, drivers
__all__ = [
"monkeysession",
"override_data_path",
"coroutine",
"json_driver",
"driver",
"config",
"config_fr",
"red",
@ -56,34 +54,31 @@ def coroutine():
@pytest.fixture()
def json_driver(tmpdir_factory):
def driver(tmpdir_factory):
import uuid
rand = str(uuid.uuid4())
path = Path(str(tmpdir_factory.mktemp(rand)))
driver = red_json.JSON("PyTest", identifier=str(uuid.uuid4()), data_path_override=path)
return driver
return drivers.get_driver("PyTest", str(random.randint(1, 999999)), data_path_override=path)
@pytest.fixture()
def config(json_driver):
def config(driver):
config_module._config_cache = weakref.WeakValueDictionary()
conf = Config(
cog_name="PyTest", unique_identifier=json_driver.unique_cog_identifier, driver=json_driver
)
conf = Config(cog_name="PyTest", unique_identifier=driver.unique_cog_identifier, driver=driver)
yield conf
@pytest.fixture()
def config_fr(json_driver):
def config_fr(driver):
"""
Mocked config object with force_register enabled.
"""
config_module._config_cache = weakref.WeakValueDictionary()
conf = Config(
cog_name="PyTest",
unique_identifier=json_driver.unique_cog_identifier,
driver=json_driver,
unique_identifier=driver.unique_cog_identifier,
driver=driver,
force_registration=True,
)
yield conf
@ -176,7 +171,7 @@ def red(config_fr):
Config.get_core_conf = lambda *args, **kwargs: config_fr
red = Red(cli_flags=cli_flags, description=description, dm_help=None)
red = Red(cli_flags=cli_flags, description=description, dm_help=None, owner_id=None)
yield red

View File

@ -1,32 +1,21 @@
#!/usr/bin/env python3
import asyncio
import json
import logging
import os
import sys
import tarfile
from copy import deepcopy
from datetime import datetime as dt
from pathlib import Path
import logging
from typing import Dict, Any, Optional
import appdirs
import click
import redbot.logging
from redbot.core.cli import confirm
from redbot.core.data_manager import (
basic_config_default,
load_basic_configuration,
instance_name,
basic_config,
cog_data_path,
core_data_path,
storage_details,
)
from redbot.core.utils import safe_delete
from redbot.core import Config
from redbot.core.utils import safe_delete, create_backup as _create_backup
from redbot.core import config, data_manager, drivers
from redbot.core.drivers import BackendType, IdentifierData
from redbot.core.drivers.red_json import JSON
conversion_log = logging.getLogger("red.converter")
@ -61,11 +50,11 @@ else:
def save_config(name, data, remove=False):
config = load_existing_config()
if remove and name in config:
config.pop(name)
_config = load_existing_config()
if remove and name in _config:
_config.pop(name)
else:
if name in config:
if name in _config:
print(
"WARNING: An instance already exists with this name. "
"Continuing will overwrite the existing instance config."
@ -73,10 +62,10 @@ def save_config(name, data, remove=False):
if not confirm("Are you absolutely certain you want to continue (y/n)? "):
print("Not continuing")
sys.exit(0)
config[name] = data
_config[name] = data
with config_file.open("w", encoding="utf-8") as fs:
json.dump(config, fs, indent=4)
json.dump(_config, fs, indent=4)
def get_data_dir():
@ -118,13 +107,14 @@ def get_data_dir():
def get_storage_type():
storage_dict = {1: "JSON", 2: "MongoDB"}
storage_dict = {1: "JSON", 2: "MongoDB", 3: "PostgreSQL"}
storage = None
while storage is None:
print()
print("Please choose your storage backend (if you're unsure, choose 1).")
print("1. JSON (file storage, requires no database).")
print("2. MongoDB")
print("3. PostgreSQL")
storage = input("> ")
try:
storage = int(storage)
@ -158,21 +148,16 @@ def basic_setup():
default_data_dir = get_data_dir()
default_dirs = deepcopy(basic_config_default)
default_dirs = deepcopy(data_manager.basic_config_default)
default_dirs["DATA_PATH"] = str(default_data_dir.resolve())
storage = get_storage_type()
storage_dict = {1: BackendType.JSON, 2: BackendType.MONGO}
storage_dict = {1: BackendType.JSON, 2: BackendType.MONGO, 3: BackendType.POSTGRES}
storage_type: BackendType = storage_dict.get(storage, BackendType.JSON)
default_dirs["STORAGE_TYPE"] = storage_type.value
if storage_type == BackendType.MONGO:
from redbot.core.drivers.red_mongo import get_config_details
default_dirs["STORAGE_DETAILS"] = get_config_details()
else:
default_dirs["STORAGE_DETAILS"] = {}
driver_cls = drivers.get_driver_class(storage_type)
default_dirs["STORAGE_DETAILS"] = driver_cls.get_config_details()
name = get_name()
save_config(name, default_dirs)
@ -193,130 +178,38 @@ def get_target_backend(backend) -> BackendType:
return BackendType.JSON
elif backend == "mongo":
return BackendType.MONGO
elif backend == "postgres":
return BackendType.POSTGRES
async def json_to_mongov2(instance):
instance_vals = instance_data[instance]
current_data_dir = Path(instance_vals["DATA_PATH"])
async def do_migration(
current_backend: BackendType, target_backend: BackendType
) -> Dict[str, Any]:
cur_driver_cls = drivers.get_driver_class(current_backend)
new_driver_cls = drivers.get_driver_class(target_backend)
cur_storage_details = data_manager.storage_details()
new_storage_details = new_driver_cls.get_config_details()
load_basic_configuration(instance)
await cur_driver_cls.initialize(**cur_storage_details)
await new_driver_cls.initialize(**new_storage_details)
from redbot.core.drivers import red_mongo
await config.migrate(cur_driver_cls, new_driver_cls)
storage_details = red_mongo.get_config_details()
await cur_driver_cls.teardown()
await new_driver_cls.teardown()
core_conf = Config.get_core_conf()
new_driver = red_mongo.Mongo(cog_name="Core", identifier="0", **storage_details)
core_conf.init_custom("CUSTOM_GROUPS", 2)
custom_group_data = await core_conf.custom("CUSTOM_GROUPS").all()
curr_custom_data = custom_group_data.get("Core", {}).get("0", {})
exported_data = await core_conf.driver.export_data(curr_custom_data)
conversion_log.info("Starting Core conversion...")
await new_driver.import_data(exported_data, curr_custom_data)
conversion_log.info("Core conversion complete.")
for p in current_data_dir.glob("cogs/**/settings.json"):
cog_name = p.parent.stem
if "." in cog_name:
# Garbage handler
continue
with p.open(mode="r", encoding="utf-8") as f:
cog_data = json.load(f)
for identifier, all_data in cog_data.items():
try:
conf = Config.get_conf(None, int(identifier), cog_name=cog_name)
except ValueError:
continue
new_driver = red_mongo.Mongo(
cog_name=cog_name, identifier=conf.driver.unique_cog_identifier, **storage_details
)
curr_custom_data = custom_group_data.get(cog_name, {}).get(identifier, {})
exported_data = await conf.driver.export_data(curr_custom_data)
conversion_log.info(f"Converting {cog_name} with identifier {identifier}...")
await new_driver.import_data(exported_data, curr_custom_data)
conversion_log.info("Cog conversion complete.")
return storage_details
return new_storage_details
async def mongov2_to_json(instance):
load_basic_configuration(instance)
core_path = core_data_path()
from redbot.core.drivers import red_json
core_conf = Config.get_core_conf()
new_driver = red_json.JSON(cog_name="Core", identifier="0", data_path_override=core_path)
core_conf.init_custom("CUSTOM_GROUPS", 2)
custom_group_data = await core_conf.custom("CUSTOM_GROUPS").all()
curr_custom_data = custom_group_data.get("Core", {}).get("0", {})
exported_data = await core_conf.driver.export_data(curr_custom_data)
conversion_log.info("Starting Core conversion...")
await new_driver.import_data(exported_data, curr_custom_data)
conversion_log.info("Core conversion complete.")
collection_names = await core_conf.driver.db.list_collection_names()
splitted_names = list(
filter(
lambda elem: elem[1] != "" and elem[0] != "Core",
[n.split(".") for n in collection_names],
)
)
ident_map = {} # Cogname: idents list
for cog_name, category in splitted_names:
if cog_name not in ident_map:
ident_map[cog_name] = set()
idents = await core_conf.driver.db[cog_name][category].distinct("_id.RED_uuid")
ident_map[cog_name].update(set(idents))
for cog_name, idents in ident_map.items():
for identifier in idents:
curr_custom_data = custom_group_data.get(cog_name, {}).get(identifier, {})
try:
conf = Config.get_conf(None, int(identifier), cog_name=cog_name)
except ValueError:
continue
exported_data = await conf.driver.export_data(curr_custom_data)
new_path = cog_data_path(raw_name=cog_name)
new_driver = red_json.JSON(cog_name, identifier, data_path_override=new_path)
conversion_log.info(f"Converting {cog_name} with identifier {identifier}...")
await new_driver.import_data(exported_data, curr_custom_data)
# cog_data_path(raw_name=cog_name)
conversion_log.info("Cog conversion complete.")
return {}
async def mongo_to_json(instance):
load_basic_configuration(instance)
from redbot.core.drivers.red_mongo import Mongo
m = Mongo("Core", "0", **storage_details())
async def mongov1_to_json() -> Dict[str, Any]:
await drivers.MongoDriver.initialize(**data_manager.storage_details())
m = drivers.MongoDriver("Core", "0")
db = m.db
collection_names = await db.list_collection_names()
for collection_name in collection_names:
if "." in collection_name:
# Fix for one of Zeph's problems
continue
elif collection_name == "Core":
c_data_path = core_data_path()
else:
c_data_path = cog_data_path(raw_name=collection_name)
c_data_path.mkdir(parents=True, exist_ok=True)
# Every cog name has its own collection
collection = db[collection_name]
async for document in collection.find():
@ -329,16 +222,22 @@ async def mongo_to_json(instance):
continue
elif not str(cog_id).isdigit():
continue
driver = JSON(collection_name, cog_id, data_path_override=c_data_path)
driver = drivers.JsonDriver(collection_name, cog_id)
for category, value in document.items():
ident_data = IdentifierData(str(cog_id), category, (), (), {})
ident_data = IdentifierData(
str(collection_name), str(cog_id), category, tuple(), tuple(), 0
)
await driver.set(ident_data, value=value)
conversion_log.info("Cog conversion complete.")
await drivers.MongoDriver.teardown()
return {}
async def edit_instance():
instance_list = load_existing_config()
if not instance_list:
_instance_list = load_existing_config()
if not _instance_list:
print("No instances have been set up!")
return
@ -346,18 +245,18 @@ async def edit_instance():
"You have chosen to edit an instance. The following "
"is a list of instances that currently exist:\n"
)
for instance in instance_list.keys():
for instance in _instance_list.keys():
print("{}\n".format(instance))
print("Please select one of the above by entering its name")
selected = input("> ")
if selected not in instance_list.keys():
if selected not in _instance_list.keys():
print("That isn't a valid instance!")
return
instance_data = instance_list[selected]
default_dirs = deepcopy(basic_config_default)
_instance_data = _instance_list[selected]
default_dirs = deepcopy(data_manager.basic_config_default)
current_data_dir = Path(instance_data["DATA_PATH"])
current_data_dir = Path(_instance_data["DATA_PATH"])
print("You have selected '{}' as the instance to modify.".format(selected))
if not confirm("Please confirm (y/n):"):
print("Ok, we will not continue then.")
@ -383,68 +282,47 @@ async def edit_instance():
print("Your basic configuration has been edited")
async def create_backup(instance):
instance_vals = instance_data[instance]
if confirm("Would you like to make a backup of the data for this instance? (y/n)"):
load_basic_configuration(instance)
if instance_vals["STORAGE_TYPE"] == "MongoDB":
await mongo_to_json(instance)
print("Backing up the instance's data...")
backup_filename = "redv3-{}-{}.tar.gz".format(
instance, dt.utcnow().strftime("%Y-%m-%d %H-%M-%S")
)
pth = Path(instance_vals["DATA_PATH"])
if pth.exists():
backup_pth = pth.home()
backup_file = backup_pth / backup_filename
to_backup = []
exclusions = [
"__pycache__",
"Lavalink.jar",
os.path.join("Downloader", "lib"),
os.path.join("CogManager", "cogs"),
os.path.join("RepoManager", "repos"),
]
from redbot.cogs.downloader.repo_manager import RepoManager
repo_mgr = RepoManager()
await repo_mgr.initialize()
repo_output = []
for repo in repo_mgr._repos.values():
repo_output.append({"url": repo.url, "name": repo.name, "branch": repo.branch})
repo_filename = pth / "cogs" / "RepoManager" / "repos.json"
with open(str(repo_filename), "w") as f:
f.write(json.dumps(repo_output, indent=4))
instance_vals = {instance_name: basic_config}
instance_file = pth / "instance.json"
with open(str(instance_file), "w") as instance_out:
instance_out.write(json.dumps(instance_vals, indent=4))
for f in pth.glob("**/*"):
if not any(ex in str(f) for ex in exclusions):
to_backup.append(f)
with tarfile.open(str(backup_file), "w:gz") as tar:
for f in to_backup:
tar.add(str(f), recursive=False)
print("A backup of {} has been made. It is at {}".format(instance, backup_file))
async def remove_instance(instance):
await create_backup(instance)
instance_vals = instance_data[instance]
if instance_vals["STORAGE_TYPE"] == "MongoDB":
from redbot.core.drivers.red_mongo import Mongo
m = Mongo("Core", **instance_vals["STORAGE_DETAILS"])
db = m.db
collections = await db.collection_names(include_system_collections=False)
for name in collections:
collection = await db.get_collection(name)
await collection.drop()
async def create_backup(instance: str) -> None:
data_manager.load_basic_configuration(instance)
backend_type = get_current_backend(instance)
if backend_type == BackendType.MONGOV1:
await mongov1_to_json()
elif backend_type != BackendType.JSON:
await do_migration(backend_type, BackendType.JSON)
print("Backing up the instance's data...")
backup_fpath = await _create_backup()
if backup_fpath is not None:
print(f"A backup of {instance} has been made. It is at {backup_fpath}")
else:
pth = Path(instance_vals["DATA_PATH"])
safe_delete(pth)
print("Creating the backup failed.")
async def remove_instance(
instance,
interactive: bool = False,
drop_db: Optional[bool] = None,
remove_datapath: Optional[bool] = None,
):
data_manager.load_basic_configuration(instance)
if confirm("Would you like to make a backup of the data for this instance? (y/n)"):
await create_backup(instance)
backend = get_current_backend(instance)
if backend == BackendType.MONGOV1:
driver_cls = drivers.MongoDriver
else:
driver_cls = drivers.get_driver_class(backend)
await driver_cls.delete_all_data(interactive=interactive, drop_db=drop_db)
if interactive is True and remove_datapath is None:
remove_datapath = confirm("Would you like to delete the instance's entire datapath? (y/n)")
if remove_datapath is True:
data_path = data_manager.core_data_path().parent
safe_delete(data_path)
save_config(instance, {}, remove=True)
print("The instance {} has been removed\n".format(instance))
@ -467,8 +345,7 @@ async def remove_instance_interaction():
print("That isn't a valid instance!")
return
await create_backup(selected)
await remove_instance(selected)
await remove_instance(selected, interactive=True)
@click.group(invoke_without_command=True)
@ -483,38 +360,56 @@ def cli(ctx, debug):
@cli.command()
@click.argument("instance", type=click.Choice(instance_list))
def delete(instance):
@click.option("--no-prompt", default=False, help="Don't ask for user input during the process.")
@click.option(
"--drop-db",
type=bool,
default=None,
help=(
"Drop the entire database constaining this instance's data. Has no effect on JSON "
"instances. If this option and --no-prompt are omitted, you will be asked about this."
),
)
@click.option(
"--remove-datapath",
type=bool,
default=None,
help=(
"Remove this entire instance's datapath. If this option and --no-prompt are omitted, you "
"will be asked about this."
),
)
def delete(instance: str, no_prompt: Optional[bool], drop_db: Optional[bool]):
loop = asyncio.get_event_loop()
loop.run_until_complete(remove_instance(instance))
if no_prompt is None:
interactive = None
else:
interactive = not no_prompt
loop.run_until_complete(remove_instance(instance, interactive, drop_db))
@cli.command()
@click.argument("instance", type=click.Choice(instance_list))
@click.argument("backend", type=click.Choice(["json", "mongo"]))
@click.argument("backend", type=click.Choice(["json", "mongo", "postgres"]))
def convert(instance, backend):
current_backend = get_current_backend(instance)
target = get_target_backend(backend)
data_manager.load_basic_configuration(instance)
default_dirs = deepcopy(basic_config_default)
default_dirs = deepcopy(data_manager.basic_config_default)
default_dirs["DATA_PATH"] = str(Path(instance_data[instance]["DATA_PATH"]))
loop = asyncio.get_event_loop()
new_storage_details = None
if current_backend == BackendType.MONGOV1:
if target == BackendType.MONGO:
if target == BackendType.JSON:
new_storage_details = loop.run_until_complete(mongov1_to_json())
else:
raise RuntimeError(
"Please see conversion docs for updating to the latest mongo version."
)
elif target == BackendType.JSON:
new_storage_details = loop.run_until_complete(mongo_to_json(instance))
elif current_backend == BackendType.JSON:
if target == BackendType.MONGO:
new_storage_details = loop.run_until_complete(json_to_mongov2(instance))
elif current_backend == BackendType.MONGO:
if target == BackendType.JSON:
new_storage_details = loop.run_until_complete(mongov2_to_json(instance))
else:
new_storage_details = loop.run_until_complete(do_migration(current_backend, target))
if new_storage_details is not None:
default_dirs["STORAGE_TYPE"] = target.value
@ -522,7 +417,9 @@ def convert(instance, backend):
save_config(instance, default_dirs)
conversion_log.info(f"Conversion to {target} complete.")
else:
conversion_log.info(f"Cannot convert {current_backend} to {target} at this time.")
conversion_log.info(
f"Cannot convert {current_backend.value} to {target.value} at this time."
)
if __name__ == "__main__":

View File

@ -80,6 +80,8 @@ mongo =
dnspython==1.16.0
motor==2.0.0
pymongo==3.8.0
postgres =
asyncpg==0.18.3
style =
black==19.3b0
toml==0.10.0
@ -123,3 +125,5 @@ include =
**/locales/*.po
data/*
data/**/*
redbot.core.drivers.postgres =
*.sql

45
tests/conftest.py Normal file
View File

@ -0,0 +1,45 @@
import asyncio
import os
import pytest
from redbot.core import drivers, data_manager
@pytest.fixture(scope="session")
def event_loop(request):
"""Create an instance of the default event loop for entire session."""
loop = asyncio.get_event_loop_policy().new_event_loop()
yield loop
loop.close()
def _get_backend_type():
if os.getenv("RED_STORAGE_TYPE") == "postgres":
return drivers.BackendType.POSTGRES
elif os.getenv("RED_STORAGE_TYPE") == "mongo":
return drivers.BackendType.MONGO
else:
return drivers.BackendType.JSON
@pytest.fixture(scope="session", autouse=True)
async def _setup_driver():
backend_type = _get_backend_type()
if backend_type == drivers.BackendType.MONGO:
storage_details = {
"URI": os.getenv("RED_MONGO_URI", "mongodb"),
"HOST": os.getenv("RED_MONGO_HOST", "localhost"),
"PORT": int(os.getenv("RED_MONGO_PORT", "27017")),
"USERNAME": os.getenv("RED_MONGO_USER", "red"),
"PASSWORD": os.getenv("RED_MONGO_PASSWORD", "red"),
"DB_NAME": os.getenv("RED_MONGO_DATABASE", "red_db"),
}
else:
storage_details = {}
data_manager.storage_type = lambda: backend_type.value
data_manager.storage_details = lambda: storage_details
driver_cls = drivers.get_driver_class(backend_type)
await driver_cls.initialize(**storage_details)
yield
await driver_cls.teardown()

View File

@ -1,3 +1,3 @@
packaging
tox
-e .[docs,mongo,style,test]
-e .[docs,mongo,postgres,style,test]

View File

@ -34,6 +34,8 @@ docs =
mongo =
dnspython
motor
postgres =
asyncpg
style =
black
test =

37
tox.ini
View File

@ -15,12 +15,47 @@ description = Run tests and basic automatic issue checking.
whitelist_externals =
pytest
pylint
extras = voice, test, mongo
extras = voice, test
commands =
python -m compileall ./redbot/cogs
pytest
pylint ./redbot
[testenv:postgres]
description = Run pytest with PostgreSQL backend
whitelist_externals =
pytest
extras = voice, test, postgres
setenv =
RED_STORAGE_TYPE=postgres
passenv =
# Use the following env vars for connection options, or other default options described here:
# https://magicstack.github.io/asyncpg/current/index.html#asyncpg.connection.connect
PGHOST
PGPORT
PGUSER
PGPASSWORD
PGDATABASE
commands =
pytest
[testenv:mongo]
description = Run pytest with MongoDB backend
whitelist_externals =
pytest
extras = voice, test, mongo
setenv =
RED_STORAGE_TYPE=mongo
passenv =
RED_MONGO_URI
RED_MONGO_HOST
RED_MONGO_PORT
RED_MONGO_USER
RED_MONGO_PASSWORD
RED_MONGO_DATABASE
commands =
pytest
[testenv:docs]
description = Attempt to build docs with sphinx-build
whitelist_externals =