mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 11:18:54 -05:00
[V3 Config] Fix unloading and implement singleton driver (#1458)
* Add the identifier as an initialization parameter * Remove config object singleton and opt for a shared JSON datastore * Fix bot unloading to deal with memory leaks * Fix tests * Fix clear all bug
This commit is contained in:
parent
720ef38886
commit
29ce2401ca
@ -7,6 +7,7 @@ from importlib.machinery import ModuleSpec
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
import sys
|
||||||
from discord.ext.commands.bot import BotBase
|
from discord.ext.commands.bot import BotBase
|
||||||
from discord.ext.commands import GroupMixin
|
from discord.ext.commands import GroupMixin
|
||||||
from discord.ext.commands import when_mentioned_or
|
from discord.ext.commands import when_mentioned_or
|
||||||
@ -268,9 +269,15 @@ class RedBase(BotBase, RpcMethodMixin):
|
|||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
# finally remove the import..
|
# finally remove the import..
|
||||||
|
pkg_name = lib.__package__
|
||||||
del lib
|
del lib
|
||||||
del self.extensions[name]
|
del self.extensions[name]
|
||||||
# del sys.modules[name]
|
for m, _ in sys.modules.copy().items():
|
||||||
|
if m.startswith(pkg_name):
|
||||||
|
del sys.modules[m]
|
||||||
|
|
||||||
|
if pkg_name.startswith('redbot.cogs'):
|
||||||
|
del sys.modules['redbot.cogs'].__dict__[name]
|
||||||
|
|
||||||
def register_rpc_methods(self):
|
def register_rpc_methods(self):
|
||||||
rpc.add_method('bot', self.rpc__cogs)
|
rpc.add_method('bot', self.rpc__cogs)
|
||||||
|
|||||||
@ -1,7 +1,5 @@
|
|||||||
import contextlib
|
|
||||||
import logging
|
import logging
|
||||||
import collections
|
import collections
|
||||||
from weakref import ref
|
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from typing import Union, Tuple
|
from typing import Union, Tuple
|
||||||
|
|
||||||
@ -418,10 +416,6 @@ class Group(Value):
|
|||||||
await self.driver.set(*self.identifiers, *path, value=value)
|
await self.driver.set(*self.identifiers, *path, value=value)
|
||||||
|
|
||||||
|
|
||||||
_config_cogrefs = {}
|
|
||||||
_config_coreref = None
|
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration manager for cogs and Red.
|
"""Configuration manager for cogs and Red.
|
||||||
|
|
||||||
@ -470,7 +464,6 @@ class Config:
|
|||||||
self.unique_identifier = unique_identifier
|
self.unique_identifier = unique_identifier
|
||||||
|
|
||||||
self.driver = driver
|
self.driver = driver
|
||||||
self.driver.unique_cog_identifier = self.unique_identifier
|
|
||||||
self.force_registration = force_registration
|
self.force_registration = force_registration
|
||||||
self._defaults = defaults or {}
|
self._defaults = defaults or {}
|
||||||
|
|
||||||
@ -483,6 +476,13 @@ class Config:
|
|||||||
force_registration=False, cog_name=None):
|
force_registration=False, cog_name=None):
|
||||||
"""Get a Config instance for your cog.
|
"""Get a Config instance for your cog.
|
||||||
|
|
||||||
|
.. warning::
|
||||||
|
|
||||||
|
If you are using this classmethod to get a second instance of an
|
||||||
|
existing Config object for a particular cog, you MUST provide the
|
||||||
|
correct identifier. If you do not, you *will* screw up all other
|
||||||
|
Config instances for that cog.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
cog_instance
|
cog_instance
|
||||||
@ -514,11 +514,6 @@ class Config:
|
|||||||
cog_name = cog_path_override.stem
|
cog_name = cog_path_override.stem
|
||||||
uuid = str(hash(identifier))
|
uuid = str(hash(identifier))
|
||||||
|
|
||||||
with contextlib.suppress(KeyError):
|
|
||||||
conf = _config_cogrefs[cog_name]()
|
|
||||||
if conf is not None:
|
|
||||||
return conf
|
|
||||||
|
|
||||||
# We have to import this here otherwise we have a circular dependency
|
# We have to import this here otherwise we have a circular dependency
|
||||||
from .data_manager import basic_config
|
from .data_manager import basic_config
|
||||||
|
|
||||||
@ -529,12 +524,11 @@ class Config:
|
|||||||
|
|
||||||
log.debug("Using driver: '{}'".format(driver_name))
|
log.debug("Using driver: '{}'".format(driver_name))
|
||||||
|
|
||||||
driver = get_driver(driver_name, cog_name, data_path_override=cog_path_override,
|
driver = get_driver(driver_name, cog_name, uuid, data_path_override=cog_path_override,
|
||||||
**driver_details)
|
**driver_details)
|
||||||
conf = cls(cog_name=cog_name, unique_identifier=uuid,
|
conf = cls(cog_name=cog_name, unique_identifier=uuid,
|
||||||
force_registration=force_registration,
|
force_registration=force_registration,
|
||||||
driver=driver)
|
driver=driver)
|
||||||
_config_cogrefs[cog_name] = ref(conf)
|
|
||||||
return conf
|
return conf
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -550,10 +544,6 @@ class Config:
|
|||||||
See `force_registration`.
|
See `force_registration`.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
global _config_coreref
|
|
||||||
if _config_coreref is not None and _config_coreref() is not None:
|
|
||||||
return _config_coreref()
|
|
||||||
|
|
||||||
core_path = core_data_path()
|
core_path = core_data_path()
|
||||||
|
|
||||||
# We have to import this here otherwise we have a circular dependency
|
# We have to import this here otherwise we have a circular dependency
|
||||||
@ -562,12 +552,11 @@ class Config:
|
|||||||
driver_name = basic_config.get('STORAGE_TYPE', 'JSON')
|
driver_name = basic_config.get('STORAGE_TYPE', 'JSON')
|
||||||
driver_details = basic_config.get('STORAGE_DETAILS', {})
|
driver_details = basic_config.get('STORAGE_DETAILS', {})
|
||||||
|
|
||||||
driver = get_driver(driver_name, "Core", data_path_override=core_path,
|
driver = get_driver(driver_name, "Core", '0', data_path_override=core_path,
|
||||||
**driver_details)
|
**driver_details)
|
||||||
conf = cls(cog_name="Core", driver=driver,
|
conf = cls(cog_name="Core", driver=driver,
|
||||||
unique_identifier='0',
|
unique_identifier='0',
|
||||||
force_registration=force_registration)
|
force_registration=force_registration)
|
||||||
_config_coreref = ref(conf)
|
|
||||||
return conf
|
return conf
|
||||||
|
|
||||||
def __getattr__(self, item: str) -> Union[Group, Value]:
|
def __getattr__(self, item: str) -> Union[Group, Value]:
|
||||||
@ -1003,7 +992,7 @@ class Config:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
if not scopes:
|
if not scopes:
|
||||||
group = Group(identifiers=(self.unique_identifier, ),
|
group = Group(identifiers=[],
|
||||||
defaults={},
|
defaults={},
|
||||||
driver=self.driver)
|
driver=self.driver)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -2,9 +2,9 @@ __all__ = ["BaseDriver"]
|
|||||||
|
|
||||||
|
|
||||||
class BaseDriver:
|
class BaseDriver:
|
||||||
def __init__(self, cog_name):
|
def __init__(self, cog_name, identifier):
|
||||||
self.cog_name = cog_name
|
self.cog_name = cog_name
|
||||||
self.unique_cog_identifier = None # This is set by Config's init method
|
self.unique_cog_identifier = identifier
|
||||||
|
|
||||||
async def get(self, *identifiers: str):
|
async def get(self, *identifiers: str):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -1,5 +1,7 @@
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
import weakref
|
||||||
|
import logging
|
||||||
|
|
||||||
from ..json_io import JsonIO
|
from ..json_io import JsonIO
|
||||||
|
|
||||||
@ -8,6 +10,28 @@ from .red_base import BaseDriver
|
|||||||
__all__ = ["JSON"]
|
__all__ = ["JSON"]
|
||||||
|
|
||||||
|
|
||||||
|
_shared_datastore = {}
|
||||||
|
_driver_counts = {}
|
||||||
|
_finalizers = []
|
||||||
|
|
||||||
|
log = logging.getLogger("redbot.json_driver")
|
||||||
|
|
||||||
|
|
||||||
|
def finalize_driver(cog_name):
|
||||||
|
if cog_name not in _driver_counts:
|
||||||
|
return
|
||||||
|
|
||||||
|
_driver_counts[cog_name] -= 1
|
||||||
|
|
||||||
|
if _driver_counts[cog_name] == 0:
|
||||||
|
if cog_name in _shared_datastore:
|
||||||
|
del _shared_datastore[cog_name]
|
||||||
|
|
||||||
|
for f in _finalizers:
|
||||||
|
if not f.alive:
|
||||||
|
_finalizers.remove(f)
|
||||||
|
|
||||||
|
|
||||||
class JSON(BaseDriver):
|
class JSON(BaseDriver):
|
||||||
"""
|
"""
|
||||||
Subclass of :py:class:`.red_base.BaseDriver`.
|
Subclass of :py:class:`.red_base.BaseDriver`.
|
||||||
@ -20,9 +44,9 @@ class JSON(BaseDriver):
|
|||||||
|
|
||||||
The path in which to store the file indicated by :py:attr:`file_name`.
|
The path in which to store the file indicated by :py:attr:`file_name`.
|
||||||
"""
|
"""
|
||||||
def __init__(self, cog_name, *, data_path_override: Path=None,
|
def __init__(self, cog_name, identifier, *, data_path_override: Path=None,
|
||||||
file_name_override: str="settings.json"):
|
file_name_override: str="settings.json"):
|
||||||
super().__init__(cog_name)
|
super().__init__(cog_name, identifier)
|
||||||
self.file_name = file_name_override
|
self.file_name = file_name_override
|
||||||
if data_path_override:
|
if data_path_override:
|
||||||
self.data_path = data_path_override
|
self.data_path = data_path_override
|
||||||
@ -35,6 +59,26 @@ class JSON(BaseDriver):
|
|||||||
|
|
||||||
self.jsonIO = JsonIO(self.data_path)
|
self.jsonIO = JsonIO(self.data_path)
|
||||||
|
|
||||||
|
self._load_data()
|
||||||
|
|
||||||
|
@property
|
||||||
|
def data(self):
|
||||||
|
return _shared_datastore.get(self.cog_name)
|
||||||
|
|
||||||
|
@data.setter
|
||||||
|
def data(self, value):
|
||||||
|
_shared_datastore[self.cog_name] = value
|
||||||
|
|
||||||
|
def _load_data(self):
|
||||||
|
if self.cog_name not in _driver_counts:
|
||||||
|
_driver_counts[self.cog_name] = 0
|
||||||
|
_driver_counts[self.cog_name] += 1
|
||||||
|
|
||||||
|
_finalizers.append(weakref.finalize(self, finalize_driver, self.cog_name))
|
||||||
|
|
||||||
|
if self.data is not None:
|
||||||
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
self.data = self.jsonIO._load_json()
|
self.data = self.jsonIO._load_json()
|
||||||
except FileNotFoundError:
|
except FileNotFoundError:
|
||||||
@ -65,8 +109,11 @@ class JSON(BaseDriver):
|
|||||||
try:
|
try:
|
||||||
for i in full_identifiers[:-1]:
|
for i in full_identifiers[:-1]:
|
||||||
partial = partial[i]
|
partial = partial[i]
|
||||||
del partial[identifiers[-1]]
|
del partial[full_identifiers[-1]]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
await self.jsonIO._threadsafe_save_json(self.data)
|
await self.jsonIO._threadsafe_save_json(self.data)
|
||||||
|
|
||||||
|
def get_config_details(self):
|
||||||
|
return
|
||||||
|
|||||||
@ -32,8 +32,8 @@ class Mongo(BaseDriver):
|
|||||||
"""
|
"""
|
||||||
Subclass of :py:class:`.red_base.BaseDriver`.
|
Subclass of :py:class:`.red_base.BaseDriver`.
|
||||||
"""
|
"""
|
||||||
def __init__(self, cog_name, **kwargs):
|
def __init__(self, cog_name, identifier, **kwargs):
|
||||||
super().__init__(cog_name)
|
super().__init__(cog_name, identifier)
|
||||||
|
|
||||||
if _conn is None:
|
if _conn is None:
|
||||||
_initialize(**kwargs)
|
_initialize(**kwargs)
|
||||||
@ -105,10 +105,15 @@ class Mongo(BaseDriver):
|
|||||||
dot_identifiers = '.'.join(identifiers)
|
dot_identifiers = '.'.join(identifiers)
|
||||||
mongo_collection = self.get_collection()
|
mongo_collection = self.get_collection()
|
||||||
|
|
||||||
await mongo_collection.update_one(
|
if len(identifiers) > 0:
|
||||||
{'_id': self.unique_cog_identifier},
|
await mongo_collection.update_one(
|
||||||
update={"$unset": {dot_identifiers: 1}}
|
{'_id': self.unique_cog_identifier},
|
||||||
)
|
update={"$unset": {dot_identifiers: 1}}
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
await mongo_collection.delete_one(
|
||||||
|
{'_id': self.unique_cog_identifier}
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_config_details():
|
def get_config_details():
|
||||||
|
|||||||
@ -38,6 +38,7 @@ def json_driver(tmpdir_factory):
|
|||||||
path = Path(str(tmpdir_factory.mktemp(rand)))
|
path = Path(str(tmpdir_factory.mktemp(rand)))
|
||||||
driver = red_json.JSON(
|
driver = red_json.JSON(
|
||||||
"PyTest",
|
"PyTest",
|
||||||
|
identifier=str(uuid.uuid4()),
|
||||||
data_path_override=path
|
data_path_override=path
|
||||||
)
|
)
|
||||||
return driver
|
return driver
|
||||||
@ -45,10 +46,9 @@ def json_driver(tmpdir_factory):
|
|||||||
|
|
||||||
@pytest.fixture()
|
@pytest.fixture()
|
||||||
def config(json_driver):
|
def config(json_driver):
|
||||||
import uuid
|
|
||||||
conf = Config(
|
conf = Config(
|
||||||
cog_name="PyTest",
|
cog_name="PyTest",
|
||||||
unique_identifier=str(uuid.uuid4()),
|
unique_identifier=json_driver.unique_cog_identifier,
|
||||||
driver=json_driver)
|
driver=json_driver)
|
||||||
yield conf
|
yield conf
|
||||||
conf._defaults = {}
|
conf._defaults = {}
|
||||||
@ -59,10 +59,9 @@ def config_fr(json_driver):
|
|||||||
"""
|
"""
|
||||||
Mocked config object with force_register enabled.
|
Mocked config object with force_register enabled.
|
||||||
"""
|
"""
|
||||||
import uuid
|
|
||||||
conf = Config(
|
conf = Config(
|
||||||
cog_name="PyTest",
|
cog_name="PyTest",
|
||||||
unique_identifier=str(uuid.uuid4()),
|
unique_identifier=json_driver.unique_cog_identifier,
|
||||||
driver=json_driver,
|
driver=json_driver,
|
||||||
force_registration=True
|
force_registration=True
|
||||||
)
|
)
|
||||||
|
|||||||
@ -298,6 +298,16 @@ async def test_member_clear_all(config, member_factory):
|
|||||||
assert len(await config.all_members()) == 0
|
assert len(await config.all_members()) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_clear_all(config):
|
||||||
|
await config.foo.set(True)
|
||||||
|
assert await config.foo() is True
|
||||||
|
|
||||||
|
await config.clear_all()
|
||||||
|
with pytest.raises(KeyError):
|
||||||
|
await config.get_raw('foo')
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_clear_value(config):
|
async def test_clear_value(config):
|
||||||
await config.foo.set(True)
|
await config.foo.set(True)
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user