mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 11:18:54 -05:00
[V3 Config] Update Mongo document organization to bypass doc size restriction (#2536)
* modify config to use identifier data class and update json driver * move identifier data attributes into read only properties * Update mongo get and set methods * Update get/set to use UUID separately, make clear work * Remove not implemented and fix get_raw * Update remaining untouched get/set/clear * Fix get_raw * Finally fix get_raw and set_raw * style * This is better * Sorry guys * Update get behavior to handle "all" calls as expected * style again * Why do you do this to me * style once more * Update mongo schema
This commit is contained in:
parent
d6d6d14977
commit
1cd7e41f33
@ -6,7 +6,7 @@ from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, Type
|
||||
import discord
|
||||
|
||||
from .data_manager import cog_data_path, core_data_path
|
||||
from .drivers import get_driver
|
||||
from .drivers import get_driver, IdentifierData
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .drivers.red_base import BaseDriver
|
||||
@ -72,14 +72,14 @@ class Value:
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self, identifiers: Tuple[str], default_value, driver):
|
||||
self.identifiers = identifiers
|
||||
def __init__(self, identifier_data: IdentifierData, default_value, driver):
|
||||
self.identifier_data = identifier_data
|
||||
self.default = default_value
|
||||
self.driver = driver
|
||||
|
||||
async def _get(self, default=...):
|
||||
try:
|
||||
ret = await self.driver.get(*self.identifiers)
|
||||
ret = await self.driver.get(self.identifier_data)
|
||||
except KeyError:
|
||||
return default if default is not ... else self.default
|
||||
return ret
|
||||
@ -150,13 +150,13 @@ class Value:
|
||||
"""
|
||||
if isinstance(value, dict):
|
||||
value = _str_key_dict(value)
|
||||
await self.driver.set(*self.identifiers, value=value)
|
||||
await self.driver.set(self.identifier_data, value=value)
|
||||
|
||||
async def clear(self):
|
||||
"""
|
||||
Clears the value from record for the data element pointed to by `identifiers`.
|
||||
"""
|
||||
await self.driver.clear(*self.identifiers)
|
||||
await self.driver.clear(self.identifier_data)
|
||||
|
||||
|
||||
class Group(Value):
|
||||
@ -178,13 +178,17 @@ class Group(Value):
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, identifiers: Tuple[str], defaults: dict, driver, force_registration: bool = False
|
||||
self,
|
||||
identifier_data: IdentifierData,
|
||||
defaults: dict,
|
||||
driver,
|
||||
force_registration: bool = False,
|
||||
):
|
||||
self._defaults = defaults
|
||||
self.force_registration = force_registration
|
||||
self.driver = driver
|
||||
|
||||
super().__init__(identifiers, {}, self.driver)
|
||||
super().__init__(identifier_data, {}, self.driver)
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
@ -225,22 +229,24 @@ class Group(Value):
|
||||
"""
|
||||
is_group = self.is_group(item)
|
||||
is_value = not is_group and self.is_value(item)
|
||||
new_identifiers = self.identifiers + (item,)
|
||||
new_identifiers = self.identifier_data.add_identifier(item)
|
||||
if is_group:
|
||||
return Group(
|
||||
identifiers=new_identifiers,
|
||||
identifier_data=new_identifiers,
|
||||
defaults=self._defaults[item],
|
||||
driver=self.driver,
|
||||
force_registration=self.force_registration,
|
||||
)
|
||||
elif is_value:
|
||||
return Value(
|
||||
identifiers=new_identifiers, default_value=self._defaults[item], driver=self.driver
|
||||
identifier_data=new_identifiers,
|
||||
default_value=self._defaults[item],
|
||||
driver=self.driver,
|
||||
)
|
||||
elif self.force_registration:
|
||||
raise AttributeError("'{}' is not a valid registered Group or value.".format(item))
|
||||
else:
|
||||
return Value(identifiers=new_identifiers, default_value=None, driver=self.driver)
|
||||
return Value(identifier_data=new_identifiers, default_value=None, driver=self.driver)
|
||||
|
||||
async def clear_raw(self, *nested_path: Any):
|
||||
"""
|
||||
@ -262,8 +268,9 @@ class Group(Value):
|
||||
Multiple arguments that mirror the arguments passed in for nested
|
||||
dict access. These are casted to `str` for you.
|
||||
"""
|
||||
path = [str(p) for p in nested_path]
|
||||
await self.driver.clear(*self.identifiers, *path)
|
||||
path = tuple(str(p) for p in nested_path)
|
||||
identifier_data = self.identifier_data.add_identifier(*path)
|
||||
await self.driver.clear(identifier_data)
|
||||
|
||||
def is_group(self, item: Any) -> bool:
|
||||
"""A helper method for `__getattr__`. Most developers will have no need
|
||||
@ -368,7 +375,7 @@ class Group(Value):
|
||||
If the value does not exist yet in Config's internal storage.
|
||||
|
||||
"""
|
||||
path = [str(p) for p in nested_path]
|
||||
path = tuple(str(p) for p in nested_path)
|
||||
|
||||
if default is ...:
|
||||
poss_default = self.defaults
|
||||
@ -380,8 +387,9 @@ class Group(Value):
|
||||
else:
|
||||
default = poss_default
|
||||
|
||||
identifier_data = self.identifier_data.add_identifier(*path)
|
||||
try:
|
||||
raw = await self.driver.get(*self.identifiers, *path)
|
||||
raw = await self.driver.get(identifier_data)
|
||||
except KeyError:
|
||||
if default is not ...:
|
||||
return default
|
||||
@ -456,10 +464,11 @@ class Group(Value):
|
||||
value
|
||||
The value to store.
|
||||
"""
|
||||
path = [str(p) for p in nested_path]
|
||||
path = tuple(str(p) for p in nested_path)
|
||||
identifier_data = self.identifier_data.add_identifier(*path)
|
||||
if isinstance(value, dict):
|
||||
value = _str_key_dict(value)
|
||||
await self.driver.set(*self.identifiers, *path, value=value)
|
||||
await self.driver.set(identifier_data, value=value)
|
||||
|
||||
|
||||
class Config:
|
||||
@ -779,11 +788,17 @@ class Config:
|
||||
"""
|
||||
self._register_default(group_identifier, **kwargs)
|
||||
|
||||
def _get_base_group(self, key: str, *identifiers: str) -> Group:
|
||||
def _get_base_group(self, category: str, *primary_keys: str) -> Group:
|
||||
# noinspection PyTypeChecker
|
||||
identifier_data = IdentifierData(
|
||||
uuid=self.unique_identifier,
|
||||
category=category,
|
||||
primary_key=primary_keys,
|
||||
identifiers=(),
|
||||
)
|
||||
return Group(
|
||||
identifiers=(key, *identifiers),
|
||||
defaults=self.defaults.get(key, {}),
|
||||
identifier_data=identifier_data,
|
||||
defaults=self.defaults.get(category, {}),
|
||||
driver=self.driver,
|
||||
force_registration=self.force_registration,
|
||||
)
|
||||
@ -904,7 +919,7 @@ class Config:
|
||||
ret = {}
|
||||
|
||||
try:
|
||||
dict_ = await self.driver.get(*group.identifiers)
|
||||
dict_ = await self.driver.get(group.identifier_data)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
@ -1021,7 +1036,7 @@ class Config:
|
||||
if guild is None:
|
||||
group = self._get_base_group(self.MEMBER)
|
||||
try:
|
||||
dict_ = await self.driver.get(*group.identifiers)
|
||||
dict_ = await self.driver.get(group.identifier_data)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
@ -1030,7 +1045,7 @@ class Config:
|
||||
else:
|
||||
group = self._get_base_group(self.MEMBER, str(guild.id))
|
||||
try:
|
||||
guild_data = await self.driver.get(*group.identifiers)
|
||||
guild_data = await self.driver.get(group.identifier_data)
|
||||
except KeyError:
|
||||
pass
|
||||
else:
|
||||
@ -1057,7 +1072,8 @@ class Config:
|
||||
"""
|
||||
if not scopes:
|
||||
# noinspection PyTypeChecker
|
||||
group = Group(identifiers=(), defaults={}, driver=self.driver)
|
||||
identifier_data = IdentifierData(self.unique_identifier, "", (), ())
|
||||
group = Group(identifier_data, defaults={}, driver=self.driver)
|
||||
else:
|
||||
group = self._get_base_group(*scopes)
|
||||
await group.clear()
|
||||
|
||||
@ -1,4 +1,6 @@
|
||||
__all__ = ["get_driver"]
|
||||
from .red_base import IdentifierData
|
||||
|
||||
__all__ = ["get_driver", "IdentifierData"]
|
||||
|
||||
|
||||
def get_driver(type, *args, **kwargs):
|
||||
|
||||
@ -1,4 +1,51 @@
|
||||
__all__ = ["BaseDriver"]
|
||||
from typing import Tuple
|
||||
|
||||
__all__ = ["BaseDriver", "IdentifierData"]
|
||||
|
||||
|
||||
class IdentifierData:
|
||||
def __init__(self, uuid: str, category: str, primary_key: Tuple[str], identifiers: Tuple[str]):
|
||||
self._uuid = uuid
|
||||
self._category = category
|
||||
self._primary_key = primary_key
|
||||
self._identifiers = identifiers
|
||||
|
||||
@property
|
||||
def uuid(self):
|
||||
return self._uuid
|
||||
|
||||
@property
|
||||
def category(self):
|
||||
return self._category
|
||||
|
||||
@property
|
||||
def primary_key(self):
|
||||
return self._primary_key
|
||||
|
||||
@property
|
||||
def identifiers(self):
|
||||
return self._identifiers
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"<IdentifierData uuid={self.uuid} category={self.category} primary_key={self.primary_key}"
|
||||
f" identifiers={self.identifiers}>"
|
||||
)
|
||||
|
||||
def add_identifier(self, *identifier: str) -> "IdentifierData":
|
||||
if not all(isinstance(i, str) for i in identifier):
|
||||
raise ValueError("Identifiers must be strings.")
|
||||
|
||||
return IdentifierData(
|
||||
self.uuid, self.category, self.primary_key, self.identifiers + identifier
|
||||
)
|
||||
|
||||
def to_tuple(self):
|
||||
return tuple(
|
||||
item
|
||||
for item in (self.uuid, self.category, *self.primary_key, *self.identifiers)
|
||||
if len(item) > 0
|
||||
)
|
||||
|
||||
|
||||
class BaseDriver:
|
||||
@ -6,14 +53,13 @@ class BaseDriver:
|
||||
self.cog_name = cog_name
|
||||
self.unique_cog_identifier = identifier
|
||||
|
||||
async def get(self, *identifiers: str):
|
||||
async def get(self, identifier_data: IdentifierData):
|
||||
"""
|
||||
Finds the value indicate by the given identifiers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifiers
|
||||
A list of identifiers that correspond to nested dict accesses.
|
||||
identifier_data
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -33,20 +79,19 @@ class BaseDriver:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def set(self, *identifiers: str, value=None):
|
||||
async def set(self, identifier_data: IdentifierData, value=None):
|
||||
"""
|
||||
Sets the value of the key indicated by the given identifiers.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifiers
|
||||
A list of identifiers that correspond to nested dict accesses.
|
||||
identifier_data
|
||||
value
|
||||
Any JSON serializable python object.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def clear(self, *identifiers: str):
|
||||
async def clear(self, identifier_data: IdentifierData):
|
||||
"""
|
||||
Clears out the value specified by the given identifiers.
|
||||
|
||||
@ -54,7 +99,6 @@ class BaseDriver:
|
||||
|
||||
Parameters
|
||||
----------
|
||||
identifiers
|
||||
A list of identifiers that correspond to nested dict accesses.
|
||||
identifier_data
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
@ -6,7 +6,7 @@ import logging
|
||||
|
||||
from ..json_io import JsonIO
|
||||
|
||||
from .red_base import BaseDriver
|
||||
from .red_base import BaseDriver, IdentifierData
|
||||
|
||||
__all__ = ["JSON"]
|
||||
|
||||
@ -93,16 +93,16 @@ class JSON(BaseDriver):
|
||||
self.data = {}
|
||||
self.jsonIO._save_json(self.data)
|
||||
|
||||
async def get(self, *identifiers: Tuple[str]):
|
||||
async def get(self, identifier_data: IdentifierData):
|
||||
partial = self.data
|
||||
full_identifiers = (self.unique_cog_identifier, *identifiers)
|
||||
full_identifiers = identifier_data.to_tuple()
|
||||
for i in full_identifiers:
|
||||
partial = partial[i]
|
||||
return copy.deepcopy(partial)
|
||||
|
||||
async def set(self, *identifiers: str, value=None):
|
||||
async def set(self, identifier_data: IdentifierData, value=None):
|
||||
partial = self.data
|
||||
full_identifiers = (self.unique_cog_identifier, *identifiers)
|
||||
full_identifiers = identifier_data.to_tuple()
|
||||
for i in full_identifiers[:-1]:
|
||||
if i not in partial:
|
||||
partial[i] = {}
|
||||
@ -111,9 +111,9 @@ class JSON(BaseDriver):
|
||||
partial[full_identifiers[-1]] = copy.deepcopy(value)
|
||||
await self.jsonIO._threadsafe_save_json(self.data)
|
||||
|
||||
async def clear(self, *identifiers: str):
|
||||
async def clear(self, identifier_data: IdentifierData):
|
||||
partial = self.data
|
||||
full_identifiers = (self.unique_cog_identifier, *identifiers)
|
||||
full_identifiers = identifier_data.to_tuple()
|
||||
try:
|
||||
for i in full_identifiers[:-1]:
|
||||
partial = partial[i]
|
||||
|
||||
@ -1,11 +1,12 @@
|
||||
import re
|
||||
from typing import Match, Pattern
|
||||
from typing import Match, Pattern, Tuple
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import motor.core
|
||||
import motor.motor_asyncio
|
||||
from motor.motor_asyncio import AsyncIOMotorCursor
|
||||
|
||||
from .red_base import BaseDriver
|
||||
from .red_base import BaseDriver, IdentifierData
|
||||
|
||||
__all__ = ["Mongo"]
|
||||
|
||||
@ -64,66 +65,119 @@ class Mongo(BaseDriver):
|
||||
"""
|
||||
return _conn.get_database()
|
||||
|
||||
def get_collection(self) -> motor.core.Collection:
|
||||
def get_collection(self, category: str) -> motor.core.Collection:
|
||||
"""
|
||||
Gets a specified collection within the PyMongo database for this cog.
|
||||
|
||||
Unless you are doing custom stuff ``collection_name`` should be one of the class
|
||||
Unless you are doing custom stuff ``category`` should be one of the class
|
||||
attributes of :py:class:`core.config.Config`.
|
||||
|
||||
:param str collection_name:
|
||||
:param str category:
|
||||
:return:
|
||||
PyMongo collection object.
|
||||
"""
|
||||
return self.db[self.cog_name]
|
||||
return self.db[self.cog_name][category]
|
||||
|
||||
@staticmethod
|
||||
def _parse_identifiers(identifiers):
|
||||
uuid, identifiers = identifiers[0], identifiers[1:]
|
||||
return uuid, identifiers
|
||||
def get_primary_key(self, identifier_data: IdentifierData) -> Tuple[str]:
|
||||
# noinspection PyTypeChecker
|
||||
return identifier_data.primary_key
|
||||
|
||||
async def get(self, *identifiers: str):
|
||||
mongo_collection = self.get_collection()
|
||||
async def rebuild_dataset(self, identifier_data: IdentifierData, cursor: AsyncIOMotorCursor):
|
||||
ret = {}
|
||||
async for doc in cursor:
|
||||
pkeys = doc["_id"]["RED_primary_key"]
|
||||
del doc["_id"]
|
||||
if len(pkeys) == 1:
|
||||
# Global data
|
||||
ret.update(**doc)
|
||||
elif len(pkeys) > 1:
|
||||
# All other data
|
||||
partial = ret
|
||||
for key in pkeys[1:-1]:
|
||||
if key in identifier_data.primary_key:
|
||||
continue
|
||||
if key not in partial:
|
||||
partial[key] = {}
|
||||
partial = partial[key]
|
||||
if pkeys[-1] in identifier_data.primary_key:
|
||||
partial.update(**doc)
|
||||
else:
|
||||
partial[pkeys[-1]] = doc
|
||||
else:
|
||||
raise RuntimeError("This should not happen.")
|
||||
return ret
|
||||
|
||||
identifiers = (*map(self._escape_key, identifiers),)
|
||||
dot_identifiers = ".".join(identifiers)
|
||||
async def get(self, identifier_data: IdentifierData):
|
||||
mongo_collection = self.get_collection(identifier_data.category)
|
||||
|
||||
partial = await mongo_collection.find_one(
|
||||
filter={"_id": self.unique_cog_identifier}, projection={dot_identifiers: True}
|
||||
)
|
||||
pkey_filter = self.generate_primary_key_filter(identifier_data)
|
||||
if len(identifier_data.identifiers) > 0:
|
||||
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
|
||||
proj = {"_id": False, dot_identifiers: True}
|
||||
|
||||
partial = await mongo_collection.find_one(filter=pkey_filter, projection=proj)
|
||||
else:
|
||||
# The case here is for partial primary keys like all_members()
|
||||
cursor = mongo_collection.find(filter=pkey_filter)
|
||||
partial = await self.rebuild_dataset(identifier_data, cursor)
|
||||
|
||||
if partial is None:
|
||||
raise KeyError("No matching document was found and Config expects a KeyError.")
|
||||
|
||||
for i in identifiers:
|
||||
for i in identifier_data.identifiers:
|
||||
partial = partial[i]
|
||||
if isinstance(partial, dict):
|
||||
return self._unescape_dict_keys(partial)
|
||||
return partial
|
||||
|
||||
async def set(self, *identifiers: str, value=None):
|
||||
dot_identifiers = ".".join(map(self._escape_key, identifiers))
|
||||
async def set(self, identifier_data: IdentifierData, value=None):
|
||||
uuid = self._escape_key(identifier_data.uuid)
|
||||
primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data)))
|
||||
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
|
||||
if isinstance(value, dict):
|
||||
if len(value) == 0:
|
||||
await self.clear(identifier_data)
|
||||
return
|
||||
value = self._escape_dict_keys(value)
|
||||
|
||||
mongo_collection = self.get_collection()
|
||||
mongo_collection = self.get_collection(identifier_data.category)
|
||||
if len(dot_identifiers) > 0:
|
||||
update_stmt = {"$set": {dot_identifiers: value}}
|
||||
else:
|
||||
update_stmt = {"$set": value}
|
||||
|
||||
await mongo_collection.update_one(
|
||||
{"_id": self.unique_cog_identifier},
|
||||
update={"$set": {dot_identifiers: value}},
|
||||
{"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}},
|
||||
update=update_stmt,
|
||||
upsert=True,
|
||||
)
|
||||
|
||||
async def clear(self, *identifiers: str):
|
||||
dot_identifiers = ".".join(map(self._escape_key, identifiers))
|
||||
mongo_collection = self.get_collection()
|
||||
|
||||
if len(identifiers) > 0:
|
||||
await mongo_collection.update_one(
|
||||
{"_id": self.unique_cog_identifier}, update={"$unset": {dot_identifiers: 1}}
|
||||
)
|
||||
def generate_primary_key_filter(self, identifier_data: IdentifierData):
|
||||
uuid = self._escape_key(identifier_data.uuid)
|
||||
primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data)))
|
||||
ret = {"_id": {"RED_uuid": uuid}}
|
||||
if len(identifier_data.identifiers) > 0:
|
||||
ret["_id"]["RED_primary_key"] = primary_key
|
||||
else:
|
||||
await mongo_collection.delete_one({"_id": self.unique_cog_identifier})
|
||||
for i, key in enumerate(primary_key):
|
||||
keyname = f"RED_primary_key.{i}"
|
||||
ret["_id"][keyname] = key
|
||||
return ret
|
||||
|
||||
async def clear(self, identifier_data: IdentifierData):
|
||||
# There are three cases here:
|
||||
# 1) We're clearing out a subset of identifiers (aka identifiers is NOT empty)
|
||||
# 2) We're clearing out full primary key and no identifiers
|
||||
# 3) We're clearing out partial primary key and no identifiers
|
||||
# 4) Primary key is empty, should wipe all documents in the collection
|
||||
mongo_collection = self.get_collection(identifier_data.category)
|
||||
pkey_filter = self.generate_primary_key_filter(identifier_data)
|
||||
if len(identifier_data.identifiers) == 0:
|
||||
# This covers cases 2-4
|
||||
await mongo_collection.delete_many(pkey_filter)
|
||||
else:
|
||||
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
|
||||
await mongo_collection.update_one(pkey_filter, update={"$unset": {dot_identifiers: 1}})
|
||||
|
||||
@staticmethod
|
||||
def _escape_key(key: str) -> str:
|
||||
|
||||
@ -269,8 +269,9 @@ async def edit_instance():
|
||||
default_dirs["STORAGE_DETAILS"] = storage_details
|
||||
|
||||
if instance_data["STORAGE_TYPE"] == "JSON":
|
||||
if confirm("Would you like to import your data? (y/n) "):
|
||||
await json_to_mongo(current_data_dir, storage_details)
|
||||
raise NotImplementedError("We cannot convert from JSON to MongoDB at this time.")
|
||||
# if confirm("Would you like to import your data? (y/n) "):
|
||||
# await json_to_mongo(current_data_dir, storage_details)
|
||||
else:
|
||||
storage_details = instance_data["STORAGE_DETAILS"]
|
||||
default_dirs["STORAGE_DETAILS"] = {}
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user