diff --git a/redbot/core/utils/__init__.py b/redbot/core/utils/__init__.py index 9653f0d7e..0d2066a70 100644 --- a/redbot/core/utils/__init__.py +++ b/redbot/core/utils/__init__.py @@ -17,15 +17,17 @@ from typing import ( Tuple, TypeVar, Union, - Set, - TYPE_CHECKING, Generator, + Coroutine, ) +from discord.utils import maybe_coroutine + __all__ = ("bounded_gather", "bounded_gather_iter", "deduplicate_iterables", "AsyncIter") -_T = TypeVar("_T") +_T = TypeVar("_T") +_S = TypeVar("_S") # Benchmarked to be the fastest method. def deduplicate_iterables(*iterables): @@ -297,6 +299,7 @@ class AsyncIter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=dupl self._iterator = iter(iterable) self._i = 0 self._steps = steps + self._map = None def __aiter__(self) -> AsyncIter[_T]: return self @@ -310,7 +313,7 @@ class AsyncIter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=dupl self._i = 0 await asyncio.sleep(self._delay) self._i += 1 - return item + return await maybe_coroutine(self._map, item) if self._map is not None else item def __await__(self) -> Generator[Any, None, List[_T]]: """Returns a list of the iterable. @@ -325,6 +328,37 @@ class AsyncIter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=dupl """ return self.flatten().__await__() + async def next(self, default: Any = ...) -> _T: + """Returns a next entry of the iterable. + + Parameters + ---------- + default: Optional[Any] + The value to return if the iterator is exhausted. + + Raises + ------ + StopAsyncIteration + When ``default`` is not specified and the iterator has been exhausted. + + Examples + -------- + >>> from redbot.core.utils import AsyncIter + >>> iterator = AsyncIter(range(5)) + >>> await iterator.next() + 0 + >>> await iterator.next() + 1 + + """ + try: + value = await self.__anext__() + except StopAsyncIteration: + if default is ...: + raise + value = default + return value + async def flatten(self) -> List[_T]: """Returns a list of the iterable. @@ -339,8 +373,7 @@ class AsyncIter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=dupl return [item async for item in self] def filter(self, function: Callable[[_T], Union[bool, Awaitable[bool]]]) -> AsyncFilter[_T]: - """ - Filter the iterable with an (optionally async) predicate. + """Filter the iterable with an (optionally async) predicate. Parameters ---------- @@ -424,3 +457,69 @@ class AsyncIter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=dupl yield item _temp.add(item) del _temp + + async def find( + self, + predicate: Callable[[_T], Union[bool, Awaitable[bool]]], + default: Optional[Any] = None, + ) -> AsyncIterator[_T]: + """Calls ``predicate`` over items in iterable and return first value to match. + + Parameters + ---------- + predicate: Union[Callable, Coroutine] + A function that returns a boolean-like result. The predicate provided can be a coroutine. + default: Optional[Any] + The value to return if there are no matches. + + Raises + ------ + TypeError + When ``predicate`` is not a callable. + + Examples + -------- + >>> from redbot.core.utils import AsyncIter + >>> await AsyncIter(range(3)).find(lambda x: x == 1) + 1 + """ + while True: + try: + elem = await self.__anext__() + except StopAsyncIteration: + return default + ret = await maybe_coroutine(predicate, elem) + if ret: + return elem + + def map(self, func: Callable[[_T], Union[_S, Awaitable[_S]]]) -> AsyncIter[_S]: + """Set the mapping callable for this instance of `AsyncIter`. + + .. important:: + This should be called after AsyncIter initialization and before any other of its methods. + + Parameters + ---------- + func: Union[Callable, Coroutine] + The function to map values to. The function provided can be a coroutine. + + Raises + ------ + TypeError + When ``func`` is not a callable. + + Examples + -------- + >>> from redbot.core.utils import AsyncIter + >>> async for value in AsyncIter(range(3)).map(bool): + ... print(value) + False + True + True + + """ + + if not callable(func): + raise TypeError("Mapping must be a callable.") + self._map = func + return self