diff --git a/changelog.d/3108.feature.rst b/changelog.d/3108.feature.rst new file mode 100644 index 000000000..5954b23ce --- /dev/null +++ b/changelog.d/3108.feature.rst @@ -0,0 +1 @@ +Ensure people can migrate from MongoDB diff --git a/redbot/core/drivers/__init__.py b/redbot/core/drivers/__init__.py index 7b650b1be..13dbe5bb2 100644 --- a/redbot/core/drivers/__init__.py +++ b/redbot/core/drivers/__init__.py @@ -32,6 +32,18 @@ class BackendType(enum.Enum): _DRIVER_CLASSES = {BackendType.JSON: JsonDriver, BackendType.POSTGRES: PostgresDriver} +def _get_driver_class_include_old(storage_type: Optional[BackendType] = None) -> Type[BaseDriver]: + """ + ONLY for use in CLI for moving data away from a no longer supported backend + """ + if storage_type and storage_type == BackendType.MONGO: + from ._mongo import MongoDriver + + return MongoDriver + else: + return get_driver_class(storage_type) + + def get_driver_class(storage_type: Optional[BackendType] = None) -> Type[BaseDriver]: """Get the driver class for the given storage type. diff --git a/redbot/core/drivers/_mongo.py b/redbot/core/drivers/_mongo.py new file mode 100644 index 000000000..58704ddd5 --- /dev/null +++ b/redbot/core/drivers/_mongo.py @@ -0,0 +1,451 @@ +# Below File is not supported for anything other than conversion +# away and will be removed at a later date +# State of file below is "AS-IS" from before removal +import contextlib +import itertools +import re +from getpass import getpass +from typing import Match, Pattern, Tuple, Optional, AsyncIterator, Any, Dict, Iterator, List +from urllib.parse import quote_plus + +try: + # pylint: disable=import-error + import pymongo.errors + import motor.core + import motor.motor_asyncio +except ModuleNotFoundError: + motor = None + pymongo = None + +from .. import errors +from .base import BaseDriver, IdentifierData + +__all__ = ["MongoDriver"] + + +class MongoDriver(BaseDriver): + """ + Subclass of :py:class:`.BaseDriver`. + """ + + _conn: Optional["motor.motor_asyncio.AsyncIOMotorClient"] = None + + @classmethod + async def initialize(cls, **storage_details) -> None: + if motor is None: + raise errors.MissingExtraRequirements( + "Red must be installed with the [mongo] extra to use the MongoDB driver" + ) + uri = storage_details.get("URI", "mongodb") + host = storage_details["HOST"] + port = storage_details["PORT"] + user = storage_details["USERNAME"] + password = storage_details["PASSWORD"] + database = storage_details.get("DB_NAME", "default_db") + + if port is 0: + ports = "" + else: + ports = ":{}".format(port) + + if user is not None and password is not None: + url = "{}://{}:{}@{}{}/{}".format( + uri, quote_plus(user), quote_plus(password), host, ports, database + ) + else: + url = "{}://{}{}/{}".format(uri, host, ports, database) + + cls._conn = motor.motor_asyncio.AsyncIOMotorClient(url, retryWrites=True) + + @classmethod + async def teardown(cls) -> None: + if cls._conn is not None: + cls._conn.close() + + @staticmethod + def get_config_details(): + while True: + uri = input("Enter URI scheme (mongodb or mongodb+srv): ") + if uri is "": + uri = "mongodb" + + if uri in ["mongodb", "mongodb+srv"]: + break + else: + print("Invalid URI scheme") + + host = input("Enter host address: ") + if uri is "mongodb": + port = int(input("Enter host port: ")) + else: + port = 0 + + admin_uname = input("Enter login username: ") + admin_password = getpass("Enter login password: ") + + db_name = input("Enter mongodb database name: ") + + if admin_uname == "": + admin_uname = admin_password = None + + ret = { + "HOST": host, + "PORT": port, + "USERNAME": admin_uname, + "PASSWORD": admin_password, + "DB_NAME": db_name, + "URI": uri, + } + return ret + + @property + def db(self) -> "motor.core.Database": + """ + Gets the mongo database for this cog's name. + + :return: + PyMongo Database object. + """ + return self._conn.get_database() + + def get_collection(self, category: str) -> "motor.core.Collection": + """ + Gets a specified collection within the PyMongo database for this cog. + + Unless you are doing custom stuff ``category`` should be one of the class + attributes of :py:class:`core.config.Config`. + + :param str category: + The group identifier of a category. + :return: + PyMongo collection object. + """ + return self.db[self.cog_name][category] + + @staticmethod + def get_primary_key(identifier_data: IdentifierData) -> Tuple[str, ...]: + # noinspection PyTypeChecker + return identifier_data.primary_key + + async def rebuild_dataset( + self, identifier_data: IdentifierData, cursor: "motor.motor_asyncio.AsyncIOMotorCursor" + ): + ret = {} + async for doc in cursor: + pkeys = doc["_id"]["RED_primary_key"] + del doc["_id"] + doc = self._unescape_dict_keys(doc) + if len(pkeys) == 0: + # Global data + ret.update(**doc) + elif len(pkeys) > 0: + # All other data + partial = ret + for key in pkeys[:-1]: + if key in identifier_data.primary_key: + continue + if key not in partial: + partial[key] = {} + partial = partial[key] + if pkeys[-1] in identifier_data.primary_key: + partial.update(**doc) + else: + partial[pkeys[-1]] = doc + return ret + + async def get(self, identifier_data: IdentifierData): + mongo_collection = self.get_collection(identifier_data.category) + + pkey_filter = self.generate_primary_key_filter(identifier_data) + escaped_identifiers = list(map(self._escape_key, identifier_data.identifiers)) + if len(identifier_data.identifiers) > 0: + proj = {"_id": False, ".".join(escaped_identifiers): True} + + partial = await mongo_collection.find_one(filter=pkey_filter, projection=proj) + else: + # The case here is for partial primary keys like all_members() + cursor = mongo_collection.find(filter=pkey_filter) + partial = await self.rebuild_dataset(identifier_data, cursor) + + if partial is None: + raise KeyError("No matching document was found and Config expects a KeyError.") + + for i in escaped_identifiers: + partial = partial[i] + if isinstance(partial, dict): + return self._unescape_dict_keys(partial) + return partial + + async def set(self, identifier_data: IdentifierData, value=None): + uuid = self._escape_key(identifier_data.uuid) + primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data))) + dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers)) + if isinstance(value, dict): + if len(value) == 0: + await self.clear(identifier_data) + return + value = self._escape_dict_keys(value) + mongo_collection = self.get_collection(identifier_data.category) + num_pkeys = len(primary_key) + + if num_pkeys >= identifier_data.primary_key_len: + # We're setting at the document level or below. + dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers)) + if dot_identifiers: + update_stmt = {"$set": {dot_identifiers: value}} + else: + update_stmt = {"$set": value} + + try: + await mongo_collection.update_one( + {"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}}, + update=update_stmt, + upsert=True, + ) + except pymongo.errors.WriteError as exc: + if exc.args and exc.args[0].startswith("Cannot create field"): + # There's a bit of a failing edge case here... + # If we accidentally set the sub-field of an array, and the key happens to be a + # digit, it will successfully set the value in the array, and not raise an + # error. This is different to how other drivers would behave, and could lead to + # unexpected behaviour. + raise errors.CannotSetSubfield + else: + # Unhandled driver exception, should expose. + raise + + else: + # We're setting above the document level. + # Easiest and most efficient thing to do is delete all documents that we're potentially + # replacing, then insert_many(). + # We'll do it in a transaction so we can roll-back in case something goes horribly + # wrong. + pkey_filter = self.generate_primary_key_filter(identifier_data) + async with await self._conn.start_session() as session: + with contextlib.suppress(pymongo.errors.CollectionInvalid): + # Collections must already exist when inserting documents within a transaction + await self.db.create_collection(mongo_collection.full_name) + try: + async with session.start_transaction(): + await mongo_collection.delete_many(pkey_filter, session=session) + await mongo_collection.insert_many( + self.generate_documents_to_insert( + uuid, primary_key, value, identifier_data.primary_key_len + ), + session=session, + ) + except pymongo.errors.OperationFailure: + # This DB version / setup doesn't support transactions, so we'll have to use + # a shittier method. + + # The strategy here is to separate the existing documents and the new documents + # into ones to be deleted, ones to be replaced, and new ones to be inserted. + # Then we can do a bulk_write(). + + # This is our list of (filter, new_document) tuples for replacing existing + # documents. The `new_document` should be taken and removed from `value`, so + # `value` only ends up containing documents which need to be inserted. + to_replace: List[Tuple[Dict, Dict]] = [] + + # This is our list of primary key filters which need deleting. They should + # simply be all the primary keys which were part of existing documents but are + # not included in the new documents. + to_delete: List[Dict] = [] + async for document in mongo_collection.find(pkey_filter, session=session): + pkey = document["_id"]["RED_primary_key"] + new_document = value + try: + for pkey_part in pkey[num_pkeys:-1]: + new_document = new_document[pkey_part] + # This document is being replaced - remove it from `value`. + new_document = new_document.pop(pkey[-1]) + except KeyError: + # We've found the primary key of an old document which isn't in the + # updated set of documents - it should be deleted. + to_delete.append({"_id": {"RED_uuid": uuid, "RED_primary_key": pkey}}) + else: + _filter = {"_id": {"RED_uuid": uuid, "RED_primary_key": pkey}} + new_document.update(_filter) + to_replace.append((_filter, new_document)) + + # What's left of `value` should be the new documents needing to be inserted. + to_insert = self.generate_documents_to_insert( + uuid, primary_key, value, identifier_data.primary_key_len + ) + requests = list( + itertools.chain( + (pymongo.DeleteOne(f) for f in to_delete), + (pymongo.ReplaceOne(f, d) for f, d in to_replace), + (pymongo.InsertOne(d) for d in to_insert if d), + ) + ) + # This will pipeline the operations so they all complete quickly. However if + # any of them fail, the rest of them will complete - i.e. this operation is not + # atomic. + await mongo_collection.bulk_write(requests, ordered=False) + + def generate_primary_key_filter(self, identifier_data: IdentifierData): + uuid = self._escape_key(identifier_data.uuid) + primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data))) + ret = {"_id.RED_uuid": uuid} + if len(identifier_data.identifiers) > 0: + ret["_id.RED_primary_key"] = primary_key + elif len(identifier_data.primary_key) > 0: + for i, key in enumerate(primary_key): + keyname = f"_id.RED_primary_key.{i}" + ret[keyname] = key + else: + ret["_id.RED_primary_key"] = {"$exists": True} + return ret + + @classmethod + def generate_documents_to_insert( + cls, uuid: str, primary_keys: List[str], data: Dict[str, Dict[str, Any]], pkey_len: int + ) -> Iterator[Dict[str, Any]]: + num_missing_pkeys = pkey_len - len(primary_keys) + if num_missing_pkeys == 1: + for pkey, document in data.items(): + document["_id"] = {"RED_uuid": uuid, "RED_primary_key": primary_keys + [pkey]} + yield document + else: + for pkey, inner_data in data.items(): + for document in cls.generate_documents_to_insert( + uuid, primary_keys + [pkey], inner_data, pkey_len + ): + yield document + + async def clear(self, identifier_data: IdentifierData): + # There are five cases here: + # 1) We're clearing out a subset of identifiers (aka identifiers is NOT empty) + # 2) We're clearing out full primary key and no identifiers + # 3) We're clearing out partial primary key and no identifiers + # 4) Primary key is empty, should wipe all documents in the collection + # 5) Category is empty, all of this cog's data should be deleted + pkey_filter = self.generate_primary_key_filter(identifier_data) + if identifier_data.identifiers: + # This covers case 1 + mongo_collection = self.get_collection(identifier_data.category) + dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers)) + await mongo_collection.update_one(pkey_filter, update={"$unset": {dot_identifiers: 1}}) + elif identifier_data.category: + # This covers cases 2-4 + mongo_collection = self.get_collection(identifier_data.category) + await mongo_collection.delete_many(pkey_filter) + else: + # This covers case 5 + db = self.db + super_collection = db[self.cog_name] + results = await db.list_collections( + filter={"name": {"$regex": rf"^{super_collection.name}\."}} + ) + for result in results: + await db[result["name"]].delete_many(pkey_filter) + + @classmethod + async def aiter_cogs(cls) -> AsyncIterator[Tuple[str, str]]: + db = cls._conn.get_database() + for collection_name in await db.list_collection_names(): + parts = collection_name.split(".") + if not len(parts) == 2: + continue + cog_name = parts[0] + for cog_id in await db[collection_name].distinct("_id.RED_uuid"): + yield cog_name, cog_id + + @classmethod + async def delete_all_data( + cls, *, interactive: bool = False, drop_db: Optional[bool] = None, **kwargs + ) -> None: + """Delete all data being stored by this driver. + + Parameters + ---------- + interactive : bool + Set to ``True`` to allow the method to ask the user for + input from the console, regarding the other unset parameters + for this method. + drop_db : Optional[bool] + Set to ``True`` to drop the entire database for the current + bot's instance. Otherwise, collections which appear to be + storing bot data will be dropped. + + """ + if interactive is True and drop_db is None: + print( + "Please choose from one of the following options:\n" + " 1. Drop the entire MongoDB database for this instance, or\n" + " 2. Delete all of Red's data within this database, without dropping the database " + "itself." + ) + options = ("1", "2") + while True: + resp = input("> ") + try: + drop_db = bool(options.index(resp)) + except ValueError: + print("Please type a number corresponding to one of the options.") + else: + break + db = cls._conn.get_database() + if drop_db is True: + await cls._conn.drop_database(db) + else: + async with await cls._conn.start_session() as session: + async for cog_name, cog_id in cls.aiter_cogs(): + await db.drop_collection(db[cog_name], session=session) + + @staticmethod + def _escape_key(key: str) -> str: + return _SPECIAL_CHAR_PATTERN.sub(_replace_with_escaped, key) + + @staticmethod + def _unescape_key(key: str) -> str: + return _CHAR_ESCAPE_PATTERN.sub(_replace_with_unescaped, key) + + @classmethod + def _escape_dict_keys(cls, data: dict) -> dict: + """Recursively escape all keys in a dict.""" + ret = {} + for key, value in data.items(): + key = cls._escape_key(key) + if isinstance(value, dict): + value = cls._escape_dict_keys(value) + ret[key] = value + return ret + + @classmethod + def _unescape_dict_keys(cls, data: dict) -> dict: + """Recursively unescape all keys in a dict.""" + ret = {} + for key, value in data.items(): + key = cls._unescape_key(key) + if isinstance(value, dict): + value = cls._unescape_dict_keys(value) + ret[key] = value + return ret + + +_SPECIAL_CHAR_PATTERN: Pattern[str] = re.compile(r"([.$]|\\U0000002E|\\U00000024)") +_SPECIAL_CHARS = { + ".": "\\U0000002E", + "$": "\\U00000024", + "\\U0000002E": "\\U&0000002E", + "\\U00000024": "\\U&00000024", +} + + +def _replace_with_escaped(match: Match[str]) -> str: + return _SPECIAL_CHARS[match[0]] + + +_CHAR_ESCAPE_PATTERN: Pattern[str] = re.compile(r"(\\U0000002E|\\U00000024)") +_CHAR_ESCAPES = { + "\\U0000002E": ".", + "\\U00000024": "$", + "\\U&0000002E": "\\U0000002E", + "\\U&00000024": "\\U00000024", +} + + +def _replace_with_unescaped(match: Match[str]) -> str: + return _CHAR_ESCAPES[match[0]] diff --git a/redbot/setup.py b/redbot/setup.py index 5f4e6c8dd..c052cb64c 100644 --- a/redbot/setup.py +++ b/redbot/setup.py @@ -110,7 +110,10 @@ def get_storage_type(): print() print("Please choose your storage backend (if you're unsure, choose 1).") print("1. JSON (file storage, requires no database).") - print("2. PostgreSQL") + print( + "2. PostgreSQL (Requires a database server)" + "\n(Warning: You cannot convert postgres instances to other backends yet)" + ) storage = input("> ") try: storage = int(storage) @@ -196,7 +199,7 @@ def get_target_backend(backend) -> BackendType: async def do_migration( current_backend: BackendType, target_backend: BackendType ) -> Dict[str, Any]: - cur_driver_cls = drivers.get_driver_class(current_backend) + cur_driver_cls = drivers._get_driver_class_include_old(current_backend) new_driver_cls = drivers.get_driver_class(target_backend) cur_storage_details = data_manager.storage_details() new_storage_details = new_driver_cls.get_config_details() @@ -375,7 +378,7 @@ def delete( @cli.command() @click.argument("instance", type=click.Choice(instance_list)) -@click.argument("backend", type=click.Choice(["json", "postgres"])) +@click.argument("backend", type=click.Choice(["json"])) # TODO: GH-3115 def convert(instance, backend): """Convert data backend of an instance.""" current_backend = get_current_backend(instance) @@ -387,8 +390,10 @@ def convert(instance, backend): loop = asyncio.get_event_loop() - if current_backend in (BackendType.MONGOV1, BackendType.MONGO): + if current_backend == BackendType.MONGOV1: raise RuntimeError("Please see the 3.2 release notes for upgrading a bot using mongo.") + elif current_backend == BackendType.POSTGRES: # TODO: GH-3115 + raise RuntimeError("Converting away from postgres isn't currently supported") else: new_storage_details = loop.run_until_complete(do_migration(current_backend, target))