mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 03:08:55 -05:00
PostgreSQL driver, tests against DB backends, and general drivers cleanup (#2723)
* PostgreSQL driver and general drivers cleanup Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Make tests pass Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Add black --target-version flag in make.bat Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Rewrite postgres driver Most of the logic is now in PL/pgSQL. This completely avoids the use of Python f-strings to format identifiers into queries. Although an SQL-injection attack would have been impossible anyway (only the owner would have ever had the ability to do that), using PostgreSQL's format() is more reliable for unusual identifiers. Performance-wise, I'm not sure whether this is an improvement, but I highly doubt that it's worse. Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Reformat Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix PostgresDriver.delete_all_data() Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Clean up PL/pgSQL code Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * More PL/pgSQL cleanup Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * PL/pgSQL function optimisations Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Ensure compatibility with PostgreSQL 10 and below Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * More/better docstrings for PG functions Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix typo in docstring Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Return correct value on toggle() Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Use composite type for PG function parameters Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix JSON driver's Config.clear_all() Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Correct description for Mongo tox recipe Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Fix linting errors Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Update dep specification after merging bumpdeps Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Add towncrier entries Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Update from merge Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Mention [postgres] extra in install docs Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Support more connection options and use better defaults Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Actually pass PG env vars in tox Signed-off-by: Toby Harradine <tobyharradine@gmail.com> * Replace event trigger with manual DELETE queries Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
parent
57fa29dd64
commit
d1a46acc9a
19
.travis.yml
19
.travis.yml
@ -5,14 +5,10 @@ notifications:
|
||||
email: false
|
||||
|
||||
python:
|
||||
- 3.7.2
|
||||
- 3.7.3
|
||||
env:
|
||||
global:
|
||||
- PIPENV_IGNORE_VIRTUALENVS=1
|
||||
matrix:
|
||||
- TOXENV=py
|
||||
- TOXENV=docs
|
||||
- TOXENV=style
|
||||
|
||||
install:
|
||||
- pip install --upgrade pip tox
|
||||
@ -22,6 +18,19 @@ script:
|
||||
|
||||
jobs:
|
||||
include:
|
||||
- env: TOXENV=py
|
||||
- env: TOXENV=docs
|
||||
- env: TOXENV=style
|
||||
- env: TOXENV=postgres
|
||||
services: postgresql
|
||||
addons:
|
||||
postgresql: "10"
|
||||
before_script:
|
||||
- psql -c 'create database red_db;' -U postgres
|
||||
- env: TOXENV=mongo
|
||||
services: mongodb
|
||||
before_script:
|
||||
- mongo red_db --eval 'db.createUser({user:"red",pwd:"red",roles:["readWrite"]});'
|
||||
# These jobs only occur on tag creation if the prior ones succeed
|
||||
- stage: PyPi Deployment
|
||||
if: tag IS present
|
||||
|
||||
4
Makefile
4
Makefile
@ -1,8 +1,8 @@
|
||||
# Python Code Style
|
||||
reformat:
|
||||
black -l 99 `git ls-files "*.py"`
|
||||
black -l 99 --target-version py37 `git ls-files "*.py"`
|
||||
stylecheck:
|
||||
black --check -l 99 `git ls-files "*.py"`
|
||||
black --check -l 99 --target-version py37 `git ls-files "*.py"`
|
||||
|
||||
# Translations
|
||||
gettext:
|
||||
|
||||
1
changelog.d/2723.feature.rst
Normal file
1
changelog.d/2723.feature.rst
Normal file
@ -0,0 +1 @@
|
||||
Added a config driver for PostgreSQL
|
||||
32
changelog.d/2723.misc.rst
Normal file
32
changelog.d/2723.misc.rst
Normal file
@ -0,0 +1,32 @@
|
||||
Changes to the ``redbot.core.drivers`` package:
|
||||
|
||||
- The modules inside the ``redbot.core.drivers`` package no longer have the ``red_`` prefix in front of their names.
|
||||
- All driver classes are now directly accessible as attributes to the ``redbot.core.drivers`` package.
|
||||
- :func:`get_driver`'s signature has been changed.
|
||||
- :func:`get_driver` can now use data manager to infer the backend type if it is not supplied as an argument.
|
||||
- :func:`get_driver_class` has been added.
|
||||
|
||||
Changes to the :class:`BaseDriver` and :class:`JsonDriver` classes class:
|
||||
|
||||
- :meth:`BaseDriver.get_config_details` is an now abstract staticmethod.
|
||||
- :meth:`BaseDriver.initialize` and :meth:`BaseDriver.teardown` are two new abstract coroutine classmethods.
|
||||
- :meth:`BaseDriver.delete_all_data` is a new concrete (but overrideable) coroutine instance method.
|
||||
- :meth:`BaseDriver.aiter_cogs` is a new abstract asynchronous iterator method.
|
||||
- :meth:`BaseDriver.migrate_to` is a new concrete coroutine classmethod.
|
||||
- :class:`JsonDriver` no longer requires the data path when constructed and will infer the data path from data manager.
|
||||
|
||||
Changes to the :class:`IdentifierData` class and :class:`ConfigCategory` enum:
|
||||
|
||||
- ``IdentifierData.custom_group_data`` has been replaced by :attr:`IdentifierData.primary_key_len`.
|
||||
- :meth:`ConfigCategory.get_pkey_info` is a new classmethod.
|
||||
|
||||
Changes to the migration and backup system:
|
||||
|
||||
- All code in the ``redbot.setup`` script, excluding that regarding MongoV1, is now virtually backend-agnostic.
|
||||
- All code in the ``[p]backup`` is now backend-agnostic.
|
||||
- :func:`redbot.core.config.migrate` is a new coroutine function.
|
||||
- All a new driver needs to do now to be compatible with migrations and backups is to implement the :class:`BaseDriver` ABC.
|
||||
|
||||
Enhancements to unit tests:
|
||||
|
||||
- New tox recipes have been added for testing against Mongo and Postgres backends. See the ``tox.ini`` file for clues on how to run them.
|
||||
@ -421,15 +421,15 @@ Driver Reference
|
||||
|
||||
Base Driver
|
||||
^^^^^^^^^^^
|
||||
.. autoclass:: redbot.core.drivers.red_base.BaseDriver
|
||||
.. autoclass:: redbot.core.drivers.BaseDriver
|
||||
:members:
|
||||
|
||||
JSON Driver
|
||||
^^^^^^^^^^^
|
||||
.. autoclass:: redbot.core.drivers.red_json.JSON
|
||||
.. autoclass:: redbot.core.drivers.JsonDriver
|
||||
:members:
|
||||
|
||||
Mongo Driver
|
||||
^^^^^^^^^^^^
|
||||
.. autoclass:: redbot.core.drivers.red_mongo.Mongo
|
||||
.. autoclass:: redbot.core.drivers.MongoDriver
|
||||
:members:
|
||||
|
||||
@ -200,6 +200,12 @@ Or, to install with MongoDB support:
|
||||
|
||||
python3.7 -m pip install -U Red-DiscordBot[mongo]
|
||||
|
||||
Or, to install with PostgreSQL support:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
python3.7 -m pip install -U Red-DiscordBot[postgres]
|
||||
|
||||
.. note::
|
||||
|
||||
To install the development version, replace ``Red-DiscordBot`` in the above commands with the
|
||||
|
||||
@ -65,7 +65,7 @@ Installing Red
|
||||
If you're not inside an activated virtual environment, include the ``--user`` flag with all
|
||||
``pip`` commands.
|
||||
|
||||
* No MongoDB support:
|
||||
* Normal installation:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
@ -77,6 +77,12 @@ Installing Red
|
||||
|
||||
python -m pip install -U Red-DiscordBot[mongo]
|
||||
|
||||
* With PostgreSQL support:
|
||||
|
||||
.. code-block:: none
|
||||
|
||||
python3.7 -m pip install -U Red-DiscordBot[postgres]
|
||||
|
||||
.. note::
|
||||
|
||||
To install the development version, replace ``Red-DiscordBot`` in the above commands with the
|
||||
|
||||
4
make.bat
4
make.bat
@ -14,11 +14,11 @@ for /F "tokens=* USEBACKQ" %%A in (`git ls-files "*.py"`) do (
|
||||
goto %1
|
||||
|
||||
:reformat
|
||||
black -l 99 !PYFILES!
|
||||
black -l 99 --target-version py37 !PYFILES!
|
||||
exit /B %ERRORLEVEL%
|
||||
|
||||
:stylecheck
|
||||
black -l 99 --check !PYFILES!
|
||||
black -l 99 --check --target-version py37 !PYFILES!
|
||||
exit /B %ERRORLEVEL%
|
||||
|
||||
:newenv
|
||||
|
||||
@ -33,7 +33,7 @@ from redbot.core.events import init_events
|
||||
from redbot.core.cli import interactive_config, confirm, parse_cli_flags
|
||||
from redbot.core.core_commands import Core
|
||||
from redbot.core.dev_commands import Dev
|
||||
from redbot.core import __version__, modlog, bank, data_manager
|
||||
from redbot.core import __version__, modlog, bank, data_manager, drivers
|
||||
from signal import SIGTERM
|
||||
|
||||
|
||||
@ -99,7 +99,11 @@ def main():
|
||||
)
|
||||
cli_flags.instance_name = "temporary_red"
|
||||
data_manager.create_temp_config()
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
data_manager.load_basic_configuration(cli_flags.instance_name)
|
||||
driver_cls = drivers.get_driver_class()
|
||||
loop.run_until_complete(driver_cls.initialize(**data_manager.storage_details()))
|
||||
redbot.logging.init_logging(
|
||||
level=cli_flags.logging_level, location=data_manager.core_data_path() / "logs"
|
||||
)
|
||||
@ -111,7 +115,6 @@ def main():
|
||||
red = Red(
|
||||
cli_flags=cli_flags, description=description, dm_help=None, fetch_offline_members=True
|
||||
)
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(red.maybe_update_config())
|
||||
init_global_checks(red)
|
||||
init_events(red, cli_flags)
|
||||
|
||||
@ -251,7 +251,7 @@ class Warnings(commands.Cog):
|
||||
user: discord.Member,
|
||||
points: Optional[int] = 1,
|
||||
*,
|
||||
reason: str
|
||||
reason: str,
|
||||
):
|
||||
"""Warn the user for the specified reason.
|
||||
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import inspect
|
||||
import os
|
||||
import logging
|
||||
import os
|
||||
from collections import Counter
|
||||
from enum import Enum
|
||||
from importlib.machinery import ModuleSpec
|
||||
@ -9,10 +9,9 @@ from pathlib import Path
|
||||
from typing import Optional, Union, List
|
||||
|
||||
import discord
|
||||
import sys
|
||||
from discord.ext.commands import when_mentioned_or
|
||||
|
||||
from . import Config, i18n, commands, errors
|
||||
from . import Config, i18n, commands, errors, drivers
|
||||
from .cog_manager import CogManager
|
||||
|
||||
from .rpc import RPCMixin
|
||||
@ -592,8 +591,8 @@ class Red(RedBase, discord.AutoShardedClient):
|
||||
|
||||
async def logout(self):
|
||||
"""Logs out of Discord and closes all connections."""
|
||||
|
||||
await super().logout()
|
||||
await drivers.get_driver_class().teardown()
|
||||
|
||||
async def shutdown(self, *, restart: bool = False):
|
||||
"""Gracefully quit Red.
|
||||
|
||||
@ -5,23 +5,22 @@ import pickle
|
||||
import weakref
|
||||
from typing import (
|
||||
Any,
|
||||
Union,
|
||||
Tuple,
|
||||
Dict,
|
||||
Awaitable,
|
||||
AsyncContextManager,
|
||||
TypeVar,
|
||||
Awaitable,
|
||||
Dict,
|
||||
MutableMapping,
|
||||
Optional,
|
||||
Tuple,
|
||||
Type,
|
||||
TypeVar,
|
||||
Union,
|
||||
)
|
||||
|
||||
import discord
|
||||
|
||||
from .data_manager import cog_data_path, core_data_path
|
||||
from .drivers import get_driver, IdentifierData, BackendType
|
||||
from .drivers.red_base import BaseDriver
|
||||
from .drivers import IdentifierData, get_driver, ConfigCategory, BaseDriver
|
||||
|
||||
__all__ = ["Config", "get_latest_confs"]
|
||||
__all__ = ["Config", "get_latest_confs", "migrate"]
|
||||
|
||||
log = logging.getLogger("red.config")
|
||||
|
||||
@ -101,7 +100,7 @@ class Value:
|
||||
Information on identifiers for this value.
|
||||
default
|
||||
The default value for the data element that `identifiers` points at.
|
||||
driver : `redbot.core.drivers.red_base.BaseDriver`
|
||||
driver : `redbot.core.drivers.BaseDriver`
|
||||
A reference to `Config.driver`.
|
||||
|
||||
"""
|
||||
@ -250,7 +249,7 @@ class Group(Value):
|
||||
All registered default values for this Group.
|
||||
force_registration : `bool`
|
||||
Same as `Config.force_registration`.
|
||||
driver : `redbot.core.drivers.red_base.BaseDriver`
|
||||
driver : `redbot.core.drivers.BaseDriver`
|
||||
A reference to `Config.driver`.
|
||||
|
||||
"""
|
||||
@ -586,7 +585,7 @@ class Config:
|
||||
Unique identifier provided to differentiate cog data when name
|
||||
conflicts occur.
|
||||
driver
|
||||
An instance of a driver that implements `redbot.core.drivers.red_base.BaseDriver`.
|
||||
An instance of a driver that implements `redbot.core.drivers.BaseDriver`.
|
||||
force_registration : `bool`
|
||||
Determines if Config should throw an error if a cog attempts to access
|
||||
an attribute which has not been previously registered.
|
||||
@ -634,7 +633,7 @@ class Config:
|
||||
self.force_registration = force_registration
|
||||
self._defaults = defaults or {}
|
||||
|
||||
self.custom_groups = {}
|
||||
self.custom_groups: Dict[str, int] = {}
|
||||
self._lock_cache: MutableMapping[
|
||||
IdentifierData, asyncio.Lock
|
||||
] = weakref.WeakValueDictionary()
|
||||
@ -643,10 +642,6 @@ class Config:
|
||||
def defaults(self):
|
||||
return pickle.loads(pickle.dumps(self._defaults, -1))
|
||||
|
||||
@staticmethod
|
||||
def _create_uuid(identifier: int):
|
||||
return str(identifier)
|
||||
|
||||
@classmethod
|
||||
def get_conf(cls, cog_instance, identifier: int, force_registration=False, cog_name=None):
|
||||
"""Get a Config instance for your cog.
|
||||
@ -681,25 +676,12 @@ class Config:
|
||||
A new Config object.
|
||||
|
||||
"""
|
||||
if cog_instance is None and cog_name is not None:
|
||||
cog_path_override = cog_data_path(raw_name=cog_name)
|
||||
else:
|
||||
cog_path_override = cog_data_path(cog_instance=cog_instance)
|
||||
uuid = str(identifier)
|
||||
if cog_name is None:
|
||||
cog_name = type(cog_instance).__name__
|
||||
|
||||
cog_name = cog_path_override.stem
|
||||
# uuid = str(hash(identifier))
|
||||
uuid = cls._create_uuid(identifier)
|
||||
|
||||
# We have to import this here otherwise we have a circular dependency
|
||||
from .data_manager import basic_config
|
||||
|
||||
driver_name = basic_config.get("STORAGE_TYPE", "JSON")
|
||||
driver_details = basic_config.get("STORAGE_DETAILS", {})
|
||||
|
||||
driver = get_driver(
|
||||
driver_name, cog_name, uuid, data_path_override=cog_path_override, **driver_details
|
||||
)
|
||||
if driver_name == BackendType.JSON.value:
|
||||
driver = get_driver(cog_name, uuid)
|
||||
if hasattr(driver, "migrate_identifier"):
|
||||
driver.migrate_identifier(identifier)
|
||||
|
||||
conf = cls(
|
||||
@ -712,7 +694,7 @@ class Config:
|
||||
|
||||
@classmethod
|
||||
def get_core_conf(cls, force_registration: bool = False):
|
||||
"""Get a Config instance for a core module.
|
||||
"""Get a Config instance for the core bot.
|
||||
|
||||
All core modules that require a config instance should use this
|
||||
classmethod instead of `get_conf`.
|
||||
@ -723,24 +705,9 @@ class Config:
|
||||
See `force_registration`.
|
||||
|
||||
"""
|
||||
core_path = core_data_path()
|
||||
|
||||
# We have to import this here otherwise we have a circular dependency
|
||||
from .data_manager import basic_config
|
||||
|
||||
driver_name = basic_config.get("STORAGE_TYPE", "JSON")
|
||||
driver_details = basic_config.get("STORAGE_DETAILS", {})
|
||||
|
||||
driver = get_driver(
|
||||
driver_name, "Core", "0", data_path_override=core_path, **driver_details
|
||||
return cls.get_conf(
|
||||
None, cog_name="Core", identifier=0, force_registration=force_registration
|
||||
)
|
||||
conf = cls(
|
||||
cog_name="Core",
|
||||
driver=driver,
|
||||
unique_identifier="0",
|
||||
force_registration=force_registration,
|
||||
)
|
||||
return conf
|
||||
|
||||
def __getattr__(self, item: str) -> Union[Group, Value]:
|
||||
"""Same as `group.__getattr__` except for global data.
|
||||
@ -916,26 +883,18 @@ class Config:
|
||||
self.custom_groups[group_identifier] = identifier_count
|
||||
|
||||
def _get_base_group(self, category: str, *primary_keys: str) -> Group:
|
||||
is_custom = category not in (
|
||||
self.GLOBAL,
|
||||
self.GUILD,
|
||||
self.USER,
|
||||
self.MEMBER,
|
||||
self.ROLE,
|
||||
self.CHANNEL,
|
||||
)
|
||||
# noinspection PyTypeChecker
|
||||
pkey_len, is_custom = ConfigCategory.get_pkey_info(category, self.custom_groups)
|
||||
identifier_data = IdentifierData(
|
||||
cog_name=self.cog_name,
|
||||
uuid=self.unique_identifier,
|
||||
category=category,
|
||||
primary_key=primary_keys,
|
||||
identifiers=(),
|
||||
custom_group_data=self.custom_groups,
|
||||
primary_key_len=pkey_len,
|
||||
is_custom=is_custom,
|
||||
)
|
||||
|
||||
pkey_len = BaseDriver.get_pkey_len(identifier_data)
|
||||
if len(primary_keys) < pkey_len:
|
||||
if len(primary_keys) < identifier_data.primary_key_len:
|
||||
# Don't mix in defaults with groups higher than the document level
|
||||
defaults = {}
|
||||
else:
|
||||
@ -1220,9 +1179,7 @@ class Config:
|
||||
"""
|
||||
if not scopes:
|
||||
# noinspection PyTypeChecker
|
||||
identifier_data = IdentifierData(
|
||||
self.unique_identifier, "", (), (), self.custom_groups
|
||||
)
|
||||
identifier_data = IdentifierData(self.cog_name, self.unique_identifier, "", (), (), 0)
|
||||
group = Group(identifier_data, defaults={}, driver=self.driver, config=self)
|
||||
else:
|
||||
cat, *scopes = scopes
|
||||
@ -1359,7 +1316,12 @@ class Config:
|
||||
return self.get_custom_lock(self.GUILD)
|
||||
else:
|
||||
id_data = IdentifierData(
|
||||
self.unique_identifier, self.MEMBER, (str(guild.id),), (), self.custom_groups
|
||||
self.cog_name,
|
||||
self.unique_identifier,
|
||||
category=self.MEMBER,
|
||||
primary_key=(str(guild.id),),
|
||||
identifiers=(),
|
||||
primary_key_len=2,
|
||||
)
|
||||
return self._lock_cache.setdefault(id_data, asyncio.Lock())
|
||||
|
||||
@ -1375,10 +1337,33 @@ class Config:
|
||||
-------
|
||||
asyncio.Lock
|
||||
"""
|
||||
id_data = IdentifierData(
|
||||
self.unique_identifier, group_identifier, (), (), self.custom_groups
|
||||
)
|
||||
return self._lock_cache.setdefault(id_data, asyncio.Lock())
|
||||
try:
|
||||
pkey_len, is_custom = ConfigCategory.get_pkey_info(
|
||||
group_identifier, self.custom_groups
|
||||
)
|
||||
except KeyError:
|
||||
raise ValueError(f"Custom group not initialized: {group_identifier}") from None
|
||||
else:
|
||||
id_data = IdentifierData(
|
||||
self.cog_name,
|
||||
self.unique_identifier,
|
||||
category=group_identifier,
|
||||
primary_key=(),
|
||||
identifiers=(),
|
||||
primary_key_len=pkey_len,
|
||||
is_custom=is_custom,
|
||||
)
|
||||
return self._lock_cache.setdefault(id_data, asyncio.Lock())
|
||||
|
||||
|
||||
async def migrate(cur_driver_cls: Type[BaseDriver], new_driver_cls: Type[BaseDriver]) -> None:
|
||||
"""Migrate from one driver type to another."""
|
||||
# Get custom group data
|
||||
core_conf = Config.get_core_conf()
|
||||
core_conf.init_custom("CUSTOM_GROUPS", 2)
|
||||
all_custom_group_data = await core_conf.custom("CUSTOM_GROUPS").all()
|
||||
|
||||
await cur_driver_cls.migrate_to(new_driver_cls, all_custom_group_data)
|
||||
|
||||
|
||||
def _str_key_dict(value: Dict[Any, _T]) -> Dict[str, _T]:
|
||||
|
||||
@ -3,15 +3,12 @@ import contextlib
|
||||
import datetime
|
||||
import importlib
|
||||
import itertools
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import pathlib
|
||||
import sys
|
||||
import platform
|
||||
import getpass
|
||||
import pip
|
||||
import tarfile
|
||||
import traceback
|
||||
from collections import namedtuple
|
||||
from pathlib import Path
|
||||
@ -23,15 +20,18 @@ import aiohttp
|
||||
import discord
|
||||
import pkg_resources
|
||||
|
||||
from redbot.core import (
|
||||
from . import (
|
||||
__version__,
|
||||
version_info as red_version_info,
|
||||
VersionInfo,
|
||||
checks,
|
||||
commands,
|
||||
drivers,
|
||||
errors,
|
||||
i18n,
|
||||
config,
|
||||
)
|
||||
from .utils import create_backup
|
||||
from .utils.predicates import MessagePredicate
|
||||
from .utils.chat_formatting import humanize_timedelta, pagify, box, inline, humanize_list
|
||||
from .commands.requires import PrivilegeLevel
|
||||
@ -1307,105 +1307,71 @@ class Core(commands.Cog, CoreLogic):
|
||||
|
||||
@commands.command()
|
||||
@checks.is_owner()
|
||||
async def backup(self, ctx: commands.Context, *, backup_path: str = None):
|
||||
"""Creates a backup of all data for the instance."""
|
||||
if backup_path:
|
||||
path = pathlib.Path(backup_path)
|
||||
if not (path.exists() and path.is_dir()):
|
||||
return await ctx.send(
|
||||
_("That path doesn't seem to exist. Please provide a valid path.")
|
||||
)
|
||||
from redbot.core.data_manager import basic_config, instance_name
|
||||
from redbot.core.drivers.red_json import JSON
|
||||
async def backup(self, ctx: commands.Context, *, backup_dir: str = None):
|
||||
"""Creates a backup of all data for the instance.
|
||||
|
||||
data_dir = Path(basic_config["DATA_PATH"])
|
||||
if basic_config["STORAGE_TYPE"] == "MongoDB":
|
||||
from redbot.core.drivers.red_mongo import Mongo
|
||||
|
||||
m = Mongo("Core", "0", **basic_config["STORAGE_DETAILS"])
|
||||
db = m.db
|
||||
collection_names = await db.list_collection_names()
|
||||
for c_name in collection_names:
|
||||
if c_name == "Core":
|
||||
c_data_path = data_dir / basic_config["CORE_PATH_APPEND"]
|
||||
else:
|
||||
c_data_path = data_dir / basic_config["COG_PATH_APPEND"] / c_name
|
||||
docs = await db[c_name].find().to_list(None)
|
||||
for item in docs:
|
||||
item_id = str(item.pop("_id"))
|
||||
target = JSON(c_name, item_id, data_path_override=c_data_path)
|
||||
target.data = item
|
||||
await target._save()
|
||||
backup_filename = "redv3-{}-{}.tar.gz".format(
|
||||
instance_name, ctx.message.created_at.strftime("%Y-%m-%d %H-%M-%S")
|
||||
)
|
||||
if data_dir.exists():
|
||||
if not backup_path:
|
||||
backup_pth = data_dir.home()
|
||||
else:
|
||||
backup_pth = Path(backup_path)
|
||||
backup_file = backup_pth / backup_filename
|
||||
|
||||
to_backup = []
|
||||
exclusions = [
|
||||
"__pycache__",
|
||||
"Lavalink.jar",
|
||||
os.path.join("Downloader", "lib"),
|
||||
os.path.join("CogManager", "cogs"),
|
||||
os.path.join("RepoManager", "repos"),
|
||||
]
|
||||
downloader_cog = ctx.bot.get_cog("Downloader")
|
||||
if downloader_cog and hasattr(downloader_cog, "_repo_manager"):
|
||||
repo_output = []
|
||||
repo_mgr = downloader_cog._repo_manager
|
||||
for repo in repo_mgr._repos.values():
|
||||
repo_output.append({"url": repo.url, "name": repo.name, "branch": repo.branch})
|
||||
repo_filename = data_dir / "cogs" / "RepoManager" / "repos.json"
|
||||
with open(str(repo_filename), "w") as f:
|
||||
f.write(json.dumps(repo_output, indent=4))
|
||||
instance_data = {instance_name: basic_config}
|
||||
instance_file = data_dir / "instance.json"
|
||||
with open(str(instance_file), "w") as instance_out:
|
||||
instance_out.write(json.dumps(instance_data, indent=4))
|
||||
for f in data_dir.glob("**/*"):
|
||||
if not any(ex in str(f) for ex in exclusions):
|
||||
to_backup.append(f)
|
||||
with tarfile.open(str(backup_file), "w:gz") as tar:
|
||||
for f in to_backup:
|
||||
tar.add(str(f), recursive=False)
|
||||
print(str(backup_file))
|
||||
await ctx.send(
|
||||
_("A backup has been made of this instance. It is at {}.").format(backup_file)
|
||||
)
|
||||
if backup_file.stat().st_size > 8_000_000:
|
||||
await ctx.send(_("This backup is too large to send via DM."))
|
||||
return
|
||||
await ctx.send(_("Would you like to receive a copy via DM? (y/n)"))
|
||||
|
||||
pred = MessagePredicate.yes_or_no(ctx)
|
||||
try:
|
||||
await ctx.bot.wait_for("message", check=pred, timeout=60)
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.send(_("Response timed out."))
|
||||
else:
|
||||
if pred.result is True:
|
||||
await ctx.send(_("OK, it's on its way!"))
|
||||
try:
|
||||
async with ctx.author.typing():
|
||||
await ctx.author.send(
|
||||
_("Here's a copy of the backup"),
|
||||
file=discord.File(str(backup_file)),
|
||||
)
|
||||
except discord.Forbidden:
|
||||
await ctx.send(
|
||||
_("I don't seem to be able to DM you. Do you have closed DMs?")
|
||||
)
|
||||
except discord.HTTPException:
|
||||
await ctx.send(_("I could not send the backup file."))
|
||||
else:
|
||||
await ctx.send(_("OK then."))
|
||||
You may provide a path to a directory for the backup archive to
|
||||
be placed in. If the directory does not exist, the bot will
|
||||
attempt to create it.
|
||||
"""
|
||||
if backup_dir is None:
|
||||
dest = Path.home()
|
||||
else:
|
||||
await ctx.send(_("That directory doesn't seem to exist..."))
|
||||
dest = Path(backup_dir)
|
||||
|
||||
driver_cls = drivers.get_driver_class()
|
||||
if driver_cls != drivers.JsonDriver:
|
||||
await ctx.send(_("Converting data to JSON for backup..."))
|
||||
async with ctx.typing():
|
||||
await config.migrate(driver_cls, drivers.JsonDriver)
|
||||
|
||||
log.info("Creating backup for this instance...")
|
||||
try:
|
||||
backup_fpath = await create_backup(dest)
|
||||
except OSError as exc:
|
||||
await ctx.send(
|
||||
_(
|
||||
"Creating the backup archive failed! Please check your console or logs for "
|
||||
"details."
|
||||
)
|
||||
)
|
||||
log.exception("Failed to create backup archive", exc_info=exc)
|
||||
return
|
||||
|
||||
if backup_fpath is None:
|
||||
await ctx.send(_("Your datapath appears to be empty."))
|
||||
return
|
||||
|
||||
log.info("Backup archive created successfully at '%s'", backup_fpath)
|
||||
await ctx.send(
|
||||
_("A backup has been made of this instance. It is located at `{path}`.").format(
|
||||
path=backup_fpath
|
||||
)
|
||||
)
|
||||
if backup_fpath.stat().st_size > 8_000_000:
|
||||
await ctx.send(_("This backup is too large to send via DM."))
|
||||
return
|
||||
await ctx.send(_("Would you like to receive a copy via DM? (y/n)"))
|
||||
|
||||
pred = MessagePredicate.yes_or_no(ctx)
|
||||
try:
|
||||
await ctx.bot.wait_for("message", check=pred, timeout=60)
|
||||
except asyncio.TimeoutError:
|
||||
await ctx.send(_("Response timed out."))
|
||||
else:
|
||||
if pred.result is True:
|
||||
await ctx.send(_("OK, it's on its way!"))
|
||||
try:
|
||||
async with ctx.author.typing():
|
||||
await ctx.author.send(
|
||||
_("Here's a copy of the backup"), file=discord.File(str(backup_fpath))
|
||||
)
|
||||
except discord.Forbidden:
|
||||
await ctx.send(_("I don't seem to be able to DM you. Do you have closed DMs?"))
|
||||
except discord.HTTPException:
|
||||
await ctx.send(_("I could not send the backup file."))
|
||||
else:
|
||||
await ctx.send(_("OK then."))
|
||||
|
||||
@commands.command()
|
||||
@commands.cooldown(1, 60, commands.BucketType.user)
|
||||
|
||||
@ -224,7 +224,4 @@ def storage_details() -> dict:
|
||||
-------
|
||||
dict
|
||||
"""
|
||||
try:
|
||||
return basic_config["STORAGE_DETAILS"]
|
||||
except KeyError as e:
|
||||
raise RuntimeError("Bot basic config has not been loaded yet.") from e
|
||||
return basic_config.get("STORAGE_DETAILS", {})
|
||||
|
||||
@ -1,50 +1,110 @@
|
||||
import enum
|
||||
from typing import Optional, Type
|
||||
|
||||
from .red_base import IdentifierData
|
||||
from .. import data_manager
|
||||
from .base import IdentifierData, BaseDriver, ConfigCategory
|
||||
from .json import JsonDriver
|
||||
from .mongo import MongoDriver
|
||||
from .postgres import PostgresDriver
|
||||
|
||||
__all__ = ["get_driver", "IdentifierData", "BackendType"]
|
||||
__all__ = [
|
||||
"get_driver",
|
||||
"ConfigCategory",
|
||||
"IdentifierData",
|
||||
"BaseDriver",
|
||||
"JsonDriver",
|
||||
"MongoDriver",
|
||||
"PostgresDriver",
|
||||
"BackendType",
|
||||
]
|
||||
|
||||
|
||||
class BackendType(enum.Enum):
|
||||
JSON = "JSON"
|
||||
MONGO = "MongoDBV2"
|
||||
MONGOV1 = "MongoDB"
|
||||
POSTGRES = "Postgres"
|
||||
|
||||
|
||||
def get_driver(type, *args, **kwargs):
|
||||
_DRIVER_CLASSES = {
|
||||
BackendType.JSON: JsonDriver,
|
||||
BackendType.MONGO: MongoDriver,
|
||||
BackendType.POSTGRES: PostgresDriver,
|
||||
}
|
||||
|
||||
|
||||
def get_driver_class(storage_type: Optional[BackendType] = None) -> Type[BaseDriver]:
|
||||
"""Get the driver class for the given storage type.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
storage_type : Optional[BackendType]
|
||||
The backend you want a driver class for. Omit to try to obtain
|
||||
the backend from data manager.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Type[BaseDriver]
|
||||
A subclass of `BaseDriver`.
|
||||
|
||||
Raises
|
||||
------
|
||||
ValueError
|
||||
If there is no driver for the given storage type.
|
||||
|
||||
"""
|
||||
Selectively import/load driver classes based on the selected type. This
|
||||
is required so that dependencies can differ between installs (e.g. so that
|
||||
you don't need to install a mongo dependency if you will just be running a
|
||||
json data backend).
|
||||
if storage_type is None:
|
||||
storage_type = BackendType(data_manager.storage_type())
|
||||
try:
|
||||
return _DRIVER_CLASSES[storage_type]
|
||||
except KeyError:
|
||||
raise ValueError(f"No driver found for storage type {storage_type}") from None
|
||||
|
||||
.. note::
|
||||
|
||||
See the respective classes for information on what ``args`` and ``kwargs``
|
||||
should be.
|
||||
def get_driver(
|
||||
cog_name: str, identifier: str, storage_type: Optional[BackendType] = None, **kwargs
|
||||
):
|
||||
"""Get a driver instance.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
cog_name : str
|
||||
The cog's name.
|
||||
identifier : str
|
||||
The cog's discriminator.
|
||||
storage_type : Optional[BackendType]
|
||||
The backend you want a driver for. Omit to try to obtain the
|
||||
backend from data manager.
|
||||
**kwargs
|
||||
Driver-specific keyword arguments.
|
||||
|
||||
Returns
|
||||
-------
|
||||
BaseDriver
|
||||
A driver instance.
|
||||
|
||||
Raises
|
||||
------
|
||||
RuntimeError
|
||||
If the storage type is MongoV1 or invalid.
|
||||
|
||||
:param str type:
|
||||
One of: json, mongo
|
||||
:param args:
|
||||
Dependent on driver type.
|
||||
:param kwargs:
|
||||
Dependent on driver type.
|
||||
:return:
|
||||
Subclass of :py:class:`.red_base.BaseDriver`.
|
||||
"""
|
||||
if type == "JSON":
|
||||
from .red_json import JSON
|
||||
if storage_type is None:
|
||||
try:
|
||||
storage_type = BackendType(data_manager.storage_type())
|
||||
except RuntimeError:
|
||||
storage_type = BackendType.JSON
|
||||
|
||||
return JSON(*args, **kwargs)
|
||||
elif type == "MongoDBV2":
|
||||
from .red_mongo import Mongo
|
||||
|
||||
return Mongo(*args, **kwargs)
|
||||
elif type == "MongoDB":
|
||||
raise RuntimeError(
|
||||
"Please convert to JSON first to continue using the bot."
|
||||
" This is a required conversion prior to using the new Mongo driver."
|
||||
" This message will be updated with a link to the update docs once those"
|
||||
" docs have been created."
|
||||
)
|
||||
raise RuntimeError("Invalid driver type: '{}'".format(type))
|
||||
try:
|
||||
driver_cls: Type[BaseDriver] = get_driver_class(storage_type)
|
||||
except ValueError:
|
||||
if storage_type == BackendType.MONGOV1:
|
||||
raise RuntimeError(
|
||||
"Please convert to JSON first to continue using the bot."
|
||||
" This is a required conversion prior to using the new Mongo driver."
|
||||
" This message will be updated with a link to the update docs once those"
|
||||
" docs have been created."
|
||||
) from None
|
||||
else:
|
||||
raise RuntimeError(f"Invalid driver type: '{storage_type}'") from None
|
||||
return driver_cls(cog_name, identifier, **kwargs)
|
||||
|
||||
342
redbot/core/drivers/base.py
Normal file
342
redbot/core/drivers/base.py
Normal file
@ -0,0 +1,342 @@
|
||||
import abc
|
||||
import enum
|
||||
from typing import Tuple, Dict, Any, Union, List, AsyncIterator, Type
|
||||
|
||||
__all__ = ["BaseDriver", "IdentifierData", "ConfigCategory"]
|
||||
|
||||
|
||||
class ConfigCategory(str, enum.Enum):
|
||||
GLOBAL = "GLOBAL"
|
||||
GUILD = "GUILD"
|
||||
CHANNEL = "TEXTCHANNEL"
|
||||
ROLE = "ROLE"
|
||||
USER = "USER"
|
||||
MEMBER = "MEMBER"
|
||||
|
||||
@classmethod
|
||||
def get_pkey_info(
|
||||
cls, category: Union[str, "ConfigCategory"], custom_group_data: Dict[str, int]
|
||||
) -> Tuple[int, bool]:
|
||||
"""Get the full primary key length for the given category,
|
||||
and whether or not the category is a custom category.
|
||||
"""
|
||||
try:
|
||||
# noinspection PyArgumentList
|
||||
category_obj = cls(category)
|
||||
except ValueError:
|
||||
return custom_group_data[category], True
|
||||
else:
|
||||
return _CATEGORY_PKEY_COUNTS[category_obj], False
|
||||
|
||||
|
||||
_CATEGORY_PKEY_COUNTS = {
|
||||
ConfigCategory.GLOBAL: 0,
|
||||
ConfigCategory.GUILD: 1,
|
||||
ConfigCategory.CHANNEL: 1,
|
||||
ConfigCategory.ROLE: 1,
|
||||
ConfigCategory.USER: 1,
|
||||
ConfigCategory.MEMBER: 2,
|
||||
}
|
||||
|
||||
|
||||
class IdentifierData:
|
||||
def __init__(
|
||||
self,
|
||||
cog_name: str,
|
||||
uuid: str,
|
||||
category: str,
|
||||
primary_key: Tuple[str, ...],
|
||||
identifiers: Tuple[str, ...],
|
||||
primary_key_len: int,
|
||||
is_custom: bool = False,
|
||||
):
|
||||
self._cog_name = cog_name
|
||||
self._uuid = uuid
|
||||
self._category = category
|
||||
self._primary_key = primary_key
|
||||
self._identifiers = identifiers
|
||||
self.primary_key_len = primary_key_len
|
||||
self._is_custom = is_custom
|
||||
|
||||
@property
|
||||
def cog_name(self) -> str:
|
||||
return self._cog_name
|
||||
|
||||
@property
|
||||
def uuid(self) -> str:
|
||||
return self._uuid
|
||||
|
||||
@property
|
||||
def category(self) -> str:
|
||||
return self._category
|
||||
|
||||
@property
|
||||
def primary_key(self) -> Tuple[str, ...]:
|
||||
return self._primary_key
|
||||
|
||||
@property
|
||||
def identifiers(self) -> Tuple[str, ...]:
|
||||
return self._identifiers
|
||||
|
||||
@property
|
||||
def is_custom(self) -> bool:
|
||||
return self._is_custom
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"<IdentifierData cog_name={self.cog_name} uuid={self.uuid} category={self.category} "
|
||||
f"primary_key={self.primary_key} identifiers={self.identifiers}>"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, IdentifierData):
|
||||
return False
|
||||
return (
|
||||
self.uuid == other.uuid
|
||||
and self.category == other.category
|
||||
and self.primary_key == other.primary_key
|
||||
and self.identifiers == other.identifiers
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.uuid, self.category, self.primary_key, self.identifiers))
|
||||
|
||||
def add_identifier(self, *identifier: str) -> "IdentifierData":
|
||||
if not all(isinstance(i, str) for i in identifier):
|
||||
raise ValueError("Identifiers must be strings.")
|
||||
|
||||
return IdentifierData(
|
||||
self.cog_name,
|
||||
self.uuid,
|
||||
self.category,
|
||||
self.primary_key,
|
||||
self.identifiers + identifier,
|
||||
self.primary_key_len,
|
||||
is_custom=self.is_custom,
|
||||
)
|
||||
|
||||
def to_tuple(self) -> Tuple[str, ...]:
|
||||
return tuple(
|
||||
filter(
|
||||
None,
|
||||
(self.cog_name, self.uuid, self.category, *self.primary_key, *self.identifiers),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
class BaseDriver(abc.ABC):
|
||||
def __init__(self, cog_name: str, identifier: str, **kwargs):
|
||||
self.cog_name = cog_name
|
||||
self.unique_cog_identifier = identifier
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
async def initialize(cls, **storage_details) -> None:
|
||||
"""
|
||||
Initialize this driver.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
**storage_details
|
||||
The storage details required to initialize this driver.
|
||||
Should be the same as :func:`data_manager.storage_details`
|
||||
|
||||
Raises
|
||||
------
|
||||
MissingExtraRequirements
|
||||
If initializing the driver requires an extra which isn't
|
||||
installed.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
async def teardown(cls) -> None:
|
||||
"""
|
||||
Tear down this driver.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def get_config_details() -> Dict[str, Any]:
|
||||
"""
|
||||
Asks users for additional configuration information necessary
|
||||
to use this config driver.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict[str, Any]
|
||||
Dictionary of configuration details.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def get(self, identifier_data: IdentifierData) -> Any:
|
||||
"""
|
||||
Finds the value indicate by the given identifiers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifier_data
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
Stored value.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def set(self, identifier_data: IdentifierData, value=None) -> None:
|
||||
"""
|
||||
Sets the value of the key indicated by the given identifiers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifier_data
|
||||
value
|
||||
Any JSON serializable python object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
async def clear(self, identifier_data: IdentifierData) -> None:
|
||||
"""
|
||||
Clears out the value specified by the given identifiers.
|
||||
|
||||
Equivalent to using ``del`` on a dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifier_data
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
|
||||
"""Get info for cogs which have data stored on this backend.
|
||||
|
||||
Yields
|
||||
------
|
||||
Tuple[str, str]
|
||||
Asynchronously yields (cog_name, cog_identifier) tuples.
|
||||
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
async def migrate_to(
|
||||
cls,
|
||||
new_driver_cls: Type["BaseDriver"],
|
||||
all_custom_group_data: Dict[str, Dict[str, Dict[str, int]]],
|
||||
) -> None:
|
||||
"""Migrate data from this backend to another.
|
||||
|
||||
Both drivers must be initialized beforehand.
|
||||
|
||||
This will only move the data - no instance metadata is modified
|
||||
as a result of this operation.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
new_driver_cls
|
||||
Subclass of `BaseDriver`.
|
||||
all_custom_group_data : Dict[str, Dict[str, Dict[str, int]]]
|
||||
Dict mapping cog names, to cog IDs, to custom groups, to
|
||||
primary key lengths.
|
||||
|
||||
"""
|
||||
# Backend-agnostic method of migrating from one driver to another.
|
||||
async for cog_name, cog_id in cls.aiter_cogs():
|
||||
this_driver = cls(cog_name, cog_id)
|
||||
other_driver = new_driver_cls(cog_name, cog_id)
|
||||
custom_group_data = all_custom_group_data.get(cog_name, {}).get(cog_id, {})
|
||||
exported_data = await this_driver.export_data(custom_group_data)
|
||||
await other_driver.import_data(exported_data, custom_group_data)
|
||||
|
||||
@classmethod
|
||||
async def delete_all_data(cls, **kwargs) -> None:
|
||||
"""Delete all data being stored by this driver.
|
||||
|
||||
The driver must be initialized before this operation.
|
||||
|
||||
The BaseDriver provides a generic method which may be overriden
|
||||
by subclasses.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
**kwargs
|
||||
Driver-specific kwargs to change the way this method
|
||||
operates.
|
||||
|
||||
"""
|
||||
async for cog_name, cog_id in cls.aiter_cogs():
|
||||
driver = cls(cog_name, cog_id)
|
||||
await driver.clear(IdentifierData(cog_name, cog_id, "", (), (), 0))
|
||||
|
||||
@staticmethod
|
||||
def _split_primary_key(
|
||||
category: Union[ConfigCategory, str],
|
||||
custom_group_data: Dict[str, int],
|
||||
data: Dict[str, Any],
|
||||
) -> List[Tuple[Tuple[str, ...], Dict[str, Any]]]:
|
||||
pkey_len = ConfigCategory.get_pkey_info(category, custom_group_data)[0]
|
||||
if pkey_len == 0:
|
||||
return [((), data)]
|
||||
|
||||
def flatten(levels_remaining, currdata, parent_key=()):
|
||||
items = []
|
||||
for _k, _v in currdata.items():
|
||||
new_key = parent_key + (_k,)
|
||||
if levels_remaining > 1:
|
||||
items.extend(flatten(levels_remaining - 1, _v, new_key).items())
|
||||
else:
|
||||
items.append((new_key, _v))
|
||||
return dict(items)
|
||||
|
||||
ret = []
|
||||
for k, v in flatten(pkey_len, data).items():
|
||||
ret.append((k, v))
|
||||
return ret
|
||||
|
||||
async def export_data(
|
||||
self, custom_group_data: Dict[str, int]
|
||||
) -> List[Tuple[str, Dict[str, Any]]]:
|
||||
categories = [c.value for c in ConfigCategory]
|
||||
categories.extend(custom_group_data.keys())
|
||||
|
||||
ret = []
|
||||
for c in categories:
|
||||
ident_data = IdentifierData(
|
||||
self.cog_name,
|
||||
self.unique_cog_identifier,
|
||||
c,
|
||||
(),
|
||||
(),
|
||||
*ConfigCategory.get_pkey_info(c, custom_group_data),
|
||||
)
|
||||
try:
|
||||
data = await self.get(ident_data)
|
||||
except KeyError:
|
||||
continue
|
||||
ret.append((c, data))
|
||||
return ret
|
||||
|
||||
async def import_data(
|
||||
self, cog_data: List[Tuple[str, Dict[str, Any]]], custom_group_data: Dict[str, int]
|
||||
) -> None:
|
||||
for category, all_data in cog_data:
|
||||
splitted_pkey = self._split_primary_key(category, custom_group_data, all_data)
|
||||
for pkey, data in splitted_pkey:
|
||||
ident_data = IdentifierData(
|
||||
self.cog_name,
|
||||
self.unique_cog_identifier,
|
||||
category,
|
||||
pkey,
|
||||
(),
|
||||
*ConfigCategory.get_pkey_info(category, custom_group_data),
|
||||
)
|
||||
await self.set(ident_data, data)
|
||||
@ -5,12 +5,13 @@ import os
|
||||
import pickle
|
||||
import weakref
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
from typing import Any, AsyncIterator, Dict, Optional, Tuple
|
||||
from uuid import uuid4
|
||||
|
||||
from .red_base import BaseDriver, IdentifierData
|
||||
from .. import data_manager, errors
|
||||
from .base import BaseDriver, IdentifierData, ConfigCategory
|
||||
|
||||
__all__ = ["JSON"]
|
||||
__all__ = ["JsonDriver"]
|
||||
|
||||
|
||||
_shared_datastore = {}
|
||||
@ -35,9 +36,10 @@ def finalize_driver(cog_name):
|
||||
_finalizers.remove(f)
|
||||
|
||||
|
||||
class JSON(BaseDriver):
|
||||
# noinspection PyProtectedMember
|
||||
class JsonDriver(BaseDriver):
|
||||
"""
|
||||
Subclass of :py:class:`.red_base.BaseDriver`.
|
||||
Subclass of :py:class:`.BaseDriver`.
|
||||
|
||||
.. py:attribute:: file_name
|
||||
|
||||
@ -50,27 +52,26 @@ class JSON(BaseDriver):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cog_name,
|
||||
identifier,
|
||||
cog_name: str,
|
||||
identifier: str,
|
||||
*,
|
||||
data_path_override: Path = None,
|
||||
file_name_override: str = "settings.json"
|
||||
data_path_override: Optional[Path] = None,
|
||||
file_name_override: str = "settings.json",
|
||||
):
|
||||
super().__init__(cog_name, identifier)
|
||||
self.file_name = file_name_override
|
||||
if data_path_override:
|
||||
if data_path_override is not None:
|
||||
self.data_path = data_path_override
|
||||
elif cog_name == "Core" and identifier == "0":
|
||||
self.data_path = data_manager.core_data_path()
|
||||
else:
|
||||
self.data_path = Path.cwd() / "cogs" / ".data" / self.cog_name
|
||||
self.data_path = data_manager.cog_data_path(raw_name=cog_name)
|
||||
self.data_path.mkdir(parents=True, exist_ok=True)
|
||||
self.data_path = self.data_path / self.file_name
|
||||
|
||||
self._lock = asyncio.Lock()
|
||||
self._load_data()
|
||||
|
||||
async def has_valid_connection(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return _shared_datastore.get(self.cog_name)
|
||||
@ -79,6 +80,21 @@ class JSON(BaseDriver):
|
||||
def data(self, value):
|
||||
_shared_datastore[self.cog_name] = value
|
||||
|
||||
@classmethod
|
||||
async def initialize(cls, **storage_details) -> None:
|
||||
# No initializing to do
|
||||
return
|
||||
|
||||
@classmethod
|
||||
async def teardown(cls) -> None:
|
||||
# No tearing down to do
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
def get_config_details() -> Dict[str, Any]:
|
||||
# No driver-specific configuration needed
|
||||
return {}
|
||||
|
||||
def _load_data(self):
|
||||
if self.cog_name not in _driver_counts:
|
||||
_driver_counts[self.cog_name] = 0
|
||||
@ -111,30 +127,32 @@ class JSON(BaseDriver):
|
||||
|
||||
async def get(self, identifier_data: IdentifierData):
|
||||
partial = self.data
|
||||
full_identifiers = identifier_data.to_tuple()
|
||||
full_identifiers = identifier_data.to_tuple()[1:]
|
||||
for i in full_identifiers:
|
||||
partial = partial[i]
|
||||
return pickle.loads(pickle.dumps(partial, -1))
|
||||
|
||||
async def set(self, identifier_data: IdentifierData, value=None):
|
||||
partial = self.data
|
||||
full_identifiers = identifier_data.to_tuple()
|
||||
full_identifiers = identifier_data.to_tuple()[1:]
|
||||
# This is both our deepcopy() and our way of making sure this value is actually JSON
|
||||
# serializable.
|
||||
value_copy = json.loads(json.dumps(value))
|
||||
|
||||
async with self._lock:
|
||||
for i in full_identifiers[:-1]:
|
||||
if i not in partial:
|
||||
partial[i] = {}
|
||||
partial = partial[i]
|
||||
partial[full_identifiers[-1]] = value_copy
|
||||
try:
|
||||
partial = partial.setdefault(i, {})
|
||||
except AttributeError:
|
||||
# Tried to set sub-field of non-object
|
||||
raise errors.CannotSetSubfield
|
||||
|
||||
partial[full_identifiers[-1]] = value_copy
|
||||
await self._save()
|
||||
|
||||
async def clear(self, identifier_data: IdentifierData):
|
||||
partial = self.data
|
||||
full_identifiers = identifier_data.to_tuple()
|
||||
full_identifiers = identifier_data.to_tuple()[1:]
|
||||
try:
|
||||
for i in full_identifiers[:-1]:
|
||||
partial = partial[i]
|
||||
@ -149,14 +167,32 @@ class JSON(BaseDriver):
|
||||
else:
|
||||
await self._save()
|
||||
|
||||
@classmethod
|
||||
async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
|
||||
yield "Core", "0"
|
||||
for _dir in data_manager.cog_data_path().iterdir():
|
||||
fpath = _dir / "settings.json"
|
||||
if not fpath.exists():
|
||||
continue
|
||||
with fpath.open() as f:
|
||||
try:
|
||||
data = json.load(f)
|
||||
except json.JSONDecodeError:
|
||||
continue
|
||||
if not isinstance(data, dict):
|
||||
continue
|
||||
for cog, inner in data.items():
|
||||
if not isinstance(inner, dict):
|
||||
continue
|
||||
for cog_id in inner:
|
||||
yield cog, cog_id
|
||||
|
||||
async def import_data(self, cog_data, custom_group_data):
|
||||
def update_write_data(identifier_data: IdentifierData, _data):
|
||||
partial = self.data
|
||||
idents = identifier_data.to_tuple()
|
||||
idents = identifier_data.to_tuple()[1:]
|
||||
for ident in idents[:-1]:
|
||||
if ident not in partial:
|
||||
partial[ident] = {}
|
||||
partial = partial[ident]
|
||||
partial = partial.setdefault(ident, {})
|
||||
partial[idents[-1]] = _data
|
||||
|
||||
async with self._lock:
|
||||
@ -164,12 +200,12 @@ class JSON(BaseDriver):
|
||||
splitted_pkey = self._split_primary_key(category, custom_group_data, all_data)
|
||||
for pkey, data in splitted_pkey:
|
||||
ident_data = IdentifierData(
|
||||
self.cog_name,
|
||||
self.unique_cog_identifier,
|
||||
category,
|
||||
pkey,
|
||||
(),
|
||||
custom_group_data,
|
||||
is_custom=category in custom_group_data,
|
||||
*ConfigCategory.get_pkey_info(category, custom_group_data),
|
||||
)
|
||||
update_write_data(ident_data, data)
|
||||
await self._save()
|
||||
@ -178,9 +214,6 @@ class JSON(BaseDriver):
|
||||
loop = asyncio.get_running_loop()
|
||||
await loop.run_in_executor(None, _save_json, self.data_path, self.data)
|
||||
|
||||
def get_config_details(self):
|
||||
return
|
||||
|
||||
|
||||
def _save_json(path: Path, data: Dict[str, Any]) -> None:
|
||||
"""
|
||||
11
redbot/core/drivers/log.py
Normal file
11
redbot/core/drivers/log.py
Normal file
@ -0,0 +1,11 @@
|
||||
import functools
|
||||
import logging
|
||||
import os
|
||||
|
||||
if os.getenv("RED_INSPECT_DRIVER_QUERIES"):
|
||||
LOGGING_INVISIBLE = logging.DEBUG
|
||||
else:
|
||||
LOGGING_INVISIBLE = 0
|
||||
|
||||
log = logging.getLogger("red.driver")
|
||||
log.invisible = functools.partial(log.log, LOGGING_INVISIBLE)
|
||||
@ -2,77 +2,110 @@ import contextlib
|
||||
import itertools
|
||||
import re
|
||||
from getpass import getpass
|
||||
from typing import Match, Pattern, Tuple, Any, Dict, Iterator, List
|
||||
from typing import Match, Pattern, Tuple, Optional, AsyncIterator, Any, Dict, Iterator, List
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import motor.core
|
||||
import motor.motor_asyncio
|
||||
import pymongo.errors
|
||||
try:
|
||||
# pylint: disable=import-error
|
||||
import pymongo.errors
|
||||
import motor.core
|
||||
import motor.motor_asyncio
|
||||
except ModuleNotFoundError:
|
||||
motor = None
|
||||
pymongo = None
|
||||
|
||||
from .red_base import BaseDriver, IdentifierData
|
||||
from .. import errors
|
||||
from .base import BaseDriver, IdentifierData
|
||||
|
||||
__all__ = ["Mongo"]
|
||||
__all__ = ["MongoDriver"]
|
||||
|
||||
|
||||
_conn = None
|
||||
|
||||
|
||||
def _initialize(**kwargs):
|
||||
uri = kwargs.get("URI", "mongodb")
|
||||
host = kwargs["HOST"]
|
||||
port = kwargs["PORT"]
|
||||
admin_user = kwargs["USERNAME"]
|
||||
admin_pass = kwargs["PASSWORD"]
|
||||
db_name = kwargs.get("DB_NAME", "default_db")
|
||||
|
||||
if port is 0:
|
||||
ports = ""
|
||||
else:
|
||||
ports = ":{}".format(port)
|
||||
|
||||
if admin_user is not None and admin_pass is not None:
|
||||
url = "{}://{}:{}@{}{}/{}".format(
|
||||
uri, quote_plus(admin_user), quote_plus(admin_pass), host, ports, db_name
|
||||
)
|
||||
else:
|
||||
url = "{}://{}{}/{}".format(uri, host, ports, db_name)
|
||||
|
||||
global _conn
|
||||
_conn = motor.motor_asyncio.AsyncIOMotorClient(url, retryWrites=True)
|
||||
|
||||
|
||||
class Mongo(BaseDriver):
|
||||
class MongoDriver(BaseDriver):
|
||||
"""
|
||||
Subclass of :py:class:`.red_base.BaseDriver`.
|
||||
Subclass of :py:class:`.BaseDriver`.
|
||||
"""
|
||||
|
||||
def __init__(self, cog_name, identifier, **kwargs):
|
||||
super().__init__(cog_name, identifier)
|
||||
_conn: Optional["motor.motor_asyncio.AsyncIOMotorClient"] = None
|
||||
|
||||
if _conn is None:
|
||||
_initialize(**kwargs)
|
||||
@classmethod
|
||||
async def initialize(cls, **storage_details) -> None:
|
||||
if motor is None:
|
||||
raise errors.MissingExtraRequirements(
|
||||
"Red must be installed with the [mongo] extra to use the MongoDB driver"
|
||||
)
|
||||
uri = storage_details.get("URI", "mongodb")
|
||||
host = storage_details["HOST"]
|
||||
port = storage_details["PORT"]
|
||||
user = storage_details["USERNAME"]
|
||||
password = storage_details["PASSWORD"]
|
||||
database = storage_details.get("DB_NAME", "default_db")
|
||||
|
||||
async def has_valid_connection(self) -> bool:
|
||||
# Maybe fix this?
|
||||
return True
|
||||
if port is 0:
|
||||
ports = ""
|
||||
else:
|
||||
ports = ":{}".format(port)
|
||||
|
||||
if user is not None and password is not None:
|
||||
url = "{}://{}:{}@{}{}/{}".format(
|
||||
uri, quote_plus(user), quote_plus(password), host, ports, database
|
||||
)
|
||||
else:
|
||||
url = "{}://{}{}/{}".format(uri, host, ports, database)
|
||||
|
||||
cls._conn = motor.motor_asyncio.AsyncIOMotorClient(url, retryWrites=True)
|
||||
|
||||
@classmethod
|
||||
async def teardown(cls) -> None:
|
||||
if cls._conn is not None:
|
||||
cls._conn.close()
|
||||
|
||||
@staticmethod
|
||||
def get_config_details():
|
||||
while True:
|
||||
uri = input("Enter URI scheme (mongodb or mongodb+srv): ")
|
||||
if uri is "":
|
||||
uri = "mongodb"
|
||||
|
||||
if uri in ["mongodb", "mongodb+srv"]:
|
||||
break
|
||||
else:
|
||||
print("Invalid URI scheme")
|
||||
|
||||
host = input("Enter host address: ")
|
||||
if uri is "mongodb":
|
||||
port = int(input("Enter host port: "))
|
||||
else:
|
||||
port = 0
|
||||
|
||||
admin_uname = input("Enter login username: ")
|
||||
admin_password = getpass("Enter login password: ")
|
||||
|
||||
db_name = input("Enter mongodb database name: ")
|
||||
|
||||
if admin_uname == "":
|
||||
admin_uname = admin_password = None
|
||||
|
||||
ret = {
|
||||
"HOST": host,
|
||||
"PORT": port,
|
||||
"USERNAME": admin_uname,
|
||||
"PASSWORD": admin_password,
|
||||
"DB_NAME": db_name,
|
||||
"URI": uri,
|
||||
}
|
||||
return ret
|
||||
|
||||
@property
|
||||
def db(self) -> motor.core.Database:
|
||||
def db(self) -> "motor.core.Database":
|
||||
"""
|
||||
Gets the mongo database for this cog's name.
|
||||
|
||||
.. warning::
|
||||
|
||||
Right now this will cause a new connection to be made every time the
|
||||
database is accessed. We will want to create a connection pool down the
|
||||
line to limit the number of connections.
|
||||
|
||||
:return:
|
||||
PyMongo Database object.
|
||||
"""
|
||||
return _conn.get_database()
|
||||
return self._conn.get_database()
|
||||
|
||||
def get_collection(self, category: str) -> motor.core.Collection:
|
||||
def get_collection(self, category: str) -> "motor.core.Collection":
|
||||
"""
|
||||
Gets a specified collection within the PyMongo database for this cog.
|
||||
|
||||
@ -85,12 +118,13 @@ class Mongo(BaseDriver):
|
||||
"""
|
||||
return self.db[self.cog_name][category]
|
||||
|
||||
def get_primary_key(self, identifier_data: IdentifierData) -> Tuple[str]:
|
||||
@staticmethod
|
||||
def get_primary_key(identifier_data: IdentifierData) -> Tuple[str, ...]:
|
||||
# noinspection PyTypeChecker
|
||||
return identifier_data.primary_key
|
||||
|
||||
async def rebuild_dataset(
|
||||
self, identifier_data: IdentifierData, cursor: motor.motor_asyncio.AsyncIOMotorCursor
|
||||
self, identifier_data: IdentifierData, cursor: "motor.motor_asyncio.AsyncIOMotorCursor"
|
||||
):
|
||||
ret = {}
|
||||
async for doc in cursor:
|
||||
@ -141,16 +175,16 @@ class Mongo(BaseDriver):
|
||||
async def set(self, identifier_data: IdentifierData, value=None):
|
||||
uuid = self._escape_key(identifier_data.uuid)
|
||||
primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data)))
|
||||
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
|
||||
if isinstance(value, dict):
|
||||
if len(value) == 0:
|
||||
await self.clear(identifier_data)
|
||||
return
|
||||
value = self._escape_dict_keys(value)
|
||||
mongo_collection = self.get_collection(identifier_data.category)
|
||||
pkey_len = self.get_pkey_len(identifier_data)
|
||||
num_pkeys = len(primary_key)
|
||||
|
||||
if num_pkeys >= pkey_len:
|
||||
if num_pkeys >= identifier_data.primary_key_len:
|
||||
# We're setting at the document level or below.
|
||||
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
|
||||
if dot_identifiers:
|
||||
@ -158,11 +192,23 @@ class Mongo(BaseDriver):
|
||||
else:
|
||||
update_stmt = {"$set": value}
|
||||
|
||||
await mongo_collection.update_one(
|
||||
{"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}},
|
||||
update=update_stmt,
|
||||
upsert=True,
|
||||
)
|
||||
try:
|
||||
await mongo_collection.update_one(
|
||||
{"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}},
|
||||
update=update_stmt,
|
||||
upsert=True,
|
||||
)
|
||||
except pymongo.errors.WriteError as exc:
|
||||
if exc.args and exc.args[0].startswith("Cannot create field"):
|
||||
# There's a bit of a failing edge case here...
|
||||
# If we accidentally set the sub-field of an array, and the key happens to be a
|
||||
# digit, it will successfully set the value in the array, and not raise an
|
||||
# error. This is different to how other drivers would behave, and could lead to
|
||||
# unexpected behaviour.
|
||||
raise errors.CannotSetSubfield
|
||||
else:
|
||||
# Unhandled driver exception, should expose.
|
||||
raise
|
||||
|
||||
else:
|
||||
# We're setting above the document level.
|
||||
@ -171,15 +217,17 @@ class Mongo(BaseDriver):
|
||||
# We'll do it in a transaction so we can roll-back in case something goes horribly
|
||||
# wrong.
|
||||
pkey_filter = self.generate_primary_key_filter(identifier_data)
|
||||
async with await _conn.start_session() as session:
|
||||
async with await self._conn.start_session() as session:
|
||||
with contextlib.suppress(pymongo.errors.CollectionInvalid):
|
||||
# Collections must already exist when inserting documents within a transaction
|
||||
await _conn.get_database().create_collection(mongo_collection.full_name)
|
||||
await self.db.create_collection(mongo_collection.full_name)
|
||||
try:
|
||||
async with session.start_transaction():
|
||||
await mongo_collection.delete_many(pkey_filter, session=session)
|
||||
await mongo_collection.insert_many(
|
||||
self.generate_documents_to_insert(uuid, primary_key, value, pkey_len),
|
||||
self.generate_documents_to_insert(
|
||||
uuid, primary_key, value, identifier_data.primary_key_len
|
||||
),
|
||||
session=session,
|
||||
)
|
||||
except pymongo.errors.OperationFailure:
|
||||
@ -218,7 +266,7 @@ class Mongo(BaseDriver):
|
||||
|
||||
# What's left of `value` should be the new documents needing to be inserted.
|
||||
to_insert = self.generate_documents_to_insert(
|
||||
uuid, primary_key, value, pkey_len
|
||||
uuid, primary_key, value, identifier_data.primary_key_len
|
||||
)
|
||||
requests = list(
|
||||
itertools.chain(
|
||||
@ -289,6 +337,59 @@ class Mongo(BaseDriver):
|
||||
for result in results:
|
||||
await db[result["name"]].delete_many(pkey_filter)
|
||||
|
||||
@classmethod
|
||||
async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
|
||||
db = cls._conn.get_database()
|
||||
for collection_name in await db.list_collection_names():
|
||||
parts = collection_name.split(".")
|
||||
if not len(parts) == 2:
|
||||
continue
|
||||
cog_name = parts[0]
|
||||
for cog_id in await db[collection_name].distinct("_id.RED_uuid"):
|
||||
yield cog_name, cog_id
|
||||
|
||||
@classmethod
|
||||
async def delete_all_data(
|
||||
cls, *, interactive: bool = False, drop_db: Optional[bool] = None, **kwargs
|
||||
) -> None:
|
||||
"""Delete all data being stored by this driver.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
interactive : bool
|
||||
Set to ``True`` to allow the method to ask the user for
|
||||
input from the console, regarding the other unset parameters
|
||||
for this method.
|
||||
drop_db : Optional[bool]
|
||||
Set to ``True`` to drop the entire database for the current
|
||||
bot's instance. Otherwise, collections which appear to be
|
||||
storing bot data will be dropped.
|
||||
|
||||
"""
|
||||
if interactive is True and drop_db is None:
|
||||
print(
|
||||
"Please choose from one of the following options:\n"
|
||||
" 1. Drop the entire MongoDB database for this instance, or\n"
|
||||
" 2. Delete all of Red's data within this database, without dropping the database "
|
||||
"itself."
|
||||
)
|
||||
options = ("1", "2")
|
||||
while True:
|
||||
resp = input("> ")
|
||||
try:
|
||||
drop_db = bool(options.index(resp))
|
||||
except ValueError:
|
||||
print("Please type a number corresponding to one of the options.")
|
||||
else:
|
||||
break
|
||||
db = cls._conn.get_database()
|
||||
if drop_db is True:
|
||||
await cls._conn.drop_database(db)
|
||||
else:
|
||||
async with await cls._conn.start_session() as session:
|
||||
async for cog_name, cog_id in cls.aiter_cogs():
|
||||
await db.drop_collection(db[cog_name], session=session)
|
||||
|
||||
@staticmethod
|
||||
def _escape_key(key: str) -> str:
|
||||
return _SPECIAL_CHAR_PATTERN.sub(_replace_with_escaped, key)
|
||||
@ -344,40 +445,3 @@ _CHAR_ESCAPES = {
|
||||
|
||||
def _replace_with_unescaped(match: Match[str]) -> str:
|
||||
return _CHAR_ESCAPES[match[0]]
|
||||
|
||||
|
||||
def get_config_details():
|
||||
uri = None
|
||||
while True:
|
||||
uri = input("Enter URI scheme (mongodb or mongodb+srv): ")
|
||||
if uri is "":
|
||||
uri = "mongodb"
|
||||
|
||||
if uri in ["mongodb", "mongodb+srv"]:
|
||||
break
|
||||
else:
|
||||
print("Invalid URI scheme")
|
||||
|
||||
host = input("Enter host address: ")
|
||||
if uri is "mongodb":
|
||||
port = int(input("Enter host port: "))
|
||||
else:
|
||||
port = 0
|
||||
|
||||
admin_uname = input("Enter login username: ")
|
||||
admin_password = getpass("Enter login password: ")
|
||||
|
||||
db_name = input("Enter mongodb database name: ")
|
||||
|
||||
if admin_uname == "":
|
||||
admin_uname = admin_password = None
|
||||
|
||||
ret = {
|
||||
"HOST": host,
|
||||
"PORT": port,
|
||||
"USERNAME": admin_uname,
|
||||
"PASSWORD": admin_password,
|
||||
"DB_NAME": db_name,
|
||||
"URI": uri,
|
||||
}
|
||||
return ret
|
||||
3
redbot/core/drivers/postgres/__init__.py
Normal file
3
redbot/core/drivers/postgres/__init__.py
Normal file
@ -0,0 +1,3 @@
|
||||
from .postgres import PostgresDriver
|
||||
|
||||
__all__ = ["PostgresDriver"]
|
||||
839
redbot/core/drivers/postgres/ddl.sql
Normal file
839
redbot/core/drivers/postgres/ddl.sql
Normal file
@ -0,0 +1,839 @@
|
||||
/*
|
||||
************************************************************
|
||||
* PostgreSQL driver Data Definition Language (DDL) Script. *
|
||||
************************************************************
|
||||
*/
|
||||
|
||||
CREATE SCHEMA IF NOT EXISTS red_config;
|
||||
CREATE SCHEMA IF NOT EXISTS red_utils;
|
||||
|
||||
DO $$
|
||||
BEGIN
|
||||
PERFORM 'red_config.identifier_data'::regtype;
|
||||
EXCEPTION
|
||||
WHEN UNDEFINED_OBJECT THEN
|
||||
CREATE TYPE red_config.identifier_data AS (
|
||||
cog_name text,
|
||||
cog_id text,
|
||||
category text,
|
||||
pkeys text[],
|
||||
identifiers text[],
|
||||
pkey_len integer,
|
||||
is_custom boolean
|
||||
);
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Create the config schema and/or table if they do not exist yet.
|
||||
*/
|
||||
red_config.maybe_create_table(
|
||||
id_data red_config.identifier_data
|
||||
)
|
||||
RETURNS void
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
schema_exists CONSTANT boolean := exists(
|
||||
SELECT 1
|
||||
FROM red_config.red_cogs t
|
||||
WHERE t.cog_name = id_data.cog_name AND t.cog_id = id_data.cog_id);
|
||||
table_exists CONSTANT boolean := schema_exists AND exists(
|
||||
SELECT 1
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema = schemaname AND table_name = id_data.category);
|
||||
|
||||
BEGIN
|
||||
IF NOT schema_exists THEN
|
||||
PERFORM red_config.create_schema(id_data.cog_name, id_data.cog_id);
|
||||
END IF;
|
||||
IF NOT table_exists THEN
|
||||
PERFORM red_config.create_table(id_data);
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Create the config schema for the given cog.
|
||||
*/
|
||||
red_config.create_schema(new_cog_name text, new_cog_id text, OUT schemaname text)
|
||||
RETURNS text
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
BEGIN
|
||||
schemaname := concat_ws('.', new_cog_name, new_cog_id);
|
||||
|
||||
EXECUTE format('CREATE SCHEMA IF NOT EXISTS %I', schemaname);
|
||||
|
||||
INSERT INTO red_config.red_cogs AS t VALUES(new_cog_name, new_cog_id, schemaname)
|
||||
ON CONFLICT(cog_name, cog_id) DO UPDATE
|
||||
SET
|
||||
schemaname = excluded.schemaname;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Create the config table for the given category.
|
||||
*/
|
||||
red_config.create_table(id_data red_config.identifier_data)
|
||||
RETURNS void
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
constraintname CONSTANT text := id_data.category||'_pkey';
|
||||
pkey_columns CONSTANT text := red_utils.gen_pkey_columns(1, id_data.pkey_len);
|
||||
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
|
||||
pkey_column_definitions CONSTANT text := red_utils.gen_pkey_column_definitions(
|
||||
1, id_data.pkey_len, pkey_type);
|
||||
|
||||
BEGIN
|
||||
EXECUTE format(
|
||||
$query$
|
||||
CREATE TABLE IF NOT EXISTS %I.%I (
|
||||
%s,
|
||||
json_data jsonb DEFAULT '{}' NOT NULL,
|
||||
CONSTRAINT %I PRIMARY KEY (%s)
|
||||
)
|
||||
$query$,
|
||||
schemaname,
|
||||
id_data.category,
|
||||
pkey_column_definitions,
|
||||
constraintname,
|
||||
pkey_columns);
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Get config data.
|
||||
*
|
||||
* - When `pkeys` is a full primary key, all or part of a document
|
||||
* will be returned.
|
||||
* - When `pkeys` is not a full primary key, documents will be
|
||||
* aggregated together into a single JSONB object, with primary keys
|
||||
* as keys mapping to the documents.
|
||||
*/
|
||||
red_config.get(
|
||||
id_data red_config.identifier_data,
|
||||
OUT result jsonb
|
||||
)
|
||||
LANGUAGE 'plpgsql'
|
||||
STABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
num_pkeys CONSTANT integer := coalesce(array_length(id_data.pkeys, 1), 0);
|
||||
num_missing_pkeys CONSTANT integer := id_data.pkey_len - num_pkeys;
|
||||
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
|
||||
whereclause CONSTANT text := red_utils.gen_whereclause(num_pkeys, pkey_type);
|
||||
|
||||
missing_pkey_columns text;
|
||||
|
||||
BEGIN
|
||||
IF num_missing_pkeys <= 0 THEN
|
||||
-- No missing primary keys: we're getting all or part of a document.
|
||||
EXECUTE format(
|
||||
'SELECT json_data #> $2 FROM %I.%I WHERE %s',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
INTO result
|
||||
USING id_data.pkeys, id_data.identifiers;
|
||||
|
||||
ELSIF num_missing_pkeys = 1 THEN
|
||||
-- 1 missing primary key: we can use the built-in jsonb_object_agg() aggregate function.
|
||||
EXECUTE format(
|
||||
'SELECT jsonb_object_agg(%I::text, json_data) FROM %I.%I WHERE %s',
|
||||
'primary_key_'||id_data.pkey_len,
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
INTO result
|
||||
USING id_data.pkeys;
|
||||
ELSE
|
||||
-- Multiple missing primary keys: we must use our custom red_utils.jsonb_object_agg2()
|
||||
-- aggregate function.
|
||||
missing_pkey_columns := red_utils.gen_pkey_columns_casted(num_pkeys + 1, id_data.pkey_len);
|
||||
|
||||
EXECUTE format(
|
||||
'SELECT red_utils.jsonb_object_agg2(json_data, %s) FROM %I.%I WHERE %s',
|
||||
missing_pkey_columns,
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
INTO result
|
||||
USING id_data.pkeys;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Set config data.
|
||||
*
|
||||
* - When `pkeys` is a full primary key, all or part of a document
|
||||
* will be set.
|
||||
* - When `pkeys` is not a full set, multiple documents will be
|
||||
* replaced or removed - `new_value` must be a JSONB object mapping
|
||||
* primary keys to the new documents.
|
||||
*
|
||||
* Raises `error_in_assignment` error when trying to set a sub-key
|
||||
* of a non-document type.
|
||||
*/
|
||||
red_config.set(
|
||||
id_data red_config.identifier_data,
|
||||
new_value jsonb
|
||||
)
|
||||
RETURNS void
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
constraintname CONSTANT text := id_data.category||'_pkey';
|
||||
num_pkeys CONSTANT integer := coalesce(array_length(id_data.pkeys, 1), 0);
|
||||
num_missing_pkeys CONSTANT integer := id_data.pkey_len - num_pkeys;
|
||||
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
|
||||
pkey_placeholders CONSTANT text := red_utils.gen_pkey_placeholders(num_pkeys, pkey_type);
|
||||
|
||||
new_document jsonb;
|
||||
pkey_column_definitions text;
|
||||
whereclause text;
|
||||
missing_pkey_columns text;
|
||||
|
||||
BEGIN
|
||||
PERFORM red_config.maybe_create_table(id_data);
|
||||
|
||||
IF num_missing_pkeys = 0 THEN
|
||||
-- Setting all or part of a document
|
||||
new_document := red_utils.jsonb_set2('{}', new_value, VARIADIC id_data.identifiers);
|
||||
|
||||
EXECUTE format(
|
||||
$query$
|
||||
INSERT INTO %I.%I AS t VALUES (%s, $2)
|
||||
ON CONFLICT ON CONSTRAINT %I DO UPDATE
|
||||
SET
|
||||
json_data = red_utils.jsonb_set2(t.json_data, $3, VARIADIC $4)
|
||||
$query$,
|
||||
schemaname,
|
||||
id_data.category,
|
||||
pkey_placeholders,
|
||||
constraintname)
|
||||
USING id_data.pkeys, new_document, new_value, id_data.identifiers;
|
||||
|
||||
ELSE
|
||||
-- Setting multiple documents
|
||||
whereclause := red_utils.gen_whereclause(num_pkeys, pkey_type);
|
||||
missing_pkey_columns := red_utils.gen_pkey_columns_casted(
|
||||
num_pkeys + 1, id_data.pkey_len, pkey_type);
|
||||
pkey_column_definitions := red_utils.gen_pkey_column_definitions(num_pkeys + 1, id_data.pkey_len);
|
||||
|
||||
-- Delete all documents which we're setting first, since we don't know whether they'll be
|
||||
-- replaced by the subsequent INSERT.
|
||||
EXECUTE format('DELETE FROM %I.%I WHERE %s', schemaname, id_data.category, whereclause)
|
||||
USING id_data.pkeys;
|
||||
|
||||
-- Insert all new documents
|
||||
EXECUTE format(
|
||||
$query$
|
||||
INSERT INTO %I.%I AS t
|
||||
SELECT %s, json_data
|
||||
FROM red_utils.generate_rows_from_object($2, $3) AS f(%s, json_data jsonb)
|
||||
ON CONFLICT ON CONSTRAINT %I DO UPDATE
|
||||
SET
|
||||
json_data = excluded.json_data
|
||||
$query$,
|
||||
schemaname,
|
||||
id_data.category,
|
||||
concat_ws(', ', pkey_placeholders, missing_pkey_columns),
|
||||
pkey_column_definitions,
|
||||
constraintname)
|
||||
USING id_data.pkeys, new_value, num_missing_pkeys;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Clear config data.
|
||||
*
|
||||
* - When `identifiers` is not empty, this will clear a key from a
|
||||
* document.
|
||||
* - When `identifiers` is empty and `pkeys` is not empty, it will
|
||||
* delete one or more documents.
|
||||
* - When `pkeys` is empty, it will drop the whole table.
|
||||
* - When `id_data.category` is NULL or an empty string, it will drop
|
||||
* the whole schema.
|
||||
*
|
||||
* Has no effect when the document or key does not exist.
|
||||
*/
|
||||
red_config.clear(
|
||||
id_data red_config.identifier_data
|
||||
)
|
||||
RETURNS void
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
num_pkeys CONSTANT integer := coalesce(array_length(id_data.pkeys, 1), 0);
|
||||
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
|
||||
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
|
||||
|
||||
whereclause text;
|
||||
|
||||
BEGIN
|
||||
IF num_identifiers > 0 THEN
|
||||
-- Popping a key from a document or nested document.
|
||||
whereclause := red_utils.gen_whereclause(num_pkeys, pkey_type);
|
||||
|
||||
EXECUTE format(
|
||||
$query$
|
||||
UPDATE %I.%I AS t
|
||||
SET
|
||||
json_data = t.json_data #- $2
|
||||
WHERE %s
|
||||
$query$,
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
USING id_data.pkeys, id_data.identifiers;
|
||||
|
||||
ELSIF num_pkeys > 0 THEN
|
||||
-- Deleting one or many documents
|
||||
whereclause := red_utils.gen_whereclause(num_pkeys, pkey_type);
|
||||
|
||||
EXECUTE format('DELETE FROM %I.%I WHERE %s', schemaname, id_data.category, whereclause)
|
||||
USING id_data.pkeys;
|
||||
|
||||
ELSIF id_data.category IS NOT NULL AND id_data.category != '' THEN
|
||||
-- Deleting an entire category
|
||||
EXECUTE format('DROP TABLE %I.%I CASCADE', schemaname, id_data.category);
|
||||
|
||||
ELSE
|
||||
-- Deleting an entire cog's data
|
||||
EXECUTE format('DROP SCHEMA %I CASCADE', schemaname);
|
||||
|
||||
DELETE FROM red_config.red_cogs
|
||||
WHERE cog_name = id_data.cog_name AND cog_id = id_data.cog_id;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Increment a number within a document.
|
||||
*
|
||||
* If the value doesn't already exist, it is inserted as
|
||||
* `default_value + amount`.
|
||||
*
|
||||
* Raises 'wrong_object_type' error when trying to increment a
|
||||
* non-numeric value.
|
||||
*/
|
||||
red_config.inc(
|
||||
id_data red_config.identifier_data,
|
||||
amount numeric,
|
||||
default_value numeric,
|
||||
OUT result numeric
|
||||
)
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
|
||||
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
|
||||
whereclause CONSTANT text := red_utils.gen_whereclause(id_data.pkey_len, pkey_type);
|
||||
|
||||
new_document jsonb;
|
||||
existing_document jsonb;
|
||||
existing_value jsonb;
|
||||
pkey_placeholders text;
|
||||
|
||||
BEGIN
|
||||
IF num_identifiers = 0 THEN
|
||||
-- Without identifiers, there's no chance we're actually incrementing a number
|
||||
RAISE EXCEPTION 'Cannot increment document(s)'
|
||||
USING ERRCODE = 'wrong_object_type';
|
||||
END IF;
|
||||
|
||||
PERFORM red_config.maybe_create_table(id_data);
|
||||
|
||||
-- Look for the existing document
|
||||
EXECUTE format(
|
||||
'SELECT json_data FROM %I.%I WHERE %s',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
INTO existing_document USING id_data.pkeys;
|
||||
|
||||
IF existing_document IS NULL THEN
|
||||
-- We need to insert a new document
|
||||
result := default_value + amount;
|
||||
new_document := red_utils.jsonb_set2('{}', result, VARIADIC id_data.identifiers);
|
||||
pkey_placeholders := red_utils.gen_pkey_placeholders(id_data.pkey_len, pkey_type);
|
||||
|
||||
EXECUTE format(
|
||||
'INSERT INTO %I.%I VALUES(%s, $2)',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
pkey_placeholders)
|
||||
USING id_data.pkeys, new_document;
|
||||
|
||||
ELSE
|
||||
-- We need to update the existing document
|
||||
existing_value := existing_document #> id_data.identifiers;
|
||||
|
||||
IF existing_value IS NULL THEN
|
||||
result := default_value + amount;
|
||||
|
||||
ELSIF jsonb_typeof(existing_value) = 'number' THEN
|
||||
result := existing_value::text::numeric + amount;
|
||||
|
||||
ELSE
|
||||
RAISE EXCEPTION 'Cannot increment non-numeric value %', existing_value
|
||||
USING ERRCODE = 'wrong_object_type';
|
||||
END IF;
|
||||
|
||||
new_document := red_utils.jsonb_set2(
|
||||
existing_document, to_jsonb(result), id_data.identifiers);
|
||||
|
||||
EXECUTE format(
|
||||
'UPDATE %I.%I SET json_data = $2 WHERE %s',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
USING id_data.pkeys, new_document;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Toggle a boolean within a document.
|
||||
*
|
||||
* If the value doesn't already exist, it is inserted as `NOT
|
||||
* default_value`.
|
||||
*
|
||||
* Raises 'wrong_object_type' error when trying to toggle a
|
||||
* non-boolean value.
|
||||
*/
|
||||
red_config.toggle(
|
||||
id_data red_config.identifier_data,
|
||||
default_value boolean,
|
||||
OUT result boolean
|
||||
)
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
|
||||
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
|
||||
whereclause CONSTANT text := red_utils.gen_whereclause(id_data.pkey_len, pkey_type);
|
||||
|
||||
new_document jsonb;
|
||||
existing_document jsonb;
|
||||
existing_value jsonb;
|
||||
pkey_placeholders text;
|
||||
|
||||
BEGIN
|
||||
IF num_identifiers = 0 THEN
|
||||
-- Without identifiers, there's no chance we're actually toggling a boolean
|
||||
RAISE EXCEPTION 'Cannot increment document(s)'
|
||||
USING ERRCODE = 'wrong_object_type';
|
||||
END IF;
|
||||
|
||||
PERFORM red_config.maybe_create_table(id_data);
|
||||
|
||||
-- Look for the existing document
|
||||
EXECUTE format(
|
||||
'SELECT json_data FROM %I.%I WHERE %s',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
INTO existing_document USING id_data.pkeys;
|
||||
|
||||
IF existing_document IS NULL THEN
|
||||
-- We need to insert a new document
|
||||
result := NOT default_value;
|
||||
new_document := red_utils.jsonb_set2('{}', result, VARIADIC id_data.identifiers);
|
||||
pkey_placeholders := red_utils.gen_pkey_placeholders(id_data.pkey_len, pkey_type);
|
||||
|
||||
EXECUTE format(
|
||||
'INSERT INTO %I.%I VALUES(%s, $2)',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
pkey_placeholders)
|
||||
USING id_data.pkeys, new_document;
|
||||
|
||||
ELSE
|
||||
-- We need to update the existing document
|
||||
existing_value := existing_document #> id_data.identifiers;
|
||||
|
||||
IF existing_value IS NULL THEN
|
||||
result := NOT default_value;
|
||||
|
||||
ELSIF jsonb_typeof(existing_value) = 'boolean' THEN
|
||||
result := NOT existing_value::text::boolean;
|
||||
|
||||
ELSE
|
||||
RAISE EXCEPTION 'Cannot increment non-boolean value %', existing_value
|
||||
USING ERRCODE = 'wrong_object_type';
|
||||
END IF;
|
||||
|
||||
new_document := red_utils.jsonb_set2(
|
||||
existing_document, to_jsonb(result), id_data.identifiers);
|
||||
|
||||
EXECUTE format(
|
||||
'UPDATE %I.%I SET json_data = $2 WHERE %s',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
USING id_data.pkeys, new_document;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
red_config.extend(
|
||||
id_data red_config.identifier_data,
|
||||
new_value text,
|
||||
default_value text,
|
||||
max_length integer DEFAULT NULL,
|
||||
extend_left boolean DEFAULT FALSE,
|
||||
OUT result jsonb
|
||||
)
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id);
|
||||
num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0);
|
||||
pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom);
|
||||
whereclause CONSTANT text := red_utils.gen_whereclause(id_data.pkey_len, pkey_type);
|
||||
pop_idx CONSTANT integer := CASE extend_left WHEN TRUE THEN -1 ELSE 0 END;
|
||||
|
||||
new_document jsonb;
|
||||
existing_document jsonb;
|
||||
existing_value jsonb;
|
||||
pkey_placeholders text;
|
||||
idx integer;
|
||||
BEGIN
|
||||
IF num_identifiers = 0 THEN
|
||||
-- Without identifiers, there's no chance we're actually appending to an array
|
||||
RAISE EXCEPTION 'Cannot append to document(s)'
|
||||
USING ERRCODE = 'wrong_object_type';
|
||||
END IF;
|
||||
|
||||
PERFORM red_config.maybe_create_table(id_data);
|
||||
|
||||
-- Look for the existing document
|
||||
EXECUTE format(
|
||||
'SELECT json_data FROM %I.%I WHERE %s',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
INTO existing_document USING id_data.pkeys;
|
||||
|
||||
IF existing_document IS NULL THEN
|
||||
result := default_value || new_value;
|
||||
new_document := red_utils.jsonb_set2('{}'::jsonb, result, id_data.identifiers);
|
||||
pkey_placeholders := red_utils.gen_pkey_placeholders(id_data.pkey_len, pkey_type);
|
||||
|
||||
EXECUTE format(
|
||||
'INSERT INTO %I.%I VALUES(%s, $2)',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
pkey_placeholders)
|
||||
USING id_data.pkeys, new_document;
|
||||
|
||||
ELSE
|
||||
existing_value := existing_document #> id_data.identifiers;
|
||||
|
||||
IF existing_value IS NULL THEN
|
||||
existing_value := default_value;
|
||||
|
||||
ELSIF jsonb_typeof(existing_value) != 'array' THEN
|
||||
RAISE EXCEPTION 'Cannot append to non-array value %', existing_value
|
||||
USING ERRCODE = 'wrong_object_type';
|
||||
END IF;
|
||||
|
||||
CASE extend_left
|
||||
WHEN TRUE THEN
|
||||
result := new_value || existing_value;
|
||||
ELSE
|
||||
result := existing_value || new_value;
|
||||
END CASE;
|
||||
|
||||
IF max_length IS NOT NULL THEN
|
||||
FOR idx IN SELECT generate_series(1, jsonb_array_length(result) - max_length) LOOP
|
||||
result := result - pop_idx;
|
||||
END LOOP;
|
||||
END IF;
|
||||
|
||||
new_document := red_utils.jsonb_set2(existing_document, result, id_data.identifiers);
|
||||
|
||||
EXECUTE format(
|
||||
'UPDATE %I.%I SET json_data = $2 WHERE %s',
|
||||
schemaname,
|
||||
id_data.category,
|
||||
whereclause)
|
||||
USING id_data.pkeys, new_document;
|
||||
END IF;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Delete all schemas listed in the red_config.red_cogs table.
|
||||
*/
|
||||
red_config.delete_all_schemas()
|
||||
RETURNS void
|
||||
LANGUAGE 'plpgsql'
|
||||
AS $$
|
||||
DECLARE
|
||||
cog_entry record;
|
||||
BEGIN
|
||||
FOR cog_entry IN SELECT * FROM red_config.red_cogs t LOOP
|
||||
EXECUTE format('DROP SCHEMA %I CASCADE', cog_entry.schemaname);
|
||||
END LOOP;
|
||||
-- Clear out red_config.red_cogs table
|
||||
DELETE FROM red_config.red_cogs WHERE TRUE;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Like `jsonb_set` but will insert new objects where one is missing
|
||||
* along the path.
|
||||
*
|
||||
* Raises `error_in_assignment` error when trying to set a sub-key
|
||||
* of a non-document type.
|
||||
*/
|
||||
red_utils.jsonb_set2(target jsonb, new_value jsonb, VARIADIC identifiers text[])
|
||||
RETURNS jsonb
|
||||
LANGUAGE 'plpgsql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
DECLARE
|
||||
num_identifiers CONSTANT integer := coalesce(array_length(identifiers, 1), 0);
|
||||
|
||||
cur_value_type text;
|
||||
idx integer;
|
||||
|
||||
BEGIN
|
||||
IF num_identifiers = 0 THEN
|
||||
RETURN new_value;
|
||||
END IF;
|
||||
|
||||
FOR idx IN SELECT generate_series(1, num_identifiers - 1) LOOP
|
||||
cur_value_type := jsonb_typeof(target #> identifiers[:idx]);
|
||||
IF cur_value_type IS NULL THEN
|
||||
-- Parent key didn't exist in JSON before - insert new object
|
||||
target := jsonb_set(target, identifiers[:idx], '{}'::jsonb);
|
||||
|
||||
ELSIF cur_value_type != 'object' THEN
|
||||
-- We can't set the sub-field of a null, int, float, array etc.
|
||||
RAISE EXCEPTION 'Cannot set sub-field of "%s"', cur_value_type
|
||||
USING ERRCODE = 'error_in_assignment';
|
||||
END IF;
|
||||
END LOOP;
|
||||
|
||||
RETURN jsonb_set(target, identifiers, new_value);
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Return a set of rows to insert into a table, from a single JSONB
|
||||
* object containing multiple documents.
|
||||
*/
|
||||
red_utils.generate_rows_from_object(object jsonb, num_missing_pkeys integer)
|
||||
RETURNS setof record
|
||||
LANGUAGE 'plpgsql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
DECLARE
|
||||
pair record;
|
||||
column_definitions text;
|
||||
BEGIN
|
||||
IF num_missing_pkeys = 1 THEN
|
||||
-- Base case: Simply return (key, value) pairs
|
||||
RETURN QUERY
|
||||
SELECT key AS key_1, value AS json_data
|
||||
FROM jsonb_each(object);
|
||||
ELSE
|
||||
-- We need to return (key, key, ..., value) pairs: recurse into inner JSONB objects
|
||||
column_definitions := red_utils.gen_pkey_column_definitions(2, num_missing_pkeys);
|
||||
|
||||
FOR pair IN SELECT * FROM jsonb_each(object) LOOP
|
||||
RETURN QUERY
|
||||
EXECUTE format(
|
||||
$query$
|
||||
SELECT $1 AS key_1, *
|
||||
FROM red_utils.generate_rows_from_object($2, $3)
|
||||
AS f(%s, json_data jsonb)
|
||||
$query$,
|
||||
column_definitions)
|
||||
USING pair.key, pair.value, num_missing_pkeys - 1;
|
||||
END LOOP;
|
||||
END IF;
|
||||
RETURN;
|
||||
END;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Get a comma-separated list of primary key placeholders.
|
||||
*
|
||||
* The placeholder will always be $1. Particularly useful for
|
||||
* inserting values into a table from an array of primary keys.
|
||||
*/
|
||||
red_utils.gen_pkey_placeholders(num_pkeys integer, pkey_type text DEFAULT 'text')
|
||||
RETURNS text
|
||||
LANGUAGE 'sql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
SELECT string_agg(t.item, ', ')
|
||||
FROM (
|
||||
SELECT format('$1[%s]::%s', idx, pkey_type) AS item
|
||||
FROM generate_series(1, num_pkeys) idx) t
|
||||
;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Generate a whereclause for the given number of primary keys.
|
||||
*
|
||||
* When there are no primary keys, this will simply return the the
|
||||
* string 'TRUE'. When there are multiple, it will return multiple
|
||||
* equality comparisons concatenated with 'AND'.
|
||||
*/
|
||||
red_utils.gen_whereclause(num_pkeys integer, pkey_type text)
|
||||
RETURNS text
|
||||
LANGUAGE 'sql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
SELECT coalesce(string_agg(t.item, ' AND '), 'TRUE')
|
||||
FROM (
|
||||
SELECT format('%I = $1[%s]::%s', 'primary_key_'||idx, idx, pkey_type) AS item
|
||||
FROM generate_series(1, num_pkeys) idx) t
|
||||
;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Generate a comma-separated list of primary key column names.
|
||||
*/
|
||||
red_utils.gen_pkey_columns(start integer, stop integer)
|
||||
RETURNS text
|
||||
LANGUAGE 'sql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
SELECT string_agg(t.item, ', ')
|
||||
FROM (
|
||||
SELECT quote_ident('primary_key_'||idx) AS item
|
||||
FROM generate_series(start, stop) idx) t
|
||||
;
|
||||
$$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Generate a comma-separated list of primary key column names casted
|
||||
* to the given type.
|
||||
*/
|
||||
red_utils.gen_pkey_columns_casted(start integer, stop integer, pkey_type text DEFAULT 'text')
|
||||
RETURNS text
|
||||
LANGUAGE 'sql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
SELECT string_agg(t.item, ', ')
|
||||
FROM (
|
||||
SELECT format('%I::%s', 'primary_key_'||idx, pkey_type) AS item
|
||||
FROM generate_series(start, stop) idx) t
|
||||
;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
/*
|
||||
* Generate a primary key column definition list.
|
||||
*/
|
||||
red_utils.gen_pkey_column_definitions(
|
||||
start integer, stop integer, column_type text DEFAULT 'text'
|
||||
)
|
||||
RETURNS text
|
||||
LANGUAGE 'sql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
SELECT string_agg(t.item, ', ')
|
||||
FROM (
|
||||
SELECT format('%I %s', 'primary_key_'||idx, column_type) AS item
|
||||
FROM generate_series(start, stop) idx) t
|
||||
;
|
||||
$$;
|
||||
|
||||
|
||||
CREATE OR REPLACE FUNCTION
|
||||
red_utils.get_pkey_type(is_custom boolean)
|
||||
RETURNS TEXT
|
||||
LANGUAGE 'sql'
|
||||
IMMUTABLE
|
||||
PARALLEL SAFE
|
||||
AS $$
|
||||
SELECT ('{bigint,text}'::text[])[is_custom::integer + 1];
|
||||
$$;
|
||||
|
||||
|
||||
DROP AGGREGATE IF EXISTS red_utils.jsonb_object_agg2(jsonb, VARIADIC text[]);
|
||||
CREATE AGGREGATE
|
||||
/*
|
||||
* Like `jsonb_object_agg` but aggregates more than two columns into a
|
||||
* single JSONB object.
|
||||
*
|
||||
* If possible, use `jsonb_object_agg` instead for performance
|
||||
* reasons.
|
||||
*/
|
||||
red_utils.jsonb_object_agg2(json_data jsonb, VARIADIC primary_keys text[]) (
|
||||
SFUNC = red_utils.jsonb_set2,
|
||||
STYPE = jsonb,
|
||||
INITCOND = '{}',
|
||||
PARALLEL = SAFE
|
||||
)
|
||||
;
|
||||
|
||||
|
||||
CREATE TABLE IF NOT EXISTS
|
||||
/*
|
||||
* Table to keep track of other cogs' schemas.
|
||||
*/
|
||||
red_config.red_cogs(
|
||||
cog_name text,
|
||||
cog_id text,
|
||||
schemaname text NOT NULL,
|
||||
PRIMARY KEY (cog_name, cog_id)
|
||||
)
|
||||
;
|
||||
3
redbot/core/drivers/postgres/drop_ddl.sql
Normal file
3
redbot/core/drivers/postgres/drop_ddl.sql
Normal file
@ -0,0 +1,3 @@
|
||||
SELECT red_config.delete_all_schemas();
|
||||
DROP SCHEMA IF EXISTS red_config CASCADE;
|
||||
DROP SCHEMA IF EXISTS red_utils CASCADE;
|
||||
255
redbot/core/drivers/postgres/postgres.py
Normal file
255
redbot/core/drivers/postgres/postgres.py
Normal file
@ -0,0 +1,255 @@
|
||||
import getpass
|
||||
import json
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from typing import Optional, Any, AsyncIterator, Tuple, Union, Callable, List
|
||||
|
||||
try:
|
||||
# pylint: disable=import-error
|
||||
import asyncpg
|
||||
except ModuleNotFoundError:
|
||||
asyncpg = None
|
||||
|
||||
from ... import data_manager, errors
|
||||
from ..base import BaseDriver, IdentifierData, ConfigCategory
|
||||
from ..log import log
|
||||
|
||||
__all__ = ["PostgresDriver"]
|
||||
|
||||
_PKG_PATH = Path(__file__).parent
|
||||
DDL_SCRIPT_PATH = _PKG_PATH / "ddl.sql"
|
||||
DROP_DDL_SCRIPT_PATH = _PKG_PATH / "drop_ddl.sql"
|
||||
|
||||
|
||||
def encode_identifier_data(
|
||||
id_data: IdentifierData
|
||||
) -> Tuple[str, str, str, List[str], List[str], int, bool]:
|
||||
return (
|
||||
id_data.cog_name,
|
||||
id_data.uuid,
|
||||
id_data.category,
|
||||
["0"] if id_data.category == ConfigCategory.GLOBAL else list(id_data.primary_key),
|
||||
list(id_data.identifiers),
|
||||
1 if id_data.category == ConfigCategory.GLOBAL else id_data.primary_key_len,
|
||||
id_data.is_custom,
|
||||
)
|
||||
|
||||
|
||||
class PostgresDriver(BaseDriver):
|
||||
|
||||
_pool: Optional["asyncpg.pool.Pool"] = None
|
||||
|
||||
@classmethod
|
||||
async def initialize(cls, **storage_details) -> None:
|
||||
if asyncpg is None:
|
||||
raise errors.MissingExtraRequirements(
|
||||
"Red must be installed with the [postgres] extra to use the PostgreSQL driver"
|
||||
)
|
||||
cls._pool = await asyncpg.create_pool(**storage_details)
|
||||
with DDL_SCRIPT_PATH.open() as fs:
|
||||
await cls._pool.execute(fs.read())
|
||||
|
||||
@classmethod
|
||||
async def teardown(cls) -> None:
|
||||
if cls._pool is not None:
|
||||
await cls._pool.close()
|
||||
|
||||
@staticmethod
|
||||
def get_config_details():
|
||||
unixmsg = (
|
||||
""
|
||||
if sys.platform != "win32"
|
||||
else (
|
||||
" - Common directories for PostgreSQL Unix-domain sockets (/run/postgresql, "
|
||||
"/var/run/postgresl, /var/pgsql_socket, /private/tmp, and /tmp),\n"
|
||||
)
|
||||
)
|
||||
host = (
|
||||
input(
|
||||
f"Enter the PostgreSQL server's address.\n"
|
||||
f"If left blank, Red will try the following, in order:\n"
|
||||
f" - The PGHOST environment variable,\n{unixmsg}"
|
||||
f" - localhost.\n"
|
||||
f"> "
|
||||
)
|
||||
or None
|
||||
)
|
||||
|
||||
print(
|
||||
"Enter the PostgreSQL server port.\n"
|
||||
"If left blank, this will default to either:\n"
|
||||
" - The PGPORT environment variable,\n"
|
||||
" - 5432."
|
||||
)
|
||||
while True:
|
||||
port = input("> ") or None
|
||||
if port is None:
|
||||
break
|
||||
|
||||
try:
|
||||
port = int(port)
|
||||
except ValueError:
|
||||
print("Port must be a number")
|
||||
else:
|
||||
break
|
||||
|
||||
user = (
|
||||
input(
|
||||
"Enter the PostgreSQL server username.\n"
|
||||
"If left blank, this will default to either:\n"
|
||||
" - The PGUSER environment variable,\n"
|
||||
" - The OS name of the user running Red (ident/peer authentication).\n"
|
||||
"> "
|
||||
)
|
||||
or None
|
||||
)
|
||||
|
||||
passfile = r"%APPDATA%\postgresql\pgpass.conf" if sys.platform != "win32" else "~/.pgpass"
|
||||
password = getpass.getpass(
|
||||
f"Enter the PostgreSQL server password. The input will be hidden.\n"
|
||||
f" NOTE: If using ident/peer authentication (no password), enter NONE.\n"
|
||||
f"When NONE is entered, this will default to:\n"
|
||||
f" - The PGPASSWORD environment variable,\n"
|
||||
f" - Looking up the password in the {passfile} passfile,\n"
|
||||
f" - No password.\n"
|
||||
f"> "
|
||||
)
|
||||
if password == "NONE":
|
||||
password = None
|
||||
|
||||
database = (
|
||||
input(
|
||||
"Enter the PostgreSQL database's name.\n"
|
||||
"If left blank, this will default to either:\n"
|
||||
" - The PGDATABASE environment variable,\n"
|
||||
" - The OS name of the user running Red.\n"
|
||||
"> "
|
||||
)
|
||||
or None
|
||||
)
|
||||
|
||||
return {
|
||||
"host": host,
|
||||
"port": port,
|
||||
"user": user,
|
||||
"password": password,
|
||||
"database": database,
|
||||
}
|
||||
|
||||
async def get(self, identifier_data: IdentifierData):
|
||||
try:
|
||||
result = await self._execute(
|
||||
"SELECT red_config.get($1)",
|
||||
encode_identifier_data(identifier_data),
|
||||
method=self._pool.fetchval,
|
||||
)
|
||||
except asyncpg.UndefinedTableError:
|
||||
raise KeyError from None
|
||||
|
||||
if result is None:
|
||||
# The result is None both when postgres yields no results, or when it yields a NULL row
|
||||
# A 'null' JSON value would be returned as encoded JSON, i.e. the string 'null'
|
||||
raise KeyError
|
||||
return json.loads(result)
|
||||
|
||||
async def set(self, identifier_data: IdentifierData, value=None):
|
||||
try:
|
||||
await self._execute(
|
||||
"SELECT red_config.set($1, $2::jsonb)",
|
||||
encode_identifier_data(identifier_data),
|
||||
json.dumps(value),
|
||||
)
|
||||
except asyncpg.ErrorInAssignmentError:
|
||||
raise errors.CannotSetSubfield
|
||||
|
||||
async def clear(self, identifier_data: IdentifierData):
|
||||
try:
|
||||
await self._execute(
|
||||
"SELECT red_config.clear($1)", encode_identifier_data(identifier_data)
|
||||
)
|
||||
except asyncpg.UndefinedTableError:
|
||||
pass
|
||||
|
||||
async def inc(
|
||||
self, identifier_data: IdentifierData, value: Union[int, float], default: Union[int, float]
|
||||
) -> Union[int, float]:
|
||||
try:
|
||||
return await self._execute(
|
||||
f"SELECT red_config.inc($1, $2, $3)",
|
||||
encode_identifier_data(identifier_data),
|
||||
value,
|
||||
default,
|
||||
method=self._pool.fetchval,
|
||||
)
|
||||
except asyncpg.WrongObjectTypeError as exc:
|
||||
raise errors.StoredTypeError(*exc.args)
|
||||
|
||||
async def toggle(self, identifier_data: IdentifierData, default: bool) -> bool:
|
||||
try:
|
||||
return await self._execute(
|
||||
"SELECT red_config.inc($1, $2)",
|
||||
encode_identifier_data(identifier_data),
|
||||
default,
|
||||
method=self._pool.fetchval,
|
||||
)
|
||||
except asyncpg.WrongObjectTypeError as exc:
|
||||
raise errors.StoredTypeError(*exc.args)
|
||||
|
||||
@classmethod
|
||||
async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]:
|
||||
query = "SELECT cog_name, cog_id FROM red_config.red_cogs"
|
||||
log.invisible(query)
|
||||
async with cls._pool.acquire() as conn, conn.transaction():
|
||||
async for row in conn.cursor(query):
|
||||
yield row["cog_name"], row["cog_id"]
|
||||
|
||||
@classmethod
|
||||
async def delete_all_data(
|
||||
cls, *, interactive: bool = False, drop_db: Optional[bool] = None, **kwargs
|
||||
) -> None:
|
||||
"""Delete all data being stored by this driver.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
interactive : bool
|
||||
Set to ``True`` to allow the method to ask the user for
|
||||
input from the console, regarding the other unset parameters
|
||||
for this method.
|
||||
drop_db : Optional[bool]
|
||||
Set to ``True`` to drop the entire database for the current
|
||||
bot's instance. Otherwise, schemas within the database which
|
||||
store bot data will be dropped, as well as functions,
|
||||
aggregates, event triggers, and meta-tables.
|
||||
|
||||
"""
|
||||
if interactive is True and drop_db is None:
|
||||
print(
|
||||
"Please choose from one of the following options:\n"
|
||||
" 1. Drop the entire PostgreSQL database for this instance, or\n"
|
||||
" 2. Delete all of Red's data within this database, without dropping the database "
|
||||
"itself."
|
||||
)
|
||||
options = ("1", "2")
|
||||
while True:
|
||||
resp = input("> ")
|
||||
try:
|
||||
drop_db = bool(options.index(resp))
|
||||
except ValueError:
|
||||
print("Please type a number corresponding to one of the options.")
|
||||
else:
|
||||
break
|
||||
if drop_db is True:
|
||||
storage_details = data_manager.storage_details()
|
||||
await cls._pool.execute(f"DROP DATABASE $1", storage_details["database"])
|
||||
else:
|
||||
with DROP_DDL_SCRIPT_PATH.open() as fs:
|
||||
await cls._pool.execute(fs.read())
|
||||
|
||||
@classmethod
|
||||
async def _execute(cls, query: str, *args, method: Optional[Callable] = None) -> Any:
|
||||
if method is None:
|
||||
method = cls._pool.execute
|
||||
log.invisible("Query: %s", query)
|
||||
if args:
|
||||
log.invisible("Args: %s", args)
|
||||
return await method(query, *args)
|
||||
@ -1,233 +0,0 @@
|
||||
import enum
|
||||
from typing import Tuple
|
||||
|
||||
__all__ = ["BaseDriver", "IdentifierData"]
|
||||
|
||||
|
||||
class ConfigCategory(enum.Enum):
|
||||
GLOBAL = "GLOBAL"
|
||||
GUILD = "GUILD"
|
||||
CHANNEL = "TEXTCHANNEL"
|
||||
ROLE = "ROLE"
|
||||
USER = "USER"
|
||||
MEMBER = "MEMBER"
|
||||
|
||||
|
||||
class IdentifierData:
|
||||
def __init__(
|
||||
self,
|
||||
uuid: str,
|
||||
category: str,
|
||||
primary_key: Tuple[str, ...],
|
||||
identifiers: Tuple[str, ...],
|
||||
custom_group_data: dict,
|
||||
is_custom: bool = False,
|
||||
):
|
||||
self._uuid = uuid
|
||||
self._category = category
|
||||
self._primary_key = primary_key
|
||||
self._identifiers = identifiers
|
||||
self.custom_group_data = custom_group_data
|
||||
self._is_custom = is_custom
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
return self._uuid
|
||||
|
||||
@property
|
||||
def category(self):
|
||||
return self._category
|
||||
|
||||
@property
|
||||
def primary_key(self):
|
||||
return self._primary_key
|
||||
|
||||
@property
|
||||
def identifiers(self):
|
||||
return self._identifiers
|
||||
|
||||
@property
|
||||
def is_custom(self):
|
||||
return self._is_custom
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<IdentifierData uuid={self.uuid} category={self.category} primary_key={self.primary_key}"
|
||||
f" identifiers={self.identifiers}>"
|
||||
)
|
||||
|
||||
def __eq__(self, other) -> bool:
|
||||
if not isinstance(other, IdentifierData):
|
||||
return False
|
||||
return (
|
||||
self.uuid == other.uuid
|
||||
and self.category == other.category
|
||||
and self.primary_key == other.primary_key
|
||||
and self.identifiers == other.identifiers
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash((self.uuid, self.category, self.primary_key, self.identifiers))
|
||||
|
||||
def add_identifier(self, *identifier: str) -> "IdentifierData":
|
||||
if not all(isinstance(i, str) for i in identifier):
|
||||
raise ValueError("Identifiers must be strings.")
|
||||
|
||||
return IdentifierData(
|
||||
self.uuid,
|
||||
self.category,
|
||||
self.primary_key,
|
||||
self.identifiers + identifier,
|
||||
self.custom_group_data,
|
||||
is_custom=self.is_custom,
|
||||
)
|
||||
|
||||
def to_tuple(self):
|
||||
return tuple(
|
||||
item
|
||||
for item in (self.uuid, self.category, *self.primary_key, *self.identifiers)
|
||||
if len(item) > 0
|
||||
)
|
||||
|
||||
|
||||
class BaseDriver:
|
||||
def __init__(self, cog_name, identifier):
|
||||
self.cog_name = cog_name
|
||||
self.unique_cog_identifier = identifier
|
||||
|
||||
async def has_valid_connection(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
async def get(self, identifier_data: IdentifierData):
|
||||
"""
|
||||
Finds the value indicate by the given identifiers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifier_data
|
||||
|
||||
Returns
|
||||
-------
|
||||
Any
|
||||
Stored value.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_config_details(self):
|
||||
"""
|
||||
Asks users for additional configuration information necessary
|
||||
to use this config driver.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Dict of configuration details.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def set(self, identifier_data: IdentifierData, value=None):
|
||||
"""
|
||||
Sets the value of the key indicated by the given identifiers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifier_data
|
||||
value
|
||||
Any JSON serializable python object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def clear(self, identifier_data: IdentifierData):
|
||||
"""
|
||||
Clears out the value specified by the given identifiers.
|
||||
|
||||
Equivalent to using ``del`` on a dict.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifier_data
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _get_levels(self, category, custom_group_data):
|
||||
if category == ConfigCategory.GLOBAL.value:
|
||||
return 0
|
||||
elif category in (
|
||||
ConfigCategory.USER.value,
|
||||
ConfigCategory.GUILD.value,
|
||||
ConfigCategory.CHANNEL.value,
|
||||
ConfigCategory.ROLE.value,
|
||||
):
|
||||
return 1
|
||||
elif category == ConfigCategory.MEMBER.value:
|
||||
return 2
|
||||
elif category in custom_group_data:
|
||||
return custom_group_data[category]
|
||||
else:
|
||||
raise RuntimeError(f"Cannot convert due to group: {category}")
|
||||
|
||||
def _split_primary_key(self, category, custom_group_data, data):
|
||||
levels = self._get_levels(category, custom_group_data)
|
||||
if levels == 0:
|
||||
return (((), data),)
|
||||
|
||||
def flatten(levels_remaining, currdata, parent_key=()):
|
||||
items = []
|
||||
for k, v in currdata.items():
|
||||
new_key = parent_key + (k,)
|
||||
if levels_remaining > 1:
|
||||
items.extend(flatten(levels_remaining - 1, v, new_key).items())
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
|
||||
ret = []
|
||||
for k, v in flatten(levels, data).items():
|
||||
ret.append((k, v))
|
||||
return tuple(ret)
|
||||
|
||||
async def export_data(self, custom_group_data):
|
||||
categories = [c.value for c in ConfigCategory]
|
||||
categories.extend(custom_group_data.keys())
|
||||
|
||||
ret = []
|
||||
for c in categories:
|
||||
ident_data = IdentifierData(
|
||||
self.unique_cog_identifier,
|
||||
c,
|
||||
(),
|
||||
(),
|
||||
custom_group_data,
|
||||
is_custom=c in custom_group_data,
|
||||
)
|
||||
try:
|
||||
data = await self.get(ident_data)
|
||||
except KeyError:
|
||||
continue
|
||||
ret.append((c, data))
|
||||
return ret
|
||||
|
||||
async def import_data(self, cog_data, custom_group_data):
|
||||
for category, all_data in cog_data:
|
||||
splitted_pkey = self._split_primary_key(category, custom_group_data, all_data)
|
||||
for pkey, data in splitted_pkey:
|
||||
ident_data = IdentifierData(
|
||||
self.unique_cog_identifier,
|
||||
category,
|
||||
pkey,
|
||||
(),
|
||||
custom_group_data,
|
||||
is_custom=category in custom_group_data,
|
||||
)
|
||||
await self.set(ident_data, data)
|
||||
|
||||
@staticmethod
|
||||
def get_pkey_len(identifier_data: IdentifierData) -> int:
|
||||
cat = identifier_data.category
|
||||
if cat == ConfigCategory.GLOBAL.value:
|
||||
return 0
|
||||
elif cat == ConfigCategory.MEMBER.value:
|
||||
return 2
|
||||
elif identifier_data.is_custom:
|
||||
return identifier_data.custom_group_data[cat]
|
||||
else:
|
||||
return 1
|
||||
@ -1,5 +1,4 @@
|
||||
import importlib.machinery
|
||||
from typing import Optional
|
||||
|
||||
import discord
|
||||
|
||||
@ -49,3 +48,37 @@ class BalanceTooHigh(BankError, OverflowError):
|
||||
return _("{user}'s balance cannot rise above {max:,} {currency}.").format(
|
||||
user=self.user, max=self.max_balance, currency=self.currency_name
|
||||
)
|
||||
|
||||
|
||||
class MissingExtraRequirements(RedError):
|
||||
"""Raised when an extra requirement is missing but required."""
|
||||
|
||||
|
||||
class ConfigError(RedError):
|
||||
"""Error in a Config operation."""
|
||||
|
||||
|
||||
class StoredTypeError(ConfigError, TypeError):
|
||||
"""A TypeError pertaining to stored Config data.
|
||||
|
||||
This error may arise when, for example, trying to increment a value
|
||||
which is not a number, or trying to toggle a value which is not a
|
||||
boolean.
|
||||
"""
|
||||
|
||||
|
||||
class CannotSetSubfield(StoredTypeError):
|
||||
"""Tried to set sub-field of an invalid data structure.
|
||||
|
||||
This would occur in the following example::
|
||||
|
||||
>>> import asyncio
|
||||
>>> from redbot.core import Config
|
||||
>>> config = Config.get_conf(None, 1234, cog_name="Example")
|
||||
>>> async def example():
|
||||
... await config.foo.set(True)
|
||||
... await config.set_raw("foo", "bar", False) # Should raise here
|
||||
...
|
||||
>>> asyncio.run(example())
|
||||
|
||||
"""
|
||||
|
||||
@ -1,7 +1,9 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
import tarfile
|
||||
from asyncio import AbstractEventLoop, as_completed, Semaphore
|
||||
from asyncio.futures import isfuture
|
||||
from itertools import chain
|
||||
@ -24,8 +26,10 @@ from typing import (
|
||||
)
|
||||
|
||||
import discord
|
||||
from datetime import datetime
|
||||
from fuzzywuzzy import fuzz, process
|
||||
|
||||
from .. import commands, data_manager
|
||||
from .chat_formatting import box
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -37,6 +41,7 @@ __all__ = [
|
||||
"fuzzy_command_search",
|
||||
"format_fuzzy_results",
|
||||
"deduplicate_iterables",
|
||||
"create_backup",
|
||||
]
|
||||
|
||||
_T = TypeVar("_T")
|
||||
@ -397,3 +402,45 @@ def bounded_gather(
|
||||
tasks = (_sem_wrapper(semaphore, task) for task in coros_or_futures)
|
||||
|
||||
return asyncio.gather(*tasks, loop=loop, return_exceptions=return_exceptions)
|
||||
|
||||
|
||||
async def create_backup(dest: Path = Path.home()) -> Optional[Path]:
|
||||
data_path = Path(data_manager.core_data_path().parent)
|
||||
if not data_path.exists():
|
||||
return
|
||||
|
||||
dest.mkdir(parents=True, exist_ok=True)
|
||||
timestr = datetime.utcnow().isoformat(timespec="minutes")
|
||||
backup_fpath = dest / f"redv3_{data_manager.instance_name}_{timestr}.tar.gz"
|
||||
|
||||
to_backup = []
|
||||
exclusions = [
|
||||
"__pycache__",
|
||||
"Lavalink.jar",
|
||||
os.path.join("Downloader", "lib"),
|
||||
os.path.join("CogManager", "cogs"),
|
||||
os.path.join("RepoManager", "repos"),
|
||||
]
|
||||
|
||||
# Avoiding circular imports
|
||||
from ...cogs.downloader.repo_manager import RepoManager
|
||||
|
||||
repo_mgr = RepoManager()
|
||||
await repo_mgr.initialize()
|
||||
repo_output = []
|
||||
for _, repo in repo_mgr._repos:
|
||||
repo_output.append({"url": repo.url, "name": repo.name, "branch": repo.branch})
|
||||
repos_file = data_path / "cogs" / "RepoManager" / "repos.json"
|
||||
with repos_file.open("w") as fs:
|
||||
json.dump(repo_output, fs, indent=4)
|
||||
instance_file = data_path / "instance.json"
|
||||
with instance_file.open("w") as fs:
|
||||
json.dump({data_manager.instance_name: data_manager.basic_config}, fs, indent=4)
|
||||
for f in data_path.glob("**/*"):
|
||||
if not any(ex in str(f) for ex in exclusions) and f.is_file():
|
||||
to_backup.append(f)
|
||||
|
||||
with tarfile.open(str(backup_fpath), "w:gz") as tar:
|
||||
for f in to_backup:
|
||||
tar.add(str(f), arcname=f.relative_to(data_path), recursive=False)
|
||||
return backup_fpath
|
||||
|
||||
@ -89,7 +89,7 @@ class Tunnel(metaclass=TunnelMeta):
|
||||
destination: discord.abc.Messageable,
|
||||
content: str = None,
|
||||
embed=None,
|
||||
files: Optional[List[discord.File]] = None
|
||||
files: Optional[List[discord.File]] = None,
|
||||
) -> List[discord.Message]:
|
||||
"""
|
||||
This does the actual sending, use this instead of a full tunnel
|
||||
|
||||
@ -7,15 +7,13 @@ import pytest
|
||||
from _pytest.monkeypatch import MonkeyPatch
|
||||
from redbot.core import Config
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core import config as config_module
|
||||
|
||||
from redbot.core.drivers import red_json
|
||||
from redbot.core import config as config_module, drivers
|
||||
|
||||
__all__ = [
|
||||
"monkeysession",
|
||||
"override_data_path",
|
||||
"coroutine",
|
||||
"json_driver",
|
||||
"driver",
|
||||
"config",
|
||||
"config_fr",
|
||||
"red",
|
||||
@ -56,34 +54,31 @@ def coroutine():
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def json_driver(tmpdir_factory):
|
||||
def driver(tmpdir_factory):
|
||||
import uuid
|
||||
|
||||
rand = str(uuid.uuid4())
|
||||
path = Path(str(tmpdir_factory.mktemp(rand)))
|
||||
driver = red_json.JSON("PyTest", identifier=str(uuid.uuid4()), data_path_override=path)
|
||||
return driver
|
||||
return drivers.get_driver("PyTest", str(random.randint(1, 999999)), data_path_override=path)
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def config(json_driver):
|
||||
def config(driver):
|
||||
config_module._config_cache = weakref.WeakValueDictionary()
|
||||
conf = Config(
|
||||
cog_name="PyTest", unique_identifier=json_driver.unique_cog_identifier, driver=json_driver
|
||||
)
|
||||
conf = Config(cog_name="PyTest", unique_identifier=driver.unique_cog_identifier, driver=driver)
|
||||
yield conf
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def config_fr(json_driver):
|
||||
def config_fr(driver):
|
||||
"""
|
||||
Mocked config object with force_register enabled.
|
||||
"""
|
||||
config_module._config_cache = weakref.WeakValueDictionary()
|
||||
conf = Config(
|
||||
cog_name="PyTest",
|
||||
unique_identifier=json_driver.unique_cog_identifier,
|
||||
driver=json_driver,
|
||||
unique_identifier=driver.unique_cog_identifier,
|
||||
driver=driver,
|
||||
force_registration=True,
|
||||
)
|
||||
yield conf
|
||||
@ -176,7 +171,7 @@ def red(config_fr):
|
||||
|
||||
Config.get_core_conf = lambda *args, **kwargs: config_fr
|
||||
|
||||
red = Red(cli_flags=cli_flags, description=description, dm_help=None)
|
||||
red = Red(cli_flags=cli_flags, description=description, dm_help=None, owner_id=None)
|
||||
|
||||
yield red
|
||||
|
||||
|
||||
355
redbot/setup.py
355
redbot/setup.py
@ -1,32 +1,21 @@
|
||||
#!/usr/bin/env python3
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import tarfile
|
||||
from copy import deepcopy
|
||||
from datetime import datetime as dt
|
||||
from pathlib import Path
|
||||
import logging
|
||||
from typing import Dict, Any, Optional
|
||||
|
||||
import appdirs
|
||||
import click
|
||||
|
||||
import redbot.logging
|
||||
from redbot.core.cli import confirm
|
||||
from redbot.core.data_manager import (
|
||||
basic_config_default,
|
||||
load_basic_configuration,
|
||||
instance_name,
|
||||
basic_config,
|
||||
cog_data_path,
|
||||
core_data_path,
|
||||
storage_details,
|
||||
)
|
||||
from redbot.core.utils import safe_delete
|
||||
from redbot.core import Config
|
||||
from redbot.core.utils import safe_delete, create_backup as _create_backup
|
||||
from redbot.core import config, data_manager, drivers
|
||||
from redbot.core.drivers import BackendType, IdentifierData
|
||||
from redbot.core.drivers.red_json import JSON
|
||||
|
||||
conversion_log = logging.getLogger("red.converter")
|
||||
|
||||
@ -61,11 +50,11 @@ else:
|
||||
|
||||
|
||||
def save_config(name, data, remove=False):
|
||||
config = load_existing_config()
|
||||
if remove and name in config:
|
||||
config.pop(name)
|
||||
_config = load_existing_config()
|
||||
if remove and name in _config:
|
||||
_config.pop(name)
|
||||
else:
|
||||
if name in config:
|
||||
if name in _config:
|
||||
print(
|
||||
"WARNING: An instance already exists with this name. "
|
||||
"Continuing will overwrite the existing instance config."
|
||||
@ -73,10 +62,10 @@ def save_config(name, data, remove=False):
|
||||
if not confirm("Are you absolutely certain you want to continue (y/n)? "):
|
||||
print("Not continuing")
|
||||
sys.exit(0)
|
||||
config[name] = data
|
||||
_config[name] = data
|
||||
|
||||
with config_file.open("w", encoding="utf-8") as fs:
|
||||
json.dump(config, fs, indent=4)
|
||||
json.dump(_config, fs, indent=4)
|
||||
|
||||
|
||||
def get_data_dir():
|
||||
@ -118,13 +107,14 @@ def get_data_dir():
|
||||
|
||||
|
||||
def get_storage_type():
|
||||
storage_dict = {1: "JSON", 2: "MongoDB"}
|
||||
storage_dict = {1: "JSON", 2: "MongoDB", 3: "PostgreSQL"}
|
||||
storage = None
|
||||
while storage is None:
|
||||
print()
|
||||
print("Please choose your storage backend (if you're unsure, choose 1).")
|
||||
print("1. JSON (file storage, requires no database).")
|
||||
print("2. MongoDB")
|
||||
print("3. PostgreSQL")
|
||||
storage = input("> ")
|
||||
try:
|
||||
storage = int(storage)
|
||||
@ -158,21 +148,16 @@ def basic_setup():
|
||||
|
||||
default_data_dir = get_data_dir()
|
||||
|
||||
default_dirs = deepcopy(basic_config_default)
|
||||
default_dirs = deepcopy(data_manager.basic_config_default)
|
||||
default_dirs["DATA_PATH"] = str(default_data_dir.resolve())
|
||||
|
||||
storage = get_storage_type()
|
||||
|
||||
storage_dict = {1: BackendType.JSON, 2: BackendType.MONGO}
|
||||
storage_dict = {1: BackendType.JSON, 2: BackendType.MONGO, 3: BackendType.POSTGRES}
|
||||
storage_type: BackendType = storage_dict.get(storage, BackendType.JSON)
|
||||
default_dirs["STORAGE_TYPE"] = storage_type.value
|
||||
|
||||
if storage_type == BackendType.MONGO:
|
||||
from redbot.core.drivers.red_mongo import get_config_details
|
||||
|
||||
default_dirs["STORAGE_DETAILS"] = get_config_details()
|
||||
else:
|
||||
default_dirs["STORAGE_DETAILS"] = {}
|
||||
driver_cls = drivers.get_driver_class(storage_type)
|
||||
default_dirs["STORAGE_DETAILS"] = driver_cls.get_config_details()
|
||||
|
||||
name = get_name()
|
||||
save_config(name, default_dirs)
|
||||
@ -193,130 +178,38 @@ def get_target_backend(backend) -> BackendType:
|
||||
return BackendType.JSON
|
||||
elif backend == "mongo":
|
||||
return BackendType.MONGO
|
||||
elif backend == "postgres":
|
||||
return BackendType.POSTGRES
|
||||
|
||||
|
||||
async def json_to_mongov2(instance):
|
||||
instance_vals = instance_data[instance]
|
||||
current_data_dir = Path(instance_vals["DATA_PATH"])
|
||||
async def do_migration(
|
||||
current_backend: BackendType, target_backend: BackendType
|
||||
) -> Dict[str, Any]:
|
||||
cur_driver_cls = drivers.get_driver_class(current_backend)
|
||||
new_driver_cls = drivers.get_driver_class(target_backend)
|
||||
cur_storage_details = data_manager.storage_details()
|
||||
new_storage_details = new_driver_cls.get_config_details()
|
||||
|
||||
load_basic_configuration(instance)
|
||||
await cur_driver_cls.initialize(**cur_storage_details)
|
||||
await new_driver_cls.initialize(**new_storage_details)
|
||||
|
||||
from redbot.core.drivers import red_mongo
|
||||
await config.migrate(cur_driver_cls, new_driver_cls)
|
||||
|
||||
storage_details = red_mongo.get_config_details()
|
||||
await cur_driver_cls.teardown()
|
||||
await new_driver_cls.teardown()
|
||||
|
||||
core_conf = Config.get_core_conf()
|
||||
new_driver = red_mongo.Mongo(cog_name="Core", identifier="0", **storage_details)
|
||||
|
||||
core_conf.init_custom("CUSTOM_GROUPS", 2)
|
||||
custom_group_data = await core_conf.custom("CUSTOM_GROUPS").all()
|
||||
|
||||
curr_custom_data = custom_group_data.get("Core", {}).get("0", {})
|
||||
exported_data = await core_conf.driver.export_data(curr_custom_data)
|
||||
conversion_log.info("Starting Core conversion...")
|
||||
await new_driver.import_data(exported_data, curr_custom_data)
|
||||
conversion_log.info("Core conversion complete.")
|
||||
|
||||
for p in current_data_dir.glob("cogs/**/settings.json"):
|
||||
cog_name = p.parent.stem
|
||||
if "." in cog_name:
|
||||
# Garbage handler
|
||||
continue
|
||||
with p.open(mode="r", encoding="utf-8") as f:
|
||||
cog_data = json.load(f)
|
||||
for identifier, all_data in cog_data.items():
|
||||
try:
|
||||
conf = Config.get_conf(None, int(identifier), cog_name=cog_name)
|
||||
except ValueError:
|
||||
continue
|
||||
new_driver = red_mongo.Mongo(
|
||||
cog_name=cog_name, identifier=conf.driver.unique_cog_identifier, **storage_details
|
||||
)
|
||||
|
||||
curr_custom_data = custom_group_data.get(cog_name, {}).get(identifier, {})
|
||||
|
||||
exported_data = await conf.driver.export_data(curr_custom_data)
|
||||
conversion_log.info(f"Converting {cog_name} with identifier {identifier}...")
|
||||
await new_driver.import_data(exported_data, curr_custom_data)
|
||||
|
||||
conversion_log.info("Cog conversion complete.")
|
||||
|
||||
return storage_details
|
||||
return new_storage_details
|
||||
|
||||
|
||||
async def mongov2_to_json(instance):
|
||||
load_basic_configuration(instance)
|
||||
|
||||
core_path = core_data_path()
|
||||
|
||||
from redbot.core.drivers import red_json
|
||||
|
||||
core_conf = Config.get_core_conf()
|
||||
new_driver = red_json.JSON(cog_name="Core", identifier="0", data_path_override=core_path)
|
||||
|
||||
core_conf.init_custom("CUSTOM_GROUPS", 2)
|
||||
custom_group_data = await core_conf.custom("CUSTOM_GROUPS").all()
|
||||
|
||||
curr_custom_data = custom_group_data.get("Core", {}).get("0", {})
|
||||
exported_data = await core_conf.driver.export_data(curr_custom_data)
|
||||
conversion_log.info("Starting Core conversion...")
|
||||
await new_driver.import_data(exported_data, curr_custom_data)
|
||||
conversion_log.info("Core conversion complete.")
|
||||
|
||||
collection_names = await core_conf.driver.db.list_collection_names()
|
||||
splitted_names = list(
|
||||
filter(
|
||||
lambda elem: elem[1] != "" and elem[0] != "Core",
|
||||
[n.split(".") for n in collection_names],
|
||||
)
|
||||
)
|
||||
|
||||
ident_map = {} # Cogname: idents list
|
||||
for cog_name, category in splitted_names:
|
||||
if cog_name not in ident_map:
|
||||
ident_map[cog_name] = set()
|
||||
|
||||
idents = await core_conf.driver.db[cog_name][category].distinct("_id.RED_uuid")
|
||||
ident_map[cog_name].update(set(idents))
|
||||
|
||||
for cog_name, idents in ident_map.items():
|
||||
for identifier in idents:
|
||||
curr_custom_data = custom_group_data.get(cog_name, {}).get(identifier, {})
|
||||
try:
|
||||
conf = Config.get_conf(None, int(identifier), cog_name=cog_name)
|
||||
except ValueError:
|
||||
continue
|
||||
exported_data = await conf.driver.export_data(curr_custom_data)
|
||||
|
||||
new_path = cog_data_path(raw_name=cog_name)
|
||||
new_driver = red_json.JSON(cog_name, identifier, data_path_override=new_path)
|
||||
conversion_log.info(f"Converting {cog_name} with identifier {identifier}...")
|
||||
await new_driver.import_data(exported_data, curr_custom_data)
|
||||
|
||||
# cog_data_path(raw_name=cog_name)
|
||||
|
||||
conversion_log.info("Cog conversion complete.")
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
async def mongo_to_json(instance):
|
||||
load_basic_configuration(instance)
|
||||
|
||||
from redbot.core.drivers.red_mongo import Mongo
|
||||
|
||||
m = Mongo("Core", "0", **storage_details())
|
||||
async def mongov1_to_json() -> Dict[str, Any]:
|
||||
await drivers.MongoDriver.initialize(**data_manager.storage_details())
|
||||
m = drivers.MongoDriver("Core", "0")
|
||||
db = m.db
|
||||
collection_names = await db.list_collection_names()
|
||||
for collection_name in collection_names:
|
||||
if "." in collection_name:
|
||||
# Fix for one of Zeph's problems
|
||||
continue
|
||||
elif collection_name == "Core":
|
||||
c_data_path = core_data_path()
|
||||
else:
|
||||
c_data_path = cog_data_path(raw_name=collection_name)
|
||||
c_data_path.mkdir(parents=True, exist_ok=True)
|
||||
# Every cog name has its own collection
|
||||
collection = db[collection_name]
|
||||
async for document in collection.find():
|
||||
@ -329,16 +222,22 @@ async def mongo_to_json(instance):
|
||||
continue
|
||||
elif not str(cog_id).isdigit():
|
||||
continue
|
||||
driver = JSON(collection_name, cog_id, data_path_override=c_data_path)
|
||||
driver = drivers.JsonDriver(collection_name, cog_id)
|
||||
for category, value in document.items():
|
||||
ident_data = IdentifierData(str(cog_id), category, (), (), {})
|
||||
ident_data = IdentifierData(
|
||||
str(collection_name), str(cog_id), category, tuple(), tuple(), 0
|
||||
)
|
||||
await driver.set(ident_data, value=value)
|
||||
|
||||
conversion_log.info("Cog conversion complete.")
|
||||
await drivers.MongoDriver.teardown()
|
||||
|
||||
return {}
|
||||
|
||||
|
||||
async def edit_instance():
|
||||
instance_list = load_existing_config()
|
||||
if not instance_list:
|
||||
_instance_list = load_existing_config()
|
||||
if not _instance_list:
|
||||
print("No instances have been set up!")
|
||||
return
|
||||
|
||||
@ -346,18 +245,18 @@ async def edit_instance():
|
||||
"You have chosen to edit an instance. The following "
|
||||
"is a list of instances that currently exist:\n"
|
||||
)
|
||||
for instance in instance_list.keys():
|
||||
for instance in _instance_list.keys():
|
||||
print("{}\n".format(instance))
|
||||
print("Please select one of the above by entering its name")
|
||||
selected = input("> ")
|
||||
|
||||
if selected not in instance_list.keys():
|
||||
if selected not in _instance_list.keys():
|
||||
print("That isn't a valid instance!")
|
||||
return
|
||||
instance_data = instance_list[selected]
|
||||
default_dirs = deepcopy(basic_config_default)
|
||||
_instance_data = _instance_list[selected]
|
||||
default_dirs = deepcopy(data_manager.basic_config_default)
|
||||
|
||||
current_data_dir = Path(instance_data["DATA_PATH"])
|
||||
current_data_dir = Path(_instance_data["DATA_PATH"])
|
||||
print("You have selected '{}' as the instance to modify.".format(selected))
|
||||
if not confirm("Please confirm (y/n):"):
|
||||
print("Ok, we will not continue then.")
|
||||
@ -383,68 +282,47 @@ async def edit_instance():
|
||||
print("Your basic configuration has been edited")
|
||||
|
||||
|
||||
async def create_backup(instance):
|
||||
instance_vals = instance_data[instance]
|
||||
if confirm("Would you like to make a backup of the data for this instance? (y/n)"):
|
||||
load_basic_configuration(instance)
|
||||
if instance_vals["STORAGE_TYPE"] == "MongoDB":
|
||||
await mongo_to_json(instance)
|
||||
print("Backing up the instance's data...")
|
||||
backup_filename = "redv3-{}-{}.tar.gz".format(
|
||||
instance, dt.utcnow().strftime("%Y-%m-%d %H-%M-%S")
|
||||
)
|
||||
pth = Path(instance_vals["DATA_PATH"])
|
||||
if pth.exists():
|
||||
backup_pth = pth.home()
|
||||
backup_file = backup_pth / backup_filename
|
||||
|
||||
to_backup = []
|
||||
exclusions = [
|
||||
"__pycache__",
|
||||
"Lavalink.jar",
|
||||
os.path.join("Downloader", "lib"),
|
||||
os.path.join("CogManager", "cogs"),
|
||||
os.path.join("RepoManager", "repos"),
|
||||
]
|
||||
from redbot.cogs.downloader.repo_manager import RepoManager
|
||||
|
||||
repo_mgr = RepoManager()
|
||||
await repo_mgr.initialize()
|
||||
repo_output = []
|
||||
for repo in repo_mgr._repos.values():
|
||||
repo_output.append({"url": repo.url, "name": repo.name, "branch": repo.branch})
|
||||
repo_filename = pth / "cogs" / "RepoManager" / "repos.json"
|
||||
with open(str(repo_filename), "w") as f:
|
||||
f.write(json.dumps(repo_output, indent=4))
|
||||
instance_vals = {instance_name: basic_config}
|
||||
instance_file = pth / "instance.json"
|
||||
with open(str(instance_file), "w") as instance_out:
|
||||
instance_out.write(json.dumps(instance_vals, indent=4))
|
||||
for f in pth.glob("**/*"):
|
||||
if not any(ex in str(f) for ex in exclusions):
|
||||
to_backup.append(f)
|
||||
with tarfile.open(str(backup_file), "w:gz") as tar:
|
||||
for f in to_backup:
|
||||
tar.add(str(f), recursive=False)
|
||||
print("A backup of {} has been made. It is at {}".format(instance, backup_file))
|
||||
|
||||
|
||||
async def remove_instance(instance):
|
||||
await create_backup(instance)
|
||||
|
||||
instance_vals = instance_data[instance]
|
||||
if instance_vals["STORAGE_TYPE"] == "MongoDB":
|
||||
from redbot.core.drivers.red_mongo import Mongo
|
||||
|
||||
m = Mongo("Core", **instance_vals["STORAGE_DETAILS"])
|
||||
db = m.db
|
||||
collections = await db.collection_names(include_system_collections=False)
|
||||
for name in collections:
|
||||
collection = await db.get_collection(name)
|
||||
await collection.drop()
|
||||
async def create_backup(instance: str) -> None:
|
||||
data_manager.load_basic_configuration(instance)
|
||||
backend_type = get_current_backend(instance)
|
||||
if backend_type == BackendType.MONGOV1:
|
||||
await mongov1_to_json()
|
||||
elif backend_type != BackendType.JSON:
|
||||
await do_migration(backend_type, BackendType.JSON)
|
||||
print("Backing up the instance's data...")
|
||||
backup_fpath = await _create_backup()
|
||||
if backup_fpath is not None:
|
||||
print(f"A backup of {instance} has been made. It is at {backup_fpath}")
|
||||
else:
|
||||
pth = Path(instance_vals["DATA_PATH"])
|
||||
safe_delete(pth)
|
||||
print("Creating the backup failed.")
|
||||
|
||||
|
||||
async def remove_instance(
|
||||
instance,
|
||||
interactive: bool = False,
|
||||
drop_db: Optional[bool] = None,
|
||||
remove_datapath: Optional[bool] = None,
|
||||
):
|
||||
data_manager.load_basic_configuration(instance)
|
||||
|
||||
if confirm("Would you like to make a backup of the data for this instance? (y/n)"):
|
||||
await create_backup(instance)
|
||||
|
||||
backend = get_current_backend(instance)
|
||||
if backend == BackendType.MONGOV1:
|
||||
driver_cls = drivers.MongoDriver
|
||||
else:
|
||||
driver_cls = drivers.get_driver_class(backend)
|
||||
|
||||
await driver_cls.delete_all_data(interactive=interactive, drop_db=drop_db)
|
||||
|
||||
if interactive is True and remove_datapath is None:
|
||||
remove_datapath = confirm("Would you like to delete the instance's entire datapath? (y/n)")
|
||||
|
||||
if remove_datapath is True:
|
||||
data_path = data_manager.core_data_path().parent
|
||||
safe_delete(data_path)
|
||||
|
||||
save_config(instance, {}, remove=True)
|
||||
print("The instance {} has been removed\n".format(instance))
|
||||
|
||||
@ -467,8 +345,7 @@ async def remove_instance_interaction():
|
||||
print("That isn't a valid instance!")
|
||||
return
|
||||
|
||||
await create_backup(selected)
|
||||
await remove_instance(selected)
|
||||
await remove_instance(selected, interactive=True)
|
||||
|
||||
|
||||
@click.group(invoke_without_command=True)
|
||||
@ -483,38 +360,56 @@ def cli(ctx, debug):
|
||||
|
||||
@cli.command()
|
||||
@click.argument("instance", type=click.Choice(instance_list))
|
||||
def delete(instance):
|
||||
@click.option("--no-prompt", default=False, help="Don't ask for user input during the process.")
|
||||
@click.option(
|
||||
"--drop-db",
|
||||
type=bool,
|
||||
default=None,
|
||||
help=(
|
||||
"Drop the entire database constaining this instance's data. Has no effect on JSON "
|
||||
"instances. If this option and --no-prompt are omitted, you will be asked about this."
|
||||
),
|
||||
)
|
||||
@click.option(
|
||||
"--remove-datapath",
|
||||
type=bool,
|
||||
default=None,
|
||||
help=(
|
||||
"Remove this entire instance's datapath. If this option and --no-prompt are omitted, you "
|
||||
"will be asked about this."
|
||||
),
|
||||
)
|
||||
def delete(instance: str, no_prompt: Optional[bool], drop_db: Optional[bool]):
|
||||
loop = asyncio.get_event_loop()
|
||||
loop.run_until_complete(remove_instance(instance))
|
||||
if no_prompt is None:
|
||||
interactive = None
|
||||
else:
|
||||
interactive = not no_prompt
|
||||
loop.run_until_complete(remove_instance(instance, interactive, drop_db))
|
||||
|
||||
|
||||
@cli.command()
|
||||
@click.argument("instance", type=click.Choice(instance_list))
|
||||
@click.argument("backend", type=click.Choice(["json", "mongo"]))
|
||||
@click.argument("backend", type=click.Choice(["json", "mongo", "postgres"]))
|
||||
def convert(instance, backend):
|
||||
current_backend = get_current_backend(instance)
|
||||
target = get_target_backend(backend)
|
||||
data_manager.load_basic_configuration(instance)
|
||||
|
||||
default_dirs = deepcopy(basic_config_default)
|
||||
default_dirs = deepcopy(data_manager.basic_config_default)
|
||||
default_dirs["DATA_PATH"] = str(Path(instance_data[instance]["DATA_PATH"]))
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
new_storage_details = None
|
||||
|
||||
if current_backend == BackendType.MONGOV1:
|
||||
if target == BackendType.MONGO:
|
||||
if target == BackendType.JSON:
|
||||
new_storage_details = loop.run_until_complete(mongov1_to_json())
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Please see conversion docs for updating to the latest mongo version."
|
||||
)
|
||||
elif target == BackendType.JSON:
|
||||
new_storage_details = loop.run_until_complete(mongo_to_json(instance))
|
||||
elif current_backend == BackendType.JSON:
|
||||
if target == BackendType.MONGO:
|
||||
new_storage_details = loop.run_until_complete(json_to_mongov2(instance))
|
||||
elif current_backend == BackendType.MONGO:
|
||||
if target == BackendType.JSON:
|
||||
new_storage_details = loop.run_until_complete(mongov2_to_json(instance))
|
||||
else:
|
||||
new_storage_details = loop.run_until_complete(do_migration(current_backend, target))
|
||||
|
||||
if new_storage_details is not None:
|
||||
default_dirs["STORAGE_TYPE"] = target.value
|
||||
@ -522,7 +417,9 @@ def convert(instance, backend):
|
||||
save_config(instance, default_dirs)
|
||||
conversion_log.info(f"Conversion to {target} complete.")
|
||||
else:
|
||||
conversion_log.info(f"Cannot convert {current_backend} to {target} at this time.")
|
||||
conversion_log.info(
|
||||
f"Cannot convert {current_backend.value} to {target.value} at this time."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@ -80,6 +80,8 @@ mongo =
|
||||
dnspython==1.16.0
|
||||
motor==2.0.0
|
||||
pymongo==3.8.0
|
||||
postgres =
|
||||
asyncpg==0.18.3
|
||||
style =
|
||||
black==19.3b0
|
||||
toml==0.10.0
|
||||
@ -123,3 +125,5 @@ include =
|
||||
**/locales/*.po
|
||||
data/*
|
||||
data/**/*
|
||||
redbot.core.drivers.postgres =
|
||||
*.sql
|
||||
|
||||
45
tests/conftest.py
Normal file
45
tests/conftest.py
Normal file
@ -0,0 +1,45 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from redbot.core import drivers, data_manager
|
||||
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def event_loop(request):
|
||||
"""Create an instance of the default event loop for entire session."""
|
||||
loop = asyncio.get_event_loop_policy().new_event_loop()
|
||||
yield loop
|
||||
loop.close()
|
||||
|
||||
|
||||
def _get_backend_type():
|
||||
if os.getenv("RED_STORAGE_TYPE") == "postgres":
|
||||
return drivers.BackendType.POSTGRES
|
||||
elif os.getenv("RED_STORAGE_TYPE") == "mongo":
|
||||
return drivers.BackendType.MONGO
|
||||
else:
|
||||
return drivers.BackendType.JSON
|
||||
|
||||
|
||||
@pytest.fixture(scope="session", autouse=True)
|
||||
async def _setup_driver():
|
||||
backend_type = _get_backend_type()
|
||||
if backend_type == drivers.BackendType.MONGO:
|
||||
storage_details = {
|
||||
"URI": os.getenv("RED_MONGO_URI", "mongodb"),
|
||||
"HOST": os.getenv("RED_MONGO_HOST", "localhost"),
|
||||
"PORT": int(os.getenv("RED_MONGO_PORT", "27017")),
|
||||
"USERNAME": os.getenv("RED_MONGO_USER", "red"),
|
||||
"PASSWORD": os.getenv("RED_MONGO_PASSWORD", "red"),
|
||||
"DB_NAME": os.getenv("RED_MONGO_DATABASE", "red_db"),
|
||||
}
|
||||
else:
|
||||
storage_details = {}
|
||||
data_manager.storage_type = lambda: backend_type.value
|
||||
data_manager.storage_details = lambda: storage_details
|
||||
driver_cls = drivers.get_driver_class(backend_type)
|
||||
await driver_cls.initialize(**storage_details)
|
||||
yield
|
||||
await driver_cls.teardown()
|
||||
@ -1,3 +1,3 @@
|
||||
packaging
|
||||
tox
|
||||
-e .[docs,mongo,style,test]
|
||||
-e .[docs,mongo,postgres,style,test]
|
||||
|
||||
@ -34,6 +34,8 @@ docs =
|
||||
mongo =
|
||||
dnspython
|
||||
motor
|
||||
postgres =
|
||||
asyncpg
|
||||
style =
|
||||
black
|
||||
test =
|
||||
|
||||
37
tox.ini
37
tox.ini
@ -15,12 +15,47 @@ description = Run tests and basic automatic issue checking.
|
||||
whitelist_externals =
|
||||
pytest
|
||||
pylint
|
||||
extras = voice, test, mongo
|
||||
extras = voice, test
|
||||
commands =
|
||||
python -m compileall ./redbot/cogs
|
||||
pytest
|
||||
pylint ./redbot
|
||||
|
||||
[testenv:postgres]
|
||||
description = Run pytest with PostgreSQL backend
|
||||
whitelist_externals =
|
||||
pytest
|
||||
extras = voice, test, postgres
|
||||
setenv =
|
||||
RED_STORAGE_TYPE=postgres
|
||||
passenv =
|
||||
# Use the following env vars for connection options, or other default options described here:
|
||||
# https://magicstack.github.io/asyncpg/current/index.html#asyncpg.connection.connect
|
||||
PGHOST
|
||||
PGPORT
|
||||
PGUSER
|
||||
PGPASSWORD
|
||||
PGDATABASE
|
||||
commands =
|
||||
pytest
|
||||
|
||||
[testenv:mongo]
|
||||
description = Run pytest with MongoDB backend
|
||||
whitelist_externals =
|
||||
pytest
|
||||
extras = voice, test, mongo
|
||||
setenv =
|
||||
RED_STORAGE_TYPE=mongo
|
||||
passenv =
|
||||
RED_MONGO_URI
|
||||
RED_MONGO_HOST
|
||||
RED_MONGO_PORT
|
||||
RED_MONGO_USER
|
||||
RED_MONGO_PASSWORD
|
||||
RED_MONGO_DATABASE
|
||||
commands =
|
||||
pytest
|
||||
|
||||
[testenv:docs]
|
||||
description = Attempt to build docs with sphinx-build
|
||||
whitelist_externals =
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user