diff --git a/redbot/core/config.py b/redbot/core/config.py index 0528b1583..71b4055a2 100644 --- a/redbot/core/config.py +++ b/redbot/core/config.py @@ -1,13 +1,18 @@ import logging import collections from copy import deepcopy -from typing import Callable, Union, Tuple +from typing import Union, Tuple import discord from .data_manager import cog_data_path, core_data_path from .drivers import get_driver +from .utils import TYPE_CHECKING + +if TYPE_CHECKING: + from .drivers.red_base import BaseDriver + log = logging.getLogger("red.config") @@ -52,24 +57,23 @@ class Value: element from a json document. default The default value for the data element that `identifiers` points at. - spawner : `redbot.core.drivers.red_base.BaseDriver` - A reference to `Config.spawner`. + driver : `redbot.core.drivers.red_base.BaseDriver` + A reference to `Config.driver`. """ - def __init__(self, identifiers: Tuple[str], default_value, spawner): + def __init__(self, identifiers: Tuple[str], default_value, driver): self._identifiers = identifiers self.default = default_value - self.spawner = spawner + self.driver = driver @property def identifiers(self): return tuple(str(i) for i in self._identifiers) async def _get(self, default): - driver = self.spawner.get_driver() try: - ret = await driver.get(self.identifiers) + ret = await self.driver.get(*self.identifiers) except KeyError: return default if default is not None else self.default return ret @@ -138,8 +142,7 @@ class Value: The new literal value of this attribute. """ - driver = self.spawner.get_driver() - await driver.set(self.identifiers, value) + await self.driver.set(*self.identifiers, value=value) class Group(Value): @@ -155,19 +158,19 @@ class Group(Value): All registered default values for this Group. force_registration : `bool` Same as `Config.force_registration`. - spawner : `redbot.core.drivers.red_base.BaseDriver` - A reference to `Config.spawner`. + driver : `redbot.core.drivers.red_base.BaseDriver` + A reference to `Config.driver`. """ def __init__(self, identifiers: Tuple[str], defaults: dict, - spawner, + driver, force_registration: bool=False): self._defaults = defaults self.force_registration = force_registration - self.spawner = spawner + self.driver = driver - super().__init__(identifiers, {}, self.spawner) + super().__init__(identifiers, {}, self.driver) @property def defaults(self): @@ -205,14 +208,14 @@ class Group(Value): return Group( identifiers=new_identifiers, defaults=self._defaults[item], - spawner=self.spawner, + driver=self.driver, force_registration=self.force_registration ) elif is_value: return Value( identifiers=new_identifiers, default_value=self._defaults[item], - spawner=self.spawner + driver=self.driver ) elif self.force_registration: raise AttributeError( @@ -223,7 +226,7 @@ class Group(Value): return Value( identifiers=new_identifiers, default_value=None, - spawner=self.spawner + driver=self.driver ) def is_group(self, item: str) -> bool: @@ -394,9 +397,8 @@ class Config: unique_identifier : `int` Unique identifier provided to differentiate cog data when name conflicts occur. - spawner - A callable object that returns some driver that implements - `redbot.core.drivers.red_base.BaseDriver`. + driver + An instance of a driver that implements `redbot.core.drivers.red_base.BaseDriver`. force_registration : `bool` Determines if Config should throw an error if a cog attempts to access an attribute which has not been previously registered. @@ -416,13 +418,14 @@ class Config: MEMBER = "MEMBER" def __init__(self, cog_name: str, unique_identifier: str, - driver_spawn: Callable, + driver: "BaseDriver", force_registration: bool=False, defaults: dict=None): self.cog_name = cog_name self.unique_identifier = unique_identifier - self.spawner = driver_spawn + self.driver = driver + self.driver.unique_cog_identifier = self.unique_identifier self.force_registration = force_registration self._defaults = defaults or {} @@ -468,11 +471,11 @@ class Config: log.debug("Using driver: '{}'".format(driver_name)) - spawner = get_driver(driver_name, cog_name, data_path_override=cog_path_override, - **driver_details) + driver = get_driver(driver_name, cog_name, data_path_override=cog_path_override, + **driver_details) return cls(cog_name=cog_name, unique_identifier=uuid, force_registration=force_registration, - driver_spawn=spawner) + driver=driver) @classmethod def get_core_conf(cls, force_registration: bool=False): @@ -495,9 +498,9 @@ class Config: driver_name = basic_config.get('STORAGE_TYPE', 'JSON') driver_details = basic_config.get('STORAGE_DETAILS', {}) - driver_spawn = get_driver(driver_name, "Core", data_path_override=core_path, - **driver_details) - return cls(cog_name="Core", driver_spawn=driver_spawn, + driver = get_driver(driver_name, "Core", data_path_override=core_path, + **driver_details) + return cls(cog_name="Core", driver=driver, unique_identifier='0', force_registration=force_registration) @@ -668,9 +671,9 @@ class Config: def _get_base_group(self, key: str, *identifiers: str) -> Group: # noinspection PyTypeChecker return Group( - identifiers=(self.unique_identifier, key) + identifiers, + identifiers=(key, *identifiers), defaults=self.defaults.get(key, {}), - spawner=self.spawner, + driver=self.driver, force_registration=self.force_registration ) @@ -911,7 +914,7 @@ class Config: if not scopes: group = Group(identifiers=(self.unique_identifier, ), defaults={}, - spawner=self.spawner) + driver=self.driver) else: group = self._get_base_group(*scopes) await group.set({}) diff --git a/redbot/core/drivers/red_base.py b/redbot/core/drivers/red_base.py index 119270cc2..a94f86d0c 100644 --- a/redbot/core/drivers/red_base.py +++ b/redbot/core/drivers/red_base.py @@ -2,14 +2,13 @@ from typing import Tuple __all__ = ["BaseDriver"] + class BaseDriver: def __init__(self, cog_name): self.cog_name = cog_name + self.unique_cog_identifier = None # This is set by Config's init method - def get_driver(self): - raise NotImplementedError - - async def get(self, identifiers: Tuple[str]): + async def get(self, *identifiers: Tuple[str]): """ Finds the value indicate by the given identifiers. @@ -30,7 +29,7 @@ class BaseDriver: """ raise NotImplementedError - async def set(self, identifiers: Tuple[str], value): + async def set(self, *identifiers: Tuple[str], value=None): """ Sets the value of the key indicated by the given identifiers. diff --git a/redbot/core/drivers/red_json.py b/redbot/core/drivers/red_json.py index f029513ef..fc2a4ec69 100644 --- a/redbot/core/drivers/red_json.py +++ b/redbot/core/drivers/red_json.py @@ -41,21 +41,20 @@ class JSON(BaseDriver): self.data = {} self.jsonIO._save_json(self.data) - def get_driver(self): - return self - - async def get(self, identifiers: Tuple[str]): + async def get(self, *identifiers: Tuple[str]): partial = self.data - for i in identifiers: + full_identifiers = (self.unique_cog_identifier, *identifiers) + for i in full_identifiers: partial = partial[i] return partial - async def set(self, identifiers, value): + async def set(self, *identifiers: str, value=None): partial = self.data - for i in identifiers[:-1]: + full_identifiers = (self.unique_cog_identifier, *identifiers) + for i in full_identifiers[:-1]: if i not in partial: partial[i] = {} partial = partial[i] - partial[identifiers[-1]] = value + partial[full_identifiers[-1]] = value await self.jsonIO._threadsafe_save_json(self.data) diff --git a/redbot/core/drivers/red_mongo.py b/redbot/core/drivers/red_mongo.py index 33eed5788..73046680a 100644 --- a/redbot/core/drivers/red_mongo.py +++ b/redbot/core/drivers/red_mongo.py @@ -71,16 +71,15 @@ class Mongo(BaseDriver): uuid, identifiers = identifiers[0], identifiers[1:] return uuid, identifiers - async def get(self, identifiers: Tuple[str]): + async def get(self, *identifiers: Tuple[str]): await self._ensure_connected() - uuid, identifiers = self._parse_identifiers(identifiers) mongo_collection = self.get_collection() dot_identifiers = '.'.join(identifiers) partial = await mongo_collection.find_one( - filter={'_id': uuid}, + filter={'_id': self.unique_cog_identifier}, projection={dot_identifiers: True} ) @@ -92,23 +91,19 @@ class Mongo(BaseDriver): partial = partial[i] return partial - async def set(self, identifiers: Tuple[str], value): + async def set(self, *identifiers: str, value=None): await self._ensure_connected() - uuid, identifiers = self._parse_identifiers(identifiers) dot_identifiers = '.'.join(identifiers) mongo_collection = self.get_collection() await mongo_collection.update_one( - {'_id': uuid}, + {'_id': self.unique_cog_identifier}, update={"$set": {dot_identifiers: value}}, upsert=True ) - def get_driver(self): - return self - def get_config_details(): host = input("Enter host address: ") diff --git a/tests/conftest.py b/tests/conftest.py index 8e55c90c7..af102e2e1 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -49,7 +49,7 @@ def config(json_driver): conf = Config( cog_name="PyTest", unique_identifier=str(uuid.uuid4()), - driver_spawn=json_driver) + driver=json_driver) yield conf conf._defaults = {} @@ -63,7 +63,7 @@ def config_fr(json_driver): conf = Config( cog_name="PyTest", unique_identifier=str(uuid.uuid4()), - driver_spawn=json_driver, + driver=json_driver, force_registration=True ) yield conf