[Commands Module] Improve usability of type hints (#3410)

* [Commands Module] Better Typehint Support

  We now do a lot more with type hints

  - No more rexporting d.py commands submodules
  - New type aliases for GuildContext & DMContext
  - More things are typehinted

  Note: Some things are still not typed, others are still incorrectly
  typed, This is progress.

Co-authored-by: jack1142 <6032823+jack1142@users.noreply.github.com>
This commit is contained in:
Michael H 2020-01-26 17:54:39 -05:00 committed by GitHub
parent 8654924869
commit a8450580e8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 807 additions and 185 deletions

View File

@ -23,5 +23,14 @@ extend functionalities used throughout the bot, as outlined below.
.. autoclass:: redbot.core.commands.Context .. autoclass:: redbot.core.commands.Context
:members: :members:
.. autoclass:: redbot.core.commands.GuildContext
.. autoclass:: redbot.core.commands.DMContext
.. automodule:: redbot.core.commands.requires .. automodule:: redbot.core.commands.requires
:members: PrivilegeLevel, PermState, Requires :members: PrivilegeLevel, PermState, Requires
.. automodule:: redbot.core.commands.converter
:members:
:exclude-members: convert
:no-undoc-members:

View File

@ -26,6 +26,7 @@ from typing import (
from types import MappingProxyType from types import MappingProxyType
import discord import discord
from discord.ext import commands as dpy_commands
from discord.ext.commands import when_mentioned_or from discord.ext.commands import when_mentioned_or
from discord.ext.commands.bot import BotBase from discord.ext.commands.bot import BotBase
@ -60,7 +61,9 @@ def _is_submodule(parent, child):
# barely spurious warning caused by our intentional shadowing # barely spurious warning caused by our intentional shadowing
class RedBase(commands.GroupMixin, BotBase, RPCMixin): # pylint: disable=no-member class RedBase(
commands.GroupMixin, dpy_commands.bot.BotBase, RPCMixin
): # pylint: disable=no-member
"""Mixin for the main bot class. """Mixin for the main bot class.
This exists because `Red` inherits from `discord.AutoShardedClient`, which This exists because `Red` inherits from `discord.AutoShardedClient`, which
@ -163,6 +166,16 @@ class RedBase(commands.GroupMixin, BotBase, RPCMixin): # pylint: disable=no-mem
self._red_ready = asyncio.Event() self._red_ready = asyncio.Event()
self._red_before_invoke_objs: Set[PreInvokeCoroutine] = set() self._red_before_invoke_objs: Set[PreInvokeCoroutine] = set()
def get_command(self, name: str) -> Optional[commands.Command]:
com = super().get_command(name)
assert com is None or isinstance(com, commands.Command)
return com
def get_cog(self, name: str) -> Optional[commands.Cog]:
cog = super().get_cog(name)
assert cog is None or isinstance(cog, commands.Cog)
return cog
@property @property
def _before_invoke(self): # DEP-WARN def _before_invoke(self): # DEP-WARN
return self._red_before_invoke_method return self._red_before_invoke_method

View File

@ -1,7 +1,145 @@
from discord.ext.commands import * ########## SENSITIVE SECTION WARNING ###########
from .commands import * ################################################
from .context import * # Any edits of any of the exported names #
from .converter import * # may result in a breaking change. #
from .errors import * # Ensure no names are removed without warning. #
from .requires import * ################################################
from .help import *
from .commands import (
Cog as Cog,
CogMixin as CogMixin,
CogCommandMixin as CogCommandMixin,
CogGroupMixin as CogGroupMixin,
Command as Command,
Group as Group,
GroupMixin as GroupMixin,
command as command,
group as group,
RESERVED_COMMAND_NAMES as RESERVED_COMMAND_NAMES,
)
from .context import Context as Context, GuildContext as GuildContext, DMContext as DMContext
from .converter import (
APIToken as APIToken,
DictConverter as DictConverter,
GuildConverter as GuildConverter,
TimedeltaConverter as TimedeltaConverter,
get_dict_converter as get_dict_converter,
get_timedelta_converter as get_timedelta_converter,
parse_timedelta as parse_timedelta,
NoParseOptional as NoParseOptional,
UserInputOptional as UserInputOptional,
Literal as Literal,
)
from .errors import (
ConversionFailure as ConversionFailure,
BotMissingPermissions as BotMissingPermissions,
UserFeedbackCheckFailure as UserFeedbackCheckFailure,
ArgParserFailure as ArgParserFailure,
)
from .help import (
red_help as red_help,
RedHelpFormatter as RedHelpFormatter,
HelpSettings as HelpSettings,
)
from .requires import (
CheckPredicate as CheckPredicate,
DM_PERMS as DM_PERMS,
GlobalPermissionModel as GlobalPermissionModel,
GuildPermissionModel as GuildPermissionModel,
PermissionModel as PermissionModel,
PrivilegeLevel as PrivilegeLevel,
PermState as PermState,
Requires as Requires,
permissions_check as permissions_check,
bot_has_permissions as bot_has_permissions,
has_permissions as has_permissions,
has_guild_permissions as has_guild_permissions,
is_owner as is_owner,
guildowner as guildowner,
guildowner_or_permissions as guildowner_or_permissions,
admin as admin,
admin_or_permissions as admin_or_permissions,
mod as mod,
mod_or_permissions as mod_or_permissions,
)
from ._dpy_reimplements import (
check as check,
guild_only as guild_only,
cooldown as cooldown,
dm_only as dm_only,
is_nsfw as is_nsfw,
has_role as has_role,
has_any_role as has_any_role,
bot_has_role as bot_has_role,
when_mentioned_or as when_mentioned_or,
when_mentioned as when_mentioned,
bot_has_any_role as bot_has_any_role,
)
### DEP-WARN: Check this *every* discord.py update
from discord.ext.commands import (
BadArgument as BadArgument,
EmojiConverter as EmojiConverter,
InvalidEndOfQuotedStringError as InvalidEndOfQuotedStringError,
MemberConverter as MemberConverter,
BotMissingRole as BotMissingRole,
PrivateMessageOnly as PrivateMessageOnly,
HelpCommand as HelpCommand,
MinimalHelpCommand as MinimalHelpCommand,
DisabledCommand as DisabledCommand,
ExtensionFailed as ExtensionFailed,
Bot as Bot,
NotOwner as NotOwner,
CategoryChannelConverter as CategoryChannelConverter,
CogMeta as CogMeta,
ConversionError as ConversionError,
UserInputError as UserInputError,
Converter as Converter,
InviteConverter as InviteConverter,
ExtensionError as ExtensionError,
Cooldown as Cooldown,
CheckFailure as CheckFailure,
MessageConverter as MessageConverter,
MissingPermissions as MissingPermissions,
BadUnionArgument as BadUnionArgument,
DefaultHelpCommand as DefaultHelpCommand,
ExtensionNotFound as ExtensionNotFound,
UserConverter as UserConverter,
MissingRole as MissingRole,
CommandOnCooldown as CommandOnCooldown,
MissingAnyRole as MissingAnyRole,
ExtensionNotLoaded as ExtensionNotLoaded,
clean_content as clean_content,
CooldownMapping as CooldownMapping,
ArgumentParsingError as ArgumentParsingError,
RoleConverter as RoleConverter,
CommandError as CommandError,
TextChannelConverter as TextChannelConverter,
UnexpectedQuoteError as UnexpectedQuoteError,
Paginator as Paginator,
BucketType as BucketType,
NoEntryPointError as NoEntryPointError,
CommandInvokeError as CommandInvokeError,
TooManyArguments as TooManyArguments,
Greedy as Greedy,
ExpectedClosingQuoteError as ExpectedClosingQuoteError,
ColourConverter as ColourConverter,
VoiceChannelConverter as VoiceChannelConverter,
NSFWChannelRequired as NSFWChannelRequired,
IDConverter as IDConverter,
MissingRequiredArgument as MissingRequiredArgument,
GameConverter as GameConverter,
CommandNotFound as CommandNotFound,
BotMissingAnyRole as BotMissingAnyRole,
NoPrivateMessage as NoPrivateMessage,
AutoShardedBot as AutoShardedBot,
ExtensionAlreadyLoaded as ExtensionAlreadyLoaded,
PartialEmojiConverter as PartialEmojiConverter,
check_any as check_any,
max_concurrency as max_concurrency,
CheckAnyFailure as CheckAnyFailure,
MaxConcurrency as MaxConcurrency,
MaxConcurrencyReached as MaxConcurrencyReached,
bot_has_guild_permissions as bot_has_guild_permissions,
)

View File

@ -0,0 +1,126 @@
from __future__ import annotations
import inspect
import functools
from typing import (
TypeVar,
Callable,
Awaitable,
Coroutine,
Union,
Type,
TYPE_CHECKING,
List,
Any,
Generator,
Protocol,
overload,
)
import discord
from discord.ext import commands as dpy_commands
# So much of this can be stripped right back out with proper stubs.
if not TYPE_CHECKING:
from discord.ext.commands import (
check as check,
guild_only as guild_only,
dm_only as dm_only,
is_nsfw as is_nsfw,
has_role as has_role,
has_any_role as has_any_role,
bot_has_role as bot_has_role,
bot_has_any_role as bot_has_any_role,
cooldown as cooldown,
)
from ..i18n import Translator
from .context import Context
from .commands import Command
_ = Translator("nah", __file__)
"""
Anything here is either a reimplementation or re-export
of a discord.py funtion or class with more lies for mypy
"""
__all__ = [
"check",
# "check_any", # discord.py 1.3
"guild_only",
"dm_only",
"is_nsfw",
"has_role",
"has_any_role",
"bot_has_role",
"bot_has_any_role",
"when_mentioned_or",
"cooldown",
"when_mentioned",
]
_CT = TypeVar("_CT", bound=Context)
_T = TypeVar("_T")
_F = TypeVar("_F")
CheckType = Union[Callable[[_CT], bool], Callable[[_CT], Coroutine[Any, Any, bool]]]
CoroLike = Callable[..., Union[Awaitable[_T], Generator[Any, None, _T]]]
class CheckDecorator(Protocol):
predicate: Coroutine[Any, Any, bool]
@overload
def __call__(self, func: _CT) -> _CT:
...
@overload
def __call__(self, func: CoroLike) -> CoroLike:
...
if TYPE_CHECKING:
def check(predicate: CheckType) -> CheckDecorator:
...
def guild_only() -> CheckDecorator:
...
def dm_only() -> CheckDecorator:
...
def is_nsfw() -> CheckDecorator:
...
def has_role() -> CheckDecorator:
...
def has_any_role() -> CheckDecorator:
...
def bot_has_role() -> CheckDecorator:
...
def bot_has_any_role() -> CheckDecorator:
...
def cooldown(rate: int, per: float, type: dpy_commands.BucketType = ...) -> Callable[[_F], _F]:
...
PrefixCallable = Callable[[dpy_commands.bot.BotBase, discord.Message], List[str]]
def when_mentioned(bot: dpy_commands.bot.BotBase, msg: discord.Message) -> List[str]:
return [f"<@{bot.user.id}> ", f"<@!{bot.user.id}> "]
def when_mentioned_or(*prefixes) -> PrefixCallable:
def inner(bot: dpy_commands.bot.BotBase, msg: discord.Message) -> List[str]:
r = list(prefixes)
r = when_mentioned(bot, msg) + r
return r
return inner

View File

@ -1,24 +1,53 @@
"""Module for command helpers and classes. """Module for command helpers and classes.
This module contains extended classes and functions which are intended to This module contains extended classes and functions which are intended to
replace those from the `discord.ext.commands` module. be used instead of those from the `discord.ext.commands` module.
""" """
from __future__ import annotations
import inspect import inspect
import re import re
import weakref import weakref
from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING from typing import (
Awaitable,
Callable,
Coroutine,
TypeVar,
Type,
Dict,
List,
Optional,
Tuple,
Union,
MutableMapping,
TYPE_CHECKING,
cast,
)
import discord import discord
from discord.ext import commands from discord.ext.commands import (
BadArgument,
CommandError,
CheckFailure,
DisabledCommand,
command as dpy_command_deco,
Command as DPYCommand,
Cog as DPYCog,
CogMeta as DPYCogMeta,
Group as DPYGroup,
Greedy,
)
from . import converter as converters from . import converter as converters
from .errors import ConversionFailure from .errors import ConversionFailure
from .requires import PermState, PrivilegeLevel, Requires from .requires import PermState, PrivilegeLevel, Requires, PermStateAllowedStates
from ..i18n import Translator from ..i18n import Translator
if TYPE_CHECKING: if TYPE_CHECKING:
# circular import avoidance
from .context import Context from .context import Context
__all__ = [ __all__ = [
"Cog", "Cog",
"CogMixin", "CogMixin",
@ -38,11 +67,17 @@ RESERVED_COMMAND_NAMES = (
) )
_ = Translator("commands.commands", __file__) _ = Translator("commands.commands", __file__)
DisablerDictType = MutableMapping[discord.Guild, Callable[["Context"], Awaitable[bool]]]
class CogCommandMixin: class CogCommandMixin:
"""A mixin for cogs and commands.""" """A mixin for cogs and commands."""
@property
def help(self) -> str:
"""To be defined by subclasses"""
...
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
if isinstance(self, Command): if isinstance(self, Command):
@ -182,7 +217,7 @@ class CogCommandMixin:
self.deny_to(Requires.DEFAULT, guild_id=guild_id) self.deny_to(Requires.DEFAULT, guild_id=guild_id)
class Command(CogCommandMixin, commands.Command): class Command(CogCommandMixin, DPYCommand):
"""Command class for Red. """Command class for Red.
This should not be created directly, and instead via the decorator. This should not be created directly, and instead via the decorator.
@ -198,7 +233,10 @@ class Command(CogCommandMixin, commands.Command):
`Requires.checks`. `Requires.checks`.
translator : Translator translator : Translator
A translator for this command's help docstring. A translator for this command's help docstring.
ignore_optional_for_conversion : bool
A value which can be set to not have discord.py's
argument parsing behavior for ``typing.Optional``
(type used will be of the inner type instead)
""" """
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
@ -209,6 +247,7 @@ class Command(CogCommandMixin, commands.Command):
return self.callback(*args, **kwargs) return self.callback(*args, **kwargs)
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.ignore_optional_for_conversion = kwargs.pop("ignore_optional_for_conversion", False)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
self._help_override = kwargs.pop("help_override", None) self._help_override = kwargs.pop("help_override", None)
self.translator = kwargs.pop("i18n", None) self.translator = kwargs.pop("i18n", None)
@ -229,8 +268,62 @@ class Command(CogCommandMixin, commands.Command):
# Red specific # Red specific
other.requires = self.requires other.requires = self.requires
other.ignore_optional_for_conversion = self.ignore_optional_for_conversion
return other return other
@property
def callback(self):
return self._callback
@callback.setter
def callback(self, function):
"""
Below should be mostly the same as discord.py
The only (current) change is to filter out typing.Optional
if a user has specified the desire for this behavior
"""
self._callback = function
self.module = function.__module__
signature = inspect.signature(function)
self.params = signature.parameters.copy()
# PEP-563 allows postponing evaluation of annotations with a __future__
# import. When postponed, Parameter.annotation will be a string and must
# be replaced with the real value for the converters to work later on
for key, value in self.params.items():
if isinstance(value.annotation, str):
self.params[key] = value = value.replace(
annotation=eval(value.annotation, function.__globals__)
)
# fail early for when someone passes an unparameterized Greedy type
if value.annotation is Greedy:
raise TypeError("Unparameterized Greedy[...] is disallowed in signature.")
if not self.ignore_optional_for_conversion:
continue # reduces indentation compared to alternative
try:
vtype = value.annotation.__origin__
if vtype is Union:
_NoneType = type if TYPE_CHECKING else type(None)
args = value.annotation.__args__
if _NoneType in args:
args = tuple(a for a in args if a is not _NoneType)
if len(args) == 1:
# can't have a union of 1 or 0 items
# 1 prevents this from becoming 0
# we need to prevent 2 become 1
# (Don't change that to becoming, it's intentional :musical_note:)
self.params[key] = value = value.replace(annotation=args[0])
else:
# and mypy wretches at the correct Union[args]
temp_type = type if TYPE_CHECKING else Union[args]
self.params[key] = value = value.replace(annotation=temp_type)
except AttributeError:
continue
@property @property
def help(self): def help(self):
"""Help string for this command. """Help string for this command.
@ -311,7 +404,7 @@ class Command(CogCommandMixin, commands.Command):
for parent in reversed(self.parents): for parent in reversed(self.parents):
try: try:
result = await parent.can_run(ctx, change_permission_state=True) result = await parent.can_run(ctx, change_permission_state=True)
except commands.CommandError: except CommandError:
result = False result = False
if result is False: if result is False:
@ -334,12 +427,10 @@ class Command(CogCommandMixin, commands.Command):
ctx.command = self ctx.command = self
if not self.enabled: if not self.enabled:
raise commands.DisabledCommand(f"{self.name} command is disabled") raise DisabledCommand(f"{self.name} command is disabled")
if not await self.can_run(ctx, change_permission_state=True): if not await self.can_run(ctx, change_permission_state=True):
raise commands.CheckFailure( raise CheckFailure(f"The check functions for command {self.qualified_name} failed.")
f"The check functions for command {self.qualified_name} failed."
)
if self.cooldown_after_parsing: if self.cooldown_after_parsing:
await self._parse_arguments(ctx) await self._parse_arguments(ctx)
@ -373,7 +464,7 @@ class Command(CogCommandMixin, commands.Command):
try: try:
return await super().do_conversion(ctx, converter, argument, param) return await super().do_conversion(ctx, converter, argument, param)
except commands.BadArgument as exc: except BadArgument as exc:
raise ConversionFailure(converter, argument, param, *exc.args) from exc raise ConversionFailure(converter, argument, param, *exc.args) from exc
except ValueError as exc: except ValueError as exc:
# Some common converters need special treatment... # Some common converters need special treatment...
@ -408,7 +499,7 @@ class Command(CogCommandMixin, commands.Command):
can_run = await self.can_run( can_run = await self.can_run(
ctx, check_all_parents=True, change_permission_state=False ctx, check_all_parents=True, change_permission_state=False
) )
except (commands.CheckFailure, commands.errors.DisabledCommand): except (CheckFailure, DisabledCommand):
return False return False
else: else:
if can_run is False: if can_run is False:
@ -564,10 +655,9 @@ class GroupMixin(discord.ext.commands.GroupMixin):
class CogGroupMixin: class CogGroupMixin:
requires: Requires requires: Requires
all_commands: Dict[str, Command]
def reevaluate_rules_for( def reevaluate_rules_for(
self, model_id: Union[str, int], guild_id: Optional[int] self, model_id: Union[str, int], guild_id: int = 0
) -> Tuple[PermState, bool]: ) -> Tuple[PermState, bool]:
"""Re-evaluate a rule by checking subcommand rules. """Re-evaluate a rule by checking subcommand rules.
@ -590,15 +680,16 @@ class CogGroupMixin:
""" """
cur_rule = self.requires.get_rule(model_id, guild_id=guild_id) cur_rule = self.requires.get_rule(model_id, guild_id=guild_id)
if cur_rule in (PermState.NORMAL, PermState.ACTIVE_ALLOW, PermState.ACTIVE_DENY): if cur_rule not in (PermState.NORMAL, PermState.ACTIVE_ALLOW, PermState.ACTIVE_DENY):
# These three states are unaffected by subcommand rules # The above three states are unaffected by subcommand rules
return cur_rule, False
else:
# Remaining states can be changed if there exists no actively-allowed # Remaining states can be changed if there exists no actively-allowed
# subcommand (this includes subcommands multiple levels below) # subcommand (this includes subcommands multiple levels below)
all_commands: Dict[str, Command] = getattr(self, "all_commands", {})
if any( if any(
cmd.requires.get_rule(model_id, guild_id=guild_id) in PermState.ALLOWED_STATES cmd.requires.get_rule(model_id, guild_id=guild_id) in PermStateAllowedStates
for cmd in self.all_commands.values() for cmd in all_commands.values()
): ):
return cur_rule, False return cur_rule, False
elif cur_rule is PermState.PASSIVE_ALLOW: elif cur_rule is PermState.PASSIVE_ALLOW:
@ -608,8 +699,11 @@ class CogGroupMixin:
self.requires.set_rule(model_id, PermState.ACTIVE_DENY, guild_id=guild_id) self.requires.set_rule(model_id, PermState.ACTIVE_DENY, guild_id=guild_id)
return PermState.ACTIVE_DENY, True return PermState.ACTIVE_DENY, True
# Default return value
return cur_rule, False
class Group(GroupMixin, Command, CogGroupMixin, commands.Group):
class Group(GroupMixin, Command, CogGroupMixin, DPYGroup):
"""Group command class for Red. """Group command class for Red.
This class inherits from `Command`, with :class:`GroupMixin` and This class inherits from `Command`, with :class:`GroupMixin` and
@ -653,14 +747,6 @@ class Group(GroupMixin, Command, CogGroupMixin, commands.Group):
class CogMixin(CogGroupMixin, CogCommandMixin): class CogMixin(CogGroupMixin, CogCommandMixin):
"""Mixin class for a cog, intended for use with discord.py's cog class""" """Mixin class for a cog, intended for use with discord.py's cog class"""
@property
def all_commands(self) -> Dict[str, Command]:
"""
This does not have identical behavior to
Group.all_commands but should return what you expect
"""
return {cmd.name: cmd for cmd in self.__cog_commands__}
@property @property
def help(self): def help(self):
doc = self.__doc__ doc = self.__doc__
@ -689,7 +775,7 @@ class CogMixin(CogGroupMixin, CogCommandMixin):
try: try:
can_run = await self.requires.verify(ctx) can_run = await self.requires.verify(ctx)
except commands.CommandError: except CommandError:
return False return False
return can_run return can_run
@ -718,16 +804,22 @@ class CogMixin(CogGroupMixin, CogCommandMixin):
return await self.can_run(ctx) return await self.can_run(ctx)
class Cog(CogMixin, commands.Cog): class Cog(CogMixin, DPYCog, metaclass=DPYCogMeta):
""" """
Red's Cog base class Red's Cog base class
This includes a metaclass from discord.py This includes a metaclass from discord.py
""" """
# NB: Do not move the inheritcance of this. Keeping the mix of that metaclass __cog_commands__: Tuple[Command]
# seperate gives us more freedoms in several places.
pass @property
def all_commands(self) -> Dict[str, Command]:
"""
This does not have identical behavior to
Group.all_commands but should return what you expect
"""
return {cmd.name: cmd for cmd in self.__cog_commands__}
def command(name=None, cls=Command, **attrs): def command(name=None, cls=Command, **attrs):
@ -736,7 +828,8 @@ def command(name=None, cls=Command, **attrs):
Same interface as `discord.ext.commands.command`. Same interface as `discord.ext.commands.command`.
""" """
attrs["help_override"] = attrs.pop("help", None) attrs["help_override"] = attrs.pop("help", None)
return commands.command(name, cls, **attrs)
return dpy_command_deco(name, cls, **attrs)
def group(name=None, cls=Group, **attrs): def group(name=None, cls=Group, **attrs):
@ -744,10 +837,10 @@ def group(name=None, cls=Group, **attrs):
Same interface as `discord.ext.commands.group`. Same interface as `discord.ext.commands.group`.
""" """
return command(name, cls, **attrs) return dpy_command_deco(name, cls, **attrs)
__command_disablers = weakref.WeakValueDictionary() __command_disablers: DisablerDictType = weakref.WeakValueDictionary()
def get_command_disabler(guild: discord.Guild) -> Callable[["Context"], Awaitable[bool]]: def get_command_disabler(guild: discord.Guild) -> Callable[["Context"], Awaitable[bool]]:
@ -762,7 +855,7 @@ def get_command_disabler(guild: discord.Guild) -> Callable[["Context"], Awaitabl
async def disabler(ctx: "Context") -> bool: async def disabler(ctx: "Context") -> bool:
if ctx.guild == guild: if ctx.guild == guild:
raise commands.DisabledCommand() raise DisabledCommand()
return True return True
__command_disablers[guild] = disabler __command_disablers[guild] = disabler

View File

@ -1,21 +1,28 @@
from __future__ import annotations
import asyncio import asyncio
import contextlib import contextlib
import os
import re import re
from typing import Iterable, List, Union from typing import Iterable, List, Union, Optional, TYPE_CHECKING
import discord import discord
from discord.ext import commands from discord.ext.commands import Context as DPYContext
from .requires import PermState from .requires import PermState
from ..utils.chat_formatting import box from ..utils.chat_formatting import box
from ..utils.predicates import MessagePredicate from ..utils.predicates import MessagePredicate
from ..utils import common_filters from ..utils import common_filters
if TYPE_CHECKING:
from .commands import Command
from ..bot import Red
TICK = "\N{WHITE HEAVY CHECK MARK}" TICK = "\N{WHITE HEAVY CHECK MARK}"
__all__ = ["Context"] __all__ = ["Context", "GuildContext", "DMContext"]
class Context(commands.Context): class Context(DPYContext):
"""Command invocation context for Red. """Command invocation context for Red.
All context passed into commands will be of this type. All context passed into commands will be of this type.
@ -40,6 +47,10 @@ class Context(commands.Context):
The permission state the current context is in. The permission state the current context is in.
""" """
command: "Command"
invoked_subcommand: "Optional[Command]"
bot: "Red"
def __init__(self, **attrs): def __init__(self, **attrs):
self.assume_yes = attrs.pop("assume_yes", False) self.assume_yes = attrs.pop("assume_yes", False)
super().__init__(**attrs) super().__init__(**attrs)
@ -254,7 +265,7 @@ class Context(commands.Context):
return pattern.sub(f"@{me.display_name}", self.prefix) return pattern.sub(f"@{me.display_name}", self.prefix)
@property @property
def me(self) -> discord.abc.User: def me(self) -> Union[discord.ClientUser, discord.Member]:
"""discord.abc.User: The bot member or user object. """discord.abc.User: The bot member or user object.
If the context is DM, this will be a `discord.User` object. If the context is DM, this will be a `discord.User` object.
@ -263,3 +274,63 @@ class Context(commands.Context):
return self.guild.me return self.guild.me
else: else:
return self.bot.user return self.bot.user
if TYPE_CHECKING or os.getenv("BUILDING_DOCS", False):
class DMContext(Context):
"""
At runtime, this will still be a normal context object.
This lies about some type narrowing for type analysis in commands
using a dm_only decorator.
It is only correct to use when those types are already narrowed
"""
@property
def author(self) -> discord.User:
...
@property
def channel(self) -> discord.DMChannel:
...
@property
def guild(self) -> None:
...
@property
def me(self) -> discord.ClientUser:
...
class GuildContext(Context):
"""
At runtime, this will still be a normal context object.
This lies about some type narrowing for type analysis in commands
using a guild_only decorator.
It is only correct to use when those types are already narrowed
"""
@property
def author(self) -> discord.Member:
...
@property
def channel(self) -> discord.TextChannel:
...
@property
def guild(self) -> discord.Guild:
...
@property
def me(self) -> discord.Member:
...
else:
GuildContext = Context
DMContext = Context

View File

@ -1,14 +1,33 @@
"""
commands.converter
==================
This module contains useful functions and classes for command argument conversion.
Some of the converters within are included provisionaly and are marked as such.
"""
import os
import re import re
import functools import functools
from datetime import timedelta from datetime import timedelta
from typing import TYPE_CHECKING, Optional, List, Dict from typing import (
TYPE_CHECKING,
Generic,
Optional,
Optional as NoParseOptional,
Tuple,
List,
Dict,
Type,
TypeVar,
Literal as Literal,
)
import discord import discord
from discord.ext import commands as dpy_commands from discord.ext import commands as dpy_commands
from discord.ext.commands import BadArgument
from . import BadArgument
from ..i18n import Translator from ..i18n import Translator
from ..utils.chat_formatting import humanize_timedelta from ..utils.chat_formatting import humanize_timedelta, humanize_list
if TYPE_CHECKING: if TYPE_CHECKING:
from .context import Context from .context import Context
@ -17,10 +36,13 @@ __all__ = [
"APIToken", "APIToken",
"DictConverter", "DictConverter",
"GuildConverter", "GuildConverter",
"UserInputOptional",
"NoParseOptional",
"TimedeltaConverter", "TimedeltaConverter",
"get_dict_converter", "get_dict_converter",
"get_timedelta_converter", "get_timedelta_converter",
"parse_timedelta", "parse_timedelta",
"Literal",
] ]
_ = Translator("commands.converter", __file__) _ = Translator("commands.converter", __file__)
@ -67,7 +89,7 @@ def parse_timedelta(
allowed_units : Optional[List[str]] allowed_units : Optional[List[str]]
If provided, you can constrain a user to expressing the amount of time If provided, you can constrain a user to expressing the amount of time
in specific units. The units you can chose to provide are the same as the in specific units. The units you can chose to provide are the same as the
parser understands. `weeks` `days` `hours` `minutes` `seconds` parser understands. (``weeks``, ``days``, ``hours``, ``minutes``, ``seconds``)
Returns Returns
------- -------
@ -138,17 +160,18 @@ class APIToken(discord.ext.commands.Converter):
This will parse the input argument separating the key value pairs into a This will parse the input argument separating the key value pairs into a
format to be used for the core bots API token storage. format to be used for the core bots API token storage.
This will split the argument by either `;` ` `, or `,` and return a dict This will split the argument by a space, comma, or semicolon and return a dict
to be stored. Since all API's are different and have different naming convention, to be stored. Since all API's are different and have different naming convention,
this leaves the onus on the cog creator to clearly define how to setup the correct this leaves the onus on the cog creator to clearly define how to setup the correct
credential names for their cogs. credential names for their cogs.
Note: Core usage of this has been replaced with DictConverter use instead. Note: Core usage of this has been replaced with `DictConverter` use instead.
This may be removed at a later date (with warning) .. warning::
This will be removed in version 3.4.
""" """
async def convert(self, ctx, argument) -> dict: async def convert(self, ctx: "Context", argument) -> dict:
bot = ctx.bot bot = ctx.bot
result = {} result = {}
match = re.split(r";|,| ", argument) match = re.split(r";|,| ", argument)
@ -162,6 +185,15 @@ class APIToken(discord.ext.commands.Converter):
return result return result
# Below this line are a lot of lies for mypy about things that *end up* correct when
# These are used for command conversion purposes. Please refer to the portion
# which is *not* for type checking for the actual implementation
# and ensure the lies stay correct for how the object should look as a typehint
if TYPE_CHECKING:
DictConverter = Dict[str, str]
else:
class DictConverter(dpy_commands.Converter): class DictConverter(dpy_commands.Converter):
""" """
Converts pairs of space seperated values to a dict Converts pairs of space seperated values to a dict
@ -173,7 +205,6 @@ class DictConverter(dpy_commands.Converter):
self.pattern = re.compile(r"|".join(re.escape(d) for d in self.delims)) self.pattern = re.compile(r"|".join(re.escape(d) for d in self.delims))
async def convert(self, ctx: "Context", argument: str) -> Dict[str, str]: async def convert(self, ctx: "Context", argument: str) -> Dict[str, str]:
ret: Dict[str, str] = {} ret: Dict[str, str] = {}
args = self.pattern.split(argument) args = self.pattern.split(argument)
@ -191,12 +222,20 @@ class DictConverter(dpy_commands.Converter):
return ret return ret
def get_dict_converter(*expected_keys: str, delims: Optional[List[str]] = None) -> type: if TYPE_CHECKING:
def get_dict_converter(*expected_keys: str, delims: Optional[List[str]] = None) -> Type[dict]:
...
else:
def get_dict_converter(*expected_keys: str, delims: Optional[List[str]] = None) -> Type[dict]:
""" """
Returns a typechecking safe `DictConverter` suitable for use with discord.py Returns a typechecking safe `DictConverter` suitable for use with discord.py
""" """
class PartialMeta(type(DictConverter)): class PartialMeta(type):
__call__ = functools.partialmethod( __call__ = functools.partialmethod(
type(DictConverter).__call__, *expected_keys, delims=delims type(DictConverter).__call__, *expected_keys, delims=delims
) )
@ -207,6 +246,10 @@ def get_dict_converter(*expected_keys: str, delims: Optional[List[str]] = None)
return ValidatedConverter return ValidatedConverter
if TYPE_CHECKING:
TimedeltaConverter = timedelta
else:
class TimedeltaConverter(dpy_commands.Converter): class TimedeltaConverter(dpy_commands.Converter):
""" """
This is a converter for timedeltas. This is a converter for timedeltas.
@ -223,11 +266,11 @@ class TimedeltaConverter(dpy_commands.Converter):
If provided, any parsed value lower than this will raise an exception If provided, any parsed value lower than this will raise an exception
allowed_units : Optional[List[str]] allowed_units : Optional[List[str]]
If provided, you can constrain a user to expressing the amount of time If provided, you can constrain a user to expressing the amount of time
in specific units. The units you can chose to provide are the same as the in specific units. The units you can choose to provide are the same as the
parser understands: `weeks` `days` `hours` `minutes` `seconds` parser understands: (``weeks``, ``days``, ``hours``, ``minutes``, ``seconds``)
default_unit : Optional[str] default_unit : Optional[str]
If provided, it will additionally try to match integer-only input into If provided, it will additionally try to match integer-only input into
a timedelta, using the unit specified. Same units as in `allowed_units` a timedelta, using the unit specified. Same units as in ``allowed_units``
apply. apply.
""" """
@ -252,13 +295,27 @@ class TimedeltaConverter(dpy_commands.Converter):
raise BadArgument() # This allows this to be a required argument. raise BadArgument() # This allows this to be a required argument.
if TYPE_CHECKING:
def get_timedelta_converter( def get_timedelta_converter(
*, *,
default_unit: Optional[str] = None, default_unit: Optional[str] = None,
maximum: Optional[timedelta] = None, maximum: Optional[timedelta] = None,
minimum: Optional[timedelta] = None, minimum: Optional[timedelta] = None,
allowed_units: Optional[List[str]] = None, allowed_units: Optional[List[str]] = None,
) -> type: ) -> Type[timedelta]:
...
else:
def get_timedelta_converter(
*,
default_unit: Optional[str] = None,
maximum: Optional[timedelta] = None,
minimum: Optional[timedelta] = None,
allowed_units: Optional[List[str]] = None,
) -> Type[timedelta]:
""" """
This creates a type suitable for typechecking which works with discord.py's This creates a type suitable for typechecking which works with discord.py's
commands. commands.
@ -273,11 +330,11 @@ def get_timedelta_converter(
If provided, any parsed value lower than this will raise an exception If provided, any parsed value lower than this will raise an exception
allowed_units : Optional[List[str]] allowed_units : Optional[List[str]]
If provided, you can constrain a user to expressing the amount of time If provided, you can constrain a user to expressing the amount of time
in specific units. The units you can chose to provide are the same as the in specific units. The units you can choose to provide are the same as the
parser understands: `weeks` `days` `hours` `minutes` `seconds` parser understands: (``weeks``, ``days``, ``hours``, ``minutes``, ``seconds``)
default_unit : Optional[str] default_unit : Optional[str]
If provided, it will additionally try to match integer-only input into If provided, it will additionally try to match integer-only input into
a timedelta, using the unit specified. Same units as in `allowed_units` a timedelta, using the unit specified. Same units as in ``allowed_units``
apply. apply.
Returns Returns
@ -286,7 +343,7 @@ def get_timedelta_converter(
The converter class, which will be a subclass of `TimedeltaConverter` The converter class, which will be a subclass of `TimedeltaConverter`
""" """
class PartialMeta(type(TimedeltaConverter)): class PartialMeta(type):
__call__ = functools.partialmethod( __call__ = functools.partialmethod(
type(DictConverter).__call__, type(DictConverter).__call__,
allowed_units=allowed_units, allowed_units=allowed_units,
@ -299,3 +356,91 @@ def get_timedelta_converter(
pass pass
return ValidatedConverter return ValidatedConverter
if not TYPE_CHECKING:
class NoParseOptional:
"""
This can be used instead of `typing.Optional`
to avoid discord.py special casing the conversion behavior.
.. warning::
This converter class is still provisional.
.. seealso::
The `ignore_optional_for_conversion` option of commands.
"""
def __class_getitem__(cls, key):
if isinstance(key, tuple):
raise TypeError("Must only provide a single type to Optional")
return key
_T_OPT = TypeVar("_T_OPT", bound=Type)
if TYPE_CHECKING or os.getenv("BUILDING_DOCS", False):
class UserInputOptional(Generic[_T_OPT]):
"""
This can be used when user input should be converted as discord.py
treats `typing.Optional`, but the type should not be equivalent to
``typing.Union[DesiredType, None]`` for type checking.
.. warning::
This converter class is still provisional.
This class may not play well with mypy yet
and may still require you guard this in a
type checking conditional import vs the desired types
We're aware and looking into improving this.
"""
def __class_getitem__(cls, key: _T_OPT) -> _T_OPT:
if isinstance(key, tuple):
raise TypeError("Must only provide a single type to Optional")
return key
else:
UserInputOptional = Optional
if not TYPE_CHECKING:
class Literal(dpy_commands.Converter):
"""
This can be used as a converter for `typing.Literal`.
In a type checking context it is `typing.Literal`.
In a runtime context, it's a converter which only matches the literals it was given.
.. warning::
This converter class is still provisional.
"""
def __init__(self, valid_names: Tuple[str]):
self.valid_names = valid_names
def __call__(self, ctx, arg):
# Callable's are treated as valid types:
# https://github.com/python/cpython/blob/3.8/Lib/typing.py#L148
# Without this, ``typing.Union[Literal["clear"], bool]`` would fail
return self.convert(ctx, arg)
async def convert(self, ctx, arg):
if arg in self.valid_names:
return arg
raise BadArgument(_("Expected one of: {}").format(humanize_list(self.valid_names)))
def __class_getitem__(cls, k):
if not k:
raise ValueError("Need at least one value for Literal")
if isinstance(k, tuple):
return cls(k)
else:
return cls((k,))

View File

@ -8,6 +8,7 @@ checks like bot permissions checks.
""" """
import asyncio import asyncio
import enum import enum
import inspect
from typing import ( from typing import (
Union, Union,
Optional, Optional,
@ -45,6 +46,7 @@ __all__ = [
"permissions_check", "permissions_check",
"bot_has_permissions", "bot_has_permissions",
"has_permissions", "has_permissions",
"has_guild_permissions",
"is_owner", "is_owner",
"guildowner", "guildowner",
"guildowner_or_permissions", "guildowner_or_permissions",
@ -52,6 +54,9 @@ __all__ = [
"admin_or_permissions", "admin_or_permissions",
"mod", "mod",
"mod_or_permissions", "mod_or_permissions",
"transition_permstate_to",
"PermStateTransitions",
"PermStateAllowedStates",
] ]
_T = TypeVar("_T") _T = TypeVar("_T")
@ -182,11 +187,6 @@ class PermState(enum.Enum):
"""This command has been actively denied by a permission hook """This command has been actively denied by a permission hook
check validation doesn't need this, but is useful to developers""" check validation doesn't need this, but is useful to developers"""
def transition_to(
self, next_state: "PermState"
) -> Tuple[Optional[bool], Union["PermState", Dict[bool, "PermState"]]]:
return self.TRANSITIONS[self][next_state]
@classmethod @classmethod
def from_bool(cls, value: Optional[bool]) -> "PermState": def from_bool(cls, value: Optional[bool]) -> "PermState":
"""Get a PermState from a bool or ``NoneType``.""" """Get a PermState from a bool or ``NoneType``."""
@ -211,7 +211,11 @@ class PermState(enum.Enum):
# result of the default permission checks - the transition from NORMAL # result of the default permission checks - the transition from NORMAL
# to PASSIVE_ALLOW. In this case "next state" is a dict mapping the # to PASSIVE_ALLOW. In this case "next state" is a dict mapping the
# permission check results to the actual next state. # permission check results to the actual next state.
PermState.TRANSITIONS = {
TransitionResult = Tuple[Optional[bool], Union[PermState, Dict[bool, PermState]]]
TransitionDict = Dict[PermState, Dict[PermState, TransitionResult]]
PermStateTransitions: TransitionDict = {
PermState.ACTIVE_ALLOW: { PermState.ACTIVE_ALLOW: {
PermState.ACTIVE_ALLOW: (True, PermState.ACTIVE_ALLOW), PermState.ACTIVE_ALLOW: (True, PermState.ACTIVE_ALLOW),
PermState.NORMAL: (True, PermState.ACTIVE_ALLOW), PermState.NORMAL: (True, PermState.ACTIVE_ALLOW),
@ -248,13 +252,18 @@ PermState.TRANSITIONS = {
PermState.ACTIVE_DENY: (False, PermState.ACTIVE_DENY), PermState.ACTIVE_DENY: (False, PermState.ACTIVE_DENY),
}, },
} }
PermState.ALLOWED_STATES = (
PermStateAllowedStates = (
PermState.ACTIVE_ALLOW, PermState.ACTIVE_ALLOW,
PermState.PASSIVE_ALLOW, PermState.PASSIVE_ALLOW,
PermState.CAUTIOUS_ALLOW, PermState.CAUTIOUS_ALLOW,
) )
def transition_permstate_to(prev: PermState, next_state: PermState) -> TransitionResult:
return PermStateTransitions[prev][next_state]
class Requires: class Requires:
"""This class describes the requirements for executing a specific command. """This class describes the requirements for executing a specific command.
@ -326,13 +335,13 @@ class Requires:
@staticmethod @staticmethod
def get_decorator( def get_decorator(
privilege_level: Optional[PrivilegeLevel], user_perms: Dict[str, bool] privilege_level: Optional[PrivilegeLevel], user_perms: Optional[Dict[str, bool]]
) -> Callable[["_CommandOrCoro"], "_CommandOrCoro"]: ) -> Callable[["_CommandOrCoro"], "_CommandOrCoro"]:
if not user_perms: if not user_perms:
user_perms = None user_perms = None
def decorator(func: "_CommandOrCoro") -> "_CommandOrCoro": def decorator(func: "_CommandOrCoro") -> "_CommandOrCoro":
if asyncio.iscoroutinefunction(func): if inspect.iscoroutinefunction(func):
func.__requires_privilege_level__ = privilege_level func.__requires_privilege_level__ = privilege_level
func.__requires_user_perms__ = user_perms func.__requires_user_perms__ = user_perms
else: else:
@ -341,6 +350,7 @@ class Requires:
func.requires.user_perms = None func.requires.user_perms = None
else: else:
_validate_perms_dict(user_perms) _validate_perms_dict(user_perms)
assert func.requires.user_perms is not None
func.requires.user_perms.update(**user_perms) func.requires.user_perms.update(**user_perms)
return func return func
@ -488,7 +498,7 @@ class Requires:
async def _transition_state(self, ctx: "Context") -> bool: async def _transition_state(self, ctx: "Context") -> bool:
prev_state = ctx.permission_state prev_state = ctx.permission_state
cur_state = self._get_rule_from_ctx(ctx) cur_state = self._get_rule_from_ctx(ctx)
should_invoke, next_state = prev_state.transition_to(cur_state) should_invoke, next_state = transition_permstate_to(prev_state, cur_state)
if should_invoke is None: if should_invoke is None:
# NORMAL invokation, we simply follow standard procedure # NORMAL invokation, we simply follow standard procedure
should_invoke = await self._verify_user(ctx) should_invoke = await self._verify_user(ctx)
@ -509,6 +519,7 @@ class Requires:
would_invoke = await self._verify_user(ctx) would_invoke = await self._verify_user(ctx)
next_state = next_state[would_invoke] next_state = next_state[would_invoke]
assert isinstance(next_state, PermState)
ctx.permission_state = next_state ctx.permission_state = next_state
return should_invoke return should_invoke
@ -635,6 +646,20 @@ def permissions_check(predicate: CheckPredicate):
return decorator return decorator
def has_guild_permissions(**perms):
"""Restrict the command to users with these guild permissions.
This check can be overridden by rules.
"""
_validate_perms_dict(perms)
def predicate(ctx):
return ctx.guild and ctx.author.guild_permissions >= discord.Permissions(**perms)
return permissions_check(predicate)
def bot_has_permissions(**perms: bool): def bot_has_permissions(**perms: bool):
"""Complain if the bot is missing permissions. """Complain if the bot is missing permissions.

View File

@ -979,7 +979,7 @@ class Config:
""" """
return self._get_base_group(self.CHANNEL, str(channel_id)) return self._get_base_group(self.CHANNEL, str(channel_id))
def channel(self, channel: discord.TextChannel) -> Group: def channel(self, channel: discord.abc.GuildChannel) -> Group:
"""Returns a `Group` for the given channel. """Returns a `Group` for the given channel.
This does not discriminate between text and voice channels. This does not discriminate between text and voice channels.

View File

@ -1,3 +1,5 @@
from __future__ import annotations
import re import re
from typing import Callable, ClassVar, List, Optional, Pattern, Sequence, Tuple, Union, cast from typing import Callable, ClassVar, List, Optional, Pattern, Sequence, Tuple, Union, cast