[V3 Config] Update Mongo document organization to bypass doc size restriction (#2536)

* modify config to use identifier data class and update json driver

* move identifier data attributes into read only properties

* Update mongo get and set methods

* Update get/set to use UUID separately, make clear work

* Remove not implemented and fix get_raw

* Update remaining untouched get/set/clear

* Fix get_raw

* Finally fix get_raw and set_raw

* style

* This is better

* Sorry guys

* Update get behavior to handle "all" calls as expected

* style again

* Why do you do this to me

* style once more

* Update mongo schema
This commit is contained in:
Will 2019-04-03 09:04:47 -04:00 committed by GitHub
parent d6d6d14977
commit 1cd7e41f33
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 194 additions and 77 deletions

View File

@ -6,7 +6,7 @@ from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, Type
import discord import discord
from .data_manager import cog_data_path, core_data_path from .data_manager import cog_data_path, core_data_path
from .drivers import get_driver from .drivers import get_driver, IdentifierData
if TYPE_CHECKING: if TYPE_CHECKING:
from .drivers.red_base import BaseDriver from .drivers.red_base import BaseDriver
@ -72,14 +72,14 @@ class Value:
""" """
def __init__(self, identifiers: Tuple[str], default_value, driver): def __init__(self, identifier_data: IdentifierData, default_value, driver):
self.identifiers = identifiers self.identifier_data = identifier_data
self.default = default_value self.default = default_value
self.driver = driver self.driver = driver
async def _get(self, default=...): async def _get(self, default=...):
try: try:
ret = await self.driver.get(*self.identifiers) ret = await self.driver.get(self.identifier_data)
except KeyError: except KeyError:
return default if default is not ... else self.default return default if default is not ... else self.default
return ret return ret
@ -150,13 +150,13 @@ class Value:
""" """
if isinstance(value, dict): if isinstance(value, dict):
value = _str_key_dict(value) value = _str_key_dict(value)
await self.driver.set(*self.identifiers, value=value) await self.driver.set(self.identifier_data, value=value)
async def clear(self): async def clear(self):
""" """
Clears the value from record for the data element pointed to by `identifiers`. Clears the value from record for the data element pointed to by `identifiers`.
""" """
await self.driver.clear(*self.identifiers) await self.driver.clear(self.identifier_data)
class Group(Value): class Group(Value):
@ -178,13 +178,17 @@ class Group(Value):
""" """
def __init__( def __init__(
self, identifiers: Tuple[str], defaults: dict, driver, force_registration: bool = False self,
identifier_data: IdentifierData,
defaults: dict,
driver,
force_registration: bool = False,
): ):
self._defaults = defaults self._defaults = defaults
self.force_registration = force_registration self.force_registration = force_registration
self.driver = driver self.driver = driver
super().__init__(identifiers, {}, self.driver) super().__init__(identifier_data, {}, self.driver)
@property @property
def defaults(self): def defaults(self):
@ -225,22 +229,24 @@ class Group(Value):
""" """
is_group = self.is_group(item) is_group = self.is_group(item)
is_value = not is_group and self.is_value(item) is_value = not is_group and self.is_value(item)
new_identifiers = self.identifiers + (item,) new_identifiers = self.identifier_data.add_identifier(item)
if is_group: if is_group:
return Group( return Group(
identifiers=new_identifiers, identifier_data=new_identifiers,
defaults=self._defaults[item], defaults=self._defaults[item],
driver=self.driver, driver=self.driver,
force_registration=self.force_registration, force_registration=self.force_registration,
) )
elif is_value: elif is_value:
return Value( return Value(
identifiers=new_identifiers, default_value=self._defaults[item], driver=self.driver identifier_data=new_identifiers,
default_value=self._defaults[item],
driver=self.driver,
) )
elif self.force_registration: elif self.force_registration:
raise AttributeError("'{}' is not a valid registered Group or value.".format(item)) raise AttributeError("'{}' is not a valid registered Group or value.".format(item))
else: else:
return Value(identifiers=new_identifiers, default_value=None, driver=self.driver) return Value(identifier_data=new_identifiers, default_value=None, driver=self.driver)
async def clear_raw(self, *nested_path: Any): async def clear_raw(self, *nested_path: Any):
""" """
@ -262,8 +268,9 @@ class Group(Value):
Multiple arguments that mirror the arguments passed in for nested Multiple arguments that mirror the arguments passed in for nested
dict access. These are casted to `str` for you. dict access. These are casted to `str` for you.
""" """
path = [str(p) for p in nested_path] path = tuple(str(p) for p in nested_path)
await self.driver.clear(*self.identifiers, *path) identifier_data = self.identifier_data.add_identifier(*path)
await self.driver.clear(identifier_data)
def is_group(self, item: Any) -> bool: def is_group(self, item: Any) -> bool:
"""A helper method for `__getattr__`. Most developers will have no need """A helper method for `__getattr__`. Most developers will have no need
@ -368,7 +375,7 @@ class Group(Value):
If the value does not exist yet in Config's internal storage. If the value does not exist yet in Config's internal storage.
""" """
path = [str(p) for p in nested_path] path = tuple(str(p) for p in nested_path)
if default is ...: if default is ...:
poss_default = self.defaults poss_default = self.defaults
@ -380,8 +387,9 @@ class Group(Value):
else: else:
default = poss_default default = poss_default
identifier_data = self.identifier_data.add_identifier(*path)
try: try:
raw = await self.driver.get(*self.identifiers, *path) raw = await self.driver.get(identifier_data)
except KeyError: except KeyError:
if default is not ...: if default is not ...:
return default return default
@ -456,10 +464,11 @@ class Group(Value):
value value
The value to store. The value to store.
""" """
path = [str(p) for p in nested_path] path = tuple(str(p) for p in nested_path)
identifier_data = self.identifier_data.add_identifier(*path)
if isinstance(value, dict): if isinstance(value, dict):
value = _str_key_dict(value) value = _str_key_dict(value)
await self.driver.set(*self.identifiers, *path, value=value) await self.driver.set(identifier_data, value=value)
class Config: class Config:
@ -779,11 +788,17 @@ class Config:
""" """
self._register_default(group_identifier, **kwargs) self._register_default(group_identifier, **kwargs)
def _get_base_group(self, key: str, *identifiers: str) -> Group: def _get_base_group(self, category: str, *primary_keys: str) -> Group:
# noinspection PyTypeChecker # noinspection PyTypeChecker
identifier_data = IdentifierData(
uuid=self.unique_identifier,
category=category,
primary_key=primary_keys,
identifiers=(),
)
return Group( return Group(
identifiers=(key, *identifiers), identifier_data=identifier_data,
defaults=self.defaults.get(key, {}), defaults=self.defaults.get(category, {}),
driver=self.driver, driver=self.driver,
force_registration=self.force_registration, force_registration=self.force_registration,
) )
@ -904,7 +919,7 @@ class Config:
ret = {} ret = {}
try: try:
dict_ = await self.driver.get(*group.identifiers) dict_ = await self.driver.get(group.identifier_data)
except KeyError: except KeyError:
pass pass
else: else:
@ -1021,7 +1036,7 @@ class Config:
if guild is None: if guild is None:
group = self._get_base_group(self.MEMBER) group = self._get_base_group(self.MEMBER)
try: try:
dict_ = await self.driver.get(*group.identifiers) dict_ = await self.driver.get(group.identifier_data)
except KeyError: except KeyError:
pass pass
else: else:
@ -1030,7 +1045,7 @@ class Config:
else: else:
group = self._get_base_group(self.MEMBER, str(guild.id)) group = self._get_base_group(self.MEMBER, str(guild.id))
try: try:
guild_data = await self.driver.get(*group.identifiers) guild_data = await self.driver.get(group.identifier_data)
except KeyError: except KeyError:
pass pass
else: else:
@ -1057,7 +1072,8 @@ class Config:
""" """
if not scopes: if not scopes:
# noinspection PyTypeChecker # noinspection PyTypeChecker
group = Group(identifiers=(), defaults={}, driver=self.driver) identifier_data = IdentifierData(self.unique_identifier, "", (), ())
group = Group(identifier_data, defaults={}, driver=self.driver)
else: else:
group = self._get_base_group(*scopes) group = self._get_base_group(*scopes)
await group.clear() await group.clear()

View File

@ -1,4 +1,6 @@
__all__ = ["get_driver"] from .red_base import IdentifierData
__all__ = ["get_driver", "IdentifierData"]
def get_driver(type, *args, **kwargs): def get_driver(type, *args, **kwargs):

View File

@ -1,4 +1,51 @@
__all__ = ["BaseDriver"] from typing import Tuple
__all__ = ["BaseDriver", "IdentifierData"]
class IdentifierData:
def __init__(self, uuid: str, category: str, primary_key: Tuple[str], identifiers: Tuple[str]):
self._uuid = uuid
self._category = category
self._primary_key = primary_key
self._identifiers = identifiers
@property
def uuid(self):
return self._uuid
@property
def category(self):
return self._category
@property
def primary_key(self):
return self._primary_key
@property
def identifiers(self):
return self._identifiers
def __repr__(self):
return (
f"<IdentifierData uuid={self.uuid} category={self.category} primary_key={self.primary_key}"
f" identifiers={self.identifiers}>"
)
def add_identifier(self, *identifier: str) -> "IdentifierData":
if not all(isinstance(i, str) for i in identifier):
raise ValueError("Identifiers must be strings.")
return IdentifierData(
self.uuid, self.category, self.primary_key, self.identifiers + identifier
)
def to_tuple(self):
return tuple(
item
for item in (self.uuid, self.category, *self.primary_key, *self.identifiers)
if len(item) > 0
)
class BaseDriver: class BaseDriver:
@ -6,14 +53,13 @@ class BaseDriver:
self.cog_name = cog_name self.cog_name = cog_name
self.unique_cog_identifier = identifier self.unique_cog_identifier = identifier
async def get(self, *identifiers: str): async def get(self, identifier_data: IdentifierData):
""" """
Finds the value indicate by the given identifiers. Finds the value indicate by the given identifiers.
Parameters Parameters
---------- ----------
identifiers identifier_data
A list of identifiers that correspond to nested dict accesses.
Returns Returns
------- -------
@ -33,20 +79,19 @@ class BaseDriver:
""" """
raise NotImplementedError raise NotImplementedError
async def set(self, *identifiers: str, value=None): async def set(self, identifier_data: IdentifierData, value=None):
""" """
Sets the value of the key indicated by the given identifiers. Sets the value of the key indicated by the given identifiers.
Parameters Parameters
---------- ----------
identifiers identifier_data
A list of identifiers that correspond to nested dict accesses.
value value
Any JSON serializable python object. Any JSON serializable python object.
""" """
raise NotImplementedError raise NotImplementedError
async def clear(self, *identifiers: str): async def clear(self, identifier_data: IdentifierData):
""" """
Clears out the value specified by the given identifiers. Clears out the value specified by the given identifiers.
@ -54,7 +99,6 @@ class BaseDriver:
Parameters Parameters
---------- ----------
identifiers identifier_data
A list of identifiers that correspond to nested dict accesses.
""" """
raise NotImplementedError raise NotImplementedError

View File

@ -6,7 +6,7 @@ import logging
from ..json_io import JsonIO from ..json_io import JsonIO
from .red_base import BaseDriver from .red_base import BaseDriver, IdentifierData
__all__ = ["JSON"] __all__ = ["JSON"]
@ -93,16 +93,16 @@ class JSON(BaseDriver):
self.data = {} self.data = {}
self.jsonIO._save_json(self.data) self.jsonIO._save_json(self.data)
async def get(self, *identifiers: Tuple[str]): async def get(self, identifier_data: IdentifierData):
partial = self.data partial = self.data
full_identifiers = (self.unique_cog_identifier, *identifiers) full_identifiers = identifier_data.to_tuple()
for i in full_identifiers: for i in full_identifiers:
partial = partial[i] partial = partial[i]
return copy.deepcopy(partial) return copy.deepcopy(partial)
async def set(self, *identifiers: str, value=None): async def set(self, identifier_data: IdentifierData, value=None):
partial = self.data partial = self.data
full_identifiers = (self.unique_cog_identifier, *identifiers) full_identifiers = identifier_data.to_tuple()
for i in full_identifiers[:-1]: for i in full_identifiers[:-1]:
if i not in partial: if i not in partial:
partial[i] = {} partial[i] = {}
@ -111,9 +111,9 @@ class JSON(BaseDriver):
partial[full_identifiers[-1]] = copy.deepcopy(value) partial[full_identifiers[-1]] = copy.deepcopy(value)
await self.jsonIO._threadsafe_save_json(self.data) await self.jsonIO._threadsafe_save_json(self.data)
async def clear(self, *identifiers: str): async def clear(self, identifier_data: IdentifierData):
partial = self.data partial = self.data
full_identifiers = (self.unique_cog_identifier, *identifiers) full_identifiers = identifier_data.to_tuple()
try: try:
for i in full_identifiers[:-1]: for i in full_identifiers[:-1]:
partial = partial[i] partial = partial[i]

View File

@ -1,11 +1,12 @@
import re import re
from typing import Match, Pattern from typing import Match, Pattern, Tuple
from urllib.parse import quote_plus from urllib.parse import quote_plus
import motor.core import motor.core
import motor.motor_asyncio import motor.motor_asyncio
from motor.motor_asyncio import AsyncIOMotorCursor
from .red_base import BaseDriver from .red_base import BaseDriver, IdentifierData
__all__ = ["Mongo"] __all__ = ["Mongo"]
@ -64,66 +65,119 @@ class Mongo(BaseDriver):
""" """
return _conn.get_database() return _conn.get_database()
def get_collection(self) -> motor.core.Collection: def get_collection(self, category: str) -> motor.core.Collection:
""" """
Gets a specified collection within the PyMongo database for this cog. Gets a specified collection within the PyMongo database for this cog.
Unless you are doing custom stuff ``collection_name`` should be one of the class Unless you are doing custom stuff ``category`` should be one of the class
attributes of :py:class:`core.config.Config`. attributes of :py:class:`core.config.Config`.
:param str collection_name: :param str category:
:return: :return:
PyMongo collection object. PyMongo collection object.
""" """
return self.db[self.cog_name] return self.db[self.cog_name][category]
@staticmethod def get_primary_key(self, identifier_data: IdentifierData) -> Tuple[str]:
def _parse_identifiers(identifiers): # noinspection PyTypeChecker
uuid, identifiers = identifiers[0], identifiers[1:] return identifier_data.primary_key
return uuid, identifiers
async def get(self, *identifiers: str): async def rebuild_dataset(self, identifier_data: IdentifierData, cursor: AsyncIOMotorCursor):
mongo_collection = self.get_collection() ret = {}
async for doc in cursor:
pkeys = doc["_id"]["RED_primary_key"]
del doc["_id"]
if len(pkeys) == 1:
# Global data
ret.update(**doc)
elif len(pkeys) > 1:
# All other data
partial = ret
for key in pkeys[1:-1]:
if key in identifier_data.primary_key:
continue
if key not in partial:
partial[key] = {}
partial = partial[key]
if pkeys[-1] in identifier_data.primary_key:
partial.update(**doc)
else:
partial[pkeys[-1]] = doc
else:
raise RuntimeError("This should not happen.")
return ret
identifiers = (*map(self._escape_key, identifiers),) async def get(self, identifier_data: IdentifierData):
dot_identifiers = ".".join(identifiers) mongo_collection = self.get_collection(identifier_data.category)
partial = await mongo_collection.find_one( pkey_filter = self.generate_primary_key_filter(identifier_data)
filter={"_id": self.unique_cog_identifier}, projection={dot_identifiers: True} if len(identifier_data.identifiers) > 0:
) dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
proj = {"_id": False, dot_identifiers: True}
partial = await mongo_collection.find_one(filter=pkey_filter, projection=proj)
else:
# The case here is for partial primary keys like all_members()
cursor = mongo_collection.find(filter=pkey_filter)
partial = await self.rebuild_dataset(identifier_data, cursor)
if partial is None: if partial is None:
raise KeyError("No matching document was found and Config expects a KeyError.") raise KeyError("No matching document was found and Config expects a KeyError.")
for i in identifiers: for i in identifier_data.identifiers:
partial = partial[i] partial = partial[i]
if isinstance(partial, dict): if isinstance(partial, dict):
return self._unescape_dict_keys(partial) return self._unescape_dict_keys(partial)
return partial return partial
async def set(self, *identifiers: str, value=None): async def set(self, identifier_data: IdentifierData, value=None):
dot_identifiers = ".".join(map(self._escape_key, identifiers)) uuid = self._escape_key(identifier_data.uuid)
primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data)))
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
if isinstance(value, dict): if isinstance(value, dict):
if len(value) == 0:
await self.clear(identifier_data)
return
value = self._escape_dict_keys(value) value = self._escape_dict_keys(value)
mongo_collection = self.get_collection() mongo_collection = self.get_collection(identifier_data.category)
if len(dot_identifiers) > 0:
update_stmt = {"$set": {dot_identifiers: value}}
else:
update_stmt = {"$set": value}
await mongo_collection.update_one( await mongo_collection.update_one(
{"_id": self.unique_cog_identifier}, {"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}},
update={"$set": {dot_identifiers: value}}, update=update_stmt,
upsert=True, upsert=True,
) )
async def clear(self, *identifiers: str): def generate_primary_key_filter(self, identifier_data: IdentifierData):
dot_identifiers = ".".join(map(self._escape_key, identifiers)) uuid = self._escape_key(identifier_data.uuid)
mongo_collection = self.get_collection() primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data)))
ret = {"_id": {"RED_uuid": uuid}}
if len(identifiers) > 0: if len(identifier_data.identifiers) > 0:
await mongo_collection.update_one( ret["_id"]["RED_primary_key"] = primary_key
{"_id": self.unique_cog_identifier}, update={"$unset": {dot_identifiers: 1}}
)
else: else:
await mongo_collection.delete_one({"_id": self.unique_cog_identifier}) for i, key in enumerate(primary_key):
keyname = f"RED_primary_key.{i}"
ret["_id"][keyname] = key
return ret
async def clear(self, identifier_data: IdentifierData):
# There are three cases here:
# 1) We're clearing out a subset of identifiers (aka identifiers is NOT empty)
# 2) We're clearing out full primary key and no identifiers
# 3) We're clearing out partial primary key and no identifiers
# 4) Primary key is empty, should wipe all documents in the collection
mongo_collection = self.get_collection(identifier_data.category)
pkey_filter = self.generate_primary_key_filter(identifier_data)
if len(identifier_data.identifiers) == 0:
# This covers cases 2-4
await mongo_collection.delete_many(pkey_filter)
else:
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
await mongo_collection.update_one(pkey_filter, update={"$unset": {dot_identifiers: 1}})
@staticmethod @staticmethod
def _escape_key(key: str) -> str: def _escape_key(key: str) -> str:

View File

@ -269,8 +269,9 @@ async def edit_instance():
default_dirs["STORAGE_DETAILS"] = storage_details default_dirs["STORAGE_DETAILS"] = storage_details
if instance_data["STORAGE_TYPE"] == "JSON": if instance_data["STORAGE_TYPE"] == "JSON":
if confirm("Would you like to import your data? (y/n) "): raise NotImplementedError("We cannot convert from JSON to MongoDB at this time.")
await json_to_mongo(current_data_dir, storage_details) # if confirm("Would you like to import your data? (y/n) "):
# await json_to_mongo(current_data_dir, storage_details)
else: else:
storage_details = instance_data["STORAGE_DETAILS"] storage_details = instance_data["STORAGE_DETAILS"]
default_dirs["STORAGE_DETAILS"] = {} default_dirs["STORAGE_DETAILS"] = {}