diff --git a/redbot/core/config.py b/redbot/core/config.py index c513e07f9..f0c6bb5d3 100644 --- a/redbot/core/config.py +++ b/redbot/core/config.py @@ -6,7 +6,7 @@ from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, Type import discord from .data_manager import cog_data_path, core_data_path -from .drivers import get_driver +from .drivers import get_driver, IdentifierData if TYPE_CHECKING: from .drivers.red_base import BaseDriver @@ -72,14 +72,14 @@ class Value: """ - def __init__(self, identifiers: Tuple[str], default_value, driver): - self.identifiers = identifiers + def __init__(self, identifier_data: IdentifierData, default_value, driver): + self.identifier_data = identifier_data self.default = default_value self.driver = driver async def _get(self, default=...): try: - ret = await self.driver.get(*self.identifiers) + ret = await self.driver.get(self.identifier_data) except KeyError: return default if default is not ... else self.default return ret @@ -150,13 +150,13 @@ class Value: """ if isinstance(value, dict): value = _str_key_dict(value) - await self.driver.set(*self.identifiers, value=value) + await self.driver.set(self.identifier_data, value=value) async def clear(self): """ Clears the value from record for the data element pointed to by `identifiers`. """ - await self.driver.clear(*self.identifiers) + await self.driver.clear(self.identifier_data) class Group(Value): @@ -178,13 +178,17 @@ class Group(Value): """ def __init__( - self, identifiers: Tuple[str], defaults: dict, driver, force_registration: bool = False + self, + identifier_data: IdentifierData, + defaults: dict, + driver, + force_registration: bool = False, ): self._defaults = defaults self.force_registration = force_registration self.driver = driver - super().__init__(identifiers, {}, self.driver) + super().__init__(identifier_data, {}, self.driver) @property def defaults(self): @@ -225,22 +229,24 @@ class Group(Value): """ is_group = self.is_group(item) is_value = not is_group and self.is_value(item) - new_identifiers = self.identifiers + (item,) + new_identifiers = self.identifier_data.add_identifier(item) if is_group: return Group( - identifiers=new_identifiers, + identifier_data=new_identifiers, defaults=self._defaults[item], driver=self.driver, force_registration=self.force_registration, ) elif is_value: return Value( - identifiers=new_identifiers, default_value=self._defaults[item], driver=self.driver + identifier_data=new_identifiers, + default_value=self._defaults[item], + driver=self.driver, ) elif self.force_registration: raise AttributeError("'{}' is not a valid registered Group or value.".format(item)) else: - return Value(identifiers=new_identifiers, default_value=None, driver=self.driver) + return Value(identifier_data=new_identifiers, default_value=None, driver=self.driver) async def clear_raw(self, *nested_path: Any): """ @@ -262,8 +268,9 @@ class Group(Value): Multiple arguments that mirror the arguments passed in for nested dict access. These are casted to `str` for you. """ - path = [str(p) for p in nested_path] - await self.driver.clear(*self.identifiers, *path) + path = tuple(str(p) for p in nested_path) + identifier_data = self.identifier_data.add_identifier(*path) + await self.driver.clear(identifier_data) def is_group(self, item: Any) -> bool: """A helper method for `__getattr__`. Most developers will have no need @@ -368,7 +375,7 @@ class Group(Value): If the value does not exist yet in Config's internal storage. """ - path = [str(p) for p in nested_path] + path = tuple(str(p) for p in nested_path) if default is ...: poss_default = self.defaults @@ -380,8 +387,9 @@ class Group(Value): else: default = poss_default + identifier_data = self.identifier_data.add_identifier(*path) try: - raw = await self.driver.get(*self.identifiers, *path) + raw = await self.driver.get(identifier_data) except KeyError: if default is not ...: return default @@ -456,10 +464,11 @@ class Group(Value): value The value to store. """ - path = [str(p) for p in nested_path] + path = tuple(str(p) for p in nested_path) + identifier_data = self.identifier_data.add_identifier(*path) if isinstance(value, dict): value = _str_key_dict(value) - await self.driver.set(*self.identifiers, *path, value=value) + await self.driver.set(identifier_data, value=value) class Config: @@ -779,11 +788,17 @@ class Config: """ self._register_default(group_identifier, **kwargs) - def _get_base_group(self, key: str, *identifiers: str) -> Group: + def _get_base_group(self, category: str, *primary_keys: str) -> Group: # noinspection PyTypeChecker + identifier_data = IdentifierData( + uuid=self.unique_identifier, + category=category, + primary_key=primary_keys, + identifiers=(), + ) return Group( - identifiers=(key, *identifiers), - defaults=self.defaults.get(key, {}), + identifier_data=identifier_data, + defaults=self.defaults.get(category, {}), driver=self.driver, force_registration=self.force_registration, ) @@ -904,7 +919,7 @@ class Config: ret = {} try: - dict_ = await self.driver.get(*group.identifiers) + dict_ = await self.driver.get(group.identifier_data) except KeyError: pass else: @@ -1021,7 +1036,7 @@ class Config: if guild is None: group = self._get_base_group(self.MEMBER) try: - dict_ = await self.driver.get(*group.identifiers) + dict_ = await self.driver.get(group.identifier_data) except KeyError: pass else: @@ -1030,7 +1045,7 @@ class Config: else: group = self._get_base_group(self.MEMBER, str(guild.id)) try: - guild_data = await self.driver.get(*group.identifiers) + guild_data = await self.driver.get(group.identifier_data) except KeyError: pass else: @@ -1057,7 +1072,8 @@ class Config: """ if not scopes: # noinspection PyTypeChecker - group = Group(identifiers=(), defaults={}, driver=self.driver) + identifier_data = IdentifierData(self.unique_identifier, "", (), ()) + group = Group(identifier_data, defaults={}, driver=self.driver) else: group = self._get_base_group(*scopes) await group.clear() diff --git a/redbot/core/drivers/__init__.py b/redbot/core/drivers/__init__.py index 6cf6ca2ca..2809427d9 100644 --- a/redbot/core/drivers/__init__.py +++ b/redbot/core/drivers/__init__.py @@ -1,4 +1,6 @@ -__all__ = ["get_driver"] +from .red_base import IdentifierData + +__all__ = ["get_driver", "IdentifierData"] def get_driver(type, *args, **kwargs): diff --git a/redbot/core/drivers/red_base.py b/redbot/core/drivers/red_base.py index 38b8bb14f..11454008e 100644 --- a/redbot/core/drivers/red_base.py +++ b/redbot/core/drivers/red_base.py @@ -1,4 +1,51 @@ -__all__ = ["BaseDriver"] +from typing import Tuple + +__all__ = ["BaseDriver", "IdentifierData"] + + +class IdentifierData: + def __init__(self, uuid: str, category: str, primary_key: Tuple[str], identifiers: Tuple[str]): + self._uuid = uuid + self._category = category + self._primary_key = primary_key + self._identifiers = identifiers + + @property + def uuid(self): + return self._uuid + + @property + def category(self): + return self._category + + @property + def primary_key(self): + return self._primary_key + + @property + def identifiers(self): + return self._identifiers + + def __repr__(self): + return ( + f"" + ) + + def add_identifier(self, *identifier: str) -> "IdentifierData": + if not all(isinstance(i, str) for i in identifier): + raise ValueError("Identifiers must be strings.") + + return IdentifierData( + self.uuid, self.category, self.primary_key, self.identifiers + identifier + ) + + def to_tuple(self): + return tuple( + item + for item in (self.uuid, self.category, *self.primary_key, *self.identifiers) + if len(item) > 0 + ) class BaseDriver: @@ -6,14 +53,13 @@ class BaseDriver: self.cog_name = cog_name self.unique_cog_identifier = identifier - async def get(self, *identifiers: str): + async def get(self, identifier_data: IdentifierData): """ Finds the value indicate by the given identifiers. Parameters ---------- - identifiers - A list of identifiers that correspond to nested dict accesses. + identifier_data Returns ------- @@ -33,20 +79,19 @@ class BaseDriver: """ raise NotImplementedError - async def set(self, *identifiers: str, value=None): + async def set(self, identifier_data: IdentifierData, value=None): """ Sets the value of the key indicated by the given identifiers. Parameters ---------- - identifiers - A list of identifiers that correspond to nested dict accesses. + identifier_data value Any JSON serializable python object. """ raise NotImplementedError - async def clear(self, *identifiers: str): + async def clear(self, identifier_data: IdentifierData): """ Clears out the value specified by the given identifiers. @@ -54,7 +99,6 @@ class BaseDriver: Parameters ---------- - identifiers - A list of identifiers that correspond to nested dict accesses. + identifier_data """ raise NotImplementedError diff --git a/redbot/core/drivers/red_json.py b/redbot/core/drivers/red_json.py index 0bb4f9ae8..b9bf3ca00 100644 --- a/redbot/core/drivers/red_json.py +++ b/redbot/core/drivers/red_json.py @@ -6,7 +6,7 @@ import logging from ..json_io import JsonIO -from .red_base import BaseDriver +from .red_base import BaseDriver, IdentifierData __all__ = ["JSON"] @@ -93,16 +93,16 @@ class JSON(BaseDriver): self.data = {} self.jsonIO._save_json(self.data) - async def get(self, *identifiers: Tuple[str]): + async def get(self, identifier_data: IdentifierData): partial = self.data - full_identifiers = (self.unique_cog_identifier, *identifiers) + full_identifiers = identifier_data.to_tuple() for i in full_identifiers: partial = partial[i] return copy.deepcopy(partial) - async def set(self, *identifiers: str, value=None): + async def set(self, identifier_data: IdentifierData, value=None): partial = self.data - full_identifiers = (self.unique_cog_identifier, *identifiers) + full_identifiers = identifier_data.to_tuple() for i in full_identifiers[:-1]: if i not in partial: partial[i] = {} @@ -111,9 +111,9 @@ class JSON(BaseDriver): partial[full_identifiers[-1]] = copy.deepcopy(value) await self.jsonIO._threadsafe_save_json(self.data) - async def clear(self, *identifiers: str): + async def clear(self, identifier_data: IdentifierData): partial = self.data - full_identifiers = (self.unique_cog_identifier, *identifiers) + full_identifiers = identifier_data.to_tuple() try: for i in full_identifiers[:-1]: partial = partial[i] diff --git a/redbot/core/drivers/red_mongo.py b/redbot/core/drivers/red_mongo.py index 6f8415bbd..2d00d468a 100644 --- a/redbot/core/drivers/red_mongo.py +++ b/redbot/core/drivers/red_mongo.py @@ -1,11 +1,12 @@ import re -from typing import Match, Pattern +from typing import Match, Pattern, Tuple from urllib.parse import quote_plus import motor.core import motor.motor_asyncio +from motor.motor_asyncio import AsyncIOMotorCursor -from .red_base import BaseDriver +from .red_base import BaseDriver, IdentifierData __all__ = ["Mongo"] @@ -64,66 +65,119 @@ class Mongo(BaseDriver): """ return _conn.get_database() - def get_collection(self) -> motor.core.Collection: + def get_collection(self, category: str) -> motor.core.Collection: """ Gets a specified collection within the PyMongo database for this cog. - Unless you are doing custom stuff ``collection_name`` should be one of the class + Unless you are doing custom stuff ``category`` should be one of the class attributes of :py:class:`core.config.Config`. - :param str collection_name: + :param str category: :return: PyMongo collection object. """ - return self.db[self.cog_name] + return self.db[self.cog_name][category] - @staticmethod - def _parse_identifiers(identifiers): - uuid, identifiers = identifiers[0], identifiers[1:] - return uuid, identifiers + def get_primary_key(self, identifier_data: IdentifierData) -> Tuple[str]: + # noinspection PyTypeChecker + return identifier_data.primary_key - async def get(self, *identifiers: str): - mongo_collection = self.get_collection() + async def rebuild_dataset(self, identifier_data: IdentifierData, cursor: AsyncIOMotorCursor): + ret = {} + async for doc in cursor: + pkeys = doc["_id"]["RED_primary_key"] + del doc["_id"] + if len(pkeys) == 1: + # Global data + ret.update(**doc) + elif len(pkeys) > 1: + # All other data + partial = ret + for key in pkeys[1:-1]: + if key in identifier_data.primary_key: + continue + if key not in partial: + partial[key] = {} + partial = partial[key] + if pkeys[-1] in identifier_data.primary_key: + partial.update(**doc) + else: + partial[pkeys[-1]] = doc + else: + raise RuntimeError("This should not happen.") + return ret - identifiers = (*map(self._escape_key, identifiers),) - dot_identifiers = ".".join(identifiers) + async def get(self, identifier_data: IdentifierData): + mongo_collection = self.get_collection(identifier_data.category) - partial = await mongo_collection.find_one( - filter={"_id": self.unique_cog_identifier}, projection={dot_identifiers: True} - ) + pkey_filter = self.generate_primary_key_filter(identifier_data) + if len(identifier_data.identifiers) > 0: + dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers)) + proj = {"_id": False, dot_identifiers: True} + + partial = await mongo_collection.find_one(filter=pkey_filter, projection=proj) + else: + # The case here is for partial primary keys like all_members() + cursor = mongo_collection.find(filter=pkey_filter) + partial = await self.rebuild_dataset(identifier_data, cursor) if partial is None: raise KeyError("No matching document was found and Config expects a KeyError.") - for i in identifiers: + for i in identifier_data.identifiers: partial = partial[i] if isinstance(partial, dict): return self._unescape_dict_keys(partial) return partial - async def set(self, *identifiers: str, value=None): - dot_identifiers = ".".join(map(self._escape_key, identifiers)) + async def set(self, identifier_data: IdentifierData, value=None): + uuid = self._escape_key(identifier_data.uuid) + primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data))) + dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers)) if isinstance(value, dict): + if len(value) == 0: + await self.clear(identifier_data) + return value = self._escape_dict_keys(value) - mongo_collection = self.get_collection() + mongo_collection = self.get_collection(identifier_data.category) + if len(dot_identifiers) > 0: + update_stmt = {"$set": {dot_identifiers: value}} + else: + update_stmt = {"$set": value} await mongo_collection.update_one( - {"_id": self.unique_cog_identifier}, - update={"$set": {dot_identifiers: value}}, + {"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}}, + update=update_stmt, upsert=True, ) - async def clear(self, *identifiers: str): - dot_identifiers = ".".join(map(self._escape_key, identifiers)) - mongo_collection = self.get_collection() - - if len(identifiers) > 0: - await mongo_collection.update_one( - {"_id": self.unique_cog_identifier}, update={"$unset": {dot_identifiers: 1}} - ) + def generate_primary_key_filter(self, identifier_data: IdentifierData): + uuid = self._escape_key(identifier_data.uuid) + primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data))) + ret = {"_id": {"RED_uuid": uuid}} + if len(identifier_data.identifiers) > 0: + ret["_id"]["RED_primary_key"] = primary_key else: - await mongo_collection.delete_one({"_id": self.unique_cog_identifier}) + for i, key in enumerate(primary_key): + keyname = f"RED_primary_key.{i}" + ret["_id"][keyname] = key + return ret + + async def clear(self, identifier_data: IdentifierData): + # There are three cases here: + # 1) We're clearing out a subset of identifiers (aka identifiers is NOT empty) + # 2) We're clearing out full primary key and no identifiers + # 3) We're clearing out partial primary key and no identifiers + # 4) Primary key is empty, should wipe all documents in the collection + mongo_collection = self.get_collection(identifier_data.category) + pkey_filter = self.generate_primary_key_filter(identifier_data) + if len(identifier_data.identifiers) == 0: + # This covers cases 2-4 + await mongo_collection.delete_many(pkey_filter) + else: + dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers)) + await mongo_collection.update_one(pkey_filter, update={"$unset": {dot_identifiers: 1}}) @staticmethod def _escape_key(key: str) -> str: diff --git a/redbot/setup.py b/redbot/setup.py index d34bb8035..00a93f8a3 100644 --- a/redbot/setup.py +++ b/redbot/setup.py @@ -269,8 +269,9 @@ async def edit_instance(): default_dirs["STORAGE_DETAILS"] = storage_details if instance_data["STORAGE_TYPE"] == "JSON": - if confirm("Would you like to import your data? (y/n) "): - await json_to_mongo(current_data_dir, storage_details) + raise NotImplementedError("We cannot convert from JSON to MongoDB at this time.") + # if confirm("Would you like to import your data? (y/n) "): + # await json_to_mongo(current_data_dir, storage_details) else: storage_details = instance_data["STORAGE_DETAILS"] default_dirs["STORAGE_DETAILS"] = {}