Use our own redbot.core.VersionInfo over distutils.StrictVersion (#2188)

* Implements our required subset of PEP 440 in redbot.core.VersionInfo
* Added unit tests for version string parsing and comparisons

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
Toby Harradine 2018-10-06 19:11:05 +10:00 committed by GitHub
parent de4b42a11e
commit 91029b73e5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 187 additions and 49 deletions

View File

@ -34,7 +34,7 @@ async def download_lavalink(session):
async def maybe_download_lavalink(loop, cog): async def maybe_download_lavalink(loop, cog):
jar_exists = LAVALINK_JAR_FILE.exists() jar_exists = LAVALINK_JAR_FILE.exists()
current_build = redbot.core.VersionInfo(*await cog.config.current_build()) current_build = redbot.core.VersionInfo.from_json(await cog.config.current_build())
if not jar_exists or current_build < redbot.core.version_info: if not jar_exists or current_build < redbot.core.version_info:
log.info("Downloading Lavalink.jar") log.info("Downloading Lavalink.jar")

View File

@ -1,40 +1,152 @@
import re as _re
from math import inf as _inf
from typing import (
Any as _Any,
ClassVar as _ClassVar,
Dict as _Dict,
List as _List,
Optional as _Optional,
Pattern as _Pattern,
Tuple as _Tuple,
Union as _Union,
)
from .config import Config from .config import Config
__all__ = ["Config", "__version__"] __all__ = ["Config", "__version__", "version_info", "VersionInfo"]
class VersionInfo: class VersionInfo:
def __init__(self, major, minor, micro, releaselevel, serial): ALPHA = "alpha"
self._levels = ["alpha", "beta", "release candidate", "final"] BETA = "beta"
self.major = major RELEASE_CANDIDATE = "release candidate"
self.minor = minor FINAL = "final"
self.micro = micro
if releaselevel not in self._levels: _VERSION_STR_PATTERN: _ClassVar[_Pattern[str]] = _re.compile(
raise TypeError("'releaselevel' must be one of: {}".format(", ".join(self._levels))) r"^"
r"(?P<major>0|[1-9]\d*)\.(?P<minor>0|[1-9]\d*)\.(?P<micro>0|[1-9]\d*)"
self.releaselevel = releaselevel r"(?:(?P<releaselevel>a|b|rc)(?P<serial>0|[1-9]\d*))?"
self.serial = serial r"(?:\.post(?P<post_release>0|[1-9]\d*))?"
r"(?:\.dev(?P<dev_release>0|[1-9]\d*))?"
def __lt__(self, other): r"$",
my_index = self._levels.index(self.releaselevel) flags=_re.IGNORECASE,
other_index = self._levels.index(other.releaselevel)
return (self.major, self.minor, self.micro, my_index, self.serial) < (
other.major,
other.minor,
other.micro,
other_index,
other.serial,
) )
_RELEASE_LEVELS: _ClassVar[_List[str]] = [ALPHA, BETA, RELEASE_CANDIDATE, FINAL]
_SHORT_RELEASE_LEVELS: _ClassVar[_Dict[str, str]] = {
"a": ALPHA,
"b": BETA,
"rc": RELEASE_CANDIDATE,
}
def __repr__(self): def __init__(
return "VersionInfo(major={}, minor={}, micro={}, releaselevel={}, serial={})".format( self,
self.major, self.minor, self.micro, self.releaselevel, self.serial major: int,
minor: int,
micro: int,
releaselevel: str,
serial: _Optional[int] = None,
post_release: _Optional[int] = None,
dev_release: _Optional[int] = None,
) -> None:
self.major: int = major
self.minor: int = minor
self.micro: int = micro
if releaselevel not in self._RELEASE_LEVELS:
raise TypeError(f"'releaselevel' must be one of: {', '.join(self._RELEASE_LEVELS)}")
self.releaselevel: str = releaselevel
self.serial: _Optional[int] = serial
self.post_release: _Optional[int] = post_release
self.dev_release: _Optional[int] = dev_release
@classmethod
def from_str(cls, version_str: str) -> "VersionInfo":
"""Parse a string into a VersionInfo object.
Raises
------
ValueError
If the version info string is invalid.
"""
match = cls._VERSION_STR_PATTERN.match(version_str)
if not match:
raise ValueError(f"Invalid version string: {version_str}")
kwargs: _Dict[str, _Union[str, int]] = {}
for key in ("major", "minor", "micro"):
kwargs[key] = int(match[key])
releaselevel = match["releaselevel"]
if releaselevel is not None:
kwargs["releaselevel"] = cls._SHORT_RELEASE_LEVELS[releaselevel]
else:
kwargs["releaselevel"] = cls.FINAL
for key in ("serial", "post_release", "dev_release"):
if match[key] is not None:
kwargs[key] = int(match[key])
return cls(**kwargs)
@classmethod
def from_json(
cls, data: _Union[_Dict[str, _Union[int, str]], _List[_Union[int, str]]]
) -> "VersionInfo":
if isinstance(data, _List):
# For old versions, data was stored as a list:
# [MAJOR, MINOR, MICRO, RELEASELEVEL, SERIAL]
return cls(*data)
else:
return cls(**data)
def to_json(self) -> _Dict[str, _Union[int, str]]:
return {
"major": self.major,
"minor": self.minor,
"micro": self.micro,
"releaselevel": self.releaselevel,
"serial": self.serial,
"post_release": self.post_release,
"dev_release": self.dev_release,
}
def __lt__(self, other: _Any) -> bool:
if not isinstance(other, VersionInfo):
return NotImplemented
tups: _List[_Tuple[int, int, int, int, int, int, int]] = []
for obj in (self, other):
tups.append(
(
obj.major,
obj.minor,
obj.micro,
obj._RELEASE_LEVELS.index(obj.releaselevel),
obj.serial if obj.serial is not None else _inf,
obj.post_release if obj.post_release is not None else -_inf,
obj.dev_release if obj.dev_release is not None else _inf,
) )
)
return tups[0] < tups[1]
def to_json(self): def __str__(self) -> str:
return [self.major, self.minor, self.micro, self.releaselevel, self.serial] ret = f"{self.major}.{self.minor}.{self.micro}"
if self.releaselevel != self.FINAL:
short = next(
k for k, v in self._SHORT_RELEASE_LEVELS.items() if v == self.releaselevel
)
ret += f"{short}{self.serial}"
if self.post_release is not None:
ret += f".post{self.post_release}"
if self.dev_release is not None:
ret += f".dev{self.dev_release}"
return ret
def __repr__(self) -> str:
return (
"VersionInfo(major={major}, minor={minor}, micro={micro}, "
"releaselevel={releaselevel}, serial={serial}, post={post_release}, "
"dev={dev_release})".format(**self.to_json())
)
__version__ = "3.0.0rc1" __version__ = "3.0.0rc1"
version_info = VersionInfo(3, 0, 0, "release candidate", 1) version_info = VersionInfo.from_str(__version__)

View File

@ -13,17 +13,20 @@ from collections import namedtuple
from pathlib import Path from pathlib import Path
from random import SystemRandom from random import SystemRandom
from string import ascii_letters, digits from string import ascii_letters, digits
from distutils.version import StrictVersion
from typing import TYPE_CHECKING, Union from typing import TYPE_CHECKING, Union
import aiohttp import aiohttp
import discord import discord
import pkg_resources import pkg_resources
from redbot.core import __version__ from redbot.core import (
from redbot.core import checks __version__,
from redbot.core import i18n version_info as red_version_info,
from redbot.core import commands VersionInfo,
checks,
commands,
i18n,
)
from .utils.predicates import MessagePredicate from .utils.predicates import MessagePredicate
from .utils.chat_formatting import pagify, box, inline from .utils.chat_formatting import pagify, box, inline
@ -274,7 +277,7 @@ class Core(commands.Cog, CoreLogic):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get("{}/json".format(red_pypi)) as r: async with session.get("{}/json".format(red_pypi)) as r:
data = await r.json() data = await r.json()
outdated = StrictVersion(data["info"]["version"]) > StrictVersion(__version__) outdated = VersionInfo.from_str(data["info"]["version"]) > red_version_info
about = ( about = (
"This is an instance of [Red, an open source Discord bot]({}) " "This is an instance of [Red, an open source Discord bot]({}) "
"created by [Twentysix]({}) and [improved by many]({}).\n\n" "created by [Twentysix]({}) and [improved by many]({}).\n\n"

View File

@ -1,10 +1,10 @@
import contextlib
import sys import sys
import codecs import codecs
import datetime import datetime
import logging import logging
import traceback import traceback
from datetime import timedelta from datetime import timedelta
from distutils.version import StrictVersion
from typing import List from typing import List
import aiohttp import aiohttp
@ -13,7 +13,7 @@ import pkg_resources
from colorama import Fore, Style, init from colorama import Fore, Style, init
from pkg_resources import DistributionNotFound from pkg_resources import DistributionNotFound
from . import __version__, commands from . import __version__ as red_version, version_info as red_version_info, VersionInfo, commands
from .data_manager import storage_type from .data_manager import storage_type
from .utils.chat_formatting import inline, bordered, humanize_list from .utils.chat_formatting import inline, bordered, humanize_list
from .utils import fuzzy_command_search, format_fuzzy_results from .utils import fuzzy_command_search, format_fuzzy_results
@ -105,7 +105,6 @@ def init_events(bot, cli_flags):
prefixes = cli_flags.prefix or (await bot.db.prefix()) prefixes = cli_flags.prefix or (await bot.db.prefix())
lang = await bot.db.locale() lang = await bot.db.locale()
red_version = __version__
red_pkg = pkg_resources.get_distribution("Red-DiscordBot") red_pkg = pkg_resources.get_distribution("Red-DiscordBot")
dpy_version = discord.__version__ dpy_version = discord.__version__
@ -125,24 +124,22 @@ def init_events(bot, cli_flags):
INFO.append("{} cogs with {} commands".format(len(bot.cogs), len(bot.commands))) INFO.append("{} cogs with {} commands".format(len(bot.cogs), len(bot.commands)))
try: with contextlib.suppress(aiohttp.ClientError, discord.HTTPException):
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
async with session.get("https://pypi.python.org/pypi/red-discordbot/json") as r: async with session.get("https://pypi.python.org/pypi/red-discordbot/json") as r:
data = await r.json() data = await r.json()
if StrictVersion(data["info"]["version"]) > StrictVersion(red_version): if VersionInfo.from_str(data["info"]["version"]) > red_version_info:
INFO.append( INFO.append(
"Outdated version! {} is available " "Outdated version! {} is available "
"but you're using {}".format(data["info"]["version"], red_version) "but you're using {}".format(data["info"]["version"], red_version)
) )
owner = discord.utils.get(bot.get_all_members(), id=bot.owner_id) owner = await bot.get_user_info(bot.owner_id)
await owner.send( await owner.send(
"Your Red instance is out of date! {} is the current " "Your Red instance is out of date! {} is the current "
"version, however you are using {}!".format( "version, however you are using {}!".format(
data["info"]["version"], red_version data["info"]["version"], red_version
) )
) )
except:
pass
INFO2 = [] INFO2 = []
sentry = await bot.db.enable_sentry() sentry = await bot.db.enable_sentry()

View File

@ -8,18 +8,14 @@ import asyncio
import aiohttp import aiohttp
import pkg_resources import pkg_resources
from pathlib import Path
from distutils.version import StrictVersion
from redbot.setup import ( from redbot.setup import (
basic_setup, basic_setup,
load_existing_config, load_existing_config,
remove_instance, remove_instance,
remove_instance_interaction, remove_instance_interaction,
create_backup, create_backup,
save_config,
) )
from redbot.core import __version__ from redbot.core import __version__, version_info as red_version_info, VersionInfo
from redbot.core.utils import safe_delete
from redbot.core.cli import confirm from redbot.core.cli import confirm
if sys.platform == "linux": if sys.platform == "linux":
@ -390,7 +386,7 @@ async def is_outdated():
async with session.get("{}/json".format(red_pypi)) as r: async with session.get("{}/json".format(red_pypi)) as r:
data = await r.json() data = await r.json()
new_version = data["info"]["version"] new_version = data["info"]["version"]
return StrictVersion(new_version) > StrictVersion(__version__), new_version return VersionInfo.from_str(new_version) > red_version_info, new_version
def main_menu(): def main_menu():

View File

@ -1,6 +1,36 @@
from redbot import core from redbot import core
from redbot.core import VersionInfo
def test_version_working(): def test_version_working():
assert hasattr(core, "__version__") assert hasattr(core, "__version__")
assert core.__version__[0] == "3" assert core.__version__[0] == "3"
# When adding more of these, ensure they are added in ascending order of precedence
version_tests = (
"3.0.0a32.post10.dev12",
"3.0.0rc1.dev1",
"3.0.0rc1",
"3.0.0",
"3.0.1",
"3.0.1.post1.dev1",
"3.0.1.post1",
"2018.10.6b21",
)
def test_version_info_str_parsing():
for version_str in version_tests:
assert version_str == str(VersionInfo.from_str(version_str))
def test_version_info_lt():
for next_idx, cur in enumerate(version_tests[:-1], start=1):
cur_test = VersionInfo.from_str(cur)
next_test = VersionInfo.from_str(version_tests[next_idx])
assert cur_test < next_test
def test_version_info_gt():
assert VersionInfo.from_str(version_tests[1]) > VersionInfo.from_str(version_tests[0])