[V3 Config] Fix unloading and implement singleton driver (#1458)

* Add the identifier as an initialization parameter

* Remove config object singleton and opt for a shared JSON datastore

* Fix bot unloading to deal with memory leaks

* Fix tests

* Fix clear all bug
This commit is contained in:
Will 2018-04-02 20:47:27 -04:00 committed by palmtree5
parent 720ef38886
commit 29ce2401ca
7 changed files with 94 additions and 37 deletions

View File

@ -7,6 +7,7 @@ from importlib.machinery import ModuleSpec
from pathlib import Path from pathlib import Path
import discord import discord
import sys
from discord.ext.commands.bot import BotBase from discord.ext.commands.bot import BotBase
from discord.ext.commands import GroupMixin from discord.ext.commands import GroupMixin
from discord.ext.commands import when_mentioned_or from discord.ext.commands import when_mentioned_or
@ -268,9 +269,15 @@ class RedBase(BotBase, RpcMethodMixin):
pass pass
finally: finally:
# finally remove the import.. # finally remove the import..
pkg_name = lib.__package__
del lib del lib
del self.extensions[name] del self.extensions[name]
# del sys.modules[name] for m, _ in sys.modules.copy().items():
if m.startswith(pkg_name):
del sys.modules[m]
if pkg_name.startswith('redbot.cogs'):
del sys.modules['redbot.cogs'].__dict__[name]
def register_rpc_methods(self): def register_rpc_methods(self):
rpc.add_method('bot', self.rpc__cogs) rpc.add_method('bot', self.rpc__cogs)

View File

@ -1,7 +1,5 @@
import contextlib
import logging import logging
import collections import collections
from weakref import ref
from copy import deepcopy from copy import deepcopy
from typing import Union, Tuple from typing import Union, Tuple
@ -418,10 +416,6 @@ class Group(Value):
await self.driver.set(*self.identifiers, *path, value=value) await self.driver.set(*self.identifiers, *path, value=value)
_config_cogrefs = {}
_config_coreref = None
class Config: class Config:
"""Configuration manager for cogs and Red. """Configuration manager for cogs and Red.
@ -470,7 +464,6 @@ class Config:
self.unique_identifier = unique_identifier self.unique_identifier = unique_identifier
self.driver = driver self.driver = driver
self.driver.unique_cog_identifier = self.unique_identifier
self.force_registration = force_registration self.force_registration = force_registration
self._defaults = defaults or {} self._defaults = defaults or {}
@ -483,6 +476,13 @@ class Config:
force_registration=False, cog_name=None): force_registration=False, cog_name=None):
"""Get a Config instance for your cog. """Get a Config instance for your cog.
.. warning::
If you are using this classmethod to get a second instance of an
existing Config object for a particular cog, you MUST provide the
correct identifier. If you do not, you *will* screw up all other
Config instances for that cog.
Parameters Parameters
---------- ----------
cog_instance cog_instance
@ -514,11 +514,6 @@ class Config:
cog_name = cog_path_override.stem cog_name = cog_path_override.stem
uuid = str(hash(identifier)) uuid = str(hash(identifier))
with contextlib.suppress(KeyError):
conf = _config_cogrefs[cog_name]()
if conf is not None:
return conf
# We have to import this here otherwise we have a circular dependency # We have to import this here otherwise we have a circular dependency
from .data_manager import basic_config from .data_manager import basic_config
@ -529,12 +524,11 @@ class Config:
log.debug("Using driver: '{}'".format(driver_name)) log.debug("Using driver: '{}'".format(driver_name))
driver = get_driver(driver_name, cog_name, data_path_override=cog_path_override, driver = get_driver(driver_name, cog_name, uuid, data_path_override=cog_path_override,
**driver_details) **driver_details)
conf = cls(cog_name=cog_name, unique_identifier=uuid, conf = cls(cog_name=cog_name, unique_identifier=uuid,
force_registration=force_registration, force_registration=force_registration,
driver=driver) driver=driver)
_config_cogrefs[cog_name] = ref(conf)
return conf return conf
@classmethod @classmethod
@ -550,10 +544,6 @@ class Config:
See `force_registration`. See `force_registration`.
""" """
global _config_coreref
if _config_coreref is not None and _config_coreref() is not None:
return _config_coreref()
core_path = core_data_path() core_path = core_data_path()
# We have to import this here otherwise we have a circular dependency # We have to import this here otherwise we have a circular dependency
@ -562,12 +552,11 @@ class Config:
driver_name = basic_config.get('STORAGE_TYPE', 'JSON') driver_name = basic_config.get('STORAGE_TYPE', 'JSON')
driver_details = basic_config.get('STORAGE_DETAILS', {}) driver_details = basic_config.get('STORAGE_DETAILS', {})
driver = get_driver(driver_name, "Core", data_path_override=core_path, driver = get_driver(driver_name, "Core", '0', data_path_override=core_path,
**driver_details) **driver_details)
conf = cls(cog_name="Core", driver=driver, conf = cls(cog_name="Core", driver=driver,
unique_identifier='0', unique_identifier='0',
force_registration=force_registration) force_registration=force_registration)
_config_coreref = ref(conf)
return conf return conf
def __getattr__(self, item: str) -> Union[Group, Value]: def __getattr__(self, item: str) -> Union[Group, Value]:
@ -1003,7 +992,7 @@ class Config:
""" """
if not scopes: if not scopes:
group = Group(identifiers=(self.unique_identifier, ), group = Group(identifiers=[],
defaults={}, defaults={},
driver=self.driver) driver=self.driver)
else: else:

