mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 03:08:55 -05:00
[V3 Config] Driver code initial cleanup (#1315)
* Remove get_driver * Rename self.driver to self._driver * Do not unnecessarily pass the cog identifier * Remove unused import * Fix type annotation * Missed a keyword rename * Modify signature of get/set methods in drivers
This commit is contained in:
parent
3984cb8f48
commit
f9d846a704
@ -1,13 +1,18 @@
|
||||
import logging
|
||||
import collections
|
||||
from copy import deepcopy
|
||||
from typing import Callable, Union, Tuple
|
||||
from typing import Union, Tuple
|
||||
|
||||
import discord
|
||||
|
||||
from .data_manager import cog_data_path, core_data_path
|
||||
from .drivers import get_driver
|
||||
|
||||
from .utils import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .drivers.red_base import BaseDriver
|
||||
|
||||
log = logging.getLogger("red.config")
|
||||
|
||||
|
||||
@ -52,24 +57,23 @@ class Value:
|
||||
element from a json document.
|
||||
default
|
||||
The default value for the data element that `identifiers` points at.
|
||||
spawner : `redbot.core.drivers.red_base.BaseDriver`
|
||||
A reference to `Config.spawner`.
|
||||
driver : `redbot.core.drivers.red_base.BaseDriver`
|
||||
A reference to `Config.driver`.
|
||||
|
||||
"""
|
||||
def __init__(self, identifiers: Tuple[str], default_value, spawner):
|
||||
def __init__(self, identifiers: Tuple[str], default_value, driver):
|
||||
self._identifiers = identifiers
|
||||
self.default = default_value
|
||||
|
||||
self.spawner = spawner
|
||||
self.driver = driver
|
||||
|
||||
@property
|
||||
def identifiers(self):
|
||||
return tuple(str(i) for i in self._identifiers)
|
||||
|
||||
async def _get(self, default):
|
||||
driver = self.spawner.get_driver()
|
||||
try:
|
||||
ret = await driver.get(self.identifiers)
|
||||
ret = await self.driver.get(*self.identifiers)
|
||||
except KeyError:
|
||||
return default if default is not None else self.default
|
||||
return ret
|
||||
@ -138,8 +142,7 @@ class Value:
|
||||
The new literal value of this attribute.
|
||||
|
||||
"""
|
||||
driver = self.spawner.get_driver()
|
||||
await driver.set(self.identifiers, value)
|
||||
await self.driver.set(*self.identifiers, value=value)
|
||||
|
||||
|
||||
class Group(Value):
|
||||
@ -155,19 +158,19 @@ class Group(Value):
|
||||
All registered default values for this Group.
|
||||
force_registration : `bool`
|
||||
Same as `Config.force_registration`.
|
||||
spawner : `redbot.core.drivers.red_base.BaseDriver`
|
||||
A reference to `Config.spawner`.
|
||||
driver : `redbot.core.drivers.red_base.BaseDriver`
|
||||
A reference to `Config.driver`.
|
||||
|
||||
"""
|
||||
def __init__(self, identifiers: Tuple[str],
|
||||
defaults: dict,
|
||||
spawner,
|
||||
driver,
|
||||
force_registration: bool=False):
|
||||
self._defaults = defaults
|
||||
self.force_registration = force_registration
|
||||
self.spawner = spawner
|
||||
self.driver = driver
|
||||
|
||||
super().__init__(identifiers, {}, self.spawner)
|
||||
super().__init__(identifiers, {}, self.driver)
|
||||
|
||||
@property
|
||||
def defaults(self):
|
||||
@ -205,14 +208,14 @@ class Group(Value):
|
||||
return Group(
|
||||
identifiers=new_identifiers,
|
||||
defaults=self._defaults[item],
|
||||
spawner=self.spawner,
|
||||
driver=self.driver,
|
||||
force_registration=self.force_registration
|
||||
)
|
||||
elif is_value:
|
||||
return Value(
|
||||
identifiers=new_identifiers,
|
||||
default_value=self._defaults[item],
|
||||
spawner=self.spawner
|
||||
driver=self.driver
|
||||
)
|
||||
elif self.force_registration:
|
||||
raise AttributeError(
|
||||
@ -223,7 +226,7 @@ class Group(Value):
|
||||
return Value(
|
||||
identifiers=new_identifiers,
|
||||
default_value=None,
|
||||
spawner=self.spawner
|
||||
driver=self.driver
|
||||
)
|
||||
|
||||
def is_group(self, item: str) -> bool:
|
||||
@ -394,9 +397,8 @@ class Config:
|
||||
unique_identifier : `int`
|
||||
Unique identifier provided to differentiate cog data when name
|
||||
conflicts occur.
|
||||
spawner
|
||||
A callable object that returns some driver that implements
|
||||
`redbot.core.drivers.red_base.BaseDriver`.
|
||||
driver
|
||||
An instance of a driver that implements `redbot.core.drivers.red_base.BaseDriver`.
|
||||
force_registration : `bool`
|
||||
Determines if Config should throw an error if a cog attempts to access
|
||||
an attribute which has not been previously registered.
|
||||
@ -416,13 +418,14 @@ class Config:
|
||||
MEMBER = "MEMBER"
|
||||
|
||||
def __init__(self, cog_name: str, unique_identifier: str,
|
||||
driver_spawn: Callable,
|
||||
driver: "BaseDriver",
|
||||
force_registration: bool=False,
|
||||
defaults: dict=None):
|
||||
self.cog_name = cog_name
|
||||
self.unique_identifier = unique_identifier
|
||||
|
||||
self.spawner = driver_spawn
|
||||
self.driver = driver
|
||||
self.driver.unique_cog_identifier = self.unique_identifier
|
||||
self.force_registration = force_registration
|
||||
self._defaults = defaults or {}
|
||||
|
||||
@ -468,11 +471,11 @@ class Config:
|
||||
|
||||
log.debug("Using driver: '{}'".format(driver_name))
|
||||
|
||||
spawner = get_driver(driver_name, cog_name, data_path_override=cog_path_override,
|
||||
driver = get_driver(driver_name, cog_name, data_path_override=cog_path_override,
|
||||
**driver_details)
|
||||
return cls(cog_name=cog_name, unique_identifier=uuid,
|
||||
force_registration=force_registration,
|
||||
driver_spawn=spawner)
|
||||
driver=driver)
|
||||
|
||||
@classmethod
|
||||
def get_core_conf(cls, force_registration: bool=False):
|
||||
@ -495,9 +498,9 @@ class Config:
|
||||
driver_name = basic_config.get('STORAGE_TYPE', 'JSON')
|
||||
driver_details = basic_config.get('STORAGE_DETAILS', {})
|
||||
|
||||
driver_spawn = get_driver(driver_name, "Core", data_path_override=core_path,
|
||||
driver = get_driver(driver_name, "Core", data_path_override=core_path,
|
||||
**driver_details)
|
||||
return cls(cog_name="Core", driver_spawn=driver_spawn,
|
||||
return cls(cog_name="Core", driver=driver,
|
||||
unique_identifier='0',
|
||||
force_registration=force_registration)
|
||||
|
||||
@ -668,9 +671,9 @@ class Config:
|
||||
def _get_base_group(self, key: str, *identifiers: str) -> Group:
|
||||
# noinspection PyTypeChecker
|
||||
return Group(
|
||||
identifiers=(self.unique_identifier, key) + identifiers,
|
||||
identifiers=(key, *identifiers),
|
||||
defaults=self.defaults.get(key, {}),
|
||||
spawner=self.spawner,
|
||||
driver=self.driver,
|
||||
force_registration=self.force_registration
|
||||
)
|
||||
|
||||
@ -911,7 +914,7 @@ class Config:
|
||||
if not scopes:
|
||||
group = Group(identifiers=(self.unique_identifier, ),
|
||||
defaults={},
|
||||
spawner=self.spawner)
|
||||
driver=self.driver)
|
||||
else:
|
||||
group = self._get_base_group(*scopes)
|
||||
await group.set({})
|
||||
|
||||
@ -2,14 +2,13 @@ from typing import Tuple
|
||||
|
||||
__all__ = ["BaseDriver"]
|
||||
|
||||
|
||||
class BaseDriver:
|
||||
def __init__(self, cog_name):
|
||||
self.cog_name = cog_name
|
||||
self.unique_cog_identifier = None # This is set by Config's init method
|
||||
|
||||
def get_driver(self):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get(self, identifiers: Tuple[str]):
|
||||
async def get(self, *identifiers: Tuple[str]):
|
||||
"""
|
||||
Finds the value indicate by the given identifiers.
|
||||
|
||||
@ -30,7 +29,7 @@ class BaseDriver:
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
async def set(self, identifiers: Tuple[str], value):
|
||||
async def set(self, *identifiers: Tuple[str], value=None):
|
||||
"""
|
||||
Sets the value of the key indicated by the given identifiers.
|
||||
|
||||
|
||||
@ -41,21 +41,20 @@ class JSON(BaseDriver):
|
||||
self.data = {}
|
||||
self.jsonIO._save_json(self.data)
|
||||
|
||||
def get_driver(self):
|
||||
return self
|
||||
|
||||
async def get(self, identifiers: Tuple[str]):
|
||||
async def get(self, *identifiers: Tuple[str]):
|
||||
partial = self.data
|
||||
for i in identifiers:
|
||||
full_identifiers = (self.unique_cog_identifier, *identifiers)
|
||||
for i in full_identifiers:
|
||||
partial = partial[i]
|
||||
return partial
|
||||
|
||||
async def set(self, identifiers, value):
|
||||
async def set(self, *identifiers: str, value=None):
|
||||
partial = self.data
|
||||
for i in identifiers[:-1]:
|
||||
full_identifiers = (self.unique_cog_identifier, *identifiers)
|
||||
for i in full_identifiers[:-1]:
|
||||
if i not in partial:
|
||||
partial[i] = {}
|
||||
partial = partial[i]
|
||||
|
||||
partial[identifiers[-1]] = value
|
||||
partial[full_identifiers[-1]] = value
|
||||
await self.jsonIO._threadsafe_save_json(self.data)
|
||||
|
||||
@ -71,16 +71,15 @@ class Mongo(BaseDriver):
|
||||
uuid, identifiers = identifiers[0], identifiers[1:]
|
||||
return uuid, identifiers
|
||||
|
||||
async def get(self, identifiers: Tuple[str]):
|
||||
async def get(self, *identifiers: Tuple[str]):
|
||||
await self._ensure_connected()
|
||||
uuid, identifiers = self._parse_identifiers(identifiers)
|
||||
|
||||
mongo_collection = self.get_collection()
|
||||
|
||||
dot_identifiers = '.'.join(identifiers)
|
||||
|
||||
partial = await mongo_collection.find_one(
|
||||
filter={'_id': uuid},
|
||||
filter={'_id': self.unique_cog_identifier},
|
||||
projection={dot_identifiers: True}
|
||||
)
|
||||
|
||||
@ -92,23 +91,19 @@ class Mongo(BaseDriver):
|
||||
partial = partial[i]
|
||||
return partial
|
||||
|
||||
async def set(self, identifiers: Tuple[str], value):
|
||||
async def set(self, *identifiers: str, value=None):
|
||||
await self._ensure_connected()
|
||||
uuid, identifiers = self._parse_identifiers(identifiers)
|
||||
|
||||
dot_identifiers = '.'.join(identifiers)
|
||||
|
||||
mongo_collection = self.get_collection()
|
||||
|
||||
await mongo_collection.update_one(
|
||||
{'_id': uuid},
|
||||
{'_id': self.unique_cog_identifier},
|
||||
update={"$set": {dot_identifiers: value}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
def get_driver(self):
|
||||
return self
|
||||
|
||||
|
||||
def get_config_details():
|
||||
host = input("Enter host address: ")
|
||||
|
||||
@ -49,7 +49,7 @@ def config(json_driver):
|
||||
conf = Config(
|
||||
cog_name="PyTest",
|
||||
unique_identifier=str(uuid.uuid4()),
|
||||
driver_spawn=json_driver)
|
||||
driver=json_driver)
|
||||
yield conf
|
||||
conf._defaults = {}
|
||||
|
||||
@ -63,7 +63,7 @@ def config_fr(json_driver):
|
||||
conf = Config(
|
||||
cog_name="PyTest",
|
||||
unique_identifier=str(uuid.uuid4()),
|
||||
driver_spawn=json_driver,
|
||||
driver=json_driver,
|
||||
force_registration=True
|
||||
)
|
||||
yield conf
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user