[Alias] Create caching to call config less frequently (#3788)

This commit is contained in:
TrustyJAID 2020-04-26 18:25:41 -06:00 committed by GitHub
parent a1095285e4
commit 6f6c536236
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 269 additions and 239 deletions

View File

@ -2,5 +2,7 @@ from .alias import Alias
from redbot.core.bot import Red
def setup(bot: Red):
bot.add_cog(Alias(bot))
async def setup(bot: Red):
cog = Alias(bot)
await cog.initialize()
bot.add_cog(cog)

View File

@ -1,16 +1,15 @@
from copy import copy
from re import findall, search
from re import search
from string import Formatter
from typing import Generator, Tuple, Iterable, Optional
from typing import Dict
import discord
from discord.ext.commands.view import StringView
from redbot.core import Config, commands, checks
from redbot.core.i18n import Translator, cog_i18n
from redbot.core.utils.chat_formatting import box
from redbot.core.bot import Red
from .alias_entry import AliasEntry
from .alias_entry import AliasEntry, AliasCache, ArgParseError
_ = Translator("Alias", __file__)
@ -26,10 +25,6 @@ class _TrackingFormatter(Formatter):
return super().get_value(key, args, kwargs)
class ArgParseError(Exception):
pass
@cog_i18n(_)
class Alias(commands.Cog):
"""Create aliases for commands.
@ -42,9 +37,9 @@ class Alias(commands.Cog):
and append them to the stored alias.
"""
default_global_settings = {"entries": []}
default_global_settings: Dict[str, list] = {"entries": []}
default_guild_settings = {"enabled": False, "entries": []} # Going to be a list of dicts
default_guild_settings: Dict[str, list] = {"entries": []} # Going to be a list of dicts
def __init__(self, bot: Red):
super().__init__()
@ -53,40 +48,12 @@ class Alias(commands.Cog):
self.config.register_global(**self.default_global_settings)
self.config.register_guild(**self.default_guild_settings)
self._aliases: AliasCache = AliasCache(config=self.config, cache_enabled=True)
async def unloaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]:
return (AliasEntry.from_json(d) for d in (await self.config.guild(guild).entries()))
async def unloaded_global_aliases(self) -> Generator[AliasEntry, None, None]:
return (AliasEntry.from_json(d) for d in (await self.config.entries()))
async def loaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]:
return (
AliasEntry.from_json(d, bot=self.bot)
for d in (await self.config.guild(guild).entries())
)
async def loaded_global_aliases(self) -> Generator[AliasEntry, None, None]:
return (AliasEntry.from_json(d, bot=self.bot) for d in (await self.config.entries()))
async def is_alias(
self,
guild: Optional[discord.Guild],
alias_name: str,
server_aliases: Iterable[AliasEntry] = (),
) -> Tuple[bool, Optional[AliasEntry]]:
if not server_aliases and guild is not None:
server_aliases = await self.unloaded_aliases(guild)
global_aliases = await self.unloaded_global_aliases()
for aliases in (server_aliases, global_aliases):
for alias in aliases:
if alias.name == alias_name:
return True, alias
return False, None
async def initialize(self):
# This can be where we set the cache_enabled attribute later
if not self._aliases._loaded:
await self._aliases.load_aliases()
def is_command(self, alias_name: str) -> bool:
"""
@ -100,56 +67,6 @@ class Alias(commands.Cog):
def is_valid_alias_name(alias_name: str) -> bool:
return not bool(search(r"\s", alias_name)) and alias_name.isprintable()
async def add_alias(
self, ctx: commands.Context, alias_name: str, command: str, global_: bool = False
) -> AliasEntry:
indices = findall(r"{(\d*)}", command)
if indices:
try:
indices = [int(a[0]) for a in indices]
except IndexError:
raise ArgParseError(_("Arguments must be specified with a number."))
low = min(indices)
indices = [a - low for a in indices]
high = max(indices)
gaps = set(indices).symmetric_difference(range(high + 1))
if gaps:
raise ArgParseError(
_("Arguments must be sequential. Missing arguments: ")
+ ", ".join(str(i + low) for i in gaps)
)
command = command.format(*(f"{{{i}}}" for i in range(-low, high + low + 1)))
alias = AliasEntry(alias_name, command, ctx.author, global_=global_)
if global_:
settings = self.config
else:
settings = self.config.guild(ctx.guild)
await settings.enabled.set(True)
async with settings.entries() as curr_aliases:
curr_aliases.append(alias.to_json())
return alias
async def delete_alias(
self, ctx: commands.Context, alias_name: str, global_: bool = False
) -> bool:
if global_:
settings = self.config
else:
settings = self.config.guild(ctx.guild)
async with settings.entries() as aliases:
for alias in aliases:
alias_obj = AliasEntry.from_json(alias)
if alias_obj.name == alias_name:
aliases.remove(alias)
return True
return False
async def get_prefix(self, message: discord.Message) -> str:
"""
Tries to determine what prefix is used in a message object.
@ -167,57 +84,11 @@ class Alias(commands.Cog):
return p
raise ValueError(_("No prefix found."))
def get_extra_args_from_alias(
self, message: discord.Message, prefix: str, alias: AliasEntry
) -> str:
"""
When an alias is executed by a user in chat this function tries
to get any extra arguments passed in with the call.
Whitespace will be trimmed from both ends.
:param message:
:param prefix:
:param alias:
:return:
"""
known_content_length = len(prefix) + len(alias.name)
extra = message.content[known_content_length:]
view = StringView(extra)
view.skip_ws()
extra = []
while not view.eof:
prev = view.index
word = view.get_quoted_word()
if len(word) < view.index - prev:
word = "".join((view.buffer[prev], word, view.buffer[view.index - 1]))
extra.append(word)
view.skip_ws()
return extra
async def maybe_call_alias(
self, message: discord.Message, aliases: Iterable[AliasEntry] = None
):
try:
prefix = await self.get_prefix(message)
except ValueError:
return
try:
potential_alias = message.content[len(prefix) :].split(" ")[0]
except IndexError:
return False
is_alias, alias = await self.is_alias(
message.guild, potential_alias, server_aliases=aliases
)
if is_alias:
await self.call_alias(message, prefix, alias)
async def call_alias(self, message: discord.Message, prefix: str, alias: AliasEntry):
new_message = copy(message)
try:
args = self.get_extra_args_from_alias(message, prefix, alias)
except commands.BadArgument as bae:
args = alias.get_extra_args_from_alias(message, prefix)
except commands.BadArgument:
return
trackform = _TrackingFormatter()
@ -257,8 +128,8 @@ class Alias(commands.Cog):
)
return
is_alias, something_useless = await self.is_alias(ctx.guild, alias_name)
if is_alias:
alias = await self._aliases.get_alias(ctx.guild, alias_name)
if alias:
await ctx.send(
_(
"You attempted to create a new alias"
@ -292,7 +163,7 @@ class Alias(commands.Cog):
# and that the alias name is valid.
try:
await self.add_alias(ctx, alias_name, command)
await self._aliases.add_alias(ctx, alias_name, command)
except ArgParseError as e:
return await ctx.send(" ".join(e.args))
@ -316,8 +187,8 @@ class Alias(commands.Cog):
)
return
is_alias, something_useless = await self.is_alias(ctx.guild, alias_name)
if is_alias:
alias = await self._aliases.get_alias(ctx.guild, alias_name)
if alias:
await ctx.send(
_(
"You attempted to create a new global alias"
@ -341,7 +212,7 @@ class Alias(commands.Cog):
# endregion
try:
await self.add_alias(ctx, alias_name, command, global_=True)
await self._aliases.add_alias(ctx, alias_name, command, global_=True)
except ArgParseError as e:
return await ctx.send(" ".join(e.args))
@ -355,8 +226,8 @@ class Alias(commands.Cog):
@commands.guild_only()
async def _help_alias(self, ctx: commands.Context, alias_name: str):
"""Try to execute help for the base command of the alias."""
is_alias, alias = await self.is_alias(ctx.guild, alias_name=alias_name)
if is_alias:
alias = await self._aliases.get_alias(ctx.guild, alias_name=alias_name)
if alias:
if self.is_command(alias.command):
base_cmd = alias.command
else:
@ -372,9 +243,9 @@ class Alias(commands.Cog):
@commands.guild_only()
async def _show_alias(self, ctx: commands.Context, alias_name: str):
"""Show what command the alias executes."""
is_alias, alias = await self.is_alias(ctx.guild, alias_name)
alias = await self._aliases.get_alias(ctx.guild, alias_name)
if is_alias:
if alias:
await ctx.send(
_("The `{alias_name}` alias will execute the command `{command}`").format(
alias_name=alias_name, command=alias.command
@ -388,14 +259,11 @@ class Alias(commands.Cog):
@commands.guild_only()
async def _del_alias(self, ctx: commands.Context, alias_name: str):
"""Delete an existing alias on this server."""
aliases = await self.unloaded_aliases(ctx.guild)
try:
next(aliases)
except StopIteration:
if not await self._aliases.get_guild_aliases(ctx.guild):
await ctx.send(_("There are no aliases on this server."))
return
if await self.delete_alias(ctx, alias_name):
if await self._aliases.delete_alias(ctx, alias_name):
await ctx.send(
_("Alias with the name `{name}` was successfully deleted.").format(name=alias_name)
)
@ -406,14 +274,11 @@ class Alias(commands.Cog):
@global_.command(name="delete", aliases=["del", "remove"])
async def _del_global_alias(self, ctx: commands.Context, alias_name: str):
"""Delete an existing global alias."""
aliases = await self.unloaded_global_aliases()
try:
next(aliases)
except StopIteration:
await ctx.send(_("There are no aliases on this bot."))
if not await self._aliases.get_global_aliases():
await ctx.send(_("There are no global aliases on this bot."))
return
if await self.delete_alias(ctx, alias_name, global_=True):
if await self._aliases.delete_alias(ctx, alias_name, global_=True):
await ctx.send(
_("Alias with the name `{name}` was successfully deleted.").format(name=alias_name)
)
@ -424,32 +289,34 @@ class Alias(commands.Cog):
@commands.guild_only()
async def _list_alias(self, ctx: commands.Context):
"""List the available aliases on this server."""
names = [_("Aliases:")] + sorted(
["+ " + a.name for a in (await self.unloaded_aliases(ctx.guild))]
)
if len(names) == 0:
await ctx.send(_("There are no aliases on this server."))
else:
await ctx.send(box("\n".join(names), "diff"))
guild_aliases = await self._aliases.get_guild_aliases(ctx.guild)
if not guild_aliases:
return await ctx.send(_("There are no aliases on this server."))
names = [_("Aliases:")] + sorted(["+ " + a.name for a in guild_aliases])
await ctx.send(box("\n".join(names), "diff"))
@global_.command(name="list")
async def _list_global_alias(self, ctx: commands.Context):
"""List the available global aliases on this bot."""
names = [_("Aliases:")] + sorted(
["+ " + a.name for a in await self.unloaded_global_aliases()]
)
if len(names) == 0:
await ctx.send(_("There are no aliases on this server."))
else:
await ctx.send(box("\n".join(names), "diff"))
global_aliases = await self._aliases.get_global_aliases()
if not global_aliases:
return await ctx.send(_("There are no global aliases."))
names = [_("Aliases:")] + sorted(["+ " + a.name for a in global_aliases])
await ctx.send(box("\n".join(names), "diff"))
@commands.Cog.listener()
async def on_message(self, message: discord.Message):
aliases = list(await self.unloaded_global_aliases())
if message.guild is not None:
aliases = aliases + list(await self.unloaded_aliases(message.guild))
if len(aliases) == 0:
async def on_message_without_command(self, message: discord.Message):
try:
prefix = await self.get_prefix(message)
except ValueError:
return
await self.maybe_call_alias(message, aliases=aliases)
try:
potential_alias = message.content[len(prefix) :].split(" ")[0]
except IndexError:
return
alias = await self._aliases.get_alias(message.guild, potential_alias)
if alias:
await self.call_alias(message, prefix, alias)

View File

@ -1,25 +1,37 @@
from typing import Tuple
from typing import Tuple, Dict, Optional, List, Union
from re import findall
import discord
from redbot.core import commands
from discord.ext.commands.view import StringView
from redbot.core import commands, Config
from redbot.core.i18n import Translator
from redbot.core.utils import AsyncIter
_ = Translator("Alias", __file__)
class ArgParseError(Exception):
pass
class AliasEntry:
"""An object containing all required information about an alias"""
name: str
command: Union[Tuple[str], str]
creator: int
guild: Optional[int]
uses: int
def __init__(
self, name: str, command: Tuple[str], creator: discord.Member, global_: bool = False
self, name: str, command: Union[Tuple[str], str], creator: int, guild: Optional[int],
):
super().__init__()
self.has_real_data = False
self.name = name
self.command = command
self.creator = creator
self.global_ = global_
self.guild = None
if hasattr(creator, "guild"):
self.guild = creator.guild
self.guild = guild
self.uses = 0
def inc(self):
@ -30,34 +42,182 @@ class AliasEntry:
self.uses += 1
return self.uses
def get_extra_args_from_alias(self, message: discord.Message, prefix: str) -> str:
"""
When an alias is executed by a user in chat this function tries
to get any extra arguments passed in with the call.
Whitespace will be trimmed from both ends.
:param message:
:param prefix:
:param alias:
:return:
"""
known_content_length = len(prefix) + len(self.name)
extra = message.content[known_content_length:]
view = StringView(extra)
view.skip_ws()
extra = []
while not view.eof:
prev = view.index
word = view.get_quoted_word()
if len(word) < view.index - prev:
word = "".join((view.buffer[prev], word, view.buffer[view.index - 1]))
extra.append(word)
view.skip_ws()
return extra
def to_json(self) -> dict:
try:
creator = str(self.creator.id)
guild = str(self.guild.id)
except AttributeError:
creator = self.creator
guild = self.guild
return {
"name": self.name,
"command": self.command,
"creator": creator,
"guild": guild,
"global": self.global_,
"creator": self.creator,
"guild": self.guild,
"uses": self.uses,
}
@classmethod
def from_json(cls, data: dict, bot: commands.Bot = None):
ret = cls(data["name"], data["command"], data["creator"], global_=data["global"])
if bot:
ret.has_real_data = True
ret.creator = bot.get_user(int(data["creator"]))
guild = bot.get_guild(int(data["guild"]))
ret.guild = guild
else:
ret.guild = data["guild"]
def from_json(cls, data: dict):
ret = cls(data["name"], data["command"], data["creator"], data["guild"])
ret.uses = data.get("uses", 0)
return ret
class AliasCache:
def __init__(self, config: Config, cache_enabled: bool = True):
self.config = config
self._cache_enabled = cache_enabled
self._loaded = False
self._aliases: Dict[Optional[int], Dict[str, AliasEntry]] = {None: {}}
async def load_aliases(self):
if not self._cache_enabled:
self._loaded = True
return
for alias in await self.config.entries():
self._aliases[None][alias["name"]] = AliasEntry.from_json(alias)
all_guilds = await self.config.all_guilds()
async for guild_id, guild_data in AsyncIter(all_guilds.items(), steps=100):
if guild_id not in self._aliases:
self._aliases[guild_id] = {}
for alias in guild_data["entries"]:
self._aliases[guild_id][alias["name"]] = AliasEntry.from_json(alias)
self._loaded = True
async def get_aliases(self, ctx: commands.Context) -> List[AliasEntry]:
"""Returns all possible aliases with the given context"""
global_aliases: List[AliasEntry] = []
server_aliases: List[AliasEntry] = []
global_aliases = await self.get_global_aliases()
if ctx.guild and ctx.guild.id in self._aliases:
server_aliases = await self.get_guild_aliases(ctx.guild)
return global_aliases + server_aliases
async def get_guild_aliases(self, guild: discord.Guild) -> List[AliasEntry]:
"""Returns all guild specific aliases"""
aliases: List[AliasEntry] = []
if self._cache_enabled:
if guild.id in self._aliases:
for _, alias in self._aliases[guild.id].items():
aliases.append(alias)
else:
aliases = [AliasEntry.from_json(d) for d in await self.config.guild(guild).entries()]
return aliases
async def get_global_aliases(self) -> List[AliasEntry]:
"""Returns all global specific aliases"""
aliases: List[AliasEntry] = []
if self._cache_enabled:
for _, alias in self._aliases[None].items():
aliases.append(alias)
else:
aliases = [AliasEntry.from_json(d) for d in await self.config.entries()]
return aliases
async def get_alias(
self, guild: Optional[discord.Guild], alias_name: str,
) -> Optional[AliasEntry]:
"""Returns an AliasEntry object if the provided alias_name is a registered alias"""
server_aliases: List[AliasEntry] = []
if self._cache_enabled:
if alias_name in self._aliases[None]:
return self._aliases[None][alias_name]
if guild is not None:
if guild.id in self._aliases:
if alias_name in self._aliases[guild.id]:
return self._aliases[guild.id][alias_name]
else:
if guild:
server_aliases = [
AliasEntry.from_json(d) for d in await self.config.guild(guild.id).entries()
]
global_aliases = [AliasEntry.from_json(d) for d in await self.config.entries()]
all_aliases = global_aliases + server_aliases
for alias in all_aliases:
if alias.name == alias_name:
return alias
return None
async def add_alias(
self, ctx: commands.Context, alias_name: str, command: str, global_: bool = False
) -> AliasEntry:
indices = findall(r"{(\d*)}", command)
if indices:
try:
indices = [int(a[0]) for a in indices]
except IndexError:
raise ArgParseError(_("Arguments must be specified with a number."))
low = min(indices)
indices = [a - low for a in indices]
high = max(indices)
gaps = set(indices).symmetric_difference(range(high + 1))
if gaps:
raise ArgParseError(
_("Arguments must be sequential. Missing arguments: ")
+ ", ".join(str(i + low) for i in gaps)
)
command = command.format(*(f"{{{i}}}" for i in range(-low, high + low + 1)))
if global_:
alias = AliasEntry(alias_name, command, ctx.author.id, None)
settings = self.config
if self._cache_enabled:
self._aliases[None][alias.name] = alias
else:
alias = AliasEntry(alias_name, command, ctx.author.id, ctx.guild.id)
settings = self.config.guild(ctx.guild)
if self._cache_enabled:
if ctx.guild.id not in self._aliases:
self._aliases[ctx.guild.id] = {}
self._aliases[ctx.guild.id][alias.name] = alias
async with settings.entries() as curr_aliases:
curr_aliases.append(alias.to_json())
return alias
async def delete_alias(
self, ctx: commands.Context, alias_name: str, global_: bool = False
) -> bool:
if global_:
settings = self.config
else:
settings = self.config.guild(ctx.guild)
async with settings.entries() as aliases:
for alias in aliases:
if alias["name"] == alias_name:
aliases.remove(alias)
if self._cache_enabled:
if global_:
del self._aliases[None][alias_name]
else:
del self._aliases[ctx.guild.id][alias_name]
return True
return False

View File

@ -409,8 +409,8 @@ class Cleanup(commands.Cog):
alias_cog = self.bot.get_cog("Alias")
if alias_cog is not None:
alias_names: Set[str] = (
set((a.name for a in await alias_cog.unloaded_global_aliases()))
| set(a.name for a in await alias_cog.unloaded_aliases(ctx.guild))
set((a.name for a in await alias_cog._aliases.get_global_aliases()))
| set(a.name for a in await alias_cog._aliases.get_guild_aliases(ctx.guild))
)
is_alias = lambda name: name in alias_names
else:
@ -538,7 +538,7 @@ class Cleanup(commands.Cog):
@commands.bot_has_permissions(manage_messages=True)
async def cleanup_spam(self, ctx: commands.Context, number: int = 50):
"""Deletes duplicate messages in the channel from the last X messages and keeps only one copy.
Defaults to 50.
"""
msgs = []

View File

@ -103,9 +103,9 @@ async def fuzzy_command_search(
# If the term is an alias or CC, we don't want to send a supplementary fuzzy search.
alias_cog = ctx.bot.get_cog("Alias")
if alias_cog is not None:
is_alias, alias = await alias_cog.is_alias(ctx.guild, term)
alias = await alias_cog._aliases.get_alias(ctx.guild, term)
if is_alias:
if alias:
return None
customcom_cog = ctx.bot.get_cog("CustomCommands")
if customcom_cog is not None:

View File

@ -9,58 +9,59 @@ def test_is_valid_alias_name(alias):
@pytest.mark.asyncio
async def test_empty_guild_aliases(alias, empty_guild):
assert list(await alias.unloaded_aliases(empty_guild)) == []
assert list(await alias._aliases.get_guild_aliases(empty_guild)) == []
@pytest.mark.asyncio
async def test_empty_global_aliases(alias):
assert list(await alias.unloaded_global_aliases()) == []
assert list(await alias._aliases.get_global_aliases()) == []
async def create_test_guild_alias(alias, ctx):
await alias.add_alias(ctx, "test", "ping", global_=False)
await alias._aliases.add_alias(ctx, "test", "ping", global_=False)
async def create_test_global_alias(alias, ctx):
await alias.add_alias(ctx, "test", "ping", global_=True)
await alias._aliases.add_alias(ctx, "test_global", "ping", global_=True)
@pytest.mark.asyncio
async def test_add_guild_alias(alias, ctx):
await create_test_guild_alias(alias, ctx)
is_alias, alias_obj = await alias.is_alias(ctx.guild, "test")
assert is_alias is True
assert alias_obj.global_ is False
alias_obj = await alias._aliases.get_alias(ctx.guild, "test")
assert alias_obj.name == "test"
@pytest.mark.asyncio
async def test_delete_guild_alias(alias, ctx):
await create_test_guild_alias(alias, ctx)
is_alias, _ = await alias.is_alias(ctx.guild, "test")
assert is_alias is True
alias_obj = await alias._aliases.get_alias(ctx.guild, "test")
assert alias_obj.name == "test"
await alias.delete_alias(ctx, "test")
did_delete = await alias._aliases.delete_alias(ctx, "test")
assert did_delete is True
is_alias, _ = await alias.is_alias(ctx.guild, "test")
assert is_alias is False
alias_obj = await alias._aliases.get_alias(ctx.guild, "test")
assert alias_obj is None
@pytest.mark.asyncio
async def test_add_global_alias(alias, ctx):
await create_test_global_alias(alias, ctx)
is_alias, alias_obj = await alias.is_alias(ctx.guild, "test")
alias_obj = await alias._aliases.get_alias(ctx.guild, "test_global")
assert is_alias is True
assert alias_obj.global_ is True
assert alias_obj.name == "test_global"
@pytest.mark.asyncio
async def test_delete_global_alias(alias, ctx):
await create_test_global_alias(alias, ctx)
is_alias, alias_obj = await alias.is_alias(ctx.guild, "test")
assert is_alias is True
assert alias_obj.global_ is True
alias_obj = await alias._aliases.get_alias(ctx.guild, "test_global")
assert alias_obj.name == "test_global"
did_delete = await alias.delete_alias(ctx, alias_name="test", global_=True)
did_delete = await alias._aliases.delete_alias(ctx, alias_name="test_global", global_=True)
assert did_delete is True
alias_obj = await alias._aliases.get_alias(None, "test_global")
assert alias_obj is None