[V3 Config] Require custom group initialization before usage (#2545)

* Require custom group initialization before usage and write that data to disk

* Style

* add tests

* remove custom info update method from drivers

* clean up remnant

* Turn config objects into a singleton to deal with custom group identifiers

* Fix dumbassery

* Stupid stupid stupid
This commit is contained in:
Will 2019-04-04 21:47:08 -04:00 committed by GitHub
parent fb722c79be
commit 0852d1be9f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 86 additions and 5 deletions

View File

@ -100,7 +100,9 @@ class Permissions(commands.Cog):
# Note that GLOBAL rules are denoted by an ID of 0. # Note that GLOBAL rules are denoted by an ID of 0.
self.config = config.Config.get_conf(self, identifier=78631113035100160) self.config = config.Config.get_conf(self, identifier=78631113035100160)
self.config.register_global(version="") self.config.register_global(version="")
self.config.init_custom(COG, 1)
self.config.register_custom(COG) self.config.register_custom(COG)
self.config.init_custom(COMMAND, 1)
self.config.register_custom(COMMAND) self.config.register_custom(COMMAND)
@commands.group() @commands.group()

View File

@ -45,6 +45,7 @@ class Reports(commands.Cog):
self.bot = bot self.bot = bot
self.config = Config.get_conf(self, 78631113035100160, force_registration=True) self.config = Config.get_conf(self, 78631113035100160, force_registration=True)
self.config.register_guild(**self.default_guild_settings) self.config.register_guild(**self.default_guild_settings)
self.config.init_custom("REPORT", 2)
self.config.register_custom("REPORT", **self.default_report) self.config.register_custom("REPORT", **self.default_report)
self.antispam = {} self.antispam = {}
self.user_cache = [] self.user_cache = []

View File

@ -2,6 +2,7 @@ import logging
import collections import collections
from copy import deepcopy from copy import deepcopy
from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar, TYPE_CHECKING from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar, TYPE_CHECKING
import weakref
import discord import discord
@ -15,6 +16,8 @@ log = logging.getLogger("red.config")
_T = TypeVar("_T") _T = TypeVar("_T")
_config_cache = weakref.WeakValueDictionary()
class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]): class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]):
"""Context manager implementation of config values. """Context manager implementation of config values.
@ -514,6 +517,19 @@ class Config:
USER = "USER" USER = "USER"
MEMBER = "MEMBER" MEMBER = "MEMBER"
def __new__(cls, cog_name, unique_identifier, *args, **kwargs):
key = (cog_name, unique_identifier)
if key[0] is None:
raise ValueError("You must provide either the cog instance or a cog name.")
if key in _config_cache:
conf = _config_cache[key]
else:
conf = object.__new__(cls)
_config_cache[key] = conf
return conf
def __init__( def __init__(
self, self,
cog_name: str, cog_name: str,
@ -529,6 +545,8 @@ class Config:
self.force_registration = force_registration self.force_registration = force_registration
self._defaults = defaults or {} self._defaults = defaults or {}
self.custom_groups = {}
@property @property
def defaults(self): def defaults(self):
return deepcopy(self._defaults) return deepcopy(self._defaults)
@ -788,13 +806,32 @@ class Config:
""" """
self._register_default(group_identifier, **kwargs) self._register_default(group_identifier, **kwargs)
def init_custom(self, group_identifier: str, identifier_count: int):
"""
Initializes a custom group for usage. This method must be called first!
"""
if group_identifier in self.custom_groups:
raise ValueError(f"Group identifier already registered: {group_identifier}")
self.custom_groups[group_identifier] = identifier_count
def _get_base_group(self, category: str, *primary_keys: str) -> Group: def _get_base_group(self, category: str, *primary_keys: str) -> Group:
is_custom = category not in (
self.GLOBAL,
self.GUILD,
self.USER,
self.MEMBER,
self.ROLE,
self.CHANNEL,
)
# noinspection PyTypeChecker # noinspection PyTypeChecker
identifier_data = IdentifierData( identifier_data = IdentifierData(
uuid=self.unique_identifier, uuid=self.unique_identifier,
category=category, category=category,
primary_key=primary_keys, primary_key=primary_keys,
identifiers=(), identifiers=(),
custom_group_data=self.custom_groups,
is_custom=is_custom,
) )
return Group( return Group(
identifier_data=identifier_data, identifier_data=identifier_data,
@ -902,6 +939,8 @@ class Config:
The custom group's Group object. The custom group's Group object.
""" """
if group_identifier not in self.custom_groups:
raise ValueError(f"Group identifier not initialized: {group_identifier}")
return self._get_base_group(str(group_identifier), *map(str, identifiers)) return self._get_base_group(str(group_identifier), *map(str, identifiers))
async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]: async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]:
@ -1072,7 +1111,9 @@ class Config:
""" """
if not scopes: if not scopes:
# noinspection PyTypeChecker # noinspection PyTypeChecker
identifier_data = IdentifierData(self.unique_identifier, "", (), ()) identifier_data = IdentifierData(
self.unique_identifier, "", (), (), self.custom_groups
)
group = Group(identifier_data, defaults={}, driver=self.driver) group = Group(identifier_data, defaults={}, driver=self.driver)
else: else:
group = self._get_base_group(*scopes) group = self._get_base_group(*scopes)

View File

@ -4,11 +4,21 @@ __all__ = ["BaseDriver", "IdentifierData"]
class IdentifierData: class IdentifierData:
def __init__(self, uuid: str, category: str, primary_key: Tuple[str], identifiers: Tuple[str]): def __init__(
self,
uuid: str,
category: str,
primary_key: Tuple[str],
identifiers: Tuple[str],
custom_group_data: dict,
is_custom: bool = False,
):
self._uuid = uuid self._uuid = uuid
self._category = category self._category = category
self._primary_key = primary_key self._primary_key = primary_key
self._identifiers = identifiers self._identifiers = identifiers
self.custom_group_data = custom_group_data
self._is_custom = is_custom
@property @property
def uuid(self): def uuid(self):
@ -26,6 +36,10 @@ class IdentifierData:
def identifiers(self): def identifiers(self):
return self._identifiers return self._identifiers
@property
def is_custom(self):
return self._is_custom
def __repr__(self): def __repr__(self):
return ( return (
f"<IdentifierData uuid={self.uuid} category={self.category} primary_key={self.primary_key}" f"<IdentifierData uuid={self.uuid} category={self.category} primary_key={self.primary_key}"
@ -37,7 +51,12 @@ class IdentifierData:
raise ValueError("Identifiers must be strings.") raise ValueError("Identifiers must be strings.")
return IdentifierData( return IdentifierData(
self.uuid, self.category, self.primary_key, self.identifiers + identifier self.uuid,
self.category,
self.primary_key,
self.identifiers + identifier,
self.custom_group_data,
is_custom=self.is_custom,
) )
def to_tuple(self): def to_tuple(self):

View File

@ -1,11 +1,13 @@
import random import random
from collections import namedtuple from collections import namedtuple
from pathlib import Path from pathlib import Path
import weakref
import pytest import pytest
from _pytest.monkeypatch import MonkeyPatch from _pytest.monkeypatch import MonkeyPatch
from redbot.core import Config from redbot.core import Config
from redbot.core.bot import Red from redbot.core.bot import Red
from redbot.core import config as config_module
from redbot.core.drivers import red_json from redbot.core.drivers import red_json
@ -65,11 +67,11 @@ def json_driver(tmpdir_factory):
@pytest.fixture() @pytest.fixture()
def config(json_driver): def config(json_driver):
config_module._config_cache = weakref.WeakValueDictionary()
conf = Config( conf = Config(
cog_name="PyTest", unique_identifier=json_driver.unique_cog_identifier, driver=json_driver cog_name="PyTest", unique_identifier=json_driver.unique_cog_identifier, driver=json_driver
) )
yield conf yield conf
conf._defaults = {}
@pytest.fixture() @pytest.fixture()
@ -77,6 +79,7 @@ def config_fr(json_driver):
""" """
Mocked config object with force_register enabled. Mocked config object with force_register enabled.
""" """
config_module._config_cache = weakref.WeakValueDictionary()
conf = Config( conf = Config(
cog_name="PyTest", cog_name="PyTest",
unique_identifier=json_driver.unique_cog_identifier, unique_identifier=json_driver.unique_cog_identifier,
@ -84,7 +87,6 @@ def config_fr(json_driver):
force_registration=True, force_registration=True,
) )
yield conf yield conf
conf._defaults = {}
# region Dpy Mocks # region Dpy Mocks

View File

@ -490,3 +490,19 @@ async def test_cast_str_nested(config):
config.register_global(foo={}) config.register_global(foo={})
await config.foo.set({123: True, 456: {789: False}}) await config.foo.set({123: True, 456: {789: False}})
assert await config.foo() == {"123": True, "456": {"789": False}} assert await config.foo() == {"123": True, "456": {"789": False}}
def test_config_custom_noinit(config):
with pytest.raises(ValueError):
config.custom("TEST", 1, 2, 3)
def test_config_custom_init(config):
config.init_custom("TEST", 3)
config.custom("TEST", 1, 2, 3)
def test_config_custom_doubleinit(config):
config.init_custom("TEST", 3)
with pytest.raises(ValueError):
config.init_custom("TEST", 2)