mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-09 12:48:54 -05:00
[MongoDB] Escape special characters in keys (#2212)
Essentially resolves #2038, although this is escaping and not rejecting keys as that issue implies. Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
parent
ce25011f0d
commit
f7b1f9f0dc
@ -1,7 +1,12 @@
|
||||
import motor.motor_asyncio
|
||||
from .red_base import BaseDriver
|
||||
import re
|
||||
from typing import Match, Pattern
|
||||
from urllib.parse import quote_plus
|
||||
|
||||
import motor.core
|
||||
import motor.motor_asyncio
|
||||
|
||||
from .red_base import BaseDriver
|
||||
|
||||
__all__ = ["Mongo"]
|
||||
|
||||
|
||||
@ -80,6 +85,7 @@ class Mongo(BaseDriver):
|
||||
async def get(self, *identifiers: str):
|
||||
mongo_collection = self.get_collection()
|
||||
|
||||
identifiers = (*map(self._escape_key, identifiers),)
|
||||
dot_identifiers = ".".join(identifiers)
|
||||
|
||||
partial = await mongo_collection.find_one(
|
||||
@ -91,10 +97,14 @@ class Mongo(BaseDriver):
|
||||
|
||||
for i in identifiers:
|
||||
partial = partial[i]
|
||||
if isinstance(partial, dict):
|
||||
return self._unescape_dict_keys(partial)
|
||||
return partial
|
||||
|
||||
async def set(self, *identifiers: str, value=None):
|
||||
dot_identifiers = ".".join(identifiers)
|
||||
dot_identifiers = ".".join(map(self._escape_key, identifiers))
|
||||
if isinstance(value, dict):
|
||||
value = self._escape_dict_keys(value)
|
||||
|
||||
mongo_collection = self.get_collection()
|
||||
|
||||
@ -105,7 +115,7 @@ class Mongo(BaseDriver):
|
||||
)
|
||||
|
||||
async def clear(self, *identifiers: str):
|
||||
dot_identifiers = ".".join(identifiers)
|
||||
dot_identifiers = ".".join(map(self._escape_key, identifiers))
|
||||
mongo_collection = self.get_collection()
|
||||
|
||||
if len(identifiers) > 0:
|
||||
@ -115,6 +125,62 @@ class Mongo(BaseDriver):
|
||||
else:
|
||||
await mongo_collection.delete_one({"_id": self.unique_cog_identifier})
|
||||
|
||||
@staticmethod
|
||||
def _escape_key(key: str) -> str:
|
||||
return _SPECIAL_CHAR_PATTERN.sub(_replace_with_escaped, key)
|
||||
|
||||
@staticmethod
|
||||
def _unescape_key(key: str) -> str:
|
||||
return _CHAR_ESCAPE_PATTERN.sub(_replace_with_unescaped, key)
|
||||
|
||||
@classmethod
|
||||
def _escape_dict_keys(cls, data: dict) -> dict:
|
||||
"""Recursively escape all keys in a dict."""
|
||||
ret = {}
|
||||
for key, value in data.items():
|
||||
key = cls._escape_key(key)
|
||||
if isinstance(value, dict):
|
||||
value = cls._escape_dict_keys(value)
|
||||
ret[key] = value
|
||||
return ret
|
||||
|
||||
@classmethod
|
||||
def _unescape_dict_keys(cls, data: dict) -> dict:
|
||||
"""Recursively unescape all keys in a dict."""
|
||||
ret = {}
|
||||
for key, value in data.items():
|
||||
key = cls._unescape_key(key)
|
||||
if isinstance(value, dict):
|
||||
value = cls._unescape_dict_keys(value)
|
||||
ret[key] = value
|
||||
return ret
|
||||
|
||||
|
||||
_SPECIAL_CHAR_PATTERN: Pattern[str] = re.compile(r"([.$]|\\U0000002E|\\U00000024)")
|
||||
_SPECIAL_CHARS = {
|
||||
".": "\\U0000002E",
|
||||
"$": "\\U00000024",
|
||||
"\\U0000002E": "\\U&0000002E",
|
||||
"\\U00000024": "\\U&00000024",
|
||||
}
|
||||
|
||||
|
||||
def _replace_with_escaped(match: Match[str]) -> str:
|
||||
return _SPECIAL_CHARS[match[0]]
|
||||
|
||||
|
||||
_CHAR_ESCAPE_PATTERN: Pattern[str] = re.compile(r"(\\U0000002E|\\U00000024)")
|
||||
_CHAR_ESCAPES = {
|
||||
"\\U0000002E": ".",
|
||||
"\\U00000024": "$",
|
||||
"\\U&0000002E": "\\U0000002E",
|
||||
"\\U&00000024": "\\U00000024",
|
||||
}
|
||||
|
||||
|
||||
def _replace_with_unescaped(match: Match[str]) -> str:
|
||||
return _CHAR_ESCAPES[match[0]]
|
||||
|
||||
|
||||
def get_config_details():
|
||||
uri = None
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user