Use rich.progress instead of tqdm (#5064)

* Use rich progress instead of tqdm

* Remove tqdm from deps
This commit is contained in:
jack1142 2021-06-03 21:37:53 +02:00 committed by GitHub
parent 0ce2634bb3
commit 8f390147c1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 49 additions and 106 deletions

View File

@ -13,7 +13,7 @@ import time
from typing import ClassVar, Final, List, Optional, Pattern, Tuple
import aiohttp
from tqdm import tqdm
import rich.progress
from redbot.core import data_manager
from redbot.core.i18n import Translator
@ -297,22 +297,23 @@ class ServerManager:
fd, path = tempfile.mkstemp()
file = open(fd, "wb")
nbytes = 0
with tqdm(
desc="Lavalink.jar",
total=response.content_length,
file=sys.stdout,
unit="B",
unit_scale=True,
miniters=1,
dynamic_ncols=True,
leave=False,
) as progress_bar:
with rich.progress.Progress(
rich.progress.SpinnerColumn(),
rich.progress.TextColumn("[progress.description]{task.description}"),
rich.progress.BarColumn(),
rich.progress.TextColumn("[progress.percentage]{task.percentage:>3.0f}%"),
rich.progress.TimeRemainingColumn(),
rich.progress.TimeElapsedColumn(),
) as progress:
progress_task_id = progress.add_task(
"[red]Downloading Lavalink.jar", total=response.content_length
)
try:
chunk = await response.content.read(1024)
while chunk:
chunk_size = file.write(chunk)
nbytes += chunk_size
progress_bar.update(chunk_size)
progress.update(progress_task_id, advance=chunk_size)
chunk = await response.content.read(1024)
file.flush()
finally:

View File

@ -2,7 +2,9 @@ import abc
import enum
from typing import Tuple, Dict, Any, Union, List, AsyncIterator, Type
from redbot.core.utils._internal_utils import async_tqdm
import rich.progress
from redbot.core.utils._internal_utils import RichIndefiniteBarColumn
__all__ = ["BaseDriver", "IdentifierData", "ConfigCategory"]
@ -282,22 +284,27 @@ class BaseDriver(abc.ABC):
"""
# Backend-agnostic method of migrating from one driver to another.
cogs_progress_bar = async_tqdm(
(tup async for tup in cls.aiter_cogs()),
desc="Migration progress",
unit=" cogs",
bar_format="{desc}: {n_fmt}{unit} [{elapsed},{rate_noinv_fmt}{postfix}]",
leave=False,
dynamic_ncols=True,
miniters=1,
)
async for cog_name, cog_id in cogs_progress_bar:
cogs_progress_bar.set_postfix_str(f"Working on {cog_name}...")
with rich.progress.Progress(
rich.progress.SpinnerColumn(),
rich.progress.TextColumn("[progress.description]{task.description}"),
RichIndefiniteBarColumn(),
rich.progress.TextColumn("{task.completed} cogs processed"),
rich.progress.TimeElapsedColumn(),
) as progress:
cog_count = 0
tid = progress.add_task("[yellow]Migrating", completed=cog_count, total=cog_count + 1)
async for cog_name, cog_id in cls.aiter_cogs():
progress.console.print(f"Working on {cog_name}...")
this_driver = cls(cog_name, cog_id)
other_driver = new_driver_cls(cog_name, cog_id)
custom_group_data = all_custom_group_data.get(cog_name, {}).get(cog_id, {})
exported_data = await this_driver.export_data(custom_group_data)
await other_driver.import_data(exported_data, custom_group_data)
cog_count += 1
progress.update(tid, completed=cog_count, total=cog_count + 1)
progress.update(tid, total=cog_count)
print()
@classmethod

View File

@ -33,7 +33,8 @@ import aiohttp
import discord
import pkg_resources
from fuzzywuzzy import fuzz, process
from tqdm import tqdm
from rich.progress import ProgressColumn
from rich.progress_bar import ProgressBar
from redbot import VersionInfo
from redbot.core import data_manager
@ -55,7 +56,7 @@ __all__ = (
"expected_version",
"fetch_latest_red_version_info",
"deprecated_removed",
"async_tqdm",
"RichIndefiniteBarColumn",
)
_T = TypeVar("_T")
@ -347,76 +348,12 @@ def deprecated_removed(
)
class _AsyncTqdm(AsyncIterator[_T], tqdm):
def __init__(self, iterable: AsyncIterable[_T], *args, **kwargs) -> None:
self.async_iterator = iterable.__aiter__()
super().__init__(self.infinite_generator(), *args, **kwargs)
self.iterator = cast(Generator[None, bool, None], iter(self))
@staticmethod
def infinite_generator() -> Generator[None, bool, None]:
while True:
# Generator can be forced to raise StopIteration by calling `g.send(True)`
current = yield
if current:
break
async def __anext__(self) -> _T:
try:
result = await self.async_iterator.__anext__()
except StopAsyncIteration:
# If the async iterator is exhausted, force-stop the tqdm iterator
with contextlib.suppress(StopIteration):
self.iterator.send(True)
raise
else:
next(self.iterator)
return result
def __aiter__(self) -> _AsyncTqdm[_T]:
return self
def async_tqdm(
iterable: Optional[Union[Iterable, AsyncIterable]] = None,
*args,
refresh_interval: float = 0.5,
**kwargs,
) -> Union[tqdm, _AsyncTqdm]:
"""Same as `tqdm() <https://tqdm.github.io>`_, except it can be used
in ``async for`` loops, and a task can be spawned to asynchronously
refresh the progress bar every ``refresh_interval`` seconds.
This should only be used for ``async for`` loops, or ``for`` loops
which ``await`` something slow between iterations.
Parameters
----------
iterable: Optional[Union[Iterable, AsyncIterable]]
The iterable to pass to ``tqdm()``. If this is an async
iterable, this function will return a wrapper
*args
Other positional arguments to ``tqdm()``.
refresh_interval : float
The sleep interval between the progress bar being refreshed, in
seconds. Defaults to 0.5. Set to 0 to disable the auto-
refresher.
**kwargs
Keyword arguments to ``tqdm()``.
"""
if isinstance(iterable, AsyncIterable):
progress_bar = _AsyncTqdm(iterable, *args, **kwargs)
else:
progress_bar = tqdm(iterable, *args, **kwargs)
if refresh_interval:
# The background task that refreshes the progress bar
async def _progress_bar_refresher() -> None:
while not progress_bar.disable:
await asyncio.sleep(refresh_interval)
progress_bar.refresh()
asyncio.create_task(_progress_bar_refresher())
return progress_bar
class RichIndefiniteBarColumn(ProgressColumn):
def render(self, task):
return ProgressBar(
pulse=task.completed < task.total,
animation_time=task.get_time(),
width=40,
total=task.total,
completed=task.completed,
)

View File

@ -65,7 +65,6 @@ install_requires =
rich==9.9.0
schema==0.7.4
six==1.15.0
tqdm==4.56.2
typing-extensions==3.7.4.3
uvloop==0.15.0; sys_platform != "win32" and platform_python_implementation == "CPython"
yarl==1.6.3

View File

@ -24,7 +24,6 @@ install_requires =
Red-Lavalink
rich
schema
tqdm
uvloop; sys_platform != "win32" and platform_python_implementation == "CPython"
PyNaCl