mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-09 04:38:55 -05:00
[CogManager, Utils] Handle missing cogs correctly, add some helpful algorithms (#1989)
* Handle missing cogs correctly, add some helpful algorithms For cog loading, only show "cog not found" if the module in question was the one that failed to import. ImportErrors within cogs will show an error as they should. - deduplicator, benchmarked to be the fastest - bounded gather and bounded async as_completed - tests for all additions * Requested changes + wrap as_completed instead So I went source diving and realized as_completed works the way I want it to, and I don't need to reinvent the wheel for cancelling tasks that remain if the generator is `break`ed out of. So there's that.
This commit is contained in:
parent
b550f38eed
commit
1329fa1b09
@ -4,6 +4,12 @@
|
|||||||
Utility Functions
|
Utility Functions
|
||||||
=================
|
=================
|
||||||
|
|
||||||
|
General Utility
|
||||||
|
===============
|
||||||
|
|
||||||
|
.. automodule:: redbot.core.utils
|
||||||
|
:members: deduplicate_iterables, bounded_gather, bounded_gather_iter
|
||||||
|
|
||||||
Chat Formatting
|
Chat Formatting
|
||||||
===============
|
===============
|
||||||
|
|
||||||
|
|||||||
@ -3,9 +3,10 @@ import pkgutil
|
|||||||
from importlib import import_module, invalidate_caches
|
from importlib import import_module, invalidate_caches
|
||||||
from importlib.machinery import ModuleSpec
|
from importlib.machinery import ModuleSpec
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Tuple, Union, List
|
from typing import Tuple, Union, List, Optional
|
||||||
|
|
||||||
import redbot.cogs
|
import redbot.cogs
|
||||||
|
from redbot.core.utils import deduplicate_iterables
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
from . import checks, commands
|
from . import checks, commands
|
||||||
@ -18,12 +19,13 @@ from .utils.chat_formatting import box, pagify
|
|||||||
__all__ = ["CogManager"]
|
__all__ = ["CogManager"]
|
||||||
|
|
||||||
|
|
||||||
def _deduplicate(xs):
|
class NoSuchCog(ImportError):
|
||||||
ret = []
|
"""Thrown when a cog is missing.
|
||||||
for x in xs:
|
|
||||||
if x not in ret:
|
Different from ImportError because some ImportErrors can happen inside cogs.
|
||||||
ret.append(x)
|
"""
|
||||||
return ret
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class CogManager:
|
class CogManager:
|
||||||
@ -56,7 +58,7 @@ class CogManager:
|
|||||||
conf_paths = [Path(p) for p in await self.conf.paths()]
|
conf_paths = [Path(p) for p in await self.conf.paths()]
|
||||||
other_paths = self._paths
|
other_paths = self._paths
|
||||||
|
|
||||||
all_paths = _deduplicate(list(conf_paths) + list(other_paths) + [self.CORE_PATH])
|
all_paths = deduplicate_iterables(conf_paths, other_paths, [self.CORE_PATH])
|
||||||
|
|
||||||
if self.install_path not in all_paths:
|
if self.install_path not in all_paths:
|
||||||
all_paths.insert(0, await self.install_path())
|
all_paths.insert(0, await self.install_path())
|
||||||
@ -209,11 +211,10 @@ class CogManager:
|
|||||||
|
|
||||||
Raises
|
Raises
|
||||||
------
|
------
|
||||||
RuntimeError
|
NoSuchCog
|
||||||
When no matching spec can be found.
|
When no cog with the requested name was found.
|
||||||
"""
|
"""
|
||||||
resolved_paths = _deduplicate(await self.paths())
|
resolved_paths = await self.paths()
|
||||||
|
|
||||||
real_paths = [str(p) for p in resolved_paths if p != self.CORE_PATH]
|
real_paths = [str(p) for p in resolved_paths if p != self.CORE_PATH]
|
||||||
|
|
||||||
for finder, module_name, _ in pkgutil.iter_modules(real_paths):
|
for finder, module_name, _ in pkgutil.iter_modules(real_paths):
|
||||||
@ -222,9 +223,11 @@ class CogManager:
|
|||||||
if spec:
|
if spec:
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
raise RuntimeError(
|
raise NoSuchCog(
|
||||||
"No 3rd party module by the name of '{}' was found"
|
"No 3rd party module by the name of '{}' was found in any available path.".format(
|
||||||
" in any available path.".format(name)
|
name
|
||||||
|
),
|
||||||
|
name=name,
|
||||||
)
|
)
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -246,16 +249,24 @@ class CogManager:
|
|||||||
When no matching spec can be found.
|
When no matching spec can be found.
|
||||||
"""
|
"""
|
||||||
real_name = ".{}".format(name)
|
real_name = ".{}".format(name)
|
||||||
|
package = "redbot.cogs"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
mod = import_module(real_name, package="redbot.cogs")
|
mod = import_module(real_name, package=package)
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
raise RuntimeError(
|
if e.name == package + real_name:
|
||||||
"No core cog by the name of '{}' could be found.".format(name)
|
raise NoSuchCog(
|
||||||
) from e
|
"No core cog by the name of '{}' could be found.".format(name),
|
||||||
|
path=e.path,
|
||||||
|
name=e.name,
|
||||||
|
) from e
|
||||||
|
|
||||||
|
raise
|
||||||
|
|
||||||
return mod.__spec__
|
return mod.__spec__
|
||||||
|
|
||||||
# noinspection PyUnreachableCode
|
# noinspection PyUnreachableCode
|
||||||
async def find_cog(self, name: str) -> ModuleSpec:
|
async def find_cog(self, name: str) -> Optional[ModuleSpec]:
|
||||||
"""Find a cog in the list of available paths.
|
"""Find a cog in the list of available paths.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
@ -265,23 +276,16 @@ class CogManager:
|
|||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
importlib.machinery.ModuleSpec
|
Optional[importlib.machinery.ModuleSpec]
|
||||||
A module spec to be used for specialized cog loading.
|
A module spec to be used for specialized cog loading, if found.
|
||||||
|
|
||||||
Raises
|
|
||||||
------
|
|
||||||
RuntimeError
|
|
||||||
If there is no cog with the given name.
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
with contextlib.suppress(RuntimeError):
|
with contextlib.suppress(NoSuchCog):
|
||||||
return await self._find_ext_cog(name)
|
return await self._find_ext_cog(name)
|
||||||
|
|
||||||
with contextlib.suppress(RuntimeError):
|
with contextlib.suppress(NoSuchCog):
|
||||||
return await self._find_core_cog(name)
|
return await self._find_core_cog(name)
|
||||||
|
|
||||||
raise RuntimeError("No cog with that name could be found.")
|
|
||||||
|
|
||||||
async def available_modules(self) -> List[str]:
|
async def available_modules(self) -> List[str]:
|
||||||
"""Finds the names of all available modules to load.
|
"""Finds the names of all available modules to load.
|
||||||
"""
|
"""
|
||||||
|
|||||||
@ -77,9 +77,17 @@ class CoreLogic:
|
|||||||
for name in cog_names:
|
for name in cog_names:
|
||||||
try:
|
try:
|
||||||
spec = await bot.cog_mgr.find_cog(name)
|
spec = await bot.cog_mgr.find_cog(name)
|
||||||
cogspecs.append((spec, name))
|
if spec:
|
||||||
except RuntimeError:
|
cogspecs.append((spec, name))
|
||||||
notfound_packages.append(name)
|
else:
|
||||||
|
notfound_packages.append(name)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception("Package import failed", exc_info=e)
|
||||||
|
|
||||||
|
exception_log = "Exception during import of cog\n"
|
||||||
|
exception_log += "".join(traceback.format_exception(type(e), e, e.__traceback__))
|
||||||
|
bot._last_exception = exception_log
|
||||||
|
failed_packages.append(name)
|
||||||
|
|
||||||
for spec, name in cogspecs:
|
for spec, name in cogspecs:
|
||||||
try:
|
try:
|
||||||
@ -95,6 +103,7 @@ class CoreLogic:
|
|||||||
else:
|
else:
|
||||||
await bot.add_loaded_package(name)
|
await bot.add_loaded_package(name)
|
||||||
loaded_packages.append(name)
|
loaded_packages.append(name)
|
||||||
|
|
||||||
return loaded_packages, failed_packages, notfound_packages
|
return loaded_packages, failed_packages, notfound_packages
|
||||||
|
|
||||||
def _cleanup_and_refresh_modules(self, module_name: str):
|
def _cleanup_and_refresh_modules(self, module_name: str):
|
||||||
@ -511,7 +520,7 @@ class Core(CoreLogic):
|
|||||||
loaded, failed, not_found = await self._load(cog_names)
|
loaded, failed, not_found = await self._load(cog_names)
|
||||||
|
|
||||||
if loaded:
|
if loaded:
|
||||||
fmt = "Loaded {packs}"
|
fmt = "Loaded {packs}."
|
||||||
formed = self._get_package_strings(loaded, fmt)
|
formed = self._get_package_strings(loaded, fmt)
|
||||||
await ctx.send(formed)
|
await ctx.send(formed)
|
||||||
|
|
||||||
|
|||||||
@ -1,14 +1,31 @@
|
|||||||
__all__ = ["safe_delete", "fuzzy_command_search"]
|
__all__ = ["bounded_gather", "safe_delete", "fuzzy_command_search", "deduplicate_iterables"]
|
||||||
|
|
||||||
from pathlib import Path
|
import asyncio
|
||||||
import os
|
from asyncio import as_completed, AbstractEventLoop, Semaphore
|
||||||
import shutil
|
from asyncio.futures import isfuture
|
||||||
|
from itertools import chain
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
|
from pathlib import Path
|
||||||
|
import shutil
|
||||||
|
from typing import Any, Awaitable, Iterator, List, Optional
|
||||||
|
|
||||||
from redbot.core import commands
|
from redbot.core import commands
|
||||||
from fuzzywuzzy import process
|
from fuzzywuzzy import process
|
||||||
|
|
||||||
from .chat_formatting import box
|
from .chat_formatting import box
|
||||||
|
|
||||||
|
|
||||||
|
# Benchmarked to be the fastest method.
|
||||||
|
def deduplicate_iterables(*iterables):
|
||||||
|
"""
|
||||||
|
Returns a list of all unique items in ``iterables``, in the order they
|
||||||
|
were first encountered.
|
||||||
|
"""
|
||||||
|
# dict insertion order is guaranteed to be preserved in 3.6+
|
||||||
|
return list(dict.fromkeys(chain.from_iterable(iterables)))
|
||||||
|
|
||||||
|
|
||||||
def fuzzy_filter(record):
|
def fuzzy_filter(record):
|
||||||
return record.funcName != "extractWithoutOrder"
|
return record.funcName != "extractWithoutOrder"
|
||||||
|
|
||||||
@ -20,10 +37,13 @@ def safe_delete(pth: Path):
|
|||||||
if pth.exists():
|
if pth.exists():
|
||||||
for root, dirs, files in os.walk(str(pth)):
|
for root, dirs, files in os.walk(str(pth)):
|
||||||
os.chmod(root, 0o755)
|
os.chmod(root, 0o755)
|
||||||
|
|
||||||
for d in dirs:
|
for d in dirs:
|
||||||
os.chmod(os.path.join(root, d), 0o755)
|
os.chmod(os.path.join(root, d), 0o755)
|
||||||
|
|
||||||
for f in files:
|
for f in files:
|
||||||
os.chmod(os.path.join(root, f), 0o755)
|
os.chmod(os.path.join(root, f), 0o755)
|
||||||
|
|
||||||
shutil.rmtree(str(pth), ignore_errors=True)
|
shutil.rmtree(str(pth), ignore_errors=True)
|
||||||
|
|
||||||
|
|
||||||
@ -33,35 +53,41 @@ async def filter_commands(ctx: commands.Context, extracted: list):
|
|||||||
for i in extracted
|
for i in extracted
|
||||||
if i[1] >= 90
|
if i[1] >= 90
|
||||||
and not i[0].hidden
|
and not i[0].hidden
|
||||||
|
and not any([p.hidden for p in i[0].parents])
|
||||||
and await i[0].can_run(ctx)
|
and await i[0].can_run(ctx)
|
||||||
and all([await p.can_run(ctx) for p in i[0].parents])
|
and all([await p.can_run(ctx) for p in i[0].parents])
|
||||||
and not any([p.hidden for p in i[0].parents])
|
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def fuzzy_command_search(ctx: commands.Context, term: str):
|
async def fuzzy_command_search(ctx: commands.Context, term: str):
|
||||||
out = ""
|
out = []
|
||||||
|
|
||||||
if ctx.guild is not None:
|
if ctx.guild is not None:
|
||||||
enabled = await ctx.bot.db.guild(ctx.guild).fuzzy()
|
enabled = await ctx.bot.db.guild(ctx.guild).fuzzy()
|
||||||
else:
|
else:
|
||||||
enabled = await ctx.bot.db.fuzzy()
|
enabled = await ctx.bot.db.fuzzy()
|
||||||
|
|
||||||
if not enabled:
|
if not enabled:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
alias_cog = ctx.bot.get_cog("Alias")
|
alias_cog = ctx.bot.get_cog("Alias")
|
||||||
if alias_cog is not None:
|
if alias_cog is not None:
|
||||||
is_alias, alias = await alias_cog.is_alias(ctx.guild, term)
|
is_alias, alias = await alias_cog.is_alias(ctx.guild, term)
|
||||||
|
|
||||||
if is_alias:
|
if is_alias:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
customcom_cog = ctx.bot.get_cog("CustomCommands")
|
customcom_cog = ctx.bot.get_cog("CustomCommands")
|
||||||
if customcom_cog is not None:
|
if customcom_cog is not None:
|
||||||
cmd_obj = customcom_cog.commandobj
|
cmd_obj = customcom_cog.commandobj
|
||||||
|
|
||||||
try:
|
try:
|
||||||
ccinfo = await cmd_obj.get(ctx.message, term)
|
ccinfo = await cmd_obj.get(ctx.message, term)
|
||||||
except:
|
except:
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
extracted_cmds = await filter_commands(
|
extracted_cmds = await filter_commands(
|
||||||
ctx, process.extract(term, ctx.bot.walk_commands(), limit=5)
|
ctx, process.extract(term, ctx.bot.walk_commands(), limit=5)
|
||||||
)
|
)
|
||||||
@ -70,10 +96,101 @@ async def fuzzy_command_search(ctx: commands.Context, term: str):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
for pos, extracted in enumerate(extracted_cmds, 1):
|
for pos, extracted in enumerate(extracted_cmds, 1):
|
||||||
out += "{0}. {1.prefix}{2.qualified_name}{3}\n".format(
|
short = " - {}".format(extracted[0].short_doc) if extracted[0].short_doc else ""
|
||||||
pos,
|
out.append("{0}. {1.prefix}{2.qualified_name}{3}".format(pos, ctx, extracted[0], short))
|
||||||
ctx,
|
|
||||||
extracted[0],
|
return box("\n".join(out), lang="Perhaps you wanted one of these?")
|
||||||
" - {}".format(extracted[0].short_doc) if extracted[0].short_doc else "",
|
|
||||||
)
|
|
||||||
return box(out, lang="Perhaps you wanted one of these?")
|
async def _sem_wrapper(sem, task):
|
||||||
|
async with sem:
|
||||||
|
return await task
|
||||||
|
|
||||||
|
|
||||||
|
def bounded_gather_iter(
|
||||||
|
*coros_or_futures,
|
||||||
|
loop: Optional[AbstractEventLoop] = None,
|
||||||
|
limit: int = 4,
|
||||||
|
semaphore: Optional[Semaphore] = None,
|
||||||
|
) -> Iterator[Awaitable[Any]]:
|
||||||
|
"""
|
||||||
|
An iterator that returns tasks as they are ready, but limits the
|
||||||
|
number of tasks running at a time.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
*coros_or_futures
|
||||||
|
The awaitables to run in a bounded concurrent fashion.
|
||||||
|
loop : asyncio.AbstractEventLoop
|
||||||
|
The event loop to use for the semaphore and :meth:`asyncio.gather`.
|
||||||
|
limit : Optional[`int`]
|
||||||
|
The maximum number of concurrent tasks. Used when no ``semaphore`` is passed.
|
||||||
|
semaphore : Optional[:class:`asyncio.Semaphore`]
|
||||||
|
The semaphore to use for bounding tasks. If `None`, create one using ``loop`` and ``limit``.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
TypeError
|
||||||
|
When invalid parameters are passed
|
||||||
|
"""
|
||||||
|
if loop is None:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
if semaphore is None:
|
||||||
|
if not isinstance(limit, int) or limit <= 0:
|
||||||
|
raise TypeError("limit must be an int > 0")
|
||||||
|
|
||||||
|
semaphore = Semaphore(limit, loop=loop)
|
||||||
|
|
||||||
|
pending = []
|
||||||
|
|
||||||
|
for cof in coros_or_futures:
|
||||||
|
if isfuture(cof) and cof._loop is not loop:
|
||||||
|
raise ValueError("futures are tied to different event loops")
|
||||||
|
|
||||||
|
cof = _sem_wrapper(semaphore, cof)
|
||||||
|
pending.append(cof)
|
||||||
|
|
||||||
|
return as_completed(pending, loop=loop)
|
||||||
|
|
||||||
|
|
||||||
|
def bounded_gather(
|
||||||
|
*coros_or_futures,
|
||||||
|
loop: Optional[AbstractEventLoop] = None,
|
||||||
|
return_exceptions: bool = False,
|
||||||
|
limit: int = 4,
|
||||||
|
semaphore: Optional[Semaphore] = None,
|
||||||
|
) -> Awaitable[List[Any]]:
|
||||||
|
"""
|
||||||
|
A semaphore-bounded wrapper to :meth:`asyncio.gather`.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
*coros_or_futures
|
||||||
|
The awaitables to run in a bounded concurrent fashion.
|
||||||
|
loop : asyncio.AbstractEventLoop
|
||||||
|
The event loop to use for the semaphore and :meth:`asyncio.gather`.
|
||||||
|
return_exceptions : bool
|
||||||
|
If true, gather exceptions in the result list instead of raising.
|
||||||
|
limit : Optional[`int`]
|
||||||
|
The maximum number of concurrent tasks. Used when no ``semaphore`` is passed.
|
||||||
|
semaphore : Optional[:class:`asyncio.Semaphore`]
|
||||||
|
The semaphore to use for bounding tasks. If `None`, create one using ``loop`` and ``limit``.
|
||||||
|
|
||||||
|
Raises
|
||||||
|
------
|
||||||
|
TypeError
|
||||||
|
When invalid parameters are passed
|
||||||
|
"""
|
||||||
|
if loop is None:
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
|
if semaphore is None:
|
||||||
|
if not isinstance(limit, int) or limit <= 0:
|
||||||
|
raise TypeError("limit must be an int > 0")
|
||||||
|
|
||||||
|
semaphore = Semaphore(limit, loop=loop)
|
||||||
|
|
||||||
|
tasks = (_sem_wrapper(semaphore, task) for task in coros_or_futures)
|
||||||
|
|
||||||
|
return asyncio.gather(*tasks, loop=loop, return_exceptions=return_exceptions)
|
||||||
|
|||||||
@ -1,5 +1,14 @@
|
|||||||
|
import asyncio
|
||||||
|
import pytest
|
||||||
|
import random
|
||||||
import textwrap
|
import textwrap
|
||||||
from redbot.core.utils import chat_formatting
|
import warnings
|
||||||
|
from redbot.core.utils import (
|
||||||
|
chat_formatting,
|
||||||
|
bounded_gather,
|
||||||
|
bounded_gather_iter,
|
||||||
|
deduplicate_iterables,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_bordered_symmetrical():
|
def test_bordered_symmetrical():
|
||||||
@ -54,3 +63,131 @@ def test_bordered_ascii():
|
|||||||
)
|
)
|
||||||
col1, col2 = ["one", "two", "three"], ["four", "five", "six"]
|
col1, col2 = ["one", "two", "three"], ["four", "five", "six"]
|
||||||
assert chat_formatting.bordered(col1, col2, ascii_border=True) == expected
|
assert chat_formatting.bordered(col1, col2, ascii_border=True) == expected
|
||||||
|
|
||||||
|
|
||||||
|
def test_deduplicate_iterables():
|
||||||
|
expected = [1, 2, 3, 4, 5]
|
||||||
|
inputs = [[1, 2, 1], [3, 1, 2, 4], [5, 1, 2]]
|
||||||
|
assert deduplicate_iterables(*inputs) == expected
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bounded_gather():
|
||||||
|
status = [0, 0] # num_running, max_running
|
||||||
|
|
||||||
|
async def wait_task(i, delay, status, fail=False):
|
||||||
|
status[0] += 1
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
status[1] = max(status)
|
||||||
|
status[0] -= 1
|
||||||
|
|
||||||
|
if fail:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
return i
|
||||||
|
|
||||||
|
num_concurrent = random.randint(2, 8)
|
||||||
|
num_tasks = random.randint(4 * num_concurrent, 5 * num_concurrent)
|
||||||
|
num_fail = random.randint(num_concurrent, num_tasks)
|
||||||
|
|
||||||
|
tasks = [wait_task(i, random.random() / 1000, status) for i in range(num_tasks)]
|
||||||
|
tasks += [wait_task(i, random.random() / 1000, status, fail=True) for i in range(num_fail)]
|
||||||
|
|
||||||
|
num_failed = 0
|
||||||
|
|
||||||
|
results = await bounded_gather(*tasks, limit=num_concurrent, return_exceptions=True)
|
||||||
|
|
||||||
|
for i, result in enumerate(results):
|
||||||
|
if isinstance(result, RuntimeError):
|
||||||
|
num_failed += 1
|
||||||
|
else:
|
||||||
|
assert result == i # verify original orde
|
||||||
|
assert 0 <= result < num_tasks
|
||||||
|
|
||||||
|
assert 0 < status[1] <= num_concurrent
|
||||||
|
assert num_fail == num_failed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bounded_gather_iter():
|
||||||
|
status = [0, 0] # num_running, max_running
|
||||||
|
|
||||||
|
async def wait_task(i, delay, status, fail=False):
|
||||||
|
status[0] += 1
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
status[1] = max(status)
|
||||||
|
status[0] -= 1
|
||||||
|
|
||||||
|
if fail:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
return i
|
||||||
|
|
||||||
|
num_concurrent = random.randint(2, 8)
|
||||||
|
num_tasks = random.randint(4 * num_concurrent, 16 * num_concurrent)
|
||||||
|
num_fail = random.randint(num_concurrent, num_tasks)
|
||||||
|
|
||||||
|
tasks = [wait_task(i, random.random() / 1000, status) for i in range(num_tasks)]
|
||||||
|
tasks += [wait_task(i, random.random() / 1000, status, fail=True) for i in range(num_fail)]
|
||||||
|
random.shuffle(tasks)
|
||||||
|
|
||||||
|
num_failed = 0
|
||||||
|
|
||||||
|
for result in bounded_gather_iter(*tasks, limit=num_concurrent):
|
||||||
|
try:
|
||||||
|
result = await result
|
||||||
|
except RuntimeError:
|
||||||
|
num_failed += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
assert 0 <= result < num_tasks
|
||||||
|
|
||||||
|
assert 0 < status[1] <= num_concurrent
|
||||||
|
assert num_fail == num_failed
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="spams logs with pending task warnings")
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_bounded_gather_iter_cancel():
|
||||||
|
status = [0, 0, 0] # num_running, max_running, num_ran
|
||||||
|
|
||||||
|
async def wait_task(i, delay, status, fail=False):
|
||||||
|
status[0] += 1
|
||||||
|
await asyncio.sleep(delay)
|
||||||
|
status[1] = max(status[:2])
|
||||||
|
status[0] -= 1
|
||||||
|
|
||||||
|
if fail:
|
||||||
|
raise RuntimeError
|
||||||
|
|
||||||
|
status[2] += 1
|
||||||
|
return i
|
||||||
|
|
||||||
|
num_concurrent = random.randint(2, 8)
|
||||||
|
num_tasks = random.randint(4 * num_concurrent, 16 * num_concurrent)
|
||||||
|
quit_on = random.randint(0, num_tasks)
|
||||||
|
num_fail = random.randint(num_concurrent, num_tasks)
|
||||||
|
|
||||||
|
tasks = [wait_task(i, random.random() / 1000, status) for i in range(num_tasks)]
|
||||||
|
tasks += [wait_task(i, random.random() / 1000, status, fail=True) for i in range(num_fail)]
|
||||||
|
random.shuffle(tasks)
|
||||||
|
|
||||||
|
num_failed = 0
|
||||||
|
i = 0
|
||||||
|
|
||||||
|
for result in bounded_gather_iter(*tasks, limit=num_concurrent):
|
||||||
|
try:
|
||||||
|
result = await result
|
||||||
|
except RuntimeError:
|
||||||
|
num_failed += 1
|
||||||
|
continue
|
||||||
|
|
||||||
|
if i == quit_on:
|
||||||
|
break
|
||||||
|
|
||||||
|
assert 0 <= result < num_tasks
|
||||||
|
i += 1
|
||||||
|
|
||||||
|
assert 0 < status[1] <= num_concurrent
|
||||||
|
assert quit_on <= status[2] <= quit_on + num_concurrent
|
||||||
|
assert num_failed <= num_fail
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user