[CogManager, Utils] Handle missing cogs correctly, add some helpful algorithms (#1989)

* Handle missing cogs correctly, add some helpful algorithms

For cog loading, only show "cog not found" if the module in question was the one
that failed to import. ImportErrors within cogs will show an error as they should.

- deduplicator, benchmarked to be the fastest
- bounded gather and bounded async as_completed
- tests for all additions

* Requested changes + wrap as_completed instead

So I went source diving and realized as_completed works the way I want it to,
and I don't need to reinvent the wheel for cancelling tasks that remain
if the generator is `break`ed out of. So there's that.
This commit is contained in:
Caleb Johnson 2018-08-20 20:26:04 -05:00 committed by Toby Harradine
parent b550f38eed
commit 1329fa1b09
5 changed files with 322 additions and 49 deletions

View File

@ -4,6 +4,12 @@
Utility Functions
=================
General Utility
===============
.. automodule:: redbot.core.utils
:members: deduplicate_iterables, bounded_gather, bounded_gather_iter
Chat Formatting
===============

View File

@ -3,9 +3,10 @@ import pkgutil
from importlib import import_module, invalidate_caches
from importlib.machinery import ModuleSpec
from pathlib import Path
from typing import Tuple, Union, List
from typing import Tuple, Union, List, Optional
import redbot.cogs
from redbot.core.utils import deduplicate_iterables
import discord
from . import checks, commands
@ -18,12 +19,13 @@ from .utils.chat_formatting import box, pagify
__all__ = ["CogManager"]
def _deduplicate(xs):
ret = []
for x in xs:
if x not in ret:
ret.append(x)
return ret
class NoSuchCog(ImportError):
"""Thrown when a cog is missing.
Different from ImportError because some ImportErrors can happen inside cogs.
"""
pass
class CogManager:
@ -56,7 +58,7 @@ class CogManager:
conf_paths = [Path(p) for p in await self.conf.paths()]
other_paths = self._paths
all_paths = _deduplicate(list(conf_paths) + list(other_paths) + [self.CORE_PATH])
all_paths = deduplicate_iterables(conf_paths, other_paths, [self.CORE_PATH])
if self.install_path not in all_paths:
all_paths.insert(0, await self.install_path())
@ -209,11 +211,10 @@ class CogManager:
Raises
------
RuntimeError
When no matching spec can be found.
NoSuchCog
When no cog with the requested name was found.
"""
resolved_paths = _deduplicate(await self.paths())
resolved_paths = await self.paths()
real_paths = [str(p) for p in resolved_paths if p != self.CORE_PATH]
for finder, module_name, _ in pkgutil.iter_modules(real_paths):
@ -222,9 +223,11 @@ class CogManager:
if spec:
return spec
raise RuntimeError(
"No 3rd party module by the name of '{}' was found"
" in any available path.".format(name)
raise NoSuchCog(
"No 3rd party module by the name of '{}' was found in any available path.".format(
name
),
name=name,
)
@staticmethod
@ -246,16 +249,24 @@ class CogManager:
When no matching spec can be found.
"""
real_name = ".{}".format(name)
package = "redbot.cogs"
try:
mod = import_module(real_name, package="redbot.cogs")
mod = import_module(real_name, package=package)
except ImportError as e:
raise RuntimeError(
"No core cog by the name of '{}' could be found.".format(name)
) from e
if e.name == package + real_name:
raise NoSuchCog(
"No core cog by the name of '{}' could be found.".format(name),
path=e.path,
name=e.name,
) from e
raise
return mod.__spec__
# noinspection PyUnreachableCode
async def find_cog(self, name: str) -> ModuleSpec:
async def find_cog(self, name: str) -> Optional[ModuleSpec]:
"""Find a cog in the list of available paths.
Parameters
@ -265,23 +276,16 @@ class CogManager:
Returns
-------
importlib.machinery.ModuleSpec
A module spec to be used for specialized cog loading.
Raises
------
RuntimeError
If there is no cog with the given name.
Optional[importlib.machinery.ModuleSpec]
A module spec to be used for specialized cog loading, if found.
"""
with contextlib.suppress(RuntimeError):
with contextlib.suppress(NoSuchCog):
return await self._find_ext_cog(name)
with contextlib.suppress(RuntimeError):
with contextlib.suppress(NoSuchCog):
return await self._find_core_cog(name)
raise RuntimeError("No cog with that name could be found.")
async def available_modules(self) -> List[str]:
"""Finds the names of all available modules to load.
"""

View File

@ -77,9 +77,17 @@ class CoreLogic:
for name in cog_names:
try:
spec = await bot.cog_mgr.find_cog(name)
cogspecs.append((spec, name))
except RuntimeError:
notfound_packages.append(name)
if spec:
cogspecs.append((spec, name))
else:
notfound_packages.append(name)
except Exception as e:
log.exception("Package import failed", exc_info=e)
exception_log = "Exception during import of cog\n"
exception_log += "".join(traceback.format_exception(type(e), e, e.__traceback__))
bot._last_exception = exception_log
failed_packages.append(name)
for spec, name in cogspecs:
try:
@ -95,6 +103,7 @@ class CoreLogic:
else:
await bot.add_loaded_package(name)
loaded_packages.append(name)
return loaded_packages, failed_packages, notfound_packages
def _cleanup_and_refresh_modules(self, module_name: str):
@ -511,7 +520,7 @@ class Core(CoreLogic):
loaded, failed, not_found = await self._load(cog_names)
if loaded:
fmt = "Loaded {packs}"
fmt = "Loaded {packs}."
formed = self._get_package_strings(loaded, fmt)
await ctx.send(formed)

View File

@ -1,14 +1,31 @@
__all__ = ["safe_delete", "fuzzy_command_search"]
__all__ = ["bounded_gather", "safe_delete", "fuzzy_command_search", "deduplicate_iterables"]
from pathlib import Path
import os
import shutil
import asyncio
from asyncio import as_completed, AbstractEventLoop, Semaphore
from asyncio.futures import isfuture
from itertools import chain
import logging
import os
from pathlib import Path
import shutil
from typing import Any, Awaitable, Iterator, List, Optional
from redbot.core import commands
from fuzzywuzzy import process
from .chat_formatting import box
# Benchmarked to be the fastest method.
def deduplicate_iterables(*iterables):
"""
Returns a list of all unique items in ``iterables``, in the order they
were first encountered.
"""
# dict insertion order is guaranteed to be preserved in 3.6+
return list(dict.fromkeys(chain.from_iterable(iterables)))
def fuzzy_filter(record):
return record.funcName != "extractWithoutOrder"
@ -20,10 +37,13 @@ def safe_delete(pth: Path):
if pth.exists():
for root, dirs, files in os.walk(str(pth)):
os.chmod(root, 0o755)
for d in dirs:
os.chmod(os.path.join(root, d), 0o755)
for f in files:
os.chmod(os.path.join(root, f), 0o755)
shutil.rmtree(str(pth), ignore_errors=True)
@ -33,35 +53,41 @@ async def filter_commands(ctx: commands.Context, extracted: list):
for i in extracted
if i[1] >= 90
and not i[0].hidden
and not any([p.hidden for p in i[0].parents])
and await i[0].can_run(ctx)
and all([await p.can_run(ctx) for p in i[0].parents])
and not any([p.hidden for p in i[0].parents])
]
async def fuzzy_command_search(ctx: commands.Context, term: str):
out = ""
out = []
if ctx.guild is not None:
enabled = await ctx.bot.db.guild(ctx.guild).fuzzy()
else:
enabled = await ctx.bot.db.fuzzy()
if not enabled:
return None
alias_cog = ctx.bot.get_cog("Alias")
if alias_cog is not None:
is_alias, alias = await alias_cog.is_alias(ctx.guild, term)
if is_alias:
return None
customcom_cog = ctx.bot.get_cog("CustomCommands")
if customcom_cog is not None:
cmd_obj = customcom_cog.commandobj
try:
ccinfo = await cmd_obj.get(ctx.message, term)
except:
pass
else:
return None
extracted_cmds = await filter_commands(
ctx, process.extract(term, ctx.bot.walk_commands(), limit=5)
)
@ -70,10 +96,101 @@ async def fuzzy_command_search(ctx: commands.Context, term: str):
return None
for pos, extracted in enumerate(extracted_cmds, 1):
out += "{0}. {1.prefix}{2.qualified_name}{3}\n".format(
pos,
ctx,
extracted[0],
" - {}".format(extracted[0].short_doc) if extracted[0].short_doc else "",
)
return box(out, lang="Perhaps you wanted one of these?")
short = " - {}".format(extracted[0].short_doc) if extracted[0].short_doc else ""
out.append("{0}. {1.prefix}{2.qualified_name}{3}".format(pos, ctx, extracted[0], short))
return box("\n".join(out), lang="Perhaps you wanted one of these?")
async def _sem_wrapper(sem, task):
async with sem:
return await task
def bounded_gather_iter(
*coros_or_futures,
loop: Optional[AbstractEventLoop] = None,
limit: int = 4,
semaphore: Optional[Semaphore] = None,
) -> Iterator[Awaitable[Any]]:
"""
An iterator that returns tasks as they are ready, but limits the
number of tasks running at a time.
Parameters
----------
*coros_or_futures
The awaitables to run in a bounded concurrent fashion.
loop : asyncio.AbstractEventLoop
The event loop to use for the semaphore and :meth:`asyncio.gather`.
limit : Optional[`int`]
The maximum number of concurrent tasks. Used when no ``semaphore`` is passed.
semaphore : Optional[:class:`asyncio.Semaphore`]
The semaphore to use for bounding tasks. If `None`, create one using ``loop`` and ``limit``.
Raises
------
TypeError
When invalid parameters are passed
"""
if loop is None:
loop = asyncio.get_event_loop()
if semaphore is None:
if not isinstance(limit, int) or limit <= 0:
raise TypeError("limit must be an int > 0")
semaphore = Semaphore(limit, loop=loop)
pending = []
for cof in coros_or_futures:
if isfuture(cof) and cof._loop is not loop:
raise ValueError("futures are tied to different event loops")
cof = _sem_wrapper(semaphore, cof)
pending.append(cof)
return as_completed(pending, loop=loop)
def bounded_gather(
*coros_or_futures,
loop: Optional[AbstractEventLoop] = None,
return_exceptions: bool = False,
limit: int = 4,
semaphore: Optional[Semaphore] = None,
) -> Awaitable[List[Any]]:
"""
A semaphore-bounded wrapper to :meth:`asyncio.gather`.
Parameters
----------
*coros_or_futures
The awaitables to run in a bounded concurrent fashion.
loop : asyncio.AbstractEventLoop
The event loop to use for the semaphore and :meth:`asyncio.gather`.
return_exceptions : bool
If true, gather exceptions in the result list instead of raising.
limit : Optional[`int`]
The maximum number of concurrent tasks. Used when no ``semaphore`` is passed.
semaphore : Optional[:class:`asyncio.Semaphore`]
The semaphore to use for bounding tasks. If `None`, create one using ``loop`` and ``limit``.
Raises
------
TypeError
When invalid parameters are passed
"""
if loop is None:
loop = asyncio.get_event_loop()
if semaphore is None:
if not isinstance(limit, int) or limit <= 0:
raise TypeError("limit must be an int > 0")
semaphore = Semaphore(limit, loop=loop)
tasks = (_sem_wrapper(semaphore, task) for task in coros_or_futures)
return asyncio.gather(*tasks, loop=loop, return_exceptions=return_exceptions)

