Avoid potential memory leak in Filter cog (#5578)

This commit is contained in:
jack1142 2022-02-20 22:52:58 +01:00 committed by GitHub
parent 0338e8e0a8
commit eeffbf8231
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -2,7 +2,7 @@ import asyncio
import discord import discord
import re import re
from datetime import timezone from datetime import timezone
from typing import Union, Set, Literal from typing import Union, Set, Literal, Optional
from redbot.core import checks, Config, modlog, commands from redbot.core import checks, Config, modlog, commands
from redbot.core.bot import Red from redbot.core.bot import Red
@ -346,12 +346,14 @@ class Filter(commands.Cog):
else: else:
await ctx.send(_("Names and nicknames will now be filtered.")) await ctx.send(_("Names and nicknames will now be filtered."))
def invalidate_cache(self, guild: discord.Guild, channel: discord.TextChannel = None): def invalidate_cache(
self, guild: discord.Guild, channel: Optional[discord.TextChannel] = None
) -> None:
""" Invalidate a cached pattern""" """ Invalidate a cached pattern"""
self.pattern_cache.pop((guild, channel), None) self.pattern_cache.pop((guild.id, channel and channel.id), None)
if channel is None: if channel is None:
for keyset in list(self.pattern_cache.keys()): # cast needed, no remove for keyset in list(self.pattern_cache.keys()): # cast needed, no remove
if guild in keyset: if guild.id == keyset[0]:
self.pattern_cache.pop(keyset, None) self.pattern_cache.pop(keyset, None)
async def add_to_filter( async def add_to_filter(
@ -408,7 +410,7 @@ class Filter(commands.Cog):
hits: Set[str] = set() hits: Set[str] = set()
try: try:
pattern = self.pattern_cache[(guild, channel)] pattern = self.pattern_cache[(guild.id, channel and channel.id)]
except KeyError: except KeyError:
word_list = set(await self.config.guild(guild).filter()) word_list = set(await self.config.guild(guild).filter())
if channel: if channel:
@ -421,7 +423,7 @@ class Filter(commands.Cog):
else: else:
pattern = None pattern = None
self.pattern_cache[(guild, channel)] = pattern self.pattern_cache[(guild.id, channel and channel.id)] = pattern
if pattern: if pattern:
hits |= set(pattern.findall(text)) hits |= set(pattern.findall(text))