diff --git a/redbot/core/drivers/base.py b/redbot/core/drivers/base.py index 59b1ac7d8..016438281 100644 --- a/redbot/core/drivers/base.py +++ b/redbot/core/drivers/base.py @@ -2,6 +2,8 @@ import abc import enum from typing import Tuple, Dict, Any, Union, List, AsyncIterator, Type +from redbot.core.utils._internal_utils import async_tqdm + __all__ = ["BaseDriver", "IdentifierData", "ConfigCategory"] @@ -280,12 +282,23 @@ class BaseDriver(abc.ABC): """ # Backend-agnostic method of migrating from one driver to another. - async for cog_name, cog_id in cls.aiter_cogs(): + cogs_progress_bar = async_tqdm( + (tup async for tup in cls.aiter_cogs()), + desc="Migration progress", + unit=" cogs", + bar_format="{desc}: {n_fmt}{unit} [{elapsed},{rate_noinv_fmt}{postfix}]", + leave=False, + dynamic_ncols=True, + miniters=1, + ) + async for cog_name, cog_id in cogs_progress_bar: + cogs_progress_bar.set_postfix_str(f"Working on {cog_name}...") this_driver = cls(cog_name, cog_id) other_driver = new_driver_cls(cog_name, cog_id) custom_group_data = all_custom_group_data.get(cog_name, {}).get(cog_id, {}) exported_data = await this_driver.export_data(custom_group_data) await other_driver.import_data(exported_data, custom_group_data) + print() @classmethod async def delete_all_data(cls, **kwargs) -> None: diff --git a/redbot/core/utils/_internal_utils.py b/redbot/core/utils/_internal_utils.py index f289ee3ed..f5e442a34 100644 --- a/redbot/core/utils/_internal_utils.py +++ b/redbot/core/utils/_internal_utils.py @@ -2,6 +2,7 @@ from __future__ import annotations import asyncio import collections.abc +import contextlib import json import logging import os @@ -12,23 +13,29 @@ import warnings from datetime import datetime from pathlib import Path from typing import ( + AsyncIterable, AsyncIterator, Awaitable, Callable, + Generator, + Iterable, Iterator, List, Optional, Union, + TypeVar, TYPE_CHECKING, Tuple, + cast, ) import aiohttp import discord import pkg_resources from fuzzywuzzy import fuzz, process -from redbot import VersionInfo +from tqdm import tqdm +from redbot import VersionInfo from redbot.core import data_manager from redbot.core.utils.chat_formatting import box @@ -48,8 +55,11 @@ __all__ = ( "expected_version", "fetch_latest_red_version_info", "deprecated_removed", + "async_tqdm", ) +_T = TypeVar("_T") + def safe_delete(pth: Path): if pth.exists(): @@ -335,3 +345,78 @@ def deprecated_removed( DeprecationWarning, stacklevel=stacklevel + 1, ) + + +class _AsyncTqdm(AsyncIterator[_T], tqdm): + def __init__(self, iterable: AsyncIterable[_T], *args, **kwargs) -> None: + self.async_iterator = iterable.__aiter__() + super().__init__(self.infinite_generator(), *args, **kwargs) + self.iterator = cast(Generator[None, bool, None], iter(self)) + + @staticmethod + def infinite_generator() -> Generator[None, bool, None]: + while True: + # Generator can be forced to raise StopIteration by calling `g.send(True)` + current = yield + if current: + break + + async def __anext__(self) -> _T: + try: + result = await self.async_iterator.__anext__() + except StopAsyncIteration: + # If the async iterator is exhausted, force-stop the tqdm iterator + with contextlib.suppress(StopIteration): + self.iterator.send(True) + raise + else: + next(self.iterator) + return result + + def __aiter__(self) -> _AsyncTqdm[_T]: + return self + + +def async_tqdm( + iterable: Optional[Union[Iterable, AsyncIterable]] = None, + *args, + refresh_interval: float = 0.5, + **kwargs, +) -> Union[tqdm, _AsyncTqdm]: + """Same as `tqdm() `_, except it can be used + in ``async for`` loops, and a task can be spawned to asynchronously + refresh the progress bar every ``refresh_interval`` seconds. + + This should only be used for ``async for`` loops, or ``for`` loops + which ``await`` something slow between iterations. + + Parameters + ---------- + iterable: Optional[Union[Iterable, AsyncIterable]] + The iterable to pass to ``tqdm()``. If this is an async + iterable, this function will return a wrapper + *args + Other positional arguments to ``tqdm()``. + refresh_interval : float + The sleep interval between the progress bar being refreshed, in + seconds. Defaults to 0.5. Set to 0 to disable the auto- + refresher. + **kwargs + Keyword arguments to ``tqdm()``. + + """ + if isinstance(iterable, AsyncIterable): + progress_bar = _AsyncTqdm(iterable, *args, **kwargs) + else: + progress_bar = tqdm(iterable, *args, **kwargs) + + if refresh_interval: + # The background task that refreshes the progress bar + async def _progress_bar_refresher() -> None: + while not progress_bar.disable: + await asyncio.sleep(refresh_interval) + progress_bar.refresh() + + asyncio.create_task(_progress_bar_refresher()) + + return progress_bar