diff --git a/docs/framework_utils.rst b/docs/framework_utils.rst index 9be4d44fc..e3ff54688 100644 --- a/docs/framework_utils.rst +++ b/docs/framework_utils.rst @@ -4,6 +4,12 @@ Utility Functions ================= +General Utility +=============== + +.. automodule:: redbot.core.utils + :members: deduplicate_iterables, bounded_gather, bounded_gather_iter + Chat Formatting =============== diff --git a/redbot/core/cog_manager.py b/redbot/core/cog_manager.py index c6791a7a0..229321512 100644 --- a/redbot/core/cog_manager.py +++ b/redbot/core/cog_manager.py @@ -3,9 +3,10 @@ import pkgutil from importlib import import_module, invalidate_caches from importlib.machinery import ModuleSpec from pathlib import Path -from typing import Tuple, Union, List +from typing import Tuple, Union, List, Optional import redbot.cogs +from redbot.core.utils import deduplicate_iterables import discord from . import checks, commands @@ -18,12 +19,13 @@ from .utils.chat_formatting import box, pagify __all__ = ["CogManager"] -def _deduplicate(xs): - ret = [] - for x in xs: - if x not in ret: - ret.append(x) - return ret +class NoSuchCog(ImportError): + """Thrown when a cog is missing. + + Different from ImportError because some ImportErrors can happen inside cogs. + """ + + pass class CogManager: @@ -56,7 +58,7 @@ class CogManager: conf_paths = [Path(p) for p in await self.conf.paths()] other_paths = self._paths - all_paths = _deduplicate(list(conf_paths) + list(other_paths) + [self.CORE_PATH]) + all_paths = deduplicate_iterables(conf_paths, other_paths, [self.CORE_PATH]) if self.install_path not in all_paths: all_paths.insert(0, await self.install_path()) @@ -209,11 +211,10 @@ class CogManager: Raises ------ - RuntimeError - When no matching spec can be found. + NoSuchCog + When no cog with the requested name was found. """ - resolved_paths = _deduplicate(await self.paths()) - + resolved_paths = await self.paths() real_paths = [str(p) for p in resolved_paths if p != self.CORE_PATH] for finder, module_name, _ in pkgutil.iter_modules(real_paths): @@ -222,9 +223,11 @@ class CogManager: if spec: return spec - raise RuntimeError( - "No 3rd party module by the name of '{}' was found" - " in any available path.".format(name) + raise NoSuchCog( + "No 3rd party module by the name of '{}' was found in any available path.".format( + name + ), + name=name, ) @staticmethod @@ -246,16 +249,24 @@ class CogManager: When no matching spec can be found. """ real_name = ".{}".format(name) + package = "redbot.cogs" + try: - mod = import_module(real_name, package="redbot.cogs") + mod = import_module(real_name, package=package) except ImportError as e: - raise RuntimeError( - "No core cog by the name of '{}' could be found.".format(name) - ) from e + if e.name == package + real_name: + raise NoSuchCog( + "No core cog by the name of '{}' could be found.".format(name), + path=e.path, + name=e.name, + ) from e + + raise + return mod.__spec__ # noinspection PyUnreachableCode - async def find_cog(self, name: str) -> ModuleSpec: + async def find_cog(self, name: str) -> Optional[ModuleSpec]: """Find a cog in the list of available paths. Parameters @@ -265,23 +276,16 @@ class CogManager: Returns ------- - importlib.machinery.ModuleSpec - A module spec to be used for specialized cog loading. - - Raises - ------ - RuntimeError - If there is no cog with the given name. + Optional[importlib.machinery.ModuleSpec] + A module spec to be used for specialized cog loading, if found. """ - with contextlib.suppress(RuntimeError): + with contextlib.suppress(NoSuchCog): return await self._find_ext_cog(name) - with contextlib.suppress(RuntimeError): + with contextlib.suppress(NoSuchCog): return await self._find_core_cog(name) - raise RuntimeError("No cog with that name could be found.") - async def available_modules(self) -> List[str]: """Finds the names of all available modules to load. """ diff --git a/redbot/core/core_commands.py b/redbot/core/core_commands.py index fd29c7f8e..ce757d490 100644 --- a/redbot/core/core_commands.py +++ b/redbot/core/core_commands.py @@ -77,9 +77,17 @@ class CoreLogic: for name in cog_names: try: spec = await bot.cog_mgr.find_cog(name) - cogspecs.append((spec, name)) - except RuntimeError: - notfound_packages.append(name) + if spec: + cogspecs.append((spec, name)) + else: + notfound_packages.append(name) + except Exception as e: + log.exception("Package import failed", exc_info=e) + + exception_log = "Exception during import of cog\n" + exception_log += "".join(traceback.format_exception(type(e), e, e.__traceback__)) + bot._last_exception = exception_log + failed_packages.append(name) for spec, name in cogspecs: try: @@ -95,6 +103,7 @@ class CoreLogic: else: await bot.add_loaded_package(name) loaded_packages.append(name) + return loaded_packages, failed_packages, notfound_packages def _cleanup_and_refresh_modules(self, module_name: str): @@ -511,7 +520,7 @@ class Core(CoreLogic): loaded, failed, not_found = await self._load(cog_names) if loaded: - fmt = "Loaded {packs}" + fmt = "Loaded {packs}." formed = self._get_package_strings(loaded, fmt) await ctx.send(formed) diff --git a/redbot/core/utils/__init__.py b/redbot/core/utils/__init__.py index f279f46c0..d6599dce7 100644 --- a/redbot/core/utils/__init__.py +++ b/redbot/core/utils/__init__.py @@ -1,14 +1,31 @@ -__all__ = ["safe_delete", "fuzzy_command_search"] +__all__ = ["bounded_gather", "safe_delete", "fuzzy_command_search", "deduplicate_iterables"] -from pathlib import Path -import os -import shutil +import asyncio +from asyncio import as_completed, AbstractEventLoop, Semaphore +from asyncio.futures import isfuture +from itertools import chain import logging +import os +from pathlib import Path +import shutil +from typing import Any, Awaitable, Iterator, List, Optional + from redbot.core import commands from fuzzywuzzy import process + from .chat_formatting import box +# Benchmarked to be the fastest method. +def deduplicate_iterables(*iterables): + """ + Returns a list of all unique items in ``iterables``, in the order they + were first encountered. + """ + # dict insertion order is guaranteed to be preserved in 3.6+ + return list(dict.fromkeys(chain.from_iterable(iterables))) + + def fuzzy_filter(record): return record.funcName != "extractWithoutOrder" @@ -20,10 +37,13 @@ def safe_delete(pth: Path): if pth.exists(): for root, dirs, files in os.walk(str(pth)): os.chmod(root, 0o755) + for d in dirs: os.chmod(os.path.join(root, d), 0o755) + for f in files: os.chmod(os.path.join(root, f), 0o755) + shutil.rmtree(str(pth), ignore_errors=True) @@ -33,35 +53,41 @@ async def filter_commands(ctx: commands.Context, extracted: list): for i in extracted if i[1] >= 90 and not i[0].hidden + and not any([p.hidden for p in i[0].parents]) and await i[0].can_run(ctx) and all([await p.can_run(ctx) for p in i[0].parents]) - and not any([p.hidden for p in i[0].parents]) ] async def fuzzy_command_search(ctx: commands.Context, term: str): - out = "" + out = [] + if ctx.guild is not None: enabled = await ctx.bot.db.guild(ctx.guild).fuzzy() else: enabled = await ctx.bot.db.fuzzy() + if not enabled: return None + alias_cog = ctx.bot.get_cog("Alias") if alias_cog is not None: is_alias, alias = await alias_cog.is_alias(ctx.guild, term) + if is_alias: return None customcom_cog = ctx.bot.get_cog("CustomCommands") if customcom_cog is not None: cmd_obj = customcom_cog.commandobj + try: ccinfo = await cmd_obj.get(ctx.message, term) except: pass else: return None + extracted_cmds = await filter_commands( ctx, process.extract(term, ctx.bot.walk_commands(), limit=5) ) @@ -70,10 +96,101 @@ async def fuzzy_command_search(ctx: commands.Context, term: str): return None for pos, extracted in enumerate(extracted_cmds, 1): - out += "{0}. {1.prefix}{2.qualified_name}{3}\n".format( - pos, - ctx, - extracted[0], - " - {}".format(extracted[0].short_doc) if extracted[0].short_doc else "", - ) - return box(out, lang="Perhaps you wanted one of these?") + short = " - {}".format(extracted[0].short_doc) if extracted[0].short_doc else "" + out.append("{0}. {1.prefix}{2.qualified_name}{3}".format(pos, ctx, extracted[0], short)) + + return box("\n".join(out), lang="Perhaps you wanted one of these?") + + +async def _sem_wrapper(sem, task): + async with sem: + return await task + + +def bounded_gather_iter( + *coros_or_futures, + loop: Optional[AbstractEventLoop] = None, + limit: int = 4, + semaphore: Optional[Semaphore] = None, +) -> Iterator[Awaitable[Any]]: + """ + An iterator that returns tasks as they are ready, but limits the + number of tasks running at a time. + + Parameters + ---------- + *coros_or_futures + The awaitables to run in a bounded concurrent fashion. + loop : asyncio.AbstractEventLoop + The event loop to use for the semaphore and :meth:`asyncio.gather`. + limit : Optional[`int`] + The maximum number of concurrent tasks. Used when no ``semaphore`` is passed. + semaphore : Optional[:class:`asyncio.Semaphore`] + The semaphore to use for bounding tasks. If `None`, create one using ``loop`` and ``limit``. + + Raises + ------ + TypeError + When invalid parameters are passed + """ + if loop is None: + loop = asyncio.get_event_loop() + + if semaphore is None: + if not isinstance(limit, int) or limit <= 0: + raise TypeError("limit must be an int > 0") + + semaphore = Semaphore(limit, loop=loop) + + pending = [] + + for cof in coros_or_futures: + if isfuture(cof) and cof._loop is not loop: + raise ValueError("futures are tied to different event loops") + + cof = _sem_wrapper(semaphore, cof) + pending.append(cof) + + return as_completed(pending, loop=loop) + + +def bounded_gather( + *coros_or_futures, + loop: Optional[AbstractEventLoop] = None, + return_exceptions: bool = False, + limit: int = 4, + semaphore: Optional[Semaphore] = None, +) -> Awaitable[List[Any]]: + """ + A semaphore-bounded wrapper to :meth:`asyncio.gather`. + + Parameters + ---------- + *coros_or_futures + The awaitables to run in a bounded concurrent fashion. + loop : asyncio.AbstractEventLoop + The event loop to use for the semaphore and :meth:`asyncio.gather`. + return_exceptions : bool + If true, gather exceptions in the result list instead of raising. + limit : Optional[`int`] + The maximum number of concurrent tasks. Used when no ``semaphore`` is passed. + semaphore : Optional[:class:`asyncio.Semaphore`] + The semaphore to use for bounding tasks. If `None`, create one using ``loop`` and ``limit``. + + Raises + ------ + TypeError + When invalid parameters are passed + """ + if loop is None: + loop = asyncio.get_event_loop() + + if semaphore is None: + if not isinstance(limit, int) or limit <= 0: + raise TypeError("limit must be an int > 0") + + semaphore = Semaphore(limit, loop=loop) + + tasks = (_sem_wrapper(semaphore, task) for task in coros_or_futures) + + return asyncio.gather(*tasks, loop=loop, return_exceptions=return_exceptions) diff --git a/tests/core/test_utils.py b/tests/core/test_utils.py index c76b066a6..3f6056bfe 100644 --- a/tests/core/test_utils.py +++ b/tests/core/test_utils.py @@ -1,5 +1,14 @@ +import asyncio +import pytest +import random import textwrap -from redbot.core.utils import chat_formatting +import warnings +from redbot.core.utils import ( + chat_formatting, + bounded_gather, + bounded_gather_iter, + deduplicate_iterables, +) def test_bordered_symmetrical(): @@ -54,3 +63,131 @@ def test_bordered_ascii(): ) col1, col2 = ["one", "two", "three"], ["four", "five", "six"] assert chat_formatting.bordered(col1, col2, ascii_border=True) == expected + + +def test_deduplicate_iterables(): + expected = [1, 2, 3, 4, 5] + inputs = [[1, 2, 1], [3, 1, 2, 4], [5, 1, 2]] + assert deduplicate_iterables(*inputs) == expected + + +@pytest.mark.asyncio +async def test_bounded_gather(): + status = [0, 0] # num_running, max_running + + async def wait_task(i, delay, status, fail=False): + status[0] += 1 + await asyncio.sleep(delay) + status[1] = max(status) + status[0] -= 1 + + if fail: + raise RuntimeError + + return i + + num_concurrent = random.randint(2, 8) + num_tasks = random.randint(4 * num_concurrent, 5 * num_concurrent) + num_fail = random.randint(num_concurrent, num_tasks) + + tasks = [wait_task(i, random.random() / 1000, status) for i in range(num_tasks)] + tasks += [wait_task(i, random.random() / 1000, status, fail=True) for i in range(num_fail)] + + num_failed = 0 + + results = await bounded_gather(*tasks, limit=num_concurrent, return_exceptions=True) + + for i, result in enumerate(results): + if isinstance(result, RuntimeError): + num_failed += 1 + else: + assert result == i # verify original orde + assert 0 <= result < num_tasks + + assert 0 < status[1] <= num_concurrent + assert num_fail == num_failed + + +@pytest.mark.asyncio +async def test_bounded_gather_iter(): + status = [0, 0] # num_running, max_running + + async def wait_task(i, delay, status, fail=False): + status[0] += 1 + await asyncio.sleep(delay) + status[1] = max(status) + status[0] -= 1 + + if fail: + raise RuntimeError + + return i + + num_concurrent = random.randint(2, 8) + num_tasks = random.randint(4 * num_concurrent, 16 * num_concurrent) + num_fail = random.randint(num_concurrent, num_tasks) + + tasks = [wait_task(i, random.random() / 1000, status) for i in range(num_tasks)] + tasks += [wait_task(i, random.random() / 1000, status, fail=True) for i in range(num_fail)] + random.shuffle(tasks) + + num_failed = 0 + + for result in bounded_gather_iter(*tasks, limit=num_concurrent): + try: + result = await result + except RuntimeError: + num_failed += 1 + continue + + assert 0 <= result < num_tasks + + assert 0 < status[1] <= num_concurrent + assert num_fail == num_failed + + +@pytest.mark.skip(reason="spams logs with pending task warnings") +@pytest.mark.asyncio +async def test_bounded_gather_iter_cancel(): + status = [0, 0, 0] # num_running, max_running, num_ran + + async def wait_task(i, delay, status, fail=False): + status[0] += 1 + await asyncio.sleep(delay) + status[1] = max(status[:2]) + status[0] -= 1 + + if fail: + raise RuntimeError + + status[2] += 1 + return i + + num_concurrent = random.randint(2, 8) + num_tasks = random.randint(4 * num_concurrent, 16 * num_concurrent) + quit_on = random.randint(0, num_tasks) + num_fail = random.randint(num_concurrent, num_tasks) + + tasks = [wait_task(i, random.random() / 1000, status) for i in range(num_tasks)] + tasks += [wait_task(i, random.random() / 1000, status, fail=True) for i in range(num_fail)] + random.shuffle(tasks) + + num_failed = 0 + i = 0 + + for result in bounded_gather_iter(*tasks, limit=num_concurrent): + try: + result = await result + except RuntimeError: + num_failed += 1 + continue + + if i == quit_on: + break + + assert 0 <= result < num_tasks + i += 1 + + assert 0 < status[1] <= num_concurrent + assert quit_on <= status[2] <= quit_on + num_concurrent + assert num_failed <= num_fail