From 476f441c9bcbd73ae910bf6d2aca5f2be57b7980 Mon Sep 17 00:00:00 2001 From: Toby Harradine Date: Tue, 30 Apr 2019 11:31:28 +1000 Subject: [PATCH] [Audio] Refactor internal Lavalink server management (#2495) * Refactor internal Lavalink server management Killing many birds with one stone here. - Made server manager into class-based API with two public methods: `start()` and `shutdown()`. Must be re-instantiated each time it is restarted. - Using V3 universal Lavalink.jar hosted on Cog-Creators/Lavalink-Jars repository. - Uses output of `java -jar Lavalink.jar --version` to check if a new jar needs to be downloaded. - `ServerManager.start()` won't return until server is ready, i.e. when "Started Launcher in X seconds" message is printed to STDOUT. - `shlex.quote()` is used so spaces in path to Lavalink.jar don't cause issues. - Enabling external Lavalink will cause internal server to be terminated. - Disabling internal Lavalink will no longer reset settings in config - instead, hard-coded values will be used when connecting to an internal server. - Internal server will now run both WS and REST servers on port 2333, meaning one less port will need to be taken up. - Now using `asyncio.subprocess` module so waiting on and reading from subprocesses can be done asynchronously. Signed-off-by: Toby Harradine * Don't use shlex.quote on Windows Signed-off-by: Toby * Don't use shlex.quote at all I misread a note in the python docs and assumed it was best to use it. Turns out the note only applies to `asyncio.create_subprocess_shell`. Signed-off-by: Toby * Missed the port on the rebase * Ignore invalid architectures and inform users when commands are used. * Style fix --- redbot/cogs/audio/__init__.py | 24 +- redbot/cogs/audio/audio.py | 128 ++++++---- redbot/cogs/audio/data/application.yml | 4 +- redbot/cogs/audio/manager.py | 333 +++++++++++++++---------- 4 files changed, 287 insertions(+), 202 deletions(-) diff --git a/redbot/cogs/audio/__init__.py b/redbot/cogs/audio/__init__.py index 3f912df3d..d36ffc7e3 100644 --- a/redbot/cogs/audio/__init__.py +++ b/redbot/cogs/audio/__init__.py @@ -1,31 +1,9 @@ -from pathlib import Path -import logging +from redbot.core import commands from .audio import Audio -from .manager import start_lavalink_server, maybe_download_lavalink -from redbot.core import commands -from redbot.core.data_manager import cog_data_path -import redbot.core - -log = logging.getLogger("red.audio") - -LAVALINK_DOWNLOAD_URL = ( - "https://github.com/Cog-Creators/Red-DiscordBot/releases/download/{}/Lavalink.jar" -).format(redbot.core.__version__) - -LAVALINK_DOWNLOAD_DIR = cog_data_path(raw_name="Audio") -LAVALINK_JAR_FILE = LAVALINK_DOWNLOAD_DIR / "Lavalink.jar" - -APP_YML_FILE = LAVALINK_DOWNLOAD_DIR / "application.yml" -BUNDLED_APP_YML_FILE = Path(__file__).parent / "data/application.yml" async def setup(bot: commands.Bot): cog = Audio(bot) - if not await cog.config.use_external_lavalink(): - await maybe_download_lavalink(bot.loop, cog) - await start_lavalink_server(bot.loop) - await cog.initialize() - bot.add_cog(cog) diff --git a/redbot/cogs/audio/audio.py b/redbot/cogs/audio/audio.py index 6578c57c7..397de8ff2 100644 --- a/redbot/cogs/audio/audio.py +++ b/redbot/cogs/audio/audio.py @@ -14,6 +14,7 @@ import os import random import re import time +from typing import Optional import redbot.core from redbot.core import Config, commands, checks, bank from redbot.core.data_manager import cog_data_path @@ -29,7 +30,7 @@ from redbot.core.utils.menus import ( ) from redbot.core.utils.predicates import MessagePredicate, ReactionPredicate from urllib.parse import urlparse -from .manager import shutdown_lavalink_server, start_lavalink_server, maybe_download_lavalink +from .manager import ServerManager _ = Translator("Audio", __file__) @@ -43,41 +44,45 @@ log = logging.getLogger("red.audio") class Audio(commands.Cog): """Play audio through voice channels.""" + _default_lavalink_settings = { + "host": "localhost", + "rest_port": 2333, + "ws_port": 2333, + "password": "youshallnotpass", + } + def __init__(self, bot): super().__init__() self.bot = bot self.config = Config.get_conf(self, 2711759130, force_registration=True) - default_global = { - "host": "localhost", - "rest_port": "2333", - "ws_port": "2332", - "password": "youshallnotpass", - "status": False, - "current_version": redbot.core.VersionInfo.from_str("3.0.0a0").to_json(), - "use_external_lavalink": False, - "restrict": True, - "localpath": str(cog_data_path(raw_name="Audio")), - } + default_global = dict( + status=False, + use_external_lavalink=False, + restrict=True, + current_version=redbot.core.VersionInfo.from_str("3.0.0a0").to_json(), + localpath=str(cog_data_path(raw_name="Audio")), + **self._default_lavalink_settings, + ) - default_guild = { - "disconnect": False, - "dj_enabled": False, - "dj_role": None, - "emptydc_enabled": False, - "emptydc_timer": 0, - "jukebox": False, - "jukebox_price": 0, - "maxlength": 0, - "playlists": {}, - "notify": False, - "repeat": False, - "shuffle": False, - "thumbnail": False, - "volume": 100, - "vote_enabled": False, - "vote_percent": 0, - } + default_guild = dict( + disconnect=False, + dj_enabled=False, + dj_role=None, + emptydc_enabled=False, + emptydc_timer=0, + jukebox=False, + jukebox_price=0, + maxlength=0, + playlists={}, + notify=False, + repeat=False, + shuffle=False, + thumbnail=False, + volume=100, + vote_enabled=False, + vote_percent=0, + ) self.config.register_guild(**default_guild) self.config.register_global(**default_global) @@ -86,9 +91,24 @@ class Audio(commands.Cog): self._connect_task = None self._disconnect_task = None self._cleaned_up = False + self.spotify_token = None self.play_lock = {} + self._manager: Optional[ServerManager] = None + + async def cog_before_invoke(self, ctx): + if self.llsetup in [ctx.command, ctx.command.root_parent]: + pass + elif self._connect_task.cancelled: + await ctx.send( + "You have attempted to run Audio's Lavalink server on an unsupported" + " architecture. Only settings related commands will be available." + ) + raise RuntimeError( + "Not running audio command due to invalid machine architecture for Lavalink." + ) + async def initialize(self): self._restart_connect() self._disconnect_task = self.bot.loop.create_task(self.disconnect_timer()) @@ -103,16 +123,33 @@ class Audio(commands.Cog): async def attempt_connect(self, timeout: int = 30): while True: # run until success external = await self.config.use_external_lavalink() - if not external: - shutdown_lavalink_server() - await maybe_download_lavalink(self.bot.loop, self) - await start_lavalink_server(self.bot.loop) - try: + if external is False: + settings = self._default_lavalink_settings + host = settings["host"] + password = settings["password"] + rest_port = settings["rest_port"] + ws_port = settings["ws_port"] + if self._manager is not None: + await self._manager.shutdown() + self._manager = ServerManager() + try: + await self._manager.start() + except RuntimeError as exc: + log.exception( + "Exception whilst starting internal Lavalink server, retrying...", + exc_info=exc, + ) + await asyncio.sleep(1) + continue + except asyncio.CancelledError: + log.exception("Invalid machine architecture, cannot run Lavalink.") + break + else: host = await self.config.host() password = await self.config.password() rest_port = await self.config.rest_port() ws_port = await self.config.ws_port() - + try: await lavalink.initialize( bot=self.bot, host=host, @@ -122,9 +159,10 @@ class Audio(commands.Cog): timeout=timeout, ) return # break infinite loop - except Exception: - if not external: - shutdown_lavalink_server() + except asyncio.TimeoutError: + log.error("Connecting to Lavalink server timed out, retrying...") + if external is False and self._manager is not None: + await self._manager.shutdown() await asyncio.sleep(1) # prevent busylooping async def event_handler(self, player, event_type, extra): @@ -3104,19 +3142,16 @@ class Audio(commands.Cog): await self.config.use_external_lavalink.set(not external) if external: - await self.config.host.set("localhost") - await self.config.password.set("youshallnotpass") - await self.config.rest_port.set(2333) - await self.config.ws_port.set(2332) embed = discord.Embed( colour=await ctx.embed_colour(), title=_("External lavalink server: {true_or_false}.").format( true_or_false=not external ), ) - embed.set_footer(text=_("Defaults reset.")) await ctx.send(embed=embed) else: + if self._manager is not None: + await self._manager.shutdown() await self._embed_msg( ctx, _("External lavalink server: {true_or_false}.").format(true_or_false=not external), @@ -3229,6 +3264,8 @@ class Audio(commands.Cog): async def _check_external(self): external = await self.config.use_external_lavalink() if not external: + if self._manager is not None: + await self._manager.shutdown() await self.config.use_external_lavalink.set(True) return True else: @@ -3597,7 +3634,8 @@ class Audio(commands.Cog): lavalink.unregister_event_listener(self.event_handler) self.bot.loop.create_task(lavalink.close()) - shutdown_lavalink_server() + if self._manager is not None: + self.bot.loop.create_task(self._manager.shutdown()) self._cleaned_up = True __del__ = cog_unload diff --git a/redbot/cogs/audio/data/application.yml b/redbot/cogs/audio/data/application.yml index 9b8d7fe33..2c7e586c9 100644 --- a/redbot/cogs/audio/data/application.yml +++ b/redbot/cogs/audio/data/application.yml @@ -1,11 +1,9 @@ server: + host: "localhost" port: 2333 # REST server lavalink: server: password: "youshallnotpass" - ws: - host: "localhost" - port: 2332 sources: youtube: true bandcamp: true diff --git a/redbot/cogs/audio/manager.py b/redbot/cogs/audio/manager.py index 10cc7d4d9..dfcd46e62 100644 --- a/redbot/cogs/audio/manager.py +++ b/redbot/cogs/audio/manager.py @@ -1,172 +1,243 @@ -import shlex +import itertools +import pathlib +import platform import shutil import asyncio import asyncio.subprocess -import os import logging import re -from subprocess import Popen, DEVNULL -from typing import Optional, Tuple +import tempfile +from typing import Optional, Tuple, ClassVar, List -from aiohttp import ClientSession +import aiohttp -import redbot.core +from redbot.core import data_manager -_JavaVersion = Tuple[int, int] +JAR_VERSION = "3.2.0.3" +JAR_BUILD = 751 +LAVALINK_DOWNLOAD_URL = ( + f"https://github.com/Cog-Creators/Lavalink-Jars/releases/download/{JAR_VERSION}_{JAR_BUILD}/" + f"Lavalink.jar" +) +LAVALINK_DOWNLOAD_DIR = data_manager.cog_data_path(raw_name="Audio") +LAVALINK_JAR_FILE = LAVALINK_DOWNLOAD_DIR / "Lavalink.jar" + +BUNDLED_APP_YML = pathlib.Path(__file__).parent / "data" / "application.yml" +LAVALINK_APP_YML = LAVALINK_DOWNLOAD_DIR / "application.yml" + +READY_LINE_RE = re.compile(rb"Started Launcher in \S+ seconds") +BUILD_LINE_RE = re.compile(rb"Build:\s+(?P\d+)") log = logging.getLogger("red.audio.manager") -proc = None -shutdown = False +class ServerManager: -def has_java_error(pid): - from . import LAVALINK_DOWNLOAD_DIR + _java_available: ClassVar[Optional[bool]] = None + _java_version: ClassVar[Optional[Tuple[int, int]]] = None + _up_to_date: ClassVar[Optional[bool]] = None - poss_error_file = LAVALINK_DOWNLOAD_DIR / "hs_err_pid{}.log".format(pid) - return poss_error_file.exists() + _blacklisted_archs = ["armv6l", "aarch32", "aarch64"] + def __init__(self) -> None: + self.ready = asyncio.Event() -async def monitor_lavalink_server(loop): - global shutdown - while shutdown is False: - if proc.poll() is not None: - break - await asyncio.sleep(0.5) + self._proc: Optional[asyncio.subprocess.Process] = None + self._monitor_task: Optional[asyncio.Task] = None + self._shutdown: bool = False - if shutdown is False: - # Lavalink was shut down by something else - log.info("Lavalink jar shutdown.") - shutdown = True - if not has_java_error(proc.pid): - log.info("Restarting Lavalink jar.") - await start_lavalink_server(loop) - else: - log.error( - "Your Java is borked. Please find the hs_err_pid{}.log file" - " in the Audio data folder and report this issue.".format(proc.pid) + async def start(self) -> None: + arch_name = platform.machine() + if arch_name in self._blacklisted_archs: + raise asyncio.CancelledError( + "You are attempting to run Lavalink audio on an unsupported machine architecture." ) + if self._proc is not None: + if self._proc.returncode is None: + raise RuntimeError("Internal Lavalink server is already running") + else: + raise RuntimeError("Server manager has already been used - create another one") -async def has_java(loop) -> Tuple[bool, Optional[_JavaVersion]]: - java_available = shutil.which("java") is not None - if not java_available: - return False, None + await self.maybe_download_jar() - version = await get_java_version(loop) - return (2, 0) > version >= (1, 8) or version >= (8, 0), version + # Copy the application.yml across. + # For people to customise their Lavalink server configuration they need to run it + # externally + shutil.copyfile(BUNDLED_APP_YML, LAVALINK_APP_YML) + args = await self._get_jar_args() + self._proc = await asyncio.subprocess.create_subprocess_exec( + *args, + cwd=str(LAVALINK_DOWNLOAD_DIR), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) -async def get_java_version(loop) -> _JavaVersion: - """ - This assumes we've already checked that java exists. - """ - _proc: asyncio.subprocess.Process = await asyncio.create_subprocess_exec( - "java", - "-version", - stdout=asyncio.subprocess.PIPE, - stderr=asyncio.subprocess.PIPE, - loop=loop, - ) - # java -version outputs to stderr - _, err = await _proc.communicate() + log.info("Internal Lavalink server started. PID: %s", self._proc.pid) - version_info: str = err.decode("utf-8") - # We expect the output to look something like: - # $ java -version - # ... - # ... version "MAJOR.MINOR.PATCH[_BUILD]" ... - # ... - # We only care about the major and minor parts though. - version_line_re = re.compile( - r'version "(?P\d+).(?P\d+).\d+(?:_\d+)?(?:-[A-Za-z0-9]+)?"' - ) - short_version_re = re.compile(r'version "(?P\d+)"') + try: + await asyncio.wait_for(self._wait_for_launcher(), timeout=120) + except asyncio.TimeoutError: + log.warning("Timeout occurred whilst waiting for internal Lavalink server to be ready") - lines = version_info.splitlines() - for line in lines: - match = version_line_re.search(line) - short_match = short_version_re.search(line) - if match: - return int(match["major"]), int(match["minor"]) - elif short_match: - return int(short_match["major"]), 0 + self._monitor_task = asyncio.create_task(self._monitor()) - raise RuntimeError( - "The output of `java -version` was unexpected. Please report this issue on Red's " - "issue tracker." - ) + @classmethod + async def _get_jar_args(cls) -> List[str]: + java_available, java_version = await cls._has_java() + if not java_available: + raise RuntimeError("You must install Java 1.8+ for Lavalink to run.") + if java_version == (1, 8): + extra_flags = ["-Dsun.zip.disableMemoryMapping=true"] + elif java_version >= (11, 0): + extra_flags = ["-Djdk.tls.client.protocols=TLSv1.2"] + else: + extra_flags = [] -async def start_lavalink_server(loop): - java_available, java_version = await has_java(loop) - if not java_available: - raise RuntimeError("You must install Java 1.8+ for Lavalink to run.") + return ["java", *extra_flags, "-jar", str(LAVALINK_JAR_FILE)] - if java_version == (1, 8): - extra_flags = "-Dsun.zip.disableMemoryMapping=true" - elif java_version >= (11, 0): - extra_flags = "-Djdk.tls.client.protocols=TLSv1.2" - else: - extra_flags = "" + @classmethod + async def _has_java(cls) -> Tuple[bool, Optional[Tuple[int, int]]]: + if cls._java_available is not None: + # Return cached value if we've checked this before + return cls._java_available, cls._java_version + java_available = shutil.which("java") is not None + if not java_available: + cls.java_available = False + cls.java_version = None + else: + cls._java_version = version = await cls._get_java_version() + cls._java_available = (2, 0) > version >= (1, 8) or version >= (8, 0) + return cls._java_available, cls._java_version - from . import LAVALINK_DOWNLOAD_DIR, LAVALINK_JAR_FILE + @staticmethod + async def _get_java_version() -> Tuple[int, int]: + """ + This assumes we've already checked that java exists. + """ + _proc: asyncio.subprocess.Process = await asyncio.create_subprocess_exec( + "java", "-version", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE + ) + # java -version outputs to stderr + _, err = await _proc.communicate() - start_cmd = "java {} -jar {}".format(extra_flags, LAVALINK_JAR_FILE.resolve()) + version_info: str = err.decode("utf-8") + # We expect the output to look something like: + # $ java -version + # ... + # ... version "MAJOR.MINOR.PATCH[_BUILD]" ... + # ... + # We only care about the major and minor parts though. + version_line_re = re.compile( + r'version "(?P\d+).(?P\d+).\d+(?:_\d+)?(?:-[A-Za-z0-9]+)?"' + ) + short_version_re = re.compile(r'version "(?P\d+)"') - global proc + lines = version_info.splitlines() + for line in lines: + match = version_line_re.search(line) + short_match = short_version_re.search(line) + if match: + return int(match["major"]), int(match["minor"]) + elif short_match: + return int(short_match["major"]), 0 - if proc and proc.poll() is None: - return # already running + raise RuntimeError( + "The output of `java -version` was unexpected. Please report this issue on Red's " + "issue tracker." + ) - proc = Popen( - shlex.split(start_cmd, posix=os.name == "posix"), - cwd=str(LAVALINK_DOWNLOAD_DIR), - stdout=DEVNULL, - stderr=DEVNULL, - ) + async def _wait_for_launcher(self) -> None: + log.debug("Waiting for Lavalink server to be ready") + for i in itertools.cycle(range(50)): + line = await self._proc.stdout.readline() + if READY_LINE_RE.search(line): + self.ready.set() + break + if self._proc.returncode is not None: + log.critical("Internal lavalink server exited early") + if i == 49: + # Sleep after 50 lines to prevent busylooping + await asyncio.sleep(0.1) - log.info("Lavalink jar started. PID: {}".format(proc.pid)) - global shutdown - shutdown = False + async def _monitor(self) -> None: + while self._proc.returncode is None: + await asyncio.sleep(0.5) - loop.create_task(monitor_lavalink_server(loop)) + # This task hasn't been cancelled - Lavalink was shut down by something else + log.info("Internal Lavalink jar shutdown unexpectedly") + if not self._has_java_error(): + log.info("Restarting internal Lavalink server") + await self.start() + else: + log.critical( + "Your Java is borked. Please find the hs_err_pid{}.log file" + " in the Audio data folder and report this issue.", + self._proc.pid, + ) + def _has_java_error(self) -> bool: + poss_error_file = LAVALINK_DOWNLOAD_DIR / "hs_err_pid{}.log".format(self._proc.pid) + return poss_error_file.exists() -def shutdown_lavalink_server(): - global shutdown - shutdown = True - global proc - if proc is not None: - log.info("Shutting down lavalink server.") - proc.terminate() - proc.wait() - proc = None + async def shutdown(self) -> None: + if self._shutdown is True or self._proc is None: + # For convenience, calling this method more than once or calling it before starting it + # does nothing. + return + log.info("Shutting down internal Lavalink server") + if self._monitor_task is not None: + self._monitor_task.cancel() + self._proc.terminate() + await self._proc.wait() + self._shutdown = True + @staticmethod + async def _download_jar() -> None: + log.info("Downloading Lavalink.jar...") + async with aiohttp.ClientSession() as session: + async with session.get(LAVALINK_DOWNLOAD_URL) as response: + if response.status == 404: + raise RuntimeError( + f"Lavalink jar version {JAR_VERSION}_{JAR_BUILD} hasn't been published" + ) + fd, path = tempfile.mkstemp() + file = open(fd, "wb") + try: + chunk = await response.content.read(1024) + while chunk: + file.write(chunk) + chunk = await response.content.read(1024) + file.flush() + finally: + file.close() + pathlib.Path(path).replace(LAVALINK_JAR_FILE) -async def download_lavalink(session): - from . import LAVALINK_DOWNLOAD_URL, LAVALINK_JAR_FILE + @classmethod + async def _is_up_to_date(cls): + if cls._up_to_date is True: + # Return cached value if we've checked this before + return True + args = await cls._get_jar_args() + args.append("--version") + _proc = await asyncio.subprocess.create_subprocess_exec( + *args, + cwd=str(LAVALINK_DOWNLOAD_DIR), + stdout=asyncio.subprocess.PIPE, + stderr=asyncio.subprocess.STDOUT, + ) + stdout = (await _proc.communicate())[0] + match = BUILD_LINE_RE.search(stdout) + if not match: + # Output is unexpected, suspect corrupted jarfile + return False + build = int(match["build"]) + cls._up_to_date = build == JAR_BUILD + return cls._up_to_date - with LAVALINK_JAR_FILE.open(mode="wb") as f: - async with session.get(LAVALINK_DOWNLOAD_URL) as resp: - while True: - chunk = await resp.content.read(512) - if not chunk: - break - f.write(chunk) - - -async def maybe_download_lavalink(loop, cog): - from . import LAVALINK_DOWNLOAD_DIR, LAVALINK_JAR_FILE, BUNDLED_APP_YML_FILE, APP_YML_FILE - - jar_exists = LAVALINK_JAR_FILE.exists() - current_build = redbot.VersionInfo.from_json(await cog.config.current_version()) - - if not jar_exists or current_build < redbot.core.version_info: - log.info("Downloading Lavalink.jar") - LAVALINK_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True) - async with ClientSession(loop=loop) as session: - await download_lavalink(session) - await cog.config.current_version.set(redbot.core.version_info.to_json()) - - shutil.copyfile(str(BUNDLED_APP_YML_FILE), str(APP_YML_FILE)) + @classmethod + async def maybe_download_jar(cls): + if not (LAVALINK_JAR_FILE.exists() and await cls._is_up_to_date()): + await cls._download_jar()