[V3 Config] Driver code initial cleanup (#1315)

* Remove get_driver

* Rename self.driver to self._driver

* Do not unnecessarily pass the cog identifier

* Remove unused import

* Fix type annotation

* Missed a keyword rename

* Modify signature of get/set methods in drivers
This commit is contained in:
Will 2018-02-18 22:30:32 -05:00 committed by palmtree5
parent 3984cb8f48
commit f9d846a704
5 changed files with 51 additions and 55 deletions

View File

@ -1,13 +1,18 @@
import logging
import collections
from copy import deepcopy
from typing import Callable, Union, Tuple
from typing import Union, Tuple
import discord
from .data_manager import cog_data_path, core_data_path
from .drivers import get_driver
from .utils import TYPE_CHECKING
if TYPE_CHECKING:
from .drivers.red_base import BaseDriver
log = logging.getLogger("red.config")
@ -52,24 +57,23 @@ class Value:
element from a json document.
default
The default value for the data element that `identifiers` points at.
spawner : `redbot.core.drivers.red_base.BaseDriver`
A reference to `Config.spawner`.
driver : `redbot.core.drivers.red_base.BaseDriver`
A reference to `Config.driver`.
"""
def __init__(self, identifiers: Tuple[str], default_value, spawner):
def __init__(self, identifiers: Tuple[str], default_value, driver):
self._identifiers = identifiers
self.default = default_value
self.spawner = spawner
self.driver = driver
@property
def identifiers(self):
return tuple(str(i) for i in self._identifiers)
async def _get(self, default):
driver = self.spawner.get_driver()
try:
ret = await driver.get(self.identifiers)
ret = await self.driver.get(*self.identifiers)
except KeyError:
return default if default is not None else self.default
return ret
@ -138,8 +142,7 @@ class Value:
The new literal value of this attribute.
"""
driver = self.spawner.get_driver()
await driver.set(self.identifiers, value)
await self.driver.set(*self.identifiers, value=value)
class Group(Value):
@ -155,19 +158,19 @@ class Group(Value):
All registered default values for this Group.
force_registration : `bool`
Same as `Config.force_registration`.
spawner : `redbot.core.drivers.red_base.BaseDriver`
A reference to `Config.spawner`.
driver : `redbot.core.drivers.red_base.BaseDriver`
A reference to `Config.driver`.
"""
def __init__(self, identifiers: Tuple[str],
defaults: dict,
spawner,
driver,
force_registration: bool=False):
self._defaults = defaults
self.force_registration = force_registration
self.spawner = spawner
self.driver = driver
super().__init__(identifiers, {}, self.spawner)
super().__init__(identifiers, {}, self.driver)
@property
def defaults(self):
@ -205,14 +208,14 @@ class Group(Value):
return Group(
identifiers=new_identifiers,
defaults=self._defaults[item],
spawner=self.spawner,
driver=self.driver,
force_registration=self.force_registration
)
elif is_value:
return Value(
identifiers=new_identifiers,
default_value=self._defaults[item],
spawner=self.spawner
driver=self.driver
)
elif self.force_registration:
raise AttributeError(
@ -223,7 +226,7 @@ class Group(Value):
return Value(
identifiers=new_identifiers,
default_value=None,
spawner=self.spawner
driver=self.driver
)
def is_group(self, item: str) -> bool:
@ -394,9 +397,8 @@ class Config:
unique_identifier : `int`
Unique identifier provided to differentiate cog data when name
conflicts occur.
spawner
A callable object that returns some driver that implements
`redbot.core.drivers.red_base.BaseDriver`.
driver
An instance of a driver that implements `redbot.core.drivers.red_base.BaseDriver`.
force_registration : `bool`
Determines if Config should throw an error if a cog attempts to access
an attribute which has not been previously registered.
@ -416,13 +418,14 @@ class Config:
MEMBER = "MEMBER"
def __init__(self, cog_name: str, unique_identifier: str,
driver_spawn: Callable,
driver: "BaseDriver",
force_registration: bool=False,
defaults: dict=None):
self.cog_name = cog_name
self.unique_identifier = unique_identifier
self.spawner = driver_spawn
self.driver = driver
self.driver.unique_cog_identifier = self.unique_identifier
self.force_registration = force_registration
self._defaults = defaults or {}
@ -468,11 +471,11 @@ class Config:
log.debug("Using driver: '{}'".format(driver_name))
spawner = get_driver(driver_name, cog_name, data_path_override=cog_path_override,
driver = get_driver(driver_name, cog_name, data_path_override=cog_path_override,
**driver_details)
return cls(cog_name=cog_name, unique_identifier=uuid,
force_registration=force_registration,
driver_spawn=spawner)
driver=driver)
@classmethod
def get_core_conf(cls, force_registration: bool=False):
@ -495,9 +498,9 @@ class Config:
driver_name = basic_config.get('STORAGE_TYPE', 'JSON')
driver_details = basic_config.get('STORAGE_DETAILS', {})
driver_spawn = get_driver(driver_name, "Core", data_path_override=core_path,
driver = get_driver(driver_name, "Core", data_path_override=core_path,
**driver_details)
return cls(cog_name="Core", driver_spawn=driver_spawn,
return cls(cog_name="Core", driver=driver,
unique_identifier='0',
force_registration=force_registration)
@ -668,9 +671,9 @@ class Config:
def _get_base_group(self, key: str, *identifiers: str) -> Group:
# noinspection PyTypeChecker
return Group(
identifiers=(self.unique_identifier, key) + identifiers,
identifiers=(key, *identifiers),
defaults=self.defaults.get(key, {}),
spawner=self.spawner,
driver=self.driver,
force_registration=self.force_registration
)
@ -911,7 +914,7 @@ class Config:
if not scopes:
group = Group(identifiers=(self.unique_identifier, ),
defaults={},
spawner=self.spawner)
driver=self.driver)
else:
group = self._get_base_group(*scopes)
await group.set({})

View File

@ -2,14 +2,13 @@ from typing import Tuple
__all__ = ["BaseDriver"]
class BaseDriver:
def __init__(self, cog_name):
self.cog_name = cog_name
self.unique_cog_identifier = None # This is set by Config's init method
def get_driver(self):
raise NotImplementedError
async def get(self, identifiers: Tuple[str]):
async def get(self, *identifiers: Tuple[str]):
"""
Finds the value indicate by the given identifiers.
@ -30,7 +29,7 @@ class BaseDriver:
"""
raise NotImplementedError
async def set(self, identifiers: Tuple[str], value):
async def set(self, *identifiers: Tuple[str], value=None):
"""
Sets the value of the key indicated by the given identifiers.

View File

@ -41,21 +41,20 @@ class JSON(BaseDriver):
self.data = {}
self.jsonIO._save_json(self.data)
def get_driver(self):
return self
async def get(self, identifiers: Tuple[str]):
async def get(self, *identifiers: Tuple[str]):
partial = self.data
for i in identifiers:
full_identifiers = (self.unique_cog_identifier, *identifiers)
for i in full_identifiers:
partial = partial[i]
return partial
async def set(self, identifiers, value):
async def set(self, *identifiers: str, value=None):
partial = self.data
for i in identifiers[:-1]:
full_identifiers = (self.unique_cog_identifier, *identifiers)
for i in full_identifiers[:-1]:
if i not in partial:
partial[i] = {}
partial = partial[i]
partial[identifiers[-1]] = value
partial[full_identifiers[-1]] = value
await self.jsonIO._threadsafe_save_json(self.data)

View File

@ -71,16 +71,15 @@ class Mongo(BaseDriver):
uuid, identifiers = identifiers[0], identifiers[1:]
return uuid, identifiers
async def get(self, identifiers: Tuple[str]):
async def get(self, *identifiers: Tuple[str]):
await self._ensure_connected()
uuid, identifiers = self._parse_identifiers(identifiers)
mongo_collection = self.get_collection()
dot_identifiers = '.'.join(identifiers)
partial = await mongo_collection.find_one(
filter={'_id': uuid},
filter={'_id': self.unique_cog_identifier},
projection={dot_identifiers: True}
)
@ -92,23 +91,19 @@ class Mongo(BaseDriver):
partial = partial[i]
return partial
async def set(self, identifiers: Tuple[str], value):
async def set(self, *identifiers: str, value=None):
await self._ensure_connected()
uuid, identifiers = self._parse_identifiers(identifiers)
dot_identifiers = '.'.join(identifiers)
mongo_collection = self.get_collection()
await mongo_collection.update_one(
{'_id': uuid},
{'_id': self.unique_cog_identifier},
update={"$set": {dot_identifiers: value}},
upsert=True
)
def get_driver(self):
return self
def get_config_details():
host = input("Enter host address: ")

View File

@ -49,7 +49,7 @@ def config(json_driver):
conf = Config(
cog_name="PyTest",
unique_identifier=str(uuid.uuid4()),
driver_spawn=json_driver)
driver=json_driver)
yield conf
conf._defaults = {}
@ -63,7 +63,7 @@ def config_fr(json_driver):
conf = Config(
cog_name="PyTest",
unique_identifier=str(uuid.uuid4()),
driver_spawn=json_driver,
driver=json_driver,
force_registration=True
)
yield conf