diff --git a/redbot/cogs/filter/filter.py b/redbot/cogs/filter/filter.py index 716bd5af1..b445f6e38 100644 --- a/redbot/cogs/filter/filter.py +++ b/redbot/cogs/filter/filter.py @@ -7,7 +7,6 @@ from redbot.core.bot import Red from redbot.core.i18n import Translator, cog_i18n from redbot.core.utils.chat_formatting import pagify -RE_WORD_SPLIT = re.compile(r"[^\w]") _ = Translator("Filter", __file__) @@ -32,6 +31,7 @@ class Filter(commands.Cog): self.settings.register_member(**default_member_settings) self.settings.register_channel(**default_channel_settings) self.register_task = self.bot.loop.create_task(self.register_filterban()) + self.pattern_cache = {} def __unload(self): self.register_task.cancel() @@ -165,6 +165,7 @@ class Filter(commands.Cog): tmp += word + " " added = await self.add_to_filter(channel, word_list) if added: + self.invalidate_cache(ctx.guild, ctx.channel) await ctx.send(_("Words added to filter.")) else: await ctx.send(_("Words already in the filter.")) @@ -198,6 +199,7 @@ class Filter(commands.Cog): removed = await self.remove_from_filter(channel, word_list) if removed: await ctx.send(_("Words removed from filter.")) + self.invalidate_cache(ctx.guild, ctx.channel) else: await ctx.send(_("Those words weren't in the filter.")) @@ -229,6 +231,7 @@ class Filter(commands.Cog): tmp += word + " " added = await self.add_to_filter(server, word_list) if added: + self.invalidate_cache(ctx.guild) await ctx.send(_("Words successfully added to filter.")) else: await ctx.send(_("Those words were already in the filter.")) @@ -261,6 +264,7 @@ class Filter(commands.Cog): tmp += word + " " removed = await self.remove_from_filter(server, word_list) if removed: + self.invalidate_cache(ctx.guild) await ctx.send(_("Words successfully removed from filter.")) else: await ctx.send(_("Those words weren't in the filter.")) @@ -279,6 +283,10 @@ class Filter(commands.Cog): else: await ctx.send(_("Names and nicknames will now be filtered.")) + def invalidate_cache(self, guild: discord.Guild, channel: discord.TextChannel = None): + """ Invalidate a cached pattern""" + self.pattern_cache.pop((guild, channel), None) + async def add_to_filter( self, server_or_channel: Union[discord.Guild, discord.TextChannel], words: list ) -> bool: @@ -322,24 +330,34 @@ class Filter(commands.Cog): async def filter_hits( self, text: str, server_or_channel: Union[discord.Guild, discord.TextChannel] ) -> Set[str]: - if isinstance(server_or_channel, discord.Guild): - word_list = set(await self.settings.guild(server_or_channel).filter()) - elif isinstance(server_or_channel, discord.TextChannel): - word_list = set( - await self.settings.guild(server_or_channel.guild).filter() - + await self.settings.channel(server_or_channel).filter() - ) - else: - raise TypeError("%r should be Guild or TextChannel" % server_or_channel) - content = text.lower() - msg_words = set(RE_WORD_SPLIT.split(content)) + try: + guild = server_or_channel.guild + channel = server_or_channel + except AttributeError: + guild = server_or_channel + channel = None - filtered_phrases = {x for x in word_list if len(RE_WORD_SPLIT.split(x)) > 1} - filtered_words = word_list - filtered_phrases + hits: Set[str] = set() - hits = {p for p in filtered_phrases if p in content} - hits |= filtered_words & msg_words + try: + pattern = self.pattern_cache[(guild, channel)] + except KeyError: + word_list = set(await self.settings.guild(guild).filter()) + if channel: + word_list |= set(await self.settings.channel(channel).filter()) + + if word_list: + pattern = re.compile( + "|".join(rf"\b{re.escape(w)}\b" for w in word_list), flags=re.I + ) + else: + pattern = None + + self.pattern_cache[(guild, channel)] = pattern + + if pattern: + hits |= set(pattern.findall(text)) return hits async def check_filter(self, message: discord.Message):