Various Config and Mongo Driver fixes (#2795)

- Fixes defaults being mixed into custom groups above the document level when doing `Group.all()`
- Fixes `Config.clear_all()` with Mongo driver
- Fixes `Group.set()` with Mongo driver on custom groups above the document level
- Fixes `IdentifierData.custom_group_data` being set to the wrong thing in `BaseDriver.import/export_data` (although this was an inconsequential bug)

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
Toby Harradine 2019-06-24 12:55:49 +10:00 committed by GitHub
parent 6ae3040aac
commit 71d0bd0d07
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 160 additions and 39 deletions

View File

@ -1,16 +1,14 @@
import logging import logging
import collections import collections
from copy import deepcopy from copy import deepcopy
from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar, TYPE_CHECKING from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar
import weakref import weakref
import discord import discord
from .data_manager import cog_data_path, core_data_path from .data_manager import cog_data_path, core_data_path
from .drivers import get_driver, IdentifierData, BackendType from .drivers import get_driver, IdentifierData, BackendType
from .drivers.red_base import BaseDriver
if TYPE_CHECKING:
from .drivers.red_base import BaseDriver
__all__ = ["Config", "get_latest_confs"] __all__ = ["Config", "get_latest_confs"]
@ -545,7 +543,7 @@ class Config:
self, self,
cog_name: str, cog_name: str,
unique_identifier: str, unique_identifier: str,
driver: "BaseDriver", driver: BaseDriver,
force_registration: bool = False, force_registration: bool = False,
defaults: dict = None, defaults: dict = None,
): ):
@ -852,9 +850,16 @@ class Config:
custom_group_data=self.custom_groups, custom_group_data=self.custom_groups,
is_custom=is_custom, is_custom=is_custom,
) )
pkey_len = BaseDriver.get_pkey_len(identifier_data)
if len(primary_keys) < pkey_len:
# Don't mix in defaults with groups higher than the document level
defaults = {}
else:
defaults = self.defaults.get(category, {})
return Group( return Group(
identifier_data=identifier_data, identifier_data=identifier_data,
defaults=self.defaults.get(category, {}), defaults=defaults,
driver=self.driver, driver=self.driver,
force_registration=self.force_registration, force_registration=self.force_registration,
) )
@ -975,6 +980,7 @@ class Config:
""" """
group = self._get_base_group(scope) group = self._get_base_group(scope)
ret = {} ret = {}
defaults = self.defaults.get(scope, {})
try: try:
dict_ = await self.driver.get(group.identifier_data) dict_ = await self.driver.get(group.identifier_data)
@ -982,7 +988,7 @@ class Config:
pass pass
else: else:
for k, v in dict_.items(): for k, v in dict_.items():
data = group.defaults data = deepcopy(defaults)
data.update(v) data.update(v)
ret[int(k)] = data ret[int(k)] = data
@ -1056,11 +1062,11 @@ class Config:
""" """
return await self._all_from_scope(self.USER) return await self._all_from_scope(self.USER)
@staticmethod def _all_members_from_guild(self, guild_data: dict) -> dict:
def _all_members_from_guild(group: Group, guild_data: dict) -> dict:
ret = {} ret = {}
defaults = self.defaults.get(self.MEMBER, {})
for member_id, member_data in guild_data.items(): for member_id, member_data in guild_data.items():
new_member_data = group.defaults new_member_data = deepcopy(defaults)
new_member_data.update(member_data) new_member_data.update(member_data)
ret[int(member_id)] = new_member_data ret[int(member_id)] = new_member_data
return ret return ret
@ -1099,7 +1105,7 @@ class Config:
pass pass
else: else:
for guild_id, guild_data in dict_.items(): for guild_id, guild_data in dict_.items():
ret[int(guild_id)] = self._all_members_from_guild(group, guild_data) ret[int(guild_id)] = self._all_members_from_guild(guild_data)
else: else:
group = self._get_base_group(self.MEMBER, str(guild.id)) group = self._get_base_group(self.MEMBER, str(guild.id))
try: try:
@ -1107,7 +1113,7 @@ class Config:
except KeyError: except KeyError:
pass pass
else: else:
ret = self._all_members_from_guild(group, guild_data) ret = self._all_members_from_guild(guild_data)
return ret return ret
async def _clear_scope(self, *scopes: str): async def _clear_scope(self, *scopes: str):

View File

@ -18,8 +18,8 @@ class IdentifierData:
self, self,
uuid: str, uuid: str,
category: str, category: str,
primary_key: Tuple[str], primary_key: Tuple[str, ...],
identifiers: Tuple[str], identifiers: Tuple[str, ...],
custom_group_data: dict, custom_group_data: dict,
is_custom: bool = False, is_custom: bool = False,
): ):
@ -183,7 +183,7 @@ class BaseDriver:
c, c,
(), (),
(), (),
custom_group_data.get(c, {}), custom_group_data,
is_custom=c in custom_group_data, is_custom=c in custom_group_data,
) )
try: try:
@ -202,7 +202,19 @@ class BaseDriver:
category, category,
pkey, pkey,
(), (),
custom_group_data.get(category, {}), custom_group_data,
is_custom=category in custom_group_data, is_custom=category in custom_group_data,
) )
await self.set(ident_data, data) await self.set(ident_data, data)
@staticmethod
def get_pkey_len(identifier_data: IdentifierData) -> int:
cat = identifier_data.category
if cat == ConfigCategory.GLOBAL.value:
return 0
elif cat == ConfigCategory.MEMBER.value:
return 2
elif identifier_data.is_custom:
return identifier_data.custom_group_data[cat]
else:
return 1

