From 145b2e43ce88cb6e669d4b1a87f7797ce9f60b4f Mon Sep 17 00:00:00 2001 From: Jakub Kuczys Date: Thu, 13 Apr 2023 21:21:36 +0200 Subject: [PATCH] Allow any Messageable in MessagePredicate's channel parameter (#5942) Co-authored-by: Flame442 <34169552+Flame442@users.noreply.github.com> --- redbot/core/utils/predicates.py | 109 ++++++++++++++------------------ 1 file changed, 47 insertions(+), 62 deletions(-) diff --git a/redbot/core/utils/predicates.py b/redbot/core/utils/predicates.py index 11a1d035e..0f383533f 100644 --- a/redbot/core/utils/predicates.py +++ b/redbot/core/utils/predicates.py @@ -4,6 +4,7 @@ import re from typing import Callable, ClassVar, List, Optional, Pattern, Sequence, Tuple, Union, cast import discord +from discord.ext import commands as dpy_commands from redbot.core import commands @@ -67,9 +68,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): def same_context( cls, ctx: Optional[commands.Context] = None, - channel: Optional[ - Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel] - ] = None, + channel: Optional[discord.abc.Messageable] = None, user: Optional[discord.abc.User] = None, ) -> "MessagePredicate": """Match if the message fits the described context. @@ -78,8 +77,8 @@ class MessagePredicate(Callable[[discord.Message], bool]): ---------- ctx : Optional[Context] The current invocation context. - channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] - The channel we expect a message in. If unspecified, + channel : Optional[discord.abc.Messageable] + The messageable object we expect a message in. If unspecified, defaults to ``ctx.channel``. If ``ctx`` is unspecified too, the message's channel will be ignored. user : Optional[discord.abc.User] @@ -93,22 +92,34 @@ class MessagePredicate(Callable[[discord.Message], bool]): The event predicate. """ + check_dm_channel = False + # using dpy_commands.Context to keep the Messageable contract in full + if isinstance(channel, dpy_commands.Context): + channel = channel.channel + elif isinstance(channel, (discord.User, discord.Member)): + check_dm_channel = True + if ctx is not None: channel = channel or ctx.channel user = user or ctx.author return cls( lambda self, m: (user is None or user.id == m.author.id) - and (channel is None or channel.id == m.channel.id) + and ( + channel is None + or ( + channel.id == m.author.id and isinstance(m.channel, discord.DMChannel) + if check_dm_channel + else channel.id == m.channel.id + ) + ) ) @classmethod def cancelled( cls, ctx: Optional[commands.Context] = None, - channel: Optional[ - Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel] - ] = None, + channel: Optional[discord.abc.Messageable] = None, user: Optional[discord.abc.User] = None, ) -> "MessagePredicate": """Match if the message is ``[p]cancel``. @@ -117,7 +128,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): ---------- ctx : Optional[Context] Same as ``ctx`` in :meth:`same_context`. - channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] + channel : Optional[discord.abc.Messageable] Same as ``channel`` in :meth:`same_context`. user : Optional[discord.abc.User] Same as ``user`` in :meth:`same_context`. @@ -137,9 +148,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): def yes_or_no( cls, ctx: Optional[commands.Context] = None, - channel: Optional[ - Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel] - ] = None, + channel: Optional[discord.abc.Messageable] = None, user: Optional[discord.abc.User] = None, ) -> "MessagePredicate": """Match if the message is "yes"/"y" or "no"/"n". @@ -151,7 +160,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): ---------- ctx : Optional[Context] Same as ``ctx`` in :meth:`same_context`. - channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] + channel : Optional[discord.abc.Messageable] Same as ``channel`` in :meth:`same_context`. user : Optional[discord.abc.User] Same as ``user`` in :meth:`same_context`. @@ -182,9 +191,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): def valid_int( cls, ctx: Optional[commands.Context] = None, - channel: Optional[ - Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel] - ] = None, + channel: Optional[discord.abc.Messageable] = None, user: Optional[discord.abc.User] = None, ) -> "MessagePredicate": """Match if the response is an integer. @@ -195,7 +202,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): ---------- ctx : Optional[Context] Same as ``ctx`` in :meth:`same_context`. - channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] + channel : Optional[discord.abc.Messageable] Same as ``channel`` in :meth:`same_context`. user : Optional[discord.abc.User] Same as ``user`` in :meth:`same_context`. @@ -224,9 +231,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): def valid_float( cls, ctx: Optional[commands.Context] = None, - channel: Optional[ - Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel] - ] = None, + channel: Optional[discord.abc.Messageable] = None, user: Optional[discord.abc.User] = None, ) -> "MessagePredicate": """Match if the response is a float. @@ -237,7 +242,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): ---------- ctx : Optional[Context] Same as ``ctx`` in :meth:`same_context`. - channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] + channel : Optional[discord.abc.Messageable] Same as ``channel`` in :meth:`same_context`. user : Optional[discord.abc.User] Same as ``user`` in :meth:`same_context`. @@ -266,9 +271,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): def positive( cls, ctx: Optional[commands.Context] = None, - channel: Optional[ - Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel] - ] = None, + channel: Optional[discord.abc.Messageable] = None, user: Optional[discord.abc.User] = None, ) -> "MessagePredicate": """Match if the response is a positive number. @@ -279,7 +282,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): ---------- ctx : Optional[Context] Same as ``ctx`` in :meth:`same_context`. - channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] + channel : Optional[discord.abc.Messageable] Same as ``channel`` in :meth:`same_context`. user : Optional[discord.abc.User] Same as ``user`` in :meth:`same_context`. @@ -504,9 +507,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): cls, value: str, ctx: Optional[commands.Context] = None, - channel: Optional[ - Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel] - ] = None, + channel: Optional[discord.abc.Messageable] = None, user: Optional[discord.abc.User] = None, ) -> "MessagePredicate": """Match if the response is equal to the specified value. @@ -517,7 +518,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): The value to compare the response with. ctx : Optional[Context] Same as ``ctx`` in :meth:`same_context`. - channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] + channel : Optional[discord.abc.Messageable] Same as ``channel`` in :meth:`same_context`. user : Optional[discord.abc.User] Same as ``user`` in :meth:`same_context`. @@ -536,9 +537,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): cls, value: str, ctx: Optional[commands.Context] = None, - channel: Optional[ - Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel] - ] = None, + channel: Optional[discord.abc.Messageable] = None, user: Optional[discord.abc.User] = None, ) -> "MessagePredicate": """Match if the response *as lowercase* is equal to the specified value. @@ -549,7 +548,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): The value to compare the response with. ctx : Optional[Context] Same as ``ctx`` in :meth:`same_context`. - channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] + channel : Optional[discord.abc.Messageable] Same as ``channel`` in :meth:`same_context`. user : Optional[discord.abc.User] Same as ``user`` in :meth:`same_context`. @@ -568,9 +567,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): cls, value: Union[int, float], ctx: Optional[commands.Context] = None, - channel: Optional[ - Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel] - ] = None, + channel: Optional[discord.abc.Messageable] = None, user: Optional[discord.abc.User] = None, ) -> "MessagePredicate": """Match if the response is less than the specified value. @@ -581,7 +578,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): The value to compare the response with. ctx : Optional[Context] Same as ``ctx`` in :meth:`same_context`. - channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] + channel : Optional[discord.abc.Messageable] Same as ``channel`` in :meth:`same_context`. user : Optional[discord.abc.User] Same as ``user`` in :meth:`same_context`. @@ -601,9 +598,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): cls, value: Union[int, float], ctx: Optional[commands.Context] = None, - channel: Optional[ - Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel] - ] = None, + channel: Optional[discord.abc.Messageable] = None, user: Optional[discord.abc.User] = None, ) -> "MessagePredicate": """Match if the response is greater than the specified value. @@ -614,7 +609,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): The value to compare the response with. ctx : Optional[Context] Same as ``ctx`` in :meth:`same_context`. - channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] + channel : Optional[discord.abc.Messageable] Same as ``channel`` in :meth:`same_context`. user : Optional[discord.abc.User] Same as ``user`` in :meth:`same_context`. @@ -634,9 +629,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): cls, length: int, ctx: Optional[commands.Context] = None, - channel: Optional[ - Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel] - ] = None, + channel: Optional[discord.abc.Messageable] = None, user: Optional[discord.abc.User] = None, ) -> "MessagePredicate": """Match if the response's length is less than the specified length. @@ -647,7 +640,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): The value to compare the response's length with. ctx : Optional[Context] Same as ``ctx`` in :meth:`same_context`. - channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] + channel : Optional[discord.abc.Messageable] Same as ``channel`` in :meth:`same_context`. user : Optional[discord.abc.User] Same as ``user`` in :meth:`same_context`. @@ -666,9 +659,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): cls, length: int, ctx: Optional[commands.Context] = None, - channel: Optional[ - Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel] - ] = None, + channel: Optional[discord.abc.Messageable] = None, user: Optional[discord.abc.User] = None, ) -> "MessagePredicate": """Match if the response's length is greater than the specified length. @@ -679,7 +670,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): The value to compare the response's length with. ctx : Optional[Context] Same as ``ctx`` in :meth:`same_context`. - channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] + channel : Optional[discord.abc.Messageable] Same as ``channel`` in :meth:`same_context`. user : Optional[discord.abc.User] Same as ``user`` in :meth:`same_context`. @@ -698,9 +689,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): cls, collection: Sequence[str], ctx: Optional[commands.Context] = None, - channel: Optional[ - Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel] - ] = None, + channel: Optional[discord.abc.Messageable] = None, user: Optional[discord.abc.User] = None, ) -> "MessagePredicate": """Match if the response is contained in the specified collection. @@ -714,7 +703,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): The collection containing valid responses. ctx : Optional[Context] Same as ``ctx`` in :meth:`same_context`. - channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] + channel : Optional[discord.abc.Messageable] Same as ``channel`` in :meth:`same_context`. user : Optional[discord.abc.User] Same as ``user`` in :meth:`same_context`. @@ -744,9 +733,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): cls, collection: Sequence[str], ctx: Optional[commands.Context] = None, - channel: Optional[ - Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel] - ] = None, + channel: Optional[discord.abc.Messageable] = None, user: Optional[discord.abc.User] = None, ) -> "MessagePredicate": """Same as :meth:`contained_in`, but the response is set to lowercase before matching. @@ -757,7 +744,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): The collection containing valid lowercase responses. ctx : Optional[Context] Same as ``ctx`` in :meth:`same_context`. - channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] + channel : Optional[discord.abc.Messageable] Same as ``channel`` in :meth:`same_context`. user : Optional[discord.abc.User] Same as ``user`` in :meth:`same_context`. @@ -787,9 +774,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): cls, pattern: Union[Pattern[str], str], ctx: Optional[commands.Context] = None, - channel: Optional[ - Union[discord.TextChannel, discord.VoiceChannel, discord.Thread, discord.DMChannel] - ] = None, + channel: Optional[discord.abc.Messageable] = None, user: Optional[discord.abc.User] = None, ) -> "MessagePredicate": """Match if the response matches the specified regex pattern. @@ -804,7 +789,7 @@ class MessagePredicate(Callable[[discord.Message], bool]): The pattern to search for in the response. ctx : Optional[Context] Same as ``ctx`` in :meth:`same_context`. - channel : Optional[Union[`discord.TextChannel`, `discord.VoiceChannel`, `discord.Thread`, `discord.DMChannel`]] + channel : Optional[discord.abc.Messageable] Same as ``channel`` in :meth:`same_context`. user : Optional[discord.abc.User] Same as ``user`` in :meth:`same_context`.