Fix creation of IdentifierData in config raw methods (#3829)

* Fix creation of IdentifierData in config

Also adds some new tests regarding partial primary keys.

Resolves #3796.

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
Co-authored-by: jack1142 <6032823+jack1142@users.noreply.github.com>
This commit is contained in:
Toby Harradine 2020-05-09 15:11:23 +10:00 committed by GitHub
parent 1a96f276f8
commit 7aff7962f0
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 132 additions and 4 deletions

View File

@ -334,7 +334,7 @@ class Group(Value):
""" """
is_group = self.is_group(item) is_group = self.is_group(item)
is_value = not is_group and self.is_value(item) is_value = not is_group and self.is_value(item)
new_identifiers = self.identifier_data.add_identifier(item) new_identifiers = self.identifier_data.get_child(item)
if is_group: if is_group:
return Group( return Group(
identifier_data=new_identifiers, identifier_data=new_identifiers,
@ -381,7 +381,7 @@ class Group(Value):
dict access. These are casted to `str` for you. dict access. These are casted to `str` for you.
""" """
path = tuple(str(p) for p in nested_path) path = tuple(str(p) for p in nested_path)
identifier_data = self.identifier_data.add_identifier(*path) identifier_data = self.identifier_data.get_child(*path)
await self.driver.clear(identifier_data) await self.driver.clear(identifier_data)
def is_group(self, item: Any) -> bool: def is_group(self, item: Any) -> bool:
@ -499,7 +499,7 @@ class Group(Value):
else: else:
default = poss_default default = poss_default
identifier_data = self.identifier_data.add_identifier(*path) identifier_data = self.identifier_data.get_child(*path)
try: try:
raw = await self.driver.get(identifier_data) raw = await self.driver.get(identifier_data)
except KeyError: except KeyError:
@ -583,7 +583,7 @@ class Group(Value):
The value to store. The value to store.
""" """
path = tuple(str(p) for p in nested_path) path = tuple(str(p) for p in nested_path)
identifier_data = self.identifier_data.add_identifier(*path) identifier_data = self.identifier_data.get_child(*path)
if isinstance(value, dict): if isinstance(value, dict):
value = _str_key_dict(value) value = _str_key_dict(value)
await self.driver.set(identifier_data, value=value) await self.driver.set(identifier_data, value=value)

View File

@ -109,6 +109,28 @@ class IdentifierData:
def __hash__(self) -> int: def __hash__(self) -> int:
return hash((self.uuid, self.category, self.primary_key, self.identifiers)) return hash((self.uuid, self.category, self.primary_key, self.identifiers))
def get_child(self, *keys: str) -> "IdentifierData":
if not all(isinstance(i, str) for i in keys):
raise ValueError("Identifiers must be strings.")
primary_keys = self.primary_key
identifiers = self.identifiers
num_missing_pkeys = self.primary_key_len - len(self.primary_key)
if num_missing_pkeys > 0:
primary_keys += keys[:num_missing_pkeys]
if len(keys) > num_missing_pkeys:
identifiers += keys[num_missing_pkeys:]
return IdentifierData(
self.cog_name,
self.uuid,
self.category,
primary_keys,
identifiers,
self.primary_key_len,
self.is_custom,
)
def add_identifier(self, *identifier: str) -> "IdentifierData": def add_identifier(self, *identifier: str) -> "IdentifierData":
if not all(isinstance(i, str) for i in identifier): if not all(isinstance(i, str) for i in identifier):
raise ValueError("Identifiers must be strings.") raise ValueError("Identifiers must be strings.")

View File

@ -554,3 +554,109 @@ async def test_config_ctxmgr_atomicity(config):
await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED) await asyncio.wait(tasks, return_when=asyncio.ALL_COMPLETED)
assert len(await config.foo()) == 15 assert len(await config.foo()) == 15
@pytest.mark.asyncio
async def test_set_with_partial_primary_keys(config):
config.init_custom("CUSTOM", 3)
await config.custom("CUSTOM", "1").set({"11": {"111": {"foo": "bar"}}})
assert await config.custom("CUSTOM", "1", "11", "111").foo() == "bar"
await config.custom("CUSTOM", "2").set(
{
"11": {"111": {"foo": "bad"}},
"22": {"111": {"foo": "baz"}},
"33": {"111": {"foo": "boo"}, "222": {"foo": "boz"}},
}
)
assert await config.custom("CUSTOM", "2", "11", "111").foo() == "bad"
assert await config.custom("CUSTOM", "2", "22", "111").foo() == "baz"
assert await config.custom("CUSTOM", "2", "33", "111").foo() == "boo"
assert await config.custom("CUSTOM", "2", "33", "222").foo() == "boz"
await config.custom("CUSTOM", "2").set({"22": {}, "33": {"111": {}, "222": {"foo": "biz"}}})
with pytest.raises(KeyError):
await config.custom("CUSTOM").get_raw("2", "11")
with pytest.raises(KeyError):
await config.custom("CUSTOM").get_raw("2", "22", "111")
with pytest.raises(KeyError):
await config.custom("CUSTOM").get_raw("2", "33", "111", "foo")
assert await config.custom("CUSTOM", "2", "33", "222").foo() == "biz"
@pytest.mark.asyncio
async def test_raw_with_partial_primary_keys(config):
config.init_custom("CUSTOM", 1)
await config.custom("CUSTOM").set_raw("primary_key", "identifier", value=True)
assert await config.custom("CUSTOM", "primary_key").identifier() is True
await config.custom("CUSTOM").set_raw(value={"primary_key": {"identifier": False}})
assert await config.custom("CUSTOM", "primary_key").identifier() is False
"""
Following PARAMS can be generated with:
from functools import reduce
from pprint import pprint
def generate_test_args(print_args=True):
pkeys = ("1", "2", "3")
identifiers = ("foo",)
full_dict = {"1": {"2": {"3": {"foo": "bar"}}}}
argvalues = [
(
pkeys[:x],
(pkeys[x:] + identifiers)[:y],
reduce(lambda d, k: d[k], (pkeys + identifiers)[:x+y], full_dict),
)
for x in range(len(pkeys) + 1)
for y in range(len(pkeys) + len(identifiers) - x + 1)
]
if print_args:
print("[")
for args in argvalues:
print(f" {args!r},")
print("]")
else:
return argvalues
generate_test_args()
"""
PARAMS = [
((), (), {"1": {"2": {"3": {"foo": "bar"}}}}),
((), (1,), {"2": {"3": {"foo": "bar"}}}),
((), (1, 2), {"3": {"foo": "bar"}}),
((), (1, 2, 3), {"foo": "bar"}),
((), (1, 2, 3, "foo"), "bar"),
((1,), (), {"2": {"3": {"foo": "bar"}}}),
((1,), (2,), {"3": {"foo": "bar"}}),
((1,), (2, 3), {"foo": "bar"}),
((1,), (2, 3, "foo"), "bar"),
((1, 2), (), {"3": {"foo": "bar"}}),
((1, 2), (3,), {"foo": "bar"}),
((1, 2), (3, "foo"), "bar"),
((1, 2, 3), (), {"foo": "bar"}),
((1, 2, 3), ("foo",), "bar"),
]
@pytest.mark.parametrize("pkeys, raw_args, result", PARAMS)
@pytest.mark.asyncio
async def test_config_custom_partial_pkeys_get(config, pkeys, raw_args, result):
# setup
config.init_custom("TEST", 3)
config.register_custom("TEST")
await config.custom("TEST", 1, 2, 3).set({"foo": "bar"})
group = config.custom("TEST", *pkeys)
assert await group.get_raw(*raw_args) == result
@pytest.mark.parametrize("pkeys, raw_args, result", PARAMS)
@pytest.mark.asyncio
async def test_config_custom_partial_pkeys_set(config, pkeys, raw_args, result):
# setup
config.init_custom("TEST", 3)
config.register_custom("TEST")
await config.custom("TEST", 1, 2, 3).set({"foo": "blah"})
group = config.custom("TEST", *pkeys)
await group.set_raw(*raw_args, value=result)
assert await group.get_raw(*raw_args) == result