Allow any Messageable in MessagePredicate's channel parameter (#5942)

Co-authored-by: Flame442 <34169552+Flame442@users.noreply.github.com>
This commit is contained in:
Jakub Kuczys 2023-04-13 21:21:36 +02:00 committed by GitHub
parent 533f036ed2
commit 145b2e43ce
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -4,6 +4,7 @@ import re
from typing import Callable, ClassVar, List, Optional, Pattern, Sequence, Tuple, Union, cast from typing import Callable, ClassVar, List, Optional, Pattern, Sequence, Tuple, Union, cast
import discord import discord
from discord.ext import commands as dpy_commands
from redbot.core import commands from redbot.core import commands
@ -67,9 +68,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
def same_context( def same_context(
cls, cls,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
channel: Optional[ channel: Optional[discord.abc.Messageable] = None,
Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel]
] = None,
user: Optional[discord.abc.User] = None, user: Optional[discord.abc.User] = None,
) -> "MessagePredicate": ) -> "MessagePredicate":
"""Match if the message fits the described context. """Match if the message fits the described context.
@ -78,8 +77,8 @@ class MessagePredicate(Callable[[discord.Message], bool]):
---------- ----------
ctx : Optional[Context] ctx : Optional[Context]
The current invocation context. The current invocation context.
channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] channel : Optional[discord.abc.Messageable]
The channel we expect a message in. If unspecified, The messageable object we expect a message in. If unspecified,
defaults to ``ctx.channel``. If ``ctx`` is unspecified defaults to ``ctx.channel``. If ``ctx`` is unspecified
too, the message's channel will be ignored. too, the message's channel will be ignored.
user : Optional[discord.abc.User] user : Optional[discord.abc.User]
@ -93,22 +92,34 @@ class MessagePredicate(Callable[[discord.Message], bool]):
The event predicate. The event predicate.
""" """
check_dm_channel = False
# using dpy_commands.Context to keep the Messageable contract in full
if isinstance(channel, dpy_commands.Context):
channel = channel.channel
elif isinstance(channel, (discord.User, discord.Member)):
check_dm_channel = True
if ctx is not None: if ctx is not None:
channel = channel or ctx.channel channel = channel or ctx.channel
user = user or ctx.author user = user or ctx.author
return cls( return cls(
lambda self, m: (user is None or user.id == m.author.id) lambda self, m: (user is None or user.id == m.author.id)
and (channel is None or channel.id == m.channel.id) and (
channel is None
or (
channel.id == m.author.id and isinstance(m.channel, discord.DMChannel)
if check_dm_channel
else channel.id == m.channel.id
)
)
) )
@classmethod @classmethod
def cancelled( def cancelled(
cls, cls,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
channel: Optional[ channel: Optional[discord.abc.Messageable] = None,
Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel]
] = None,
user: Optional[discord.abc.User] = None, user: Optional[discord.abc.User] = None,
) -> "MessagePredicate": ) -> "MessagePredicate":
"""Match if the message is ``[p]cancel``. """Match if the message is ``[p]cancel``.
@ -117,7 +128,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
---------- ----------
ctx : Optional[Context] ctx : Optional[Context]
Same as ``ctx`` in :meth:`same_context`. Same as ``ctx`` in :meth:`same_context`.
channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] channel : Optional[discord.abc.Messageable]
Same as ``channel`` in :meth:`same_context`. Same as ``channel`` in :meth:`same_context`.
user : Optional[discord.abc.User] user : Optional[discord.abc.User]
Same as ``user`` in :meth:`same_context`. Same as ``user`` in :meth:`same_context`.
@ -137,9 +148,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
def yes_or_no( def yes_or_no(
cls, cls,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
channel: Optional[ channel: Optional[discord.abc.Messageable] = None,
Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel]
] = None,
user: Optional[discord.abc.User] = None, user: Optional[discord.abc.User] = None,
) -> "MessagePredicate": ) -> "MessagePredicate":
"""Match if the message is "yes"/"y" or "no"/"n". """Match if the message is "yes"/"y" or "no"/"n".
@ -151,7 +160,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
---------- ----------
ctx : Optional[Context] ctx : Optional[Context]
Same as ``ctx`` in :meth:`same_context`. Same as ``ctx`` in :meth:`same_context`.
channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] channel : Optional[discord.abc.Messageable]
Same as ``channel`` in :meth:`same_context`. Same as ``channel`` in :meth:`same_context`.
user : Optional[discord.abc.User] user : Optional[discord.abc.User]
Same as ``user`` in :meth:`same_context`. Same as ``user`` in :meth:`same_context`.
@ -182,9 +191,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
def valid_int( def valid_int(
cls, cls,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
channel: Optional[ channel: Optional[discord.abc.Messageable] = None,
Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel]
] = None,
user: Optional[discord.abc.User] = None, user: Optional[discord.abc.User] = None,
) -> "MessagePredicate": ) -> "MessagePredicate":
"""Match if the response is an integer. """Match if the response is an integer.
@ -195,7 +202,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
---------- ----------
ctx : Optional[Context] ctx : Optional[Context]
Same as ``ctx`` in :meth:`same_context`. Same as ``ctx`` in :meth:`same_context`.
channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] channel : Optional[discord.abc.Messageable]
Same as ``channel`` in :meth:`same_context`. Same as ``channel`` in :meth:`same_context`.
user : Optional[discord.abc.User] user : Optional[discord.abc.User]
Same as ``user`` in :meth:`same_context`. Same as ``user`` in :meth:`same_context`.
@ -224,9 +231,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
def valid_float( def valid_float(
cls, cls,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
channel: Optional[ channel: Optional[discord.abc.Messageable] = None,
Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel]
] = None,
user: Optional[discord.abc.User] = None, user: Optional[discord.abc.User] = None,
) -> "MessagePredicate": ) -> "MessagePredicate":
"""Match if the response is a float. """Match if the response is a float.
@ -237,7 +242,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
---------- ----------
ctx : Optional[Context] ctx : Optional[Context]
Same as ``ctx`` in :meth:`same_context`. Same as ``ctx`` in :meth:`same_context`.
channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] channel : Optional[discord.abc.Messageable]
Same as ``channel`` in :meth:`same_context`. Same as ``channel`` in :meth:`same_context`.
user : Optional[discord.abc.User] user : Optional[discord.abc.User]
Same as ``user`` in :meth:`same_context`. Same as ``user`` in :meth:`same_context`.
@ -266,9 +271,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
def positive( def positive(
cls, cls,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
channel: Optional[ channel: Optional[discord.abc.Messageable] = None,
Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel]
] = None,
user: Optional[discord.abc.User] = None, user: Optional[discord.abc.User] = None,
) -> "MessagePredicate": ) -> "MessagePredicate":
"""Match if the response is a positive number. """Match if the response is a positive number.
@ -279,7 +282,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
---------- ----------
ctx : Optional[Context] ctx : Optional[Context]
Same as ``ctx`` in :meth:`same_context`. Same as ``ctx`` in :meth:`same_context`.
channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] channel : Optional[discord.abc.Messageable]
Same as ``channel`` in :meth:`same_context`. Same as ``channel`` in :meth:`same_context`.
user : Optional[discord.abc.User] user : Optional[discord.abc.User]
Same as ``user`` in :meth:`same_context`. Same as ``user`` in :meth:`same_context`.
@ -504,9 +507,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
cls, cls,
value: str, value: str,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
channel: Optional[ channel: Optional[discord.abc.Messageable] = None,
Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel]
] = None,
user: Optional[discord.abc.User] = None, user: Optional[discord.abc.User] = None,
) -> "MessagePredicate": ) -> "MessagePredicate":
"""Match if the response is equal to the specified value. """Match if the response is equal to the specified value.
@ -517,7 +518,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
The value to compare the response with. The value to compare the response with.
ctx : Optional[Context] ctx : Optional[Context]
Same as ``ctx`` in :meth:`same_context`. Same as ``ctx`` in :meth:`same_context`.
channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] channel : Optional[discord.abc.Messageable]
Same as ``channel`` in :meth:`same_context`. Same as ``channel`` in :meth:`same_context`.
user : Optional[discord.abc.User] user : Optional[discord.abc.User]
Same as ``user`` in :meth:`same_context`. Same as ``user`` in :meth:`same_context`.
@ -536,9 +537,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
cls, cls,
value: str, value: str,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
channel: Optional[ channel: Optional[discord.abc.Messageable] = None,
Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel]
] = None,
user: Optional[discord.abc.User] = None, user: Optional[discord.abc.User] = None,
) -> "MessagePredicate": ) -> "MessagePredicate":
"""Match if the response *as lowercase* is equal to the specified value. """Match if the response *as lowercase* is equal to the specified value.
@ -549,7 +548,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
The value to compare the response with. The value to compare the response with.
ctx : Optional[Context] ctx : Optional[Context]
Same as ``ctx`` in :meth:`same_context`. Same as ``ctx`` in :meth:`same_context`.
channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] channel : Optional[discord.abc.Messageable]
Same as ``channel`` in :meth:`same_context`. Same as ``channel`` in :meth:`same_context`.
user : Optional[discord.abc.User] user : Optional[discord.abc.User]
Same as ``user`` in :meth:`same_context`. Same as ``user`` in :meth:`same_context`.
@ -568,9 +567,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
cls, cls,
value: Union[int, float], value: Union[int, float],
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
channel: Optional[ channel: Optional[discord.abc.Messageable] = None,
Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel]
] = None,
user: Optional[discord.abc.User] = None, user: Optional[discord.abc.User] = None,
) -> "MessagePredicate": ) -> "MessagePredicate":
"""Match if the response is less than the specified value. """Match if the response is less than the specified value.
@ -581,7 +578,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
The value to compare the response with. The value to compare the response with.
ctx : Optional[Context] ctx : Optional[Context]
Same as ``ctx`` in :meth:`same_context`. Same as ``ctx`` in :meth:`same_context`.
channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] channel : Optional[discord.abc.Messageable]
Same as ``channel`` in :meth:`same_context`. Same as ``channel`` in :meth:`same_context`.
user : Optional[discord.abc.User] user : Optional[discord.abc.User]
Same as ``user`` in :meth:`same_context`. Same as ``user`` in :meth:`same_context`.
@ -601,9 +598,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
cls, cls,
value: Union[int, float], value: Union[int, float],
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
channel: Optional[ channel: Optional[discord.abc.Messageable] = None,
Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel]
] = None,
user: Optional[discord.abc.User] = None, user: Optional[discord.abc.User] = None,
) -> "MessagePredicate": ) -> "MessagePredicate":
"""Match if the response is greater than the specified value. """Match if the response is greater than the specified value.
@ -614,7 +609,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
The value to compare the response with. The value to compare the response with.
ctx : Optional[Context] ctx : Optional[Context]
Same as ``ctx`` in :meth:`same_context`. Same as ``ctx`` in :meth:`same_context`.
channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] channel : Optional[discord.abc.Messageable]
Same as ``channel`` in :meth:`same_context`. Same as ``channel`` in :meth:`same_context`.
user : Optional[discord.abc.User] user : Optional[discord.abc.User]
Same as ``user`` in :meth:`same_context`. Same as ``user`` in :meth:`same_context`.
@ -634,9 +629,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
cls, cls,
length: int, length: int,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
channel: Optional[ channel: Optional[discord.abc.Messageable] = None,
Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel]
] = None,
user: Optional[discord.abc.User] = None, user: Optional[discord.abc.User] = None,
) -> "MessagePredicate": ) -> "MessagePredicate":
"""Match if the response's length is less than the specified length. """Match if the response's length is less than the specified length.
@ -647,7 +640,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
The value to compare the response's length with. The value to compare the response's length with.
ctx : Optional[Context] ctx : Optional[Context]
Same as ``ctx`` in :meth:`same_context`. Same as ``ctx`` in :meth:`same_context`.
channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] channel : Optional[discord.abc.Messageable]
Same as ``channel`` in :meth:`same_context`. Same as ``channel`` in :meth:`same_context`.
user : Optional[discord.abc.User] user : Optional[discord.abc.User]
Same as ``user`` in :meth:`same_context`. Same as ``user`` in :meth:`same_context`.
@ -666,9 +659,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
cls, cls,
length: int, length: int,
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
channel: Optional[ channel: Optional[discord.abc.Messageable] = None,
Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel]
] = None,
user: Optional[discord.abc.User] = None, user: Optional[discord.abc.User] = None,
) -> "MessagePredicate": ) -> "MessagePredicate":
"""Match if the response's length is greater than the specified length. """Match if the response's length is greater than the specified length.
@ -679,7 +670,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
The value to compare the response's length with. The value to compare the response's length with.
ctx : Optional[Context] ctx : Optional[Context]
Same as ``ctx`` in :meth:`same_context`. Same as ``ctx`` in :meth:`same_context`.
channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] channel : Optional[discord.abc.Messageable]
Same as ``channel`` in :meth:`same_context`. Same as ``channel`` in :meth:`same_context`.
user : Optional[discord.abc.User] user : Optional[discord.abc.User]
Same as ``user`` in :meth:`same_context`. Same as ``user`` in :meth:`same_context`.
@ -698,9 +689,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
cls, cls,
collection: Sequence[str], collection: Sequence[str],
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
channel: Optional[ channel: Optional[discord.abc.Messageable] = None,
Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel]
] = None,
user: Optional[discord.abc.User] = None, user: Optional[discord.abc.User] = None,
) -> "MessagePredicate": ) -> "MessagePredicate":
"""Match if the response is contained in the specified collection. """Match if the response is contained in the specified collection.
@ -714,7 +703,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
The collection containing valid responses. The collection containing valid responses.
ctx : Optional[Context] ctx : Optional[Context]
Same as ``ctx`` in :meth:`same_context`. Same as ``ctx`` in :meth:`same_context`.
channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] channel : Optional[discord.abc.Messageable]
Same as ``channel`` in :meth:`same_context`. Same as ``channel`` in :meth:`same_context`.
user : Optional[discord.abc.User] user : Optional[discord.abc.User]
Same as ``user`` in :meth:`same_context`. Same as ``user`` in :meth:`same_context`.
@ -744,9 +733,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
cls, cls,
collection: Sequence[str], collection: Sequence[str],
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
channel: Optional[ channel: Optional[discord.abc.Messageable] = None,
Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel]
] = None,
user: Optional[discord.abc.User] = None, user: Optional[discord.abc.User] = None,
) -> "MessagePredicate": ) -> "MessagePredicate":
"""Same as :meth:`contained_in`, but the response is set to lowercase before matching. """Same as :meth:`contained_in`, but the response is set to lowercase before matching.
@ -757,7 +744,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
The collection containing valid lowercase responses. The collection containing valid lowercase responses.
ctx : Optional[Context] ctx : Optional[Context]
Same as ``ctx`` in :meth:`same_context`. Same as ``ctx`` in :meth:`same_context`.
channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] channel : Optional[discord.abc.Messageable]
Same as ``channel`` in :meth:`same_context`. Same as ``channel`` in :meth:`same_context`.
user : Optional[discord.abc.User] user : Optional[discord.abc.User]
Same as ``user`` in :meth:`same_context`. Same as ``user`` in :meth:`same_context`.
@ -787,9 +774,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
cls, cls,
pattern: Union[Pattern[str], str], pattern: Union[Pattern[str], str],
ctx: Optional[commands.Context] = None, ctx: Optional[commands.Context] = None,
channel: Optional[ channel: Optional[discord.abc.Messageable] = None,
Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel]
] = None,
user: Optional[discord.abc.User] = None, user: Optional[discord.abc.User] = None,
) -> "MessagePredicate": ) -> "MessagePredicate":
"""Match if the response matches the specified regex pattern. """Match if the response matches the specified regex pattern.
@ -804,7 +789,7 @@ class MessagePredicate(Callable[[discord.Message], bool]):
The pattern to search for in the response. The pattern to search for in the response.
ctx : Optional[Context] ctx : Optional[Context]
Same as ``ctx`` in :meth:`same_context`. Same as ``ctx`` in :meth:`same_context`.
channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] channel : Optional[discord.abc.Messageable]
Same as ``channel`` in :meth:`same_context`. Same as ``channel`` in :meth:`same_context`.
user : Optional[discord.abc.User] user : Optional[discord.abc.User]
Same as ``user`` in :meth:`same_context`. Same as ``user`` in :meth:`same_context`.