[Audio] Refactor internal Lavalink server management (#2495)

* Refactor internal Lavalink server management

Killing many birds with one stone here.
- Made server manager into class-based API with two public methods: `start()` and `shutdown()`. Must be re-instantiated each time it is restarted.
- Using V3 universal Lavalink.jar hosted on Cog-Creators/Lavalink-Jars repository.
- Uses output of `java -jar Lavalink.jar --version` to check if a new jar needs to be downloaded.
- `ServerManager.start()` won't return until server is ready, i.e. when "Started Launcher in X seconds" message is printed to STDOUT.
- `shlex.quote()` is used so spaces in path to Lavalink.jar don't cause issues.
- Enabling external Lavalink will cause internal server to be terminated.
- Disabling internal Lavalink will no longer reset settings in config - instead, hard-coded values will be used when connecting to an internal server.
- Internal server will now run both WS and REST servers on port 2333, meaning one less port will need to be taken up.
- Now using `asyncio.subprocess` module so waiting on and reading from subprocesses can be done asynchronously.

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>

* Don't use shlex.quote on Windows

Signed-off-by: Toby <tobyharradine@gmail.com>

* Don't use shlex.quote at all

I misread a note in the python docs and assumed it was best to use it. Turns out the note only applies to `asyncio.create_subprocess_shell`.

Signed-off-by: Toby <tobyharradine@gmail.com>

* Missed the port on the rebase

* Ignore invalid architectures and inform users when commands are used.

* Style fix
This commit is contained in:
Toby Harradine 2019-04-30 11:31:28 +10:00 committed by Will
parent c79b5e6179
commit 476f441c9b
4 changed files with 287 additions and 202 deletions

View File

@ -1,31 +1,9 @@
from pathlib import Path
import logging
from redbot.core import commands
from .audio import Audio
from .manager import start_lavalink_server, maybe_download_lavalink
from redbot.core import commands
from redbot.core.data_manager import cog_data_path
import redbot.core
log = logging.getLogger("red.audio")
LAVALINK_DOWNLOAD_URL = (
"https://github.com/Cog-Creators/Red-DiscordBot/releases/download/{}/Lavalink.jar"
).format(redbot.core.__version__)
LAVALINK_DOWNLOAD_DIR = cog_data_path(raw_name="Audio")
LAVALINK_JAR_FILE = LAVALINK_DOWNLOAD_DIR / "Lavalink.jar"
APP_YML_FILE = LAVALINK_DOWNLOAD_DIR / "application.yml"
BUNDLED_APP_YML_FILE = Path(__file__).parent / "data/application.yml"
async def setup(bot: commands.Bot):
cog = Audio(bot)
if not await cog.config.use_external_lavalink():
await maybe_download_lavalink(bot.loop, cog)
await start_lavalink_server(bot.loop)
await cog.initialize()
bot.add_cog(cog)

View File

@ -14,6 +14,7 @@ import os
import random
import re
import time
from typing import Optional
import redbot.core
from redbot.core import Config, commands, checks, bank
from redbot.core.data_manager import cog_data_path
@ -29,7 +30,7 @@ from redbot.core.utils.menus import (
)
from redbot.core.utils.predicates import MessagePredicate, ReactionPredicate
from urllib.parse import urlparse
from .manager import shutdown_lavalink_server, start_lavalink_server, maybe_download_lavalink
from .manager import ServerManager
_ = Translator("Audio", __file__)
@ -43,41 +44,45 @@ log = logging.getLogger("red.audio")
class Audio(commands.Cog):
"""Play audio through voice channels."""
_default_lavalink_settings = {
"host": "localhost",
"rest_port": 2333,
"ws_port": 2333,
"password": "youshallnotpass",
}
def __init__(self, bot):
super().__init__()
self.bot = bot
self.config = Config.get_conf(self, 2711759130, force_registration=True)
default_global = {
"host": "localhost",
"rest_port": "2333",
"ws_port": "2332",
"password": "youshallnotpass",
"status": False,
"current_version": redbot.core.VersionInfo.from_str("3.0.0a0").to_json(),
"use_external_lavalink": False,
"restrict": True,
"localpath": str(cog_data_path(raw_name="Audio")),
}
default_global = dict(
status=False,
use_external_lavalink=False,
restrict=True,
current_version=redbot.core.VersionInfo.from_str("3.0.0a0").to_json(),
localpath=str(cog_data_path(raw_name="Audio")),
**self._default_lavalink_settings,
)
default_guild = {
"disconnect": False,
"dj_enabled": False,
"dj_role": None,
"emptydc_enabled": False,
"emptydc_timer": 0,
"jukebox": False,
"jukebox_price": 0,
"maxlength": 0,
"playlists": {},
"notify": False,
"repeat": False,
"shuffle": False,
"thumbnail": False,
"volume": 100,
"vote_enabled": False,
"vote_percent": 0,
}
default_guild = dict(
disconnect=False,
dj_enabled=False,
dj_role=None,
emptydc_enabled=False,
emptydc_timer=0,
jukebox=False,
jukebox_price=0,
maxlength=0,
playlists={},
notify=False,
repeat=False,
shuffle=False,
thumbnail=False,
volume=100,
vote_enabled=False,
vote_percent=0,
)
self.config.register_guild(**default_guild)
self.config.register_global(**default_global)
@ -86,9 +91,24 @@ class Audio(commands.Cog):
self._connect_task = None
self._disconnect_task = None
self._cleaned_up = False
self.spotify_token = None
self.play_lock = {}
self._manager: Optional[ServerManager] = None
async def cog_before_invoke(self, ctx):
if self.llsetup in [ctx.command, ctx.command.root_parent]:
pass
elif self._connect_task.cancelled:
await ctx.send(
"You have attempted to run Audio's Lavalink server on an unsupported"
" architecture. Only settings related commands will be available."
)
raise RuntimeError(
"Not running audio command due to invalid machine architecture for Lavalink."
)
async def initialize(self):
self._restart_connect()
self._disconnect_task = self.bot.loop.create_task(self.disconnect_timer())
@ -103,16 +123,33 @@ class Audio(commands.Cog):
async def attempt_connect(self, timeout: int = 30):
while True: # run until success
external = await self.config.use_external_lavalink()
if not external:
shutdown_lavalink_server()
await maybe_download_lavalink(self.bot.loop, self)
await start_lavalink_server(self.bot.loop)
try:
if external is False:
settings = self._default_lavalink_settings
host = settings["host"]
password = settings["password"]
rest_port = settings["rest_port"]
ws_port = settings["ws_port"]
if self._manager is not None:
await self._manager.shutdown()
self._manager = ServerManager()
try:
await self._manager.start()
except RuntimeError as exc:
log.exception(
"Exception whilst starting internal Lavalink server, retrying...",
exc_info=exc,
)
await asyncio.sleep(1)
continue
except asyncio.CancelledError:
log.exception("Invalid machine architecture, cannot run Lavalink.")
break
else:
host = await self.config.host()
password = await self.config.password()
rest_port = await self.config.rest_port()
ws_port = await self.config.ws_port()
try:
await lavalink.initialize(
bot=self.bot,
host=host,
@ -122,9 +159,10 @@ class Audio(commands.Cog):
timeout=timeout,
)
return # break infinite loop
except Exception:
if not external:
shutdown_lavalink_server()
except asyncio.TimeoutError:
log.error("Connecting to Lavalink server timed out, retrying...")
if external is False and self._manager is not None:
await self._manager.shutdown()
await asyncio.sleep(1) # prevent busylooping
async def event_handler(self, player, event_type, extra):
@ -3104,19 +3142,16 @@ class Audio(commands.Cog):
await self.config.use_external_lavalink.set(not external)
if external:
await self.config.host.set("localhost")
await self.config.password.set("youshallnotpass")
await self.config.rest_port.set(2333)
await self.config.ws_port.set(2332)
embed = discord.Embed(
colour=await ctx.embed_colour(),
title=_("External lavalink server: {true_or_false}.").format(
true_or_false=not external
),
)
embed.set_footer(text=_("Defaults reset."))
await ctx.send(embed=embed)
else:
if self._manager is not None:
await self._manager.shutdown()
await self._embed_msg(
ctx,
_("External lavalink server: {true_or_false}.").format(true_or_false=not external),
@ -3229,6 +3264,8 @@ class Audio(commands.Cog):
async def _check_external(self):
external = await self.config.use_external_lavalink()
if not external:
if self._manager is not None:
await self._manager.shutdown()
await self.config.use_external_lavalink.set(True)
return True
else:
@ -3597,7 +3634,8 @@ class Audio(commands.Cog):
lavalink.unregister_event_listener(self.event_handler)
self.bot.loop.create_task(lavalink.close())
shutdown_lavalink_server()
if self._manager is not None:
self.bot.loop.create_task(self._manager.shutdown())
self._cleaned_up = True
__del__ = cog_unload

View File

@ -1,11 +1,9 @@
server:
host: "localhost"
port: 2333 # REST server
lavalink:
server:
password: "youshallnotpass"
ws:
host: "localhost"
port: 2332
sources:
youtube: true
bandcamp: true

View File

@ -1,172 +1,243 @@
import shlex
import itertools
import pathlib
import platform
import shutil
import asyncio
import asyncio.subprocess
import os
import logging
import re
from subprocess import Popen, DEVNULL
from typing import Optional, Tuple
import tempfile
from typing import Optional, Tuple, ClassVar, List
from aiohttp import ClientSession
import aiohttp
import redbot.core
from redbot.core import data_manager
_JavaVersion = Tuple[int, int]
JAR_VERSION = "3.2.0.3"
JAR_BUILD = 751
LAVALINK_DOWNLOAD_URL = (
f"https://github.com/Cog-Creators/Lavalink-Jars/releases/download/{JAR_VERSION}_{JAR_BUILD}/"
f"Lavalink.jar"
)
LAVALINK_DOWNLOAD_DIR = data_manager.cog_data_path(raw_name="Audio")
LAVALINK_JAR_FILE = LAVALINK_DOWNLOAD_DIR / "Lavalink.jar"
BUNDLED_APP_YML = pathlib.Path(__file__).parent / "data" / "application.yml"
LAVALINK_APP_YML = LAVALINK_DOWNLOAD_DIR / "application.yml"
READY_LINE_RE = re.compile(rb"Started Launcher in \S+ seconds")
BUILD_LINE_RE = re.compile(rb"Build:\s+(?P<build>\d+)")
log = logging.getLogger("red.audio.manager")
proc = None
shutdown = False
class ServerManager:
def has_java_error(pid):
from . import LAVALINK_DOWNLOAD_DIR
_java_available: ClassVar[Optional[bool]] = None
_java_version: ClassVar[Optional[Tuple[int, int]]] = None
_up_to_date: ClassVar[Optional[bool]] = None
poss_error_file = LAVALINK_DOWNLOAD_DIR / "hs_err_pid{}.log".format(pid)
return poss_error_file.exists()
_blacklisted_archs = ["armv6l", "aarch32", "aarch64"]
def __init__(self) -> None:
self.ready = asyncio.Event()
async def monitor_lavalink_server(loop):
global shutdown
while shutdown is False:
if proc.poll() is not None:
break
await asyncio.sleep(0.5)
self._proc: Optional[asyncio.subprocess.Process] = None
self._monitor_task: Optional[asyncio.Task] = None
self._shutdown: bool = False
if shutdown is False:
# Lavalink was shut down by something else
log.info("Lavalink jar shutdown.")
shutdown = True
if not has_java_error(proc.pid):
log.info("Restarting Lavalink jar.")
await start_lavalink_server(loop)
else:
log.error(
"Your Java is borked. Please find the hs_err_pid{}.log file"
" in the Audio data folder and report this issue.".format(proc.pid)
async def start(self) -> None:
arch_name = platform.machine()
if arch_name in self._blacklisted_archs:
raise asyncio.CancelledError(
"You are attempting to run Lavalink audio on an unsupported machine architecture."
)
if self._proc is not None:
if self._proc.returncode is None:
raise RuntimeError("Internal Lavalink server is already running")
else:
raise RuntimeError("Server manager has already been used - create another one")
async def has_java(loop) -> Tuple[bool, Optional[_JavaVersion]]:
java_available = shutil.which("java") is not None
if not java_available:
return False, None
await self.maybe_download_jar()
version = await get_java_version(loop)
return (2, 0) > version >= (1, 8) or version >= (8, 0), version
# Copy the application.yml across.
# For people to customise their Lavalink server configuration they need to run it
# externally
shutil.copyfile(BUNDLED_APP_YML, LAVALINK_APP_YML)
args = await self._get_jar_args()
self._proc = await asyncio.subprocess.create_subprocess_exec(
*args,
cwd=str(LAVALINK_DOWNLOAD_DIR),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
async def get_java_version(loop) -> _JavaVersion:
"""
This assumes we've already checked that java exists.
"""
_proc: asyncio.subprocess.Process = await asyncio.create_subprocess_exec(
"java",
"-version",
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.PIPE,
loop=loop,
)
# java -version outputs to stderr
_, err = await _proc.communicate()
log.info("Internal Lavalink server started. PID: %s", self._proc.pid)
version_info: str = err.decode("utf-8")
# We expect the output to look something like:
# $ java -version
# ...
# ... version "MAJOR.MINOR.PATCH[_BUILD]" ...
# ...
# We only care about the major and minor parts though.
version_line_re = re.compile(
r'version "(?P<major>\d+).(?P<minor>\d+).\d+(?:_\d+)?(?:-[A-Za-z0-9]+)?"'
)
short_version_re = re.compile(r'version "(?P<major>\d+)"')
try:
await asyncio.wait_for(self._wait_for_launcher(), timeout=120)
except asyncio.TimeoutError:
log.warning("Timeout occurred whilst waiting for internal Lavalink server to be ready")
lines = version_info.splitlines()
for line in lines:
match = version_line_re.search(line)
short_match = short_version_re.search(line)
if match:
return int(match["major"]), int(match["minor"])
elif short_match:
return int(short_match["major"]), 0
self._monitor_task = asyncio.create_task(self._monitor())
raise RuntimeError(
"The output of `java -version` was unexpected. Please report this issue on Red's "
"issue tracker."
)
@classmethod
async def _get_jar_args(cls) -> List[str]:
java_available, java_version = await cls._has_java()
if not java_available:
raise RuntimeError("You must install Java 1.8+ for Lavalink to run.")
if java_version == (1, 8):
extra_flags = ["-Dsun.zip.disableMemoryMapping=true"]
elif java_version >= (11, 0):
extra_flags = ["-Djdk.tls.client.protocols=TLSv1.2"]
else:
extra_flags = []
async def start_lavalink_server(loop):
java_available, java_version = await has_java(loop)
if not java_available:
raise RuntimeError("You must install Java 1.8+ for Lavalink to run.")
return ["java", *extra_flags, "-jar", str(LAVALINK_JAR_FILE)]
if java_version == (1, 8):
extra_flags = "-Dsun.zip.disableMemoryMapping=true"
elif java_version >= (11, 0):
extra_flags = "-Djdk.tls.client.protocols=TLSv1.2"
else:
extra_flags = ""
@classmethod
async def _has_java(cls) -> Tuple[bool, Optional[Tuple[int, int]]]:
if cls._java_available is not None:
# Return cached value if we've checked this before
return cls._java_available, cls._java_version
java_available = shutil.which("java") is not None
if not java_available:
cls.java_available = False
cls.java_version = None
else:
cls._java_version = version = await cls._get_java_version()
cls._java_available = (2, 0) > version >= (1, 8) or version >= (8, 0)
return cls._java_available, cls._java_version
from . import LAVALINK_DOWNLOAD_DIR, LAVALINK_JAR_FILE
@staticmethod
async def _get_java_version() -> Tuple[int, int]:
"""
This assumes we've already checked that java exists.
"""
_proc: asyncio.subprocess.Process = await asyncio.create_subprocess_exec(
"java", "-version", stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE
)
# java -version outputs to stderr
_, err = await _proc.communicate()
start_cmd = "java {} -jar {}".format(extra_flags, LAVALINK_JAR_FILE.resolve())
version_info: str = err.decode("utf-8")
# We expect the output to look something like:
# $ java -version
# ...
# ... version "MAJOR.MINOR.PATCH[_BUILD]" ...
# ...
# We only care about the major and minor parts though.
version_line_re = re.compile(
r'version "(?P<major>\d+).(?P<minor>\d+).\d+(?:_\d+)?(?:-[A-Za-z0-9]+)?"'
)
short_version_re = re.compile(r'version "(?P<major>\d+)"')
global proc
lines = version_info.splitlines()
for line in lines:
match = version_line_re.search(line)
short_match = short_version_re.search(line)
if match:
return int(match["major"]), int(match["minor"])
elif short_match:
return int(short_match["major"]), 0
if proc and proc.poll() is None:
return # already running
raise RuntimeError(
"The output of `java -version` was unexpected. Please report this issue on Red's "
"issue tracker."
)
proc = Popen(
shlex.split(start_cmd, posix=os.name == "posix"),
cwd=str(LAVALINK_DOWNLOAD_DIR),
stdout=DEVNULL,
stderr=DEVNULL,
)
async def _wait_for_launcher(self) -> None:
log.debug("Waiting for Lavalink server to be ready")
for i in itertools.cycle(range(50)):
line = await self._proc.stdout.readline()
if READY_LINE_RE.search(line):
self.ready.set()
break
if self._proc.returncode is not None:
log.critical("Internal lavalink server exited early")
if i == 49:
# Sleep after 50 lines to prevent busylooping
await asyncio.sleep(0.1)
log.info("Lavalink jar started. PID: {}".format(proc.pid))
global shutdown
shutdown = False
async def _monitor(self) -> None:
while self._proc.returncode is None:
await asyncio.sleep(0.5)
loop.create_task(monitor_lavalink_server(loop))
# This task hasn't been cancelled - Lavalink was shut down by something else
log.info("Internal Lavalink jar shutdown unexpectedly")
if not self._has_java_error():
log.info("Restarting internal Lavalink server")
await self.start()
else:
log.critical(
"Your Java is borked. Please find the hs_err_pid{}.log file"
" in the Audio data folder and report this issue.",
self._proc.pid,
)
def _has_java_error(self) -> bool:
poss_error_file = LAVALINK_DOWNLOAD_DIR / "hs_err_pid{}.log".format(self._proc.pid)
return poss_error_file.exists()
def shutdown_lavalink_server():
global shutdown
shutdown = True
global proc
if proc is not None:
log.info("Shutting down lavalink server.")
proc.terminate()
proc.wait()
proc = None
async def shutdown(self) -> None:
if self._shutdown is True or self._proc is None:
# For convenience, calling this method more than once or calling it before starting it
# does nothing.
return
log.info("Shutting down internal Lavalink server")
if self._monitor_task is not None:
self._monitor_task.cancel()
self._proc.terminate()
await self._proc.wait()
self._shutdown = True
@staticmethod
async def _download_jar() -> None:
log.info("Downloading Lavalink.jar...")
async with aiohttp.ClientSession() as session:
async with session.get(LAVALINK_DOWNLOAD_URL) as response:
if response.status == 404:
raise RuntimeError(
f"Lavalink jar version {JAR_VERSION}_{JAR_BUILD} hasn't been published"
)
fd, path = tempfile.mkstemp()
file = open(fd, "wb")
try:
chunk = await response.content.read(1024)
while chunk:
file.write(chunk)
chunk = await response.content.read(1024)
file.flush()
finally:
file.close()
pathlib.Path(path).replace(LAVALINK_JAR_FILE)
async def download_lavalink(session):
from . import LAVALINK_DOWNLOAD_URL, LAVALINK_JAR_FILE
@classmethod
async def _is_up_to_date(cls):
if cls._up_to_date is True:
# Return cached value if we've checked this before
return True
args = await cls._get_jar_args()
args.append("--version")
_proc = await asyncio.subprocess.create_subprocess_exec(
*args,
cwd=str(LAVALINK_DOWNLOAD_DIR),
stdout=asyncio.subprocess.PIPE,
stderr=asyncio.subprocess.STDOUT,
)
stdout = (await _proc.communicate())[0]
match = BUILD_LINE_RE.search(stdout)
if not match:
# Output is unexpected, suspect corrupted jarfile
return False
build = int(match["build"])
cls._up_to_date = build == JAR_BUILD
return cls._up_to_date
with LAVALINK_JAR_FILE.open(mode="wb") as f:
async with session.get(LAVALINK_DOWNLOAD_URL) as resp:
while True:
chunk = await resp.content.read(512)
if not chunk:
break
f.write(chunk)
async def maybe_download_lavalink(loop, cog):
from . import LAVALINK_DOWNLOAD_DIR, LAVALINK_JAR_FILE, BUNDLED_APP_YML_FILE, APP_YML_FILE
jar_exists = LAVALINK_JAR_FILE.exists()
current_build = redbot.VersionInfo.from_json(await cog.config.current_version())
if not jar_exists or current_build < redbot.core.version_info:
log.info("Downloading Lavalink.jar")
LAVALINK_DOWNLOAD_DIR.mkdir(parents=True, exist_ok=True)
async with ClientSession(loop=loop) as session:
await download_lavalink(session)
await cog.config.current_version.set(redbot.core.version_info.to_json())
shutil.copyfile(str(BUNDLED_APP_YML_FILE), str(APP_YML_FILE))
@classmethod
async def maybe_download_jar(cls):
if not (LAVALINK_JAR_FILE.exists() and await cls._is_up_to_date()):
await cls._download_jar()