diff --git a/redbot/cogs/alias/__init__.py b/redbot/cogs/alias/__init__.py index df7b4900e..c4ff8ea95 100644 --- a/redbot/cogs/alias/__init__.py +++ b/redbot/cogs/alias/__init__.py @@ -2,5 +2,7 @@ from .alias import Alias from redbot.core.bot import Red -def setup(bot: Red): - bot.add_cog(Alias(bot)) +async def setup(bot: Red): + cog = Alias(bot) + await cog.initialize() + bot.add_cog(cog) diff --git a/redbot/cogs/alias/alias.py b/redbot/cogs/alias/alias.py index 639dfb377..89f0bb741 100644 --- a/redbot/cogs/alias/alias.py +++ b/redbot/cogs/alias/alias.py @@ -1,16 +1,15 @@ from copy import copy -from re import findall, search +from re import search from string import Formatter -from typing import Generator, Tuple, Iterable, Optional +from typing import Dict import discord -from discord.ext.commands.view import StringView from redbot.core import Config, commands, checks from redbot.core.i18n import Translator, cog_i18n from redbot.core.utils.chat_formatting import box from redbot.core.bot import Red -from .alias_entry import AliasEntry +from .alias_entry import AliasEntry, AliasCache, ArgParseError _ = Translator("Alias", __file__) @@ -26,10 +25,6 @@ class _TrackingFormatter(Formatter): return super().get_value(key, args, kwargs) -class ArgParseError(Exception): - pass - - @cog_i18n(_) class Alias(commands.Cog): """Create aliases for commands. @@ -42,9 +37,9 @@ class Alias(commands.Cog): and append them to the stored alias. """ - default_global_settings = {"entries": []} + default_global_settings: Dict[str, list] = {"entries": []} - default_guild_settings = {"enabled": False, "entries": []} # Going to be a list of dicts + default_guild_settings: Dict[str, list] = {"entries": []} # Going to be a list of dicts def __init__(self, bot: Red): super().__init__() @@ -53,40 +48,12 @@ class Alias(commands.Cog): self.config.register_global(**self.default_global_settings) self.config.register_guild(**self.default_guild_settings) + self._aliases: AliasCache = AliasCache(config=self.config, cache_enabled=True) - async def unloaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]: - return (AliasEntry.from_json(d) for d in (await self.config.guild(guild).entries())) - - async def unloaded_global_aliases(self) -> Generator[AliasEntry, None, None]: - return (AliasEntry.from_json(d) for d in (await self.config.entries())) - - async def loaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]: - return ( - AliasEntry.from_json(d, bot=self.bot) - for d in (await self.config.guild(guild).entries()) - ) - - async def loaded_global_aliases(self) -> Generator[AliasEntry, None, None]: - return (AliasEntry.from_json(d, bot=self.bot) for d in (await self.config.entries())) - - async def is_alias( - self, - guild: Optional[discord.Guild], - alias_name: str, - server_aliases: Iterable[AliasEntry] = (), - ) -> Tuple[bool, Optional[AliasEntry]]: - - if not server_aliases and guild is not None: - server_aliases = await self.unloaded_aliases(guild) - - global_aliases = await self.unloaded_global_aliases() - - for aliases in (server_aliases, global_aliases): - for alias in aliases: - if alias.name == alias_name: - return True, alias - - return False, None + async def initialize(self): + # This can be where we set the cache_enabled attribute later + if not self._aliases._loaded: + await self._aliases.load_aliases() def is_command(self, alias_name: str) -> bool: """ @@ -100,56 +67,6 @@ class Alias(commands.Cog): def is_valid_alias_name(alias_name: str) -> bool: return not bool(search(r"\s", alias_name)) and alias_name.isprintable() - async def add_alias( - self, ctx: commands.Context, alias_name: str, command: str, global_: bool = False - ) -> AliasEntry: - indices = findall(r"{(\d*)}", command) - if indices: - try: - indices = [int(a[0]) for a in indices] - except IndexError: - raise ArgParseError(_("Arguments must be specified with a number.")) - low = min(indices) - indices = [a - low for a in indices] - high = max(indices) - gaps = set(indices).symmetric_difference(range(high + 1)) - if gaps: - raise ArgParseError( - _("Arguments must be sequential. Missing arguments: ") - + ", ".join(str(i + low) for i in gaps) - ) - command = command.format(*(f"{{{i}}}" for i in range(-low, high + low + 1))) - - alias = AliasEntry(alias_name, command, ctx.author, global_=global_) - - if global_: - settings = self.config - else: - settings = self.config.guild(ctx.guild) - await settings.enabled.set(True) - - async with settings.entries() as curr_aliases: - curr_aliases.append(alias.to_json()) - - return alias - - async def delete_alias( - self, ctx: commands.Context, alias_name: str, global_: bool = False - ) -> bool: - if global_: - settings = self.config - else: - settings = self.config.guild(ctx.guild) - - async with settings.entries() as aliases: - for alias in aliases: - alias_obj = AliasEntry.from_json(alias) - if alias_obj.name == alias_name: - aliases.remove(alias) - return True - - return False - async def get_prefix(self, message: discord.Message) -> str: """ Tries to determine what prefix is used in a message object. @@ -167,57 +84,11 @@ class Alias(commands.Cog): return p raise ValueError(_("No prefix found.")) - def get_extra_args_from_alias( - self, message: discord.Message, prefix: str, alias: AliasEntry - ) -> str: - """ - When an alias is executed by a user in chat this function tries - to get any extra arguments passed in with the call. - Whitespace will be trimmed from both ends. - :param message: - :param prefix: - :param alias: - :return: - """ - known_content_length = len(prefix) + len(alias.name) - extra = message.content[known_content_length:] - view = StringView(extra) - view.skip_ws() - extra = [] - while not view.eof: - prev = view.index - word = view.get_quoted_word() - if len(word) < view.index - prev: - word = "".join((view.buffer[prev], word, view.buffer[view.index - 1])) - extra.append(word) - view.skip_ws() - return extra - - async def maybe_call_alias( - self, message: discord.Message, aliases: Iterable[AliasEntry] = None - ): - try: - prefix = await self.get_prefix(message) - except ValueError: - return - - try: - potential_alias = message.content[len(prefix) :].split(" ")[0] - except IndexError: - return False - - is_alias, alias = await self.is_alias( - message.guild, potential_alias, server_aliases=aliases - ) - - if is_alias: - await self.call_alias(message, prefix, alias) - async def call_alias(self, message: discord.Message, prefix: str, alias: AliasEntry): new_message = copy(message) try: - args = self.get_extra_args_from_alias(message, prefix, alias) - except commands.BadArgument as bae: + args = alias.get_extra_args_from_alias(message, prefix) + except commands.BadArgument: return trackform = _TrackingFormatter() @@ -257,8 +128,8 @@ class Alias(commands.Cog): ) return - is_alias, something_useless = await self.is_alias(ctx.guild, alias_name) - if is_alias: + alias = await self._aliases.get_alias(ctx.guild, alias_name) + if alias: await ctx.send( _( "You attempted to create a new alias" @@ -292,7 +163,7 @@ class Alias(commands.Cog): # and that the alias name is valid. try: - await self.add_alias(ctx, alias_name, command) + await self._aliases.add_alias(ctx, alias_name, command) except ArgParseError as e: return await ctx.send(" ".join(e.args)) @@ -316,8 +187,8 @@ class Alias(commands.Cog): ) return - is_alias, something_useless = await self.is_alias(ctx.guild, alias_name) - if is_alias: + alias = await self._aliases.get_alias(ctx.guild, alias_name) + if alias: await ctx.send( _( "You attempted to create a new global alias" @@ -341,7 +212,7 @@ class Alias(commands.Cog): # endregion try: - await self.add_alias(ctx, alias_name, command, global_=True) + await self._aliases.add_alias(ctx, alias_name, command, global_=True) except ArgParseError as e: return await ctx.send(" ".join(e.args)) @@ -355,8 +226,8 @@ class Alias(commands.Cog): @commands.guild_only() async def _help_alias(self, ctx: commands.Context, alias_name: str): """Try to execute help for the base command of the alias.""" - is_alias, alias = await self.is_alias(ctx.guild, alias_name=alias_name) - if is_alias: + alias = await self._aliases.get_alias(ctx.guild, alias_name=alias_name) + if alias: if self.is_command(alias.command): base_cmd = alias.command else: @@ -372,9 +243,9 @@ class Alias(commands.Cog): @commands.guild_only() async def _show_alias(self, ctx: commands.Context, alias_name: str): """Show what command the alias executes.""" - is_alias, alias = await self.is_alias(ctx.guild, alias_name) + alias = await self._aliases.get_alias(ctx.guild, alias_name) - if is_alias: + if alias: await ctx.send( _("The `{alias_name}` alias will execute the command `{command}`").format( alias_name=alias_name, command=alias.command @@ -388,14 +259,11 @@ class Alias(commands.Cog): @commands.guild_only() async def _del_alias(self, ctx: commands.Context, alias_name: str): """Delete an existing alias on this server.""" - aliases = await self.unloaded_aliases(ctx.guild) - try: - next(aliases) - except StopIteration: + if not await self._aliases.get_guild_aliases(ctx.guild): await ctx.send(_("There are no aliases on this server.")) return - if await self.delete_alias(ctx, alias_name): + if await self._aliases.delete_alias(ctx, alias_name): await ctx.send( _("Alias with the name `{name}` was successfully deleted.").format(name=alias_name) ) @@ -406,14 +274,11 @@ class Alias(commands.Cog): @global_.command(name="delete", aliases=["del", "remove"]) async def _del_global_alias(self, ctx: commands.Context, alias_name: str): """Delete an existing global alias.""" - aliases = await self.unloaded_global_aliases() - try: - next(aliases) - except StopIteration: - await ctx.send(_("There are no aliases on this bot.")) + if not await self._aliases.get_global_aliases(): + await ctx.send(_("There are no global aliases on this bot.")) return - if await self.delete_alias(ctx, alias_name, global_=True): + if await self._aliases.delete_alias(ctx, alias_name, global_=True): await ctx.send( _("Alias with the name `{name}` was successfully deleted.").format(name=alias_name) ) @@ -424,32 +289,34 @@ class Alias(commands.Cog): @commands.guild_only() async def _list_alias(self, ctx: commands.Context): """List the available aliases on this server.""" - names = [_("Aliases:")] + sorted( - ["+ " + a.name for a in (await self.unloaded_aliases(ctx.guild))] - ) - if len(names) == 0: - await ctx.send(_("There are no aliases on this server.")) - else: - await ctx.send(box("\n".join(names), "diff")) + guild_aliases = await self._aliases.get_guild_aliases(ctx.guild) + if not guild_aliases: + return await ctx.send(_("There are no aliases on this server.")) + names = [_("Aliases:")] + sorted(["+ " + a.name for a in guild_aliases]) + await ctx.send(box("\n".join(names), "diff")) @global_.command(name="list") async def _list_global_alias(self, ctx: commands.Context): """List the available global aliases on this bot.""" - names = [_("Aliases:")] + sorted( - ["+ " + a.name for a in await self.unloaded_global_aliases()] - ) - if len(names) == 0: - await ctx.send(_("There are no aliases on this server.")) - else: - await ctx.send(box("\n".join(names), "diff")) + global_aliases = await self._aliases.get_global_aliases() + if not global_aliases: + return await ctx.send(_("There are no global aliases.")) + names = [_("Aliases:")] + sorted(["+ " + a.name for a in global_aliases]) + await ctx.send(box("\n".join(names), "diff")) @commands.Cog.listener() - async def on_message(self, message: discord.Message): - aliases = list(await self.unloaded_global_aliases()) - if message.guild is not None: - aliases = aliases + list(await self.unloaded_aliases(message.guild)) - - if len(aliases) == 0: + async def on_message_without_command(self, message: discord.Message): + try: + prefix = await self.get_prefix(message) + except ValueError: return - await self.maybe_call_alias(message, aliases=aliases) + try: + potential_alias = message.content[len(prefix) :].split(" ")[0] + except IndexError: + return + + alias = await self._aliases.get_alias(message.guild, potential_alias) + + if alias: + await self.call_alias(message, prefix, alias) diff --git a/redbot/cogs/alias/alias_entry.py b/redbot/cogs/alias/alias_entry.py index 9bf06d41a..78e1026ba 100644 --- a/redbot/cogs/alias/alias_entry.py +++ b/redbot/cogs/alias/alias_entry.py @@ -1,25 +1,37 @@ -from typing import Tuple +from typing import Tuple, Dict, Optional, List, Union +from re import findall import discord -from redbot.core import commands +from discord.ext.commands.view import StringView +from redbot.core import commands, Config +from redbot.core.i18n import Translator +from redbot.core.utils import AsyncIter + +_ = Translator("Alias", __file__) + + +class ArgParseError(Exception): + pass class AliasEntry: + """An object containing all required information about an alias""" + + name: str + command: Union[Tuple[str], str] + creator: int + guild: Optional[int] + uses: int + def __init__( - self, name: str, command: Tuple[str], creator: discord.Member, global_: bool = False + self, name: str, command: Union[Tuple[str], str], creator: int, guild: Optional[int], ): super().__init__() - self.has_real_data = False self.name = name self.command = command self.creator = creator - self.global_ = global_ - - self.guild = None - if hasattr(creator, "guild"): - self.guild = creator.guild - + self.guild = guild self.uses = 0 def inc(self): @@ -30,34 +42,182 @@ class AliasEntry: self.uses += 1 return self.uses + def get_extra_args_from_alias(self, message: discord.Message, prefix: str) -> str: + """ + When an alias is executed by a user in chat this function tries + to get any extra arguments passed in with the call. + Whitespace will be trimmed from both ends. + :param message: + :param prefix: + :param alias: + :return: + """ + known_content_length = len(prefix) + len(self.name) + extra = message.content[known_content_length:] + view = StringView(extra) + view.skip_ws() + extra = [] + while not view.eof: + prev = view.index + word = view.get_quoted_word() + if len(word) < view.index - prev: + word = "".join((view.buffer[prev], word, view.buffer[view.index - 1])) + extra.append(word) + view.skip_ws() + return extra + def to_json(self) -> dict: - try: - creator = str(self.creator.id) - guild = str(self.guild.id) - except AttributeError: - creator = self.creator - guild = self.guild return { "name": self.name, "command": self.command, - "creator": creator, - "guild": guild, - "global": self.global_, + "creator": self.creator, + "guild": self.guild, "uses": self.uses, } @classmethod - def from_json(cls, data: dict, bot: commands.Bot = None): - ret = cls(data["name"], data["command"], data["creator"], global_=data["global"]) - - if bot: - ret.has_real_data = True - ret.creator = bot.get_user(int(data["creator"])) - guild = bot.get_guild(int(data["guild"])) - ret.guild = guild - else: - ret.guild = data["guild"] - + def from_json(cls, data: dict): + ret = cls(data["name"], data["command"], data["creator"], data["guild"]) ret.uses = data.get("uses", 0) return ret + + +class AliasCache: + def __init__(self, config: Config, cache_enabled: bool = True): + self.config = config + self._cache_enabled = cache_enabled + self._loaded = False + self._aliases: Dict[Optional[int], Dict[str, AliasEntry]] = {None: {}} + + async def load_aliases(self): + if not self._cache_enabled: + self._loaded = True + return + for alias in await self.config.entries(): + self._aliases[None][alias["name"]] = AliasEntry.from_json(alias) + + all_guilds = await self.config.all_guilds() + async for guild_id, guild_data in AsyncIter(all_guilds.items(), steps=100): + if guild_id not in self._aliases: + self._aliases[guild_id] = {} + for alias in guild_data["entries"]: + self._aliases[guild_id][alias["name"]] = AliasEntry.from_json(alias) + self._loaded = True + + async def get_aliases(self, ctx: commands.Context) -> List[AliasEntry]: + """Returns all possible aliases with the given context""" + global_aliases: List[AliasEntry] = [] + server_aliases: List[AliasEntry] = [] + global_aliases = await self.get_global_aliases() + if ctx.guild and ctx.guild.id in self._aliases: + server_aliases = await self.get_guild_aliases(ctx.guild) + return global_aliases + server_aliases + + async def get_guild_aliases(self, guild: discord.Guild) -> List[AliasEntry]: + """Returns all guild specific aliases""" + aliases: List[AliasEntry] = [] + + if self._cache_enabled: + if guild.id in self._aliases: + for _, alias in self._aliases[guild.id].items(): + aliases.append(alias) + else: + aliases = [AliasEntry.from_json(d) for d in await self.config.guild(guild).entries()] + return aliases + + async def get_global_aliases(self) -> List[AliasEntry]: + """Returns all global specific aliases""" + aliases: List[AliasEntry] = [] + if self._cache_enabled: + for _, alias in self._aliases[None].items(): + aliases.append(alias) + else: + aliases = [AliasEntry.from_json(d) for d in await self.config.entries()] + return aliases + + async def get_alias( + self, guild: Optional[discord.Guild], alias_name: str, + ) -> Optional[AliasEntry]: + """Returns an AliasEntry object if the provided alias_name is a registered alias""" + server_aliases: List[AliasEntry] = [] + + if self._cache_enabled: + if alias_name in self._aliases[None]: + return self._aliases[None][alias_name] + if guild is not None: + if guild.id in self._aliases: + if alias_name in self._aliases[guild.id]: + return self._aliases[guild.id][alias_name] + else: + if guild: + server_aliases = [ + AliasEntry.from_json(d) for d in await self.config.guild(guild.id).entries() + ] + global_aliases = [AliasEntry.from_json(d) for d in await self.config.entries()] + all_aliases = global_aliases + server_aliases + + for alias in all_aliases: + if alias.name == alias_name: + return alias + + return None + + async def add_alias( + self, ctx: commands.Context, alias_name: str, command: str, global_: bool = False + ) -> AliasEntry: + indices = findall(r"{(\d*)}", command) + if indices: + try: + indices = [int(a[0]) for a in indices] + except IndexError: + raise ArgParseError(_("Arguments must be specified with a number.")) + low = min(indices) + indices = [a - low for a in indices] + high = max(indices) + gaps = set(indices).symmetric_difference(range(high + 1)) + if gaps: + raise ArgParseError( + _("Arguments must be sequential. Missing arguments: ") + + ", ".join(str(i + low) for i in gaps) + ) + command = command.format(*(f"{{{i}}}" for i in range(-low, high + low + 1))) + + if global_: + alias = AliasEntry(alias_name, command, ctx.author.id, None) + settings = self.config + if self._cache_enabled: + self._aliases[None][alias.name] = alias + else: + alias = AliasEntry(alias_name, command, ctx.author.id, ctx.guild.id) + settings = self.config.guild(ctx.guild) + if self._cache_enabled: + if ctx.guild.id not in self._aliases: + self._aliases[ctx.guild.id] = {} + self._aliases[ctx.guild.id][alias.name] = alias + + async with settings.entries() as curr_aliases: + curr_aliases.append(alias.to_json()) + + return alias + + async def delete_alias( + self, ctx: commands.Context, alias_name: str, global_: bool = False + ) -> bool: + if global_: + settings = self.config + else: + settings = self.config.guild(ctx.guild) + + async with settings.entries() as aliases: + for alias in aliases: + if alias["name"] == alias_name: + aliases.remove(alias) + if self._cache_enabled: + if global_: + del self._aliases[None][alias_name] + else: + del self._aliases[ctx.guild.id][alias_name] + return True + + return False diff --git a/redbot/cogs/cleanup/cleanup.py b/redbot/cogs/cleanup/cleanup.py index 31170a07f..42c222ff5 100644 --- a/redbot/cogs/cleanup/cleanup.py +++ b/redbot/cogs/cleanup/cleanup.py @@ -409,8 +409,8 @@ class Cleanup(commands.Cog): alias_cog = self.bot.get_cog("Alias") if alias_cog is not None: alias_names: Set[str] = ( - set((a.name for a in await alias_cog.unloaded_global_aliases())) - | set(a.name for a in await alias_cog.unloaded_aliases(ctx.guild)) + set((a.name for a in await alias_cog._aliases.get_global_aliases())) + | set(a.name for a in await alias_cog._aliases.get_guild_aliases(ctx.guild)) ) is_alias = lambda name: name in alias_names else: @@ -538,7 +538,7 @@ class Cleanup(commands.Cog): @commands.bot_has_permissions(manage_messages=True) async def cleanup_spam(self, ctx: commands.Context, number: int = 50): """Deletes duplicate messages in the channel from the last X messages and keeps only one copy. - + Defaults to 50. """ msgs = [] diff --git a/redbot/core/utils/_internal_utils.py b/redbot/core/utils/_internal_utils.py index 51d2c4828..238aca44c 100644 --- a/redbot/core/utils/_internal_utils.py +++ b/redbot/core/utils/_internal_utils.py @@ -103,9 +103,9 @@ async def fuzzy_command_search( # If the term is an alias or CC, we don't want to send a supplementary fuzzy search. alias_cog = ctx.bot.get_cog("Alias") if alias_cog is not None: - is_alias, alias = await alias_cog.is_alias(ctx.guild, term) + alias = await alias_cog._aliases.get_alias(ctx.guild, term) - if is_alias: + if alias: return None customcom_cog = ctx.bot.get_cog("CustomCommands") if customcom_cog is not None: diff --git a/tests/cogs/test_alias.py b/tests/cogs/test_alias.py index 73f51b4c0..27fdc9833 100644 --- a/tests/cogs/test_alias.py +++ b/tests/cogs/test_alias.py @@ -9,58 +9,59 @@ def test_is_valid_alias_name(alias): @pytest.mark.asyncio async def test_empty_guild_aliases(alias, empty_guild): - assert list(await alias.unloaded_aliases(empty_guild)) == [] + assert list(await alias._aliases.get_guild_aliases(empty_guild)) == [] @pytest.mark.asyncio async def test_empty_global_aliases(alias): - assert list(await alias.unloaded_global_aliases()) == [] + assert list(await alias._aliases.get_global_aliases()) == [] async def create_test_guild_alias(alias, ctx): - await alias.add_alias(ctx, "test", "ping", global_=False) + await alias._aliases.add_alias(ctx, "test", "ping", global_=False) async def create_test_global_alias(alias, ctx): - await alias.add_alias(ctx, "test", "ping", global_=True) + await alias._aliases.add_alias(ctx, "test_global", "ping", global_=True) @pytest.mark.asyncio async def test_add_guild_alias(alias, ctx): await create_test_guild_alias(alias, ctx) - is_alias, alias_obj = await alias.is_alias(ctx.guild, "test") - assert is_alias is True - assert alias_obj.global_ is False + alias_obj = await alias._aliases.get_alias(ctx.guild, "test") + assert alias_obj.name == "test" @pytest.mark.asyncio async def test_delete_guild_alias(alias, ctx): await create_test_guild_alias(alias, ctx) - is_alias, _ = await alias.is_alias(ctx.guild, "test") - assert is_alias is True + alias_obj = await alias._aliases.get_alias(ctx.guild, "test") + assert alias_obj.name == "test" - await alias.delete_alias(ctx, "test") + did_delete = await alias._aliases.delete_alias(ctx, "test") + assert did_delete is True - is_alias, _ = await alias.is_alias(ctx.guild, "test") - assert is_alias is False + alias_obj = await alias._aliases.get_alias(ctx.guild, "test") + assert alias_obj is None @pytest.mark.asyncio async def test_add_global_alias(alias, ctx): await create_test_global_alias(alias, ctx) - is_alias, alias_obj = await alias.is_alias(ctx.guild, "test") + alias_obj = await alias._aliases.get_alias(ctx.guild, "test_global") - assert is_alias is True - assert alias_obj.global_ is True + assert alias_obj.name == "test_global" @pytest.mark.asyncio async def test_delete_global_alias(alias, ctx): await create_test_global_alias(alias, ctx) - is_alias, alias_obj = await alias.is_alias(ctx.guild, "test") - assert is_alias is True - assert alias_obj.global_ is True + alias_obj = await alias._aliases.get_alias(ctx.guild, "test_global") + assert alias_obj.name == "test_global" - did_delete = await alias.delete_alias(ctx, alias_name="test", global_=True) + did_delete = await alias._aliases.delete_alias(ctx, alias_name="test_global", global_=True) assert did_delete is True + + alias_obj = await alias._aliases.get_alias(None, "test_global") + assert alias_obj is None