[V3 Config] Driver code initial cleanup (#1315)

* Remove get_driver

* Rename self.driver to self._driver

* Do not unnecessarily pass the cog identifier

* Remove unused import

* Fix type annotation

* Missed a keyword rename

* Modify signature of get/set methods in drivers
This commit is contained in:
Will 2018-02-18 22:30:32 -05:00 committed by palmtree5
parent 3984cb8f48
commit f9d846a704
5 changed files with 51 additions and 55 deletions

View File

@ -1,13 +1,18 @@
import logging import logging
import collections import collections
from copy import deepcopy from copy import deepcopy
from typing import Callable, Union, Tuple from typing import Union, Tuple
import discord import discord
from .data_manager import cog_data_path, core_data_path from .data_manager import cog_data_path, core_data_path
from .drivers import get_driver from .drivers import get_driver
from .utils import TYPE_CHECKING
if TYPE_CHECKING:
from .drivers.red_base import BaseDriver
log = logging.getLogger("red.config") log = logging.getLogger("red.config")
@ -52,24 +57,23 @@ class Value:
element from a json document. element from a json document.
default default
The default value for the data element that `identifiers` points at. The default value for the data element that `identifiers` points at.
spawner : `redbot.core.drivers.red_base.BaseDriver` driver : `redbot.core.drivers.red_base.BaseDriver`
A reference to `Config.spawner`. A reference to `Config.driver`.
""" """
def __init__(self, identifiers: Tuple[str], default_value, spawner): def __init__(self, identifiers: Tuple[str], default_value, driver):
self._identifiers = identifiers self._identifiers = identifiers
self.default = default_value self.default = default_value
self.spawner = spawner self.driver = driver
@property @property
def identifiers(self): def identifiers(self):
return tuple(str(i) for i in self._identifiers) return tuple(str(i) for i in self._identifiers)
async def _get(self, default): async def _get(self, default):
driver = self.spawner.get_driver()
try: try:
ret = await driver.get(self.identifiers) ret = await self.driver.get(*self.identifiers)
except KeyError: except KeyError:
return default if default is not None else self.default return default if default is not None else self.default
return ret return ret
@ -138,8 +142,7 @@ class Value:
The new literal value of this attribute. The new literal value of this attribute.
""" """
driver = self.spawner.get_driver() await self.driver.set(*self.identifiers, value=value)
await driver.set(self.identifiers, value)
class Group(Value): class Group(Value):
@ -155,19 +158,19 @@ class Group(Value):
All registered default values for this Group. All registered default values for this Group.
force_registration : `bool` force_registration : `bool`
Same as `Config.force_registration`. Same as `Config.force_registration`.
spawner : `redbot.core.drivers.red_base.BaseDriver` driver : `redbot.core.drivers.red_base.BaseDriver`
A reference to `Config.spawner`. A reference to `Config.driver`.
""" """
def __init__(self, identifiers: Tuple[str], def __init__(self, identifiers: Tuple[str],
defaults: dict, defaults: dict,
spawner, driver,
force_registration: bool=False): force_registration: bool=False):
self._defaults = defaults self._defaults = defaults
self.force_registration = force_registration self.force_registration = force_registration
self.spawner = spawner self.driver = driver
super().__init__(identifiers, {}, self.spawner) super().__init__(identifiers, {}, self.driver)
@property @property
def defaults(self): def defaults(self):
@ -205,14 +208,14 @@ class Group(Value):
return Group( return Group(
identifiers=new_identifiers, identifiers=new_identifiers,
defaults=self._defaults[item], defaults=self._defaults[item],
spawner=self.spawner, driver=self.driver,
force_registration=self.force_registration force_registration=self.force_registration
) )
elif is_value: elif is_value:
return Value( return Value(
identifiers=new_identifiers, identifiers=new_identifiers,
default_value=self._defaults[item], default_value=self._defaults[item],
spawner=self.spawner driver=self.driver
) )
elif self.force_registration: elif self.force_registration:
raise AttributeError( raise AttributeError(
@ -223,7 +226,7 @@ class Group(Value):
return Value( return Value(
identifiers=new_identifiers, identifiers=new_identifiers,
default_value=None, default_value=None,
spawner=self.spawner driver=self.driver
) )
def is_group(self, item: str) -> bool: def is_group(self, item: str) -> bool:
@ -394,9 +397,8 @@ class Config:
unique_identifier : `int` unique_identifier : `int`
Unique identifier provided to differentiate cog data when name Unique identifier provided to differentiate cog data when name
conflicts occur. conflicts occur.
spawner driver
A callable object that returns some driver that implements An instance of a driver that implements `redbot.core.drivers.red_base.BaseDriver`.
`redbot.core.drivers.red_base.BaseDriver`.
force_registration : `bool` force_registration : `bool`
Determines if Config should throw an error if a cog attempts to access Determines if Config should throw an error if a cog attempts to access
an attribute which has not been previously registered. an attribute which has not been previously registered.
@ -416,13 +418,14 @@ class Config:
MEMBER = "MEMBER" MEMBER = "MEMBER"
def __init__(self, cog_name: str, unique_identifier: str, def __init__(self, cog_name: str, unique_identifier: str,
driver_spawn: Callable, driver: "BaseDriver",
force_registration: bool=False, force_registration: bool=False,
defaults: dict=None): defaults: dict=None):
self.cog_name = cog_name self.cog_name = cog_name
self.unique_identifier = unique_identifier self.unique_identifier = unique_identifier
self.spawner = driver_spawn self.driver = driver
self.driver.unique_cog_identifier = self.unique_identifier
self.force_registration = force_registration self.force_registration = force_registration
self._defaults = defaults or {} self._defaults = defaults or {}
@ -468,11 +471,11 @@ class Config:
log.debug("Using driver: '{}'".format(driver_name)) log.debug("Using driver: '{}'".format(driver_name))
spawner = get_driver(driver_name, cog_name, data_path_override=cog_path_override, driver = get_driver(driver_name, cog_name, data_path_override=cog_path_override,
**driver_details) **driver_details)
return cls(cog_name=cog_name, unique_identifier=uuid, return cls(cog_name=cog_name, unique_identifier=uuid,
force_registration=force_registration, force_registration=force_registration,
driver_spawn=spawner) driver=driver)
@classmethod @classmethod
def get_core_conf(cls, force_registration: bool=False): def get_core_conf(cls, force_registration: bool=False):
@ -495,9 +498,9 @@ class Config:
driver_name = basic_config.get('STORAGE_TYPE', 'JSON') driver_name = basic_config.get('STORAGE_TYPE', 'JSON')
driver_details = basic_config.get('STORAGE_DETAILS', {}) driver_details = basic_config.get('STORAGE_DETAILS', {})
driver_spawn = get_driver(driver_name, "Core", data_path_override=core_path, driver = get_driver(driver_name, "Core", data_path_override=core_path,
**driver_details) **driver_details)
return cls(cog_name="Core", driver_spawn=driver_spawn, return cls(cog_name="Core", driver=driver,
unique_identifier='0', unique_identifier='0',
force_registration=force_registration) force_registration=force_registration)
@ -668,9 +671,9 @@ class Config:
def _get_base_group(self, key: str, *identifiers: str) -> Group: def _get_base_group(self, key: str, *identifiers: str) -> Group:
# noinspection PyTypeChecker # noinspection PyTypeChecker
return Group( return Group(
identifiers=(self.unique_identifier, key) + identifiers, identifiers=(key, *identifiers),
defaults=self.defaults.get(key, {}), defaults=self.defaults.get(key, {}),
spawner=self.spawner, driver=self.driver,
force_registration=self.force_registration force_registration=self.force_registration
) )
@ -911,7 +914,7 @@ class Config:
if not scopes: if not scopes:
group = Group(identifiers=(self.unique_identifier, ), group = Group(identifiers=(self.unique_identifier, ),
defaults={}, defaults={},
spawner=self.spawner) driver=self.driver)
else: else:
group = self._get_base_group(*scopes) group = self._get_base_group(*scopes)
await group.set({}) await group.set({})

View File

@ -2,14 +2,13 @@ from typing import Tuple
__all__ = ["BaseDriver"] __all__ = ["BaseDriver"]
class BaseDriver: class BaseDriver:
def __init__(self, cog_name): def __init__(self, cog_name):
self.cog_name = cog_name self.cog_name = cog_name
self.unique_cog_identifier = None # This is set by Config's init method
def get_driver(self): async def get(self, *identifiers: Tuple[str]):
raise NotImplementedError
async def get(self, identifiers: Tuple[str]):
""" """
Finds the value indicate by the given identifiers. Finds the value indicate by the given identifiers.
@ -30,7 +29,7 @@ class BaseDriver:
""" """
raise NotImplementedError raise NotImplementedError
async def set(self, identifiers: Tuple[str], value): async def set(self, *identifiers: Tuple[str], value=None):
""" """
Sets the value of the key indicated by the given identifiers. Sets the value of the key indicated by the given identifiers.

View File

@ -41,21 +41,20 @@ class JSON(BaseDriver):
self.data = {} self.data = {}
self.jsonIO._save_json(self.data) self.jsonIO._save_json(self.data)
def get_driver(self): async def get(self, *identifiers: Tuple[str]):
return self
async def get(self, identifiers: Tuple[str]):
partial = self.data partial = self.data
for i in identifiers: full_identifiers = (self.unique_cog_identifier, *identifiers)
for i in full_identifiers:
partial = partial[i] partial = partial[i]
return partial return partial
async def set(self, identifiers, value): async def set(self, *identifiers: str, value=None):
partial = self.data partial = self.data
for i in identifiers[:-1]: full_identifiers = (self.unique_cog_identifier, *identifiers)
for i in full_identifiers[:-1]:
if i not in partial: if i not in partial:
partial[i] = {} partial[i] = {}
partial = partial[i] partial = partial[i]
partial[identifiers[-1]] = value partial[full_identifiers[-1]] = value
await self.jsonIO._threadsafe_save_json(self.data) await self.jsonIO._threadsafe_save_json(self.data)

View File

@ -71,16 +71,15 @@ class Mongo(BaseDriver):
uuid, identifiers = identifiers[0], identifiers[1:] uuid, identifiers = identifiers[0], identifiers[1:]
return uuid, identifiers return uuid, identifiers
async def get(self, identifiers: Tuple[str]): async def get(self, *identifiers: Tuple[str]):
await self._ensure_connected() await self._ensure_connected()
uuid, identifiers = self._parse_identifiers(identifiers)
mongo_collection = self.get_collection() mongo_collection = self.get_collection()
dot_identifiers = '.'.join(identifiers) dot_identifiers = '.'.join(identifiers)
partial = await mongo_collection.find_one( partial = await mongo_collection.find_one(
filter={'_id': uuid}, filter={'_id': self.unique_cog_identifier},
projection={dot_identifiers: True} projection={dot_identifiers: True}
) )
@ -92,23 +91,19 @@ class Mongo(BaseDriver):
partial = partial[i] partial = partial[i]
return partial return partial
async def set(self, identifiers: Tuple[str], value): async def set(self, *identifiers: str, value=None):
await self._ensure_connected() await self._ensure_connected()
uuid, identifiers = self._parse_identifiers(identifiers)
dot_identifiers = '.'.join(identifiers) dot_identifiers = '.'.join(identifiers)
mongo_collection = self.get_collection() mongo_collection = self.get_collection()
await mongo_collection.update_one( await mongo_collection.update_one(
{'_id': uuid}, {'_id': self.unique_cog_identifier},
update={"$set": {dot_identifiers: value}}, update={"$set": {dot_identifiers: value}},
upsert=True upsert=True
) )
def get_driver(self):
return self
def get_config_details(): def get_config_details():
host = input("Enter host address: ") host = input("Enter host address: ")

View File

@ -49,7 +49,7 @@ def config(json_driver):
conf = Config( conf = Config(
cog_name="PyTest", cog_name="PyTest",
unique_identifier=str(uuid.uuid4()), unique_identifier=str(uuid.uuid4()),
driver_spawn=json_driver) driver=json_driver)
yield conf yield conf
conf._defaults = {} conf._defaults = {}
@ -63,7 +63,7 @@ def config_fr(json_driver):
conf = Config( conf = Config(
cog_name="PyTest", cog_name="PyTest",
unique_identifier=str(uuid.uuid4()), unique_identifier=str(uuid.uuid4()),
driver_spawn=json_driver, driver=json_driver,
force_registration=True force_registration=True
) )
yield conf yield conf