Add map(), find() and next() methods to AsyncIter (#3921)

* properly handle prefixes

* Docsss and typehinting

* aaaaaaaaaaa

* Apply suggestions from code review

* ffs

* docs

* docs

* Apply suggestions from code review

Co-authored-by: jack1142 <6032823+jack1142@users.noreply.github.com>

* skip await if map is none

* implement `.next()`

Co-authored-by: jack1142 <6032823+jack1142@users.noreply.github.com>
This commit is contained in:
Draper 2020-06-21 18:47:48 +01:00 committed by GitHub
parent 81f146a2ef
commit 477186d09d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -17,15 +17,17 @@ from typing import (
Tuple,
TypeVar,
Union,
Set,
TYPE_CHECKING,
Generator,
Coroutine,
)
from discord.utils import maybe_coroutine
__all__ = ("bounded_gather", "bounded_gather_iter", "deduplicate_iterables", "AsyncIter")
_T = TypeVar("_T")
_T = TypeVar("_T")
_S = TypeVar("_S")
# Benchmarked to be the fastest method.
def deduplicate_iterables(*iterables):
@ -297,6 +299,7 @@ class AsyncIter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=dupl
self._iterator = iter(iterable)
self._i = 0
self._steps = steps
self._map = None
def __aiter__(self) -> AsyncIter[_T]:
return self
@ -310,7 +313,7 @@ class AsyncIter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=dupl
self._i = 0
await asyncio.sleep(self._delay)
self._i += 1
return item
return await maybe_coroutine(self._map, item) if self._map is not None else item
def __await__(self) -> Generator[Any, None, List[_T]]:
"""Returns a list of the iterable.
@ -325,6 +328,37 @@ class AsyncIter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=dupl
"""
return self.flatten().__await__()
async def next(self, default: Any = ...) -> _T:
"""Returns a next entry of the iterable.
Parameters
----------
default: Optional[Any]
The value to return if the iterator is exhausted.
Raises
------
StopAsyncIteration
When ``default`` is not specified and the iterator has been exhausted.
Examples
--------
>>> from redbot.core.utils import AsyncIter
>>> iterator = AsyncIter(range(5))
>>> await iterator.next()
0
>>> await iterator.next()
1
"""
try:
value = await self.__anext__()
except StopAsyncIteration:
if default is ...:
raise
value = default
return value
async def flatten(self) -> List[_T]:
"""Returns a list of the iterable.
@ -339,8 +373,7 @@ class AsyncIter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=dupl
return [item async for item in self]
def filter(self, function: Callable[[_T], Union[bool, Awaitable[bool]]]) -> AsyncFilter[_T]:
"""
Filter the iterable with an (optionally async) predicate.
"""Filter the iterable with an (optionally async) predicate.
Parameters
----------
@ -424,3 +457,69 @@ class AsyncIter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=dupl
yield item
_temp.add(item)
del _temp
async def find(
self,
predicate: Callable[[_T], Union[bool, Awaitable[bool]]],
default: Optional[Any] = None,
) -> AsyncIterator[_T]:
"""Calls ``predicate`` over items in iterable and return first value to match.
Parameters
----------
predicate: Union[Callable, Coroutine]
A function that returns a boolean-like result. The predicate provided can be a coroutine.
default: Optional[Any]
The value to return if there are no matches.
Raises
------
TypeError
When ``predicate`` is not a callable.
Examples
--------
>>> from redbot.core.utils import AsyncIter
>>> await AsyncIter(range(3)).find(lambda x: x == 1)
1
"""
while True:
try:
elem = await self.__anext__()
except StopAsyncIteration:
return default
ret = await maybe_coroutine(predicate, elem)
if ret:
return elem
def map(self, func: Callable[[_T], Union[_S, Awaitable[_S]]]) -> AsyncIter[_S]:
"""Set the mapping callable for this instance of `AsyncIter`.
.. important::
This should be called after AsyncIter initialization and before any other of its methods.
Parameters
----------
func: Union[Callable, Coroutine]
The function to map values to. The function provided can be a coroutine.
Raises
------
TypeError
When ``func`` is not a callable.
Examples
--------
>>> from redbot.core.utils import AsyncIter
>>> async for value in AsyncIter(range(3)).map(bool):
... print(value)
False
True
True
"""
if not callable(func):
raise TypeError("Mapping must be a callable.")
self._map = func
return self