From 49bf103891e5ccf043f566add095480ca1f28197 Mon Sep 17 00:00:00 2001 From: Jakub Kuczys Date: Wed, 21 Jun 2023 15:52:00 +0200 Subject: [PATCH] Update the Lavalink version parsing and add tests for it (#6093) --- .github/labeler.yml | 2 + redbot/cogs/audio/manager.py | 88 ++++++++++++++++++++++---------- tests/cogs/audio/__init__.py | 0 tests/cogs/audio/test_manager.py | 76 +++++++++++++++++++++++++++ 4 files changed, 139 insertions(+), 27 deletions(-) create mode 100644 tests/cogs/audio/__init__.py create mode 100644 tests/cogs/audio/test_manager.py diff --git a/.github/labeler.yml b/.github/labeler.yml index 849b7db01..3eff68540 100644 --- a/.github/labeler.yml +++ b/.github/labeler.yml @@ -24,6 +24,8 @@ - "!redbot/cogs/audio/**/locales/*" # Docs - docs/cog_guides/audio.rst + # Tests + - tests/cogs/audio/**/* "Category: Cogs - Bank": [] # historical label for a removed cog "Category: Cogs - Cleanup": # Source diff --git a/redbot/cogs/audio/manager.py b/redbot/cogs/audio/manager.py index 00a17a7e2..3658de2b3 100644 --- a/redbot/cogs/audio/manager.py +++ b/redbot/cogs/audio/manager.py @@ -10,6 +10,7 @@ import shlex import shutil import tempfile from typing import ClassVar, Final, List, Optional, Pattern, Tuple, Union, TYPE_CHECKING +from typing_extensions import Self import aiohttp import lavalink @@ -47,9 +48,6 @@ if TYPE_CHECKING: _ = Translator("Audio", pathlib.Path(__file__)) log = getLogger("red.Audio.manager") -LAVALINK_DOWNLOAD_DIR: Final[pathlib.Path] = data_manager.cog_data_path(raw_name="Audio") -LAVALINK_JAR_FILE: Final[pathlib.Path] = LAVALINK_DOWNLOAD_DIR / "Lavalink.jar" -LAVALINK_APP_YML: Final[pathlib.Path] = LAVALINK_DOWNLOAD_DIR / "application.yml" _FAILED_TO_START: Final[Pattern] = re.compile(rb"Web server failed to start\. (.*)") @@ -109,6 +107,9 @@ LAVALINK_VERSION_LINE_PRE35: Final[Pattern] = re.compile( rb"^Version:\s+(?P\S+)$", re.MULTILINE | re.VERBOSE ) # used for LL 3.5-rc4 and newer +# This regex is limited to the realistic usage in the LL version number, +# not everything that could be a part of it according to the spec. +# We can easily release an update to this regex in the future if it ever becomes necessary. LAVALINK_VERSION_LINE: Final[Pattern] = re.compile( rb""" ^ @@ -117,9 +118,11 @@ LAVALINK_VERSION_LINE: Final[Pattern] = re.compile( (?P0|[1-9]\d*)\.(?P0|[1-9]\d*) # Before LL 3.6, when patch version == 0, it was stripped from the version string (?:\.(?P0|[1-9]\d*))? - (?:-rc(?P0|[1-9]\d*))? - # only used by our downstream Lavalink if we need to make a release before upstream - (?:_red(?P[1-9]\d*))? + # Before LL 3.6, the dot in rc.N was optional + (?:-rc\.?(?P0|[1-9]\d*))? + # additional build metadata, can be used by our downstream Lavalink + # if we need to alter an upstream release + (?:\+red\.(?P[1-9]\d*))? ) $ """, @@ -135,6 +138,19 @@ class LavalinkOldVersion: def __str__(self) -> None: return f"{self.raw_version}_{self.build_number}" + @classmethod + def from_version_output(cls, output: bytes) -> Self: + build_match = LAVALINK_BUILD_LINE.search(output) + if build_match is None: + raise ValueError("Could not find Build line in the given `--version` output.") + version_match = LAVALINK_VERSION_LINE_PRE35.search(output) + if version_match is None: + raise ValueError("Could not find Version line in the given `--version` output.") + return cls( + raw_version=version_match["version"].decode(), + build_number=int(build_match["build"]), + ) + def __eq__(self, other: object) -> bool: if isinstance(other, LavalinkOldVersion): return self.build_number == other.build_number @@ -195,6 +211,19 @@ class LavalinkVersion: version += f"_red{self.red}" return version + @classmethod + def from_version_output(cls, output: bytes) -> Self: + match = LAVALINK_VERSION_LINE.search(output) + if match is None: + raise ValueError("Could not find Version line in the given `--version` output.") + return LavalinkVersion( + major=int(match["major"]), + minor=int(match["minor"]), + patch=int(match["patch"] or 0), + rc=int(match["rc"]) if match["rc"] is not None else None, + red=int(match["red"] or 0), + ) + def _get_comparison_tuple(self) -> Tuple[int, int, int, bool, int, int]: return self.major, self.minor, self.patch, self.rc is None, self.rc or 0, self.red @@ -265,6 +294,18 @@ class ServerManager: self._args = [] self._pipe_task = None + @property + def lavalink_download_dir(self) -> pathlib.Path: + return data_manager.cog_data_path(raw_name="Audio") + + @property + def lavalink_jar_file(self) -> pathlib.Path: + return self.lavalink_download_dir / "Lavalink.jar" + + @property + def lavalink_app_yml(self) -> pathlib.Path: + return self.lavalink_download_dir / "application.yml" + @property def path(self) -> Optional[str]: return self._java_exc @@ -330,7 +371,7 @@ class ServerManager: self._proc = ( await asyncio.subprocess.create_subprocess_exec( # pylint:disable=no-member *args, - cwd=str(LAVALINK_DOWNLOAD_DIR), + cwd=str(self.lavalink_download_dir), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT, ) @@ -351,7 +392,7 @@ class ServerManager: async def process_settings(self): data = change_dict_naming_convention(await self._config.yaml.all()) - with open(LAVALINK_APP_YML, "w") as f: + with open(self.lavalink_app_yml, "w") as f: yaml.safe_dump(data, f) async def _get_jar_args(self) -> Tuple[List[str], Optional[str]]: @@ -392,7 +433,7 @@ class ServerManager: "please fix this by setting the correct value with '[p]llset heapsize'.", ) - command_args.extend(["-jar", str(LAVALINK_JAR_FILE)]) + command_args.extend(["-jar", str(self.lavalink_jar_file)]) self._args = command_args return command_args, invalid @@ -522,7 +563,7 @@ class ServerManager: finally: file.close() - shutil.move(path, str(LAVALINK_JAR_FILE), copy_function=shutil.copyfile) + shutil.move(path, str(self.lavalink_jar_file), copy_function=shutil.copyfile) log.info("Successfully downloaded Lavalink.jar (%s bytes written)", format(nbytes, ",")) await self._is_up_to_date() @@ -535,7 +576,7 @@ class ServerManager: args.append("--version") _proc = await asyncio.subprocess.create_subprocess_exec( # pylint:disable=no-member *args, - cwd=str(LAVALINK_DOWNLOAD_DIR), + cwd=str(self.lavalink_download_dir), stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.STDOUT, ) @@ -554,24 +595,17 @@ class ServerManager: return False if (build := LAVALINK_BUILD_LINE.search(stdout)) is not None: - if (version := LAVALINK_VERSION_LINE_PRE35.search(stdout)) is None: + try: + self._lavalink_version = LavalinkOldVersion.from_version_output(stdout) + except ValueError: # Output is unexpected, suspect corrupted jarfile return False - self._lavalink_version = LavalinkOldVersion( - raw_version=version["version"].decode(), - build_number=int(build["build"]), - ) - elif (version := LAVALINK_VERSION_LINE.search(stdout)) is not None: - self._lavalink_version = LavalinkVersion( - major=int(version["major"]), - minor=int(version["minor"]), - patch=int(version["patch"] or 0), - rc=int(version["rc"]) if version["rc"] is not None else None, - red=int(version["red"] or 0), - ) else: - # Output is unexpected, suspect corrupted jarfile - return False + try: + self._lavalink_version = LavalinkVersion.from_version_output(stdout) + except ValueError: + # Output is unexpected, suspect corrupted jarfile + return False date = buildtime["build_time"].decode() date = date.replace(".", "/") self._lavalink_branch = branch["branch"].decode() @@ -582,7 +616,7 @@ class ServerManager: return self._up_to_date async def maybe_download_jar(self): - if not (LAVALINK_JAR_FILE.exists() and await self._is_up_to_date()): + if not (self.lavalink_jar_file.exists() and await self._is_up_to_date()): await self._download_jar() async def wait_until_ready(self, timeout: Optional[float] = None): diff --git a/tests/cogs/audio/__init__.py b/tests/cogs/audio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/cogs/audio/test_manager.py b/tests/cogs/audio/test_manager.py new file mode 100644 index 000000000..9afb86e52 --- /dev/null +++ b/tests/cogs/audio/test_manager.py @@ -0,0 +1,76 @@ +import itertools + +import pytest + +from redbot.cogs.audio.manager import LavalinkOldVersion, LavalinkVersion + + +ORDERED_VERSIONS = [ + LavalinkOldVersion("3.3.2.3", build_number=1239), + LavalinkOldVersion("3.4.0", build_number=1275), + LavalinkOldVersion("3.4.0", build_number=1350), + # LavalinkVersion is always newer than LavalinkOldVersion + LavalinkVersion(3, 3), + LavalinkVersion(3, 4), + LavalinkVersion(3, 5, rc=1), + LavalinkVersion(3, 5, rc=2), + LavalinkVersion(3, 5, rc=3), + # version with `+red.N` build number is newer than an equivalent upstream version + LavalinkVersion(3, 5, rc=3, red=1), + LavalinkVersion(3, 5, rc=3, red=2), + # all RC versions (including ones with `+red.N`) are older than a stable version + LavalinkVersion(3, 5), + # version with `+red.N` build number is newer than an equivalent upstream version + LavalinkVersion(3, 5, red=1), + LavalinkVersion(3, 5, red=2), + # but newer version number without `+red.N` is still newer + LavalinkVersion(3, 5, 1), +] + + +@pytest.mark.parametrize( + "raw_version,raw_build_number,expected", + ( + # 3-segment version number + ("3.4.0", "1350", LavalinkOldVersion("3.4.0", build_number=1350)), + # 4-segment version number + ("3.3.2.3", "1239", LavalinkOldVersion("3.3.2.3", build_number=1239)), + # 3-segment version number with 3-digit build number + ("3.3.1", "987", LavalinkOldVersion("3.3.1", build_number=987)), + ), +) +def test_old_ll_version_parsing( + raw_version: str, raw_build_number: str, expected: LavalinkOldVersion +) -> None: + line = b"Version: %b\nBuild: %b" % (raw_version.encode(), raw_build_number.encode()) + assert LavalinkOldVersion.from_version_output(line) + + +@pytest.mark.parametrize( + "raw_version,expected", + ( + # older version format that allowed stripped `.0` and no dot in `rc.4`, used until LL 3.6 + ("3.5-rc4", LavalinkVersion(3, 5, rc=4)), + ("3.5", LavalinkVersion(3, 5)), + # newer version format + ("3.6.0-rc.1", LavalinkVersion(3, 6, 0, rc=1)), + # downstream RC version with `+red.N` suffix + ("3.7.5-rc.1+red.1", LavalinkVersion(3, 7, 5, rc=1, red=1)), + ("3.7.5-rc.1+red.123", LavalinkVersion(3, 7, 5, rc=1, red=123)), + # upstream stable version + ("3.7.5", LavalinkVersion(3, 7, 5)), + # downstream stable version with `+red.N` suffix + ("3.7.5+red.1", LavalinkVersion(3, 7, 5, red=1)), + ("3.7.5+red.123", LavalinkVersion(3, 7, 5, red=123)), + ), +) +def test_ll_version_parsing(raw_version: str, expected: LavalinkVersion) -> None: + line = b"Version: " + raw_version.encode() + assert LavalinkVersion.from_version_output(line) + + +def test_ll_version_comparison() -> None: + it1, it2 = itertools.tee(ORDERED_VERSIONS) + next(it2, None) + for a, b in zip(it1, it2): + assert a < b