From 1ab303bce782ef9a47d175c83bc22d085aeea1b0 Mon Sep 17 00:00:00 2001 From: Draper <27962761+Drapersniper@users.noreply.github.com> Date: Tue, 27 Dec 2022 19:33:50 +0000 Subject: [PATCH] Fix managed LL subprocess's stdout overflowing and deadlocking (#5903) Signed-off-by: Draper <27962761+Drapersniper@users.noreply.github.com> Co-authored-by: Jakub Kuczys --- redbot/cogs/audio/errors.py | 8 -------- redbot/cogs/audio/manager.py | 32 ++++++++++++++++---------------- 2 files changed, 16 insertions(+), 24 deletions(-) diff --git a/redbot/cogs/audio/errors.py b/redbot/cogs/audio/errors.py index 05d3d49c9..9bddf87aa 100644 --- a/redbot/cogs/audio/errors.py +++ b/redbot/cogs/audio/errors.py @@ -51,14 +51,6 @@ class NoProcessFound(ManagedLavalinkNodeException): """Exception thrown when the managed node process is not found""" -class IncorrectProcessFound(ManagedLavalinkNodeException): - """Exception thrown when the managed node process is incorrect""" - - -class TooManyProcessFound(ManagedLavalinkNodeException): - """Exception thrown when zombie processes are suspected""" - - class LavalinkDownloadFailed(ManagedLavalinkNodeException, RuntimeError): """Downloading the Lavalink jar failed. diff --git a/redbot/cogs/audio/manager.py b/redbot/cogs/audio/manager.py index b0c396872..921d0c9cc 100644 --- a/redbot/cogs/audio/manager.py +++ b/redbot/cogs/audio/manager.py @@ -13,7 +13,6 @@ from typing import ClassVar, Final, List, Optional, Pattern, Tuple, Union, TYPE_ import aiohttp import lavalink -import psutil import rich.progress import yaml from discord.backoff import ExponentialBackoff @@ -32,8 +31,6 @@ from .errors import ( UnexpectedJavaResponseException, EarlyExitException, ManagedLavalinkNodeException, - TooManyProcessFound, - IncorrectProcessFound, NoProcessFound, NodeUnhealthy, ) @@ -261,12 +258,12 @@ class ServerManager: self.ready: asyncio.Event = asyncio.Event() self._config = config self._proc: Optional[asyncio.subprocess.Process] = None # pylint:disable=no-member - self._node_pid: Optional[int] = None self._shutdown: bool = False self.start_monitor_task = None self.timeout = timeout self.cog = cog self._args = [] + self._pipe_task = None @property def path(self) -> Optional[str]: @@ -292,6 +289,11 @@ class ServerManager: def build_time(self) -> Optional[str]: return self._buildtime + async def _pipe_output(self): + with contextlib.suppress(asyncio.CancelledError): + async for __ in self._proc.stdout: + pass + async def _start(self, java_path: str) -> None: arch_name = platform.machine() self._java_exc = java_path @@ -333,8 +335,7 @@ class ServerManager: stderr=asyncio.subprocess.STDOUT, ) ) - self._node_pid = self._proc.pid - log.info("Managed Lavalink node started. PID: %s", self._node_pid) + log.info("Managed Lavalink node started. PID: %s", self._proc.pid) try: await asyncio.wait_for(self._wait_for_launcher(), timeout=self.timeout) except asyncio.TimeoutError: @@ -450,6 +451,7 @@ class ServerManager: if b"Lavalink is ready to accept connections." in line: self.ready.set() log.info("Managed Lavalink node is ready to receive requests.") + self._pipe_task = asyncio.create_task(self._pipe_output()) break if _FAILED_TO_START.search(line): raise ManagedLavalinkStartFailure( @@ -469,19 +471,17 @@ class ServerManager: async def _partial_shutdown(self) -> None: self.ready.clear() - # In certain situations to await self._proc.wait() is invalid so waiting on it waits forever. if self._shutdown is True: # For convenience, calling this method more than once or calling it before starting it # does nothing. return - if self._node_pid: - with contextlib.suppress(psutil.Error): - p = psutil.Process(self._node_pid) - p.terminate() - p.kill() + if self._pipe_task: + self._pipe_task.cancel() + if self._proc is not None: + self._proc.terminate() + await self._proc.wait() self._proc = None self._shutdown = True - self._node_pid = None async def _download_jar(self) -> None: log.info("Downloading Lavalink.jar...") @@ -595,12 +595,12 @@ class ServerManager: while True: try: self._shutdown = False - if self._node_pid is None or not psutil.pid_exists(self._node_pid): + if self._proc is None or self._proc.returncode is not None: self.ready.clear() await self._start(java_path=java_path) while True: await self.wait_until_ready(timeout=self.timeout) - if not psutil.pid_exists(self._node_pid): + if self._proc.returncode is not None: raise NoProcessFound try: node = lavalink.get_all_nodes()[0] @@ -628,7 +628,7 @@ class ServerManager: except Exception as exc: log.debug(exc, exc_info=exc) raise NodeUnhealthy(str(exc)) - except (TooManyProcessFound, IncorrectProcessFound, NoProcessFound): + except NoProcessFound: await self._partial_shutdown() except asyncio.TimeoutError: delay = backoff.delay()