Fix managed LL subprocess's stdout overflowing and deadlocking (#5903)

Signed-off-by: Draper <27962761+Drapersniper@users.noreply.github.com>
Co-authored-by: Jakub Kuczys <me@jacken.men>
This commit is contained in:
Draper 2022-12-27 19:33:50 +00:00 committed by GitHub
parent 43ab6e2ef5
commit 1ab303bce7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 16 additions and 24 deletions

View File

@ -51,14 +51,6 @@ class NoProcessFound(ManagedLavalinkNodeException):
"""Exception thrown when the managed node process is not found"""
class IncorrectProcessFound(ManagedLavalinkNodeException):
"""Exception thrown when the managed node process is incorrect"""
class TooManyProcessFound(ManagedLavalinkNodeException):
"""Exception thrown when zombie processes are suspected"""
class LavalinkDownloadFailed(ManagedLavalinkNodeException, RuntimeError):
"""Downloading the Lavalink jar failed.

View File

@ -13,7 +13,6 @@ from typing import ClassVar, Final, List, Optional, Pattern, Tuple, Union, TYPE_
import aiohttp
import lavalink
import psutil
import rich.progress
import yaml
from discord.backoff import ExponentialBackoff
@ -32,8 +31,6 @@ from .errors import (
UnexpectedJavaResponseException,
EarlyExitException,
ManagedLavalinkNodeException,
TooManyProcessFound,
IncorrectProcessFound,
NoProcessFound,
NodeUnhealthy,
)
@ -261,12 +258,12 @@ class ServerManager:
self.ready: asyncio.Event = asyncio.Event()
self._config = config
self._proc: Optional[asyncio.subprocess.Process] = None # pylint:disable=no-member
self._node_pid: Optional[int] = None
self._shutdown: bool = False
self.start_monitor_task = None
self.timeout = timeout
self.cog = cog
self._args = []
self._pipe_task = None
@property
def path(self) -> Optional[str]:
@ -292,6 +289,11 @@ class ServerManager:
def build_time(self) -> Optional[str]:
return self._buildtime
async def _pipe_output(self):
with contextlib.suppress(asyncio.CancelledError):
async for __ in self._proc.stdout:
pass
async def _start(self, java_path: str) -> None:
arch_name = platform.machine()
self._java_exc = java_path
@ -333,8 +335,7 @@ class ServerManager:
stderr=asyncio.subprocess.STDOUT,
)
)
self._node_pid = self._proc.pid
log.info("Managed Lavalink node started. PID: %s", self._node_pid)
log.info("Managed Lavalink node started. PID: %s", self._proc.pid)
try:
await asyncio.wait_for(self._wait_for_launcher(), timeout=self.timeout)
except asyncio.TimeoutError:
@ -450,6 +451,7 @@ class ServerManager:
if b"Lavalink is ready to accept connections." in line:
self.ready.set()
log.info("Managed Lavalink node is ready to receive requests.")
self._pipe_task = asyncio.create_task(self._pipe_output())
break
if _FAILED_TO_START.search(line):
raise ManagedLavalinkStartFailure(
@ -469,19 +471,17 @@ class ServerManager:
async def _partial_shutdown(self) -> None:
self.ready.clear()
# In certain situations to await self._proc.wait() is invalid so waiting on it waits forever.
if self._shutdown is True:
# For convenience, calling this method more than once or calling it before starting it
# does nothing.
return
if self._node_pid:
with contextlib.suppress(psutil.Error):
p = psutil.Process(self._node_pid)
p.terminate()
p.kill()
if self._pipe_task:
self._pipe_task.cancel()
if self._proc is not None:
self._proc.terminate()
await self._proc.wait()
self._proc = None
self._shutdown = True
self._node_pid = None
async def _download_jar(self) -> None:
log.info("Downloading Lavalink.jar...")
@ -595,12 +595,12 @@ class ServerManager:
while True:
try:
self._shutdown = False
if self._node_pid is None or not psutil.pid_exists(self._node_pid):
if self._proc is None or self._proc.returncode is not None:
self.ready.clear()
await self._start(java_path=java_path)
while True:
await self.wait_until_ready(timeout=self.timeout)
if not psutil.pid_exists(self._node_pid):
if self._proc.returncode is not None:
raise NoProcessFound
try:
node = lavalink.get_all_nodes()[0]
@ -628,7 +628,7 @@ class ServerManager:
except Exception as exc:
log.debug(exc, exc_info=exc)
raise NodeUnhealthy(str(exc))
except (TooManyProcessFound, IncorrectProcessFound, NoProcessFound):
except NoProcessFound:
await self._partial_shutdown()
except asyncio.TimeoutError:
delay = backoff.delay()