mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-07 11:48:55 -05:00
Various Config and Mongo Driver fixes (#2795)
- Fixes defaults being mixed into custom groups above the document level when doing `Group.all()` - Fixes `Config.clear_all()` with Mongo driver - Fixes `Group.set()` with Mongo driver on custom groups above the document level - Fixes `IdentifierData.custom_group_data` being set to the wrong thing in `BaseDriver.import/export_data` (although this was an inconsequential bug) Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
parent
6ae3040aac
commit
71d0bd0d07
@ -1,16 +1,14 @@
|
|||||||
import logging
|
import logging
|
||||||
import collections
|
import collections
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar, TYPE_CHECKING
|
from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar
|
||||||
import weakref
|
import weakref
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
from .data_manager import cog_data_path, core_data_path
|
from .data_manager import cog_data_path, core_data_path
|
||||||
from .drivers import get_driver, IdentifierData, BackendType
|
from .drivers import get_driver, IdentifierData, BackendType
|
||||||
|
from .drivers.red_base import BaseDriver
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .drivers.red_base import BaseDriver
|
|
||||||
|
|
||||||
__all__ = ["Config", "get_latest_confs"]
|
__all__ = ["Config", "get_latest_confs"]
|
||||||
|
|
||||||
@ -545,7 +543,7 @@ class Config:
|
|||||||
self,
|
self,
|
||||||
cog_name: str,
|
cog_name: str,
|
||||||
unique_identifier: str,
|
unique_identifier: str,
|
||||||
driver: "BaseDriver",
|
driver: BaseDriver,
|
||||||
force_registration: bool = False,
|
force_registration: bool = False,
|
||||||
defaults: dict = None,
|
defaults: dict = None,
|
||||||
):
|
):
|
||||||
@ -852,9 +850,16 @@ class Config:
|
|||||||
custom_group_data=self.custom_groups,
|
custom_group_data=self.custom_groups,
|
||||||
is_custom=is_custom,
|
is_custom=is_custom,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
pkey_len = BaseDriver.get_pkey_len(identifier_data)
|
||||||
|
if len(primary_keys) < pkey_len:
|
||||||
|
# Don't mix in defaults with groups higher than the document level
|
||||||
|
defaults = {}
|
||||||
|
else:
|
||||||
|
defaults = self.defaults.get(category, {})
|
||||||
return Group(
|
return Group(
|
||||||
identifier_data=identifier_data,
|
identifier_data=identifier_data,
|
||||||
defaults=self.defaults.get(category, {}),
|
defaults=defaults,
|
||||||
driver=self.driver,
|
driver=self.driver,
|
||||||
force_registration=self.force_registration,
|
force_registration=self.force_registration,
|
||||||
)
|
)
|
||||||
@ -975,6 +980,7 @@ class Config:
|
|||||||
"""
|
"""
|
||||||
group = self._get_base_group(scope)
|
group = self._get_base_group(scope)
|
||||||
ret = {}
|
ret = {}
|
||||||
|
defaults = self.defaults.get(scope, {})
|
||||||
|
|
||||||
try:
|
try:
|
||||||
dict_ = await self.driver.get(group.identifier_data)
|
dict_ = await self.driver.get(group.identifier_data)
|
||||||
@ -982,7 +988,7 @@ class Config:
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
for k, v in dict_.items():
|
for k, v in dict_.items():
|
||||||
data = group.defaults
|
data = deepcopy(defaults)
|
||||||
data.update(v)
|
data.update(v)
|
||||||
ret[int(k)] = data
|
ret[int(k)] = data
|
||||||
|
|
||||||
@ -1056,11 +1062,11 @@ class Config:
|
|||||||
"""
|
"""
|
||||||
return await self._all_from_scope(self.USER)
|
return await self._all_from_scope(self.USER)
|
||||||
|
|
||||||
@staticmethod
|
def _all_members_from_guild(self, guild_data: dict) -> dict:
|
||||||
def _all_members_from_guild(group: Group, guild_data: dict) -> dict:
|
|
||||||
ret = {}
|
ret = {}
|
||||||
|
defaults = self.defaults.get(self.MEMBER, {})
|
||||||
for member_id, member_data in guild_data.items():
|
for member_id, member_data in guild_data.items():
|
||||||
new_member_data = group.defaults
|
new_member_data = deepcopy(defaults)
|
||||||
new_member_data.update(member_data)
|
new_member_data.update(member_data)
|
||||||
ret[int(member_id)] = new_member_data
|
ret[int(member_id)] = new_member_data
|
||||||
return ret
|
return ret
|
||||||
@ -1099,7 +1105,7 @@ class Config:
|
|||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
for guild_id, guild_data in dict_.items():
|
for guild_id, guild_data in dict_.items():
|
||||||
ret[int(guild_id)] = self._all_members_from_guild(group, guild_data)
|
ret[int(guild_id)] = self._all_members_from_guild(guild_data)
|
||||||
else:
|
else:
|
||||||
group = self._get_base_group(self.MEMBER, str(guild.id))
|
group = self._get_base_group(self.MEMBER, str(guild.id))
|
||||||
try:
|
try:
|
||||||
@ -1107,7 +1113,7 @@ class Config:
|
|||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
ret = self._all_members_from_guild(group, guild_data)
|
ret = self._all_members_from_guild(guild_data)
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
async def _clear_scope(self, *scopes: str):
|
async def _clear_scope(self, *scopes: str):
|
||||||
|
|||||||
@ -18,8 +18,8 @@ class IdentifierData:
|
|||||||
self,
|
self,
|
||||||
uuid: str,
|
uuid: str,
|
||||||
category: str,
|
category: str,
|
||||||
primary_key: Tuple[str],
|
primary_key: Tuple[str, ...],
|
||||||
identifiers: Tuple[str],
|
identifiers: Tuple[str, ...],
|
||||||
custom_group_data: dict,
|
custom_group_data: dict,
|
||||||
is_custom: bool = False,
|
is_custom: bool = False,
|
||||||
):
|
):
|
||||||
@ -183,7 +183,7 @@ class BaseDriver:
|
|||||||
c,
|
c,
|
||||||
(),
|
(),
|
||||||
(),
|
(),
|
||||||
custom_group_data.get(c, {}),
|
custom_group_data,
|
||||||
is_custom=c in custom_group_data,
|
is_custom=c in custom_group_data,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
@ -202,7 +202,19 @@ class BaseDriver:
|
|||||||
category,
|
category,
|
||||||
pkey,
|
pkey,
|
||||||
(),
|
(),
|
||||||
custom_group_data.get(category, {}),
|
custom_group_data,
|
||||||
is_custom=category in custom_group_data,
|
is_custom=category in custom_group_data,
|
||||||
)
|
)
|
||||||
await self.set(ident_data, data)
|
await self.set(ident_data, data)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def get_pkey_len(identifier_data: IdentifierData) -> int:
|
||||||
|
cat = identifier_data.category
|
||||||
|
if cat == ConfigCategory.GLOBAL.value:
|
||||||
|
return 0
|
||||||
|
elif cat == ConfigCategory.MEMBER.value:
|
||||||
|
return 2
|
||||||
|
elif identifier_data.is_custom:
|
||||||
|
return identifier_data.custom_group_data[cat]
|
||||||
|
else:
|
||||||
|
return 1
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple
|
|
||||||
import copy
|
import copy
|
||||||
import weakref
|
import weakref
|
||||||
import logging
|
import logging
|
||||||
@ -156,7 +155,7 @@ class JSON(BaseDriver):
|
|||||||
category,
|
category,
|
||||||
pkey,
|
pkey,
|
||||||
(),
|
(),
|
||||||
custom_group_data.get(category, {}),
|
custom_group_data,
|
||||||
is_custom=category in custom_group_data,
|
is_custom=category in custom_group_data,
|
||||||
)
|
)
|
||||||
update_write_data(ident_data, data)
|
update_write_data(ident_data, data)
|
||||||
|
|||||||
@ -1,11 +1,13 @@
|
|||||||
|
import contextlib
|
||||||
|
import itertools
|
||||||
import re
|
import re
|
||||||
from getpass import getpass
|
from getpass import getpass
|
||||||
from typing import Match, Pattern, Tuple
|
from typing import Match, Pattern, Tuple, Any, Dict, Iterator, List
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
import motor.core
|
import motor.core
|
||||||
import motor.motor_asyncio
|
import motor.motor_asyncio
|
||||||
from motor.motor_asyncio import AsyncIOMotorCursor
|
import pymongo.errors
|
||||||
|
|
||||||
from .red_base import BaseDriver, IdentifierData
|
from .red_base import BaseDriver, IdentifierData
|
||||||
|
|
||||||
@ -36,7 +38,7 @@ def _initialize(**kwargs):
|
|||||||
url = "{}://{}{}/{}".format(uri, host, ports, db_name)
|
url = "{}://{}{}/{}".format(uri, host, ports, db_name)
|
||||||
|
|
||||||
global _conn
|
global _conn
|
||||||
_conn = motor.motor_asyncio.AsyncIOMotorClient(url)
|
_conn = motor.motor_asyncio.AsyncIOMotorClient(url, retryWrites=True)
|
||||||
|
|
||||||
|
|
||||||
class Mongo(BaseDriver):
|
class Mongo(BaseDriver):
|
||||||
@ -87,7 +89,9 @@ class Mongo(BaseDriver):
|
|||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
return identifier_data.primary_key
|
return identifier_data.primary_key
|
||||||
|
|
||||||
async def rebuild_dataset(self, identifier_data: IdentifierData, cursor: AsyncIOMotorCursor):
|
async def rebuild_dataset(
|
||||||
|
self, identifier_data: IdentifierData, cursor: motor.motor_asyncio.AsyncIOMotorCursor
|
||||||
|
):
|
||||||
ret = {}
|
ret = {}
|
||||||
async for doc in cursor:
|
async for doc in cursor:
|
||||||
pkeys = doc["_id"]["RED_primary_key"]
|
pkeys = doc["_id"]["RED_primary_key"]
|
||||||
@ -137,24 +141,96 @@ class Mongo(BaseDriver):
|
|||||||
async def set(self, identifier_data: IdentifierData, value=None):
|
async def set(self, identifier_data: IdentifierData, value=None):
|
||||||
uuid = self._escape_key(identifier_data.uuid)
|
uuid = self._escape_key(identifier_data.uuid)
|
||||||
primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data)))
|
primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data)))
|
||||||
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
|
|
||||||
if isinstance(value, dict):
|
if isinstance(value, dict):
|
||||||
if len(value) == 0:
|
if len(value) == 0:
|
||||||
await self.clear(identifier_data)
|
await self.clear(identifier_data)
|
||||||
return
|
return
|
||||||
value = self._escape_dict_keys(value)
|
value = self._escape_dict_keys(value)
|
||||||
|
|
||||||
mongo_collection = self.get_collection(identifier_data.category)
|
mongo_collection = self.get_collection(identifier_data.category)
|
||||||
if len(dot_identifiers) > 0:
|
pkey_len = self.get_pkey_len(identifier_data)
|
||||||
update_stmt = {"$set": {dot_identifiers: value}}
|
num_pkeys = len(primary_key)
|
||||||
else:
|
|
||||||
update_stmt = {"$set": value}
|
|
||||||
|
|
||||||
await mongo_collection.update_one(
|
if num_pkeys >= pkey_len:
|
||||||
{"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}},
|
# We're setting at the document level or below.
|
||||||
update=update_stmt,
|
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
|
||||||
upsert=True,
|
if dot_identifiers:
|
||||||
)
|
update_stmt = {"$set": {dot_identifiers: value}}
|
||||||
|
else:
|
||||||
|
update_stmt = {"$set": value}
|
||||||
|
|
||||||
|
await mongo_collection.update_one(
|
||||||
|
{"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}},
|
||||||
|
update=update_stmt,
|
||||||
|
upsert=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
else:
|
||||||
|
# We're setting above the document level.
|
||||||
|
# Easiest and most efficient thing to do is delete all documents that we're potentially
|
||||||
|
# replacing, then insert_many().
|
||||||
|
# We'll do it in a transaction so we can roll-back in case something goes horribly
|
||||||
|
# wrong.
|
||||||
|
pkey_filter = self.generate_primary_key_filter(identifier_data)
|
||||||
|
async with await _conn.start_session() as session:
|
||||||
|
with contextlib.suppress(pymongo.errors.CollectionInvalid):
|
||||||
|
# Collections must already exist when inserting documents within a transaction
|
||||||
|
await _conn.get_database().create_collection(mongo_collection.full_name)
|
||||||
|
try:
|
||||||
|
async with session.start_transaction():
|
||||||
|
await mongo_collection.delete_many(pkey_filter, session=session)
|
||||||
|
await mongo_collection.insert_many(
|
||||||
|
self.generate_documents_to_insert(uuid, primary_key, value, pkey_len),
|
||||||
|
session=session,
|
||||||
|
)
|
||||||
|
except pymongo.errors.OperationFailure:
|
||||||
|
# This DB version / setup doesn't support transactions, so we'll have to use
|
||||||
|
# a shittier method.
|
||||||
|
|
||||||
|
# The strategy here is to separate the existing documents and the new documents
|
||||||
|
# into ones to be deleted, ones to be replaced, and new ones to be inserted.
|
||||||
|
# Then we can do a bulk_write().
|
||||||
|
|
||||||
|
# This is our list of (filter, new_document) tuples for replacing existing
|
||||||
|
# documents. The `new_document` should be taken and removed from `value`, so
|
||||||
|
# `value` only ends up containing documents which need to be inserted.
|
||||||
|
to_replace: List[Tuple[Dict, Dict]] = []
|
||||||
|
|
||||||
|
# This is our list of primary key filters which need deleting. They should
|
||||||
|
# simply be all the primary keys which were part of existing documents but are
|
||||||
|
# not included in the new documents.
|
||||||
|
to_delete: List[Dict] = []
|
||||||
|
async for document in mongo_collection.find(pkey_filter, session=session):
|
||||||
|
pkey = document["_id"]["RED_primary_key"]
|
||||||
|
new_document = value
|
||||||
|
try:
|
||||||
|
for pkey_part in pkey[num_pkeys:-1]:
|
||||||
|
new_document = new_document[pkey_part]
|
||||||
|
# This document is being replaced - remove it from `value`.
|
||||||
|
new_document = new_document.pop(pkey[-1])
|
||||||
|
except KeyError:
|
||||||
|
# We've found the primary key of an old document which isn't in the
|
||||||
|
# updated set of documents - it should be deleted.
|
||||||
|
to_delete.append({"_id": {"RED_uuid": uuid, "RED_primary_key": pkey}})
|
||||||
|
else:
|
||||||
|
_filter = {"_id": {"RED_uuid": uuid, "RED_primary_key": pkey}}
|
||||||
|
new_document.update(_filter)
|
||||||
|
to_replace.append((_filter, new_document))
|
||||||
|
|
||||||
|
# What's left of `value` should be the new documents needing to be inserted.
|
||||||
|
to_insert = self.generate_documents_to_insert(
|
||||||
|
uuid, primary_key, value, pkey_len
|
||||||
|
)
|
||||||
|
requests = list(
|
||||||
|
itertools.chain(
|
||||||
|
(pymongo.DeleteOne(f) for f in to_delete),
|
||||||
|
(pymongo.ReplaceOne(f, d) for f, d in to_replace),
|
||||||
|
(pymongo.InsertOne(d) for d in to_insert if d),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
# This will pipeline the operations so they all complete quickly. However if
|
||||||
|
# any of them fail, the rest of them will complete - i.e. this operation is not
|
||||||
|
# atomic.
|
||||||
|
await mongo_collection.bulk_write(requests, ordered=False)
|
||||||
|
|
||||||
def generate_primary_key_filter(self, identifier_data: IdentifierData):
|
def generate_primary_key_filter(self, identifier_data: IdentifierData):
|
||||||
uuid = self._escape_key(identifier_data.uuid)
|
uuid = self._escape_key(identifier_data.uuid)
|
||||||
@ -170,20 +246,48 @@ class Mongo(BaseDriver):
|
|||||||
ret["_id.RED_primary_key"] = {"$exists": True}
|
ret["_id.RED_primary_key"] = {"$exists": True}
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def generate_documents_to_insert(
|
||||||
|
cls, uuid: str, primary_keys: List[str], data: Dict[str, Dict[str, Any]], pkey_len: int
|
||||||
|
) -> Iterator[Dict[str, Any]]:
|
||||||
|
num_missing_pkeys = pkey_len - len(primary_keys)
|
||||||
|
if num_missing_pkeys == 1:
|
||||||
|
for pkey, document in data.items():
|
||||||
|
document["_id"] = {"RED_uuid": uuid, "RED_primary_key": primary_keys + [pkey]}
|
||||||
|
yield document
|
||||||
|
else:
|
||||||
|
for pkey, inner_data in data.items():
|
||||||
|
for document in cls.generate_documents_to_insert(
|
||||||
|
uuid, primary_keys + [pkey], inner_data, pkey_len
|
||||||
|
):
|
||||||
|
yield document
|
||||||
|
|
||||||
async def clear(self, identifier_data: IdentifierData):
|
async def clear(self, identifier_data: IdentifierData):
|
||||||
# There are three cases here:
|
# There are five cases here:
|
||||||
# 1) We're clearing out a subset of identifiers (aka identifiers is NOT empty)
|
# 1) We're clearing out a subset of identifiers (aka identifiers is NOT empty)
|
||||||
# 2) We're clearing out full primary key and no identifiers
|
# 2) We're clearing out full primary key and no identifiers
|
||||||
# 3) We're clearing out partial primary key and no identifiers
|
# 3) We're clearing out partial primary key and no identifiers
|
||||||
# 4) Primary key is empty, should wipe all documents in the collection
|
# 4) Primary key is empty, should wipe all documents in the collection
|
||||||
mongo_collection = self.get_collection(identifier_data.category)
|
# 5) Category is empty, all of this cog's data should be deleted
|
||||||
pkey_filter = self.generate_primary_key_filter(identifier_data)
|
pkey_filter = self.generate_primary_key_filter(identifier_data)
|
||||||
if len(identifier_data.identifiers) == 0:
|
if identifier_data.identifiers:
|
||||||
# This covers cases 2-4
|
# This covers case 1
|
||||||
await mongo_collection.delete_many(pkey_filter)
|
mongo_collection = self.get_collection(identifier_data.category)
|
||||||
else:
|
|
||||||
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
|
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
|
||||||
await mongo_collection.update_one(pkey_filter, update={"$unset": {dot_identifiers: 1}})
|
await mongo_collection.update_one(pkey_filter, update={"$unset": {dot_identifiers: 1}})
|
||||||
|
elif identifier_data.category:
|
||||||
|
# This covers cases 2-4
|
||||||
|
mongo_collection = self.get_collection(identifier_data.category)
|
||||||
|
await mongo_collection.delete_many(pkey_filter)
|
||||||
|
else:
|
||||||
|
# This covers case 5
|
||||||
|
db = self.db
|
||||||
|
super_collection = db[self.cog_name]
|
||||||
|
results = await db.list_collections(
|
||||||
|
filter={"name": {"$regex": rf"^{super_collection.name}\."}}
|
||||||
|
)
|
||||||
|
for result in results:
|
||||||
|
await db[result["name"]].delete_many(pkey_filter)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _escape_key(key: str) -> str:
|
def _escape_key(key: str) -> str:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user