mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 11:18:54 -05:00
[V3 Config] Update Mongo document organization to bypass doc size restriction (#2536)
* modify config to use identifier data class and update json driver * move identifier data attributes into read only properties * Update mongo get and set methods * Update get/set to use UUID separately, make clear work * Remove not implemented and fix get_raw * Update remaining untouched get/set/clear * Fix get_raw * Finally fix get_raw and set_raw * style * This is better * Sorry guys * Update get behavior to handle "all" calls as expected * style again * Why do you do this to me * style once more * Update mongo schema
This commit is contained in:
parent
d6d6d14977
commit
1cd7e41f33
@ -6,7 +6,7 @@ from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, Type
|
|||||||
import discord
|
import discord
|
||||||
|
|
||||||
from .data_manager import cog_data_path, core_data_path
|
from .data_manager import cog_data_path, core_data_path
|
||||||
from .drivers import get_driver
|
from .drivers import get_driver, IdentifierData
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from .drivers.red_base import BaseDriver
|
from .drivers.red_base import BaseDriver
|
||||||
@ -72,14 +72,14 @@ class Value:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, identifiers: Tuple[str], default_value, driver):
|
def __init__(self, identifier_data: IdentifierData, default_value, driver):
|
||||||
self.identifiers = identifiers
|
self.identifier_data = identifier_data
|
||||||
self.default = default_value
|
self.default = default_value
|
||||||
self.driver = driver
|
self.driver = driver
|
||||||
|
|
||||||
async def _get(self, default=...):
|
async def _get(self, default=...):
|
||||||
try:
|
try:
|
||||||
ret = await self.driver.get(*self.identifiers)
|
ret = await self.driver.get(self.identifier_data)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return default if default is not ... else self.default
|
return default if default is not ... else self.default
|
||||||
return ret
|
return ret
|
||||||
@ -150,13 +150,13 @@ class Value:
|
|||||||
"""
|
"""
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
value = _str_key_dict(value)
|
value = _str_key_dict(value)
|
||||||
await self.driver.set(*self.identifiers, value=value)
|
await self.driver.set(self.identifier_data, value=value)
|
||||||
|
|
||||||
async def clear(self):
|
async def clear(self):
|
||||||
"""
|
"""
|
||||||
Clears the value from record for the data element pointed to by `identifiers`.
|
Clears the value from record for the data element pointed to by `identifiers`.
|
||||||
"""
|
"""
|
||||||
await self.driver.clear(*self.identifiers)
|
await self.driver.clear(self.identifier_data)
|
||||||
|
|
||||||
|
|
||||||
class Group(Value):
|
class Group(Value):
|
||||||
@ -178,13 +178,17 @@ class Group(Value):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, identifiers: Tuple[str], defaults: dict, driver, force_registration: bool = False
|
self,
|
||||||
|
identifier_data: IdentifierData,
|
||||||
|
defaults: dict,
|
||||||
|
driver,
|
||||||
|
force_registration: bool = False,
|
||||||
):
|
):
|
||||||
self._defaults = defaults
|
self._defaults = defaults
|
||||||
self.force_registration = force_registration
|
self.force_registration = force_registration
|
||||||
self.driver = driver
|
self.driver = driver
|
||||||
|
|
||||||
super().__init__(identifiers, {}, self.driver)
|
super().__init__(identifier_data, {}, self.driver)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def defaults(self):
|
def defaults(self):
|
||||||
@ -225,22 +229,24 @@ class Group(Value):
|
|||||||
"""
|
"""
|
||||||
is_group = self.is_group(item)
|
is_group = self.is_group(item)
|
||||||
is_value = not is_group and self.is_value(item)
|
is_value = not is_group and self.is_value(item)
|
||||||
new_identifiers = self.identifiers + (item,)
|
new_identifiers = self.identifier_data.add_identifier(item)
|
||||||
if is_group:
|
if is_group:
|
||||||
return Group(
|
return Group(
|
||||||
identifiers=new_identifiers,
|
identifier_data=new_identifiers,
|
||||||
defaults=self._defaults[item],
|
defaults=self._defaults[item],
|
||||||
driver=self.driver,
|
driver=self.driver,
|
||||||
force_registration=self.force_registration,
|
force_registration=self.force_registration,
|
||||||
)
|
)
|
||||||
elif is_value:
|
elif is_value:
|
||||||
return Value(
|
return Value(
|
||||||
identifiers=new_identifiers, default_value=self._defaults[item], driver=self.driver
|
identifier_data=new_identifiers,
|
||||||
|
default_value=self._defaults[item],
|
||||||
|
driver=self.driver,
|
||||||
)
|
)
|
||||||
elif self.force_registration:
|
elif self.force_registration:
|
||||||
raise AttributeError("'{}' is not a valid registered Group or value.".format(item))
|
raise AttributeError("'{}' is not a valid registered Group or value.".format(item))
|
||||||
else:
|
else:
|
||||||
return Value(identifiers=new_identifiers, default_value=None, driver=self.driver)
|
return Value(identifier_data=new_identifiers, default_value=None, driver=self.driver)
|
||||||
|
|
||||||
async def clear_raw(self, *nested_path: Any):
|
async def clear_raw(self, *nested_path: Any):
|
||||||
"""
|
"""
|
||||||
@ -262,8 +268,9 @@ class Group(Value):
|
|||||||
Multiple arguments that mirror the arguments passed in for nested
|
Multiple arguments that mirror the arguments passed in for nested
|
||||||
dict access. These are casted to `str` for you.
|
dict access. These are casted to `str` for you.
|
||||||
"""
|
"""
|
||||||
path = [str(p) for p in nested_path]
|
path = tuple(str(p) for p in nested_path)
|
||||||
await self.driver.clear(*self.identifiers, *path)
|
identifier_data = self.identifier_data.add_identifier(*path)
|
||||||
|
await self.driver.clear(identifier_data)
|
||||||
|
|
||||||
def is_group(self, item: Any) -> bool:
|
def is_group(self, item: Any) -> bool:
|
||||||
"""A helper method for `__getattr__`. Most developers will have no need
|
"""A helper method for `__getattr__`. Most developers will have no need
|
||||||
@ -368,7 +375,7 @@ class Group(Value):
|
|||||||
If the value does not exist yet in Config's internal storage.
|
If the value does not exist yet in Config's internal storage.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
path = [str(p) for p in nested_path]
|
path = tuple(str(p) for p in nested_path)
|
||||||
|
|
||||||
if default is ...:
|
if default is ...:
|
||||||
poss_default = self.defaults
|
poss_default = self.defaults
|
||||||
@ -380,8 +387,9 @@ class Group(Value):
|
|||||||
else:
|
else:
|
||||||
default = poss_default
|
default = poss_default
|
||||||
|
|
||||||
|
identifier_data = self.identifier_data.add_identifier(*path)
|
||||||
try:
|
try:
|
||||||
raw = await self.driver.get(*self.identifiers, *path)
|
raw = await self.driver.get(identifier_data)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
if default is not ...:
|
if default is not ...:
|
||||||
return default
|
return default
|
||||||
@ -456,10 +464,11 @@ class Group(Value):
|
|||||||
value
|
value
|
||||||
The value to store.
|
The value to store.
|
||||||
"""
|
"""
|
||||||
path = [str(p) for p in nested_path]
|
path = tuple(str(p) for p in nested_path)
|
||||||
|
identifier_data = self.identifier_data.add_identifier(*path)
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
value = _str_key_dict(value)
|
value = _str_key_dict(value)
|
||||||
await self.driver.set(*self.identifiers, *path, value=value)
|
await self.driver.set(identifier_data, value=value)
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -779,11 +788,17 @@ class Config:
|
|||||||
"""
|
"""
|
||||||
self._register_default(group_identifier, **kwargs)
|
self._register_default(group_identifier, **kwargs)
|
||||||
|
|
||||||
def _get_base_group(self, key: str, *identifiers: str) -> Group:
|
def _get_base_group(self, category: str, *primary_keys: str) -> Group:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
|
identifier_data = IdentifierData(
|
||||||
|
uuid=self.unique_identifier,
|
||||||
|
category=category,
|
||||||
|
primary_key=primary_keys,
|
||||||
|
identifiers=(),
|
||||||
|
)
|
||||||
return Group(
|
return Group(
|
||||||
identifiers=(key, *identifiers),
|
identifier_data=identifier_data,
|
||||||
defaults=self.defaults.get(key, {}),
|
defaults=self.defaults.get(category, {}),
|
||||||
driver=self.driver,
|
driver=self.driver,
|
||||||
force_registration=self.force_registration,
|
force_registration=self.force_registration,
|
||||||
)
|
)
|
||||||
@ -904,7 +919,7 @@ class Config:
|
|||||||
ret = {}
|
ret = {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dict_ = await self.driver.get(*group.identifiers)
|
dict_ = await self.driver.get(group.identifier_data)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@ -1021,7 +1036,7 @@ class Config:
|
|||||||
if guild is None:
|
if guild is None:
|
||||||
group = self._get_base_group(self.MEMBER)
|
group = self._get_base_group(self.MEMBER)
|
||||||
try:
|
try:
|
||||||
dict_ = await self.driver.get(*group.identifiers)
|
dict_ = await self.driver.get(group.identifier_data)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@ -1030,7 +1045,7 @@ class Config:
|
|||||||
else:
|
else:
|
||||||
group = self._get_base_group(self.MEMBER, str(guild.id))
|
group = self._get_base_group(self.MEMBER, str(guild.id))
|
||||||
try:
|
try:
|
||||||
guild_data = await self.driver.get(*group.identifiers)
|
guild_data = await self.driver.get(group.identifier_data)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
@ -1057,7 +1072,8 @@ class Config:
|
|||||||
"""
|
"""
|
||||||
if not scopes:
|
if not scopes:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
group = Group(identifiers=(), defaults={}, driver=self.driver)
|
identifier_data = IdentifierData(self.unique_identifier, "", (), ())
|
||||||
|
group = Group(identifier_data, defaults={}, driver=self.driver)
|
||||||
else:
|
else:
|
||||||
group = self._get_base_group(*scopes)
|
group = self._get_base_group(*scopes)
|
||||||
await group.clear()
|
await group.clear()
|
||||||
|
|||||||
@ -1,4 +1,6 @@
|
|||||||
__all__ = ["get_driver"]
|
from .red_base import IdentifierData
|
||||||
|
|
||||||
|
__all__ = ["get_driver", "IdentifierData"]
|
||||||
|
|
||||||
|
|
||||||
def get_driver(type, *args, **kwargs):
|
def get_driver(type, *args, **kwargs):
|
||||||
|
|||||||
@ -1,4 +1,51 @@
|
|||||||
__all__ = ["BaseDriver"]
|
from typing import Tuple
|
||||||
|
|
||||||
|
__all__ = ["BaseDriver", "IdentifierData"]
|
||||||
|
|
||||||
|
|
||||||
|
class IdentifierData:
|
||||||
|
def __init__(self, uuid: str, category: str, primary_key: Tuple[str], identifiers: Tuple[str]):
|
||||||
|
self._uuid = uuid
|
||||||
|
self._category = category
|
||||||
|
self._primary_key = primary_key
|
||||||
|
self._identifiers = identifiers
|
||||||
|
|
||||||
|
@property
|
||||||
|
def uuid(self):
|
||||||
|
return self._uuid
|
||||||
|
|
||||||
|
@property
|
||||||
|
def category(self):
|
||||||
|
return self._category
|
||||||
|
|
||||||
|
@property
|
||||||
|
def primary_key(self):
|
||||||
|
return self._primary_key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def identifiers(self):
|
||||||
|
return self._identifiers
|
||||||
|
|
||||||
|
def __repr__(self):
|
||||||
|
return (
|
||||||
|
f"<IdentifierData uuid={self.uuid} category={self.category} primary_key={self.primary_key}"
|
||||||
|
f" identifiers={self.identifiers}>"
|
||||||
|
)
|
||||||
|
|
||||||
|
def add_identifier(self, *identifier: str) -> "IdentifierData":
|
||||||
|
if not all(isinstance(i, str) for i in identifier):
|
||||||
|
raise ValueError("Identifiers must be strings.")
|
||||||
|
|
||||||
|
return IdentifierData(
|
||||||
|
self.uuid, self.category, self.primary_key, self.identifiers + identifier
|
||||||
|
)
|
||||||
|
|
||||||
|
def to_tuple(self):
|
||||||
|
return tuple(
|
||||||
|
item
|
||||||
|
for item in (self.uuid, self.category, *self.primary_key, *self.identifiers)
|
||||||
|
if len(item) > 0
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BaseDriver:
|
class BaseDriver:
|
||||||
@ -6,14 +53,13 @@ class BaseDriver:
|
|||||||
self.cog_name = cog_name
|
self.cog_name = cog_name
|
||||||
self.unique_cog_identifier = identifier
|
self.unique_cog_identifier = identifier
|
||||||
|
|
||||||
async def get(self, *identifiers: str):
|
async def get(self, identifier_data: IdentifierData):
|
||||||
"""
|
"""
|
||||||
Finds the value indicate by the given identifiers.
|
Finds the value indicate by the given identifiers.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
identifiers
|
identifier_data
|
||||||
A list of identifiers that correspond to nested dict accesses.
|
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -33,20 +79,19 @@ class BaseDriver:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def set(self, *identifiers: str, value=None):
|
async def set(self, identifier_data: IdentifierData, value=None):
|
||||||
"""
|
"""
|
||||||
Sets the value of the key indicated by the given identifiers.
|
Sets the value of the key indicated by the given identifiers.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
identifiers
|
identifier_data
|
||||||
A list of identifiers that correspond to nested dict accesses.
|
|
||||||
value
|
value
|
||||||
Any JSON serializable python object.
|
Any JSON serializable python object.
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def clear(self, *identifiers: str):
|
async def clear(self, identifier_data: IdentifierData):
|
||||||
"""
|
"""
|
||||||
Clears out the value specified by the given identifiers.
|
Clears out the value specified by the given identifiers.
|
||||||
|
|
||||||
@ -54,7 +99,6 @@ class BaseDriver:
|
|||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
identifiers
|
identifier_data
|
||||||
A list of identifiers that correspond to nested dict accesses.
|
|
||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|||||||
@ -6,7 +6,7 @@ import logging
|
|||||||
|
|
||||||
from ..json_io import JsonIO
|
from ..json_io import JsonIO
|
||||||
|
|
||||||
from .red_base import BaseDriver
|
from .red_base import BaseDriver, IdentifierData
|
||||||
|
|
||||||
__all__ = ["JSON"]
|
__all__ = ["JSON"]
|
||||||
|
|
||||||
@ -93,16 +93,16 @@ class JSON(BaseDriver):
|
|||||||
self.data = {}
|
self.data = {}
|
||||||
self.jsonIO._save_json(self.data)
|
self.jsonIO._save_json(self.data)
|
||||||
|
|
||||||
async def get(self, *identifiers: Tuple[str]):
|
async def get(self, identifier_data: IdentifierData):
|
||||||
partial = self.data
|
partial = self.data
|
||||||
full_identifiers = (self.unique_cog_identifier, *identifiers)
|
full_identifiers = identifier_data.to_tuple()
|
||||||
for i in full_identifiers:
|
for i in full_identifiers:
|
||||||
partial = partial[i]
|
partial = partial[i]
|
||||||
return copy.deepcopy(partial)
|
return copy.deepcopy(partial)
|
||||||
|
|
||||||
async def set(self, *identifiers: str, value=None):
|
async def set(self, identifier_data: IdentifierData, value=None):
|
||||||
partial = self.data
|
partial = self.data
|
||||||
full_identifiers = (self.unique_cog_identifier, *identifiers)
|
full_identifiers = identifier_data.to_tuple()
|
||||||
for i in full_identifiers[:-1]:
|
for i in full_identifiers[:-1]:
|
||||||
if i not in partial:
|
if i not in partial:
|
||||||
partial[i] = {}
|
partial[i] = {}
|
||||||
@ -111,9 +111,9 @@ class JSON(BaseDriver):
|
|||||||
partial[full_identifiers[-1]] = copy.deepcopy(value)
|
partial[full_identifiers[-1]] = copy.deepcopy(value)
|
||||||
await self.jsonIO._threadsafe_save_json(self.data)
|
await self.jsonIO._threadsafe_save_json(self.data)
|
||||||
|
|
||||||
async def clear(self, *identifiers: str):
|
async def clear(self, identifier_data: IdentifierData):
|
||||||
partial = self.data
|
partial = self.data
|
||||||
full_identifiers = (self.unique_cog_identifier, *identifiers)
|
full_identifiers = identifier_data.to_tuple()
|
||||||
try:
|
try:
|
||||||
for i in full_identifiers[:-1]:
|
for i in full_identifiers[:-1]:
|
||||||
partial = partial[i]
|
partial = partial[i]
|
||||||
|
|||||||
@ -1,11 +1,12 @@
|
|||||||
import re
|
import re
|
||||||
from typing import Match, Pattern
|
from typing import Match, Pattern, Tuple
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
import motor.core
|
import motor.core
|
||||||
import motor.motor_asyncio
|
import motor.motor_asyncio
|
||||||
|
from motor.motor_asyncio import AsyncIOMotorCursor
|
||||||
|
|
||||||
from .red_base import BaseDriver
|
from .red_base import BaseDriver, IdentifierData
|
||||||
|
|
||||||
__all__ = ["Mongo"]
|
__all__ = ["Mongo"]
|
||||||
|
|
||||||
@ -64,66 +65,119 @@ class Mongo(BaseDriver):
|
|||||||
"""
|
"""
|
||||||
return _conn.get_database()
|
return _conn.get_database()
|
||||||
|
|
||||||
def get_collection(self) -> motor.core.Collection:
|
def get_collection(self, category: str) -> motor.core.Collection:
|
||||||
"""
|
"""
|
||||||
Gets a specified collection within the PyMongo database for this cog.
|
Gets a specified collection within the PyMongo database for this cog.
|
||||||
|
|
||||||
Unless you are doing custom stuff ``collection_name`` should be one of the class
|
Unless you are doing custom stuff ``category`` should be one of the class
|
||||||
attributes of :py:class:`core.config.Config`.
|
attributes of :py:class:`core.config.Config`.
|
||||||
|
|
||||||
:param str collection_name:
|
:param str category:
|
||||||
:return:
|
:return:
|
||||||
PyMongo collection object.
|
PyMongo collection object.
|
||||||
"""
|
"""
|
||||||
return self.db[self.cog_name]
|
return self.db[self.cog_name][category]
|
||||||
|
|
||||||
@staticmethod
|
def get_primary_key(self, identifier_data: IdentifierData) -> Tuple[str]:
|
||||||
def _parse_identifiers(identifiers):
|
# noinspection PyTypeChecker
|
||||||
uuid, identifiers = identifiers[0], identifiers[1:]
|
return identifier_data.primary_key
|
||||||
return uuid, identifiers
|
|
||||||
|
|
||||||
async def get(self, *identifiers: str):
|
async def rebuild_dataset(self, identifier_data: IdentifierData, cursor: AsyncIOMotorCursor):
|
||||||
mongo_collection = self.get_collection()
|
ret = {}
|
||||||
|
async for doc in cursor:
|
||||||
|
pkeys = doc["_id"]["RED_primary_key"]
|
||||||
|
del doc["_id"]
|
||||||
|
if len(pkeys) == 1:
|
||||||
|
# Global data
|
||||||
|
ret.update(**doc)
|
||||||
|
elif len(pkeys) > 1:
|
||||||
|
# All other data
|
||||||
|
partial = ret
|
||||||
|
for key in pkeys[1:-1]:
|
||||||
|
if key in identifier_data.primary_key:
|
||||||
|
continue
|
||||||
|
if key not in partial:
|
||||||
|
partial[key] = {}
|
||||||
|
partial = partial[key]
|
||||||
|
if pkeys[-1] in identifier_data.primary_key:
|
||||||
|
partial.update(**doc)
|
||||||
|
else:
|
||||||
|
partial[pkeys[-1]] = doc
|
||||||
|
else:
|
||||||
|
raise RuntimeError("This should not happen.")
|
||||||
|
return ret
|
||||||
|
|
||||||
identifiers = (*map(self._escape_key, identifiers),)
|
async def get(self, identifier_data: IdentifierData):
|
||||||
dot_identifiers = ".".join(identifiers)
|
mongo_collection = self.get_collection(identifier_data.category)
|
||||||
|
|
||||||
partial = await mongo_collection.find_one(
|
pkey_filter = self.generate_primary_key_filter(identifier_data)
|
||||||
filter={"_id": self.unique_cog_identifier}, projection={dot_identifiers: True}
|
if len(identifier_data.identifiers) > 0:
|
||||||
)
|
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
|
||||||
|
proj = {"_id": False, dot_identifiers: True}
|
||||||
|
|
||||||
|
partial = await mongo_collection.find_one(filter=pkey_filter, projection=proj)
|
||||||
|
else:
|
||||||
|
# The case here is for partial primary keys like all_members()
|
||||||
|
cursor = mongo_collection.find(filter=pkey_filter)
|
||||||
|
partial = await self.rebuild_dataset(identifier_data, cursor)
|
||||||
|
|
||||||
if partial is None:
|
if partial is None:
|
||||||
raise KeyError("No matching document was found and Config expects a KeyError.")
|
raise KeyError("No matching document was found and Config expects a KeyError.")
|
||||||
|
|
||||||
for i in identifiers:
|
for i in identifier_data.identifiers:
|
||||||
partial = partial[i]
|
partial = partial[i]
|
||||||
if isinstance(partial, dict):
|
if isinstance(partial, dict):
|
||||||
return self._unescape_dict_keys(partial)
|
return self._unescape_dict_keys(partial)
|
||||||
return partial
|
return partial
|
||||||
|
|
||||||
async def set(self, *identifiers: str, value=None):
|
async def set(self, identifier_data: IdentifierData, value=None):
|
||||||
dot_identifiers = ".".join(map(self._escape_key, identifiers))
|
uuid = self._escape_key(identifier_data.uuid)
|
||||||
|
primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data)))
|
||||||
|
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
|
if len(value) == 0:
|
||||||
|
await self.clear(identifier_data)
|
||||||
|
return
|
||||||
value = self._escape_dict_keys(value)
|
value = self._escape_dict_keys(value)
|
||||||
|
|
||||||
mongo_collection = self.get_collection()
|
mongo_collection = self.get_collection(identifier_data.category)
|
||||||
|
if len(dot_identifiers) > 0:
|
||||||
|
update_stmt = {"$set": {dot_identifiers: value}}
|
||||||
|
else:
|
||||||
|
update_stmt = {"$set": value}
|
||||||
|
|
||||||
await mongo_collection.update_one(
|
await mongo_collection.update_one(
|
||||||
{"_id": self.unique_cog_identifier},
|
{"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}},
|
||||||
update={"$set": {dot_identifiers: value}},
|
update=update_stmt,
|
||||||
upsert=True,
|
upsert=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
async def clear(self, *identifiers: str):
|
def generate_primary_key_filter(self, identifier_data: IdentifierData):
|
||||||
dot_identifiers = ".".join(map(self._escape_key, identifiers))
|
uuid = self._escape_key(identifier_data.uuid)
|
||||||
mongo_collection = self.get_collection()
|
primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data)))
|
||||||
|
ret = {"_id": {"RED_uuid": uuid}}
|
||||||
if len(identifiers) > 0:
|
if len(identifier_data.identifiers) > 0:
|
||||||
await mongo_collection.update_one(
|
ret["_id"]["RED_primary_key"] = primary_key
|
||||||
{"_id": self.unique_cog_identifier}, update={"$unset": {dot_identifiers: 1}}
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
await mongo_collection.delete_one({"_id": self.unique_cog_identifier})
|
for i, key in enumerate(primary_key):
|
||||||
|
keyname = f"RED_primary_key.{i}"
|
||||||
|
ret["_id"][keyname] = key
|
||||||
|
return ret
|
||||||
|
|
||||||
|
async def clear(self, identifier_data: IdentifierData):
|
||||||
|
# There are three cases here:
|
||||||
|
# 1) We're clearing out a subset of identifiers (aka identifiers is NOT empty)
|
||||||
|
# 2) We're clearing out full primary key and no identifiers
|
||||||
|
# 3) We're clearing out partial primary key and no identifiers
|
||||||
|
# 4) Primary key is empty, should wipe all documents in the collection
|
||||||
|
mongo_collection = self.get_collection(identifier_data.category)
|
||||||
|
pkey_filter = self.generate_primary_key_filter(identifier_data)
|
||||||
|
if len(identifier_data.identifiers) == 0:
|
||||||
|
# This covers cases 2-4
|
||||||
|
await mongo_collection.delete_many(pkey_filter)
|
||||||
|
else:
|
||||||
|
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
|
||||||
|
await mongo_collection.update_one(pkey_filter, update={"$unset": {dot_identifiers: 1}})
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _escape_key(key: str) -> str:
|
def _escape_key(key: str) -> str:
|
||||||
|
|||||||
@ -269,8 +269,9 @@ async def edit_instance():
|
|||||||
default_dirs["STORAGE_DETAILS"] = storage_details
|
default_dirs["STORAGE_DETAILS"] = storage_details
|
||||||
|
|
||||||
if instance_data["STORAGE_TYPE"] == "JSON":
|
if instance_data["STORAGE_TYPE"] == "JSON":
|
||||||
if confirm("Would you like to import your data? (y/n) "):
|
raise NotImplementedError("We cannot convert from JSON to MongoDB at this time.")
|
||||||
await json_to_mongo(current_data_dir, storage_details)
|
# if confirm("Would you like to import your data? (y/n) "):
|
||||||
|
# await json_to_mongo(current_data_dir, storage_details)
|
||||||
else:
|
else:
|
||||||
storage_details = instance_data["STORAGE_DETAILS"]
|
storage_details = instance_data["STORAGE_DETAILS"]
|
||||||
default_dirs["STORAGE_DETAILS"] = {}
|
default_dirs["STORAGE_DETAILS"] = {}
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user