mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-09 12:48:54 -05:00
[MongoDB] Escape special characters in keys (#2212)
Essentially resolves #2038, although this is escaping and not rejecting keys as that issue implies. Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
parent
ce25011f0d
commit
f7b1f9f0dc
@ -1,7 +1,12 @@
|
|||||||
import motor.motor_asyncio
|
import re
|
||||||
from .red_base import BaseDriver
|
from typing import Match, Pattern
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
|
import motor.core
|
||||||
|
import motor.motor_asyncio
|
||||||
|
|
||||||
|
from .red_base import BaseDriver
|
||||||
|
|
||||||
__all__ = ["Mongo"]
|
__all__ = ["Mongo"]
|
||||||
|
|
||||||
|
|
||||||
@ -80,6 +85,7 @@ class Mongo(BaseDriver):
|
|||||||
async def get(self, *identifiers: str):
|
async def get(self, *identifiers: str):
|
||||||
mongo_collection = self.get_collection()
|
mongo_collection = self.get_collection()
|
||||||
|
|
||||||
|
identifiers = (*map(self._escape_key, identifiers),)
|
||||||
dot_identifiers = ".".join(identifiers)
|
dot_identifiers = ".".join(identifiers)
|
||||||
|
|
||||||
partial = await mongo_collection.find_one(
|
partial = await mongo_collection.find_one(
|
||||||
@ -91,10 +97,14 @@ class Mongo(BaseDriver):
|
|||||||
|
|
||||||
for i in identifiers:
|
for i in identifiers:
|
||||||
partial = partial[i]
|
partial = partial[i]
|
||||||
|
if isinstance(partial, dict):
|
||||||
|
return self._unescape_dict_keys(partial)
|
||||||
return partial
|
return partial
|
||||||
|
|
||||||
async def set(self, *identifiers: str, value=None):
|
async def set(self, *identifiers: str, value=None):
|
||||||
dot_identifiers = ".".join(identifiers)
|
dot_identifiers = ".".join(map(self._escape_key, identifiers))
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = self._escape_dict_keys(value)
|
||||||
|
|
||||||
mongo_collection = self.get_collection()
|
mongo_collection = self.get_collection()
|
||||||
|
|
||||||
@ -105,7 +115,7 @@ class Mongo(BaseDriver):
|
|||||||
)
|
)
|
||||||
|
|
||||||
async def clear(self, *identifiers: str):
|
async def clear(self, *identifiers: str):
|
||||||
dot_identifiers = ".".join(identifiers)
|
dot_identifiers = ".".join(map(self._escape_key, identifiers))
|
||||||
mongo_collection = self.get_collection()
|
mongo_collection = self.get_collection()
|
||||||
|
|
||||||
if len(identifiers) > 0:
|
if len(identifiers) > 0:
|
||||||
@ -115,6 +125,62 @@ class Mongo(BaseDriver):
|
|||||||
else:
|
else:
|
||||||
await mongo_collection.delete_one({"_id": self.unique_cog_identifier})
|
await mongo_collection.delete_one({"_id": self.unique_cog_identifier})
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _escape_key(key: str) -> str:
|
||||||
|
return _SPECIAL_CHAR_PATTERN.sub(_replace_with_escaped, key)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _unescape_key(key: str) -> str:
|
||||||
|
return _CHAR_ESCAPE_PATTERN.sub(_replace_with_unescaped, key)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _escape_dict_keys(cls, data: dict) -> dict:
|
||||||
|
"""Recursively escape all keys in a dict."""
|
||||||
|
ret = {}
|
||||||
|
for key, value in data.items():
|
||||||
|
key = cls._escape_key(key)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = cls._escape_dict_keys(value)
|
||||||
|
ret[key] = value
|
||||||
|
return ret
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _unescape_dict_keys(cls, data: dict) -> dict:
|
||||||
|
"""Recursively unescape all keys in a dict."""
|
||||||
|
ret = {}
|
||||||
|
for key, value in data.items():
|
||||||
|
key = cls._unescape_key(key)
|
||||||
|
if isinstance(value, dict):
|
||||||
|
value = cls._unescape_dict_keys(value)
|
||||||
|
ret[key] = value
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
|
_SPECIAL_CHAR_PATTERN: Pattern[str] = re.compile(r"([.$]|\\U0000002E|\\U00000024)")
|
||||||
|
_SPECIAL_CHARS = {
|
||||||
|
".": "\\U0000002E",
|
||||||
|
"$": "\\U00000024",
|
||||||
|
"\\U0000002E": "\\U&0000002E",
|
||||||
|
"\\U00000024": "\\U&00000024",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_with_escaped(match: Match[str]) -> str:
|
||||||
|
return _SPECIAL_CHARS[match[0]]
|
||||||
|
|
||||||
|
|
||||||
|
_CHAR_ESCAPE_PATTERN: Pattern[str] = re.compile(r"(\\U0000002E|\\U00000024)")
|
||||||
|
_CHAR_ESCAPES = {
|
||||||
|
"\\U0000002E": ".",
|
||||||
|
"\\U00000024": "$",
|
||||||
|
"\\U&0000002E": "\\U0000002E",
|
||||||
|
"\\U&00000024": "\\U00000024",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _replace_with_unescaped(match: Match[str]) -> str:
|
||||||
|
return _CHAR_ESCAPES[match[0]]
|
||||||
|
|
||||||
|
|
||||||
def get_config_details():
|
def get_config_details():
|
||||||
uri = None
|
uri = None
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user