Use rich.progress instead of tqdm (#5064)

* Use rich progress instead of tqdm

* Remove tqdm from deps
This commit is contained in:
jack1142 2021-06-03 21:37:53 +02:00 committed by GitHub
parent 0ce2634bb3
commit 8f390147c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 49 additions and 106 deletions

View File

@ -13,7 +13,7 @@ import time
from typing import ClassVar, Final, List, Optional, Pattern, Tuple from typing import ClassVar, Final, List, Optional, Pattern, Tuple
import aiohttp import aiohttp
from tqdm import tqdm import rich.progress
from redbot.core import data_manager from redbot.core import data_manager
from redbot.core.i18n import Translator from redbot.core.i18n import Translator
@ -297,22 +297,23 @@ class ServerManager:
fd, path = tempfile.mkstemp() fd, path = tempfile.mkstemp()
file = open(fd, "wb") file = open(fd, "wb")
nbytes = 0 nbytes = 0
with tqdm( with rich.progress.Progress(
desc="Lavalink.jar", rich.progress.SpinnerColumn(),
total=response.content_length, rich.progress.TextColumn("[progress.description]{task.description}"),
file=sys.stdout, rich.progress.BarColumn(),
unit="B", rich.progress.TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
unit_scale=True, rich.progress.TimeRemainingColumn(),
miniters=1, rich.progress.TimeElapsedColumn(),
dynamic_ncols=True, ) as progress:
leave=False, progress_task_id = progress.add_task(
) as progress_bar: "[red]Downloading Lavalink.jar", total=response.content_length
)
try: try:
chunk = await response.content.read(1024) chunk = await response.content.read(1024)
while chunk: while chunk:
chunk_size = file.write(chunk) chunk_size = file.write(chunk)
nbytes += chunk_size nbytes += chunk_size
progress_bar.update(chunk_size) progress.update(progress_task_id, advance=chunk_size)
chunk = await response.content.read(1024) chunk = await response.content.read(1024)
file.flush() file.flush()
finally: finally:

View File

@ -2,7 +2,9 @@ import abc
import enum import enum
from typing import Tuple, Dict, Any, Union, List, AsyncIterator, Type from typing import Tuple, Dict, Any, Union, List, AsyncIterator, Type
from redbot.core.utils._internal_utils import async_tqdm import rich.progress
from redbot.core.utils._internal_utils import RichIndefiniteBarColumn
__all__ = ["BaseDriver", "IdentifierData", "ConfigCategory"] __all__ = ["BaseDriver", "IdentifierData", "ConfigCategory"]
@ -282,22 +284,27 @@ class BaseDriver(abc.ABC):
""" """
# Backend-agnostic method of migrating from one driver to another. # Backend-agnostic method of migrating from one driver to another.
cogs_progress_bar = async_tqdm( with rich.progress.Progress(
(tup async for tup in cls.aiter_cogs()), rich.progress.SpinnerColumn(),
desc="Migration progress", rich.progress.TextColumn("[progress.description]{task.description}"),
unit=" cogs", RichIndefiniteBarColumn(),
bar_format="{desc}: {n_fmt}{unit} [{elapsed},{rate_noinv_fmt}{postfix}]", rich.progress.TextColumn("{task.completed} cogs processed"),
leave=False, rich.progress.TimeElapsedColumn(),
dynamic_ncols=True, ) as progress:
miniters=1, cog_count = 0
) tid = progress.add_task("[yellow]Migrating", completed=cog_count, total=cog_count + 1)
async for cog_name, cog_id in cogs_progress_bar: async for cog_name, cog_id in cls.aiter_cogs():
cogs_progress_bar.set_postfix_str(f"Working on {cog_name}...") progress.console.print(f"Working on {cog_name}...")
this_driver = cls(cog_name, cog_id) this_driver = cls(cog_name, cog_id)
other_driver = new_driver_cls(cog_name, cog_id) other_driver = new_driver_cls(cog_name, cog_id)
custom_group_data = all_custom_group_data.get(cog_name, {}).get(cog_id, {}) custom_group_data = all_custom_group_data.get(cog_name, {}).get(cog_id, {})
exported_data = await this_driver.export_data(custom_group_data) exported_data = await this_driver.export_data(custom_group_data)
await other_driver.import_data(exported_data, custom_group_data) await other_driver.import_data(exported_data, custom_group_data)
cog_count += 1
progress.update(tid, completed=cog_count, total=cog_count + 1)
progress.update(tid, total=cog_count)
print() print()
@classmethod @classmethod

View File

@ -33,7 +33,8 @@ import aiohttp
import discord import discord
import pkg_resources import pkg_resources
from fuzzywuzzy import fuzz, process from fuzzywuzzy import fuzz, process
from tqdm import tqdm from rich.progress import ProgressColumn
from rich.progress_bar import ProgressBar
from redbot import VersionInfo from redbot import VersionInfo
from redbot.core import data_manager from redbot.core import data_manager
@ -55,7 +56,7 @@ __all__ = (
"expected_version", "expected_version",
"fetch_latest_red_version_info", "fetch_latest_red_version_info",
"deprecated_removed", "deprecated_removed",
"async_tqdm", "RichIndefiniteBarColumn",
) )
_T = TypeVar("_T") _T = TypeVar("_T")
@ -347,76 +348,12 @@ def deprecated_removed(
) )
class _AsyncTqdm(AsyncIterator[_T], tqdm): class RichIndefiniteBarColumn(ProgressColumn):
def __init__(self, iterable: AsyncIterable[_T], *args, **kwargs) -> None: def render(self, task):
self.async_iterator = iterable.__aiter__() return ProgressBar(
super().__init__(self.infinite_generator(), *args, **kwargs) pulse=task.completed < task.total,
self.iterator = cast(Generator[None, bool, None], iter(self)) animation_time=task.get_time(),
width=40,
@staticmethod total=task.total,
def infinite_generator() -> Generator[None, bool, None]: completed=task.completed,
while True: )
# Generator can be forced to raise StopIteration by calling `g.send(True)`
current = yield
if current:
break
async def __anext__(self) -> _T:
try:
result = await self.async_iterator.__anext__()
except StopAsyncIteration:
# If the async iterator is exhausted, force-stop the tqdm iterator
with contextlib.suppress(StopIteration):
self.iterator.send(True)
raise
else:
next(self.iterator)
return result
def __aiter__(self) -> _AsyncTqdm[_T]:
return self
def async_tqdm(
iterable: Optional[Union[Iterable, AsyncIterable]] = None,
*args,
refresh_interval: float = 0.5,
**kwargs,
) -> Union[tqdm, _AsyncTqdm]:
"""Same as `tqdm() <https://tqdm.github.io>`_, except it can be used
in ``async for`` loops, and a task can be spawned to asynchronously
refresh the progress bar every ``refresh_interval`` seconds.
This should only be used for ``async for`` loops, or ``for`` loops
which ``await`` something slow between iterations.
Parameters
----------
iterable: Optional[Union[Iterable, AsyncIterable]]
The iterable to pass to ``tqdm()``. If this is an async
iterable, this function will return a wrapper
*args
Other positional arguments to ``tqdm()``.
refresh_interval : float
The sleep interval between the progress bar being refreshed, in
seconds. Defaults to 0.5. Set to 0 to disable the auto-
refresher.
**kwargs
Keyword arguments to ``tqdm()``.
"""
if isinstance(iterable, AsyncIterable):
progress_bar = _AsyncTqdm(iterable, *args, **kwargs)
else:
progress_bar = tqdm(iterable, *args, **kwargs)
if refresh_interval:
# The background task that refreshes the progress bar
async def _progress_bar_refresher() -> None:
while not progress_bar.disable:
await asyncio.sleep(refresh_interval)
progress_bar.refresh()
asyncio.create_task(_progress_bar_refresher())
return progress_bar

View File

@ -65,7 +65,6 @@ install_requires =
rich==9.9.0 rich==9.9.0
schema==0.7.4 schema==0.7.4
six==1.15.0 six==1.15.0
tqdm==4.56.2
typing-extensions==3.7.4.3 typing-extensions==3.7.4.3
uvloop==0.15.0; sys_platform != "win32" and platform_python_implementation == "CPython" uvloop==0.15.0; sys_platform != "win32" and platform_python_implementation == "CPython"
yarl==1.6.3 yarl==1.6.3

View File

@ -24,7 +24,6 @@ install_requires =
Red-Lavalink Red-Lavalink
rich rich
schema schema
tqdm
uvloop; sys_platform != "win32" and platform_python_implementation == "CPython" uvloop; sys_platform != "win32" and platform_python_implementation == "CPython"
PyNaCl PyNaCl