View File

@ -2,9 +2,9 @@ __all__ = ["BaseDriver"]
class BaseDriver: class BaseDriver:
def __init__(self, cog_name): def __init__(self, cog_name, identifier):
self.cog_name = cog_name self.cog_name = cog_name
self.unique_cog_identifier = None # This is set by Config's init method self.unique_cog_identifier = identifier
async def get(self, *identifiers: str): async def get(self, *identifiers: str):
""" """

View File

@ -1,5 +1,7 @@
from pathlib import Path from pathlib import Path
from typing import Tuple from typing import Tuple
import weakref
import logging
from ..json_io import JsonIO from ..json_io import JsonIO
@ -8,6 +10,28 @@ from .red_base import BaseDriver
__all__ = ["JSON"] __all__ = ["JSON"]
_shared_datastore = {}
_driver_counts = {}
_finalizers = []
log = logging.getLogger("redbot.json_driver")
def finalize_driver(cog_name):
if cog_name not in _driver_counts:
return
_driver_counts[cog_name] -= 1
if _driver_counts[cog_name] == 0:
if cog_name in _shared_datastore:
del _shared_datastore[cog_name]
for f in _finalizers:
if not f.alive:
_finalizers.remove(f)
class JSON(BaseDriver): class JSON(BaseDriver):
""" """
Subclass of :py:class:`.red_base.BaseDriver`. Subclass of :py:class:`.red_base.BaseDriver`.
@ -20,9 +44,9 @@ class JSON(BaseDriver):
The path in which to store the file indicated by :py:attr:`file_name`. The path in which to store the file indicated by :py:attr:`file_name`.
""" """
def __init__(self, cog_name, *, data_path_override: Path=None, def __init__(self, cog_name, identifier, *, data_path_override: Path=None,
file_name_override: str="settings.json"): file_name_override: str="settings.json"):
super().__init__(cog_name) super().__init__(cog_name, identifier)
self.file_name = file_name_override self.file_name = file_name_override
if data_path_override: if data_path_override:
self.data_path = data_path_override self.data_path = data_path_override
@ -35,6 +59,26 @@ class JSON(BaseDriver):
self.jsonIO = JsonIO(self.data_path) self.jsonIO = JsonIO(self.data_path)
self._load_data()
@property
def data(self):
return _shared_datastore.get(self.cog_name)
@data.setter
def data(self, value):
_shared_datastore[self.cog_name] = value
def _load_data(self):
if self.cog_name not in _driver_counts:
_driver_counts[self.cog_name] = 0
_driver_counts[self.cog_name] += 1
_finalizers.append(weakref.finalize(self, finalize_driver, self.cog_name))
if self.data is not None:
return
try: try:
self.data = self.jsonIO._load_json() self.data = self.jsonIO._load_json()
except FileNotFoundError: except FileNotFoundError:
@ -65,8 +109,11 @@ class JSON(BaseDriver):
try: try:
for i in full_identifiers[:-1]: for i in full_identifiers[:-1]:
partial = partial[i] partial = partial[i]
del partial[identifiers[-1]] del partial[full_identifiers[-1]]
except KeyError: except KeyError:
pass pass
else: else:
await self.jsonIO._threadsafe_save_json(self.data) await self.jsonIO._threadsafe_save_json(self.data)
def get_config_details(self):
return

