[V3 Core] Fix unload_extension for module-less cogs (#1984)

Fixes #1943

This just skips cogs when  `__module__` is None. Also:
- backported rapptz/discord.py#621
- moved RPC handler unregister to remove_cog()
- raise an exception when a load would overwrite an existing extension
This commit is contained in:
Caleb Johnson 2018-08-07 20:26:14 -04:00 committed by Toby Harradine
parent aec3ad382a
commit 872cce784a
2 changed files with 25 additions and 22 deletions

View File

@ -24,6 +24,10 @@ from .help_formatter import Help, help as help_
from .sentry import SentryManager
def _is_submodule(parent, child):
return parent == child or child.startswith(parent + ".")
class RedBase(BotBase, RPCMixin):
"""Mixin for the main bot class.
@ -211,12 +215,12 @@ class RedBase(BotBase, RPCMixin):
async def load_extension(self, spec: ModuleSpec):
name = spec.name.split(".")[-1]
if name in self.extensions:
return
raise discord.ClientException(f"there is already a package named {name} loaded")
lib = spec.loader.load_module()
if not hasattr(lib, "setup"):
del lib
raise discord.ClientException("extension does not have a setup function")
raise discord.ClientException(f"extension {name} does not have a setup function")
if asyncio.iscoroutinefunction(lib.setup):
await lib.setup(self)
@ -225,44 +229,41 @@ class RedBase(BotBase, RPCMixin):
self.extensions[name] = lib
def remove_cog(self, cogname):
super().remove_cog(cogname)
for meth in self.rpc_handlers.pop(cogname.upper(), ()):
self.unregister_rpc_handler(meth)
def unload_extension(self, name):
lib = self.extensions.get(name)
if lib is None:
return
lib_name = lib.__name__ # Thank you
# find all references to the module
cog_names = []
# remove the cogs registered from the module
for cogname, cog in self.cogs.copy().items():
if cog.__module__.startswith(lib_name):
if cog.__module__ and _is_submodule(lib_name, cog.__module__):
self.remove_cog(cogname)
cog_names.append(cogname)
# remove all rpc handlers
for cogname in cog_names:
if cogname.upper() in self.rpc_handlers:
methods = self.rpc_handlers[cogname]
for meth in methods:
self.unregister_rpc_handler(meth)
del self.rpc_handlers[cogname]
# first remove all the commands from the module
for cmd in self.all_commands.copy().values():
if cmd.module.startswith(lib_name):
if cmd.module and _is_submodule(lib_name, cmd.module):
if isinstance(cmd, GroupMixin):
cmd.recursively_remove_all_commands()
self.remove_command(cmd.name)
# then remove all the listeners from the module
for event_list in self.extra_events.copy().values():
remove = []
for index, event in enumerate(event_list):
if event.__module__.startswith(lib_name):
if event.__module__ and _is_submodule(lib_name, event.__module__):
remove.append(index)
for index in reversed(remove):
@ -282,11 +283,12 @@ class RedBase(BotBase, RPCMixin):
pkg_name = lib.__package__
del lib
del self.extensions[name]
for m, _ in sys.modules.copy().items():
if m.startswith(pkg_name):
del sys.modules[m]
if pkg_name.startswith("redbot.cogs"):
for module in list(sys.modules):
if _is_submodule(lib_name, module):
del sys.modules[module]
if pkg_name.startswith("redbot.cogs."):
del sys.modules["redbot.cogs"].__dict__[name]

View File

@ -111,7 +111,7 @@ class RPCMixin:
super().__init__(**kwargs)
self.rpc = RPC()
self.rpc_handlers = {} # Lowered cog name to method
self.rpc_handlers = {} # Uppercase cog name to method
def register_rpc_handler(self, method):
"""
@ -132,6 +132,7 @@ class RPCMixin:
self.rpc.add_method(method)
cog_name = method.__self__.__class__.__name__.upper()
if cog_name not in self.rpc_handlers:
self.rpc_handlers[cog_name] = []