diff --git a/docs/framework_utils.rst b/docs/framework_utils.rst index f9573b05b..5b7037809 100644 --- a/docs/framework_utils.rst +++ b/docs/framework_utils.rst @@ -22,12 +22,18 @@ Embed Helpers .. automodule:: redbot.core.utils.embed :members: -Menu Helpers -============ +Reaction Menus +============== .. automodule:: redbot.core.utils.menus :members: +Event Predicates +================ + +.. automodule:: redbot.core.utils.predicates + :members: + Mod Helpers =========== diff --git a/redbot/cogs/audio/audio.py b/redbot/cogs/audio/audio.py index 1b8d5f2c5..dd17228d8 100644 --- a/redbot/cogs/audio/audio.py +++ b/redbot/cogs/audio/audio.py @@ -13,8 +13,16 @@ import time import redbot.core from redbot.core import Config, commands, checks, bank from redbot.core.data_manager import cog_data_path -from redbot.core.utils.menus import menu, DEFAULT_CONTROLS, prev_page, next_page, close_menu from redbot.core.i18n import Translator, cog_i18n +from redbot.core.utils.menus import ( + menu, + DEFAULT_CONTROLS, + prev_page, + next_page, + close_menu, + start_adding_reactions, +) +from redbot.core.utils.predicates import MessagePredicate, ReactionPredicate from urllib.parse import urlparse from .manager import shutdown_lavalink_server @@ -225,22 +233,17 @@ class Audio(commands.Cog): async def dj(self, ctx): """Toggle DJ mode (users need a role to use audio commands).""" dj_role_id = await self.config.guild(ctx.guild).dj_role() - if dj_role_id is None: + if dj_role_id is None and ctx.guild.get_role(dj_role_id): await self._embed_msg( - ctx, "Please set a role to use with DJ mode. Enter the role name now." + ctx, "Please set a role to use with DJ mode. Enter the role name or ID now." ) - def check(m): - return m.author == ctx.author - try: - dj_role = await ctx.bot.wait_for("message", timeout=15.0, check=check) - dj_role_obj = discord.utils.get(ctx.guild.roles, name=dj_role.content) - if dj_role_obj is None: - return await self._embed_msg(ctx, "No role with that name.") - await ctx.invoke(self.role, dj_role_obj) + pred = MessagePredicate.valid_role(ctx) + await ctx.bot.wait_for("message", timeout=15.0, check=pred) + await ctx.invoke(self.role, pred.result) except asyncio.TimeoutError: - return await self._embed_msg(ctx, "No role entered, try again later.") + return await self._embed_msg(ctx, "Response timed out, try again later.") dj_enabled = await self.config.guild(ctx.guild).dj_enabled() await self.config.guild(ctx.guild).dj_enabled.set(not dj_enabled) @@ -710,20 +713,21 @@ class Audio(commands.Cog): return if player.current: - for i in range(4): - await message.add_reaction(expected[i]) - - def check(r, u): - return ( - r.message.id == message.id - and u == ctx.message.author - and any(e in str(r.emoji) for e in expected) - ) + task = start_adding_reactions(message, expected[:4], ctx.bot.loop) + else: + task = None try: - (r, u) = await self.bot.wait_for("reaction_add", check=check, timeout=10.0) + (r, u) = await self.bot.wait_for( + "reaction_add", + check=ReactionPredicate.with_emojis(expected, message, ctx.author), + timeout=10.0, + ) except asyncio.TimeoutError: return await self._clear_react(message) + else: + if task is not None: + task.cancel() reacts = {v: k for k, v in emoji.items()} react = reacts[r.emoji] if react == "prev": @@ -1125,11 +1129,12 @@ class Audio(commands.Cog): if not playlist_name: await self._embed_msg(ctx, "Please enter a name for this playlist.") - def check(m): - return m.author == ctx.author and not m.content.startswith(ctx.prefix) - try: - playlist_name_msg = await ctx.bot.wait_for("message", timeout=15.0, check=check) + playlist_name_msg = await ctx.bot.wait_for( + "message", + timeout=15.0, + check=MessagePredicate.regex(fr"^(?!{ctx.prefix})", ctx), + ) playlist_name = playlist_name_msg.content.split(" ")[0].strip('"') if len(playlist_name) > 20: return await self._embed_msg(ctx, "Try the command again with a shorter name.") @@ -1238,11 +1243,10 @@ class Audio(commands.Cog): ctx, "Please upload the playlist file. Any other message will cancel this operation." ) - def check(m): - return m.author == ctx.author - try: - file_message = await ctx.bot.wait_for("message", timeout=30.0, check=check) + file_message = await ctx.bot.wait_for( + "message", timeout=30.0, check=MessagePredicate.same_context(ctx) + ) except asyncio.TimeoutError: return await self._embed_msg(ctx, "No file detected, try again later.") try: diff --git a/redbot/cogs/cleanup/cleanup.py b/redbot/cogs/cleanup/cleanup.py index 4feb4de45..7ca041e19 100644 --- a/redbot/cogs/cleanup/cleanup.py +++ b/redbot/cogs/cleanup/cleanup.py @@ -9,6 +9,7 @@ from redbot.core.bot import Red from redbot.core.i18n import Translator, cog_i18n from redbot.core.utils.mod import slow_deletion, mass_purge from redbot.cogs.mod.log import log +from redbot.core.utils.predicates import MessagePredicate _ = Translator("Cleanup", __file__) @@ -31,13 +32,10 @@ class Cleanup(commands.Cog): Tries its best to cleanup after itself if the response is positive. """ - def author_check(message): - return message.author == ctx.author - prompt = await ctx.send( _("Are you sure you want to delete {} messages? (y/n)").format(number) ) - response = await ctx.bot.wait_for("message", check=author_check) + response = await ctx.bot.wait_for("message", check=MessagePredicate.same_context(ctx)) if response.content.lower().startswith("y"): await prompt.delete() diff --git a/redbot/cogs/customcom/customcom.py b/redbot/cogs/customcom/customcom.py index df947d3c5..ab76794ec 100644 --- a/redbot/cogs/customcom/customcom.py +++ b/redbot/cogs/customcom/customcom.py @@ -11,6 +11,7 @@ import discord from redbot.core import Config, checks, commands from redbot.core.utils.chat_formatting import box, pagify from redbot.core.i18n import Translator, cog_i18n +from redbot.core.utils.predicates import MessagePredicate _ = Translator("CustomCommands", __file__) @@ -58,14 +59,11 @@ class CommandObj: ).format("customcommand", "customcommand", "exit()") await ctx.send(intro) - def check(m): - return m.channel == ctx.channel and m.author == ctx.message.author - responses = [] args = None while True: await ctx.send(_("Add a random response:")) - msg = await self.bot.wait_for("message", check=check) + msg = await self.bot.wait_for("message", check=MessagePredicate.same_context(ctx)) if msg.content.lower() == "exit()": break @@ -130,18 +128,27 @@ class CommandObj: author = ctx.message.author - def check(m): - return m.channel == ctx.channel and m.author == ctx.author - if ask_for and not response: - await ctx.send(_("Do you want to create a 'randomized' cc? {}").format("y/n")) + await ctx.send(_("Do you want to create a 'randomized' custom command? (y/n)")) - msg = await self.bot.wait_for("message", check=check) - if msg.content.lower() == "y": + pred = MessagePredicate.yes_or_no(ctx) + try: + await self.bot.wait_for("message", check=pred, timeout=30) + except TimeoutError: + await ctx.send(_("Response timed out, please try again later.")) + return + if pred.result is True: response = await self.get_responses(ctx=ctx) else: await ctx.send(_("What response do you want?")) - response = (await self.bot.wait_for("message", check=check)).content + try: + resp = await self.bot.wait_for( + "message", check=MessagePredicate.same_context(ctx), timeout=180 + ) + except TimeoutError: + await ctx.send(_("Response timed out, please try again later.")) + return + response = resp.content if response: # test to raise diff --git a/redbot/cogs/dataconverter/dataconverter.py b/redbot/cogs/dataconverter/dataconverter.py index f2dc05cd8..93696800b 100644 --- a/redbot/cogs/dataconverter/dataconverter.py +++ b/redbot/cogs/dataconverter/dataconverter.py @@ -6,6 +6,7 @@ from redbot.core.bot import Red from redbot.core.i18n import Translator, cog_i18n from redbot.cogs.dataconverter.core_specs import SpecResolver from redbot.core.utils.chat_formatting import box +from redbot.core.utils.predicates import MessagePredicate _ = Translator("DataConverter", __file__) @@ -48,11 +49,10 @@ class DataConverter(commands.Cog): menu_message = await ctx.send(box(menu)) - def pred(m): - return m.channel == ctx.channel and m.author == ctx.author - try: - message = await self.bot.wait_for("message", check=pred, timeout=60) + message = await self.bot.wait_for( + "message", check=MessagePredicate.same_context(ctx), timeout=60 + ) except asyncio.TimeoutError: return await ctx.send(_("Try this again when you are more ready")) else: diff --git a/redbot/cogs/downloader/checks.py b/redbot/cogs/downloader/checks.py index e5dc2e88c..33bd192d7 100644 --- a/redbot/cogs/downloader/checks.py +++ b/redbot/cogs/downloader/checks.py @@ -1,7 +1,7 @@ import asyncio -import discord from redbot.core import commands +from redbot.core.utils.predicates import MessagePredicate __all__ = ["do_install_agreement"] @@ -21,13 +21,12 @@ async def do_install_agreement(ctx: commands.Context): if downloader is None or downloader.already_agreed: return True - def does_agree(msg: discord.Message): - return ctx.author == msg.author and ctx.channel == msg.channel and msg.content == "I agree" - await ctx.send(REPO_INSTALL_MSG) try: - await ctx.bot.wait_for("message", check=does_agree, timeout=30) + await ctx.bot.wait_for( + "message", check=MessagePredicate.lower_equal_to("i agree", ctx), timeout=30 + ) except asyncio.TimeoutError: await ctx.send("Your response has timed out, please try again.") return False diff --git a/redbot/cogs/permissions/permissions.py b/redbot/cogs/permissions/permissions.py index 61e821760..48346b8d2 100644 --- a/redbot/cogs/permissions/permissions.py +++ b/redbot/cogs/permissions/permissions.py @@ -11,6 +11,8 @@ from redbot.core import checks, commands, config from redbot.core.bot import Red from redbot.core.i18n import Translator, cog_i18n from redbot.core.utils.chat_formatting import box +from redbot.core.utils.menus import start_adding_reactions +from redbot.core.utils.predicates import ReactionPredicate, MessagePredicate from .converters import CogOrCommand, RuleType, ClearableRuleType @@ -20,9 +22,6 @@ COG = "COG" COMMAND = "COMMAND" GLOBAL = 0 -# noinspection PyDictDuplicateKeys -REACTS = {"\N{WHITE HEAVY CHECK MARK}": True, "\N{NEGATIVE SQUARED CROSS MARK}": False} -Y_OR_N = {"y": True, "yes": True, "n": False, "no": False} # The strings in the schema are constants and should get extracted, but not translated until # runtime. translate = _ @@ -566,35 +565,29 @@ class Permissions(commands.Cog): """Ask "Are you sure?" and get the response as a bool.""" if ctx.guild is None or ctx.guild.me.permissions_in(ctx.channel).add_reactions: msg = await ctx.send(_("Are you sure?")) - for emoji in REACTS.keys(): - await msg.add_reaction(emoji) + # noinspection PyAsyncCall + task = start_adding_reactions(msg, ReactionPredicate.YES_OR_NO_EMOJIS, ctx.bot.loop) + pred = ReactionPredicate.yes_or_no(msg, ctx.author) try: - reaction, user = await ctx.bot.wait_for( - "reaction_add", - check=lambda r, u: ( - r.message.id == msg.id and u == ctx.author and r.emoji in REACTS - ), - timeout=30, - ) + await ctx.bot.wait_for("reaction_add", check=pred, timeout=30) except asyncio.TimeoutError: - agreed = False + await ctx.send(_("Response timed out.")) + return False else: - agreed = REACTS.get(reaction.emoji) + task.cancel() + agreed = pred.result + finally: await msg.delete() else: await ctx.send(_("Are you sure? (y/n)")) + pred = MessagePredicate.yes_or_no(ctx) try: - message = await ctx.bot.wait_for( - "message", - check=lambda m: m.author == ctx.author - and m.channel == ctx.channel - and m.content in Y_OR_N, - timeout=30, - ) + await ctx.bot.wait_for("message", check=pred, timeout=30) except asyncio.TimeoutError: - agreed = False + await ctx.send(_("Response timed out.")) + return False else: - agreed = Y_OR_N.get(message.content.lower()) + agreed = pred.result if agreed is False: await ctx.send(_("Action cancelled.")) diff --git a/redbot/cogs/reports/reports.py b/redbot/cogs/reports/reports.py index 1ff190dcb..08cfc1d8a 100644 --- a/redbot/cogs/reports/reports.py +++ b/redbot/cogs/reports/reports.py @@ -11,6 +11,7 @@ from redbot.core.utils.chat_formatting import pagify, box from redbot.core.utils.antispam import AntiSpam from redbot.core.bot import Red from redbot.core.i18n import Translator, cog_i18n +from redbot.core.utils.predicates import MessagePredicate from redbot.core.utils.tunnel import Tunnel @@ -136,13 +137,14 @@ class Reports(commands.Cog): output += "\n{}".format(prompt) for page in pagify(output, delims=["\n"]): - dm = await author.send(box(page)) - - def pred(m): - return m.author == author and m.channel == dm.channel + await author.send(box(page)) try: - message = await self.bot.wait_for("message", check=pred, timeout=45) + message = await self.bot.wait_for( + "message", + check=MessagePredicate.same_context(channel=author.dm_channel, user=author), + timeout=45, + ) except asyncio.TimeoutError: await author.send(_("You took too long to select. Try again later.")) return None @@ -247,7 +249,7 @@ class Reports(commands.Cog): val = await self.send_report(_m, guild) else: try: - dm = await author.send( + await author.send( _( "Please respond to this message with your Report." "\nYour report should be a single message" @@ -256,11 +258,12 @@ class Reports(commands.Cog): except discord.Forbidden: return await ctx.send(_("This requires DMs enabled.")) - def pred(m): - return m.author == author and m.channel == dm.channel - try: - message = await self.bot.wait_for("message", check=pred, timeout=180) + message = await self.bot.wait_for( + "message", + check=MessagePredicate.same_context(ctx, channel=author.dm_channel), + timeout=180, + ) except asyncio.TimeoutError: return await author.send(_("You took too long. Try again later.")) else: diff --git a/redbot/cogs/warnings/helpers.py b/redbot/cogs/warnings/helpers.py index 2ea504878..39aae8739 100644 --- a/redbot/cogs/warnings/helpers.py +++ b/redbot/cogs/warnings/helpers.py @@ -5,6 +5,7 @@ import discord from redbot.core import Config, checks, commands from redbot.core.i18n import Translator +from redbot.core.utils.predicates import MessagePredicate _ = Translator("Warnings", __file__) @@ -95,11 +96,10 @@ async def get_command_for_exceeded_points(ctx: commands.Context): await ctx.send(_("You may enter your response now.")) - def same_author_check(m): - return m.author == ctx.author - try: - msg = await ctx.bot.wait_for("message", check=same_author_check, timeout=30) + msg = await ctx.bot.wait_for( + "message", check=MessagePredicate.same_context(ctx), timeout=30 + ) except asyncio.TimeoutError: return None else: @@ -140,11 +140,10 @@ async def get_command_for_dropping_points(ctx: commands.Context): await ctx.send(_("You may enter your response now.")) - def same_author_check(m): - return m.author == ctx.author - try: - msg = await ctx.bot.wait_for("message", check=same_author_check, timeout=30) + msg = await ctx.bot.wait_for( + "message", check=MessagePredicate.same_context(ctx), timeout=30 + ) except asyncio.TimeoutError: return None else: diff --git a/redbot/cogs/warnings/warnings.py b/redbot/cogs/warnings/warnings.py index 75169e1b9..77937b645 100644 --- a/redbot/cogs/warnings/warnings.py +++ b/redbot/cogs/warnings/warnings.py @@ -15,6 +15,7 @@ from redbot.core.i18n import Translator, cog_i18n from redbot.core.utils.mod import is_admin_or_superior from redbot.core.utils.chat_formatting import warning, pagify from redbot.core.utils.menus import menu, DEFAULT_CONTROLS +from redbot.core.utils.predicates import MessagePredicate _ = Translator("Warnings", __file__) @@ -363,12 +364,11 @@ class Warnings(commands.Cog): """Handles getting description and points for custom reasons""" to_add = {"points": 0, "description": ""} - def same_author_check(m): - return m.author == ctx.author - await ctx.send(_("How many points should be given for this reason?")) try: - msg = await ctx.bot.wait_for("message", check=same_author_check, timeout=30) + msg = await ctx.bot.wait_for( + "message", check=MessagePredicate.same_context(ctx), timeout=30 + ) except asyncio.TimeoutError: await ctx.send(_("Ok then.")) return @@ -385,7 +385,9 @@ class Warnings(commands.Cog): await ctx.send(_("Enter a description for this reason.")) try: - msg = await ctx.bot.wait_for("message", check=same_author_check, timeout=30) + msg = await ctx.bot.wait_for( + "message", check=MessagePredicate.same_context(ctx), timeout=30 + ) except asyncio.TimeoutError: await ctx.send(_("Ok then.")) return diff --git a/redbot/core/commands/context.py b/redbot/core/commands/context.py index 65b5a862a..7398b80c0 100644 --- a/redbot/core/commands/context.py +++ b/redbot/core/commands/context.py @@ -6,6 +6,7 @@ from discord.ext import commands from .requires import PermState from ..utils.chat_formatting import box +from ..utils.predicates import MessagePredicate from ..utils import common_filters TICK = "\N{WHITE HEAVY CHECK MARK}" @@ -141,10 +142,6 @@ class Context(commands.Context): messages = tuple(messages) ret = [] - more_check = lambda m: ( - m.author == self.author and m.channel == self.channel and m.content.lower() == "more" - ) - for idx, page in enumerate(messages, 1): if box_lang is None: msg = await self.send(page) @@ -165,7 +162,11 @@ class Context(commands.Context): "".format(is_are, n_remaining, plural) ) try: - resp = await self.bot.wait_for("message", check=more_check, timeout=timeout) + resp = await self.bot.wait_for( + "message", + check=MessagePredicate.lower_equal_to("more", self), + timeout=timeout, + ) except asyncio.TimeoutError: await query.delete() break @@ -175,7 +176,7 @@ class Context(commands.Context): except (discord.HTTPException, AttributeError): # In case the bot can't delete other users' messages, # or is not a bot account - # or chanel is a DM + # or channel is a DM await query.delete() return ret diff --git a/redbot/core/core_commands.py b/redbot/core/core_commands.py index 284ae48fe..a70582063 100644 --- a/redbot/core/core_commands.py +++ b/redbot/core/core_commands.py @@ -24,6 +24,7 @@ from redbot.core import __version__ from redbot.core import checks from redbot.core import i18n from redbot.core import commands +from .utils.predicates import MessagePredicate from .utils.chat_formatting import pagify, box, inline if TYPE_CHECKING: @@ -438,73 +439,63 @@ class Core(commands.Cog, CoreLogic): @checks.is_owner() async def leave(self, ctx): """Leaves server""" - author = ctx.author - guild = ctx.guild + await ctx.send("Are you sure you want me to leave this server? (y/n)") - await ctx.send("Are you sure you want me to leave this server? Type yes to confirm.") - - def conf_check(m): - return m.author == author - - response = await self.bot.wait_for("message", check=conf_check) - - if response.content.lower().strip() == "yes": - await ctx.send("Alright. Bye :wave:") - log.debug("Leaving '{}'".format(guild.name)) - await guild.leave() + pred = MessagePredicate.yes_or_no(ctx) + try: + await self.bot.wait_for("message", check=MessagePredicate.yes_or_no(ctx)) + except asyncio.TimeoutError: + await ctx.send("Response timed out.") + return + else: + if pred.result is True: + await ctx.send("Alright. Bye :wave:") + log.debug("Leaving guild '{}'".format(ctx.guild.name)) + await ctx.guild.leave() + else: + await ctx.send("Alright, I'll stay then :)") @commands.command() @checks.is_owner() async def servers(self, ctx): """Lists and allows to leave servers""" - owner = ctx.author guilds = sorted(list(self.bot.guilds), key=lambda s: s.name.lower()) msg = "" + responses = [] for i, server in enumerate(guilds, 1): msg += "{}: {}\n".format(i, server.name) - - msg += "\nTo leave a server, just type its number." + responses.append(str(i)) for page in pagify(msg, ["\n"]): await ctx.send(page) - def msg_check(m): - return m.author == owner - - while msg is not None: - try: - msg = await self.bot.wait_for("message", check=msg_check, timeout=15) - except asyncio.TimeoutError: - await ctx.send("I guess not.") - break - try: - msg = int(msg.content) - 1 - if msg < 0: - break - await self.leave_confirmation(guilds[msg], owner, ctx) - break - except (IndexError, ValueError, AttributeError): - pass - - async def leave_confirmation(self, server, owner, ctx): - await ctx.send("Are you sure you want me to leave {}? (yes/no)".format(server.name)) - - def conf_check(m): - return m.author == owner + query = await ctx.send("To leave a server, just type its number.") + pred = MessagePredicate.contained_in(responses, ctx) try: - msg = await self.bot.wait_for("message", check=conf_check, timeout=15) - if msg.content.lower().strip() in ("yes", "y"): - if server.owner == ctx.bot.user: - await ctx.send("I cannot leave a guild I am the owner of.") - return - await server.leave() - if server != ctx.guild: + await self.bot.wait_for("message", check=pred, timeout=15) + except asyncio.TimeoutError: + await query.delete() + else: + await self.leave_confirmation(guilds[pred.result], ctx) + + async def leave_confirmation(self, guild, ctx): + if guild.owner.id == ctx.bot.user.id: + await ctx.send("I cannot leave a guild I am the owner of.") + return + + await ctx.send("Are you sure you want me to leave {}? (yes/no)".format(guild.name)) + pred = MessagePredicate.yes_or_no(ctx) + try: + await self.bot.wait_for("message", check=pred, timeout=15) + if pred.result is True: + await guild.leave() + if guild != ctx.guild: await ctx.send("Done.") else: await ctx.send("Alright then.") except asyncio.TimeoutError: - await ctx.send("I guess not.") + await ctx.send("Response timed out.") @commands.command() @checks.is_owner() @@ -892,10 +883,6 @@ class Core(commands.Cog, CoreLogic): @commands.cooldown(1, 60 * 10, commands.BucketType.default) async def owner(self, ctx): """Sets Red's main owner""" - - def check(m): - return m.author == ctx.author and m.channel == ctx.channel - # According to the Python docs this is suitable for cryptographic use random = SystemRandom() length = random.randint(25, 35) @@ -919,10 +906,14 @@ class Core(commands.Cog, CoreLogic): ) try: - message = await ctx.bot.wait_for("message", check=check, timeout=60) + message = await ctx.bot.wait_for( + "message", check=MessagePredicate.same_context(ctx), timeout=60 + ) except asyncio.TimeoutError: self.owner.reset_cooldown(ctx) - await ctx.send(_("The set owner request has timed out.")) + await ctx.send( + _("The `{prefix}set owner` request has timed out.").format(prefix=ctx.prefix) + ) else: if message.content.strip() == token: self.owner.reset_cooldown(ctx) @@ -1146,18 +1137,20 @@ class Core(commands.Cog, CoreLogic): ) await ctx.send(_("Would you like to receive a copy via DM? (y/n)")) - def same_author_check(m): - return m.author == ctx.author and m.channel == ctx.channel - + pred = MessagePredicate.yes_or_no(ctx) try: - msg = await ctx.bot.wait_for("message", check=same_author_check, timeout=60) + await ctx.bot.wait_for("message", check=pred, timeout=60) except asyncio.TimeoutError: - await ctx.send(_("Ok then.")) + await ctx.send(_("Response timed out.")) else: - if msg.content.lower().strip() == "y": - await ctx.author.send( - _("Here's a copy of the backup"), file=discord.File(str(backup_file)) - ) + if pred.result is True: + await ctx.send(_("OK, it's on its way!")) + async with ctx.author.typing(): + await ctx.author.send( + _("Here's a copy of the backup"), file=discord.File(str(backup_file)) + ) + else: + await ctx.send(_("OK then.")) else: await ctx.send(_("That directory doesn't seem to exist...")) diff --git a/redbot/core/dev_commands.py b/redbot/core/dev_commands.py index 58abb2463..1b17befa9 100644 --- a/redbot/core/dev_commands.py +++ b/redbot/core/dev_commands.py @@ -8,9 +8,11 @@ from contextlib import redirect_stdout from copy import copy import discord + from . import checks, commands from .i18n import Translator from .utils.chat_formatting import box, pagify +from .utils.predicates import MessagePredicate """ Notice: @@ -218,12 +220,8 @@ class Dev(commands.Cog): self.sessions.add(ctx.channel.id) await ctx.send(_("Enter code to execute or evaluate. `exit()` or `quit` to exit.")) - msg_check = lambda m: ( - m.author == ctx.author and m.channel == ctx.channel and m.content.startswith("`") - ) - while True: - response = await ctx.bot.wait_for("message", check=msg_check) + response = await ctx.bot.wait_for("message", check=MessagePredicate.regex(r"^`", ctx)) cleaned = self.cleanup_code(response.content) diff --git a/redbot/core/utils/menus.py b/redbot/core/utils/menus.py index e180ce039..2c9a22f0c 100644 --- a/redbot/core/utils/menus.py +++ b/redbot/core/utils/menus.py @@ -1,15 +1,14 @@ -""" -Original source of reaction-based menu idea from -https://github.com/Lunar-Dust/Dusty-Cogs/blob/master/menu/menu.py - -Ported to Red V3 by Palm\_\_ (https://github.com/palmtree5) -""" +# Original source of reaction-based menu idea from +# https://github.com/Lunar-Dust/Dusty-Cogs/blob/master/menu/menu.py +# +# Ported to Red V3 by Palm\_\_ (https://github.com/palmtree5) import asyncio import contextlib -from typing import Union, Iterable +from typing import Union, Iterable, Optional import discord -from redbot.core import commands +from .. import commands +from .predicates import ReactionPredicate _ReactableEmoji = Union[str, discord.Emoji] @@ -71,18 +70,20 @@ async def menu( else: message = await ctx.send(current_page) # Don't wait for reactions to be added (GH-1797) - ctx.bot.loop.create_task(_add_menu_reactions(message, controls.keys())) + # noinspection PyAsyncCall + start_adding_reactions(message, controls.keys(), ctx.bot.loop) else: if isinstance(current_page, discord.Embed): await message.edit(embed=current_page) else: await message.edit(content=current_page) - def react_check(r, u): - return u == ctx.author and r.message.id == message.id and str(r.emoji) in controls.keys() - try: - react, user = await ctx.bot.wait_for("reaction_add", check=react_check, timeout=timeout) + react, user = await ctx.bot.wait_for( + "reaction_add", + check=ReactionPredicate.with_emojis(tuple(controls.keys()), message, ctx.author), + timeout=timeout, + ) except asyncio.TimeoutError: try: await message.clear_reactions() @@ -152,12 +153,51 @@ async def close_menu( return None -async def _add_menu_reactions(message: discord.Message, emojis: Iterable[_ReactableEmoji]): - """Add the reactions""" - # The task should exit silently if the message is deleted - with contextlib.suppress(discord.NotFound): - for emoji in emojis: - await message.add_reaction(emoji) +def start_adding_reactions( + message: discord.Message, + emojis: Iterable[_ReactableEmoji], + loop: Optional[asyncio.AbstractEventLoop] = None, +) -> asyncio.Task: + """Start adding reactions to a message. + + This is a non-blocking operation - calling this will schedule the + reactions being added, but will the calling code will continue to + execute asynchronously. There is no need to await this function. + + This is particularly useful if you wish to start waiting for a + reaction whilst the reactions are still being added - in fact, + this is exactly what `menu` uses to do that. + + This spawns a `asyncio.Task` object and schedules it on ``loop``. + If ``loop`` omitted, the loop will be retreived with + `asyncio.get_event_loop`. + + Parameters + ---------- + message: discord.Message + The message to add reactions to. + emojis : Iterable[Union[str, discord.Emoji]] + The emojis to react to the message with. + loop : Optional[asyncio.AbstractEventLoop] + The event loop. + + Returns + ------- + asyncio.Task + The task for the coroutine adding the reactions. + + """ + + async def task(): + # The task should exit silently if the message is deleted + with contextlib.suppress(discord.NotFound): + for emoji in emojis: + await message.add_reaction(emoji) + + if loop is None: + loop = asyncio.get_event_loop() + + return loop.create_task(task()) DEFAULT_CONTROLS = {"⬅": prev_page, "❌": close_menu, "➡": next_page} diff --git a/redbot/core/utils/predicates.py b/redbot/core/utils/predicates.py index d117466d9..4df7154ee 100644 --- a/redbot/core/utils/predicates.py +++ b/redbot/core/utils/predicates.py @@ -1,149 +1,1015 @@ +import re +from typing import Callable, ClassVar, List, Optional, Pattern, Sequence, Tuple, Union, cast + import discord -from collections import Iterable + +from redbot.core import commands + +_ID_RE = re.compile(r"([0-9]{15,21})$") +_USER_MENTION_RE = re.compile(r"<@!?([0-9]{15,21})>$") +_CHAN_MENTION_RE = re.compile(r"<#([0-9]{15,21})>$") +_ROLE_MENTION_RE = re.compile(r"<&([0-9]{15,21})>$") -class MessagePredicate: - """A simple collection of predicates. +class MessagePredicate(Callable[[discord.Message], bool]): + """A simple collection of predicates for message events. - These predicates were made to help simplify checks in message events and - reduce boilerplate code. + These predicates intend to help simplify checks in message events + and reduce boilerplate code. - For examples: - # valid yes or no response - `ctx.bot.wait_for('message', timeout=15.0, check=Predicate(ctx).confirm)` + This class should be created through the provided classmethods. + Instances of this class are callable message predicates, i.e. they + return ``True`` if a message matches the criteria. - # check if message content in under 2000 characters - `check = Predicate(ctx, length=2000).length_under - ctx.bot.wait_for('message', timeout=15.0, check=check)` + All predicates are combined with :meth:`MessagePredicate.same_context`. + Examples + -------- + Waiting for a response in the same channel and from the same + author:: + + await bot.wait_for("message", check=MessagePredicate.same_context(ctx)) + + Waiting for a response to a yes or no question:: + + pred = MessagePredicate.yes_or_no(ctx) + await bot.wait_for("message", check=pred) + if pred.result is True: + # User responded "yes" + ... + + Getting a member object from a user's response:: + + pred = MessagePredicate.valid_member(ctx) + await bot.wait_for("message", check=pred) + member = pred.result Attributes ---------- - ctx - Context object. - collection : `Iterable` - Optional argument used for checking if the message content is inside the - declared collection. - length : `int` - Optional argument for comparing message lengths. - value - Optional argument that can be either a string, int, float, or object. - Used for comparison and equality. - - Returns - ------- - Boolean or it will raise a ValueError if you use a certain methods without an argument or - the value argument is set to an invalid type for a particular method. + result : Any + The object which the message content matched with. This is + dependent on the predicate used - see each predicate's + documentation for details, not every method will assign this + attribute. Defaults to ``None``. """ - def __init__(self, ctx, collection: Iterable = None, length: int = None, value=None): - self.ctx = ctx - self.collection = collection - self.length = length - self.value = value + def __init__(self, predicate: Callable[["MessagePredicate", discord.Message], bool]) -> None: + self._pred: Callable[["MessagePredicate", discord.Message], bool] = predicate + self.result = None - def valid_source(self, m): - return self.same(m) and self.channel(m) + def __call__(self, message: discord.Message) -> bool: + return self._pred(self, message) - def same(self, m): - """Checks if the author of the message is the same as the command issuer.""" - return self.ctx.author == m.author + @classmethod + def same_context( + cls, + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the reaction fits the described context. - def channel(self, m): - """Verifies the message was sent from the same channel.""" - return self.ctx.channel == m.channel + Parameters + ---------- + ctx : Optional[Context] + The current invokation context. + channel : Optional[discord.TextChannel] + The channel we expect a message in. If unspecified, + defaults to ``ctx.channel``. If ``ctx`` is unspecified + too, the message's channel will be ignored. + user : Optional[discord.TextChannel] + The user we expect a message from. If unspecified, + defaults to ``ctx.author``. If ``ctx`` is unspecified + too, the message's author will be ignored. - def cancelled(self, m): - if self.valid_source(m) and m.content.lower() == f"{self.ctx.prefix}cancel": - raise RuntimeError + Returns + ------- + MessagePredicate + The event predicate. - def confirm(self, m): - """Checks if the author of the message is the same as the command issuer.""" - return self.valid_source(m) and m.content.lower() in ("yes", "no", "y", "n") + """ + if ctx is not None: + channel = channel or ctx.channel + user = user or ctx.author - def valid_int(self, m): - """Returns true if the message content is an integer.""" - return self.valid_source(m) and m.content.isdigit() - - def valid_float(self, m): - """Returns true if the message content is a float.""" - try: - return self.valid_source(m) and float(m.content) >= 1 - except ValueError: - return False - - def positive(self, m): - """Returns true if the message content is an integer and is positive""" - return self.valid_source(m) and m.content.isdigit() and int(m.content) >= 0 - - def valid_role(self, m): - """Returns true if the message content is an existing role on the server.""" - if self.valid_source(m): - if discord.utils.get(self.ctx.guild.roles, name=m.content) is not None: - return True - elif discord.utils.get(self.ctx.guild.roles, id=m.content) is not None: - return True - else: - return False - else: - return False - - def has_role(self, m): - """Returns true if the message content is a role the message sender has.""" - if self.valid_source(m): - if discord.utils.get(self.ctx.roles, name=m.content) is not None: - return True - elif discord.utils.get(self.ctx.roles, id=m.content) is not None: - return True - else: - return False - else: - return False - - def equal(self, m): - """Returns true if the message content is equal to the value set.""" - return self.valid_source(m) and m.content.lower() == self.value.lower() - - def greater(self, m): - """Returns true if the message content is greater than the value set.""" - try: - return self.valid_int(m) or self.valid_float(m) and float(m.content) > int(self.value) - except TypeError: - raise ValueError("Value argument in Predicate() must be an integer or float.") - - def less(self, m): - """Returns true if the message content is less than the value set.""" - try: - return self.valid_int(m) or self.valid_float(m) and float(m.content) < int(self.value) - except TypeError: - raise ValueError("Value argument in Predicate() must be an integer or float.") - - def member(self, m): - """Returns true if the message content is the name of a member in the server.""" - return ( - self.valid_source(m) - and discord.utils.get(self.ctx.guild.members, name=m.content) is not None + return cls( + lambda self, m: (user is None or user.id == m.author.id) + and (channel is None or channel.id == m.channel.id) ) - def length_less(self, m): - """Returns true if the message content length is less than the provided length.""" - try: - return self.valid_source(m) and len(m.content) <= self.length - except TypeError: - raise ValueError("A length must be specified in Predicate().") + @classmethod + def cancelled( + cls, + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the message is ``[p]cancel``. - def length_greater(self, m): - """Returns true if the message content length is greater than or equal - to the provided length.""" - try: - return self.valid_source(m) and len(m.content) >= self.length - except TypeError: - raise ValueError("A length must be specified in Predicate().") + Parameters + ---------- + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. - def contained(self, m): - """Returns true if the message content is a member of the provided collection.""" - try: - return self.valid_source(m) and m.content.lower() in self.collection - except TypeError: - raise ValueError("An iterable was not specified in Predicate().") + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + return cls( + lambda self, m: (same_context(m) and m.content.lower() == f"{ctx.prefix}cancel") + ) + + @classmethod + def yes_or_no( + cls, + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the message is "yes"/"y" or "no"/"n". + + This will assign ``True`` for *yes*, or ``False`` for *no* to + the `result` attribute. + + Parameters + ---------- + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + + def predicate(self: MessagePredicate, m: discord.Message) -> bool: + if not same_context(m): + return False + content = m.content.lower() + if content in ("yes", "y"): + self.result = True + elif content in ("no", "n"): + self.result = False + else: + return False + return True + + return cls(predicate) + + @classmethod + def valid_int( + cls, + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the response is an integer. + + Assigns the response to `result` as an `int`. + + Parameters + ---------- + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + + def predicate(self: MessagePredicate, m: discord.Message) -> bool: + if not same_context(m): + return False + try: + self.result = int(m.content) + except ValueError: + return False + else: + return True + + return cls(predicate) + + @classmethod + def valid_float( + cls, + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the response is a float. + + Assigns the response to `result` as a `float`. + + Parameters + ---------- + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + + def predicate(self: MessagePredicate, m: discord.Message) -> bool: + if not same_context(m): + return False + try: + self.result = float(m.content) + except ValueError: + return False + else: + return True + + return cls(predicate) + + @classmethod + def positive( + cls, + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the response is a positive number. + + Assigns the response to `result` as a `float`. + + Parameters + ---------- + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + + def predicate(self: MessagePredicate, m: discord.Message) -> bool: + if not same_context(m): + return False + try: + number = float(m.content) + except ValueError: + return False + else: + if number > 0: + self.result = number + return True + else: + return False + + return cls(predicate) + + @classmethod + def valid_role( + cls, + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the response refers to a role in the current guild. + + Assigns the matching `discord.Role` object to `result`. + + This predicate cannot be used in DM. + + Parameters + ---------- + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + guild = cls._get_guild(ctx, channel, cast(discord.Member, user)) + + def predicate(self: MessagePredicate, m: discord.Message) -> bool: + if not same_context(m): + return False + + role = self._find_role(guild, m.content) + if role is None: + return False + + self.result = role + return True + + return cls(predicate) + + @classmethod + def valid_member( + cls, + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the response refers to a member in the current guild. + + Assigns the matching `discord.Member` object to `result`. + + This predicate cannot be used in DM. + + Parameters + ---------- + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + guild = cls._get_guild(ctx, channel, cast(discord.Member, user)) + + def predicate(self: MessagePredicate, m: discord.Message) -> bool: + if not same_context(m): + return False + + match = _ID_RE.match(m.content) or _USER_MENTION_RE.match(m.content) + if match: + result = guild.get_member(int(match.group(1))) + else: + result = guild.get_member_named(m.content) + + if result is None: + return False + self.result = result + return True + + return cls(predicate) + + @classmethod + def valid_text_channel( + cls, + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the response refers to a text channel in the current guild. + + Assigns the matching `discord.TextChannel` object to `result`. + + This predicate cannot be used in DM. + + Parameters + ---------- + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + guild = cls._get_guild(ctx, channel, cast(discord.Member, user)) + + def predicate(self: MessagePredicate, m: discord.Message) -> bool: + if not same_context(m): + return False + + match = _ID_RE.match(m.content) or _CHAN_MENTION_RE.match(m.content) + if match: + result = guild.get_channel(int(match.group(1))) + else: + result = discord.utils.get(guild.text_channels, name=m.content) + + if not isinstance(result, discord.TextChannel): + return False + self.result = result + return True + + return cls(predicate) + + @classmethod + def has_role( + cls, + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the response refers to a role which the author has. + + Assigns the matching `discord.Role` object to `result`. + + One of ``user`` or ``ctx`` must be supplied. This predicate + cannot be used in DM. + + Parameters + ---------- + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + guild = cls._get_guild(ctx, channel, cast(discord.Member, user)) + if user is None: + if ctx is None: + raise TypeError( + "One of `user` or `ctx` must be supplied to `MessagePredicate.has_role`." + ) + user = ctx.author + + def predicate(self: MessagePredicate, m: discord.Message) -> bool: + if not same_context(m): + return False + + role = self._find_role(guild, m.content) + if role is None or role not in user.roles: + return False + + self.result = role + return True + + return cls(predicate) + + @classmethod + def equal_to( + cls, + value: str, + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the response is equal to the specified value. + + Parameters + ---------- + value : str + The value to compare the response with. + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + return cls(lambda self, m: same_context(m) and m.content == value) + + @classmethod + def lower_equal_to( + cls, + value: str, + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the response *as lowercase* is equal to the specified value. + + Parameters + ---------- + value : str + The value to compare the response with. + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + return cls(lambda self, m: same_context(m) and m.content.lower() == value) + + @classmethod + def less( + cls, + value: Union[int, float], + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the response is less than the specified value. + + Parameters + ---------- + value : Union[int, float] + The value to compare the response with. + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + valid_int = cls.valid_int(ctx, channel, user) + valid_float = cls.valid_float(ctx, channel, user) + return cls(lambda self, m: valid_int(m) or valid_float(m) and float(m.content) < value) + + @classmethod + def greater( + cls, + value: Union[int, float], + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the response is greater than the specified value. + + Parameters + ---------- + value : Union[int, float] + The value to compare the response with. + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + valid_int = cls.valid_int(ctx, channel, user) + valid_float = cls.valid_float(ctx, channel, user) + return cls(lambda self, m: valid_int(m) or valid_float(m) and float(m.content) > value) + + @classmethod + def length_less( + cls, + length: int, + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the response's length is less than the specified length. + + Parameters + ---------- + length : int + The value to compare the response's length with. + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + return cls(lambda self, m: same_context(m) and len(m.content) <= length) + + @classmethod + def length_greater( + cls, + length: int, + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the response's length is greater than the specified length. + + Parameters + ---------- + length : int + The value to compare the response's length with. + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + return cls(lambda self, m: same_context(m) and len(m.content) >= length) + + @classmethod + def contained_in( + cls, + collection: Sequence[str], + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the response is contained in the specified collection. + + The index of the response in the ``collection`` sequence is + assigned to the `result` attribute. + + Parameters + ---------- + collection : Sequence[str] + The collection containing valid responses. + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + + def predicate(self: MessagePredicate, m: discord.Message) -> bool: + if not same_context(m): + return False + try: + self.result = collection.index(m.content) + except ValueError: + return False + else: + return True + + return cls(predicate) + + @classmethod + def lower_contained_in( + cls, + collection: Sequence[str], + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Same as :meth:`contained_in`, but the response is set to lowercase before matching. + + Parameters + ---------- + collection : Sequence[str] + The collection containing valid lowercase responses. + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + + def predicate(self: MessagePredicate, m: discord.Message) -> bool: + if not same_context(m): + return False + try: + self.result = collection.index(m.content) + except ValueError: + return False + else: + return True + + return cls(predicate) + + @classmethod + def regex( + cls, + pattern: Union[Pattern[str], str], + ctx: Optional[commands.Context] = None, + channel: Optional[discord.TextChannel] = None, + user: Optional[discord.abc.User] = None, + ) -> "MessagePredicate": + """Match if the response matches the specified regex pattern. + + This predicate will use `re.search` to find a match. The + resulting `match object ` will be assigned + to `result`. + + Parameters + ---------- + pattern : Union[`pattern object `, str] + The pattern to search for in the response. + ctx : Optional[Context] + Same as ``ctx`` in :meth:`same_context`. + channel : Optional[discord.TextChannel] + Same as ``channel`` in :meth:`same_context`. + user : Optional[discord.TextChannel] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + MessagePredicate + The event predicate. + + """ + same_context = cls.same_context(ctx, channel, user) + + def predicate(self: MessagePredicate, m: discord.Message) -> bool: + if not same_context(m): + return False + + if isinstance(pattern, str): + pattern_obj = re.compile(pattern) + else: + pattern_obj = pattern + + match = pattern_obj.search(m.content) + if match: + self.result = match + return True + return False + + return cls(predicate) + + @staticmethod + def _find_role(guild: discord.Guild, argument: str) -> Optional[discord.Role]: + match = _ID_RE.match(argument) or _ROLE_MENTION_RE.match(argument) + if match: + result = guild.get_role(int(match.group(1))) + else: + result = discord.utils.get(guild.roles, name=argument) + return result + + @staticmethod + def _get_guild( + ctx: commands.Context, channel: discord.TextChannel, user: discord.Member + ) -> discord.Guild: + if ctx is not None: + return ctx.guild + elif channel is not None: + return channel.guild + elif user is not None: + return user.guild + + +class ReactionPredicate(Callable[[discord.Reaction, discord.abc.User], bool]): + """A collection of predicates for reaction events. + + All checks are combined with :meth:`ReactionPredicate.same_context`. + + Examples + -------- + Confirming a yes/no question with a tick/cross reaction:: + + from redbot.core.utils.predicates import ReactionPredicate + from redbot.core.utils.menus import start_adding_reactions + + msg = await ctx.send("Yes or no?") + start_adding_reactions(msg, ReactionPredicate.YES_OR_NO_EMOJIS) + + pred = ReactionPredicate.yes_or_no(msg, ctx.author) + await ctx.bot.wait_for("reaction_add", check=pred) + if pred.result is True: + # User responded with tick + ... + else: + # User responded with cross + ... + + Waiting for the first reaction from any user with one of the first + 5 letters of the alphabet:: + + from redbot.core.utils.predicates import ReactionPredicate + from redbot.core.utils.menus import start_adding_reactions + + msg = await ctx.send("React to me!") + emojis = ReactionPredicate.ALPHABET_EMOJIS[:5] + start_adding_reactions(msg, emojis) + + pred = ReactionPredicate.with_emojis(emojis, msg) + await ctx.bot.wait_for("reaction_add", check=pred) + # pred.result is now the index of the letter in `emojis` + + Attributes + ---------- + result : Any + The object which the message content matched with. This is + dependent on the predicate used - see each predicate's + documentation for details, not every method will assign this + attribute. Defaults to ``None``. + + """ + + YES_OR_NO_EMOJIS: ClassVar[Tuple[str, str]] = ( + "\N{WHITE HEAVY CHECK MARK}", + "\N{NEGATIVE SQUARED CROSS MARK}", + ) + """Tuple[str, str] : A tuple containing the tick emoji and cross emoji, in that order.""" + + ALPHABET_EMOJIS: ClassVar[List[str]] = [ + chr(code) + for code in range( + ord("\N{REGIONAL INDICATOR SYMBOL LETTER A}"), + ord("\N{REGIONAL INDICATOR SYMBOL LETTER Z}") + 1, + ) + ] + """List[str] : A list of all 26 alphabetical letter emojis.""" + + NUMBER_EMOJIS: ClassVar[List[str]] = [ + chr(code) + "\N{COMBINING ENCLOSING KEYCAP}" for code in range(ord("0"), ord("9") + 1) + ] + """List[str] : A list of all single-digit number emojis, 0 through 9.""" + + def __init__( + self, predicate: Callable[["ReactionPredicate", discord.Reaction, discord.abc.User], bool] + ) -> None: + self._pred: Callable[ + ["ReactionPredicate", discord.Reaction, discord.abc.User], bool + ] = predicate + self.result = None + + def __call__(self, reaction: discord.Reaction, user: discord.abc.User) -> bool: + return self._pred(self, reaction, user) + + # noinspection PyUnusedLocal + @classmethod + def same_context( + cls, message: Optional[discord.Message] = None, user: Optional[discord.abc.User] = None + ) -> "ReactionPredicate": + """Match if a reaction fits the described context. + + This will ignore reactions added by the bot user, regardless + of whether or not ``user`` is supplied. + + Parameters + ---------- + message : Optional[discord.Message] + The message which we expect a reaction to. If unspecified, + the reaction's message will be ignored. + user : Optional[discord.abc.User] + The user we expect to react. If unspecified, the user who + added the reaction will be ignored. + + Returns + ------- + ReactionPredicate + The event predicate. + + """ + # noinspection PyProtectedMember + me_id = message._state.self_id + return cls( + lambda self, r, u: u.id != me_id + and (message is None or r.message.id == message.id) + and (user is None or u.id == user.id) + ) + + @classmethod + def with_emojis( + cls, + emojis: Sequence[Union[str, discord.Emoji, discord.PartialEmoji]], + message: Optional[discord.Message] = None, + user: Optional[discord.abc.User] = None, + ) -> "ReactionPredicate": + """Match if the reaction is one of the specified emojis. + + Parameters + ---------- + emojis : Sequence[Union[str, discord.Emoji, discord.PartialEmoji]] + The emojis of which one we expect to be reacted. + message : discord.Message + Same as ``message`` in :meth:`same_context`. + user : Optional[discord.abc.User] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + ReactionPredicate + The event predicate. + + """ + same_context = cls.same_context(message, user) + + def predicate(self: ReactionPredicate, r: discord.Reaction, u: discord.abc.User): + if not same_context(r, u): + return False + + try: + self.result = emojis.index(r.emoji) + except ValueError: + return False + else: + return True + + return cls(predicate) + + @classmethod + def yes_or_no( + cls, message: Optional[discord.Message] = None, user: Optional[discord.abc.User] = None + ) -> "ReactionPredicate": + """Match if the reaction is a tick or cross emoji. + + The emojis used can are in + `ReactionPredicate.YES_OR_NO_EMOJIS`. + + This will assign ``True`` for *yes*, or ``False`` for *no* to + the `result` attribute. + + Parameters + ---------- + message : discord.Message + Same as ``message`` in :meth:`same_context`. + user : Optional[discord.abc.User] + Same as ``user`` in :meth:`same_context`. + + Returns + ------- + ReactionPredicate + The event predicate. + + """ + same_context = cls.same_context(message, user) + + def predicate(self: ReactionPredicate, r: discord.Reaction, u: discord.abc.User) -> bool: + if not same_context(r, u): + return False + + try: + self.result = not bool(self.YES_OR_NO_EMOJIS.index(r.emoji)) + except ValueError: + return False + else: + return True + + return cls(predicate)