mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 11:18:54 -05:00
Add map(), find() and next() methods to AsyncIter (#3921)
* properly handle prefixes * Docsss and typehinting * aaaaaaaaaaa * Apply suggestions from code review * ffs * docs * docs * Apply suggestions from code review Co-authored-by: jack1142 <6032823+jack1142@users.noreply.github.com> * skip await if map is none * implement `.next()` Co-authored-by: jack1142 <6032823+jack1142@users.noreply.github.com>
This commit is contained in:
parent
81f146a2ef
commit
477186d09d
@ -17,15 +17,17 @@ from typing import (
|
||||
Tuple,
|
||||
TypeVar,
|
||||
Union,
|
||||
Set,
|
||||
TYPE_CHECKING,
|
||||
Generator,
|
||||
Coroutine,
|
||||
)
|
||||
|
||||
from discord.utils import maybe_coroutine
|
||||
|
||||
__all__ = ("bounded_gather", "bounded_gather_iter", "deduplicate_iterables", "AsyncIter")
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
_T = TypeVar("_T")
|
||||
_S = TypeVar("_S")
|
||||
|
||||
# Benchmarked to be the fastest method.
|
||||
def deduplicate_iterables(*iterables):
|
||||
@ -297,6 +299,7 @@ class AsyncIter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=dupl
|
||||
self._iterator = iter(iterable)
|
||||
self._i = 0
|
||||
self._steps = steps
|
||||
self._map = None
|
||||
|
||||
def __aiter__(self) -> AsyncIter[_T]:
|
||||
return self
|
||||
@ -310,7 +313,7 @@ class AsyncIter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=dupl
|
||||
self._i = 0
|
||||
await asyncio.sleep(self._delay)
|
||||
self._i += 1
|
||||
return item
|
||||
return await maybe_coroutine(self._map, item) if self._map is not None else item
|
||||
|
||||
def __await__(self) -> Generator[Any, None, List[_T]]:
|
||||
"""Returns a list of the iterable.
|
||||
@ -325,6 +328,37 @@ class AsyncIter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=dupl
|
||||
"""
|
||||
return self.flatten().__await__()
|
||||
|
||||
async def next(self, default: Any = ...) -> _T:
|
||||
"""Returns a next entry of the iterable.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
default: Optional[Any]
|
||||
The value to return if the iterator is exhausted.
|
||||
|
||||
Raises
|
||||
------
|
||||
StopAsyncIteration
|
||||
When ``default`` is not specified and the iterator has been exhausted.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from redbot.core.utils import AsyncIter
|
||||
>>> iterator = AsyncIter(range(5))
|
||||
>>> await iterator.next()
|
||||
0
|
||||
>>> await iterator.next()
|
||||
1
|
||||
|
||||
"""
|
||||
try:
|
||||
value = await self.__anext__()
|
||||
except StopAsyncIteration:
|
||||
if default is ...:
|
||||
raise
|
||||
value = default
|
||||
return value
|
||||
|
||||
async def flatten(self) -> List[_T]:
|
||||
"""Returns a list of the iterable.
|
||||
|
||||
@ -339,8 +373,7 @@ class AsyncIter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=dupl
|
||||
return [item async for item in self]
|
||||
|
||||
def filter(self, function: Callable[[_T], Union[bool, Awaitable[bool]]]) -> AsyncFilter[_T]:
|
||||
"""
|
||||
Filter the iterable with an (optionally async) predicate.
|
||||
"""Filter the iterable with an (optionally async) predicate.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@ -424,3 +457,69 @@ class AsyncIter(AsyncIterator[_T], Awaitable[List[_T]]): # pylint: disable=dupl
|
||||
yield item
|
||||
_temp.add(item)
|
||||
del _temp
|
||||
|
||||
async def find(
|
||||
self,
|
||||
predicate: Callable[[_T], Union[bool, Awaitable[bool]]],
|
||||
default: Optional[Any] = None,
|
||||
) -> AsyncIterator[_T]:
|
||||
"""Calls ``predicate`` over items in iterable and return first value to match.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
predicate: Union[Callable, Coroutine]
|
||||
A function that returns a boolean-like result. The predicate provided can be a coroutine.
|
||||
default: Optional[Any]
|
||||
The value to return if there are no matches.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError
|
||||
When ``predicate`` is not a callable.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from redbot.core.utils import AsyncIter
|
||||
>>> await AsyncIter(range(3)).find(lambda x: x == 1)
|
||||
1
|
||||
"""
|
||||
while True:
|
||||
try:
|
||||
elem = await self.__anext__()
|
||||
except StopAsyncIteration:
|
||||
return default
|
||||
ret = await maybe_coroutine(predicate, elem)
|
||||
if ret:
|
||||
return elem
|
||||
|
||||
def map(self, func: Callable[[_T], Union[_S, Awaitable[_S]]]) -> AsyncIter[_S]:
|
||||
"""Set the mapping callable for this instance of `AsyncIter`.
|
||||
|
||||
.. important::
|
||||
This should be called after AsyncIter initialization and before any other of its methods.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
func: Union[Callable, Coroutine]
|
||||
The function to map values to. The function provided can be a coroutine.
|
||||
|
||||
Raises
|
||||
------
|
||||
TypeError
|
||||
When ``func`` is not a callable.
|
||||
|
||||
Examples
|
||||
--------
|
||||
>>> from redbot.core.utils import AsyncIter
|
||||
>>> async for value in AsyncIter(range(3)).map(bool):
|
||||
... print(value)
|
||||
False
|
||||
True
|
||||
True
|
||||
|
||||
"""
|
||||
|
||||
if not callable(func):
|
||||
raise TypeError("Mapping must be a callable.")
|
||||
self._map = func
|
||||
return self
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user