diff --git a/.travis.yml b/.travis.yml index 11bc49b0a..a1748288b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,14 +5,10 @@ notifications: email: false python: -- 3.7.2 +- 3.7.3 env: global: - PIPENV_IGNORE_VIRTUALENVS=1 - matrix: - - TOXENV=py - - TOXENV=docs - - TOXENV=style install: - pip install --upgrade pip tox @@ -22,6 +18,19 @@ script: jobs: include: + - env: TOXENV=py + - env: TOXENV=docs + - env: TOXENV=style + - env: TOXENV=postgres + services: postgresql + addons: + postgresql: "10" + before_script: + - psql -c 'create database red_db;' -U postgres + - env: TOXENV=mongo + services: mongodb + before_script: + - mongo red_db --eval 'db.createUser({user:"red",pwd:"red",roles:["readWrite"]});' # These jobs only occur on tag creation if the prior ones succeed - stage: PyPi Deployment if: tag IS present diff --git a/Makefile b/Makefile index 5d9afbf64..c6f92b713 100644 --- a/Makefile +++ b/Makefile @@ -1,8 +1,8 @@ # Python Code Style reformat: - black -l 99 `git ls-files "*.py"` + black -l 99 --target-version py37 `git ls-files "*.py"` stylecheck: - black --check -l 99 `git ls-files "*.py"` + black --check -l 99 --target-version py37 `git ls-files "*.py"` # Translations gettext: diff --git a/changelog.d/2723.feature.rst b/changelog.d/2723.feature.rst new file mode 100644 index 000000000..1e2f3d49b --- /dev/null +++ b/changelog.d/2723.feature.rst @@ -0,0 +1 @@ +Added a config driver for PostgreSQL diff --git a/changelog.d/2723.misc.rst b/changelog.d/2723.misc.rst new file mode 100644 index 000000000..2d4afcf25 --- /dev/null +++ b/changelog.d/2723.misc.rst @@ -0,0 +1,32 @@ +Changes to the ``redbot.core.drivers`` package: + +- The modules inside the ``redbot.core.drivers`` package no longer have the ``red_`` prefix in front of their names. +- All driver classes are now directly accessible as attributes to the ``redbot.core.drivers`` package. +- :func:`get_driver`'s signature has been changed. +- :func:`get_driver` can now use data manager to infer the backend type if it is not supplied as an argument. +- :func:`get_driver_class` has been added. + +Changes to the :class:`BaseDriver` and :class:`JsonDriver` classes class: + +- :meth:`BaseDriver.get_config_details` is an now abstract staticmethod. +- :meth:`BaseDriver.initialize` and :meth:`BaseDriver.teardown` are two new abstract coroutine classmethods. +- :meth:`BaseDriver.delete_all_data` is a new concrete (but overrideable) coroutine instance method. +- :meth:`BaseDriver.aiter_cogs` is a new abstract asynchronous iterator method. +- :meth:`BaseDriver.migrate_to` is a new concrete coroutine classmethod. +- :class:`JsonDriver` no longer requires the data path when constructed and will infer the data path from data manager. + +Changes to the :class:`IdentifierData` class and :class:`ConfigCategory` enum: + +- ``IdentifierData.custom_group_data`` has been replaced by :attr:`IdentifierData.primary_key_len`. +- :meth:`ConfigCategory.get_pkey_info` is a new classmethod. + +Changes to the migration and backup system: + +- All code in the ``redbot.setup`` script, excluding that regarding MongoV1, is now virtually backend-agnostic. +- All code in the ``[p]backup`` is now backend-agnostic. +- :func:`redbot.core.config.migrate` is a new coroutine function. +- All a new driver needs to do now to be compatible with migrations and backups is to implement the :class:`BaseDriver` ABC. + +Enhancements to unit tests: + +- New tox recipes have been added for testing against Mongo and Postgres backends. See the ``tox.ini`` file for clues on how to run them. diff --git a/docs/framework_config.rst b/docs/framework_config.rst index 96b8d5768..d18aed63b 100644 --- a/docs/framework_config.rst +++ b/docs/framework_config.rst @@ -421,15 +421,15 @@ Driver Reference Base Driver ^^^^^^^^^^^ -.. autoclass:: redbot.core.drivers.red_base.BaseDriver +.. autoclass:: redbot.core.drivers.BaseDriver :members: JSON Driver ^^^^^^^^^^^ -.. autoclass:: redbot.core.drivers.red_json.JSON +.. autoclass:: redbot.core.drivers.JsonDriver :members: Mongo Driver ^^^^^^^^^^^^ -.. autoclass:: redbot.core.drivers.red_mongo.Mongo +.. autoclass:: redbot.core.drivers.MongoDriver :members: diff --git a/docs/install_linux_mac.rst b/docs/install_linux_mac.rst index 22d0c492e..cb58b13db 100644 --- a/docs/install_linux_mac.rst +++ b/docs/install_linux_mac.rst @@ -200,6 +200,12 @@ Or, to install with MongoDB support: python3.7 -m pip install -U Red-DiscordBot[mongo] +Or, to install with PostgreSQL support: + +.. code-block:: none + + python3.7 -m pip install -U Red-DiscordBot[postgres] + .. note:: To install the development version, replace ``Red-DiscordBot`` in the above commands with the diff --git a/docs/install_windows.rst b/docs/install_windows.rst index f36753f70..ad316cb58 100644 --- a/docs/install_windows.rst +++ b/docs/install_windows.rst @@ -65,7 +65,7 @@ Installing Red If you're not inside an activated virtual environment, include the ``--user`` flag with all ``pip`` commands. - * No MongoDB support: + * Normal installation: .. code-block:: none @@ -77,6 +77,12 @@ Installing Red python -m pip install -U Red-DiscordBot[mongo] + * With PostgreSQL support: + + .. code-block:: none + + python3.7 -m pip install -U Red-DiscordBot[postgres] + .. note:: To install the development version, replace ``Red-DiscordBot`` in the above commands with the diff --git a/make.bat b/make.bat index 46ea86173..1e38b16e7 100644 --- a/make.bat +++ b/make.bat @@ -14,11 +14,11 @@ for /F "tokens=* USEBACKQ" %%A in (`git ls-files "*.py"`) do ( goto %1 :reformat -black -l 99 !PYFILES! +black -l 99 --target-version py37 !PYFILES! exit /B %ERRORLEVEL% :stylecheck -black -l 99 --check !PYFILES! +black -l 99 --check --target-version py37 !PYFILES! exit /B %ERRORLEVEL% :newenv diff --git a/redbot/__main__.py b/redbot/__main__.py index 38338b0ca..2cac3e178 100644 --- a/redbot/__main__.py +++ b/redbot/__main__.py @@ -33,7 +33,7 @@ from redbot.core.events import init_events from redbot.core.cli import interactive_config, confirm, parse_cli_flags from redbot.core.core_commands import Core from redbot.core.dev_commands import Dev -from redbot.core import __version__, modlog, bank, data_manager +from redbot.core import __version__, modlog, bank, data_manager, drivers from signal import SIGTERM @@ -99,7 +99,11 @@ def main(): ) cli_flags.instance_name = "temporary_red" data_manager.create_temp_config() + loop = asyncio.get_event_loop() + data_manager.load_basic_configuration(cli_flags.instance_name) + driver_cls = drivers.get_driver_class() + loop.run_until_complete(driver_cls.initialize(**data_manager.storage_details())) redbot.logging.init_logging( level=cli_flags.logging_level, location=data_manager.core_data_path() / "logs" ) @@ -111,7 +115,6 @@ def main(): red = Red( cli_flags=cli_flags, description=description, dm_help=None, fetch_offline_members=True ) - loop = asyncio.get_event_loop() loop.run_until_complete(red.maybe_update_config()) init_global_checks(red) init_events(red, cli_flags) diff --git a/redbot/cogs/warnings/warnings.py b/redbot/cogs/warnings/warnings.py index b53fdaf8d..a871ae62f 100644 --- a/redbot/cogs/warnings/warnings.py +++ b/redbot/cogs/warnings/warnings.py @@ -251,7 +251,7 @@ class Warnings(commands.Cog): user: discord.Member, points: Optional[int] = 1, *, - reason: str + reason: str, ): """Warn the user for the specified reason. diff --git a/redbot/core/bot.py b/redbot/core/bot.py index 9d047aed0..a7c82f058 100644 --- a/redbot/core/bot.py +++ b/redbot/core/bot.py @@ -1,7 +1,7 @@ import asyncio import inspect -import os import logging +import os from collections import Counter from enum import Enum from importlib.machinery import ModuleSpec @@ -9,10 +9,9 @@ from pathlib import Path from typing import Optional, Union, List import discord -import sys from discord.ext.commands import when_mentioned_or -from . import Config, i18n, commands, errors +from . import Config, i18n, commands, errors, drivers from .cog_manager import CogManager from .rpc import RPCMixin @@ -592,8 +591,8 @@ class Red(RedBase, discord.AutoShardedClient): async def logout(self): """Logs out of Discord and closes all connections.""" - await super().logout() + await drivers.get_driver_class().teardown() async def shutdown(self, *, restart: bool = False): """Gracefully quit Red. diff --git a/redbot/core/config.py b/redbot/core/config.py index e81b20c0d..1fed17bba 100644 --- a/redbot/core/config.py +++ b/redbot/core/config.py @@ -5,23 +5,22 @@ import pickle import weakref from typing import ( Any, - Union, - Tuple, - Dict, - Awaitable, AsyncContextManager, - TypeVar, + Awaitable, + Dict, MutableMapping, Optional, + Tuple, + Type, + TypeVar, + Union, ) import discord -from .data_manager import cog_data_path, core_data_path -from .drivers import get_driver, IdentifierData, BackendType -from .drivers.red_base import BaseDriver +from .drivers import IdentifierData, get_driver, ConfigCategory, BaseDriver -__all__ = ["Config", "get_latest_confs"] +__all__ = ["Config", "get_latest_confs", "migrate"] log = logging.getLogger("red.config") @@ -101,7 +100,7 @@ class Value: Information on identifiers for this value. default The default value for the data element that `identifiers` points at. - driver : `redbot.core.drivers.red_base.BaseDriver` + driver : `redbot.core.drivers.BaseDriver` A reference to `Config.driver`. """ @@ -250,7 +249,7 @@ class Group(Value): All registered default values for this Group. force_registration : `bool` Same as `Config.force_registration`. - driver : `redbot.core.drivers.red_base.BaseDriver` + driver : `redbot.core.drivers.BaseDriver` A reference to `Config.driver`. """ @@ -586,7 +585,7 @@ class Config: Unique identifier provided to differentiate cog data when name conflicts occur. driver - An instance of a driver that implements `redbot.core.drivers.red_base.BaseDriver`. + An instance of a driver that implements `redbot.core.drivers.BaseDriver`. force_registration : `bool` Determines if Config should throw an error if a cog attempts to access an attribute which has not been previously registered. @@ -634,7 +633,7 @@ class Config: self.force_registration = force_registration self._defaults = defaults or {} - self.custom_groups = {} + self.custom_groups: Dict[str, int] = {} self._lock_cache: MutableMapping[ IdentifierData, asyncio.Lock ] = weakref.WeakValueDictionary() @@ -643,10 +642,6 @@ class Config: def defaults(self): return pickle.loads(pickle.dumps(self._defaults, -1)) - @staticmethod - def _create_uuid(identifier: int): - return str(identifier) - @classmethod def get_conf(cls, cog_instance, identifier: int, force_registration=False, cog_name=None): """Get a Config instance for your cog. @@ -681,25 +676,12 @@ class Config: A new Config object. """ - if cog_instance is None and cog_name is not None: - cog_path_override = cog_data_path(raw_name=cog_name) - else: - cog_path_override = cog_data_path(cog_instance=cog_instance) + uuid = str(identifier) + if cog_name is None: + cog_name = type(cog_instance).__name__ - cog_name = cog_path_override.stem - # uuid = str(hash(identifier)) - uuid = cls._create_uuid(identifier) - - # We have to import this here otherwise we have a circular dependency - from .data_manager import basic_config - - driver_name = basic_config.get("STORAGE_TYPE", "JSON") - driver_details = basic_config.get("STORAGE_DETAILS", {}) - - driver = get_driver( - driver_name, cog_name, uuid, data_path_override=cog_path_override, **driver_details - ) - if driver_name == BackendType.JSON.value: + driver = get_driver(cog_name, uuid) + if hasattr(driver, "migrate_identifier"): driver.migrate_identifier(identifier) conf = cls( @@ -712,7 +694,7 @@ class Config: @classmethod def get_core_conf(cls, force_registration: bool = False): - """Get a Config instance for a core module. + """Get a Config instance for the core bot. All core modules that require a config instance should use this classmethod instead of `get_conf`. @@ -723,24 +705,9 @@ class Config: See `force_registration`. """ - core_path = core_data_path() - - # We have to import this here otherwise we have a circular dependency - from .data_manager import basic_config - - driver_name = basic_config.get("STORAGE_TYPE", "JSON") - driver_details = basic_config.get("STORAGE_DETAILS", {}) - - driver = get_driver( - driver_name, "Core", "0", data_path_override=core_path, **driver_details + return cls.get_conf( + None, cog_name="Core", identifier=0, force_registration=force_registration ) - conf = cls( - cog_name="Core", - driver=driver, - unique_identifier="0", - force_registration=force_registration, - ) - return conf def __getattr__(self, item: str) -> Union[Group, Value]: """Same as `group.__getattr__` except for global data. @@ -916,26 +883,18 @@ class Config: self.custom_groups[group_identifier] = identifier_count def _get_base_group(self, category: str, *primary_keys: str) -> Group: - is_custom = category not in ( - self.GLOBAL, - self.GUILD, - self.USER, - self.MEMBER, - self.ROLE, - self.CHANNEL, - ) - # noinspection PyTypeChecker + pkey_len, is_custom = ConfigCategory.get_pkey_info(category, self.custom_groups) identifier_data = IdentifierData( + cog_name=self.cog_name, uuid=self.unique_identifier, category=category, primary_key=primary_keys, identifiers=(), - custom_group_data=self.custom_groups, + primary_key_len=pkey_len, is_custom=is_custom, ) - pkey_len = BaseDriver.get_pkey_len(identifier_data) - if len(primary_keys) < pkey_len: + if len(primary_keys) < identifier_data.primary_key_len: # Don't mix in defaults with groups higher than the document level defaults = {} else: @@ -1220,9 +1179,7 @@ class Config: """ if not scopes: # noinspection PyTypeChecker - identifier_data = IdentifierData( - self.unique_identifier, "", (), (), self.custom_groups - ) + identifier_data = IdentifierData(self.cog_name, self.unique_identifier, "", (), (), 0) group = Group(identifier_data, defaults={}, driver=self.driver, config=self) else: cat, *scopes = scopes @@ -1359,7 +1316,12 @@ class Config: return self.get_custom_lock(self.GUILD) else: id_data = IdentifierData( - self.unique_identifier, self.MEMBER, (str(guild.id),), (), self.custom_groups + self.cog_name, + self.unique_identifier, + category=self.MEMBER, + primary_key=(str(guild.id),), + identifiers=(), + primary_key_len=2, ) return self._lock_cache.setdefault(id_data, asyncio.Lock()) @@ -1375,10 +1337,33 @@ class Config: ------- asyncio.Lock """ - id_data = IdentifierData( - self.unique_identifier, group_identifier, (), (), self.custom_groups - ) - return self._lock_cache.setdefault(id_data, asyncio.Lock()) + try: + pkey_len, is_custom = ConfigCategory.get_pkey_info( + group_identifier, self.custom_groups + ) + except KeyError: + raise ValueError(f"Custom group not initialized: {group_identifier}") from None + else: + id_data = IdentifierData( + self.cog_name, + self.unique_identifier, + category=group_identifier, + primary_key=(), + identifiers=(), + primary_key_len=pkey_len, + is_custom=is_custom, + ) + return self._lock_cache.setdefault(id_data, asyncio.Lock()) + + +async def migrate(cur_driver_cls: Type[BaseDriver], new_driver_cls: Type[BaseDriver]) -> None: + """Migrate from one driver type to another.""" + # Get custom group data + core_conf = Config.get_core_conf() + core_conf.init_custom("CUSTOM_GROUPS", 2) + all_custom_group_data = await core_conf.custom("CUSTOM_GROUPS").all() + + await cur_driver_cls.migrate_to(new_driver_cls, all_custom_group_data) def _str_key_dict(value: Dict[Any, _T]) -> Dict[str, _T]: diff --git a/redbot/core/core_commands.py b/redbot/core/core_commands.py index 77ba280d3..fabd4cd78 100644 --- a/redbot/core/core_commands.py +++ b/redbot/core/core_commands.py @@ -3,15 +3,12 @@ import contextlib import datetime import importlib import itertools -import json import logging import os -import pathlib import sys import platform import getpass import pip -import tarfile import traceback from collections import namedtuple from pathlib import Path @@ -23,15 +20,18 @@ import aiohttp import discord import pkg_resources -from redbot.core import ( +from . import ( __version__, version_info as red_version_info, VersionInfo, checks, commands, + drivers, errors, i18n, + config, ) +from .utils import create_backup from .utils.predicates import MessagePredicate from .utils.chat_formatting import humanize_timedelta, pagify, box, inline, humanize_list from .commands.requires import PrivilegeLevel @@ -1307,105 +1307,71 @@ class Core(commands.Cog, CoreLogic): @commands.command() @checks.is_owner() - async def backup(self, ctx: commands.Context, *, backup_path: str = None): - """Creates a backup of all data for the instance.""" - if backup_path: - path = pathlib.Path(backup_path) - if not (path.exists() and path.is_dir()): - return await ctx.send( - _("That path doesn't seem to exist. Please provide a valid path.") - ) - from redbot.core.data_manager import basic_config, instance_name - from redbot.core.drivers.red_json import JSON + async def backup(self, ctx: commands.Context, *, backup_dir: str = None): + """Creates a backup of all data for the instance. - data_dir = Path(basic_config["DATA_PATH"]) - if basic_config["STORAGE_TYPE"] == "MongoDB": - from redbot.core.drivers.red_mongo import Mongo - - m = Mongo("Core", "0", **basic_config["STORAGE_DETAILS"]) - db = m.db - collection_names = await db.list_collection_names() - for c_name in collection_names: - if c_name == "Core": - c_data_path = data_dir / basic_config["CORE_PATH_APPEND"] - else: - c_data_path = data_dir / basic_config["COG_PATH_APPEND"] / c_name - docs = await db[c_name].find().to_list(None) - for item in docs: - item_id = str(item.pop("_id")) - target = JSON(c_name, item_id, data_path_override=c_data_path) - target.data = item - await target._save() - backup_filename = "redv3-{}-{}.tar.gz".format( - instance_name, ctx.message.created_at.strftime("%Y-%m-%d %H-%M-%S") - ) - if data_dir.exists(): - if not backup_path: - backup_pth = data_dir.home() - else: - backup_pth = Path(backup_path) - backup_file = backup_pth / backup_filename - - to_backup = [] - exclusions = [ - "__pycache__", - "Lavalink.jar", - os.path.join("Downloader", "lib"), - os.path.join("CogManager", "cogs"), - os.path.join("RepoManager", "repos"), - ] - downloader_cog = ctx.bot.get_cog("Downloader") - if downloader_cog and hasattr(downloader_cog, "_repo_manager"): - repo_output = [] - repo_mgr = downloader_cog._repo_manager - for repo in repo_mgr._repos.values(): - repo_output.append({"url": repo.url, "name": repo.name, "branch": repo.branch}) - repo_filename = data_dir / "cogs" / "RepoManager" / "repos.json" - with open(str(repo_filename), "w") as f: - f.write(json.dumps(repo_output, indent=4)) - instance_data = {instance_name: basic_config} - instance_file = data_dir / "instance.json" - with open(str(instance_file), "w") as instance_out: - instance_out.write(json.dumps(instance_data, indent=4)) - for f in data_dir.glob("**/*"): - if not any(ex in str(f) for ex in exclusions): - to_backup.append(f) - with tarfile.open(str(backup_file), "w:gz") as tar: - for f in to_backup: - tar.add(str(f), recursive=False) - print(str(backup_file)) - await ctx.send( - _("A backup has been made of this instance. It is at {}.").format(backup_file) - ) - if backup_file.stat().st_size > 8_000_000: - await ctx.send(_("This backup is too large to send via DM.")) - return - await ctx.send(_("Would you like to receive a copy via DM? (y/n)")) - - pred = MessagePredicate.yes_or_no(ctx) - try: - await ctx.bot.wait_for("message", check=pred, timeout=60) - except asyncio.TimeoutError: - await ctx.send(_("Response timed out.")) - else: - if pred.result is True: - await ctx.send(_("OK, it's on its way!")) - try: - async with ctx.author.typing(): - await ctx.author.send( - _("Here's a copy of the backup"), - file=discord.File(str(backup_file)), - ) - except discord.Forbidden: - await ctx.send( - _("I don't seem to be able to DM you. Do you have closed DMs?") - ) - except discord.HTTPException: - await ctx.send(_("I could not send the backup file.")) - else: - await ctx.send(_("OK then.")) + You may provide a path to a directory for the backup archive to + be placed in. If the directory does not exist, the bot will + attempt to create it. + """ + if backup_dir is None: + dest = Path.home() else: - await ctx.send(_("That directory doesn't seem to exist...")) + dest = Path(backup_dir) + + driver_cls = drivers.get_driver_class() + if driver_cls != drivers.JsonDriver: + await ctx.send(_("Converting data to JSON for backup...")) + async with ctx.typing(): + await config.migrate(driver_cls, drivers.JsonDriver) + + log.info("Creating backup for this instance...") + try: + backup_fpath = await create_backup(dest) + except OSError as exc: + await ctx.send( + _( + "Creating the backup archive failed! Please check your console or logs for " + "details." + ) + ) + log.exception("Failed to create backup archive", exc_info=exc) + return + + if backup_fpath is None: + await ctx.send(_("Your datapath appears to be empty.")) + return + + log.info("Backup archive created successfully at '%s'", backup_fpath) + await ctx.send( + _("A backup has been made of this instance. It is located at `{path}`.").format( + path=backup_fpath + ) + ) + if backup_fpath.stat().st_size > 8_000_000: + await ctx.send(_("This backup is too large to send via DM.")) + return + await ctx.send(_("Would you like to receive a copy via DM? (y/n)")) + + pred = MessagePredicate.yes_or_no(ctx) + try: + await ctx.bot.wait_for("message", check=pred, timeout=60) + except asyncio.TimeoutError: + await ctx.send(_("Response timed out.")) + else: + if pred.result is True: + await ctx.send(_("OK, it's on its way!")) + try: + async with ctx.author.typing(): + await ctx.author.send( + _("Here's a copy of the backup"), file=discord.File(str(backup_fpath)) + ) + except discord.Forbidden: + await ctx.send(_("I don't seem to be able to DM you. Do you have closed DMs?")) + except discord.HTTPException: + await ctx.send(_("I could not send the backup file.")) + else: + await ctx.send(_("OK then.")) @commands.command() @commands.cooldown(1, 60, commands.BucketType.user) diff --git a/redbot/core/data_manager.py b/redbot/core/data_manager.py index b031cd1cb..be1632809 100644 --- a/redbot/core/data_manager.py +++ b/redbot/core/data_manager.py @@ -224,7 +224,4 @@ def storage_details() -> dict: ------- dict """ - try: - return basic_config["STORAGE_DETAILS"] - except KeyError as e: - raise RuntimeError("Bot basic config has not been loaded yet.") from e + return basic_config.get("STORAGE_DETAILS", {}) diff --git a/redbot/core/drivers/__init__.py b/redbot/core/drivers/__init__.py index 5108049f0..f34727c86 100644 --- a/redbot/core/drivers/__init__.py +++ b/redbot/core/drivers/__init__.py @@ -1,50 +1,110 @@ import enum +from typing import Optional, Type -from .red_base import IdentifierData +from .. import data_manager +from .base import IdentifierData, BaseDriver, ConfigCategory +from .json import JsonDriver +from .mongo import MongoDriver +from .postgres import PostgresDriver -__all__ = ["get_driver", "IdentifierData", "BackendType"] +__all__ = [ + "get_driver", + "ConfigCategory", + "IdentifierData", + "BaseDriver", + "JsonDriver", + "MongoDriver", + "PostgresDriver", + "BackendType", +] class BackendType(enum.Enum): JSON = "JSON" MONGO = "MongoDBV2" MONGOV1 = "MongoDB" + POSTGRES = "Postgres" -def get_driver(type, *args, **kwargs): +_DRIVER_CLASSES = { + BackendType.JSON: JsonDriver, + BackendType.MONGO: MongoDriver, + BackendType.POSTGRES: PostgresDriver, +} + + +def get_driver_class(storage_type: Optional[BackendType] = None) -> Type[BaseDriver]: + """Get the driver class for the given storage type. + + Parameters + ---------- + storage_type : Optional[BackendType] + The backend you want a driver class for. Omit to try to obtain + the backend from data manager. + + Returns + ------- + Type[BaseDriver] + A subclass of `BaseDriver`. + + Raises + ------ + ValueError + If there is no driver for the given storage type. + """ - Selectively import/load driver classes based on the selected type. This - is required so that dependencies can differ between installs (e.g. so that - you don't need to install a mongo dependency if you will just be running a - json data backend). + if storage_type is None: + storage_type = BackendType(data_manager.storage_type()) + try: + return _DRIVER_CLASSES[storage_type] + except KeyError: + raise ValueError(f"No driver found for storage type {storage_type}") from None - .. note:: - See the respective classes for information on what ``args`` and ``kwargs`` - should be. +def get_driver( + cog_name: str, identifier: str, storage_type: Optional[BackendType] = None, **kwargs +): + """Get a driver instance. + + Parameters + ---------- + cog_name : str + The cog's name. + identifier : str + The cog's discriminator. + storage_type : Optional[BackendType] + The backend you want a driver for. Omit to try to obtain the + backend from data manager. + **kwargs + Driver-specific keyword arguments. + + Returns + ------- + BaseDriver + A driver instance. + + Raises + ------ + RuntimeError + If the storage type is MongoV1 or invalid. - :param str type: - One of: json, mongo - :param args: - Dependent on driver type. - :param kwargs: - Dependent on driver type. - :return: - Subclass of :py:class:`.red_base.BaseDriver`. """ - if type == "JSON": - from .red_json import JSON + if storage_type is None: + try: + storage_type = BackendType(data_manager.storage_type()) + except RuntimeError: + storage_type = BackendType.JSON - return JSON(*args, **kwargs) - elif type == "MongoDBV2": - from .red_mongo import Mongo - - return Mongo(*args, **kwargs) - elif type == "MongoDB": - raise RuntimeError( - "Please convert to JSON first to continue using the bot." - " This is a required conversion prior to using the new Mongo driver." - " This message will be updated with a link to the update docs once those" - " docs have been created." - ) - raise RuntimeError("Invalid driver type: '{}'".format(type)) + try: + driver_cls: Type[BaseDriver] = get_driver_class(storage_type) + except ValueError: + if storage_type == BackendType.MONGOV1: + raise RuntimeError( + "Please convert to JSON first to continue using the bot." + " This is a required conversion prior to using the new Mongo driver." + " This message will be updated with a link to the update docs once those" + " docs have been created." + ) from None + else: + raise RuntimeError(f"Invalid driver type: '{storage_type}'") from None + return driver_cls(cog_name, identifier, **kwargs) diff --git a/redbot/core/drivers/base.py b/redbot/core/drivers/base.py new file mode 100644 index 000000000..8d8ed8a36 --- /dev/null +++ b/redbot/core/drivers/base.py @@ -0,0 +1,342 @@ +import abc +import enum +from typing import Tuple, Dict, Any, Union, List, AsyncIterator, Type + +__all__ = ["BaseDriver", "IdentifierData", "ConfigCategory"] + + +class ConfigCategory(str, enum.Enum): + GLOBAL = "GLOBAL" + GUILD = "GUILD" + CHANNEL = "TEXTCHANNEL" + ROLE = "ROLE" + USER = "USER" + MEMBER = "MEMBER" + + @classmethod + def get_pkey_info( + cls, category: Union[str, "ConfigCategory"], custom_group_data: Dict[str, int] + ) -> Tuple[int, bool]: + """Get the full primary key length for the given category, + and whether or not the category is a custom category. + """ + try: + # noinspection PyArgumentList + category_obj = cls(category) + except ValueError: + return custom_group_data[category], True + else: + return _CATEGORY_PKEY_COUNTS[category_obj], False + + +_CATEGORY_PKEY_COUNTS = { + ConfigCategory.GLOBAL: 0, + ConfigCategory.GUILD: 1, + ConfigCategory.CHANNEL: 1, + ConfigCategory.ROLE: 1, + ConfigCategory.USER: 1, + ConfigCategory.MEMBER: 2, +} + + +class IdentifierData: + def __init__( + self, + cog_name: str, + uuid: str, + category: str, + primary_key: Tuple[str, ...], + identifiers: Tuple[str, ...], + primary_key_len: int, + is_custom: bool = False, + ): + self._cog_name = cog_name + self._uuid = uuid + self._category = category + self._primary_key = primary_key + self._identifiers = identifiers + self.primary_key_len = primary_key_len + self._is_custom = is_custom + + @property + def cog_name(self) -> str: + return self._cog_name + + @property + def uuid(self) -> str: + return self._uuid + + @property + def category(self) -> str: + return self._category + + @property + def primary_key(self) -> Tuple[str, ...]: + return self._primary_key + + @property + def identifiers(self) -> Tuple[str, ...]: + return self._identifiers + + @property + def is_custom(self) -> bool: + return self._is_custom + + def __repr__(self) -> str: + return ( + f"" + ) + + def __eq__(self, other) -> bool: + if not isinstance(other, IdentifierData): + return False + return ( + self.uuid == other.uuid + and self.category == other.category + and self.primary_key == other.primary_key + and self.identifiers == other.identifiers + ) + + def __hash__(self) -> int: + return hash((self.uuid, self.category, self.primary_key, self.identifiers)) + + def add_identifier(self, *identifier: str) -> "IdentifierData": + if not all(isinstance(i, str) for i in identifier): + raise ValueError("Identifiers must be strings.") + + return IdentifierData( + self.cog_name, + self.uuid, + self.category, + self.primary_key, + self.identifiers + identifier, + self.primary_key_len, + is_custom=self.is_custom, + ) + + def to_tuple(self) -> Tuple[str, ...]: + return tuple( + filter( + None, + (self.cog_name, self.uuid, self.category, *self.primary_key, *self.identifiers), + ) + ) + + +class BaseDriver(abc.ABC): + def __init__(self, cog_name: str, identifier: str, **kwargs): + self.cog_name = cog_name + self.unique_cog_identifier = identifier + + @classmethod + @abc.abstractmethod + async def initialize(cls, **storage_details) -> None: + """ + Initialize this driver. + + Parameters + ---------- + **storage_details + The storage details required to initialize this driver. + Should be the same as :func:`data_manager.storage_details` + + Raises + ------ + MissingExtraRequirements + If initializing the driver requires an extra which isn't + installed. + + """ + raise NotImplementedError + + @classmethod + @abc.abstractmethod + async def teardown(cls) -> None: + """ + Tear down this driver. + """ + raise NotImplementedError + + @staticmethod + @abc.abstractmethod + def get_config_details() -> Dict[str, Any]: + """ + Asks users for additional configuration information necessary + to use this config driver. + + Returns + ------- + Dict[str, Any] + Dictionary of configuration details. + """ + raise NotImplementedError + + @abc.abstractmethod + async def get(self, identifier_data: IdentifierData) -> Any: + """ + Finds the value indicate by the given identifiers. + + Parameters + ---------- + identifier_data + + Returns + ------- + Any + Stored value. + """ + raise NotImplementedError + + @abc.abstractmethod + async def set(self, identifier_data: IdentifierData, value=None) -> None: + """ + Sets the value of the key indicated by the given identifiers. + + Parameters + ---------- + identifier_data + value + Any JSON serializable python object. + """ + raise NotImplementedError + + @abc.abstractmethod + async def clear(self, identifier_data: IdentifierData) -> None: + """ + Clears out the value specified by the given identifiers. + + Equivalent to using ``del`` on a dict. + + Parameters + ---------- + identifier_data + """ + raise NotImplementedError + + @classmethod + @abc.abstractmethod + def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]: + """Get info for cogs which have data stored on this backend. + + Yields + ------ + Tuple[str, str] + Asynchronously yields (cog_name, cog_identifier) tuples. + + """ + raise NotImplementedError + + @classmethod + async def migrate_to( + cls, + new_driver_cls: Type["BaseDriver"], + all_custom_group_data: Dict[str, Dict[str, Dict[str, int]]], + ) -> None: + """Migrate data from this backend to another. + + Both drivers must be initialized beforehand. + + This will only move the data - no instance metadata is modified + as a result of this operation. + + Parameters + ---------- + new_driver_cls + Subclass of `BaseDriver`. + all_custom_group_data : Dict[str, Dict[str, Dict[str, int]]] + Dict mapping cog names, to cog IDs, to custom groups, to + primary key lengths. + + """ + # Backend-agnostic method of migrating from one driver to another. + async for cog_name, cog_id in cls.aiter_cogs(): + this_driver = cls(cog_name, cog_id) + other_driver = new_driver_cls(cog_name, cog_id) + custom_group_data = all_custom_group_data.get(cog_name, {}).get(cog_id, {}) + exported_data = await this_driver.export_data(custom_group_data) + await other_driver.import_data(exported_data, custom_group_data) + + @classmethod + async def delete_all_data(cls, **kwargs) -> None: + """Delete all data being stored by this driver. + + The driver must be initialized before this operation. + + The BaseDriver provides a generic method which may be overriden + by subclasses. + + Parameters + ---------- + **kwargs + Driver-specific kwargs to change the way this method + operates. + + """ + async for cog_name, cog_id in cls.aiter_cogs(): + driver = cls(cog_name, cog_id) + await driver.clear(IdentifierData(cog_name, cog_id, "", (), (), 0)) + + @staticmethod + def _split_primary_key( + category: Union[ConfigCategory, str], + custom_group_data: Dict[str, int], + data: Dict[str, Any], + ) -> List[Tuple[Tuple[str, ...], Dict[str, Any]]]: + pkey_len = ConfigCategory.get_pkey_info(category, custom_group_data)[0] + if pkey_len == 0: + return [((), data)] + + def flatten(levels_remaining, currdata, parent_key=()): + items = [] + for _k, _v in currdata.items(): + new_key = parent_key + (_k,) + if levels_remaining > 1: + items.extend(flatten(levels_remaining - 1, _v, new_key).items()) + else: + items.append((new_key, _v)) + return dict(items) + + ret = [] + for k, v in flatten(pkey_len, data).items(): + ret.append((k, v)) + return ret + + async def export_data( + self, custom_group_data: Dict[str, int] + ) -> List[Tuple[str, Dict[str, Any]]]: + categories = [c.value for c in ConfigCategory] + categories.extend(custom_group_data.keys()) + + ret = [] + for c in categories: + ident_data = IdentifierData( + self.cog_name, + self.unique_cog_identifier, + c, + (), + (), + *ConfigCategory.get_pkey_info(c, custom_group_data), + ) + try: + data = await self.get(ident_data) + except KeyError: + continue + ret.append((c, data)) + return ret + + async def import_data( + self, cog_data: List[Tuple[str, Dict[str, Any]]], custom_group_data: Dict[str, int] + ) -> None: + for category, all_data in cog_data: + splitted_pkey = self._split_primary_key(category, custom_group_data, all_data) + for pkey, data in splitted_pkey: + ident_data = IdentifierData( + self.cog_name, + self.unique_cog_identifier, + category, + pkey, + (), + *ConfigCategory.get_pkey_info(category, custom_group_data), + ) + await self.set(ident_data, data) diff --git a/redbot/core/drivers/red_json.py b/redbot/core/drivers/json.py similarity index 71% rename from redbot/core/drivers/red_json.py rename to redbot/core/drivers/json.py index 54e73e715..bd66c84f1 100644 --- a/redbot/core/drivers/red_json.py +++ b/redbot/core/drivers/json.py @@ -5,12 +5,13 @@ import os import pickle import weakref from pathlib import Path -from typing import Any, Dict +from typing import Any, AsyncIterator, Dict, Optional, Tuple from uuid import uuid4 -from .red_base import BaseDriver, IdentifierData +from .. import data_manager, errors +from .base import BaseDriver, IdentifierData, ConfigCategory -__all__ = ["JSON"] +__all__ = ["JsonDriver"] _shared_datastore = {} @@ -35,9 +36,10 @@ def finalize_driver(cog_name): _finalizers.remove(f) -class JSON(BaseDriver): +# noinspection PyProtectedMember +class JsonDriver(BaseDriver): """ - Subclass of :py:class:`.red_base.BaseDriver`. + Subclass of :py:class:`.BaseDriver`. .. py:attribute:: file_name @@ -50,27 +52,26 @@ class JSON(BaseDriver): def __init__( self, - cog_name, - identifier, + cog_name: str, + identifier: str, *, - data_path_override: Path = None, - file_name_override: str = "settings.json" + data_path_override: Optional[Path] = None, + file_name_override: str = "settings.json", ): super().__init__(cog_name, identifier) self.file_name = file_name_override - if data_path_override: + if data_path_override is not None: self.data_path = data_path_override + elif cog_name == "Core" and identifier == "0": + self.data_path = data_manager.core_data_path() else: - self.data_path = Path.cwd() / "cogs" / ".data" / self.cog_name + self.data_path = data_manager.cog_data_path(raw_name=cog_name) self.data_path.mkdir(parents=True, exist_ok=True) self.data_path = self.data_path / self.file_name self._lock = asyncio.Lock() self._load_data() - async def has_valid_connection(self) -> bool: - return True - @property def data(self): return _shared_datastore.get(self.cog_name) @@ -79,6 +80,21 @@ class JSON(BaseDriver): def data(self, value): _shared_datastore[self.cog_name] = value + @classmethod + async def initialize(cls, **storage_details) -> None: + # No initializing to do + return + + @classmethod + async def teardown(cls) -> None: + # No tearing down to do + return + + @staticmethod + def get_config_details() -> Dict[str, Any]: + # No driver-specific configuration needed + return {} + def _load_data(self): if self.cog_name not in _driver_counts: _driver_counts[self.cog_name] = 0 @@ -111,30 +127,32 @@ class JSON(BaseDriver): async def get(self, identifier_data: IdentifierData): partial = self.data - full_identifiers = identifier_data.to_tuple() + full_identifiers = identifier_data.to_tuple()[1:] for i in full_identifiers: partial = partial[i] return pickle.loads(pickle.dumps(partial, -1)) async def set(self, identifier_data: IdentifierData, value=None): partial = self.data - full_identifiers = identifier_data.to_tuple() + full_identifiers = identifier_data.to_tuple()[1:] # This is both our deepcopy() and our way of making sure this value is actually JSON # serializable. value_copy = json.loads(json.dumps(value)) async with self._lock: for i in full_identifiers[:-1]: - if i not in partial: - partial[i] = {} - partial = partial[i] - partial[full_identifiers[-1]] = value_copy + try: + partial = partial.setdefault(i, {}) + except AttributeError: + # Tried to set sub-field of non-object + raise errors.CannotSetSubfield + partial[full_identifiers[-1]] = value_copy await self._save() async def clear(self, identifier_data: IdentifierData): partial = self.data - full_identifiers = identifier_data.to_tuple() + full_identifiers = identifier_data.to_tuple()[1:] try: for i in full_identifiers[:-1]: partial = partial[i] @@ -149,14 +167,32 @@ class JSON(BaseDriver): else: await self._save() + @classmethod + async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]: + yield "Core", "0" + for _dir in data_manager.cog_data_path().iterdir(): + fpath = _dir / "settings.json" + if not fpath.exists(): + continue + with fpath.open() as f: + try: + data = json.load(f) + except json.JSONDecodeError: + continue + if not isinstance(data, dict): + continue + for cog, inner in data.items(): + if not isinstance(inner, dict): + continue + for cog_id in inner: + yield cog, cog_id + async def import_data(self, cog_data, custom_group_data): def update_write_data(identifier_data: IdentifierData, _data): partial = self.data - idents = identifier_data.to_tuple() + idents = identifier_data.to_tuple()[1:] for ident in idents[:-1]: - if ident not in partial: - partial[ident] = {} - partial = partial[ident] + partial = partial.setdefault(ident, {}) partial[idents[-1]] = _data async with self._lock: @@ -164,12 +200,12 @@ class JSON(BaseDriver): splitted_pkey = self._split_primary_key(category, custom_group_data, all_data) for pkey, data in splitted_pkey: ident_data = IdentifierData( + self.cog_name, self.unique_cog_identifier, category, pkey, (), - custom_group_data, - is_custom=category in custom_group_data, + *ConfigCategory.get_pkey_info(category, custom_group_data), ) update_write_data(ident_data, data) await self._save() @@ -178,9 +214,6 @@ class JSON(BaseDriver): loop = asyncio.get_running_loop() await loop.run_in_executor(None, _save_json, self.data_path, self.data) - def get_config_details(self): - return - def _save_json(path: Path, data: Dict[str, Any]) -> None: """ diff --git a/redbot/core/drivers/log.py b/redbot/core/drivers/log.py new file mode 100644 index 000000000..03b5e8bf8 --- /dev/null +++ b/redbot/core/drivers/log.py @@ -0,0 +1,11 @@ +import functools +import logging +import os + +if os.getenv("RED_INSPECT_DRIVER_QUERIES"): + LOGGING_INVISIBLE = logging.DEBUG +else: + LOGGING_INVISIBLE = 0 + +log = logging.getLogger("red.driver") +log.invisible = functools.partial(log.log, LOGGING_INVISIBLE) diff --git a/redbot/core/drivers/red_mongo.py b/redbot/core/drivers/mongo.py similarity index 64% rename from redbot/core/drivers/red_mongo.py rename to redbot/core/drivers/mongo.py index d500f32db..5d0c5d795 100644 --- a/redbot/core/drivers/red_mongo.py +++ b/redbot/core/drivers/mongo.py @@ -2,77 +2,110 @@ import contextlib import itertools import re from getpass import getpass -from typing import Match, Pattern, Tuple, Any, Dict, Iterator, List +from typing import Match, Pattern, Tuple, Optional, AsyncIterator, Any, Dict, Iterator, List from urllib.parse import quote_plus -import motor.core -import motor.motor_asyncio -import pymongo.errors +try: + # pylint: disable=import-error + import pymongo.errors + import motor.core + import motor.motor_asyncio +except ModuleNotFoundError: + motor = None + pymongo = None -from .red_base import BaseDriver, IdentifierData +from .. import errors +from .base import BaseDriver, IdentifierData -__all__ = ["Mongo"] +__all__ = ["MongoDriver"] -_conn = None - - -def _initialize(**kwargs): - uri = kwargs.get("URI", "mongodb") - host = kwargs["HOST"] - port = kwargs["PORT"] - admin_user = kwargs["USERNAME"] - admin_pass = kwargs["PASSWORD"] - db_name = kwargs.get("DB_NAME", "default_db") - - if port is 0: - ports = "" - else: - ports = ":{}".format(port) - - if admin_user is not None and admin_pass is not None: - url = "{}://{}:{}@{}{}/{}".format( - uri, quote_plus(admin_user), quote_plus(admin_pass), host, ports, db_name - ) - else: - url = "{}://{}{}/{}".format(uri, host, ports, db_name) - - global _conn - _conn = motor.motor_asyncio.AsyncIOMotorClient(url, retryWrites=True) - - -class Mongo(BaseDriver): +class MongoDriver(BaseDriver): """ - Subclass of :py:class:`.red_base.BaseDriver`. + Subclass of :py:class:`.BaseDriver`. """ - def __init__(self, cog_name, identifier, **kwargs): - super().__init__(cog_name, identifier) + _conn: Optional["motor.motor_asyncio.AsyncIOMotorClient"] = None - if _conn is None: - _initialize(**kwargs) + @classmethod + async def initialize(cls, **storage_details) -> None: + if motor is None: + raise errors.MissingExtraRequirements( + "Red must be installed with the [mongo] extra to use the MongoDB driver" + ) + uri = storage_details.get("URI", "mongodb") + host = storage_details["HOST"] + port = storage_details["PORT"] + user = storage_details["USERNAME"] + password = storage_details["PASSWORD"] + database = storage_details.get("DB_NAME", "default_db") - async def has_valid_connection(self) -> bool: - # Maybe fix this? - return True + if port is 0: + ports = "" + else: + ports = ":{}".format(port) + + if user is not None and password is not None: + url = "{}://{}:{}@{}{}/{}".format( + uri, quote_plus(user), quote_plus(password), host, ports, database + ) + else: + url = "{}://{}{}/{}".format(uri, host, ports, database) + + cls._conn = motor.motor_asyncio.AsyncIOMotorClient(url, retryWrites=True) + + @classmethod + async def teardown(cls) -> None: + if cls._conn is not None: + cls._conn.close() + + @staticmethod + def get_config_details(): + while True: + uri = input("Enter URI scheme (mongodb or mongodb+srv): ") + if uri is "": + uri = "mongodb" + + if uri in ["mongodb", "mongodb+srv"]: + break + else: + print("Invalid URI scheme") + + host = input("Enter host address: ") + if uri is "mongodb": + port = int(input("Enter host port: ")) + else: + port = 0 + + admin_uname = input("Enter login username: ") + admin_password = getpass("Enter login password: ") + + db_name = input("Enter mongodb database name: ") + + if admin_uname == "": + admin_uname = admin_password = None + + ret = { + "HOST": host, + "PORT": port, + "USERNAME": admin_uname, + "PASSWORD": admin_password, + "DB_NAME": db_name, + "URI": uri, + } + return ret @property - def db(self) -> motor.core.Database: + def db(self) -> "motor.core.Database": """ Gets the mongo database for this cog's name. - .. warning:: - - Right now this will cause a new connection to be made every time the - database is accessed. We will want to create a connection pool down the - line to limit the number of connections. - :return: PyMongo Database object. """ - return _conn.get_database() + return self._conn.get_database() - def get_collection(self, category: str) -> motor.core.Collection: + def get_collection(self, category: str) -> "motor.core.Collection": """ Gets a specified collection within the PyMongo database for this cog. @@ -85,12 +118,13 @@ class Mongo(BaseDriver): """ return self.db[self.cog_name][category] - def get_primary_key(self, identifier_data: IdentifierData) -> Tuple[str]: + @staticmethod + def get_primary_key(identifier_data: IdentifierData) -> Tuple[str, ...]: # noinspection PyTypeChecker return identifier_data.primary_key async def rebuild_dataset( - self, identifier_data: IdentifierData, cursor: motor.motor_asyncio.AsyncIOMotorCursor + self, identifier_data: IdentifierData, cursor: "motor.motor_asyncio.AsyncIOMotorCursor" ): ret = {} async for doc in cursor: @@ -141,16 +175,16 @@ class Mongo(BaseDriver): async def set(self, identifier_data: IdentifierData, value=None): uuid = self._escape_key(identifier_data.uuid) primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data))) + dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers)) if isinstance(value, dict): if len(value) == 0: await self.clear(identifier_data) return value = self._escape_dict_keys(value) mongo_collection = self.get_collection(identifier_data.category) - pkey_len = self.get_pkey_len(identifier_data) num_pkeys = len(primary_key) - if num_pkeys >= pkey_len: + if num_pkeys >= identifier_data.primary_key_len: # We're setting at the document level or below. dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers)) if dot_identifiers: @@ -158,11 +192,23 @@ class Mongo(BaseDriver): else: update_stmt = {"$set": value} - await mongo_collection.update_one( - {"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}}, - update=update_stmt, - upsert=True, - ) + try: + await mongo_collection.update_one( + {"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}}, + update=update_stmt, + upsert=True, + ) + except pymongo.errors.WriteError as exc: + if exc.args and exc.args[0].startswith("Cannot create field"): + # There's a bit of a failing edge case here... + # If we accidentally set the sub-field of an array, and the key happens to be a + # digit, it will successfully set the value in the array, and not raise an + # error. This is different to how other drivers would behave, and could lead to + # unexpected behaviour. + raise errors.CannotSetSubfield + else: + # Unhandled driver exception, should expose. + raise else: # We're setting above the document level. @@ -171,15 +217,17 @@ class Mongo(BaseDriver): # We'll do it in a transaction so we can roll-back in case something goes horribly # wrong. pkey_filter = self.generate_primary_key_filter(identifier_data) - async with await _conn.start_session() as session: + async with await self._conn.start_session() as session: with contextlib.suppress(pymongo.errors.CollectionInvalid): # Collections must already exist when inserting documents within a transaction - await _conn.get_database().create_collection(mongo_collection.full_name) + await self.db.create_collection(mongo_collection.full_name) try: async with session.start_transaction(): await mongo_collection.delete_many(pkey_filter, session=session) await mongo_collection.insert_many( - self.generate_documents_to_insert(uuid, primary_key, value, pkey_len), + self.generate_documents_to_insert( + uuid, primary_key, value, identifier_data.primary_key_len + ), session=session, ) except pymongo.errors.OperationFailure: @@ -218,7 +266,7 @@ class Mongo(BaseDriver): # What's left of `value` should be the new documents needing to be inserted. to_insert = self.generate_documents_to_insert( - uuid, primary_key, value, pkey_len + uuid, primary_key, value, identifier_data.primary_key_len ) requests = list( itertools.chain( @@ -289,6 +337,59 @@ class Mongo(BaseDriver): for result in results: await db[result["name"]].delete_many(pkey_filter) + @classmethod + async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]: + db = cls._conn.get_database() + for collection_name in await db.list_collection_names(): + parts = collection_name.split(".") + if not len(parts) == 2: + continue + cog_name = parts[0] + for cog_id in await db[collection_name].distinct("_id.RED_uuid"): + yield cog_name, cog_id + + @classmethod + async def delete_all_data( + cls, *, interactive: bool = False, drop_db: Optional[bool] = None, **kwargs + ) -> None: + """Delete all data being stored by this driver. + + Parameters + ---------- + interactive : bool + Set to ``True`` to allow the method to ask the user for + input from the console, regarding the other unset parameters + for this method. + drop_db : Optional[bool] + Set to ``True`` to drop the entire database for the current + bot's instance. Otherwise, collections which appear to be + storing bot data will be dropped. + + """ + if interactive is True and drop_db is None: + print( + "Please choose from one of the following options:\n" + " 1. Drop the entire MongoDB database for this instance, or\n" + " 2. Delete all of Red's data within this database, without dropping the database " + "itself." + ) + options = ("1", "2") + while True: + resp = input("> ") + try: + drop_db = bool(options.index(resp)) + except ValueError: + print("Please type a number corresponding to one of the options.") + else: + break + db = cls._conn.get_database() + if drop_db is True: + await cls._conn.drop_database(db) + else: + async with await cls._conn.start_session() as session: + async for cog_name, cog_id in cls.aiter_cogs(): + await db.drop_collection(db[cog_name], session=session) + @staticmethod def _escape_key(key: str) -> str: return _SPECIAL_CHAR_PATTERN.sub(_replace_with_escaped, key) @@ -344,40 +445,3 @@ _CHAR_ESCAPES = { def _replace_with_unescaped(match: Match[str]) -> str: return _CHAR_ESCAPES[match[0]] - - -def get_config_details(): - uri = None - while True: - uri = input("Enter URI scheme (mongodb or mongodb+srv): ") - if uri is "": - uri = "mongodb" - - if uri in ["mongodb", "mongodb+srv"]: - break - else: - print("Invalid URI scheme") - - host = input("Enter host address: ") - if uri is "mongodb": - port = int(input("Enter host port: ")) - else: - port = 0 - - admin_uname = input("Enter login username: ") - admin_password = getpass("Enter login password: ") - - db_name = input("Enter mongodb database name: ") - - if admin_uname == "": - admin_uname = admin_password = None - - ret = { - "HOST": host, - "PORT": port, - "USERNAME": admin_uname, - "PASSWORD": admin_password, - "DB_NAME": db_name, - "URI": uri, - } - return ret diff --git a/redbot/core/drivers/postgres/__init__.py b/redbot/core/drivers/postgres/__init__.py new file mode 100644 index 000000000..870df18c0 --- /dev/null +++ b/redbot/core/drivers/postgres/__init__.py @@ -0,0 +1,3 @@ +from .postgres import PostgresDriver + +__all__ = ["PostgresDriver"] diff --git a/redbot/core/drivers/postgres/ddl.sql b/redbot/core/drivers/postgres/ddl.sql new file mode 100644 index 000000000..9d1c3c7b0 --- /dev/null +++ b/redbot/core/drivers/postgres/ddl.sql @@ -0,0 +1,839 @@ +/* + ************************************************************ + * PostgreSQL driver Data Definition Language (DDL) Script. * + ************************************************************ + */ + +CREATE SCHEMA IF NOT EXISTS red_config; +CREATE SCHEMA IF NOT EXISTS red_utils; + +DO $$ +BEGIN + PERFORM 'red_config.identifier_data'::regtype; +EXCEPTION + WHEN UNDEFINED_OBJECT THEN + CREATE TYPE red_config.identifier_data AS ( + cog_name text, + cog_id text, + category text, + pkeys text[], + identifiers text[], + pkey_len integer, + is_custom boolean + ); +END; +$$; + + +CREATE OR REPLACE FUNCTION + /* + * Create the config schema and/or table if they do not exist yet. + */ + red_config.maybe_create_table( + id_data red_config.identifier_data + ) + RETURNS void + LANGUAGE 'plpgsql' + AS $$ + DECLARE + schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id); + schema_exists CONSTANT boolean := exists( + SELECT 1 + FROM red_config.red_cogs t + WHERE t.cog_name = id_data.cog_name AND t.cog_id = id_data.cog_id); + table_exists CONSTANT boolean := schema_exists AND exists( + SELECT 1 + FROM information_schema.tables + WHERE table_schema = schemaname AND table_name = id_data.category); + + BEGIN + IF NOT schema_exists THEN + PERFORM red_config.create_schema(id_data.cog_name, id_data.cog_id); + END IF; + IF NOT table_exists THEN + PERFORM red_config.create_table(id_data); + END IF; + END; +$$; + + +CREATE OR REPLACE FUNCTION + /* + * Create the config schema for the given cog. + */ + red_config.create_schema(new_cog_name text, new_cog_id text, OUT schemaname text) + RETURNS text + LANGUAGE 'plpgsql' + AS $$ + BEGIN + schemaname := concat_ws('.', new_cog_name, new_cog_id); + + EXECUTE format('CREATE SCHEMA IF NOT EXISTS %I', schemaname); + + INSERT INTO red_config.red_cogs AS t VALUES(new_cog_name, new_cog_id, schemaname) + ON CONFLICT(cog_name, cog_id) DO UPDATE + SET + schemaname = excluded.schemaname; + END; +$$; + + +CREATE OR REPLACE FUNCTION + /* + * Create the config table for the given category. + */ + red_config.create_table(id_data red_config.identifier_data) + RETURNS void + LANGUAGE 'plpgsql' + AS $$ + DECLARE + schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id); + constraintname CONSTANT text := id_data.category||'_pkey'; + pkey_columns CONSTANT text := red_utils.gen_pkey_columns(1, id_data.pkey_len); + pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom); + pkey_column_definitions CONSTANT text := red_utils.gen_pkey_column_definitions( + 1, id_data.pkey_len, pkey_type); + + BEGIN + EXECUTE format( + $query$ + CREATE TABLE IF NOT EXISTS %I.%I ( + %s, + json_data jsonb DEFAULT '{}' NOT NULL, + CONSTRAINT %I PRIMARY KEY (%s) + ) + $query$, + schemaname, + id_data.category, + pkey_column_definitions, + constraintname, + pkey_columns); + END; +$$; + + +CREATE OR REPLACE FUNCTION + /* + * Get config data. + * + * - When `pkeys` is a full primary key, all or part of a document + * will be returned. + * - When `pkeys` is not a full primary key, documents will be + * aggregated together into a single JSONB object, with primary keys + * as keys mapping to the documents. + */ + red_config.get( + id_data red_config.identifier_data, + OUT result jsonb + ) + LANGUAGE 'plpgsql' + STABLE + PARALLEL SAFE + AS $$ + DECLARE + schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id); + num_pkeys CONSTANT integer := coalesce(array_length(id_data.pkeys, 1), 0); + num_missing_pkeys CONSTANT integer := id_data.pkey_len - num_pkeys; + pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom); + whereclause CONSTANT text := red_utils.gen_whereclause(num_pkeys, pkey_type); + + missing_pkey_columns text; + + BEGIN + IF num_missing_pkeys <= 0 THEN + -- No missing primary keys: we're getting all or part of a document. + EXECUTE format( + 'SELECT json_data #> $2 FROM %I.%I WHERE %s', + schemaname, + id_data.category, + whereclause) + INTO result + USING id_data.pkeys, id_data.identifiers; + + ELSIF num_missing_pkeys = 1 THEN + -- 1 missing primary key: we can use the built-in jsonb_object_agg() aggregate function. + EXECUTE format( + 'SELECT jsonb_object_agg(%I::text, json_data) FROM %I.%I WHERE %s', + 'primary_key_'||id_data.pkey_len, + schemaname, + id_data.category, + whereclause) + INTO result + USING id_data.pkeys; + ELSE + -- Multiple missing primary keys: we must use our custom red_utils.jsonb_object_agg2() + -- aggregate function. + missing_pkey_columns := red_utils.gen_pkey_columns_casted(num_pkeys + 1, id_data.pkey_len); + + EXECUTE format( + 'SELECT red_utils.jsonb_object_agg2(json_data, %s) FROM %I.%I WHERE %s', + missing_pkey_columns, + schemaname, + id_data.category, + whereclause) + INTO result + USING id_data.pkeys; + END IF; + END; +$$; + + +CREATE OR REPLACE FUNCTION + /* + * Set config data. + * + * - When `pkeys` is a full primary key, all or part of a document + * will be set. + * - When `pkeys` is not a full set, multiple documents will be + * replaced or removed - `new_value` must be a JSONB object mapping + * primary keys to the new documents. + * + * Raises `error_in_assignment` error when trying to set a sub-key + * of a non-document type. + */ + red_config.set( + id_data red_config.identifier_data, + new_value jsonb + ) + RETURNS void + LANGUAGE 'plpgsql' + AS $$ + DECLARE + schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id); + constraintname CONSTANT text := id_data.category||'_pkey'; + num_pkeys CONSTANT integer := coalesce(array_length(id_data.pkeys, 1), 0); + num_missing_pkeys CONSTANT integer := id_data.pkey_len - num_pkeys; + pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom); + pkey_placeholders CONSTANT text := red_utils.gen_pkey_placeholders(num_pkeys, pkey_type); + + new_document jsonb; + pkey_column_definitions text; + whereclause text; + missing_pkey_columns text; + + BEGIN + PERFORM red_config.maybe_create_table(id_data); + + IF num_missing_pkeys = 0 THEN + -- Setting all or part of a document + new_document := red_utils.jsonb_set2('{}', new_value, VARIADIC id_data.identifiers); + + EXECUTE format( + $query$ + INSERT INTO %I.%I AS t VALUES (%s, $2) + ON CONFLICT ON CONSTRAINT %I DO UPDATE + SET + json_data = red_utils.jsonb_set2(t.json_data, $3, VARIADIC $4) + $query$, + schemaname, + id_data.category, + pkey_placeholders, + constraintname) + USING id_data.pkeys, new_document, new_value, id_data.identifiers; + + ELSE + -- Setting multiple documents + whereclause := red_utils.gen_whereclause(num_pkeys, pkey_type); + missing_pkey_columns := red_utils.gen_pkey_columns_casted( + num_pkeys + 1, id_data.pkey_len, pkey_type); + pkey_column_definitions := red_utils.gen_pkey_column_definitions(num_pkeys + 1, id_data.pkey_len); + + -- Delete all documents which we're setting first, since we don't know whether they'll be + -- replaced by the subsequent INSERT. + EXECUTE format('DELETE FROM %I.%I WHERE %s', schemaname, id_data.category, whereclause) + USING id_data.pkeys; + + -- Insert all new documents + EXECUTE format( + $query$ + INSERT INTO %I.%I AS t + SELECT %s, json_data + FROM red_utils.generate_rows_from_object($2, $3) AS f(%s, json_data jsonb) + ON CONFLICT ON CONSTRAINT %I DO UPDATE + SET + json_data = excluded.json_data + $query$, + schemaname, + id_data.category, + concat_ws(', ', pkey_placeholders, missing_pkey_columns), + pkey_column_definitions, + constraintname) + USING id_data.pkeys, new_value, num_missing_pkeys; + END IF; + END; +$$; + + +CREATE OR REPLACE FUNCTION + /* + * Clear config data. + * + * - When `identifiers` is not empty, this will clear a key from a + * document. + * - When `identifiers` is empty and `pkeys` is not empty, it will + * delete one or more documents. + * - When `pkeys` is empty, it will drop the whole table. + * - When `id_data.category` is NULL or an empty string, it will drop + * the whole schema. + * + * Has no effect when the document or key does not exist. + */ + red_config.clear( + id_data red_config.identifier_data + ) + RETURNS void + LANGUAGE 'plpgsql' + AS $$ + DECLARE + schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id); + num_pkeys CONSTANT integer := coalesce(array_length(id_data.pkeys, 1), 0); + num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0); + pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom); + + whereclause text; + + BEGIN + IF num_identifiers > 0 THEN + -- Popping a key from a document or nested document. + whereclause := red_utils.gen_whereclause(num_pkeys, pkey_type); + + EXECUTE format( + $query$ + UPDATE %I.%I AS t + SET + json_data = t.json_data #- $2 + WHERE %s + $query$, + schemaname, + id_data.category, + whereclause) + USING id_data.pkeys, id_data.identifiers; + + ELSIF num_pkeys > 0 THEN + -- Deleting one or many documents + whereclause := red_utils.gen_whereclause(num_pkeys, pkey_type); + + EXECUTE format('DELETE FROM %I.%I WHERE %s', schemaname, id_data.category, whereclause) + USING id_data.pkeys; + + ELSIF id_data.category IS NOT NULL AND id_data.category != '' THEN + -- Deleting an entire category + EXECUTE format('DROP TABLE %I.%I CASCADE', schemaname, id_data.category); + + ELSE + -- Deleting an entire cog's data + EXECUTE format('DROP SCHEMA %I CASCADE', schemaname); + + DELETE FROM red_config.red_cogs + WHERE cog_name = id_data.cog_name AND cog_id = id_data.cog_id; + END IF; + END; +$$; + + +CREATE OR REPLACE FUNCTION + /* + * Increment a number within a document. + * + * If the value doesn't already exist, it is inserted as + * `default_value + amount`. + * + * Raises 'wrong_object_type' error when trying to increment a + * non-numeric value. + */ + red_config.inc( + id_data red_config.identifier_data, + amount numeric, + default_value numeric, + OUT result numeric + ) + LANGUAGE 'plpgsql' + AS $$ + DECLARE + schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id); + num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0); + pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom); + whereclause CONSTANT text := red_utils.gen_whereclause(id_data.pkey_len, pkey_type); + + new_document jsonb; + existing_document jsonb; + existing_value jsonb; + pkey_placeholders text; + + BEGIN + IF num_identifiers = 0 THEN + -- Without identifiers, there's no chance we're actually incrementing a number + RAISE EXCEPTION 'Cannot increment document(s)' + USING ERRCODE = 'wrong_object_type'; + END IF; + + PERFORM red_config.maybe_create_table(id_data); + + -- Look for the existing document + EXECUTE format( + 'SELECT json_data FROM %I.%I WHERE %s', + schemaname, + id_data.category, + whereclause) + INTO existing_document USING id_data.pkeys; + + IF existing_document IS NULL THEN + -- We need to insert a new document + result := default_value + amount; + new_document := red_utils.jsonb_set2('{}', result, VARIADIC id_data.identifiers); + pkey_placeholders := red_utils.gen_pkey_placeholders(id_data.pkey_len, pkey_type); + + EXECUTE format( + 'INSERT INTO %I.%I VALUES(%s, $2)', + schemaname, + id_data.category, + pkey_placeholders) + USING id_data.pkeys, new_document; + + ELSE + -- We need to update the existing document + existing_value := existing_document #> id_data.identifiers; + + IF existing_value IS NULL THEN + result := default_value + amount; + + ELSIF jsonb_typeof(existing_value) = 'number' THEN + result := existing_value::text::numeric + amount; + + ELSE + RAISE EXCEPTION 'Cannot increment non-numeric value %', existing_value + USING ERRCODE = 'wrong_object_type'; + END IF; + + new_document := red_utils.jsonb_set2( + existing_document, to_jsonb(result), id_data.identifiers); + + EXECUTE format( + 'UPDATE %I.%I SET json_data = $2 WHERE %s', + schemaname, + id_data.category, + whereclause) + USING id_data.pkeys, new_document; + END IF; + END; +$$; + + +CREATE OR REPLACE FUNCTION + /* + * Toggle a boolean within a document. + * + * If the value doesn't already exist, it is inserted as `NOT + * default_value`. + * + * Raises 'wrong_object_type' error when trying to toggle a + * non-boolean value. + */ + red_config.toggle( + id_data red_config.identifier_data, + default_value boolean, + OUT result boolean + ) + LANGUAGE 'plpgsql' + AS $$ + DECLARE + schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id); + num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0); + pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom); + whereclause CONSTANT text := red_utils.gen_whereclause(id_data.pkey_len, pkey_type); + + new_document jsonb; + existing_document jsonb; + existing_value jsonb; + pkey_placeholders text; + + BEGIN + IF num_identifiers = 0 THEN + -- Without identifiers, there's no chance we're actually toggling a boolean + RAISE EXCEPTION 'Cannot increment document(s)' + USING ERRCODE = 'wrong_object_type'; + END IF; + + PERFORM red_config.maybe_create_table(id_data); + + -- Look for the existing document + EXECUTE format( + 'SELECT json_data FROM %I.%I WHERE %s', + schemaname, + id_data.category, + whereclause) + INTO existing_document USING id_data.pkeys; + + IF existing_document IS NULL THEN + -- We need to insert a new document + result := NOT default_value; + new_document := red_utils.jsonb_set2('{}', result, VARIADIC id_data.identifiers); + pkey_placeholders := red_utils.gen_pkey_placeholders(id_data.pkey_len, pkey_type); + + EXECUTE format( + 'INSERT INTO %I.%I VALUES(%s, $2)', + schemaname, + id_data.category, + pkey_placeholders) + USING id_data.pkeys, new_document; + + ELSE + -- We need to update the existing document + existing_value := existing_document #> id_data.identifiers; + + IF existing_value IS NULL THEN + result := NOT default_value; + + ELSIF jsonb_typeof(existing_value) = 'boolean' THEN + result := NOT existing_value::text::boolean; + + ELSE + RAISE EXCEPTION 'Cannot increment non-boolean value %', existing_value + USING ERRCODE = 'wrong_object_type'; + END IF; + + new_document := red_utils.jsonb_set2( + existing_document, to_jsonb(result), id_data.identifiers); + + EXECUTE format( + 'UPDATE %I.%I SET json_data = $2 WHERE %s', + schemaname, + id_data.category, + whereclause) + USING id_data.pkeys, new_document; + END IF; + END; +$$; + + +CREATE OR REPLACE FUNCTION + red_config.extend( + id_data red_config.identifier_data, + new_value text, + default_value text, + max_length integer DEFAULT NULL, + extend_left boolean DEFAULT FALSE, + OUT result jsonb + ) + LANGUAGE 'plpgsql' + AS $$ + DECLARE + schemaname CONSTANT text := concat_ws('.', id_data.cog_name, id_data.cog_id); + num_identifiers CONSTANT integer := coalesce(array_length(id_data.identifiers, 1), 0); + pkey_type CONSTANT text := red_utils.get_pkey_type(id_data.is_custom); + whereclause CONSTANT text := red_utils.gen_whereclause(id_data.pkey_len, pkey_type); + pop_idx CONSTANT integer := CASE extend_left WHEN TRUE THEN -1 ELSE 0 END; + + new_document jsonb; + existing_document jsonb; + existing_value jsonb; + pkey_placeholders text; + idx integer; + BEGIN + IF num_identifiers = 0 THEN + -- Without identifiers, there's no chance we're actually appending to an array + RAISE EXCEPTION 'Cannot append to document(s)' + USING ERRCODE = 'wrong_object_type'; + END IF; + + PERFORM red_config.maybe_create_table(id_data); + + -- Look for the existing document + EXECUTE format( + 'SELECT json_data FROM %I.%I WHERE %s', + schemaname, + id_data.category, + whereclause) + INTO existing_document USING id_data.pkeys; + + IF existing_document IS NULL THEN + result := default_value || new_value; + new_document := red_utils.jsonb_set2('{}'::jsonb, result, id_data.identifiers); + pkey_placeholders := red_utils.gen_pkey_placeholders(id_data.pkey_len, pkey_type); + + EXECUTE format( + 'INSERT INTO %I.%I VALUES(%s, $2)', + schemaname, + id_data.category, + pkey_placeholders) + USING id_data.pkeys, new_document; + + ELSE + existing_value := existing_document #> id_data.identifiers; + + IF existing_value IS NULL THEN + existing_value := default_value; + + ELSIF jsonb_typeof(existing_value) != 'array' THEN + RAISE EXCEPTION 'Cannot append to non-array value %', existing_value + USING ERRCODE = 'wrong_object_type'; + END IF; + + CASE extend_left + WHEN TRUE THEN + result := new_value || existing_value; + ELSE + result := existing_value || new_value; + END CASE; + + IF max_length IS NOT NULL THEN + FOR idx IN SELECT generate_series(1, jsonb_array_length(result) - max_length) LOOP + result := result - pop_idx; + END LOOP; + END IF; + + new_document := red_utils.jsonb_set2(existing_document, result, id_data.identifiers); + + EXECUTE format( + 'UPDATE %I.%I SET json_data = $2 WHERE %s', + schemaname, + id_data.category, + whereclause) + USING id_data.pkeys, new_document; + END IF; + END; +$$; + + +CREATE OR REPLACE FUNCTION + /* + * Delete all schemas listed in the red_config.red_cogs table. + */ + red_config.delete_all_schemas() + RETURNS void + LANGUAGE 'plpgsql' + AS $$ + DECLARE + cog_entry record; + BEGIN + FOR cog_entry IN SELECT * FROM red_config.red_cogs t LOOP + EXECUTE format('DROP SCHEMA %I CASCADE', cog_entry.schemaname); + END LOOP; + -- Clear out red_config.red_cogs table + DELETE FROM red_config.red_cogs WHERE TRUE; + END; +$$; + + + +CREATE OR REPLACE FUNCTION + /* + * Like `jsonb_set` but will insert new objects where one is missing + * along the path. + * + * Raises `error_in_assignment` error when trying to set a sub-key + * of a non-document type. + */ + red_utils.jsonb_set2(target jsonb, new_value jsonb, VARIADIC identifiers text[]) + RETURNS jsonb + LANGUAGE 'plpgsql' + IMMUTABLE + PARALLEL SAFE + AS $$ + DECLARE + num_identifiers CONSTANT integer := coalesce(array_length(identifiers, 1), 0); + + cur_value_type text; + idx integer; + + BEGIN + IF num_identifiers = 0 THEN + RETURN new_value; + END IF; + + FOR idx IN SELECT generate_series(1, num_identifiers - 1) LOOP + cur_value_type := jsonb_typeof(target #> identifiers[:idx]); + IF cur_value_type IS NULL THEN + -- Parent key didn't exist in JSON before - insert new object + target := jsonb_set(target, identifiers[:idx], '{}'::jsonb); + + ELSIF cur_value_type != 'object' THEN + -- We can't set the sub-field of a null, int, float, array etc. + RAISE EXCEPTION 'Cannot set sub-field of "%s"', cur_value_type + USING ERRCODE = 'error_in_assignment'; + END IF; + END LOOP; + + RETURN jsonb_set(target, identifiers, new_value); + END; +$$; + + +CREATE OR REPLACE FUNCTION + /* + * Return a set of rows to insert into a table, from a single JSONB + * object containing multiple documents. + */ + red_utils.generate_rows_from_object(object jsonb, num_missing_pkeys integer) + RETURNS setof record + LANGUAGE 'plpgsql' + IMMUTABLE + PARALLEL SAFE + AS $$ + DECLARE + pair record; + column_definitions text; + BEGIN + IF num_missing_pkeys = 1 THEN + -- Base case: Simply return (key, value) pairs + RETURN QUERY + SELECT key AS key_1, value AS json_data + FROM jsonb_each(object); + ELSE + -- We need to return (key, key, ..., value) pairs: recurse into inner JSONB objects + column_definitions := red_utils.gen_pkey_column_definitions(2, num_missing_pkeys); + + FOR pair IN SELECT * FROM jsonb_each(object) LOOP + RETURN QUERY + EXECUTE format( + $query$ + SELECT $1 AS key_1, * + FROM red_utils.generate_rows_from_object($2, $3) + AS f(%s, json_data jsonb) + $query$, + column_definitions) + USING pair.key, pair.value, num_missing_pkeys - 1; + END LOOP; + END IF; + RETURN; + END; +$$; + + +CREATE OR REPLACE FUNCTION + /* + * Get a comma-separated list of primary key placeholders. + * + * The placeholder will always be $1. Particularly useful for + * inserting values into a table from an array of primary keys. + */ + red_utils.gen_pkey_placeholders(num_pkeys integer, pkey_type text DEFAULT 'text') + RETURNS text + LANGUAGE 'sql' + IMMUTABLE + PARALLEL SAFE + AS $$ + SELECT string_agg(t.item, ', ') + FROM ( + SELECT format('$1[%s]::%s', idx, pkey_type) AS item + FROM generate_series(1, num_pkeys) idx) t + ; +$$; + +CREATE OR REPLACE FUNCTION + /* + * Generate a whereclause for the given number of primary keys. + * + * When there are no primary keys, this will simply return the the + * string 'TRUE'. When there are multiple, it will return multiple + * equality comparisons concatenated with 'AND'. + */ + red_utils.gen_whereclause(num_pkeys integer, pkey_type text) + RETURNS text + LANGUAGE 'sql' + IMMUTABLE + PARALLEL SAFE + AS $$ + SELECT coalesce(string_agg(t.item, ' AND '), 'TRUE') + FROM ( + SELECT format('%I = $1[%s]::%s', 'primary_key_'||idx, idx, pkey_type) AS item + FROM generate_series(1, num_pkeys) idx) t + ; +$$; + +CREATE OR REPLACE FUNCTION + /* + * Generate a comma-separated list of primary key column names. + */ + red_utils.gen_pkey_columns(start integer, stop integer) + RETURNS text + LANGUAGE 'sql' + IMMUTABLE + PARALLEL SAFE + AS $$ + SELECT string_agg(t.item, ', ') + FROM ( + SELECT quote_ident('primary_key_'||idx) AS item + FROM generate_series(start, stop) idx) t + ; +$$; + +CREATE OR REPLACE FUNCTION + /* + * Generate a comma-separated list of primary key column names casted + * to the given type. + */ + red_utils.gen_pkey_columns_casted(start integer, stop integer, pkey_type text DEFAULT 'text') + RETURNS text + LANGUAGE 'sql' + IMMUTABLE + PARALLEL SAFE + AS $$ + SELECT string_agg(t.item, ', ') + FROM ( + SELECT format('%I::%s', 'primary_key_'||idx, pkey_type) AS item + FROM generate_series(start, stop) idx) t + ; +$$; + + +CREATE OR REPLACE FUNCTION + /* + * Generate a primary key column definition list. + */ + red_utils.gen_pkey_column_definitions( + start integer, stop integer, column_type text DEFAULT 'text' + ) + RETURNS text + LANGUAGE 'sql' + IMMUTABLE + PARALLEL SAFE + AS $$ + SELECT string_agg(t.item, ', ') + FROM ( + SELECT format('%I %s', 'primary_key_'||idx, column_type) AS item + FROM generate_series(start, stop) idx) t + ; +$$; + + +CREATE OR REPLACE FUNCTION + red_utils.get_pkey_type(is_custom boolean) + RETURNS TEXT + LANGUAGE 'sql' + IMMUTABLE + PARALLEL SAFE + AS $$ + SELECT ('{bigint,text}'::text[])[is_custom::integer + 1]; +$$; + + +DROP AGGREGATE IF EXISTS red_utils.jsonb_object_agg2(jsonb, VARIADIC text[]); +CREATE AGGREGATE + /* + * Like `jsonb_object_agg` but aggregates more than two columns into a + * single JSONB object. + * + * If possible, use `jsonb_object_agg` instead for performance + * reasons. + */ + red_utils.jsonb_object_agg2(json_data jsonb, VARIADIC primary_keys text[]) ( + SFUNC = red_utils.jsonb_set2, + STYPE = jsonb, + INITCOND = '{}', + PARALLEL = SAFE + ) +; + + +CREATE TABLE IF NOT EXISTS + /* + * Table to keep track of other cogs' schemas. + */ + red_config.red_cogs( + cog_name text, + cog_id text, + schemaname text NOT NULL, + PRIMARY KEY (cog_name, cog_id) +) +; diff --git a/redbot/core/drivers/postgres/drop_ddl.sql b/redbot/core/drivers/postgres/drop_ddl.sql new file mode 100644 index 000000000..8587446d0 --- /dev/null +++ b/redbot/core/drivers/postgres/drop_ddl.sql @@ -0,0 +1,3 @@ +SELECT red_config.delete_all_schemas(); +DROP SCHEMA IF EXISTS red_config CASCADE; +DROP SCHEMA IF EXISTS red_utils CASCADE; diff --git a/redbot/core/drivers/postgres/postgres.py b/redbot/core/drivers/postgres/postgres.py new file mode 100644 index 000000000..926052b05 --- /dev/null +++ b/redbot/core/drivers/postgres/postgres.py @@ -0,0 +1,255 @@ +import getpass +import json +import sys +from pathlib import Path +from typing import Optional, Any, AsyncIterator, Tuple, Union, Callable, List + +try: + # pylint: disable=import-error + import asyncpg +except ModuleNotFoundError: + asyncpg = None + +from ... import data_manager, errors +from ..base import BaseDriver, IdentifierData, ConfigCategory +from ..log import log + +__all__ = ["PostgresDriver"] + +_PKG_PATH = Path(__file__).parent +DDL_SCRIPT_PATH = _PKG_PATH / "ddl.sql" +DROP_DDL_SCRIPT_PATH = _PKG_PATH / "drop_ddl.sql" + + +def encode_identifier_data( + id_data: IdentifierData +) -> Tuple[str, str, str, List[str], List[str], int, bool]: + return ( + id_data.cog_name, + id_data.uuid, + id_data.category, + ["0"] if id_data.category == ConfigCategory.GLOBAL else list(id_data.primary_key), + list(id_data.identifiers), + 1 if id_data.category == ConfigCategory.GLOBAL else id_data.primary_key_len, + id_data.is_custom, + ) + + +class PostgresDriver(BaseDriver): + + _pool: Optional["asyncpg.pool.Pool"] = None + + @classmethod + async def initialize(cls, **storage_details) -> None: + if asyncpg is None: + raise errors.MissingExtraRequirements( + "Red must be installed with the [postgres] extra to use the PostgreSQL driver" + ) + cls._pool = await asyncpg.create_pool(**storage_details) + with DDL_SCRIPT_PATH.open() as fs: + await cls._pool.execute(fs.read()) + + @classmethod + async def teardown(cls) -> None: + if cls._pool is not None: + await cls._pool.close() + + @staticmethod + def get_config_details(): + unixmsg = ( + "" + if sys.platform != "win32" + else ( + " - Common directories for PostgreSQL Unix-domain sockets (/run/postgresql, " + "/var/run/postgresl, /var/pgsql_socket, /private/tmp, and /tmp),\n" + ) + ) + host = ( + input( + f"Enter the PostgreSQL server's address.\n" + f"If left blank, Red will try the following, in order:\n" + f" - The PGHOST environment variable,\n{unixmsg}" + f" - localhost.\n" + f"> " + ) + or None + ) + + print( + "Enter the PostgreSQL server port.\n" + "If left blank, this will default to either:\n" + " - The PGPORT environment variable,\n" + " - 5432." + ) + while True: + port = input("> ") or None + if port is None: + break + + try: + port = int(port) + except ValueError: + print("Port must be a number") + else: + break + + user = ( + input( + "Enter the PostgreSQL server username.\n" + "If left blank, this will default to either:\n" + " - The PGUSER environment variable,\n" + " - The OS name of the user running Red (ident/peer authentication).\n" + "> " + ) + or None + ) + + passfile = r"%APPDATA%\postgresql\pgpass.conf" if sys.platform != "win32" else "~/.pgpass" + password = getpass.getpass( + f"Enter the PostgreSQL server password. The input will be hidden.\n" + f" NOTE: If using ident/peer authentication (no password), enter NONE.\n" + f"When NONE is entered, this will default to:\n" + f" - The PGPASSWORD environment variable,\n" + f" - Looking up the password in the {passfile} passfile,\n" + f" - No password.\n" + f"> " + ) + if password == "NONE": + password = None + + database = ( + input( + "Enter the PostgreSQL database's name.\n" + "If left blank, this will default to either:\n" + " - The PGDATABASE environment variable,\n" + " - The OS name of the user running Red.\n" + "> " + ) + or None + ) + + return { + "host": host, + "port": port, + "user": user, + "password": password, + "database": database, + } + + async def get(self, identifier_data: IdentifierData): + try: + result = await self._execute( + "SELECT red_config.get($1)", + encode_identifier_data(identifier_data), + method=self._pool.fetchval, + ) + except asyncpg.UndefinedTableError: + raise KeyError from None + + if result is None: + # The result is None both when postgres yields no results, or when it yields a NULL row + # A 'null' JSON value would be returned as encoded JSON, i.e. the string 'null' + raise KeyError + return json.loads(result) + + async def set(self, identifier_data: IdentifierData, value=None): + try: + await self._execute( + "SELECT red_config.set($1, $2::jsonb)", + encode_identifier_data(identifier_data), + json.dumps(value), + ) + except asyncpg.ErrorInAssignmentError: + raise errors.CannotSetSubfield + + async def clear(self, identifier_data: IdentifierData): + try: + await self._execute( + "SELECT red_config.clear($1)", encode_identifier_data(identifier_data) + ) + except asyncpg.UndefinedTableError: + pass + + async def inc( + self, identifier_data: IdentifierData, value: Union[int, float], default: Union[int, float] + ) -> Union[int, float]: + try: + return await self._execute( + f"SELECT red_config.inc($1, $2, $3)", + encode_identifier_data(identifier_data), + value, + default, + method=self._pool.fetchval, + ) + except asyncpg.WrongObjectTypeError as exc: + raise errors.StoredTypeError(*exc.args) + + async def toggle(self, identifier_data: IdentifierData, default: bool) -> bool: + try: + return await self._execute( + "SELECT red_config.inc($1, $2)", + encode_identifier_data(identifier_data), + default, + method=self._pool.fetchval, + ) + except asyncpg.WrongObjectTypeError as exc: + raise errors.StoredTypeError(*exc.args) + + @classmethod + async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]: + query = "SELECT cog_name, cog_id FROM red_config.red_cogs" + log.invisible(query) + async with cls._pool.acquire() as conn, conn.transaction(): + async for row in conn.cursor(query): + yield row["cog_name"], row["cog_id"] + + @classmethod + async def delete_all_data( + cls, *, interactive: bool = False, drop_db: Optional[bool] = None, **kwargs + ) -> None: + """Delete all data being stored by this driver. + + Parameters + ---------- + interactive : bool + Set to ``True`` to allow the method to ask the user for + input from the console, regarding the other unset parameters + for this method. + drop_db : Optional[bool] + Set to ``True`` to drop the entire database for the current + bot's instance. Otherwise, schemas within the database which + store bot data will be dropped, as well as functions, + aggregates, event triggers, and meta-tables. + + """ + if interactive is True and drop_db is None: + print( + "Please choose from one of the following options:\n" + " 1. Drop the entire PostgreSQL database for this instance, or\n" + " 2. Delete all of Red's data within this database, without dropping the database " + "itself." + ) + options = ("1", "2") + while True: + resp = input("> ") + try: + drop_db = bool(options.index(resp)) + except ValueError: + print("Please type a number corresponding to one of the options.") + else: + break + if drop_db is True: + storage_details = data_manager.storage_details() + await cls._pool.execute(f"DROP DATABASE $1", storage_details["database"]) + else: + with DROP_DDL_SCRIPT_PATH.open() as fs: + await cls._pool.execute(fs.read()) + + @classmethod + async def _execute(cls, query: str, *args, method: Optional[Callable] = None) -> Any: + if method is None: + method = cls._pool.execute + log.invisible("Query: %s", query) + if args: + log.invisible("Args: %s", args) + return await method(query, *args) diff --git a/redbot/core/drivers/red_base.py b/redbot/core/drivers/red_base.py deleted file mode 100644 index de28b1b9a..000000000 --- a/redbot/core/drivers/red_base.py +++ /dev/null @@ -1,233 +0,0 @@ -import enum -from typing import Tuple - -__all__ = ["BaseDriver", "IdentifierData"] - - -class ConfigCategory(enum.Enum): - GLOBAL = "GLOBAL" - GUILD = "GUILD" - CHANNEL = "TEXTCHANNEL" - ROLE = "ROLE" - USER = "USER" - MEMBER = "MEMBER" - - -class IdentifierData: - def __init__( - self, - uuid: str, - category: str, - primary_key: Tuple[str, ...], - identifiers: Tuple[str, ...], - custom_group_data: dict, - is_custom: bool = False, - ): - self._uuid = uuid - self._category = category - self._primary_key = primary_key - self._identifiers = identifiers - self.custom_group_data = custom_group_data - self._is_custom = is_custom - - @property - def uuid(self): - return self._uuid - - @property - def category(self): - return self._category - - @property - def primary_key(self): - return self._primary_key - - @property - def identifiers(self): - return self._identifiers - - @property - def is_custom(self): - return self._is_custom - - def __repr__(self): - return ( - f"" - ) - - def __eq__(self, other) -> bool: - if not isinstance(other, IdentifierData): - return False - return ( - self.uuid == other.uuid - and self.category == other.category - and self.primary_key == other.primary_key - and self.identifiers == other.identifiers - ) - - def __hash__(self) -> int: - return hash((self.uuid, self.category, self.primary_key, self.identifiers)) - - def add_identifier(self, *identifier: str) -> "IdentifierData": - if not all(isinstance(i, str) for i in identifier): - raise ValueError("Identifiers must be strings.") - - return IdentifierData( - self.uuid, - self.category, - self.primary_key, - self.identifiers + identifier, - self.custom_group_data, - is_custom=self.is_custom, - ) - - def to_tuple(self): - return tuple( - item - for item in (self.uuid, self.category, *self.primary_key, *self.identifiers) - if len(item) > 0 - ) - - -class BaseDriver: - def __init__(self, cog_name, identifier): - self.cog_name = cog_name - self.unique_cog_identifier = identifier - - async def has_valid_connection(self) -> bool: - raise NotImplementedError - - async def get(self, identifier_data: IdentifierData): - """ - Finds the value indicate by the given identifiers. - - Parameters - ---------- - identifier_data - - Returns - ------- - Any - Stored value. - """ - raise NotImplementedError - - def get_config_details(self): - """ - Asks users for additional configuration information necessary - to use this config driver. - - Returns - ------- - Dict of configuration details. - """ - raise NotImplementedError - - async def set(self, identifier_data: IdentifierData, value=None): - """ - Sets the value of the key indicated by the given identifiers. - - Parameters - ---------- - identifier_data - value - Any JSON serializable python object. - """ - raise NotImplementedError - - async def clear(self, identifier_data: IdentifierData): - """ - Clears out the value specified by the given identifiers. - - Equivalent to using ``del`` on a dict. - - Parameters - ---------- - identifier_data - """ - raise NotImplementedError - - def _get_levels(self, category, custom_group_data): - if category == ConfigCategory.GLOBAL.value: - return 0 - elif category in ( - ConfigCategory.USER.value, - ConfigCategory.GUILD.value, - ConfigCategory.CHANNEL.value, - ConfigCategory.ROLE.value, - ): - return 1 - elif category == ConfigCategory.MEMBER.value: - return 2 - elif category in custom_group_data: - return custom_group_data[category] - else: - raise RuntimeError(f"Cannot convert due to group: {category}") - - def _split_primary_key(self, category, custom_group_data, data): - levels = self._get_levels(category, custom_group_data) - if levels == 0: - return (((), data),) - - def flatten(levels_remaining, currdata, parent_key=()): - items = [] - for k, v in currdata.items(): - new_key = parent_key + (k,) - if levels_remaining > 1: - items.extend(flatten(levels_remaining - 1, v, new_key).items()) - else: - items.append((new_key, v)) - return dict(items) - - ret = [] - for k, v in flatten(levels, data).items(): - ret.append((k, v)) - return tuple(ret) - - async def export_data(self, custom_group_data): - categories = [c.value for c in ConfigCategory] - categories.extend(custom_group_data.keys()) - - ret = [] - for c in categories: - ident_data = IdentifierData( - self.unique_cog_identifier, - c, - (), - (), - custom_group_data, - is_custom=c in custom_group_data, - ) - try: - data = await self.get(ident_data) - except KeyError: - continue - ret.append((c, data)) - return ret - - async def import_data(self, cog_data, custom_group_data): - for category, all_data in cog_data: - splitted_pkey = self._split_primary_key(category, custom_group_data, all_data) - for pkey, data in splitted_pkey: - ident_data = IdentifierData( - self.unique_cog_identifier, - category, - pkey, - (), - custom_group_data, - is_custom=category in custom_group_data, - ) - await self.set(ident_data, data) - - @staticmethod - def get_pkey_len(identifier_data: IdentifierData) -> int: - cat = identifier_data.category - if cat == ConfigCategory.GLOBAL.value: - return 0 - elif cat == ConfigCategory.MEMBER.value: - return 2 - elif identifier_data.is_custom: - return identifier_data.custom_group_data[cat] - else: - return 1 diff --git a/redbot/core/errors.py b/redbot/core/errors.py index a67097dc0..5bd7f4fa2 100644 --- a/redbot/core/errors.py +++ b/redbot/core/errors.py @@ -1,5 +1,4 @@ import importlib.machinery -from typing import Optional import discord @@ -49,3 +48,37 @@ class BalanceTooHigh(BankError, OverflowError): return _("{user}'s balance cannot rise above {max:,} {currency}.").format( user=self.user, max=self.max_balance, currency=self.currency_name ) + + +class MissingExtraRequirements(RedError): + """Raised when an extra requirement is missing but required.""" + + +class ConfigError(RedError): + """Error in a Config operation.""" + + +class StoredTypeError(ConfigError, TypeError): + """A TypeError pertaining to stored Config data. + + This error may arise when, for example, trying to increment a value + which is not a number, or trying to toggle a value which is not a + boolean. + """ + + +class CannotSetSubfield(StoredTypeError): + """Tried to set sub-field of an invalid data structure. + + This would occur in the following example:: + + >>> import asyncio + >>> from redbot.core import Config + >>> config = Config.get_conf(None, 1234, cog_name="Example") + >>> async def example(): + ... await config.foo.set(True) + ... await config.set_raw("foo", "bar", False) # Should raise here + ... + >>> asyncio.run(example()) + + """ diff --git a/redbot/core/utils/__init__.py b/redbot/core/utils/__init__.py index 4dc805109..96a44c515 100644 --- a/redbot/core/utils/__init__.py +++ b/redbot/core/utils/__init__.py @@ -1,7 +1,9 @@ import asyncio +import json import logging import os import shutil +import tarfile from asyncio import AbstractEventLoop, as_completed, Semaphore from asyncio.futures import isfuture from itertools import chain @@ -24,8 +26,10 @@ from typing import ( ) import discord +from datetime import datetime from fuzzywuzzy import fuzz, process +from .. import commands, data_manager from .chat_formatting import box if TYPE_CHECKING: @@ -37,6 +41,7 @@ __all__ = [ "fuzzy_command_search", "format_fuzzy_results", "deduplicate_iterables", + "create_backup", ] _T = TypeVar("_T") @@ -397,3 +402,45 @@ def bounded_gather( tasks = (_sem_wrapper(semaphore, task) for task in coros_or_futures) return asyncio.gather(*tasks, loop=loop, return_exceptions=return_exceptions) + + +async def create_backup(dest: Path = Path.home()) -> Optional[Path]: + data_path = Path(data_manager.core_data_path().parent) + if not data_path.exists(): + return + + dest.mkdir(parents=True, exist_ok=True) + timestr = datetime.utcnow().isoformat(timespec="minutes") + backup_fpath = dest / f"redv3_{data_manager.instance_name}_{timestr}.tar.gz" + + to_backup = [] + exclusions = [ + "__pycache__", + "Lavalink.jar", + os.path.join("Downloader", "lib"), + os.path.join("CogManager", "cogs"), + os.path.join("RepoManager", "repos"), + ] + + # Avoiding circular imports + from ...cogs.downloader.repo_manager import RepoManager + + repo_mgr = RepoManager() + await repo_mgr.initialize() + repo_output = [] + for _, repo in repo_mgr._repos: + repo_output.append({"url": repo.url, "name": repo.name, "branch": repo.branch}) + repos_file = data_path / "cogs" / "RepoManager" / "repos.json" + with repos_file.open("w") as fs: + json.dump(repo_output, fs, indent=4) + instance_file = data_path / "instance.json" + with instance_file.open("w") as fs: + json.dump({data_manager.instance_name: data_manager.basic_config}, fs, indent=4) + for f in data_path.glob("**/*"): + if not any(ex in str(f) for ex in exclusions) and f.is_file(): + to_backup.append(f) + + with tarfile.open(str(backup_fpath), "w:gz") as tar: + for f in to_backup: + tar.add(str(f), arcname=f.relative_to(data_path), recursive=False) + return backup_fpath diff --git a/redbot/core/utils/tunnel.py b/redbot/core/utils/tunnel.py index bae10b04c..e1655ee78 100644 --- a/redbot/core/utils/tunnel.py +++ b/redbot/core/utils/tunnel.py @@ -89,7 +89,7 @@ class Tunnel(metaclass=TunnelMeta): destination: discord.abc.Messageable, content: str = None, embed=None, - files: Optional[List[discord.File]] = None + files: Optional[List[discord.File]] = None, ) -> List[discord.Message]: """ This does the actual sending, use this instead of a full tunnel diff --git a/redbot/pytest/core.py b/redbot/pytest/core.py index f22ce931d..b839d0732 100644 --- a/redbot/pytest/core.py +++ b/redbot/pytest/core.py @@ -7,15 +7,13 @@ import pytest from _pytest.monkeypatch import MonkeyPatch from redbot.core import Config from redbot.core.bot import Red -from redbot.core import config as config_module - -from redbot.core.drivers import red_json +from redbot.core import config as config_module, drivers __all__ = [ "monkeysession", "override_data_path", "coroutine", - "json_driver", + "driver", "config", "config_fr", "red", @@ -56,34 +54,31 @@ def coroutine(): @pytest.fixture() -def json_driver(tmpdir_factory): +def driver(tmpdir_factory): import uuid rand = str(uuid.uuid4()) path = Path(str(tmpdir_factory.mktemp(rand))) - driver = red_json.JSON("PyTest", identifier=str(uuid.uuid4()), data_path_override=path) - return driver + return drivers.get_driver("PyTest", str(random.randint(1, 999999)), data_path_override=path) @pytest.fixture() -def config(json_driver): +def config(driver): config_module._config_cache = weakref.WeakValueDictionary() - conf = Config( - cog_name="PyTest", unique_identifier=json_driver.unique_cog_identifier, driver=json_driver - ) + conf = Config(cog_name="PyTest", unique_identifier=driver.unique_cog_identifier, driver=driver) yield conf @pytest.fixture() -def config_fr(json_driver): +def config_fr(driver): """ Mocked config object with force_register enabled. """ config_module._config_cache = weakref.WeakValueDictionary() conf = Config( cog_name="PyTest", - unique_identifier=json_driver.unique_cog_identifier, - driver=json_driver, + unique_identifier=driver.unique_cog_identifier, + driver=driver, force_registration=True, ) yield conf @@ -176,7 +171,7 @@ def red(config_fr): Config.get_core_conf = lambda *args, **kwargs: config_fr - red = Red(cli_flags=cli_flags, description=description, dm_help=None) + red = Red(cli_flags=cli_flags, description=description, dm_help=None, owner_id=None) yield red diff --git a/redbot/setup.py b/redbot/setup.py index 797cbd841..0edaa6b32 100644 --- a/redbot/setup.py +++ b/redbot/setup.py @@ -1,32 +1,21 @@ #!/usr/bin/env python3 import asyncio import json +import logging import os import sys -import tarfile from copy import deepcopy -from datetime import datetime as dt from pathlib import Path -import logging +from typing import Dict, Any, Optional import appdirs import click import redbot.logging from redbot.core.cli import confirm -from redbot.core.data_manager import ( - basic_config_default, - load_basic_configuration, - instance_name, - basic_config, - cog_data_path, - core_data_path, - storage_details, -) -from redbot.core.utils import safe_delete -from redbot.core import Config +from redbot.core.utils import safe_delete, create_backup as _create_backup +from redbot.core import config, data_manager, drivers from redbot.core.drivers import BackendType, IdentifierData -from redbot.core.drivers.red_json import JSON conversion_log = logging.getLogger("red.converter") @@ -61,11 +50,11 @@ else: def save_config(name, data, remove=False): - config = load_existing_config() - if remove and name in config: - config.pop(name) + _config = load_existing_config() + if remove and name in _config: + _config.pop(name) else: - if name in config: + if name in _config: print( "WARNING: An instance already exists with this name. " "Continuing will overwrite the existing instance config." @@ -73,10 +62,10 @@ def save_config(name, data, remove=False): if not confirm("Are you absolutely certain you want to continue (y/n)? "): print("Not continuing") sys.exit(0) - config[name] = data + _config[name] = data with config_file.open("w", encoding="utf-8") as fs: - json.dump(config, fs, indent=4) + json.dump(_config, fs, indent=4) def get_data_dir(): @@ -118,13 +107,14 @@ def get_data_dir(): def get_storage_type(): - storage_dict = {1: "JSON", 2: "MongoDB"} + storage_dict = {1: "JSON", 2: "MongoDB", 3: "PostgreSQL"} storage = None while storage is None: print() print("Please choose your storage backend (if you're unsure, choose 1).") print("1. JSON (file storage, requires no database).") print("2. MongoDB") + print("3. PostgreSQL") storage = input("> ") try: storage = int(storage) @@ -158,21 +148,16 @@ def basic_setup(): default_data_dir = get_data_dir() - default_dirs = deepcopy(basic_config_default) + default_dirs = deepcopy(data_manager.basic_config_default) default_dirs["DATA_PATH"] = str(default_data_dir.resolve()) storage = get_storage_type() - storage_dict = {1: BackendType.JSON, 2: BackendType.MONGO} + storage_dict = {1: BackendType.JSON, 2: BackendType.MONGO, 3: BackendType.POSTGRES} storage_type: BackendType = storage_dict.get(storage, BackendType.JSON) default_dirs["STORAGE_TYPE"] = storage_type.value - - if storage_type == BackendType.MONGO: - from redbot.core.drivers.red_mongo import get_config_details - - default_dirs["STORAGE_DETAILS"] = get_config_details() - else: - default_dirs["STORAGE_DETAILS"] = {} + driver_cls = drivers.get_driver_class(storage_type) + default_dirs["STORAGE_DETAILS"] = driver_cls.get_config_details() name = get_name() save_config(name, default_dirs) @@ -193,130 +178,38 @@ def get_target_backend(backend) -> BackendType: return BackendType.JSON elif backend == "mongo": return BackendType.MONGO + elif backend == "postgres": + return BackendType.POSTGRES -async def json_to_mongov2(instance): - instance_vals = instance_data[instance] - current_data_dir = Path(instance_vals["DATA_PATH"]) +async def do_migration( + current_backend: BackendType, target_backend: BackendType +) -> Dict[str, Any]: + cur_driver_cls = drivers.get_driver_class(current_backend) + new_driver_cls = drivers.get_driver_class(target_backend) + cur_storage_details = data_manager.storage_details() + new_storage_details = new_driver_cls.get_config_details() - load_basic_configuration(instance) + await cur_driver_cls.initialize(**cur_storage_details) + await new_driver_cls.initialize(**new_storage_details) - from redbot.core.drivers import red_mongo + await config.migrate(cur_driver_cls, new_driver_cls) - storage_details = red_mongo.get_config_details() + await cur_driver_cls.teardown() + await new_driver_cls.teardown() - core_conf = Config.get_core_conf() - new_driver = red_mongo.Mongo(cog_name="Core", identifier="0", **storage_details) - - core_conf.init_custom("CUSTOM_GROUPS", 2) - custom_group_data = await core_conf.custom("CUSTOM_GROUPS").all() - - curr_custom_data = custom_group_data.get("Core", {}).get("0", {}) - exported_data = await core_conf.driver.export_data(curr_custom_data) - conversion_log.info("Starting Core conversion...") - await new_driver.import_data(exported_data, curr_custom_data) - conversion_log.info("Core conversion complete.") - - for p in current_data_dir.glob("cogs/**/settings.json"): - cog_name = p.parent.stem - if "." in cog_name: - # Garbage handler - continue - with p.open(mode="r", encoding="utf-8") as f: - cog_data = json.load(f) - for identifier, all_data in cog_data.items(): - try: - conf = Config.get_conf(None, int(identifier), cog_name=cog_name) - except ValueError: - continue - new_driver = red_mongo.Mongo( - cog_name=cog_name, identifier=conf.driver.unique_cog_identifier, **storage_details - ) - - curr_custom_data = custom_group_data.get(cog_name, {}).get(identifier, {}) - - exported_data = await conf.driver.export_data(curr_custom_data) - conversion_log.info(f"Converting {cog_name} with identifier {identifier}...") - await new_driver.import_data(exported_data, curr_custom_data) - - conversion_log.info("Cog conversion complete.") - - return storage_details + return new_storage_details -async def mongov2_to_json(instance): - load_basic_configuration(instance) - - core_path = core_data_path() - - from redbot.core.drivers import red_json - - core_conf = Config.get_core_conf() - new_driver = red_json.JSON(cog_name="Core", identifier="0", data_path_override=core_path) - - core_conf.init_custom("CUSTOM_GROUPS", 2) - custom_group_data = await core_conf.custom("CUSTOM_GROUPS").all() - - curr_custom_data = custom_group_data.get("Core", {}).get("0", {}) - exported_data = await core_conf.driver.export_data(curr_custom_data) - conversion_log.info("Starting Core conversion...") - await new_driver.import_data(exported_data, curr_custom_data) - conversion_log.info("Core conversion complete.") - - collection_names = await core_conf.driver.db.list_collection_names() - splitted_names = list( - filter( - lambda elem: elem[1] != "" and elem[0] != "Core", - [n.split(".") for n in collection_names], - ) - ) - - ident_map = {} # Cogname: idents list - for cog_name, category in splitted_names: - if cog_name not in ident_map: - ident_map[cog_name] = set() - - idents = await core_conf.driver.db[cog_name][category].distinct("_id.RED_uuid") - ident_map[cog_name].update(set(idents)) - - for cog_name, idents in ident_map.items(): - for identifier in idents: - curr_custom_data = custom_group_data.get(cog_name, {}).get(identifier, {}) - try: - conf = Config.get_conf(None, int(identifier), cog_name=cog_name) - except ValueError: - continue - exported_data = await conf.driver.export_data(curr_custom_data) - - new_path = cog_data_path(raw_name=cog_name) - new_driver = red_json.JSON(cog_name, identifier, data_path_override=new_path) - conversion_log.info(f"Converting {cog_name} with identifier {identifier}...") - await new_driver.import_data(exported_data, curr_custom_data) - - # cog_data_path(raw_name=cog_name) - - conversion_log.info("Cog conversion complete.") - - return {} - - -async def mongo_to_json(instance): - load_basic_configuration(instance) - - from redbot.core.drivers.red_mongo import Mongo - - m = Mongo("Core", "0", **storage_details()) +async def mongov1_to_json() -> Dict[str, Any]: + await drivers.MongoDriver.initialize(**data_manager.storage_details()) + m = drivers.MongoDriver("Core", "0") db = m.db collection_names = await db.list_collection_names() for collection_name in collection_names: if "." in collection_name: # Fix for one of Zeph's problems continue - elif collection_name == "Core": - c_data_path = core_data_path() - else: - c_data_path = cog_data_path(raw_name=collection_name) - c_data_path.mkdir(parents=True, exist_ok=True) # Every cog name has its own collection collection = db[collection_name] async for document in collection.find(): @@ -329,16 +222,22 @@ async def mongo_to_json(instance): continue elif not str(cog_id).isdigit(): continue - driver = JSON(collection_name, cog_id, data_path_override=c_data_path) + driver = drivers.JsonDriver(collection_name, cog_id) for category, value in document.items(): - ident_data = IdentifierData(str(cog_id), category, (), (), {}) + ident_data = IdentifierData( + str(collection_name), str(cog_id), category, tuple(), tuple(), 0 + ) await driver.set(ident_data, value=value) + + conversion_log.info("Cog conversion complete.") + await drivers.MongoDriver.teardown() + return {} async def edit_instance(): - instance_list = load_existing_config() - if not instance_list: + _instance_list = load_existing_config() + if not _instance_list: print("No instances have been set up!") return @@ -346,18 +245,18 @@ async def edit_instance(): "You have chosen to edit an instance. The following " "is a list of instances that currently exist:\n" ) - for instance in instance_list.keys(): + for instance in _instance_list.keys(): print("{}\n".format(instance)) print("Please select one of the above by entering its name") selected = input("> ") - if selected not in instance_list.keys(): + if selected not in _instance_list.keys(): print("That isn't a valid instance!") return - instance_data = instance_list[selected] - default_dirs = deepcopy(basic_config_default) + _instance_data = _instance_list[selected] + default_dirs = deepcopy(data_manager.basic_config_default) - current_data_dir = Path(instance_data["DATA_PATH"]) + current_data_dir = Path(_instance_data["DATA_PATH"]) print("You have selected '{}' as the instance to modify.".format(selected)) if not confirm("Please confirm (y/n):"): print("Ok, we will not continue then.") @@ -383,68 +282,47 @@ async def edit_instance(): print("Your basic configuration has been edited") -async def create_backup(instance): - instance_vals = instance_data[instance] - if confirm("Would you like to make a backup of the data for this instance? (y/n)"): - load_basic_configuration(instance) - if instance_vals["STORAGE_TYPE"] == "MongoDB": - await mongo_to_json(instance) - print("Backing up the instance's data...") - backup_filename = "redv3-{}-{}.tar.gz".format( - instance, dt.utcnow().strftime("%Y-%m-%d %H-%M-%S") - ) - pth = Path(instance_vals["DATA_PATH"]) - if pth.exists(): - backup_pth = pth.home() - backup_file = backup_pth / backup_filename - - to_backup = [] - exclusions = [ - "__pycache__", - "Lavalink.jar", - os.path.join("Downloader", "lib"), - os.path.join("CogManager", "cogs"), - os.path.join("RepoManager", "repos"), - ] - from redbot.cogs.downloader.repo_manager import RepoManager - - repo_mgr = RepoManager() - await repo_mgr.initialize() - repo_output = [] - for repo in repo_mgr._repos.values(): - repo_output.append({"url": repo.url, "name": repo.name, "branch": repo.branch}) - repo_filename = pth / "cogs" / "RepoManager" / "repos.json" - with open(str(repo_filename), "w") as f: - f.write(json.dumps(repo_output, indent=4)) - instance_vals = {instance_name: basic_config} - instance_file = pth / "instance.json" - with open(str(instance_file), "w") as instance_out: - instance_out.write(json.dumps(instance_vals, indent=4)) - for f in pth.glob("**/*"): - if not any(ex in str(f) for ex in exclusions): - to_backup.append(f) - with tarfile.open(str(backup_file), "w:gz") as tar: - for f in to_backup: - tar.add(str(f), recursive=False) - print("A backup of {} has been made. It is at {}".format(instance, backup_file)) - - -async def remove_instance(instance): - await create_backup(instance) - - instance_vals = instance_data[instance] - if instance_vals["STORAGE_TYPE"] == "MongoDB": - from redbot.core.drivers.red_mongo import Mongo - - m = Mongo("Core", **instance_vals["STORAGE_DETAILS"]) - db = m.db - collections = await db.collection_names(include_system_collections=False) - for name in collections: - collection = await db.get_collection(name) - await collection.drop() +async def create_backup(instance: str) -> None: + data_manager.load_basic_configuration(instance) + backend_type = get_current_backend(instance) + if backend_type == BackendType.MONGOV1: + await mongov1_to_json() + elif backend_type != BackendType.JSON: + await do_migration(backend_type, BackendType.JSON) + print("Backing up the instance's data...") + backup_fpath = await _create_backup() + if backup_fpath is not None: + print(f"A backup of {instance} has been made. It is at {backup_fpath}") else: - pth = Path(instance_vals["DATA_PATH"]) - safe_delete(pth) + print("Creating the backup failed.") + + +async def remove_instance( + instance, + interactive: bool = False, + drop_db: Optional[bool] = None, + remove_datapath: Optional[bool] = None, +): + data_manager.load_basic_configuration(instance) + + if confirm("Would you like to make a backup of the data for this instance? (y/n)"): + await create_backup(instance) + + backend = get_current_backend(instance) + if backend == BackendType.MONGOV1: + driver_cls = drivers.MongoDriver + else: + driver_cls = drivers.get_driver_class(backend) + + await driver_cls.delete_all_data(interactive=interactive, drop_db=drop_db) + + if interactive is True and remove_datapath is None: + remove_datapath = confirm("Would you like to delete the instance's entire datapath? (y/n)") + + if remove_datapath is True: + data_path = data_manager.core_data_path().parent + safe_delete(data_path) + save_config(instance, {}, remove=True) print("The instance {} has been removed\n".format(instance)) @@ -467,8 +345,7 @@ async def remove_instance_interaction(): print("That isn't a valid instance!") return - await create_backup(selected) - await remove_instance(selected) + await remove_instance(selected, interactive=True) @click.group(invoke_without_command=True) @@ -483,38 +360,56 @@ def cli(ctx, debug): @cli.command() @click.argument("instance", type=click.Choice(instance_list)) -def delete(instance): +@click.option("--no-prompt", default=False, help="Don't ask for user input during the process.") +@click.option( + "--drop-db", + type=bool, + default=None, + help=( + "Drop the entire database constaining this instance's data. Has no effect on JSON " + "instances. If this option and --no-prompt are omitted, you will be asked about this." + ), +) +@click.option( + "--remove-datapath", + type=bool, + default=None, + help=( + "Remove this entire instance's datapath. If this option and --no-prompt are omitted, you " + "will be asked about this." + ), +) +def delete(instance: str, no_prompt: Optional[bool], drop_db: Optional[bool]): loop = asyncio.get_event_loop() - loop.run_until_complete(remove_instance(instance)) + if no_prompt is None: + interactive = None + else: + interactive = not no_prompt + loop.run_until_complete(remove_instance(instance, interactive, drop_db)) @cli.command() @click.argument("instance", type=click.Choice(instance_list)) -@click.argument("backend", type=click.Choice(["json", "mongo"])) +@click.argument("backend", type=click.Choice(["json", "mongo", "postgres"])) def convert(instance, backend): current_backend = get_current_backend(instance) target = get_target_backend(backend) + data_manager.load_basic_configuration(instance) - default_dirs = deepcopy(basic_config_default) + default_dirs = deepcopy(data_manager.basic_config_default) default_dirs["DATA_PATH"] = str(Path(instance_data[instance]["DATA_PATH"])) loop = asyncio.get_event_loop() - new_storage_details = None - if current_backend == BackendType.MONGOV1: - if target == BackendType.MONGO: + if target == BackendType.JSON: + new_storage_details = loop.run_until_complete(mongov1_to_json()) + else: raise RuntimeError( "Please see conversion docs for updating to the latest mongo version." ) - elif target == BackendType.JSON: - new_storage_details = loop.run_until_complete(mongo_to_json(instance)) - elif current_backend == BackendType.JSON: - if target == BackendType.MONGO: - new_storage_details = loop.run_until_complete(json_to_mongov2(instance)) - elif current_backend == BackendType.MONGO: - if target == BackendType.JSON: - new_storage_details = loop.run_until_complete(mongov2_to_json(instance)) + else: + new_storage_details = loop.run_until_complete(do_migration(current_backend, target)) if new_storage_details is not None: default_dirs["STORAGE_TYPE"] = target.value @@ -522,7 +417,9 @@ def convert(instance, backend): save_config(instance, default_dirs) conversion_log.info(f"Conversion to {target} complete.") else: - conversion_log.info(f"Cannot convert {current_backend} to {target} at this time.") + conversion_log.info( + f"Cannot convert {current_backend.value} to {target.value} at this time." + ) if __name__ == "__main__": diff --git a/setup.cfg b/setup.cfg index 399c55186..76acee512 100644 --- a/setup.cfg +++ b/setup.cfg @@ -80,6 +80,8 @@ mongo = dnspython==1.16.0 motor==2.0.0 pymongo==3.8.0 +postgres = + asyncpg==0.18.3 style = black==19.3b0 toml==0.10.0 @@ -123,3 +125,5 @@ include = **/locales/*.po data/* data/**/* +redbot.core.drivers.postgres = + *.sql diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 000000000..11d03fb88 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,45 @@ +import asyncio +import os + +import pytest + +from redbot.core import drivers, data_manager + + +@pytest.fixture(scope="session") +def event_loop(request): + """Create an instance of the default event loop for entire session.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + +def _get_backend_type(): + if os.getenv("RED_STORAGE_TYPE") == "postgres": + return drivers.BackendType.POSTGRES + elif os.getenv("RED_STORAGE_TYPE") == "mongo": + return drivers.BackendType.MONGO + else: + return drivers.BackendType.JSON + + +@pytest.fixture(scope="session", autouse=True) +async def _setup_driver(): + backend_type = _get_backend_type() + if backend_type == drivers.BackendType.MONGO: + storage_details = { + "URI": os.getenv("RED_MONGO_URI", "mongodb"), + "HOST": os.getenv("RED_MONGO_HOST", "localhost"), + "PORT": int(os.getenv("RED_MONGO_PORT", "27017")), + "USERNAME": os.getenv("RED_MONGO_USER", "red"), + "PASSWORD": os.getenv("RED_MONGO_PASSWORD", "red"), + "DB_NAME": os.getenv("RED_MONGO_DATABASE", "red_db"), + } + else: + storage_details = {} + data_manager.storage_type = lambda: backend_type.value + data_manager.storage_details = lambda: storage_details + driver_cls = drivers.get_driver_class(backend_type) + await driver_cls.initialize(**storage_details) + yield + await driver_cls.teardown() diff --git a/tools/dev-requirements.txt b/tools/dev-requirements.txt index dbacdc621..ab52d3e8f 100644 --- a/tools/dev-requirements.txt +++ b/tools/dev-requirements.txt @@ -1,3 +1,3 @@ packaging tox --e .[docs,mongo,style,test] +-e .[docs,mongo,postgres,style,test] diff --git a/tools/primary_deps.ini b/tools/primary_deps.ini index 390920480..bd6e8819a 100644 --- a/tools/primary_deps.ini +++ b/tools/primary_deps.ini @@ -34,6 +34,8 @@ docs = mongo = dnspython motor +postgres = + asyncpg style = black test = diff --git a/tox.ini b/tox.ini index 965ea3e99..599d53eb1 100644 --- a/tox.ini +++ b/tox.ini @@ -15,12 +15,47 @@ description = Run tests and basic automatic issue checking. whitelist_externals = pytest pylint -extras = voice, test, mongo +extras = voice, test commands = python -m compileall ./redbot/cogs pytest pylint ./redbot +[testenv:postgres] +description = Run pytest with PostgreSQL backend +whitelist_externals = + pytest +extras = voice, test, postgres +setenv = + RED_STORAGE_TYPE=postgres +passenv = + # Use the following env vars for connection options, or other default options described here: + # https://magicstack.github.io/asyncpg/current/index.html#asyncpg.connection.connect + PGHOST + PGPORT + PGUSER + PGPASSWORD + PGDATABASE +commands = + pytest + +[testenv:mongo] +description = Run pytest with MongoDB backend +whitelist_externals = + pytest +extras = voice, test, mongo +setenv = + RED_STORAGE_TYPE=mongo +passenv = + RED_MONGO_URI + RED_MONGO_HOST + RED_MONGO_PORT + RED_MONGO_USER + RED_MONGO_PASSWORD + RED_MONGO_DATABASE +commands = + pytest + [testenv:docs] description = Attempt to build docs with sphinx-build whitelist_externals =