diff --git a/redbot/core/config.py b/redbot/core/config.py index 7abf8957e..569e33b70 100644 --- a/redbot/core/config.py +++ b/redbot/core/config.py @@ -1,16 +1,14 @@ import logging import collections from copy import deepcopy -from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar, TYPE_CHECKING +from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar import weakref import discord from .data_manager import cog_data_path, core_data_path from .drivers import get_driver, IdentifierData, BackendType - -if TYPE_CHECKING: - from .drivers.red_base import BaseDriver +from .drivers.red_base import BaseDriver __all__ = ["Config", "get_latest_confs"] @@ -545,7 +543,7 @@ class Config: self, cog_name: str, unique_identifier: str, - driver: "BaseDriver", + driver: BaseDriver, force_registration: bool = False, defaults: dict = None, ): @@ -852,9 +850,16 @@ class Config: custom_group_data=self.custom_groups, is_custom=is_custom, ) + + pkey_len = BaseDriver.get_pkey_len(identifier_data) + if len(primary_keys) < pkey_len: + # Don't mix in defaults with groups higher than the document level + defaults = {} + else: + defaults = self.defaults.get(category, {}) return Group( identifier_data=identifier_data, - defaults=self.defaults.get(category, {}), + defaults=defaults, driver=self.driver, force_registration=self.force_registration, ) @@ -975,6 +980,7 @@ class Config: """ group = self._get_base_group(scope) ret = {} + defaults = self.defaults.get(scope, {}) try: dict_ = await self.driver.get(group.identifier_data) @@ -982,7 +988,7 @@ class Config: pass else: for k, v in dict_.items(): - data = group.defaults + data = deepcopy(defaults) data.update(v) ret[int(k)] = data @@ -1056,11 +1062,11 @@ class Config: """ return await self._all_from_scope(self.USER) - @staticmethod - def _all_members_from_guild(group: Group, guild_data: dict) -> dict: + def _all_members_from_guild(self, guild_data: dict) -> dict: ret = {} + defaults = self.defaults.get(self.MEMBER, {}) for member_id, member_data in guild_data.items(): - new_member_data = group.defaults + new_member_data = deepcopy(defaults) new_member_data.update(member_data) ret[int(member_id)] = new_member_data return ret @@ -1099,7 +1105,7 @@ class Config: pass else: for guild_id, guild_data in dict_.items(): - ret[int(guild_id)] = self._all_members_from_guild(group, guild_data) + ret[int(guild_id)] = self._all_members_from_guild(guild_data) else: group = self._get_base_group(self.MEMBER, str(guild.id)) try: @@ -1107,7 +1113,7 @@ class Config: except KeyError: pass else: - ret = self._all_members_from_guild(group, guild_data) + ret = self._all_members_from_guild(guild_data) return ret async def _clear_scope(self, *scopes: str): diff --git a/redbot/core/drivers/red_base.py b/redbot/core/drivers/red_base.py index cd7091a71..2e5d41c6f 100644 --- a/redbot/core/drivers/red_base.py +++ b/redbot/core/drivers/red_base.py @@ -18,8 +18,8 @@ class IdentifierData: self, uuid: str, category: str, - primary_key: Tuple[str], - identifiers: Tuple[str], + primary_key: Tuple[str, ...], + identifiers: Tuple[str, ...], custom_group_data: dict, is_custom: bool = False, ): @@ -183,7 +183,7 @@ class BaseDriver: c, (), (), - custom_group_data.get(c, {}), + custom_group_data, is_custom=c in custom_group_data, ) try: @@ -202,7 +202,19 @@ class BaseDriver: category, pkey, (), - custom_group_data.get(category, {}), + custom_group_data, is_custom=category in custom_group_data, ) await self.set(ident_data, data) + + @staticmethod + def get_pkey_len(identifier_data: IdentifierData) -> int: + cat = identifier_data.category + if cat == ConfigCategory.GLOBAL.value: + return 0 + elif cat == ConfigCategory.MEMBER.value: + return 2 + elif identifier_data.is_custom: + return identifier_data.custom_group_data[cat] + else: + return 1 diff --git a/redbot/core/drivers/red_json.py b/redbot/core/drivers/red_json.py index 73023ffeb..7e6f7c333 100644 --- a/redbot/core/drivers/red_json.py +++ b/redbot/core/drivers/red_json.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Tuple import copy import weakref import logging @@ -156,7 +155,7 @@ class JSON(BaseDriver): category, pkey, (), - custom_group_data.get(category, {}), + custom_group_data, is_custom=category in custom_group_data, ) update_write_data(ident_data, data) diff --git a/redbot/core/drivers/red_mongo.py b/redbot/core/drivers/red_mongo.py index 4554aa0a5..aabd36adf 100644 --- a/redbot/core/drivers/red_mongo.py +++ b/redbot/core/drivers/red_mongo.py @@ -1,11 +1,13 @@ +import contextlib +import itertools import re from getpass import getpass -from typing import Match, Pattern, Tuple +from typing import Match, Pattern, Tuple, Any, Dict, Iterator, List from urllib.parse import quote_plus import motor.core import motor.motor_asyncio -from motor.motor_asyncio import AsyncIOMotorCursor +import pymongo.errors from .red_base import BaseDriver, IdentifierData @@ -36,7 +38,7 @@ def _initialize(**kwargs): url = "{}://{}{}/{}".format(uri, host, ports, db_name) global _conn - _conn = motor.motor_asyncio.AsyncIOMotorClient(url) + _conn = motor.motor_asyncio.AsyncIOMotorClient(url, retryWrites=True) class Mongo(BaseDriver): @@ -87,7 +89,9 @@ class Mongo(BaseDriver): # noinspection PyTypeChecker return identifier_data.primary_key - async def rebuild_dataset(self, identifier_data: IdentifierData, cursor: AsyncIOMotorCursor): + async def rebuild_dataset( + self, identifier_data: IdentifierData, cursor: motor.motor_asyncio.AsyncIOMotorCursor + ): ret = {} async for doc in cursor: pkeys = doc["_id"]["RED_primary_key"] @@ -137,24 +141,96 @@ class Mongo(BaseDriver): async def set(self, identifier_data: IdentifierData, value=None): uuid = self._escape_key(identifier_data.uuid) primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data))) - dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers)) if isinstance(value, dict): if len(value) == 0: await self.clear(identifier_data) return value = self._escape_dict_keys(value) - mongo_collection = self.get_collection(identifier_data.category) - if len(dot_identifiers) > 0: - update_stmt = {"$set": {dot_identifiers: value}} - else: - update_stmt = {"$set": value} + pkey_len = self.get_pkey_len(identifier_data) + num_pkeys = len(primary_key) - await mongo_collection.update_one( - {"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}}, - update=update_stmt, - upsert=True, - ) + if num_pkeys >= pkey_len: + # We're setting at the document level or below. + dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers)) + if dot_identifiers: + update_stmt = {"$set": {dot_identifiers: value}} + else: + update_stmt = {"$set": value} + + await mongo_collection.update_one( + {"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}}, + update=update_stmt, + upsert=True, + ) + + else: + # We're setting above the document level. + # Easiest and most efficient thing to do is delete all documents that we're potentially + # replacing, then insert_many(). + # We'll do it in a transaction so we can roll-back in case something goes horribly + # wrong. + pkey_filter = self.generate_primary_key_filter(identifier_data) + async with await _conn.start_session() as session: + with contextlib.suppress(pymongo.errors.CollectionInvalid): + # Collections must already exist when inserting documents within a transaction + await _conn.get_database().create_collection(mongo_collection.full_name) + try: + async with session.start_transaction(): + await mongo_collection.delete_many(pkey_filter, session=session) + await mongo_collection.insert_many( + self.generate_documents_to_insert(uuid, primary_key, value, pkey_len), + session=session, + ) + except pymongo.errors.OperationFailure: + # This DB version / setup doesn't support transactions, so we'll have to use + # a shittier method. + + # The strategy here is to separate the existing documents and the new documents + # into ones to be deleted, ones to be replaced, and new ones to be inserted. + # Then we can do a bulk_write(). + + # This is our list of (filter, new_document) tuples for replacing existing + # documents. The `new_document` should be taken and removed from `value`, so + # `value` only ends up containing documents which need to be inserted. + to_replace: List[Tuple[Dict, Dict]] = [] + + # This is our list of primary key filters which need deleting. They should + # simply be all the primary keys which were part of existing documents but are + # not included in the new documents. + to_delete: List[Dict] = [] + async for document in mongo_collection.find(pkey_filter, session=session): + pkey = document["_id"]["RED_primary_key"] + new_document = value + try: + for pkey_part in pkey[num_pkeys:-1]: + new_document = new_document[pkey_part] + # This document is being replaced - remove it from `value`. + new_document = new_document.pop(pkey[-1]) + except KeyError: + # We've found the primary key of an old document which isn't in the + # updated set of documents - it should be deleted. + to_delete.append({"_id": {"RED_uuid": uuid, "RED_primary_key": pkey}}) + else: + _filter = {"_id": {"RED_uuid": uuid, "RED_primary_key": pkey}} + new_document.update(_filter) + to_replace.append((_filter, new_document)) + + # What's left of `value` should be the new documents needing to be inserted. + to_insert = self.generate_documents_to_insert( + uuid, primary_key, value, pkey_len + ) + requests = list( + itertools.chain( + (pymongo.DeleteOne(f) for f in to_delete), + (pymongo.ReplaceOne(f, d) for f, d in to_replace), + (pymongo.InsertOne(d) for d in to_insert if d), + ) + ) + # This will pipeline the operations so they all complete quickly. However if + # any of them fail, the rest of them will complete - i.e. this operation is not + # atomic. + await mongo_collection.bulk_write(requests, ordered=False) def generate_primary_key_filter(self, identifier_data: IdentifierData): uuid = self._escape_key(identifier_data.uuid) @@ -170,20 +246,48 @@ class Mongo(BaseDriver): ret["_id.RED_primary_key"] = {"$exists": True} return ret + @classmethod + def generate_documents_to_insert( + cls, uuid: str, primary_keys: List[str], data: Dict[str, Dict[str, Any]], pkey_len: int + ) -> Iterator[Dict[str, Any]]: + num_missing_pkeys = pkey_len - len(primary_keys) + if num_missing_pkeys == 1: + for pkey, document in data.items(): + document["_id"] = {"RED_uuid": uuid, "RED_primary_key": primary_keys + [pkey]} + yield document + else: + for pkey, inner_data in data.items(): + for document in cls.generate_documents_to_insert( + uuid, primary_keys + [pkey], inner_data, pkey_len + ): + yield document + async def clear(self, identifier_data: IdentifierData): - # There are three cases here: + # There are five cases here: # 1) We're clearing out a subset of identifiers (aka identifiers is NOT empty) # 2) We're clearing out full primary key and no identifiers # 3) We're clearing out partial primary key and no identifiers # 4) Primary key is empty, should wipe all documents in the collection - mongo_collection = self.get_collection(identifier_data.category) + # 5) Category is empty, all of this cog's data should be deleted pkey_filter = self.generate_primary_key_filter(identifier_data) - if len(identifier_data.identifiers) == 0: - # This covers cases 2-4 - await mongo_collection.delete_many(pkey_filter) - else: + if identifier_data.identifiers: + # This covers case 1 + mongo_collection = self.get_collection(identifier_data.category) dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers)) await mongo_collection.update_one(pkey_filter, update={"$unset": {dot_identifiers: 1}}) + elif identifier_data.category: + # This covers cases 2-4 + mongo_collection = self.get_collection(identifier_data.category) + await mongo_collection.delete_many(pkey_filter) + else: + # This covers case 5 + db = self.db + super_collection = db[self.cog_name] + results = await db.list_collections( + filter={"name": {"$regex": rf"^{super_collection.name}\."}} + ) + for result in results: + await db[result["name"]].delete_many(pkey_filter) @staticmethod def _escape_key(key: str) -> str: