[MongoDB] Escape special characters in keys (#2212)

Essentially resolves #2038, although this is escaping and not rejecting keys as that issue implies.

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
Toby Harradine 2018-10-11 11:20:42 +11:00 committed by GitHub
parent ce25011f0d
commit f7b1f9f0dc
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,7 +1,12 @@
import motor.motor_asyncio
from .red_base import BaseDriver
import re
from typing import Match, Pattern
from urllib.parse import quote_plus
import motor.core
import motor.motor_asyncio
from .red_base import BaseDriver
__all__ = ["Mongo"]
@ -80,6 +85,7 @@ class Mongo(BaseDriver):
async def get(self, *identifiers: str):
mongo_collection = self.get_collection()
identifiers = (*map(self._escape_key, identifiers),)
dot_identifiers = ".".join(identifiers)
partial = await mongo_collection.find_one(
@ -91,10 +97,14 @@ class Mongo(BaseDriver):
for i in identifiers:
partial = partial[i]
if isinstance(partial, dict):
return self._unescape_dict_keys(partial)
return partial
async def set(self, *identifiers: str, value=None):
dot_identifiers = ".".join(identifiers)
dot_identifiers = ".".join(map(self._escape_key, identifiers))
if isinstance(value, dict):
value = self._escape_dict_keys(value)
mongo_collection = self.get_collection()
@ -105,7 +115,7 @@ class Mongo(BaseDriver):
)
async def clear(self, *identifiers: str):
dot_identifiers = ".".join(identifiers)
dot_identifiers = ".".join(map(self._escape_key, identifiers))
mongo_collection = self.get_collection()
if len(identifiers) > 0:
@ -115,6 +125,62 @@ class Mongo(BaseDriver):
else:
await mongo_collection.delete_one({"_id": self.unique_cog_identifier})
@staticmethod
def _escape_key(key: str) -> str:
return _SPECIAL_CHAR_PATTERN.sub(_replace_with_escaped, key)
@staticmethod
def _unescape_key(key: str) -> str:
return _CHAR_ESCAPE_PATTERN.sub(_replace_with_unescaped, key)
@classmethod
def _escape_dict_keys(cls, data: dict) -> dict:
"""Recursively escape all keys in a dict."""
ret = {}
for key, value in data.items():
key = cls._escape_key(key)
if isinstance(value, dict):
value = cls._escape_dict_keys(value)
ret[key] = value
return ret
@classmethod
def _unescape_dict_keys(cls, data: dict) -> dict:
"""Recursively unescape all keys in a dict."""
ret = {}
for key, value in data.items():
key = cls._unescape_key(key)
if isinstance(value, dict):
value = cls._unescape_dict_keys(value)
ret[key] = value
return ret
_SPECIAL_CHAR_PATTERN: Pattern[str] = re.compile(r"([.$]|\\U0000002E|\\U00000024)")
_SPECIAL_CHARS = {
".": "\\U0000002E",
"$": "\\U00000024",
"\\U0000002E": "\\U&0000002E",
"\\U00000024": "\\U&00000024",
}
def _replace_with_escaped(match: Match[str]) -> str:
return _SPECIAL_CHARS[match[0]]
_CHAR_ESCAPE_PATTERN: Pattern[str] = re.compile(r"(\\U0000002E|\\U00000024)")
_CHAR_ESCAPES = {
"\\U0000002E": ".",
"\\U00000024": "$",
"\\U&0000002E": "\\U0000002E",
"\\U&00000024": "\\U00000024",
}
def _replace_with_unescaped(match: Match[str]) -> str:
return _CHAR_ESCAPES[match[0]]
def get_config_details():
uri = None