mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 11:18:54 -05:00
[Commands Module] Improve usability of type hints (#3410)
* [Commands Module] Better Typehint Support We now do a lot more with type hints - No more rexporting d.py commands submodules - New type aliases for GuildContext & DMContext - More things are typehinted Note: Some things are still not typed, others are still incorrectly typed, This is progress. Co-authored-by: jack1142 <6032823+jack1142@users.noreply.github.com>
This commit is contained in:
parent
8654924869
commit
a8450580e8
@ -23,5 +23,14 @@ extend functionalities used throughout the bot, as outlined below.
|
||||
.. autoclass:: redbot.core.commands.Context
|
||||
:members:
|
||||
|
||||
.. autoclass:: redbot.core.commands.GuildContext
|
||||
|
||||
.. autoclass:: redbot.core.commands.DMContext
|
||||
|
||||
.. automodule:: redbot.core.commands.requires
|
||||
:members: PrivilegeLevel, PermState, Requires
|
||||
|
||||
.. automodule:: redbot.core.commands.converter
|
||||
:members:
|
||||
:exclude-members: convert
|
||||
:no-undoc-members:
|
||||
|
||||
@ -26,6 +26,7 @@ from typing import (
|
||||
from types import MappingProxyType
|
||||
|
||||
import discord
|
||||
from discord.ext import commands as dpy_commands
|
||||
from discord.ext.commands import when_mentioned_or
|
||||
from discord.ext.commands.bot import BotBase
|
||||
|
||||
@ -60,7 +61,9 @@ def _is_submodule(parent, child):
|
||||
|
||||
|
||||
# barely spurious warning caused by our intentional shadowing
|
||||
class RedBase(commands.GroupMixin, BotBase, RPCMixin): # pylint: disable=no-member
|
||||
class RedBase(
|
||||
commands.GroupMixin, dpy_commands.bot.BotBase, RPCMixin
|
||||
): # pylint: disable=no-member
|
||||
"""Mixin for the main bot class.
|
||||
|
||||
This exists because `Red` inherits from `discord.AutoShardedClient`, which
|
||||
@ -163,6 +166,16 @@ class RedBase(commands.GroupMixin, BotBase, RPCMixin): # pylint: disable=no-mem
|
||||
self._red_ready = asyncio.Event()
|
||||
self._red_before_invoke_objs: Set[PreInvokeCoroutine] = set()
|
||||
|
||||
def get_command(self, name: str) -> Optional[commands.Command]:
|
||||
com = super().get_command(name)
|
||||
assert com is None or isinstance(com, commands.Command)
|
||||
return com
|
||||
|
||||
def get_cog(self, name: str) -> Optional[commands.Cog]:
|
||||
cog = super().get_cog(name)
|
||||
assert cog is None or isinstance(cog, commands.Cog)
|
||||
return cog
|
||||
|
||||
@property
|
||||
def _before_invoke(self): # DEP-WARN
|
||||
return self._red_before_invoke_method
|
||||
|
||||
@ -1,7 +1,145 @@
|
||||
from discord.ext.commands import *
|
||||
from .commands import *
|
||||
from .context import *
|
||||
from .converter import *
|
||||
from .errors import *
|
||||
from .requires import *
|
||||
from .help import *
|
||||
########## SENSITIVE SECTION WARNING ###########
|
||||
################################################
|
||||
# Any edits of any of the exported names #
|
||||
# may result in a breaking change. #
|
||||
# Ensure no names are removed without warning. #
|
||||
################################################
|
||||
|
||||
from .commands import (
|
||||
Cog as Cog,
|
||||
CogMixin as CogMixin,
|
||||
CogCommandMixin as CogCommandMixin,
|
||||
CogGroupMixin as CogGroupMixin,
|
||||
Command as Command,
|
||||
Group as Group,
|
||||
GroupMixin as GroupMixin,
|
||||
command as command,
|
||||
group as group,
|
||||
RESERVED_COMMAND_NAMES as RESERVED_COMMAND_NAMES,
|
||||
)
|
||||
from .context import Context as Context, GuildContext as GuildContext, DMContext as DMContext
|
||||
from .converter import (
|
||||
APIToken as APIToken,
|
||||
DictConverter as DictConverter,
|
||||
GuildConverter as GuildConverter,
|
||||
TimedeltaConverter as TimedeltaConverter,
|
||||
get_dict_converter as get_dict_converter,
|
||||
get_timedelta_converter as get_timedelta_converter,
|
||||
parse_timedelta as parse_timedelta,
|
||||
NoParseOptional as NoParseOptional,
|
||||
UserInputOptional as UserInputOptional,
|
||||
Literal as Literal,
|
||||
)
|
||||
from .errors import (
|
||||
ConversionFailure as ConversionFailure,
|
||||
BotMissingPermissions as BotMissingPermissions,
|
||||
UserFeedbackCheckFailure as UserFeedbackCheckFailure,
|
||||
ArgParserFailure as ArgParserFailure,
|
||||
)
|
||||
from .help import (
|
||||
red_help as red_help,
|
||||
RedHelpFormatter as RedHelpFormatter,
|
||||
HelpSettings as HelpSettings,
|
||||
)
|
||||
from .requires import (
|
||||
CheckPredicate as CheckPredicate,
|
||||
DM_PERMS as DM_PERMS,
|
||||
GlobalPermissionModel as GlobalPermissionModel,
|
||||
GuildPermissionModel as GuildPermissionModel,
|
||||
PermissionModel as PermissionModel,
|
||||
PrivilegeLevel as PrivilegeLevel,
|
||||
PermState as PermState,
|
||||
Requires as Requires,
|
||||
permissions_check as permissions_check,
|
||||
bot_has_permissions as bot_has_permissions,
|
||||
has_permissions as has_permissions,
|
||||
has_guild_permissions as has_guild_permissions,
|
||||
is_owner as is_owner,
|
||||
guildowner as guildowner,
|
||||
guildowner_or_permissions as guildowner_or_permissions,
|
||||
admin as admin,
|
||||
admin_or_permissions as admin_or_permissions,
|
||||
mod as mod,
|
||||
mod_or_permissions as mod_or_permissions,
|
||||
)
|
||||
|
||||
from ._dpy_reimplements import (
|
||||
check as check,
|
||||
guild_only as guild_only,
|
||||
cooldown as cooldown,
|
||||
dm_only as dm_only,
|
||||
is_nsfw as is_nsfw,
|
||||
has_role as has_role,
|
||||
has_any_role as has_any_role,
|
||||
bot_has_role as bot_has_role,
|
||||
when_mentioned_or as when_mentioned_or,
|
||||
when_mentioned as when_mentioned,
|
||||
bot_has_any_role as bot_has_any_role,
|
||||
)
|
||||
|
||||
### DEP-WARN: Check this *every* discord.py update
|
||||
from discord.ext.commands import (
|
||||
BadArgument as BadArgument,
|
||||
EmojiConverter as EmojiConverter,
|
||||
InvalidEndOfQuotedStringError as InvalidEndOfQuotedStringError,
|
||||
MemberConverter as MemberConverter,
|
||||
BotMissingRole as BotMissingRole,
|
||||
PrivateMessageOnly as PrivateMessageOnly,
|
||||
HelpCommand as HelpCommand,
|
||||
MinimalHelpCommand as MinimalHelpCommand,
|
||||
DisabledCommand as DisabledCommand,
|
||||
ExtensionFailed as ExtensionFailed,
|
||||
Bot as Bot,
|
||||
NotOwner as NotOwner,
|
||||
CategoryChannelConverter as CategoryChannelConverter,
|
||||
CogMeta as CogMeta,
|
||||
ConversionError as ConversionError,
|
||||
UserInputError as UserInputError,
|
||||
Converter as Converter,
|
||||
InviteConverter as InviteConverter,
|
||||
ExtensionError as ExtensionError,
|
||||
Cooldown as Cooldown,
|
||||
CheckFailure as CheckFailure,
|
||||
MessageConverter as MessageConverter,
|
||||
MissingPermissions as MissingPermissions,
|
||||
BadUnionArgument as BadUnionArgument,
|
||||
DefaultHelpCommand as DefaultHelpCommand,
|
||||
ExtensionNotFound as ExtensionNotFound,
|
||||
UserConverter as UserConverter,
|
||||
MissingRole as MissingRole,
|
||||
CommandOnCooldown as CommandOnCooldown,
|
||||
MissingAnyRole as MissingAnyRole,
|
||||
ExtensionNotLoaded as ExtensionNotLoaded,
|
||||
clean_content as clean_content,
|
||||
CooldownMapping as CooldownMapping,
|
||||
ArgumentParsingError as ArgumentParsingError,
|
||||
RoleConverter as RoleConverter,
|
||||
CommandError as CommandError,
|
||||
TextChannelConverter as TextChannelConverter,
|
||||
UnexpectedQuoteError as UnexpectedQuoteError,
|
||||
Paginator as Paginator,
|
||||
BucketType as BucketType,
|
||||
NoEntryPointError as NoEntryPointError,
|
||||
CommandInvokeError as CommandInvokeError,
|
||||
TooManyArguments as TooManyArguments,
|
||||
Greedy as Greedy,
|
||||
ExpectedClosingQuoteError as ExpectedClosingQuoteError,
|
||||
ColourConverter as ColourConverter,
|
||||
VoiceChannelConverter as VoiceChannelConverter,
|
||||
NSFWChannelRequired as NSFWChannelRequired,
|
||||
IDConverter as IDConverter,
|
||||
MissingRequiredArgument as MissingRequiredArgument,
|
||||
GameConverter as GameConverter,
|
||||
CommandNotFound as CommandNotFound,
|
||||
BotMissingAnyRole as BotMissingAnyRole,
|
||||
NoPrivateMessage as NoPrivateMessage,
|
||||
AutoShardedBot as AutoShardedBot,
|
||||
ExtensionAlreadyLoaded as ExtensionAlreadyLoaded,
|
||||
PartialEmojiConverter as PartialEmojiConverter,
|
||||
check_any as check_any,
|
||||
max_concurrency as max_concurrency,
|
||||
CheckAnyFailure as CheckAnyFailure,
|
||||
MaxConcurrency as MaxConcurrency,
|
||||
MaxConcurrencyReached as MaxConcurrencyReached,
|
||||
bot_has_guild_permissions as bot_has_guild_permissions,
|
||||
)
|
||||
|
||||
126
redbot/core/commands/_dpy_reimplements.py
Normal file
126
redbot/core/commands/_dpy_reimplements.py
Normal file
@ -0,0 +1,126 @@
|
||||
from __future__ import annotations
|
||||
import inspect
|
||||
import functools
|
||||
from typing import (
|
||||
TypeVar,
|
||||
Callable,
|
||||
Awaitable,
|
||||
Coroutine,
|
||||
Union,
|
||||
Type,
|
||||
TYPE_CHECKING,
|
||||
List,
|
||||
Any,
|
||||
Generator,
|
||||
Protocol,
|
||||
overload,
|
||||
)
|
||||
|
||||
import discord
|
||||
from discord.ext import commands as dpy_commands
|
||||
|
||||
# So much of this can be stripped right back out with proper stubs.
|
||||
if not TYPE_CHECKING:
|
||||
from discord.ext.commands import (
|
||||
check as check,
|
||||
guild_only as guild_only,
|
||||
dm_only as dm_only,
|
||||
is_nsfw as is_nsfw,
|
||||
has_role as has_role,
|
||||
has_any_role as has_any_role,
|
||||
bot_has_role as bot_has_role,
|
||||
bot_has_any_role as bot_has_any_role,
|
||||
cooldown as cooldown,
|
||||
)
|
||||
|
||||
from ..i18n import Translator
|
||||
from .context import Context
|
||||
from .commands import Command
|
||||
|
||||
|
||||
_ = Translator("nah", __file__)
|
||||
|
||||
|
||||
"""
|
||||
Anything here is either a reimplementation or re-export
|
||||
of a discord.py funtion or class with more lies for mypy
|
||||
"""
|
||||
|
||||
__all__ = [
|
||||
"check",
|
||||
# "check_any", # discord.py 1.3
|
||||
"guild_only",
|
||||
"dm_only",
|
||||
"is_nsfw",
|
||||
"has_role",
|
||||
"has_any_role",
|
||||
"bot_has_role",
|
||||
"bot_has_any_role",
|
||||
"when_mentioned_or",
|
||||
"cooldown",
|
||||
"when_mentioned",
|
||||
]
|
||||
|
||||
_CT = TypeVar("_CT", bound=Context)
|
||||
_T = TypeVar("_T")
|
||||
_F = TypeVar("_F")
|
||||
CheckType = Union[Callable[[_CT], bool], Callable[[_CT], Coroutine[Any, Any, bool]]]
|
||||
CoroLike = Callable[..., Union[Awaitable[_T], Generator[Any, None, _T]]]
|
||||
|
||||
|
||||
class CheckDecorator(Protocol):
|
||||
predicate: Coroutine[Any, Any, bool]
|
||||
|
||||
@overload
|
||||
def __call__(self, func: _CT) -> _CT:
|
||||
...
|
||||
|
||||
@overload
|
||||
def __call__(self, func: CoroLike) -> CoroLike:
|
||||
...
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def check(predicate: CheckType) -> CheckDecorator:
|
||||
...
|
||||
|
||||
def guild_only() -> CheckDecorator:
|
||||
...
|
||||
|
||||
def dm_only() -> CheckDecorator:
|
||||
...
|
||||
|
||||
def is_nsfw() -> CheckDecorator:
|
||||
...
|
||||
|
||||
def has_role() -> CheckDecorator:
|
||||
...
|
||||
|
||||
def has_any_role() -> CheckDecorator:
|
||||
...
|
||||
|
||||
def bot_has_role() -> CheckDecorator:
|
||||
...
|
||||
|
||||
def bot_has_any_role() -> CheckDecorator:
|
||||
...
|
||||
|
||||
def cooldown(rate: int, per: float, type: dpy_commands.BucketType = ...) -> Callable[[_F], _F]:
|
||||
...
|
||||
|
||||
|
||||
PrefixCallable = Callable[[dpy_commands.bot.BotBase, discord.Message], List[str]]
|
||||
|
||||
|
||||
def when_mentioned(bot: dpy_commands.bot.BotBase, msg: discord.Message) -> List[str]:
|
||||
return [f"<@{bot.user.id}> ", f"<@!{bot.user.id}> "]
|
||||
|
||||
|
||||
def when_mentioned_or(*prefixes) -> PrefixCallable:
|
||||
def inner(bot: dpy_commands.bot.BotBase, msg: discord.Message) -> List[str]:
|
||||
r = list(prefixes)
|
||||
r = when_mentioned(bot, msg) + r
|
||||
return r
|
||||
|
||||
return inner
|
||||
@ -1,24 +1,53 @@
|
||||
"""Module for command helpers and classes.
|
||||
|
||||
This module contains extended classes and functions which are intended to
|
||||
replace those from the `discord.ext.commands` module.
|
||||
be used instead of those from the `discord.ext.commands` module.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import re
|
||||
import weakref
|
||||
from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
|
||||
from typing import (
|
||||
Awaitable,
|
||||
Callable,
|
||||
Coroutine,
|
||||
TypeVar,
|
||||
Type,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
Tuple,
|
||||
Union,
|
||||
MutableMapping,
|
||||
TYPE_CHECKING,
|
||||
cast,
|
||||
)
|
||||
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import (
|
||||
BadArgument,
|
||||
CommandError,
|
||||
CheckFailure,
|
||||
DisabledCommand,
|
||||
command as dpy_command_deco,
|
||||
Command as DPYCommand,
|
||||
Cog as DPYCog,
|
||||
CogMeta as DPYCogMeta,
|
||||
Group as DPYGroup,
|
||||
Greedy,
|
||||
)
|
||||
|
||||
from . import converter as converters
|
||||
from .errors import ConversionFailure
|
||||
from .requires import PermState, PrivilegeLevel, Requires
|
||||
from .requires import PermState, PrivilegeLevel, Requires, PermStateAllowedStates
|
||||
from ..i18n import Translator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
# circular import avoidance
|
||||
from .context import Context
|
||||
|
||||
|
||||
__all__ = [
|
||||
"Cog",
|
||||
"CogMixin",
|
||||
@ -38,11 +67,17 @@ RESERVED_COMMAND_NAMES = (
|
||||
)
|
||||
|
||||
_ = Translator("commands.commands", __file__)
|
||||
DisablerDictType = MutableMapping[discord.Guild, Callable[["Context"], Awaitable[bool]]]
|
||||
|
||||
|
||||
class CogCommandMixin:
|
||||
"""A mixin for cogs and commands."""
|
||||
|
||||
@property
|
||||
def help(self) -> str:
|
||||
"""To be defined by subclasses"""
|
||||
...
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
if isinstance(self, Command):
|
||||
@ -182,7 +217,7 @@ class CogCommandMixin:
|
||||
self.deny_to(Requires.DEFAULT, guild_id=guild_id)
|
||||
|
||||
|
||||
class Command(CogCommandMixin, commands.Command):
|
||||
class Command(CogCommandMixin, DPYCommand):
|
||||
"""Command class for Red.
|
||||
|
||||
This should not be created directly, and instead via the decorator.
|
||||
@ -198,7 +233,10 @@ class Command(CogCommandMixin, commands.Command):
|
||||
`Requires.checks`.
|
||||
translator : Translator
|
||||
A translator for this command's help docstring.
|
||||
|
||||
ignore_optional_for_conversion : bool
|
||||
A value which can be set to not have discord.py's
|
||||
argument parsing behavior for ``typing.Optional``
|
||||
(type used will be of the inner type instead)
|
||||
"""
|
||||
|
||||
def __call__(self, *args, **kwargs):
|
||||
@ -209,6 +247,7 @@ class Command(CogCommandMixin, commands.Command):
|
||||
return self.callback(*args, **kwargs)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.ignore_optional_for_conversion = kwargs.pop("ignore_optional_for_conversion", False)
|
||||
super().__init__(*args, **kwargs)
|
||||
self._help_override = kwargs.pop("help_override", None)
|
||||
self.translator = kwargs.pop("i18n", None)
|
||||
@ -229,8 +268,62 @@ class Command(CogCommandMixin, commands.Command):
|
||||
|
||||
# Red specific
|
||||
other.requires = self.requires
|
||||
other.ignore_optional_for_conversion = self.ignore_optional_for_conversion
|
||||
return other
|
||||
|
||||
@property
|
||||
def callback(self):
|
||||
return self._callback
|
||||
|
||||
@callback.setter
|
||||
def callback(self, function):
|
||||
"""
|
||||
Below should be mostly the same as discord.py
|
||||
The only (current) change is to filter out typing.Optional
|
||||
if a user has specified the desire for this behavior
|
||||
"""
|
||||
self._callback = function
|
||||
self.module = function.__module__
|
||||
|
||||
signature = inspect.signature(function)
|
||||
self.params = signature.parameters.copy()
|
||||
|
||||
# PEP-563 allows postponing evaluation of annotations with a __future__
|
||||
# import. When postponed, Parameter.annotation will be a string and must
|
||||
# be replaced with the real value for the converters to work later on
|
||||
for key, value in self.params.items():
|
||||
if isinstance(value.annotation, str):
|
||||
self.params[key] = value = value.replace(
|
||||
annotation=eval(value.annotation, function.__globals__)
|
||||
)
|
||||
|
||||
# fail early for when someone passes an unparameterized Greedy type
|
||||
if value.annotation is Greedy:
|
||||
raise TypeError("Unparameterized Greedy[...] is disallowed in signature.")
|
||||
|
||||
if not self.ignore_optional_for_conversion:
|
||||
continue # reduces indentation compared to alternative
|
||||
|
||||
try:
|
||||
vtype = value.annotation.__origin__
|
||||
if vtype is Union:
|
||||
_NoneType = type if TYPE_CHECKING else type(None)
|
||||
args = value.annotation.__args__
|
||||
if _NoneType in args:
|
||||
args = tuple(a for a in args if a is not _NoneType)
|
||||
if len(args) == 1:
|
||||
# can't have a union of 1 or 0 items
|
||||
# 1 prevents this from becoming 0
|
||||
# we need to prevent 2 become 1
|
||||
# (Don't change that to becoming, it's intentional :musical_note:)
|
||||
self.params[key] = value = value.replace(annotation=args[0])
|
||||
else:
|
||||
# and mypy wretches at the correct Union[args]
|
||||
temp_type = type if TYPE_CHECKING else Union[args]
|
||||
self.params[key] = value = value.replace(annotation=temp_type)
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
@property
|
||||
def help(self):
|
||||
"""Help string for this command.
|
||||
@ -311,7 +404,7 @@ class Command(CogCommandMixin, commands.Command):
|
||||
for parent in reversed(self.parents):
|
||||
try:
|
||||
result = await parent.can_run(ctx, change_permission_state=True)
|
||||
except commands.CommandError:
|
||||
except CommandError:
|
||||
result = False
|
||||
|
||||
if result is False:
|
||||
@ -334,12 +427,10 @@ class Command(CogCommandMixin, commands.Command):
|
||||
ctx.command = self
|
||||
|
||||
if not self.enabled:
|
||||
raise commands.DisabledCommand(f"{self.name} command is disabled")
|
||||
raise DisabledCommand(f"{self.name} command is disabled")
|
||||
|
||||
if not await self.can_run(ctx, change_permission_state=True):
|
||||
raise commands.CheckFailure(
|
||||
f"The check functions for command {self.qualified_name} failed."
|
||||
)
|
||||
raise CheckFailure(f"The check functions for command {self.qualified_name} failed.")
|
||||
|
||||
if self.cooldown_after_parsing:
|
||||
await self._parse_arguments(ctx)
|
||||
@ -373,7 +464,7 @@ class Command(CogCommandMixin, commands.Command):
|
||||
|
||||
try:
|
||||
return await super().do_conversion(ctx, converter, argument, param)
|
||||
except commands.BadArgument as exc:
|
||||
except BadArgument as exc:
|
||||
raise ConversionFailure(converter, argument, param, *exc.args) from exc
|
||||
except ValueError as exc:
|
||||
# Some common converters need special treatment...
|
||||
@ -408,7 +499,7 @@ class Command(CogCommandMixin, commands.Command):
|
||||
can_run = await self.can_run(
|
||||
ctx, check_all_parents=True, change_permission_state=False
|
||||
)
|
||||
except (commands.CheckFailure, commands.errors.DisabledCommand):
|
||||
except (CheckFailure, DisabledCommand):
|
||||
return False
|
||||
else:
|
||||
if can_run is False:
|
||||
@ -564,10 +655,9 @@ class GroupMixin(discord.ext.commands.GroupMixin):
|
||||
|
||||
class CogGroupMixin:
|
||||
requires: Requires
|
||||
all_commands: Dict[str, Command]
|
||||
|
||||
def reevaluate_rules_for(
|
||||
self, model_id: Union[str, int], guild_id: Optional[int]
|
||||
self, model_id: Union[str, int], guild_id: int = 0
|
||||
) -> Tuple[PermState, bool]:
|
||||
"""Re-evaluate a rule by checking subcommand rules.
|
||||
|
||||
@ -590,15 +680,16 @@ class CogGroupMixin:
|
||||
|
||||
"""
|
||||
cur_rule = self.requires.get_rule(model_id, guild_id=guild_id)
|
||||
if cur_rule in (PermState.NORMAL, PermState.ACTIVE_ALLOW, PermState.ACTIVE_DENY):
|
||||
# These three states are unaffected by subcommand rules
|
||||
return cur_rule, False
|
||||
else:
|
||||
if cur_rule not in (PermState.NORMAL, PermState.ACTIVE_ALLOW, PermState.ACTIVE_DENY):
|
||||
# The above three states are unaffected by subcommand rules
|
||||
# Remaining states can be changed if there exists no actively-allowed
|
||||
# subcommand (this includes subcommands multiple levels below)
|
||||
|
||||
all_commands: Dict[str, Command] = getattr(self, "all_commands", {})
|
||||
|
||||
if any(
|
||||
cmd.requires.get_rule(model_id, guild_id=guild_id) in PermState.ALLOWED_STATES
|
||||
for cmd in self.all_commands.values()
|
||||
cmd.requires.get_rule(model_id, guild_id=guild_id) in PermStateAllowedStates
|
||||
for cmd in all_commands.values()
|
||||
):
|
||||
return cur_rule, False
|
||||
elif cur_rule is PermState.PASSIVE_ALLOW:
|
||||
@ -608,8 +699,11 @@ class CogGroupMixin:
|
||||
self.requires.set_rule(model_id, PermState.ACTIVE_DENY, guild_id=guild_id)
|
||||
return PermState.ACTIVE_DENY, True
|
||||
|
||||
# Default return value
|
||||
return cur_rule, False
|
||||
|
||||
class Group(GroupMixin, Command, CogGroupMixin, commands.Group):
|
||||
|
||||
class Group(GroupMixin, Command, CogGroupMixin, DPYGroup):
|
||||
"""Group command class for Red.
|
||||
|
||||
This class inherits from `Command`, with :class:`GroupMixin` and
|
||||
@ -653,14 +747,6 @@ class Group(GroupMixin, Command, CogGroupMixin, commands.Group):
|
||||
class CogMixin(CogGroupMixin, CogCommandMixin):
|
||||
"""Mixin class for a cog, intended for use with discord.py's cog class"""
|
||||
|
||||
@property
|
||||
def all_commands(self) -> Dict[str, Command]:
|
||||
"""
|
||||
This does not have identical behavior to
|
||||
Group.all_commands but should return what you expect
|
||||
"""
|
||||
return {cmd.name: cmd for cmd in self.__cog_commands__}
|
||||
|
||||
@property
|
||||
def help(self):
|
||||
doc = self.__doc__
|
||||
@ -689,7 +775,7 @@ class CogMixin(CogGroupMixin, CogCommandMixin):
|
||||
|
||||
try:
|
||||
can_run = await self.requires.verify(ctx)
|
||||
except commands.CommandError:
|
||||
except CommandError:
|
||||
return False
|
||||
|
||||
return can_run
|
||||
@ -718,16 +804,22 @@ class CogMixin(CogGroupMixin, CogCommandMixin):
|
||||
return await self.can_run(ctx)
|
||||
|
||||
|
||||
class Cog(CogMixin, commands.Cog):
|
||||
class Cog(CogMixin, DPYCog, metaclass=DPYCogMeta):
|
||||
"""
|
||||
Red's Cog base class
|
||||
|
||||
This includes a metaclass from discord.py
|
||||
"""
|
||||
|
||||
# NB: Do not move the inheritcance of this. Keeping the mix of that metaclass
|
||||
# seperate gives us more freedoms in several places.
|
||||
pass
|
||||
__cog_commands__: Tuple[Command]
|
||||
|
||||
@property
|
||||
def all_commands(self) -> Dict[str, Command]:
|
||||
"""
|
||||
This does not have identical behavior to
|
||||
Group.all_commands but should return what you expect
|
||||
"""
|
||||
return {cmd.name: cmd for cmd in self.__cog_commands__}
|
||||
|
||||
|
||||
def command(name=None, cls=Command, **attrs):
|
||||
@ -736,7 +828,8 @@ def command(name=None, cls=Command, **attrs):
|
||||
Same interface as `discord.ext.commands.command`.
|
||||
"""
|
||||
attrs["help_override"] = attrs.pop("help", None)
|
||||
return commands.command(name, cls, **attrs)
|
||||
|
||||
return dpy_command_deco(name, cls, **attrs)
|
||||
|
||||
|
||||
def group(name=None, cls=Group, **attrs):
|
||||
@ -744,10 +837,10 @@ def group(name=None, cls=Group, **attrs):
|
||||
|
||||
Same interface as `discord.ext.commands.group`.
|
||||
"""
|
||||
return command(name, cls, **attrs)
|
||||
return dpy_command_deco(name, cls, **attrs)
|
||||
|
||||
|
||||
__command_disablers = weakref.WeakValueDictionary()
|
||||
__command_disablers: DisablerDictType = weakref.WeakValueDictionary()
|
||||
|
||||
|
||||
def get_command_disabler(guild: discord.Guild) -> Callable[["Context"], Awaitable[bool]]:
|
||||
@ -762,7 +855,7 @@ def get_command_disabler(guild: discord.Guild) -> Callable[["Context"], Awaitabl
|
||||
|
||||
async def disabler(ctx: "Context") -> bool:
|
||||
if ctx.guild == guild:
|
||||
raise commands.DisabledCommand()
|
||||
raise DisabledCommand()
|
||||
return True
|
||||
|
||||
__command_disablers[guild] = disabler
|
||||
|
||||
@ -1,21 +1,28 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import contextlib
|
||||
import os
|
||||
import re
|
||||
from typing import Iterable, List, Union
|
||||
from typing import Iterable, List, Union, Optional, TYPE_CHECKING
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
from discord.ext.commands import Context as DPYContext
|
||||
|
||||
from .requires import PermState
|
||||
from ..utils.chat_formatting import box
|
||||
from ..utils.predicates import MessagePredicate
|
||||
from ..utils import common_filters
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .commands import Command
|
||||
from ..bot import Red
|
||||
|
||||
TICK = "\N{WHITE HEAVY CHECK MARK}"
|
||||
|
||||
__all__ = ["Context"]
|
||||
__all__ = ["Context", "GuildContext", "DMContext"]
|
||||
|
||||
|
||||
class Context(commands.Context):
|
||||
class Context(DPYContext):
|
||||
"""Command invocation context for Red.
|
||||
|
||||
All context passed into commands will be of this type.
|
||||
@ -40,6 +47,10 @@ class Context(commands.Context):
|
||||
The permission state the current context is in.
|
||||
"""
|
||||
|
||||
command: "Command"
|
||||
invoked_subcommand: "Optional[Command]"
|
||||
bot: "Red"
|
||||
|
||||
def __init__(self, **attrs):
|
||||
self.assume_yes = attrs.pop("assume_yes", False)
|
||||
super().__init__(**attrs)
|
||||
@ -254,7 +265,7 @@ class Context(commands.Context):
|
||||
return pattern.sub(f"@{me.display_name}", self.prefix)
|
||||
|
||||
@property
|
||||
def me(self) -> discord.abc.User:
|
||||
def me(self) -> Union[discord.ClientUser, discord.Member]:
|
||||
"""discord.abc.User: The bot member or user object.
|
||||
|
||||
If the context is DM, this will be a `discord.User` object.
|
||||
@ -263,3 +274,63 @@ class Context(commands.Context):
|
||||
return self.guild.me
|
||||
else:
|
||||
return self.bot.user
|
||||
|
||||
|
||||
if TYPE_CHECKING or os.getenv("BUILDING_DOCS", False):
|
||||
|
||||
class DMContext(Context):
|
||||
"""
|
||||
At runtime, this will still be a normal context object.
|
||||
|
||||
This lies about some type narrowing for type analysis in commands
|
||||
using a dm_only decorator.
|
||||
|
||||
It is only correct to use when those types are already narrowed
|
||||
"""
|
||||
|
||||
@property
|
||||
def author(self) -> discord.User:
|
||||
...
|
||||
|
||||
@property
|
||||
def channel(self) -> discord.DMChannel:
|
||||
...
|
||||
|
||||
@property
|
||||
def guild(self) -> None:
|
||||
...
|
||||
|
||||
@property
|
||||
def me(self) -> discord.ClientUser:
|
||||
...
|
||||
|
||||
class GuildContext(Context):
|
||||
"""
|
||||
At runtime, this will still be a normal context object.
|
||||
|
||||
This lies about some type narrowing for type analysis in commands
|
||||
using a guild_only decorator.
|
||||
|
||||
It is only correct to use when those types are already narrowed
|
||||
"""
|
||||
|
||||
@property
|
||||
def author(self) -> discord.Member:
|
||||
...
|
||||
|
||||
@property
|
||||
def channel(self) -> discord.TextChannel:
|
||||
...
|
||||
|
||||
@property
|
||||
def guild(self) -> discord.Guild:
|
||||
...
|
||||
|
||||
@property
|
||||
def me(self) -> discord.Member:
|
||||
...
|
||||
|
||||
|
||||
else:
|
||||
GuildContext = Context
|
||||
DMContext = Context
|
||||
|
||||
@ -1,14 +1,33 @@
|
||||
"""
|
||||
commands.converter
|
||||
==================
|
||||
This module contains useful functions and classes for command argument conversion.
|
||||
|
||||
Some of the converters within are included provisionaly and are marked as such.
|
||||
"""
|
||||
import os
|
||||
import re
|
||||
import functools
|
||||
from datetime import timedelta
|
||||
from typing import TYPE_CHECKING, Optional, List, Dict
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Generic,
|
||||
Optional,
|
||||
Optional as NoParseOptional,
|
||||
Tuple,
|
||||
List,
|
||||
Dict,
|
||||
Type,
|
||||
TypeVar,
|
||||
Literal as Literal,
|
||||
)
|
||||
|
||||
import discord
|
||||
from discord.ext import commands as dpy_commands
|
||||
from discord.ext.commands import BadArgument
|
||||
|
||||
from . import BadArgument
|
||||
from ..i18n import Translator
|
||||
from ..utils.chat_formatting import humanize_timedelta
|
||||
from ..utils.chat_formatting import humanize_timedelta, humanize_list
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .context import Context
|
||||
@ -17,10 +36,13 @@ __all__ = [
|
||||
"APIToken",
|
||||
"DictConverter",
|
||||
"GuildConverter",
|
||||
"UserInputOptional",
|
||||
"NoParseOptional",
|
||||
"TimedeltaConverter",
|
||||
"get_dict_converter",
|
||||
"get_timedelta_converter",
|
||||
"parse_timedelta",
|
||||
"Literal",
|
||||
]
|
||||
|
||||
_ = Translator("commands.converter", __file__)
|
||||
@ -67,7 +89,7 @@ def parse_timedelta(
|
||||
allowed_units : Optional[List[str]]
|
||||
If provided, you can constrain a user to expressing the amount of time
|
||||
in specific units. The units you can chose to provide are the same as the
|
||||
parser understands. `weeks` `days` `hours` `minutes` `seconds`
|
||||
parser understands. (``weeks``, ``days``, ``hours``, ``minutes``, ``seconds``)
|
||||
|
||||
Returns
|
||||
-------
|
||||
@ -138,17 +160,18 @@ class APIToken(discord.ext.commands.Converter):
|
||||
This will parse the input argument separating the key value pairs into a
|
||||
format to be used for the core bots API token storage.
|
||||
|
||||
This will split the argument by either `;` ` `, or `,` and return a dict
|
||||
This will split the argument by a space, comma, or semicolon and return a dict
|
||||
to be stored. Since all API's are different and have different naming convention,
|
||||
this leaves the onus on the cog creator to clearly define how to setup the correct
|
||||
credential names for their cogs.
|
||||
|
||||
Note: Core usage of this has been replaced with DictConverter use instead.
|
||||
Note: Core usage of this has been replaced with `DictConverter` use instead.
|
||||
|
||||
This may be removed at a later date (with warning)
|
||||
.. warning::
|
||||
This will be removed in version 3.4.
|
||||
"""
|
||||
|
||||
async def convert(self, ctx, argument) -> dict:
|
||||
async def convert(self, ctx: "Context", argument) -> dict:
|
||||
bot = ctx.bot
|
||||
result = {}
|
||||
match = re.split(r";|,| ", argument)
|
||||
@ -162,140 +185,262 @@ class APIToken(discord.ext.commands.Converter):
|
||||
return result
|
||||
|
||||
|
||||
class DictConverter(dpy_commands.Converter):
|
||||
"""
|
||||
Converts pairs of space seperated values to a dict
|
||||
"""
|
||||
# Below this line are a lot of lies for mypy about things that *end up* correct when
|
||||
# These are used for command conversion purposes. Please refer to the portion
|
||||
# which is *not* for type checking for the actual implementation
|
||||
# and ensure the lies stay correct for how the object should look as a typehint
|
||||
|
||||
def __init__(self, *expected_keys: str, delims: Optional[List[str]] = None):
|
||||
self.expected_keys = expected_keys
|
||||
self.delims = delims or [" "]
|
||||
self.pattern = re.compile(r"|".join(re.escape(d) for d in self.delims))
|
||||
if TYPE_CHECKING:
|
||||
DictConverter = Dict[str, str]
|
||||
else:
|
||||
|
||||
async def convert(self, ctx: "Context", argument: str) -> Dict[str, str]:
|
||||
class DictConverter(dpy_commands.Converter):
|
||||
"""
|
||||
Converts pairs of space seperated values to a dict
|
||||
"""
|
||||
|
||||
ret: Dict[str, str] = {}
|
||||
args = self.pattern.split(argument)
|
||||
def __init__(self, *expected_keys: str, delims: Optional[List[str]] = None):
|
||||
self.expected_keys = expected_keys
|
||||
self.delims = delims or [" "]
|
||||
self.pattern = re.compile(r"|".join(re.escape(d) for d in self.delims))
|
||||
|
||||
if len(args) % 2 != 0:
|
||||
raise BadArgument()
|
||||
async def convert(self, ctx: "Context", argument: str) -> Dict[str, str]:
|
||||
ret: Dict[str, str] = {}
|
||||
args = self.pattern.split(argument)
|
||||
|
||||
iterator = iter(args)
|
||||
if len(args) % 2 != 0:
|
||||
raise BadArgument()
|
||||
|
||||
for key in iterator:
|
||||
if self.expected_keys and key not in self.expected_keys:
|
||||
raise BadArgument(_("Unexpected key {key}").format(key=key))
|
||||
iterator = iter(args)
|
||||
|
||||
ret[key] = next(iterator)
|
||||
for key in iterator:
|
||||
if self.expected_keys and key not in self.expected_keys:
|
||||
raise BadArgument(_("Unexpected key {key}").format(key=key))
|
||||
|
||||
return ret
|
||||
ret[key] = next(iterator)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def get_dict_converter(*expected_keys: str, delims: Optional[List[str]] = None) -> type:
|
||||
"""
|
||||
Returns a typechecking safe `DictConverter` suitable for use with discord.py
|
||||
"""
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class PartialMeta(type(DictConverter)):
|
||||
__call__ = functools.partialmethod(
|
||||
type(DictConverter).__call__, *expected_keys, delims=delims
|
||||
)
|
||||
|
||||
class ValidatedConverter(DictConverter, metaclass=PartialMeta):
|
||||
pass
|
||||
|
||||
return ValidatedConverter
|
||||
def get_dict_converter(*expected_keys: str, delims: Optional[List[str]] = None) -> Type[dict]:
|
||||
...
|
||||
|
||||
|
||||
class TimedeltaConverter(dpy_commands.Converter):
|
||||
"""
|
||||
This is a converter for timedeltas.
|
||||
The units should be in order from largest to smallest.
|
||||
This works with or without whitespace.
|
||||
else:
|
||||
|
||||
See `parse_timedelta` for more information about how this functions.
|
||||
def get_dict_converter(*expected_keys: str, delims: Optional[List[str]] = None) -> Type[dict]:
|
||||
"""
|
||||
Returns a typechecking safe `DictConverter` suitable for use with discord.py
|
||||
"""
|
||||
|
||||
Attributes
|
||||
----------
|
||||
maximum : Optional[timedelta]
|
||||
If provided, any parsed value higher than this will raise an exception
|
||||
minimum : Optional[timedelta]
|
||||
If provided, any parsed value lower than this will raise an exception
|
||||
allowed_units : Optional[List[str]]
|
||||
If provided, you can constrain a user to expressing the amount of time
|
||||
in specific units. The units you can chose to provide are the same as the
|
||||
parser understands: `weeks` `days` `hours` `minutes` `seconds`
|
||||
default_unit : Optional[str]
|
||||
If provided, it will additionally try to match integer-only input into
|
||||
a timedelta, using the unit specified. Same units as in `allowed_units`
|
||||
apply.
|
||||
"""
|
||||
|
||||
def __init__(self, *, minimum=None, maximum=None, allowed_units=None, default_unit=None):
|
||||
self.allowed_units = allowed_units
|
||||
self.default_unit = default_unit
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
|
||||
async def convert(self, ctx: "Context", argument: str) -> timedelta:
|
||||
if self.default_unit and argument.isdecimal():
|
||||
delta = timedelta(**{self.default_unit: int(argument)})
|
||||
else:
|
||||
delta = parse_timedelta(
|
||||
argument,
|
||||
minimum=self.minimum,
|
||||
maximum=self.maximum,
|
||||
allowed_units=self.allowed_units,
|
||||
class PartialMeta(type):
|
||||
__call__ = functools.partialmethod(
|
||||
type(DictConverter).__call__, *expected_keys, delims=delims
|
||||
)
|
||||
if delta is not None:
|
||||
return delta
|
||||
raise BadArgument() # This allows this to be a required argument.
|
||||
|
||||
class ValidatedConverter(DictConverter, metaclass=PartialMeta):
|
||||
pass
|
||||
|
||||
return ValidatedConverter
|
||||
|
||||
|
||||
def get_timedelta_converter(
|
||||
*,
|
||||
default_unit: Optional[str] = None,
|
||||
maximum: Optional[timedelta] = None,
|
||||
minimum: Optional[timedelta] = None,
|
||||
allowed_units: Optional[List[str]] = None,
|
||||
) -> type:
|
||||
"""
|
||||
This creates a type suitable for typechecking which works with discord.py's
|
||||
commands.
|
||||
if TYPE_CHECKING:
|
||||
TimedeltaConverter = timedelta
|
||||
else:
|
||||
|
||||
See `parse_timedelta` for more information about how this functions.
|
||||
class TimedeltaConverter(dpy_commands.Converter):
|
||||
"""
|
||||
This is a converter for timedeltas.
|
||||
The units should be in order from largest to smallest.
|
||||
This works with or without whitespace.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
maximum : Optional[timedelta]
|
||||
If provided, any parsed value higher than this will raise an exception
|
||||
minimum : Optional[timedelta]
|
||||
If provided, any parsed value lower than this will raise an exception
|
||||
allowed_units : Optional[List[str]]
|
||||
If provided, you can constrain a user to expressing the amount of time
|
||||
in specific units. The units you can chose to provide are the same as the
|
||||
parser understands: `weeks` `days` `hours` `minutes` `seconds`
|
||||
default_unit : Optional[str]
|
||||
If provided, it will additionally try to match integer-only input into
|
||||
a timedelta, using the unit specified. Same units as in `allowed_units`
|
||||
apply.
|
||||
See `parse_timedelta` for more information about how this functions.
|
||||
|
||||
Returns
|
||||
-------
|
||||
type
|
||||
The converter class, which will be a subclass of `TimedeltaConverter`
|
||||
"""
|
||||
Attributes
|
||||
----------
|
||||
maximum : Optional[timedelta]
|
||||
If provided, any parsed value higher than this will raise an exception
|
||||
minimum : Optional[timedelta]
|
||||
If provided, any parsed value lower than this will raise an exception
|
||||
allowed_units : Optional[List[str]]
|
||||
If provided, you can constrain a user to expressing the amount of time
|
||||
in specific units. The units you can choose to provide are the same as the
|
||||
parser understands: (``weeks``, ``days``, ``hours``, ``minutes``, ``seconds``)
|
||||
default_unit : Optional[str]
|
||||
If provided, it will additionally try to match integer-only input into
|
||||
a timedelta, using the unit specified. Same units as in ``allowed_units``
|
||||
apply.
|
||||
"""
|
||||
|
||||
class PartialMeta(type(TimedeltaConverter)):
|
||||
__call__ = functools.partialmethod(
|
||||
type(DictConverter).__call__,
|
||||
allowed_units=allowed_units,
|
||||
default_unit=default_unit,
|
||||
minimum=minimum,
|
||||
maximum=maximum,
|
||||
)
|
||||
def __init__(self, *, minimum=None, maximum=None, allowed_units=None, default_unit=None):
|
||||
self.allowed_units = allowed_units
|
||||
self.default_unit = default_unit
|
||||
self.minimum = minimum
|
||||
self.maximum = maximum
|
||||
|
||||
class ValidatedConverter(TimedeltaConverter, metaclass=PartialMeta):
|
||||
pass
|
||||
async def convert(self, ctx: "Context", argument: str) -> timedelta:
|
||||
if self.default_unit and argument.isdecimal():
|
||||
delta = timedelta(**{self.default_unit: int(argument)})
|
||||
else:
|
||||
delta = parse_timedelta(
|
||||
argument,
|
||||
minimum=self.minimum,
|
||||
maximum=self.maximum,
|
||||
allowed_units=self.allowed_units,
|
||||
)
|
||||
if delta is not None:
|
||||
return delta
|
||||
raise BadArgument() # This allows this to be a required argument.
|
||||
|
||||
return ValidatedConverter
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def get_timedelta_converter(
|
||||
*,
|
||||
default_unit: Optional[str] = None,
|
||||
maximum: Optional[timedelta] = None,
|
||||
minimum: Optional[timedelta] = None,
|
||||
allowed_units: Optional[List[str]] = None,
|
||||
) -> Type[timedelta]:
|
||||
...
|
||||
|
||||
|
||||
else:
|
||||
|
||||
def get_timedelta_converter(
|
||||
*,
|
||||
default_unit: Optional[str] = None,
|
||||
maximum: Optional[timedelta] = None,
|
||||
minimum: Optional[timedelta] = None,
|
||||
allowed_units: Optional[List[str]] = None,
|
||||
) -> Type[timedelta]:
|
||||
"""
|
||||
This creates a type suitable for typechecking which works with discord.py's
|
||||
commands.
|
||||
|
||||
See `parse_timedelta` for more information about how this functions.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
maximum : Optional[timedelta]
|
||||
If provided, any parsed value higher than this will raise an exception
|
||||
minimum : Optional[timedelta]
|
||||
If provided, any parsed value lower than this will raise an exception
|
||||
allowed_units : Optional[List[str]]
|
||||
If provided, you can constrain a user to expressing the amount of time
|
||||
in specific units. The units you can choose to provide are the same as the
|
||||
parser understands: (``weeks``, ``days``, ``hours``, ``minutes``, ``seconds``)
|
||||
default_unit : Optional[str]
|
||||
If provided, it will additionally try to match integer-only input into
|
||||
a timedelta, using the unit specified. Same units as in ``allowed_units``
|
||||
apply.
|
||||
|
||||
Returns
|
||||
-------
|
||||
type
|
||||
The converter class, which will be a subclass of `TimedeltaConverter`
|
||||
"""
|
||||
|
||||
class PartialMeta(type):
|
||||
__call__ = functools.partialmethod(
|
||||
type(DictConverter).__call__,
|
||||
allowed_units=allowed_units,
|
||||
default_unit=default_unit,
|
||||
minimum=minimum,
|
||||
maximum=maximum,
|
||||
)
|
||||
|
||||
class ValidatedConverter(TimedeltaConverter, metaclass=PartialMeta):
|
||||
pass
|
||||
|
||||
return ValidatedConverter
|
||||
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
class NoParseOptional:
|
||||
"""
|
||||
This can be used instead of `typing.Optional`
|
||||
to avoid discord.py special casing the conversion behavior.
|
||||
|
||||
.. warning::
|
||||
This converter class is still provisional.
|
||||
|
||||
.. seealso::
|
||||
The `ignore_optional_for_conversion` option of commands.
|
||||
"""
|
||||
|
||||
def __class_getitem__(cls, key):
|
||||
if isinstance(key, tuple):
|
||||
raise TypeError("Must only provide a single type to Optional")
|
||||
return key
|
||||
|
||||
|
||||
_T_OPT = TypeVar("_T_OPT", bound=Type)
|
||||
|
||||
if TYPE_CHECKING or os.getenv("BUILDING_DOCS", False):
|
||||
|
||||
class UserInputOptional(Generic[_T_OPT]):
|
||||
"""
|
||||
This can be used when user input should be converted as discord.py
|
||||
treats `typing.Optional`, but the type should not be equivalent to
|
||||
``typing.Union[DesiredType, None]`` for type checking.
|
||||
|
||||
|
||||
.. warning::
|
||||
This converter class is still provisional.
|
||||
|
||||
This class may not play well with mypy yet
|
||||
and may still require you guard this in a
|
||||
type checking conditional import vs the desired types
|
||||
|
||||
We're aware and looking into improving this.
|
||||
"""
|
||||
|
||||
def __class_getitem__(cls, key: _T_OPT) -> _T_OPT:
|
||||
if isinstance(key, tuple):
|
||||
raise TypeError("Must only provide a single type to Optional")
|
||||
return key
|
||||
|
||||
|
||||
else:
|
||||
UserInputOptional = Optional
|
||||
|
||||
|
||||
if not TYPE_CHECKING:
|
||||
|
||||
class Literal(dpy_commands.Converter):
|
||||
"""
|
||||
This can be used as a converter for `typing.Literal`.
|
||||
|
||||
In a type checking context it is `typing.Literal`.
|
||||
In a runtime context, it's a converter which only matches the literals it was given.
|
||||
|
||||
|
||||
.. warning::
|
||||
This converter class is still provisional.
|
||||
"""
|
||||
|
||||
def __init__(self, valid_names: Tuple[str]):
|
||||
self.valid_names = valid_names
|
||||
|
||||
def __call__(self, ctx, arg):
|
||||
# Callable's are treated as valid types:
|
||||
# https://github.com/python/cpython/blob/3.8/Lib/typing.py#L148
|
||||
# Without this, ``typing.Union[Literal["clear"], bool]`` would fail
|
||||
return self.convert(ctx, arg)
|
||||
|
||||
async def convert(self, ctx, arg):
|
||||
if arg in self.valid_names:
|
||||
return arg
|
||||
raise BadArgument(_("Expected one of: {}").format(humanize_list(self.valid_names)))
|
||||
|
||||
def __class_getitem__(cls, k):
|
||||
if not k:
|
||||
raise ValueError("Need at least one value for Literal")
|
||||
if isinstance(k, tuple):
|
||||
return cls(k)
|
||||
else:
|
||||
return cls((k,))
|
||||
|
||||
@ -8,6 +8,7 @@ checks like bot permissions checks.
|
||||
"""
|
||||
import asyncio
|
||||
import enum
|
||||
import inspect
|
||||
from typing import (
|
||||
Union,
|
||||
Optional,
|
||||
@ -45,6 +46,7 @@ __all__ = [
|
||||
"permissions_check",
|
||||
"bot_has_permissions",
|
||||
"has_permissions",
|
||||
"has_guild_permissions",
|
||||
"is_owner",
|
||||
"guildowner",
|
||||
"guildowner_or_permissions",
|
||||
@ -52,6 +54,9 @@ __all__ = [
|
||||
"admin_or_permissions",
|
||||
"mod",
|
||||
"mod_or_permissions",
|
||||
"transition_permstate_to",
|
||||
"PermStateTransitions",
|
||||
"PermStateAllowedStates",
|
||||
]
|
||||
|
||||
_T = TypeVar("_T")
|
||||
@ -182,11 +187,6 @@ class PermState(enum.Enum):
|
||||
"""This command has been actively denied by a permission hook
|
||||
check validation doesn't need this, but is useful to developers"""
|
||||
|
||||
def transition_to(
|
||||
self, next_state: "PermState"
|
||||
) -> Tuple[Optional[bool], Union["PermState", Dict[bool, "PermState"]]]:
|
||||
return self.TRANSITIONS[self][next_state]
|
||||
|
||||
@classmethod
|
||||
def from_bool(cls, value: Optional[bool]) -> "PermState":
|
||||
"""Get a PermState from a bool or ``NoneType``."""
|
||||
@ -211,7 +211,11 @@ class PermState(enum.Enum):
|
||||
# result of the default permission checks - the transition from NORMAL
|
||||
# to PASSIVE_ALLOW. In this case "next state" is a dict mapping the
|
||||
# permission check results to the actual next state.
|
||||
PermState.TRANSITIONS = {
|
||||
|
||||
TransitionResult = Tuple[Optional[bool], Union[PermState, Dict[bool, PermState]]]
|
||||
TransitionDict = Dict[PermState, Dict[PermState, TransitionResult]]
|
||||
|
||||
PermStateTransitions: TransitionDict = {
|
||||
PermState.ACTIVE_ALLOW: {
|
||||
PermState.ACTIVE_ALLOW: (True, PermState.ACTIVE_ALLOW),
|
||||
PermState.NORMAL: (True, PermState.ACTIVE_ALLOW),
|
||||
@ -248,13 +252,18 @@ PermState.TRANSITIONS = {
|
||||
PermState.ACTIVE_DENY: (False, PermState.ACTIVE_DENY),
|
||||
},
|
||||
}
|
||||
PermState.ALLOWED_STATES = (
|
||||
|
||||
PermStateAllowedStates = (
|
||||
PermState.ACTIVE_ALLOW,
|
||||
PermState.PASSIVE_ALLOW,
|
||||
PermState.CAUTIOUS_ALLOW,
|
||||
)
|
||||
|
||||
|
||||
def transition_permstate_to(prev: PermState, next_state: PermState) -> TransitionResult:
|
||||
return PermStateTransitions[prev][next_state]
|
||||
|
||||
|
||||
class Requires:
|
||||
"""This class describes the requirements for executing a specific command.
|
||||
|
||||
@ -326,13 +335,13 @@ class Requires:
|
||||
|
||||
@staticmethod
|
||||
def get_decorator(
|
||||
privilege_level: Optional[PrivilegeLevel], user_perms: Dict[str, bool]
|
||||
privilege_level: Optional[PrivilegeLevel], user_perms: Optional[Dict[str, bool]]
|
||||
) -> Callable[["_CommandOrCoro"], "_CommandOrCoro"]:
|
||||
if not user_perms:
|
||||
user_perms = None
|
||||
|
||||
def decorator(func: "_CommandOrCoro") -> "_CommandOrCoro":
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
if inspect.iscoroutinefunction(func):
|
||||
func.__requires_privilege_level__ = privilege_level
|
||||
func.__requires_user_perms__ = user_perms
|
||||
else:
|
||||
@ -341,6 +350,7 @@ class Requires:
|
||||
func.requires.user_perms = None
|
||||
else:
|
||||
_validate_perms_dict(user_perms)
|
||||
assert func.requires.user_perms is not None
|
||||
func.requires.user_perms.update(**user_perms)
|
||||
return func
|
||||
|
||||
@ -488,7 +498,7 @@ class Requires:
|
||||
async def _transition_state(self, ctx: "Context") -> bool:
|
||||
prev_state = ctx.permission_state
|
||||
cur_state = self._get_rule_from_ctx(ctx)
|
||||
should_invoke, next_state = prev_state.transition_to(cur_state)
|
||||
should_invoke, next_state = transition_permstate_to(prev_state, cur_state)
|
||||
if should_invoke is None:
|
||||
# NORMAL invokation, we simply follow standard procedure
|
||||
should_invoke = await self._verify_user(ctx)
|
||||
@ -509,6 +519,7 @@ class Requires:
|
||||
would_invoke = await self._verify_user(ctx)
|
||||
next_state = next_state[would_invoke]
|
||||
|
||||
assert isinstance(next_state, PermState)
|
||||
ctx.permission_state = next_state
|
||||
return should_invoke
|
||||
|
||||
@ -635,6 +646,20 @@ def permissions_check(predicate: CheckPredicate):
|
||||
return decorator
|
||||
|
||||
|
||||
def has_guild_permissions(**perms):
|
||||
"""Restrict the command to users with these guild permissions.
|
||||
|
||||
This check can be overridden by rules.
|
||||
"""
|
||||
|
||||
_validate_perms_dict(perms)
|
||||
|
||||
def predicate(ctx):
|
||||
return ctx.guild and ctx.author.guild_permissions >= discord.Permissions(**perms)
|
||||
|
||||
return permissions_check(predicate)
|
||||
|
||||
|
||||
def bot_has_permissions(**perms: bool):
|
||||
"""Complain if the bot is missing permissions.
|
||||
|
||||
|
||||
@ -979,7 +979,7 @@ class Config:
|
||||
"""
|
||||
return self._get_base_group(self.CHANNEL, str(channel_id))
|
||||
|
||||
def channel(self, channel: discord.TextChannel) -> Group:
|
||||
def channel(self, channel: discord.abc.GuildChannel) -> Group:
|
||||
"""Returns a `Group` for the given channel.
|
||||
|
||||
This does not discriminate between text and voice channels.
|
||||
|
||||
@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from typing import Callable, ClassVar, List, Optional, Pattern, Sequence, Tuple, Union, cast
|
||||
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user