mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 11:18:54 -05:00
[V3 Config] Driver code initial cleanup (#1315)
* Remove get_driver * Rename self.driver to self._driver * Do not unnecessarily pass the cog identifier * Remove unused import * Fix type annotation * Missed a keyword rename * Modify signature of get/set methods in drivers
This commit is contained in:
parent
3984cb8f48
commit
f9d846a704
@ -1,13 +1,18 @@
|
|||||||
import logging
|
import logging
|
||||||
import collections
|
import collections
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Callable, Union, Tuple
|
from typing import Union, Tuple
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
from .data_manager import cog_data_path, core_data_path
|
from .data_manager import cog_data_path, core_data_path
|
||||||
from .drivers import get_driver
|
from .drivers import get_driver
|
||||||
|
|
||||||
|
from .utils import TYPE_CHECKING
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .drivers.red_base import BaseDriver
|
||||||
|
|
||||||
log = logging.getLogger("red.config")
|
log = logging.getLogger("red.config")
|
||||||
|
|
||||||
|
|
||||||
@ -52,24 +57,23 @@ class Value:
|
|||||||
element from a json document.
|
element from a json document.
|
||||||
default
|
default
|
||||||
The default value for the data element that `identifiers` points at.
|
The default value for the data element that `identifiers` points at.
|
||||||
spawner : `redbot.core.drivers.red_base.BaseDriver`
|
driver : `redbot.core.drivers.red_base.BaseDriver`
|
||||||
A reference to `Config.spawner`.
|
A reference to `Config.driver`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, identifiers: Tuple[str], default_value, spawner):
|
def __init__(self, identifiers: Tuple[str], default_value, driver):
|
||||||
self._identifiers = identifiers
|
self._identifiers = identifiers
|
||||||
self.default = default_value
|
self.default = default_value
|
||||||
|
|
||||||
self.spawner = spawner
|
self.driver = driver
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def identifiers(self):
|
def identifiers(self):
|
||||||
return tuple(str(i) for i in self._identifiers)
|
return tuple(str(i) for i in self._identifiers)
|
||||||
|
|
||||||
async def _get(self, default):
|
async def _get(self, default):
|
||||||
driver = self.spawner.get_driver()
|
|
||||||
try:
|
try:
|
||||||
ret = await driver.get(self.identifiers)
|
ret = await self.driver.get(*self.identifiers)
|
||||||
except KeyError:
|
except KeyError:
|
||||||
return default if default is not None else self.default
|
return default if default is not None else self.default
|
||||||
return ret
|
return ret
|
||||||
@ -138,8 +142,7 @@ class Value:
|
|||||||
The new literal value of this attribute.
|
The new literal value of this attribute.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
driver = self.spawner.get_driver()
|
await self.driver.set(*self.identifiers, value=value)
|
||||||
await driver.set(self.identifiers, value)
|
|
||||||
|
|
||||||
|
|
||||||
class Group(Value):
|
class Group(Value):
|
||||||
@ -155,19 +158,19 @@ class Group(Value):
|
|||||||
All registered default values for this Group.
|
All registered default values for this Group.
|
||||||
force_registration : `bool`
|
force_registration : `bool`
|
||||||
Same as `Config.force_registration`.
|
Same as `Config.force_registration`.
|
||||||
spawner : `redbot.core.drivers.red_base.BaseDriver`
|
driver : `redbot.core.drivers.red_base.BaseDriver`
|
||||||
A reference to `Config.spawner`.
|
A reference to `Config.driver`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
def __init__(self, identifiers: Tuple[str],
|
def __init__(self, identifiers: Tuple[str],
|
||||||
defaults: dict,
|
defaults: dict,
|
||||||
spawner,
|
driver,
|
||||||
force_registration: bool=False):
|
force_registration: bool=False):
|
||||||
self._defaults = defaults
|
self._defaults = defaults
|
||||||
self.force_registration = force_registration
|
self.force_registration = force_registration
|
||||||
self.spawner = spawner
|
self.driver = driver
|
||||||
|
|
||||||
super().__init__(identifiers, {}, self.spawner)
|
super().__init__(identifiers, {}, self.driver)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def defaults(self):
|
def defaults(self):
|
||||||
@ -205,14 +208,14 @@ class Group(Value):
|
|||||||
return Group(
|
return Group(
|
||||||
identifiers=new_identifiers,
|
identifiers=new_identifiers,
|
||||||
defaults=self._defaults[item],
|
defaults=self._defaults[item],
|
||||||
spawner=self.spawner,
|
driver=self.driver,
|
||||||
force_registration=self.force_registration
|
force_registration=self.force_registration
|
||||||
)
|
)
|
||||||
elif is_value:
|
elif is_value:
|
||||||
return Value(
|
return Value(
|
||||||
identifiers=new_identifiers,
|
identifiers=new_identifiers,
|
||||||
default_value=self._defaults[item],
|
default_value=self._defaults[item],
|
||||||
spawner=self.spawner
|
driver=self.driver
|
||||||
)
|
)
|
||||||
elif self.force_registration:
|
elif self.force_registration:
|
||||||
raise AttributeError(
|
raise AttributeError(
|
||||||
@ -223,7 +226,7 @@ class Group(Value):
|
|||||||
return Value(
|
return Value(
|
||||||
identifiers=new_identifiers,
|
identifiers=new_identifiers,
|
||||||
default_value=None,
|
default_value=None,
|
||||||
spawner=self.spawner
|
driver=self.driver
|
||||||
)
|
)
|
||||||
|
|
||||||
def is_group(self, item: str) -> bool:
|
def is_group(self, item: str) -> bool:
|
||||||
@ -394,9 +397,8 @@ class Config:
|
|||||||
unique_identifier : `int`
|
unique_identifier : `int`
|
||||||
Unique identifier provided to differentiate cog data when name
|
Unique identifier provided to differentiate cog data when name
|
||||||
conflicts occur.
|
conflicts occur.
|
||||||
spawner
|
driver
|
||||||
A callable object that returns some driver that implements
|
An instance of a driver that implements `redbot.core.drivers.red_base.BaseDriver`.
|
||||||
`redbot.core.drivers.red_base.BaseDriver`.
|
|
||||||
force_registration : `bool`
|
force_registration : `bool`
|
||||||
Determines if Config should throw an error if a cog attempts to access
|
Determines if Config should throw an error if a cog attempts to access
|
||||||
an attribute which has not been previously registered.
|
an attribute which has not been previously registered.
|
||||||
@ -416,13 +418,14 @@ class Config:
|
|||||||
MEMBER = "MEMBER"
|
MEMBER = "MEMBER"
|
||||||
|
|
||||||
def __init__(self, cog_name: str, unique_identifier: str,
|
def __init__(self, cog_name: str, unique_identifier: str,
|
||||||
driver_spawn: Callable,
|
driver: "BaseDriver",
|
||||||
force_registration: bool=False,
|
force_registration: bool=False,
|
||||||
defaults: dict=None):
|
defaults: dict=None):
|
||||||
self.cog_name = cog_name
|
self.cog_name = cog_name
|
||||||
self.unique_identifier = unique_identifier
|
self.unique_identifier = unique_identifier
|
||||||
|
|
||||||
self.spawner = driver_spawn
|
self.driver = driver
|
||||||
|
self.driver.unique_cog_identifier = self.unique_identifier
|
||||||
self.force_registration = force_registration
|
self.force_registration = force_registration
|
||||||
self._defaults = defaults or {}
|
self._defaults = defaults or {}
|
||||||
|
|
||||||
@ -468,11 +471,11 @@ class Config:
|
|||||||
|
|
||||||
log.debug("Using driver: '{}'".format(driver_name))
|
log.debug("Using driver: '{}'".format(driver_name))
|
||||||
|
|
||||||
spawner = get_driver(driver_name, cog_name, data_path_override=cog_path_override,
|
driver = get_driver(driver_name, cog_name, data_path_override=cog_path_override,
|
||||||
**driver_details)
|
**driver_details)
|
||||||
return cls(cog_name=cog_name, unique_identifier=uuid,
|
return cls(cog_name=cog_name, unique_identifier=uuid,
|
||||||
force_registration=force_registration,
|
force_registration=force_registration,
|
||||||
driver_spawn=spawner)
|
driver=driver)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def get_core_conf(cls, force_registration: bool=False):
|
def get_core_conf(cls, force_registration: bool=False):
|
||||||
@ -495,9 +498,9 @@ class Config:
|
|||||||
driver_name = basic_config.get('STORAGE_TYPE', 'JSON')
|
driver_name = basic_config.get('STORAGE_TYPE', 'JSON')
|
||||||
driver_details = basic_config.get('STORAGE_DETAILS', {})
|
driver_details = basic_config.get('STORAGE_DETAILS', {})
|
||||||
|
|
||||||
driver_spawn = get_driver(driver_name, "Core", data_path_override=core_path,
|
driver = get_driver(driver_name, "Core", data_path_override=core_path,
|
||||||
**driver_details)
|
**driver_details)
|
||||||
return cls(cog_name="Core", driver_spawn=driver_spawn,
|
return cls(cog_name="Core", driver=driver,
|
||||||
unique_identifier='0',
|
unique_identifier='0',
|
||||||
force_registration=force_registration)
|
force_registration=force_registration)
|
||||||
|
|
||||||
@ -668,9 +671,9 @@ class Config:
|
|||||||
def _get_base_group(self, key: str, *identifiers: str) -> Group:
|
def _get_base_group(self, key: str, *identifiers: str) -> Group:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
return Group(
|
return Group(
|
||||||
identifiers=(self.unique_identifier, key) + identifiers,
|
identifiers=(key, *identifiers),
|
||||||
defaults=self.defaults.get(key, {}),
|
defaults=self.defaults.get(key, {}),
|
||||||
spawner=self.spawner,
|
driver=self.driver,
|
||||||
force_registration=self.force_registration
|
force_registration=self.force_registration
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -911,7 +914,7 @@ class Config:
|
|||||||
if not scopes:
|
if not scopes:
|
||||||
group = Group(identifiers=(self.unique_identifier, ),
|
group = Group(identifiers=(self.unique_identifier, ),
|
||||||
defaults={},
|
defaults={},
|
||||||
spawner=self.spawner)
|
driver=self.driver)
|
||||||
else:
|
else:
|
||||||
group = self._get_base_group(*scopes)
|
group = self._get_base_group(*scopes)
|
||||||
await group.set({})
|
await group.set({})
|
||||||
|
|||||||
@ -2,14 +2,13 @@ from typing import Tuple
|
|||||||
|
|
||||||
__all__ = ["BaseDriver"]
|
__all__ = ["BaseDriver"]
|
||||||
|
|
||||||
|
|
||||||
class BaseDriver:
|
class BaseDriver:
|
||||||
def __init__(self, cog_name):
|
def __init__(self, cog_name):
|
||||||
self.cog_name = cog_name
|
self.cog_name = cog_name
|
||||||
|
self.unique_cog_identifier = None # This is set by Config's init method
|
||||||
|
|
||||||
def get_driver(self):
|
async def get(self, *identifiers: Tuple[str]):
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
async def get(self, identifiers: Tuple[str]):
|
|
||||||
"""
|
"""
|
||||||
Finds the value indicate by the given identifiers.
|
Finds the value indicate by the given identifiers.
|
||||||
|
|
||||||
@ -30,7 +29,7 @@ class BaseDriver:
|
|||||||
"""
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def set(self, identifiers: Tuple[str], value):
|
async def set(self, *identifiers: Tuple[str], value=None):
|
||||||
"""
|
"""
|
||||||
Sets the value of the key indicated by the given identifiers.
|
Sets the value of the key indicated by the given identifiers.
|
||||||
|
|
||||||
|
|||||||
@ -41,21 +41,20 @@ class JSON(BaseDriver):
|
|||||||
self.data = {}
|
self.data = {}
|
||||||
self.jsonIO._save_json(self.data)
|
self.jsonIO._save_json(self.data)
|
||||||
|
|
||||||
def get_driver(self):
|
async def get(self, *identifiers: Tuple[str]):
|
||||||
return self
|
|
||||||
|
|
||||||
async def get(self, identifiers: Tuple[str]):
|
|
||||||
partial = self.data
|
partial = self.data
|
||||||
for i in identifiers:
|
full_identifiers = (self.unique_cog_identifier, *identifiers)
|
||||||
|
for i in full_identifiers:
|
||||||
partial = partial[i]
|
partial = partial[i]
|
||||||
return partial
|
return partial
|
||||||
|
|
||||||
async def set(self, identifiers, value):
|
async def set(self, *identifiers: str, value=None):
|
||||||
partial = self.data
|
partial = self.data
|
||||||
for i in identifiers[:-1]:
|
full_identifiers = (self.unique_cog_identifier, *identifiers)
|
||||||
|
for i in full_identifiers[:-1]:
|
||||||
if i not in partial:
|
if i not in partial:
|
||||||
partial[i] = {}
|
partial[i] = {}
|
||||||
partial = partial[i]
|
partial = partial[i]
|
||||||
|
|
||||||
partial[identifiers[-1]] = value
|
partial[full_identifiers[-1]] = value
|
||||||
await self.jsonIO._threadsafe_save_json(self.data)
|
await self.jsonIO._threadsafe_save_json(self.data)
|
||||||
|
|||||||
@ -71,16 +71,15 @@ class Mongo(BaseDriver):
|
|||||||
uuid, identifiers = identifiers[0], identifiers[1:]
|
uuid, identifiers = identifiers[0], identifiers[1:]
|
||||||
return uuid, identifiers
|
return uuid, identifiers
|
||||||
|
|
||||||
async def get(self, identifiers: Tuple[str]):
|
async def get(self, *identifiers: Tuple[str]):
|
||||||
await self._ensure_connected()
|
await self._ensure_connected()
|
||||||
uuid, identifiers = self._parse_identifiers(identifiers)
|
|
||||||
|
|
||||||
mongo_collection = self.get_collection()
|
mongo_collection = self.get_collection()
|
||||||
|
|
||||||
dot_identifiers = '.'.join(identifiers)
|
dot_identifiers = '.'.join(identifiers)
|
||||||
|
|
||||||
partial = await mongo_collection.find_one(
|
partial = await mongo_collection.find_one(
|
||||||
filter={'_id': uuid},
|
filter={'_id': self.unique_cog_identifier},
|
||||||
projection={dot_identifiers: True}
|
projection={dot_identifiers: True}
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -92,23 +91,19 @@ class Mongo(BaseDriver):
|
|||||||
partial = partial[i]
|
partial = partial[i]
|
||||||
return partial
|
return partial
|
||||||
|
|
||||||
async def set(self, identifiers: Tuple[str], value):
|
async def set(self, *identifiers: str, value=None):
|
||||||
await self._ensure_connected()
|
await self._ensure_connected()
|
||||||
uuid, identifiers = self._parse_identifiers(identifiers)
|
|
||||||
|
|
||||||
dot_identifiers = '.'.join(identifiers)
|
dot_identifiers = '.'.join(identifiers)
|
||||||
|
|
||||||
mongo_collection = self.get_collection()
|
mongo_collection = self.get_collection()
|
||||||
|
|
||||||
await mongo_collection.update_one(
|
await mongo_collection.update_one(
|
||||||
{'_id': uuid},
|
{'_id': self.unique_cog_identifier},
|
||||||
update={"$set": {dot_identifiers: value}},
|
update={"$set": {dot_identifiers: value}},
|
||||||
upsert=True
|
upsert=True
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_driver(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
|
|
||||||
def get_config_details():
|
def get_config_details():
|
||||||
host = input("Enter host address: ")
|
host = input("Enter host address: ")
|
||||||
|
|||||||
@ -49,7 +49,7 @@ def config(json_driver):
|
|||||||
conf = Config(
|
conf = Config(
|
||||||
cog_name="PyTest",
|
cog_name="PyTest",
|
||||||
unique_identifier=str(uuid.uuid4()),
|
unique_identifier=str(uuid.uuid4()),
|
||||||
driver_spawn=json_driver)
|
driver=json_driver)
|
||||||
yield conf
|
yield conf
|
||||||
conf._defaults = {}
|
conf._defaults = {}
|
||||||
|
|
||||||
@ -63,7 +63,7 @@ def config_fr(json_driver):
|
|||||||
conf = Config(
|
conf = Config(
|
||||||
cog_name="PyTest",
|
cog_name="PyTest",
|
||||||
unique_identifier=str(uuid.uuid4()),
|
unique_identifier=str(uuid.uuid4()),
|
||||||
driver_spawn=json_driver,
|
driver=json_driver,
|
||||||
force_registration=True
|
force_registration=True
|
||||||
)
|
)
|
||||||
yield conf
|
yield conf
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user