View File

@ -32,8 +32,8 @@ class Mongo(BaseDriver):
""" """
Subclass of :py:class:`.red_base.BaseDriver`. Subclass of :py:class:`.red_base.BaseDriver`.
""" """
def __init__(self, cog_name, **kwargs): def __init__(self, cog_name, identifier, **kwargs):
super().__init__(cog_name) super().__init__(cog_name, identifier)
if _conn is None: if _conn is None:
_initialize(**kwargs) _initialize(**kwargs)
@ -105,10 +105,15 @@ class Mongo(BaseDriver):
dot_identifiers = '.'.join(identifiers) dot_identifiers = '.'.join(identifiers)
mongo_collection = self.get_collection() mongo_collection = self.get_collection()
if len(identifiers) > 0:
await mongo_collection.update_one( await mongo_collection.update_one(
{'_id': self.unique_cog_identifier}, {'_id': self.unique_cog_identifier},
update={"$unset": {dot_identifiers: 1}} update={"$unset": {dot_identifiers: 1}}
) )
else:
await mongo_collection.delete_one(
{'_id': self.unique_cog_identifier}
)
def get_config_details(): def get_config_details():

View File

@ -38,6 +38,7 @@ def json_driver(tmpdir_factory):
path = Path(str(tmpdir_factory.mktemp(rand))) path = Path(str(tmpdir_factory.mktemp(rand)))
driver = red_json.JSON( driver = red_json.JSON(
"PyTest", "PyTest",
identifier=str(uuid.uuid4()),
data_path_override=path data_path_override=path
) )
return driver return driver
@ -45,10 +46,9 @@ def json_driver(tmpdir_factory):
@pytest.fixture() @pytest.fixture()
def config(json_driver): def config(json_driver):
import uuid
conf = Config( conf = Config(
cog_name="PyTest", cog_name="PyTest",
unique_identifier=str(uuid.uuid4()), unique_identifier=json_driver.unique_cog_identifier,
driver=json_driver) driver=json_driver)
yield conf yield conf
conf._defaults = {} conf._defaults = {}
@ -59,10 +59,9 @@ def config_fr(json_driver):
""" """
Mocked config object with force_register enabled. Mocked config object with force_register enabled.
""" """
import uuid
conf = Config( conf = Config(
cog_name="PyTest", cog_name="PyTest",
unique_identifier=str(uuid.uuid4()), unique_identifier=json_driver.unique_cog_identifier,
driver=json_driver, driver=json_driver,
force_registration=True force_registration=True
) )

View File

@ -298,6 +298,16 @@ async def test_member_clear_all(config, member_factory):
assert len(await config.all_members()) == 0 assert len(await config.all_members()) == 0
@pytest.mark.asyncio
async def test_clear_all(config):
await config.foo.set(True)
assert await config.foo() is True
await config.clear_all()
with pytest.raises(KeyError):
await config.get_raw('foo')
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_clear_value(config): async def test_clear_value(config):
await config.foo.set(True) await config.foo.set(True)