View File

@ -1,5 +1,4 @@
from pathlib import Path from pathlib import Path
from typing import Tuple
import copy import copy
import weakref import weakref
import logging import logging
@ -156,7 +155,7 @@ class JSON(BaseDriver):
category, category,
pkey, pkey,
(), (),
custom_group_data.get(category, {}), custom_group_data,
is_custom=category in custom_group_data, is_custom=category in custom_group_data,
) )
update_write_data(ident_data, data) update_write_data(ident_data, data)

View File

@ -1,11 +1,13 @@
import contextlib
import itertools
import re import re
from getpass import getpass from getpass import getpass
from typing import Match, Pattern, Tuple from typing import Match, Pattern, Tuple, Any, Dict, Iterator, List
from urllib.parse import quote_plus from urllib.parse import quote_plus
import motor.core import motor.core
import motor.motor_asyncio import motor.motor_asyncio
from motor.motor_asyncio import AsyncIOMotorCursor import pymongo.errors
from .red_base import BaseDriver, IdentifierData from .red_base import BaseDriver, IdentifierData
@ -36,7 +38,7 @@ def _initialize(**kwargs):
url = "{}://{}{}/{}".format(uri, host, ports, db_name) url = "{}://{}{}/{}".format(uri, host, ports, db_name)
global _conn global _conn
_conn = motor.motor_asyncio.AsyncIOMotorClient(url) _conn = motor.motor_asyncio.AsyncIOMotorClient(url, retryWrites=True)
class Mongo(BaseDriver): class Mongo(BaseDriver):
@ -87,7 +89,9 @@ class Mongo(BaseDriver):
# noinspection PyTypeChecker # noinspection PyTypeChecker
return identifier_data.primary_key return identifier_data.primary_key
async def rebuild_dataset(self, identifier_data: IdentifierData, cursor: AsyncIOMotorCursor): async def rebuild_dataset(
self, identifier_data: IdentifierData, cursor: motor.motor_asyncio.AsyncIOMotorCursor
):
ret = {} ret = {}
async for doc in cursor: async for doc in cursor:
pkeys = doc["_id"]["RED_primary_key"] pkeys = doc["_id"]["RED_primary_key"]
@ -137,24 +141,96 @@ class Mongo(BaseDriver):
async def set(self, identifier_data: IdentifierData, value=None): async def set(self, identifier_data: IdentifierData, value=None):
uuid = self._escape_key(identifier_data.uuid) uuid = self._escape_key(identifier_data.uuid)
primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data))) primary_key = list(map(self._escape_key, self.get_primary_key(identifier_data)))
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
if isinstance(value, dict): if isinstance(value, dict):
if len(value) == 0: if len(value) == 0:
await self.clear(identifier_data) await self.clear(identifier_data)
return return
value = self._escape_dict_keys(value) value = self._escape_dict_keys(value)
mongo_collection = self.get_collection(identifier_data.category) mongo_collection = self.get_collection(identifier_data.category)
if len(dot_identifiers) > 0: pkey_len = self.get_pkey_len(identifier_data)
update_stmt = {"$set": {dot_identifiers: value}} num_pkeys = len(primary_key)
else:
update_stmt = {"$set": value}
await mongo_collection.update_one( if num_pkeys >= pkey_len:
{"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}}, # We're setting at the document level or below.
update=update_stmt, dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
upsert=True, if dot_identifiers:
) update_stmt = {"$set": {dot_identifiers: value}}
else:
update_stmt = {"$set": value}
await mongo_collection.update_one(
{"_id": {"RED_uuid": uuid, "RED_primary_key": primary_key}},
update=update_stmt,
upsert=True,
)
else:
# We're setting above the document level.
# Easiest and most efficient thing to do is delete all documents that we're potentially
# replacing, then insert_many().
# We'll do it in a transaction so we can roll-back in case something goes horribly
# wrong.
pkey_filter = self.generate_primary_key_filter(identifier_data)
async with await _conn.start_session() as session:
with contextlib.suppress(pymongo.errors.CollectionInvalid):
# Collections must already exist when inserting documents within a transaction
await _conn.get_database().create_collection(mongo_collection.full_name)
try:
async with session.start_transaction():
await mongo_collection.delete_many(pkey_filter, session=session)
await mongo_collection.insert_many(
self.generate_documents_to_insert(uuid, primary_key, value, pkey_len),
session=session,
)
except pymongo.errors.OperationFailure:
# This DB version / setup doesn't support transactions, so we'll have to use
# a shittier method.
# The strategy here is to separate the existing documents and the new documents
# into ones to be deleted, ones to be replaced, and new ones to be inserted.
# Then we can do a bulk_write().
# This is our list of (filter, new_document) tuples for replacing existing
# documents. The `new_document` should be taken and removed from `value`, so
# `value` only ends up containing documents which need to be inserted.
to_replace: List[Tuple[Dict, Dict]] = []
# This is our list of primary key filters which need deleting. They should
# simply be all the primary keys which were part of existing documents but are
# not included in the new documents.
to_delete: List[Dict] = []
async for document in mongo_collection.find(pkey_filter, session=session):
pkey = document["_id"]["RED_primary_key"]
new_document = value
try:
for pkey_part in pkey[num_pkeys:-1]:
new_document = new_document[pkey_part]
# This document is being replaced - remove it from `value`.
new_document = new_document.pop(pkey[-1])
except KeyError:
# We've found the primary key of an old document which isn't in the
# updated set of documents - it should be deleted.
to_delete.append({"_id": {"RED_uuid": uuid, "RED_primary_key": pkey}})
else:
_filter = {"_id": {"RED_uuid": uuid, "RED_primary_key": pkey}}
new_document.update(_filter)
to_replace.append((_filter, new_document))
# What's left of `value` should be the new documents needing to be inserted.
to_insert = self.generate_documents_to_insert(
uuid, primary_key, value, pkey_len
)
requests = list(
itertools.chain(
(pymongo.DeleteOne(f) for f in to_delete),
(pymongo.ReplaceOne(f, d) for f, d in to_replace),
(pymongo.InsertOne(d) for d in to_insert if d),
)
)
# This will pipeline the operations so they all complete quickly. However if
# any of them fail, the rest of them will complete - i.e. this operation is not
# atomic.
await mongo_collection.bulk_write(requests, ordered=False)
def generate_primary_key_filter(self, identifier_data: IdentifierData): def generate_primary_key_filter(self, identifier_data: IdentifierData):
uuid = self._escape_key(identifier_data.uuid) uuid = self._escape_key(identifier_data.uuid)
@ -170,20 +246,48 @@ class Mongo(BaseDriver):
ret["_id.RED_primary_key"] = {"$exists": True} ret["_id.RED_primary_key"] = {"$exists": True}
return ret return ret
@classmethod
def generate_documents_to_insert(
cls, uuid: str, primary_keys: List[str], data: Dict[str, Dict[str, Any]], pkey_len: int
) -> Iterator[Dict[str, Any]]:
num_missing_pkeys = pkey_len - len(primary_keys)
if num_missing_pkeys == 1:
for pkey, document in data.items():
document["_id"] = {"RED_uuid": uuid, "RED_primary_key": primary_keys + [pkey]}
yield document
else:
for pkey, inner_data in data.items():
for document in cls.generate_documents_to_insert(
uuid, primary_keys + [pkey], inner_data, pkey_len
):
yield document
async def clear(self, identifier_data: IdentifierData): async def clear(self, identifier_data: IdentifierData):
# There are three cases here: # There are five cases here:
# 1) We're clearing out a subset of identifiers (aka identifiers is NOT empty) # 1) We're clearing out a subset of identifiers (aka identifiers is NOT empty)
# 2) We're clearing out full primary key and no identifiers # 2) We're clearing out full primary key and no identifiers
# 3) We're clearing out partial primary key and no identifiers # 3) We're clearing out partial primary key and no identifiers
# 4) Primary key is empty, should wipe all documents in the collection # 4) Primary key is empty, should wipe all documents in the collection
mongo_collection = self.get_collection(identifier_data.category) # 5) Category is empty, all of this cog's data should be deleted
pkey_filter = self.generate_primary_key_filter(identifier_data) pkey_filter = self.generate_primary_key_filter(identifier_data)
if len(identifier_data.identifiers) == 0: if identifier_data.identifiers:
# This covers cases 2-4 # This covers case 1
await mongo_collection.delete_many(pkey_filter) mongo_collection = self.get_collection(identifier_data.category)
else:
dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers)) dot_identifiers = ".".join(map(self._escape_key, identifier_data.identifiers))
await mongo_collection.update_one(pkey_filter, update={"$unset": {dot_identifiers: 1}}) await mongo_collection.update_one(pkey_filter, update={"$unset": {dot_identifiers: 1}})
elif identifier_data.category:
# This covers cases 2-4
mongo_collection = self.get_collection(identifier_data.category)
await mongo_collection.delete_many(pkey_filter)
else:
# This covers case 5
db = self.db
super_collection = db[self.cog_name]
results = await db.list_collections(
filter={"name": {"$regex": rf"^{super_collection.name}\."}}
)
for result in results:
await db[result["name"]].delete_many(pkey_filter)
@staticmethod @staticmethod
def _escape_key(key: str) -> str: def _escape_key(key: str) -> str: