diff --git a/redbot/core/drivers/red_mongo.py b/redbot/core/drivers/red_mongo.py index 0f9bc585c..6f8415bbd 100644 --- a/redbot/core/drivers/red_mongo.py +++ b/redbot/core/drivers/red_mongo.py @@ -1,7 +1,12 @@ -import motor.motor_asyncio -from .red_base import BaseDriver +import re +from typing import Match, Pattern from urllib.parse import quote_plus +import motor.core +import motor.motor_asyncio + +from .red_base import BaseDriver + __all__ = ["Mongo"] @@ -80,6 +85,7 @@ class Mongo(BaseDriver): async def get(self, *identifiers: str): mongo_collection = self.get_collection() + identifiers = (*map(self._escape_key, identifiers),) dot_identifiers = ".".join(identifiers) partial = await mongo_collection.find_one( @@ -91,10 +97,14 @@ class Mongo(BaseDriver): for i in identifiers: partial = partial[i] + if isinstance(partial, dict): + return self._unescape_dict_keys(partial) return partial async def set(self, *identifiers: str, value=None): - dot_identifiers = ".".join(identifiers) + dot_identifiers = ".".join(map(self._escape_key, identifiers)) + if isinstance(value, dict): + value = self._escape_dict_keys(value) mongo_collection = self.get_collection() @@ -105,7 +115,7 @@ class Mongo(BaseDriver): ) async def clear(self, *identifiers: str): - dot_identifiers = ".".join(identifiers) + dot_identifiers = ".".join(map(self._escape_key, identifiers)) mongo_collection = self.get_collection() if len(identifiers) > 0: @@ -115,6 +125,62 @@ class Mongo(BaseDriver): else: await mongo_collection.delete_one({"_id": self.unique_cog_identifier}) + @staticmethod + def _escape_key(key: str) -> str: + return _SPECIAL_CHAR_PATTERN.sub(_replace_with_escaped, key) + + @staticmethod + def _unescape_key(key: str) -> str: + return _CHAR_ESCAPE_PATTERN.sub(_replace_with_unescaped, key) + + @classmethod + def _escape_dict_keys(cls, data: dict) -> dict: + """Recursively escape all keys in a dict.""" + ret = {} + for key, value in data.items(): + key = cls._escape_key(key) + if isinstance(value, dict): + value = cls._escape_dict_keys(value) + ret[key] = value + return ret + + @classmethod + def _unescape_dict_keys(cls, data: dict) -> dict: + """Recursively unescape all keys in a dict.""" + ret = {} + for key, value in data.items(): + key = cls._unescape_key(key) + if isinstance(value, dict): + value = cls._unescape_dict_keys(value) + ret[key] = value + return ret + + +_SPECIAL_CHAR_PATTERN: Pattern[str] = re.compile(r"([.$]|\\U0000002E|\\U00000024)") +_SPECIAL_CHARS = { + ".": "\\U0000002E", + "$": "\\U00000024", + "\\U0000002E": "\\U&0000002E", + "\\U00000024": "\\U&00000024", +} + + +def _replace_with_escaped(match: Match[str]) -> str: + return _SPECIAL_CHARS[match[0]] + + +_CHAR_ESCAPE_PATTERN: Pattern[str] = re.compile(r"(\\U0000002E|\\U00000024)") +_CHAR_ESCAPES = { + "\\U0000002E": ".", + "\\U00000024": "$", + "\\U&0000002E": "\\U0000002E", + "\\U&00000024": "\\U00000024", +} + + +def _replace_with_unescaped(match: Match[str]) -> str: + return _CHAR_ESCAPES[match[0]] + def get_config_details(): uri = None