From 8fa967cb9143ddc94547dd3cde7130cac108eec8 Mon Sep 17 00:00:00 2001 From: Will Date: Wed, 21 Jun 2017 16:59:26 -0400 Subject: [PATCH] Added Alias cog (#780) --- cogs/alias/__init__.py | 6 + cogs/alias/alias.py | 364 +++++++++++++++++++++++ cogs/alias/alias_entry.py | 63 ++++ tests/cogs/downloader/test_downloader.py | 2 +- tests/cogs/test_alias.py | 64 ++++ tests/conftest.py | 16 +- 6 files changed, 507 insertions(+), 8 deletions(-) create mode 100644 cogs/alias/__init__.py create mode 100644 cogs/alias/alias.py create mode 100644 cogs/alias/alias_entry.py create mode 100644 tests/cogs/test_alias.py diff --git a/cogs/alias/__init__.py b/cogs/alias/__init__.py new file mode 100644 index 000000000..b1b8587ca --- /dev/null +++ b/cogs/alias/__init__.py @@ -0,0 +1,6 @@ +from .alias import Alias +from discord.ext import commands + + +def setup(bot: commands.Bot): + bot.add_cog(Alias(bot)) diff --git a/cogs/alias/alias.py b/cogs/alias/alias.py new file mode 100644 index 000000000..2179a543e --- /dev/null +++ b/cogs/alias/alias.py @@ -0,0 +1,364 @@ +import discord +from copy import copy +from discord.ext import commands + +from typing import Generator, Tuple, Iterable +from core import Config +from core.bot import Red +from core.utils.chat_formatting import box +from .alias_entry import AliasEntry + + +class Alias: + """ + Alias + + Aliases are per server shortcuts for commands. They + can act as both a lambda (storing arguments for repeated use) + or as simply a shortcut to saying "x y z". + + When run, aliases will accept any additional arguments + and append them to the stored alias + """ + + default_global_settings = { + "entries": [] + } + + default_guild_settings = { + "enabled": False, + "entries": [] # Going to be a list of dicts + } + + def __init__(self, bot: Red): + self.bot = bot + self.file_path = "data/alias/aliases.json" + self._aliases = Config.get_conf(self, 8927348724) + + self._aliases.register_global(**self.default_global_settings) + self._aliases.register_guild(**self.default_guild_settings) + + def unloaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]: + return (AliasEntry.from_json(d) for d in self._aliases.guild(guild).entries()) + + def unloaded_global_aliases(self) -> Generator[AliasEntry, None, None]: + return (AliasEntry.from_json(d) for d in self._aliases.entries()) + + def loaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]: + return (AliasEntry.from_json(d, bot=self.bot) + for d in self._aliases.guild(guild).entries()) + + def loaded_global_aliases(self) -> Generator[AliasEntry, None, None]: + return (AliasEntry.from_json(d, bot=self.bot) for d in self._aliases.entries()) + + def is_alias(self, guild: discord.Guild, alias_name: str, + server_aliases: Iterable[AliasEntry]=()) -> (bool, AliasEntry): + + if not server_aliases: + server_aliases = self.unloaded_aliases(guild) + + global_aliases = self.unloaded_global_aliases() + + for aliases in (server_aliases, global_aliases): + for alias in aliases: + if alias.name == alias_name: + return True, alias + + return False, None + + def is_command(self, alias_name: str) -> bool: + command = self.bot.get_command(alias_name) + return command is not None + + @staticmethod + def is_valid_alias_name(alias_name: str) -> bool: + return alias_name.isidentifier() + + async def add_alias(self, ctx: commands.Context, alias_name: str, + command: Tuple[str], global_: bool=False) -> AliasEntry: + alias = AliasEntry(alias_name, command, ctx.author, global_=global_) + + if global_: + curr_aliases = self._aliases.entries() + curr_aliases.append(alias.to_json()) + await self._aliases.set("entries", curr_aliases) + else: + curr_aliases = self._aliases.guild(ctx.guild).entries() + + curr_aliases.append(alias.to_json()) + await self._aliases.guild(ctx.guild).set("entries", curr_aliases) + + await self._aliases.guild(ctx.guild).set("enabled", True) + return alias + + async def delete_alias(self, ctx: commands.Context, alias_name: str, + global_: bool=False) -> bool: + if global_: + aliases = self.unloaded_global_aliases() + setter_func = self._aliases.set + else: + aliases = self.unloaded_aliases(ctx.guild) + setter_func = self._aliases.guild(ctx.guild).set + + did_delete_alias = False + + to_keep = [] + for alias in aliases: + if alias.name != alias_name: + to_keep.append(alias) + else: + did_delete_alias = True + + await setter_func( + "entries", + [a.to_json() for a in to_keep] + ) + + return did_delete_alias + + def get_prefix(self, message: discord.Message) -> str: + """ + Tries to determine what prefix is used in a message object. + Looks to identify from longest prefix to smallest. + + Will raise ValueError if no prefix is found. + :param message: Message object + :return: + """ + guild = message.guild + content = message.content + prefixes = sorted(self.bot.command_prefix(self.bot, message), + key=lambda pfx: len(pfx), + reverse=True) + for p in prefixes: + if content.startswith(p): + return p + raise ValueError("No prefix found.") + + def get_extra_args_from_alias(self, message: discord.Message, prefix: str, + alias: AliasEntry) -> str: + """ + When an alias is executed by a user in chat this function tries + to get any extra arguments passed in with the call. + Whitespace will be trimmed from both ends. + :param message: + :param prefix: + :param alias: + :return: + """ + known_content_length = len(prefix) + len(alias.name) + extra = message.content[known_content_length:].strip() + return extra + + async def maybe_call_alias(self, message: discord.Message, + aliases: Iterable[AliasEntry]=None): + try: + prefix = self.get_prefix(message) + except ValueError: + return + + try: + potential_alias = message.content[len(prefix):].split(" ")[0] + except IndexError: + return False + + is_alias, alias = self.is_alias(message.guild, potential_alias, server_aliases=aliases) + + if is_alias: + await self.call_alias(message, prefix, alias) + + async def call_alias(self, message: discord.Message, prefix: str, + alias: AliasEntry): + new_message = copy(message) + args = self.get_extra_args_from_alias(message, prefix, alias) + + # noinspection PyDunderSlots + new_message.content = "{}{} {}".format(prefix, alias.command, args) + await self.bot.process_commands(new_message) + + @commands.group() + @commands.guild_only() + async def alias(self, ctx: commands.Context): + """Manage per-server aliases for commands""" + if ctx.invoked_subcommand is None: + await self.bot.send_cmd_help(ctx) + + @alias.group(name="global") + async def global_(self, ctx: commands.Context): + """ + Manage global aliases. + """ + if ctx.invoked_subcommand is None or \ + isinstance(ctx.invoked_subcommand, commands.Group): + await self.bot.send_cmd_help(ctx) + + @alias.command(name="add") + @commands.guild_only() + async def _add_alias(self, ctx: commands.Context, + alias_name: str, *, command): + """ + Add an alias for a command. + """ +#region Alias Add Validity Checking + is_command = self.is_command(alias_name) + if is_command: + await ctx.send(("You attempted to create a new alias" + " with the name {} but that" + " name is already a command on this bot.").format(alias_name)) + return + + is_alias, _ = self.is_alias(ctx.guild, alias_name) + if is_alias: + await ctx.send(("You attempted to create a new alias" + " with the name {} but that" + " alias already exists on this server.").format(alias_name)) + return + + is_valid_name = self.is_valid_alias_name(alias_name) + if not is_valid_name: + await ctx.send(("You attempted to create a new alias" + " with the name {} but that" + " name is an invalid alias name. Alias" + " names may only contain letters, numbers," + " and underscores and must start with a letter.").format(alias_name)) + return +#endregion + + # At this point we know we need to make a new alias + # and that the alias name is valid. + + await self.add_alias(ctx, alias_name, command) + + await ctx.send(("A new alias with the trigger `{}`" + " has been created.").format(alias_name)) + + @global_.command(name="add") + async def _add_global_alias(self, ctx: commands.Context, + alias_name: str, *, command): + """ + Add a global alias for a command. + """ +# region Alias Add Validity Checking + is_command = self.is_command(alias_name) + if is_command: + await ctx.send(("You attempted to create a new global alias" + " with the name {} but that" + " name is already a command on this bot.").format(alias_name)) + return + + is_alias, _ = self.is_alias(ctx.guild, alias_name) + if is_alias: + await ctx.send(("You attempted to create a new alias" + " with the name {} but that" + " alias already exists on this server.").format(alias_name)) + return + + is_valid_name = self.is_valid_alias_name(alias_name) + if not is_valid_name: + await ctx.send(("You attempted to create a new alias" + " with the name {} but that" + " name is an invalid alias name. Alias" + " names may only contain letters, numbers," + " and underscores and must start with a letter.").format(alias_name)) + return +# endregion + + await self.add_alias(ctx, alias_name, command, global_=True) + + await ctx.send(("A new global alias with the trigger `{}`" + " has been created.").format(alias_name)) + + @alias.command(name="help") + @commands.guild_only() + async def _help_alias(self, ctx: commands.Context, alias_name: str): + """Tries to execute help for the base command of the alias""" + is_alias, alias = self.is_alias(ctx.guild, alias_name=alias_name) + if is_alias: + base_cmd = alias.command[0] + + new_msg = copy(ctx.message) + new_msg.content = "{}help {}".format(ctx.prefix, base_cmd) + await self.bot.process_commands(new_msg) + else: + ctx.send("No such alias exists.") + + @alias.command(name="show") + @commands.guild_only() + async def _show_alias(self, ctx: commands.Context, alias_name: str): + """Shows what command the alias executes.""" + is_alias, alias = self.is_alias(ctx.guild, alias_name) + + if is_alias: + await ctx.send(("The `{}` alias will execute the" + " command `{}`").format(alias_name, alias.command)) + else: + await ctx.send("There is no alias with the name `{}`".format(alias_name)) + + @alias.command(name="del") + @commands.guild_only() + async def _del_alias(self, ctx: commands.Context, alias_name: str): + """ + Deletes an existing alias on this server. + """ + aliases = self.unloaded_aliases(ctx.guild) + try: + next(aliases) + except StopIteration: + await ctx.send("There are no aliases on this guild.") + return + + if await self.delete_alias(ctx, alias_name): + await ctx.send(("Alias with the name `{}` was successfully" + " deleted.").format(alias_name)) + else: + await ctx.send("Alias with name `{}` was not found.".format(alias_name)) + + @global_.command(name="del") + async def _del_global_alias(self, ctx: commands.Context, alias_name: str): + """ + Deletes an existing global alias. + """ + aliases = self.unloaded_global_aliases() + try: + next(aliases) + except StopIteration: + await ctx.send("There are no aliases on this bot.") + return + + if await self.delete_alias(ctx, alias_name, global_=True): + await ctx.send(("Alias with the name `{}` was successfully" + " deleted.").format(alias_name)) + else: + await ctx.send("Alias with name `{}` was not found.".format(alias_name)) + + @alias.command(name="list") + @commands.guild_only() + async def _list_alias(self, ctx: commands.Context): + """ + Lists the available aliases on this server. + """ + names = ["Aliases:", ] + sorted(["+ " + a.name for a in self.unloaded_aliases(ctx.guild)]) + if len(names) == 0: + await ctx.send("There are no aliases on this server.") + else: + await ctx.send(box("\n".join(names), "diff")) + + @global_.command(name="list") + async def _list_global_alias(self, ctx: commands.Context): + """ + Lists the available global aliases on this bot. + """ + names = ["Aliases:", ] + sorted(["+ " + a.name for a in self.unloaded_global_aliases()]) + if len(names) == 0: + await ctx.send("There are no aliases on this server.") + else: + await ctx.send(box("\n".join(names), "diff")) + + async def on_message(self, message: discord.Message): + aliases = list(self.unloaded_aliases(message.guild)) + \ + list(self.unloaded_global_aliases()) + + if len(aliases) == 0: + return + + await self.maybe_call_alias(message, aliases=aliases) diff --git a/cogs/alias/alias_entry.py b/cogs/alias/alias_entry.py new file mode 100644 index 000000000..ec3c373c0 --- /dev/null +++ b/cogs/alias/alias_entry.py @@ -0,0 +1,63 @@ +from typing import Tuple +from discord.ext import commands + +import discord + + +class AliasEntry: + def __init__(self, name: str, command: Tuple[str], + creator: discord.Member, global_: bool=False): + super().__init__() + self.has_real_data = False + self.name = name + self.command = command + self.creator = creator + + self.global_ = global_ + + self.guild = None + if hasattr(creator, "guild"): + self.guild = creator.guild + + self.uses = 0 + + def inc(self): + """ + Increases the `uses` stat by 1. + :return: new use count + """ + self.uses += 1 + return self.uses + + def to_json(self) -> dict: + try: + creator = str(self.creator.id) + guild = str(self.guild.id) + except AttributeError: + creator = self.creator + guild = self.guild + + return { + "name": self.name, + "command": self.command, + "creator": creator, + "guild": guild, + "global": self.global_, + "uses": self.uses + } + + @classmethod + def from_json(cls, data: dict, bot: commands.Bot=None): + ret = cls(data["name"], data["command"], + data["creator"], global_=data["global"]) + + if bot: + ret.has_real_data = True + ret.creator = bot.get_user(int(data["creator"])) + guild = bot.get_guild(int(data["guild"])) + ret.guild = guild + else: + ret.guild = data["guild"] + + ret.uses = data.get("uses", 0) + return ret diff --git a/tests/cogs/downloader/test_downloader.py b/tests/cogs/downloader/test_downloader.py index e718b57a5..1756d279a 100644 --- a/tests/cogs/downloader/test_downloader.py +++ b/tests/cogs/downloader/test_downloader.py @@ -27,7 +27,7 @@ def patch_relative_to(monkeysession): monkeysession.setattr("pathlib.Path.relative_to", fake_relative_to) -@pytest.fixture(scope="module") +@pytest.fixture def repo_manager(tmpdir_factory, config): config.register_global(repos={}) rm = RepoManager(config) diff --git a/tests/cogs/test_alias.py b/tests/cogs/test_alias.py new file mode 100644 index 000000000..9e39b6f51 --- /dev/null +++ b/tests/cogs/test_alias.py @@ -0,0 +1,64 @@ +from cogs.alias import Alias +import pytest + + +@pytest.fixture +def alias(monkeysession, config): + def get_mock_conf(*args, **kwargs): + return config + + monkeysession.setattr("core.config.Config.get_conf", get_mock_conf) + + return Alias(None) + + +def test_is_valid_alias_name(alias): + assert alias.is_valid_alias_name("valid") is True + assert alias.is_valid_alias_name("not valid name") is False + + +def test_empty_guild_aliases(alias, empty_guild): + assert list(alias.unloaded_aliases(empty_guild)) == [] + + +def test_empty_global_aliases(alias): + assert list(alias.unloaded_global_aliases()) == [] + + +@pytest.mark.asyncio +async def test_add_guild_alias(alias, ctx): + await alias.add_alias(ctx, "test", "ping", global_=False) + + is_alias, alias_obj = alias.is_alias(ctx.guild, "test") + assert is_alias is True + assert alias_obj.global_ is False + + +@pytest.mark.asyncio +async def test_delete_guild_alias(alias, ctx): + is_alias, _ = alias.is_alias(ctx.guild, "test") + assert is_alias is True + + await alias.delete_alias(ctx, "test") + + is_alias, _ = alias.is_alias(ctx.guild, "test") + assert is_alias is False + + +@pytest.mark.asyncio +async def test_add_global_alias(alias, ctx): + await alias.add_alias(ctx, "test", "ping", global_=True) + is_alias, alias_obj = alias.is_alias(ctx.guild, "test") + + assert is_alias is True + assert alias_obj.global_ is True + + +@pytest.mark.asyncio +async def test_delete_global_alias(alias, ctx): + is_alias, alias_obj = alias.is_alias(ctx.guild, "test") + assert is_alias is True + assert alias_obj.global_ is True + + did_delete = await alias.delete_alias(ctx, alias_name="test", global_=True) + assert did_delete is True diff --git a/tests/conftest.py b/tests/conftest.py index 39adadd12..86d70afef 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -84,7 +84,7 @@ def empty_message(): return mock_msg("No content.") -@pytest.fixture(scope="module") +@pytest.fixture def ctx(empty_member, empty_channel, red): mock_ctx = namedtuple("Context", "author guild channel message bot") return mock_ctx(empty_member, empty_member.guild, empty_channel, @@ -94,16 +94,18 @@ def ctx(empty_member, empty_channel, red): #region Red Mock @pytest.fixture -def red(monkeypatch, config_fr, event_loop): +def red(monkeysession, config_fr): from core.cli import parse_cli_flags cli_flags = parse_cli_flags() description = "Red v3 - Alpha" - monkeypatch.setattr("core.config.Config.get_core_conf", - lambda *args, **kwargs: config_fr) + monkeysession.setattr("core.config.Config.get_core_conf", + lambda *args, **kwargs: config_fr) - red = Red(cli_flags, description=description, pm_help=None, - loop=event_loop) - return red + red = Red(cli_flags, description=description, pm_help=None) + + yield red + + red.http._session.close() #endregion \ No newline at end of file