[V3 Config] Require custom group initialization before usage (#2545)

* Require custom group initialization before usage and write that data to disk

* Style

* add tests

* remove custom info update method from drivers

* clean up remnant

* Turn config objects into a singleton to deal with custom group identifiers

* Fix dumbassery

* Stupid stupid stupid
This commit is contained in:
Will
2019-04-04 21:47:08 -04:00
committed by GitHub
parent fb722c79be
commit 0852d1be9f
6 changed files with 86 additions and 5 deletions

View File

@@ -2,6 +2,7 @@ import logging
import collections
from copy import deepcopy
from typing import Any, Union, Tuple, Dict, Awaitable, AsyncContextManager, TypeVar, TYPE_CHECKING
import weakref
import discord
@@ -15,6 +16,8 @@ log = logging.getLogger("red.config")
_T = TypeVar("_T")
_config_cache = weakref.WeakValueDictionary()
class _ValueCtxManager(Awaitable[_T], AsyncContextManager[_T]):
"""Context manager implementation of config values.
@@ -514,6 +517,19 @@ class Config:
USER = "USER"
MEMBER = "MEMBER"
def __new__(cls, cog_name, unique_identifier, *args, **kwargs):
key = (cog_name, unique_identifier)
if key[0] is None:
raise ValueError("You must provide either the cog instance or a cog name.")
if key in _config_cache:
conf = _config_cache[key]
else:
conf = object.__new__(cls)
_config_cache[key] = conf
return conf
def __init__(
self,
cog_name: str,
@@ -529,6 +545,8 @@ class Config:
self.force_registration = force_registration
self._defaults = defaults or {}
self.custom_groups = {}
@property
def defaults(self):
return deepcopy(self._defaults)
@@ -788,13 +806,32 @@ class Config:
"""
self._register_default(group_identifier, **kwargs)
def init_custom(self, group_identifier: str, identifier_count: int):
"""
Initializes a custom group for usage. This method must be called first!
"""
if group_identifier in self.custom_groups:
raise ValueError(f"Group identifier already registered: {group_identifier}")
self.custom_groups[group_identifier] = identifier_count
def _get_base_group(self, category: str, *primary_keys: str) -> Group:
is_custom = category not in (
self.GLOBAL,
self.GUILD,
self.USER,
self.MEMBER,
self.ROLE,
self.CHANNEL,
)
# noinspection PyTypeChecker
identifier_data = IdentifierData(
uuid=self.unique_identifier,
category=category,
primary_key=primary_keys,
identifiers=(),
custom_group_data=self.custom_groups,
is_custom=is_custom,
)
return Group(
identifier_data=identifier_data,
@@ -902,6 +939,8 @@ class Config:
The custom group's Group object.
"""
if group_identifier not in self.custom_groups:
raise ValueError(f"Group identifier not initialized: {group_identifier}")
return self._get_base_group(str(group_identifier), *map(str, identifiers))
async def _all_from_scope(self, scope: str) -> Dict[int, Dict[Any, Any]]:
@@ -1072,7 +1111,9 @@ class Config:
"""
if not scopes:
# noinspection PyTypeChecker
identifier_data = IdentifierData(self.unique_identifier, "", (), ())
identifier_data = IdentifierData(
self.unique_identifier, "", (), (), self.custom_groups
)
group = Group(identifier_data, defaults={}, driver=self.driver)
else:
group = self._get_base_group(*scopes)

View File

@@ -4,11 +4,21 @@ __all__ = ["BaseDriver", "IdentifierData"]
class IdentifierData:
def __init__(self, uuid: str, category: str, primary_key: Tuple[str], identifiers: Tuple[str]):
def __init__(
self,
uuid: str,
category: str,
primary_key: Tuple[str],
identifiers: Tuple[str],
custom_group_data: dict,
is_custom: bool = False,
):
self._uuid = uuid
self._category = category
self._primary_key = primary_key
self._identifiers = identifiers
self.custom_group_data = custom_group_data
self._is_custom = is_custom
@property
def uuid(self):
@@ -26,6 +36,10 @@ class IdentifierData:
def identifiers(self):
return self._identifiers
@property
def is_custom(self):
return self._is_custom
def __repr__(self):
return (
f"<IdentifierData uuid={self.uuid} category={self.category} primary_key={self.primary_key}"
@@ -37,7 +51,12 @@ class IdentifierData:
raise ValueError("Identifiers must be strings.")
return IdentifierData(
self.uuid, self.category, self.primary_key, self.identifiers + identifier
self.uuid,
self.category,
self.primary_key,
self.identifiers + identifier,
self.custom_group_data,
is_custom=self.is_custom,
)
def to_tuple(self):