mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-21 02:16:09 -05:00
[V3 RPC] Swap back to initial RPC library and hook into core commands (#1780)
* Switch RPC libs for websockets support * Implement RPC handling for core * Black reformat * Fix docs for build on travis * Modify RPC to use a Cog base class * Refactor rpc server reference as global * Handle cogbase unload method * Add an init call to handle mutable base attributes * Move RPC server reference back to the bot object * Remove unused import * Add tests for rpc method add/removal * Add tests for rpc method add/removal and cog base unloading * Add one more test * Black reformat * Add RPC mixin...fix MRO * Correct internal rpc method names * Add rpc test html file for debugging/example purposes * Add documentation * Add get_method_info * Update docs with an example RPC call specifying parameter formatting * Make rpc methods UPPER * Black reformat * Fix doc example * Modify this to match new method naming convention * Add more tests
This commit is contained in:
@@ -13,8 +13,7 @@ from redbot.core.events import init_events
|
||||
from redbot.core.cli import interactive_config, confirm, parse_cli_flags, ask_sentry
|
||||
from redbot.core.core_commands import Core
|
||||
from redbot.core.dev_commands import Dev
|
||||
from redbot.core import rpc, __version__
|
||||
import redbot.meta
|
||||
from redbot.core import __version__
|
||||
import asyncio
|
||||
import logging.handlers
|
||||
import logging
|
||||
@@ -112,7 +111,7 @@ def main():
|
||||
sys.exit(1)
|
||||
load_basic_configuration(cli_flags.instance_name)
|
||||
log, sentry_log = init_loggers(cli_flags)
|
||||
red = Red(cli_flags, description=description, pm_help=None)
|
||||
red = Red(cli_flags=cli_flags, description=description, pm_help=None)
|
||||
init_global_checks(red)
|
||||
init_events(red, cli_flags)
|
||||
red.add_cog(Core(red))
|
||||
@@ -166,6 +165,10 @@ def main():
|
||||
pending = asyncio.Task.all_tasks(loop=red.loop)
|
||||
gathered = asyncio.gather(*pending, loop=red.loop, return_exceptions=True)
|
||||
gathered.cancel()
|
||||
try:
|
||||
red.rpc.server.close()
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
sys.exit(red._shutdown_mode.value)
|
||||
|
||||
|
||||
@@ -18,12 +18,13 @@ from discord.voice_client import VoiceClient
|
||||
VoiceClient.warn_nacl = False
|
||||
|
||||
from .cog_manager import CogManager
|
||||
from . import Config, i18n, commands, rpc
|
||||
from . import Config, i18n, commands
|
||||
from .rpc import RPCMixin
|
||||
from .help_formatter import Help, help as help_
|
||||
from .sentry import SentryManager
|
||||
|
||||
|
||||
class RedBase(BotBase):
|
||||
class RedBase(BotBase, RPCMixin):
|
||||
"""Mixin for the main bot class.
|
||||
|
||||
This exists because `Red` inherits from `discord.AutoShardedClient`, which
|
||||
@@ -33,7 +34,7 @@ class RedBase(BotBase):
|
||||
Selfbots should inherit from this mixin along with `discord.Client`.
|
||||
"""
|
||||
|
||||
def __init__(self, cli_flags, bot_dir: Path = Path.cwd(), **kwargs):
|
||||
def __init__(self, *args, cli_flags=None, bot_dir: Path = Path.cwd(), **kwargs):
|
||||
self._shutdown_mode = ExitCodes.CRITICAL
|
||||
self.db = Config.get_core_conf(force_registration=True)
|
||||
self._co_owners = cli_flags.co_owner
|
||||
@@ -107,10 +108,7 @@ class RedBase(BotBase):
|
||||
|
||||
self.cog_mgr = CogManager(paths=(str(self.main_dir / "cogs"),))
|
||||
|
||||
super().__init__(formatter=Help(), **kwargs)
|
||||
|
||||
if self.rpc_enabled:
|
||||
self.rpc = rpc.RPC(self)
|
||||
super().__init__(*args, formatter=Help(), **kwargs)
|
||||
|
||||
self.remove_command("help")
|
||||
|
||||
@@ -235,12 +233,24 @@ class RedBase(BotBase):
|
||||
lib_name = lib.__name__ # Thank you
|
||||
|
||||
# find all references to the module
|
||||
cog_names = []
|
||||
|
||||
# remove the cogs registered from the module
|
||||
for cogname, cog in self.cogs.copy().items():
|
||||
if cog.__module__.startswith(lib_name):
|
||||
self.remove_cog(cogname)
|
||||
|
||||
cog_names.append(cogname)
|
||||
|
||||
# remove all rpc handlers
|
||||
for cogname in cog_names:
|
||||
if cogname.upper() in self.rpc_handlers:
|
||||
methods = self.rpc_handlers[cogname]
|
||||
for meth in methods:
|
||||
self.unregister_rpc_handler(meth)
|
||||
|
||||
del self.rpc_handlers[cogname]
|
||||
|
||||
# first remove all the commands from the module
|
||||
for cmd in self.all_commands.copy().values():
|
||||
if cmd.module.startswith(lib_name):
|
||||
|
||||
@@ -46,6 +46,13 @@ _ = i18n.Translator("Core", __file__)
|
||||
class CoreLogic:
|
||||
def __init__(self, bot: "Red"):
|
||||
self.bot = bot
|
||||
self.bot.register_rpc_handler(self._load)
|
||||
self.bot.register_rpc_handler(self._unload)
|
||||
self.bot.register_rpc_handler(self._reload)
|
||||
self.bot.register_rpc_handler(self._name)
|
||||
self.bot.register_rpc_handler(self._prefixes)
|
||||
self.bot.register_rpc_handler(self._version_info)
|
||||
self.bot.register_rpc_handler(self._invite_url)
|
||||
|
||||
async def _load(self, cog_names: list):
|
||||
"""
|
||||
|
||||
@@ -18,6 +18,7 @@ from .data_manager import storage_type
|
||||
from .utils.chat_formatting import inline, bordered, pagify, box
|
||||
from .utils import fuzzy_command_search
|
||||
from colorama import Fore, Style, init
|
||||
from . import rpc
|
||||
|
||||
log = logging.getLogger("red")
|
||||
sentry_log = logging.getLogger("red.sentry")
|
||||
@@ -84,6 +85,9 @@ def init_events(bot, cli_flags):
|
||||
if packages:
|
||||
print("Loaded packages: " + ", ".join(packages))
|
||||
|
||||
if bot.rpc_enabled:
|
||||
await bot.rpc.initialize()
|
||||
|
||||
guilds = len(bot.guilds)
|
||||
users = len(set([m for m in bot.get_all_members()]))
|
||||
|
||||
@@ -172,8 +176,6 @@ def init_events(bot, cli_flags):
|
||||
print("\nInvite URL: {}\n".format(invite_url))
|
||||
|
||||
bot.color = discord.Colour(await bot.db.color())
|
||||
if bot.rpc_enabled:
|
||||
await bot.rpc.initialize()
|
||||
|
||||
@bot.event
|
||||
async def on_error(event_method, *args, **kwargs):
|
||||
|
||||
@@ -1,116 +1,74 @@
|
||||
import weakref
|
||||
import asyncio
|
||||
|
||||
from aiohttp import web
|
||||
import jsonrpcserver.aio
|
||||
from aiohttp_json_rpc import JsonRpc
|
||||
from aiohttp_json_rpc.rpc import unpack_request_args
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
|
||||
__all__ = ["methods", "RPC", "Methods"]
|
||||
|
||||
log = logging.getLogger("red.rpc")
|
||||
|
||||
|
||||
class Methods(jsonrpcserver.aio.AsyncMethods):
|
||||
"""
|
||||
Container class for all registered RPC methods, please use the existing `methods`
|
||||
attribute rather than creating a new instance of this class.
|
||||
|
||||
.. warning::
|
||||
|
||||
**NEVER** create a new instance of this class!
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self._items = weakref.WeakValueDictionary()
|
||||
|
||||
def add(self, method, name: str = None):
|
||||
"""
|
||||
Registers a method to the internal RPC server making it available for
|
||||
RPC users to call.
|
||||
|
||||
.. important::
|
||||
|
||||
Any method added here must take ONLY JSON serializable parameters and
|
||||
MUST return a JSON serializable object.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
method : function
|
||||
A reference to the function to register.
|
||||
|
||||
name : str
|
||||
Name of the function as seen by the RPC clients.
|
||||
"""
|
||||
if not inspect.iscoroutinefunction(method):
|
||||
raise TypeError("Method must be a coroutine.")
|
||||
|
||||
if name is None:
|
||||
name = method.__qualname__
|
||||
|
||||
self._items[str(name)] = method
|
||||
|
||||
def remove(self, *, name: str = None, method=None):
|
||||
"""
|
||||
Unregisters an RPC method. Either a name or reference to the method must
|
||||
be provided and name will take priority.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
name : str
|
||||
method : function
|
||||
"""
|
||||
if name and name in self._items:
|
||||
del self._items[name]
|
||||
|
||||
elif method and method in self._items.values():
|
||||
to_remove = []
|
||||
for name, val in self._items.items():
|
||||
if method == val:
|
||||
to_remove.append(name)
|
||||
|
||||
for name in to_remove:
|
||||
del self._items[name]
|
||||
|
||||
def all_methods(self):
|
||||
"""
|
||||
Lists all available method names.
|
||||
|
||||
Returns
|
||||
-------
|
||||
list of str
|
||||
"""
|
||||
return self._items.keys()
|
||||
__all__ = ["RPC", "RPCMixin", "get_name"]
|
||||
|
||||
|
||||
methods = Methods()
|
||||
def get_name(func, prefix=None):
|
||||
class_name = prefix or func.__self__.__class__.__name__.lower()
|
||||
func_name = func.__name__.strip("_")
|
||||
if class_name == "redrpc":
|
||||
return func_name.upper()
|
||||
return f"{class_name}__{func_name}".upper()
|
||||
|
||||
|
||||
class BaseRPCMethodMixin:
|
||||
def __init__(self):
|
||||
methods.add(self.all_methods, name="all_methods")
|
||||
class RedRpc(JsonRpc):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.add_methods(("", self.get_method_info))
|
||||
|
||||
async def all_methods(self):
|
||||
return list(methods.all_methods())
|
||||
def _add_method(self, method, prefix=""):
|
||||
if not asyncio.iscoroutinefunction(method):
|
||||
return
|
||||
|
||||
name = get_name(method, prefix)
|
||||
|
||||
self.methods[name] = method
|
||||
|
||||
def remove_method(self, method):
|
||||
meth_name = get_name(method)
|
||||
new_methods = {}
|
||||
for name, meth in self.methods.items():
|
||||
if name != meth_name:
|
||||
new_methods[name] = meth
|
||||
self.methods = new_methods
|
||||
|
||||
def remove_methods(self, prefix: str):
|
||||
new_methods = {}
|
||||
for name, meth in self.methods.items():
|
||||
splitted = name.split("__")
|
||||
if len(splitted) < 2 or splitted[0] != prefix:
|
||||
new_methods[name] = meth
|
||||
self.methods = new_methods
|
||||
|
||||
async def get_method_info(self, request):
|
||||
method_name = request.params[0]
|
||||
if method_name in self.methods:
|
||||
return self.methods[method_name].__doc__
|
||||
return "No docstring available."
|
||||
|
||||
|
||||
class RPC(BaseRPCMethodMixin):
|
||||
class RPC:
|
||||
"""
|
||||
RPC server manager.
|
||||
"""
|
||||
|
||||
def __init__(self, bot):
|
||||
self.app = web.Application(loop=bot.loop)
|
||||
self.app.router.add_post("/rpc", self.handle)
|
||||
def __init__(self):
|
||||
self.app = web.Application()
|
||||
self._rpc = RedRpc()
|
||||
self.app.router.add_route("*", "/", self._rpc)
|
||||
|
||||
self.app_handler = self.app.make_handler()
|
||||
|
||||
self.server = None
|
||||
|
||||
super().__init__()
|
||||
|
||||
async def initialize(self):
|
||||
"""
|
||||
Finalizes the initialization of the RPC server and allows it to begin
|
||||
@@ -125,10 +83,79 @@ class RPC(BaseRPCMethodMixin):
|
||||
"""
|
||||
self.server.close()
|
||||
|
||||
async def handle(self, request):
|
||||
request = await request.text()
|
||||
response = await methods.dispatch(request)
|
||||
if response.is_notification:
|
||||
return web.Response()
|
||||
else:
|
||||
return web.json_response(response, status=response.http_status)
|
||||
def add_method(self, method, prefix: str = None):
|
||||
if prefix is None:
|
||||
prefix = method.__self__.__class__.__name__.lower()
|
||||
|
||||
if not asyncio.iscoroutinefunction(method):
|
||||
raise TypeError("RPC methods must be coroutines.")
|
||||
|
||||
self._rpc.add_methods((prefix, unpack_request_args(method)))
|
||||
|
||||
def add_multi_method(self, *methods, prefix: str = None):
|
||||
if not all(asyncio.iscoroutinefunction(m) for m in methods):
|
||||
raise TypeError("RPC methods must be coroutines.")
|
||||
|
||||
for method in methods:
|
||||
self.add_method(method, prefix=prefix)
|
||||
|
||||
def remove_method(self, method):
|
||||
self._rpc.remove_method(method)
|
||||
|
||||
def remove_methods(self, prefix: str):
|
||||
self._rpc.remove_methods(prefix)
|
||||
|
||||
|
||||
class RPCMixin:
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.rpc = RPC()
|
||||
|
||||
self.rpc_handlers = {} # Lowered cog name to method
|
||||
|
||||
def register_rpc_handler(self, method):
|
||||
"""
|
||||
Registers a method to act as an RPC handler if the internal RPC server is active.
|
||||
|
||||
When calling this method through the RPC server, use the naming scheme "cogname__methodname".
|
||||
|
||||
.. important::
|
||||
|
||||
All parameters to RPC handler methods must be JSON serializable objects.
|
||||
The return value of handler methods must also be JSON serializable.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
method : coroutine
|
||||
The method to register with the internal RPC server.
|
||||
"""
|
||||
self.rpc.add_method(method)
|
||||
|
||||
cog_name = method.__self__.__class__.__name__.upper()
|
||||
if cog_name not in self.rpc_handlers:
|
||||
self.rpc_handlers[cog_name] = []
|
||||
|
||||
self.rpc_handlers[cog_name].append(method)
|
||||
|
||||
def unregister_rpc_handler(self, method):
|
||||
"""
|
||||
Unregisters an RPC method handler.
|
||||
|
||||
This will be called automatically for you on cog unload and will pass silently if the
|
||||
method is not previously registered.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
method : coroutine
|
||||
The method to unregister from the internal RPC server.
|
||||
"""
|
||||
self.rpc.remove_method(method)
|
||||
|
||||
name = get_name(method)
|
||||
cog_name = name.split("__")[0]
|
||||
|
||||
if cog_name in self.rpc_handlers:
|
||||
try:
|
||||
self.rpc_handlers[cog_name].remove(method)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user