View File

@ -1,5 +1,14 @@
import asyncio
import pytest
import random
import textwrap
from redbot.core.utils import chat_formatting
import warnings
from redbot.core.utils import (
chat_formatting,
bounded_gather,
bounded_gather_iter,
deduplicate_iterables,
)
def test_bordered_symmetrical():
@ -54,3 +63,131 @@ def test_bordered_ascii():
)
col1, col2 = ["one", "two", "three"], ["four", "five", "six"]
assert chat_formatting.bordered(col1, col2, ascii_border=True) == expected
def test_deduplicate_iterables():
expected = [1, 2, 3, 4, 5]
inputs = [[1, 2, 1], [3, 1, 2, 4], [5, 1, 2]]
assert deduplicate_iterables(*inputs) == expected
@pytest.mark.asyncio
async def test_bounded_gather():
status = [0, 0] # num_running, max_running
async def wait_task(i, delay, status, fail=False):
status[0] += 1
await asyncio.sleep(delay)
status[1] = max(status)
status[0] -= 1
if fail:
raise RuntimeError
return i
num_concurrent = random.randint(2, 8)
num_tasks = random.randint(4 * num_concurrent, 5 * num_concurrent)
num_fail = random.randint(num_concurrent, num_tasks)
tasks = [wait_task(i, random.random() / 1000, status) for i in range(num_tasks)]
tasks += [wait_task(i, random.random() / 1000, status, fail=True) for i in range(num_fail)]
num_failed = 0
results = await bounded_gather(*tasks, limit=num_concurrent, return_exceptions=True)
for i, result in enumerate(results):
if isinstance(result, RuntimeError):
num_failed += 1
else:
assert result == i # verify original orde
assert 0 <= result < num_tasks
assert 0 < status[1] <= num_concurrent
assert num_fail == num_failed
@pytest.mark.asyncio
async def test_bounded_gather_iter():
status = [0, 0] # num_running, max_running
async def wait_task(i, delay, status, fail=False):
status[0] += 1
await asyncio.sleep(delay)
status[1] = max(status)
status[0] -= 1
if fail:
raise RuntimeError
return i
num_concurrent = random.randint(2, 8)
num_tasks = random.randint(4 * num_concurrent, 16 * num_concurrent)
num_fail = random.randint(num_concurrent, num_tasks)
tasks = [wait_task(i, random.random() / 1000, status) for i in range(num_tasks)]
tasks += [wait_task(i, random.random() / 1000, status, fail=True) for i in range(num_fail)]
random.shuffle(tasks)
num_failed = 0
for result in bounded_gather_iter(*tasks, limit=num_concurrent):
try:
result = await result
except RuntimeError:
num_failed += 1
continue
assert 0 <= result < num_tasks
assert 0 < status[1] <= num_concurrent
assert num_fail == num_failed
@pytest.mark.skip(reason="spams logs with pending task warnings")
@pytest.mark.asyncio
async def test_bounded_gather_iter_cancel():
status = [0, 0, 0] # num_running, max_running, num_ran
async def wait_task(i, delay, status, fail=False):
status[0] += 1
await asyncio.sleep(delay)
status[1] = max(status[:2])
status[0] -= 1
if fail:
raise RuntimeError
status[2] += 1
return i
num_concurrent = random.randint(2, 8)
num_tasks = random.randint(4 * num_concurrent, 16 * num_concurrent)
quit_on = random.randint(0, num_tasks)
num_fail = random.randint(num_concurrent, num_tasks)
tasks = [wait_task(i, random.random() / 1000, status) for i in range(num_tasks)]
tasks += [wait_task(i, random.random() / 1000, status, fail=True) for i in range(num_fail)]
random.shuffle(tasks)
num_failed = 0
i = 0
for result in bounded_gather_iter(*tasks, limit=num_concurrent):
try:
result = await result
except RuntimeError:
num_failed += 1
continue
if i == quit_on:
break
assert 0 <= result < num_tasks
i += 1
assert 0 < status[1] <= num_concurrent
assert quit_on <= status[2] <= quit_on + num_concurrent
assert num_failed <= num_fail