From 7aff7962f07b0802edc6531f1e292a37e2d0bbe9 Mon Sep 17 00:00:00 2001 From: Toby Harradine Date: Sat, 9 May 2020 15:11:23 +1000 Subject: [PATCH] Fix creation of IdentifierData in config raw methods (#3829) * Fix creation of IdentifierData in config Also adds some new tests regarding partial primary keys. Resolves #3796. Signed-off-by: Toby Harradine Co-authored-by: jack1142 <6032823+jack1142@users.noreply.github.com> --- redbot/core/config.py | 8 +-- redbot/core/drivers/base.py | 22 ++++++++ tests/core/test_config.py | 106 ++++++++++++++++++++++++++++++++++++ 3 files changed, 132 insertions(+), 4 deletions(-) diff --git a/redbot/core/config.py b/redbot/core/config.py index ccabbd328..e0f39baad 100644 --- a/redbot/core/config.py +++ b/redbot/core/config.py @@ -334,7 +334,7 @@ class Group(Value): """ is_group = self.is_group(item) is_value = not is_group and self.is_value(item) - new_identifiers = self.identifier_data.add_identifier(item) + new_identifiers = self.identifier_data.get_child(item) if is_group: return Group( identifier_data=new_identifiers, @@ -381,7 +381,7 @@ class Group(Value): dict access. These are casted to `str` for you. """ path = tuple(str(p) for p in nested_path) - identifier_data = self.identifier_data.add_identifier(*path) + identifier_data = self.identifier_data.get_child(*path) await self.driver.clear(identifier_data) def is_group(self, item: Any) -> bool: @@ -499,7 +499,7 @@ class Group(Value): else: default = poss_default - identifier_data = self.identifier_data.add_identifier(*path) + identifier_data = self.identifier_data.get_child(*path) try: raw = await self.driver.get(identifier_data) except KeyError: @@ -583,7 +583,7 @@ class Group(Value): The value to store. """ path = tuple(str(p) for p in nested_path) - identifier_data = self.identifier_data.add_identifier(*path) + identifier_data = self.identifier_data.get_child(*path) if isinstance(value, dict): value = _str_key_dict(value) await self.driver.set(identifier_data, value=value) diff --git a/redbot/core/drivers/base.py b/redbot/core/drivers/base.py index 3d4ac42e3..59b1ac7d8 100644 --- a/redbot/core/drivers/base.py +++ b/redbot/core/drivers/base.py @@ -109,6 +109,28 @@ class IdentifierData: def __hash__(self) -> int: return hash((self.uuid, self.category, self.primary_key, self.identifiers)) + def get_child(self, *keys: str) -> "IdentifierData": + if not all(isinstance(i, str) for i in keys): + raise ValueError("Identifiers must be strings.") + + primary_keys = self.primary_key + identifiers = self.identifiers + num_missing_pkeys = self.primary_key_len - len(self.primary_key) + if num_missing_pkeys > 0: + primary_keys += keys[:num_missing_pkeys] + if len(keys) > num_missing_pkeys: + identifiers += keys[num_missing_pkeys:] + + return IdentifierData( + self.cog_name, + self.uuid, + self.category, + primary_keys, + identifiers, + self.primary_key_len, + self.is_custom, + ) + def add_identifier(self, *identifier: str) -> "IdentifierData": if not all(isinstance(i, str) for i in identifier): raise ValueError("Identifiers must be strings.") diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 5cd9e0e90..d955a3a0a 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -554,3 +554,109 @@ async def test_config_ctxmgr_atomicity(config): await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) assert len(await config.foo()) == 15 + + +@pytest.mark.asyncio +async def test_set_with_partial_primary_keys(config): + config.init_custom("CUSTOM", 3) + await config.custom("CUSTOM", "1").set({"11": {"111": {"foo": "bar"}}}) + assert await config.custom("CUSTOM", "1", "11", "111").foo() == "bar" + + await config.custom("CUSTOM", "2").set( + { + "11": {"111": {"foo": "bad"}}, + "22": {"111": {"foo": "baz"}}, + "33": {"111": {"foo": "boo"}, "222": {"foo": "boz"}}, + } + ) + assert await config.custom("CUSTOM", "2", "11", "111").foo() == "bad" + assert await config.custom("CUSTOM", "2", "22", "111").foo() == "baz" + assert await config.custom("CUSTOM", "2", "33", "111").foo() == "boo" + assert await config.custom("CUSTOM", "2", "33", "222").foo() == "boz" + + await config.custom("CUSTOM", "2").set({"22": {}, "33": {"111": {}, "222": {"foo": "biz"}}}) + with pytest.raises(KeyError): + await config.custom("CUSTOM").get_raw("2", "11") + with pytest.raises(KeyError): + await config.custom("CUSTOM").get_raw("2", "22", "111") + with pytest.raises(KeyError): + await config.custom("CUSTOM").get_raw("2", "33", "111", "foo") + assert await config.custom("CUSTOM", "2", "33", "222").foo() == "biz" + + +@pytest.mark.asyncio +async def test_raw_with_partial_primary_keys(config): + config.init_custom("CUSTOM", 1) + await config.custom("CUSTOM").set_raw("primary_key", "identifier", value=True) + assert await config.custom("CUSTOM", "primary_key").identifier() is True + await config.custom("CUSTOM").set_raw(value={"primary_key": {"identifier": False}}) + assert await config.custom("CUSTOM", "primary_key").identifier() is False + + +""" +Following PARAMS can be generated with: +from functools import reduce +from pprint import pprint +def generate_test_args(print_args=True): + pkeys = ("1", "2", "3") + identifiers = ("foo",) + full_dict = {"1": {"2": {"3": {"foo": "bar"}}}} + argvalues = [ + ( + pkeys[:x], + (pkeys[x:] + identifiers)[:y], + reduce(lambda d, k: d[k], (pkeys + identifiers)[:x+y], full_dict), + ) + for x in range(len(pkeys) + 1) + for y in range(len(pkeys) + len(identifiers) - x + 1) + ] + if print_args: + print("[") + for args in argvalues: + print(f" {args!r},") + print("]") + else: + return argvalues +generate_test_args() +""" +PARAMS = [ + ((), (), {"1": {"2": {"3": {"foo": "bar"}}}}), + ((), (1,), {"2": {"3": {"foo": "bar"}}}), + ((), (1, 2), {"3": {"foo": "bar"}}), + ((), (1, 2, 3), {"foo": "bar"}), + ((), (1, 2, 3, "foo"), "bar"), + ((1,), (), {"2": {"3": {"foo": "bar"}}}), + ((1,), (2,), {"3": {"foo": "bar"}}), + ((1,), (2, 3), {"foo": "bar"}), + ((1,), (2, 3, "foo"), "bar"), + ((1, 2), (), {"3": {"foo": "bar"}}), + ((1, 2), (3,), {"foo": "bar"}), + ((1, 2), (3, "foo"), "bar"), + ((1, 2, 3), (), {"foo": "bar"}), + ((1, 2, 3), ("foo",), "bar"), +] + + +@pytest.mark.parametrize("pkeys, raw_args, result", PARAMS) +@pytest.mark.asyncio +async def test_config_custom_partial_pkeys_get(config, pkeys, raw_args, result): + # setup + config.init_custom("TEST", 3) + config.register_custom("TEST") + await config.custom("TEST", 1, 2, 3).set({"foo": "bar"}) + + group = config.custom("TEST", *pkeys) + assert await group.get_raw(*raw_args) == result + + +@pytest.mark.parametrize("pkeys, raw_args, result", PARAMS) +@pytest.mark.asyncio +async def test_config_custom_partial_pkeys_set(config, pkeys, raw_args, result): + # setup + config.init_custom("TEST", 3) + config.register_custom("TEST") + await config.custom("TEST", 1, 2, 3).set({"foo": "blah"}) + + group = config.custom("TEST", *pkeys) + await group.set_raw(*raw_args, value=result) + assert await group.get_raw(*raw_args) == result