mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-07 03:38:53 -05:00
[Config] Asynchronous getters (#907)
* Make config get async * Asyncify alias * Asyncify bank * Asyncify cog manager * IT BOOTS * Asyncify core commands * Asyncify repo manager * Asyncify downloader * Asyncify economy * Asyncify alias TESTS * Asyncify economy TESTS * Asyncify downloader TESTS * Asyncify config TESTS * A bank thing * Asyncify Bank cog * Warning message in docs * Update docs with await syntax * Update docs with await syntax
This commit is contained in:
parent
cf8e11238c
commit
de912a3cfb
@ -38,26 +38,26 @@ class Alias:
|
|||||||
self._aliases.register_global(**self.default_global_settings)
|
self._aliases.register_global(**self.default_global_settings)
|
||||||
self._aliases.register_guild(**self.default_guild_settings)
|
self._aliases.register_guild(**self.default_guild_settings)
|
||||||
|
|
||||||
def unloaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]:
|
async def unloaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]:
|
||||||
return (AliasEntry.from_json(d) for d in self._aliases.guild(guild).entries())
|
return (AliasEntry.from_json(d) for d in (await self._aliases.guild(guild).entries()))
|
||||||
|
|
||||||
def unloaded_global_aliases(self) -> Generator[AliasEntry, None, None]:
|
async def unloaded_global_aliases(self) -> Generator[AliasEntry, None, None]:
|
||||||
return (AliasEntry.from_json(d) for d in self._aliases.entries())
|
return (AliasEntry.from_json(d) for d in (await self._aliases.entries()))
|
||||||
|
|
||||||
def loaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]:
|
async def loaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]:
|
||||||
return (AliasEntry.from_json(d, bot=self.bot)
|
return (AliasEntry.from_json(d, bot=self.bot)
|
||||||
for d in self._aliases.guild(guild).entries())
|
for d in (await self._aliases.guild(guild).entries()))
|
||||||
|
|
||||||
def loaded_global_aliases(self) -> Generator[AliasEntry, None, None]:
|
async def loaded_global_aliases(self) -> Generator[AliasEntry, None, None]:
|
||||||
return (AliasEntry.from_json(d, bot=self.bot) for d in self._aliases.entries())
|
return (AliasEntry.from_json(d, bot=self.bot) for d in (await self._aliases.entries()))
|
||||||
|
|
||||||
def is_alias(self, guild: discord.Guild, alias_name: str,
|
async def is_alias(self, guild: discord.Guild, alias_name: str,
|
||||||
server_aliases: Iterable[AliasEntry]=()) -> (bool, AliasEntry):
|
server_aliases: Iterable[AliasEntry]=()) -> (bool, AliasEntry):
|
||||||
|
|
||||||
if not server_aliases:
|
if not server_aliases:
|
||||||
server_aliases = self.unloaded_aliases(guild)
|
server_aliases = await self.unloaded_aliases(guild)
|
||||||
|
|
||||||
global_aliases = self.unloaded_global_aliases()
|
global_aliases = await self.unloaded_global_aliases()
|
||||||
|
|
||||||
for aliases in (server_aliases, global_aliases):
|
for aliases in (server_aliases, global_aliases):
|
||||||
for alias in aliases:
|
for alias in aliases:
|
||||||
@ -79,11 +79,11 @@ class Alias:
|
|||||||
alias = AliasEntry(alias_name, command, ctx.author, global_=global_)
|
alias = AliasEntry(alias_name, command, ctx.author, global_=global_)
|
||||||
|
|
||||||
if global_:
|
if global_:
|
||||||
curr_aliases = self._aliases.entries()
|
curr_aliases = await self._aliases.entries()
|
||||||
curr_aliases.append(alias.to_json())
|
curr_aliases.append(alias.to_json())
|
||||||
await self._aliases.entries.set(curr_aliases)
|
await self._aliases.entries.set(curr_aliases)
|
||||||
else:
|
else:
|
||||||
curr_aliases = self._aliases.guild(ctx.guild).entries()
|
curr_aliases = await self._aliases.guild(ctx.guild).entries()
|
||||||
|
|
||||||
curr_aliases.append(alias.to_json())
|
curr_aliases.append(alias.to_json())
|
||||||
await self._aliases.guild(ctx.guild).entries.set(curr_aliases)
|
await self._aliases.guild(ctx.guild).entries.set(curr_aliases)
|
||||||
@ -94,10 +94,10 @@ class Alias:
|
|||||||
async def delete_alias(self, ctx: commands.Context, alias_name: str,
|
async def delete_alias(self, ctx: commands.Context, alias_name: str,
|
||||||
global_: bool=False) -> bool:
|
global_: bool=False) -> bool:
|
||||||
if global_:
|
if global_:
|
||||||
aliases = self.unloaded_global_aliases()
|
aliases = await self.unloaded_global_aliases()
|
||||||
setter_func = self._aliases.entries.set
|
setter_func = self._aliases.entries.set
|
||||||
else:
|
else:
|
||||||
aliases = self.unloaded_aliases(ctx.guild)
|
aliases = await self.unloaded_aliases(ctx.guild)
|
||||||
setter_func = self._aliases.guild(ctx.guild).entries.set
|
setter_func = self._aliases.guild(ctx.guild).entries.set
|
||||||
|
|
||||||
did_delete_alias = False
|
did_delete_alias = False
|
||||||
@ -161,7 +161,7 @@ class Alias:
|
|||||||
except IndexError:
|
except IndexError:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
is_alias, alias = self.is_alias(message.guild, potential_alias, server_aliases=aliases)
|
is_alias, alias = await self.is_alias(message.guild, potential_alias, server_aliases=aliases)
|
||||||
|
|
||||||
if is_alias:
|
if is_alias:
|
||||||
await self.call_alias(message, prefix, alias)
|
await self.call_alias(message, prefix, alias)
|
||||||
@ -206,7 +206,7 @@ class Alias:
|
|||||||
" name is already a command on this bot.").format(alias_name))
|
" name is already a command on this bot.").format(alias_name))
|
||||||
return
|
return
|
||||||
|
|
||||||
is_alias, _ = self.is_alias(ctx.guild, alias_name)
|
is_alias, _ = await self.is_alias(ctx.guild, alias_name)
|
||||||
if is_alias:
|
if is_alias:
|
||||||
await ctx.send(("You attempted to create a new alias"
|
await ctx.send(("You attempted to create a new alias"
|
||||||
" with the name {} but that"
|
" with the name {} but that"
|
||||||
@ -285,7 +285,7 @@ class Alias:
|
|||||||
@commands.guild_only()
|
@commands.guild_only()
|
||||||
async def _show_alias(self, ctx: commands.Context, alias_name: str):
|
async def _show_alias(self, ctx: commands.Context, alias_name: str):
|
||||||
"""Shows what command the alias executes."""
|
"""Shows what command the alias executes."""
|
||||||
is_alias, alias = self.is_alias(ctx.guild, alias_name)
|
is_alias, alias = await self.is_alias(ctx.guild, alias_name)
|
||||||
|
|
||||||
if is_alias:
|
if is_alias:
|
||||||
await ctx.send(("The `{}` alias will execute the"
|
await ctx.send(("The `{}` alias will execute the"
|
||||||
@ -299,7 +299,7 @@ class Alias:
|
|||||||
"""
|
"""
|
||||||
Deletes an existing alias on this server.
|
Deletes an existing alias on this server.
|
||||||
"""
|
"""
|
||||||
aliases = self.unloaded_aliases(ctx.guild)
|
aliases = await self.unloaded_aliases(ctx.guild)
|
||||||
try:
|
try:
|
||||||
next(aliases)
|
next(aliases)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
@ -317,7 +317,7 @@ class Alias:
|
|||||||
"""
|
"""
|
||||||
Deletes an existing global alias.
|
Deletes an existing global alias.
|
||||||
"""
|
"""
|
||||||
aliases = self.unloaded_global_aliases()
|
aliases = await self.unloaded_global_aliases()
|
||||||
try:
|
try:
|
||||||
next(aliases)
|
next(aliases)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
@ -336,7 +336,7 @@ class Alias:
|
|||||||
"""
|
"""
|
||||||
Lists the available aliases on this server.
|
Lists the available aliases on this server.
|
||||||
"""
|
"""
|
||||||
names = ["Aliases:", ] + sorted(["+ " + a.name for a in self.unloaded_aliases(ctx.guild)])
|
names = ["Aliases:", ] + sorted(["+ " + a.name for a in (await self.unloaded_aliases(ctx.guild))])
|
||||||
if len(names) == 0:
|
if len(names) == 0:
|
||||||
await ctx.send("There are no aliases on this server.")
|
await ctx.send("There are no aliases on this server.")
|
||||||
else:
|
else:
|
||||||
@ -347,16 +347,16 @@ class Alias:
|
|||||||
"""
|
"""
|
||||||
Lists the available global aliases on this bot.
|
Lists the available global aliases on this bot.
|
||||||
"""
|
"""
|
||||||
names = ["Aliases:", ] + sorted(["+ " + a.name for a in self.unloaded_global_aliases()])
|
names = ["Aliases:", ] + sorted(["+ " + a.name for a in await self.unloaded_global_aliases()])
|
||||||
if len(names) == 0:
|
if len(names) == 0:
|
||||||
await ctx.send("There are no aliases on this server.")
|
await ctx.send("There are no aliases on this server.")
|
||||||
else:
|
else:
|
||||||
await ctx.send(box("\n".join(names), "diff"))
|
await ctx.send(box("\n".join(names), "diff"))
|
||||||
|
|
||||||
async def on_message(self, message: discord.Message):
|
async def on_message(self, message: discord.Message):
|
||||||
aliases = list(self.unloaded_global_aliases())
|
aliases = list(await self.unloaded_global_aliases())
|
||||||
if message.guild is not None:
|
if message.guild is not None:
|
||||||
aliases = aliases + list(self.unloaded_aliases(message.guild))
|
aliases = aliases + list(await self.unloaded_aliases(message.guild))
|
||||||
|
|
||||||
if len(aliases) == 0:
|
if len(aliases) == 0:
|
||||||
return
|
return
|
||||||
|
|||||||
@ -6,7 +6,7 @@ from core.bot import Red # Only used for type hints
|
|||||||
|
|
||||||
def check_global_setting_guildowner():
|
def check_global_setting_guildowner():
|
||||||
async def pred(ctx: commands.Context):
|
async def pred(ctx: commands.Context):
|
||||||
if bank.is_global():
|
if await bank.is_global():
|
||||||
return checks.is_owner()
|
return checks.is_owner()
|
||||||
else:
|
else:
|
||||||
return checks.guildowner_or_permissions(administrator=True)
|
return checks.guildowner_or_permissions(administrator=True)
|
||||||
@ -15,7 +15,7 @@ def check_global_setting_guildowner():
|
|||||||
|
|
||||||
def check_global_setting_admin():
|
def check_global_setting_admin():
|
||||||
async def pred(ctx: commands.Context):
|
async def pred(ctx: commands.Context):
|
||||||
if bank.is_global():
|
if await bank.is_global():
|
||||||
return checks.is_owner()
|
return checks.is_owner()
|
||||||
else:
|
else:
|
||||||
return checks.admin_or_permissions(manage_guild=True)
|
return checks.admin_or_permissions(manage_guild=True)
|
||||||
@ -43,7 +43,7 @@ class Bank:
|
|||||||
"""Toggles whether the bank is global or not
|
"""Toggles whether the bank is global or not
|
||||||
If the bank is global, it will become per-guild
|
If the bank is global, it will become per-guild
|
||||||
If the bank is per-guild, it will become global"""
|
If the bank is per-guild, it will become global"""
|
||||||
cur_setting = bank.is_global()
|
cur_setting = await bank.is_global()
|
||||||
await bank.set_global(not cur_setting, ctx.author)
|
await bank.set_global(not cur_setting, ctx.author)
|
||||||
|
|
||||||
word = "per-guild" if cur_setting else "global"
|
word = "per-guild" if cur_setting else "global"
|
||||||
|
|||||||
@ -49,22 +49,20 @@ class Downloader:
|
|||||||
|
|
||||||
self._repo_manager = RepoManager(self.conf)
|
self._repo_manager = RepoManager(self.conf)
|
||||||
|
|
||||||
@property
|
async def cog_install_path(self):
|
||||||
def cog_install_path(self):
|
|
||||||
"""
|
"""
|
||||||
Returns the current cog install path.
|
Returns the current cog install path.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return self.bot.cog_mgr.install_path
|
return await self.bot.cog_mgr.install_path()
|
||||||
|
|
||||||
@property
|
async def installed_cogs(self) -> Tuple[Installable]:
|
||||||
def installed_cogs(self) -> Tuple[Installable]:
|
|
||||||
"""
|
"""
|
||||||
Returns the dictionary mapping cog name to install location
|
Returns the dictionary mapping cog name to install location
|
||||||
and repo name.
|
and repo name.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
installed = self.conf.installed()
|
installed = await self.conf.installed()
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
return tuple(Installable.from_json(v) for v in installed)
|
return tuple(Installable.from_json(v) for v in installed)
|
||||||
|
|
||||||
@ -74,7 +72,7 @@ class Downloader:
|
|||||||
:param cog:
|
:param cog:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
installed = self.conf.installed()
|
installed = await self.conf.installed()
|
||||||
cog_json = cog.to_json()
|
cog_json = cog.to_json()
|
||||||
|
|
||||||
if cog_json not in installed:
|
if cog_json not in installed:
|
||||||
@ -87,7 +85,7 @@ class Downloader:
|
|||||||
:param cog:
|
:param cog:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
installed = self.conf.installed()
|
installed = await self.conf.installed()
|
||||||
cog_json = cog.to_json()
|
cog_json = cog.to_json()
|
||||||
|
|
||||||
if cog_json in installed:
|
if cog_json in installed:
|
||||||
@ -102,7 +100,7 @@ class Downloader:
|
|||||||
"""
|
"""
|
||||||
failed = []
|
failed = []
|
||||||
for cog in cogs:
|
for cog in cogs:
|
||||||
if not await cog.copy_to(self.cog_install_path):
|
if not await cog.copy_to(await self.cog_install_path()):
|
||||||
failed.append(cog)
|
failed.append(cog)
|
||||||
|
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
@ -249,7 +247,7 @@ class Downloader:
|
|||||||
" `{}`: `{}`".format(cog.name, cog.requirements))
|
" `{}`: `{}`".format(cog.name, cog.requirements))
|
||||||
return
|
return
|
||||||
|
|
||||||
await repo_name.install_cog(cog, self.cog_install_path)
|
await repo_name.install_cog(cog, await self.cog_install_path())
|
||||||
|
|
||||||
await self._add_to_installed(cog)
|
await self._add_to_installed(cog)
|
||||||
|
|
||||||
@ -266,7 +264,7 @@ class Downloader:
|
|||||||
# noinspection PyUnresolvedReferences,PyProtectedMember
|
# noinspection PyUnresolvedReferences,PyProtectedMember
|
||||||
real_name = cog_name.name
|
real_name = cog_name.name
|
||||||
|
|
||||||
poss_installed_path = self.cog_install_path / real_name
|
poss_installed_path = (await self.cog_install_path()) / real_name
|
||||||
if poss_installed_path.exists():
|
if poss_installed_path.exists():
|
||||||
await self._delete_cog(poss_installed_path)
|
await self._delete_cog(poss_installed_path)
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
@ -284,7 +282,7 @@ class Downloader:
|
|||||||
"""
|
"""
|
||||||
if cog_name is None:
|
if cog_name is None:
|
||||||
updated = await self._repo_manager.update_all_repos()
|
updated = await self._repo_manager.update_all_repos()
|
||||||
installed_cogs = set(self.installed_cogs)
|
installed_cogs = set(await self.installed_cogs())
|
||||||
updated_cogs = set(cog for repo in updated.keys() for cog in repo.available_cogs)
|
updated_cogs = set(cog for repo in updated.keys() for cog in repo.available_cogs)
|
||||||
|
|
||||||
installed_and_updated = updated_cogs & installed_cogs
|
installed_and_updated = updated_cogs & installed_cogs
|
||||||
@ -325,14 +323,14 @@ class Downloader:
|
|||||||
msg = "Information on {}:\n{}".format(cog.name, cog.description or "")
|
msg = "Information on {}:\n{}".format(cog.name, cog.description or "")
|
||||||
await ctx.send(box(msg))
|
await ctx.send(box(msg))
|
||||||
|
|
||||||
def is_installed(self, cog_name: str) -> (bool, Union[Installable, None]):
|
async def is_installed(self, cog_name: str) -> (bool, Union[Installable, None]):
|
||||||
"""
|
"""
|
||||||
Checks to see if a cog with the given name was installed
|
Checks to see if a cog with the given name was installed
|
||||||
through Downloader.
|
through Downloader.
|
||||||
:param cog_name:
|
:param cog_name:
|
||||||
:return: is_installed, Installable
|
:return: is_installed, Installable
|
||||||
"""
|
"""
|
||||||
for installable in self.installed_cogs:
|
for installable in await self.installed_cogs():
|
||||||
if installable.name == cog_name:
|
if installable.name == cog_name:
|
||||||
return True, installable
|
return True, installable
|
||||||
return False, None
|
return False, None
|
||||||
@ -384,7 +382,7 @@ class Downloader:
|
|||||||
|
|
||||||
# Check if in installed cogs
|
# Check if in installed cogs
|
||||||
cog_name = self.cog_name_from_instance(command.instance)
|
cog_name = self.cog_name_from_instance(command.instance)
|
||||||
installed, cog_installable = self.is_installed(cog_name)
|
installed, cog_installable = await self.is_installed(cog_name)
|
||||||
if installed:
|
if installed:
|
||||||
msg = self.format_findcog_info(command_name, cog_installable)
|
msg = self.format_findcog_info(command_name, cog_installable)
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -430,7 +430,10 @@ class RepoManager:
|
|||||||
|
|
||||||
self.repos_folder = Path(__file__).parent / 'repos'
|
self.repos_folder = Path(__file__).parent / 'repos'
|
||||||
|
|
||||||
self._repos = self._load_repos() # str_name: Repo
|
self._repos = {}
|
||||||
|
|
||||||
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(self._load_repos(set=True)) # str_name: Repo
|
||||||
|
|
||||||
def does_repo_exist(self, name: str) -> bool:
|
def does_repo_exist(self, name: str) -> bool:
|
||||||
return name in self._repos
|
return name in self._repos
|
||||||
@ -494,7 +497,6 @@ class RepoManager:
|
|||||||
|
|
||||||
shutil.rmtree(str(repo.folder_path))
|
shutil.rmtree(str(repo.folder_path))
|
||||||
|
|
||||||
repos = self.downloader_config.repos()
|
|
||||||
try:
|
try:
|
||||||
del self._repos[name]
|
del self._repos[name]
|
||||||
except KeyError:
|
except KeyError:
|
||||||
@ -518,11 +520,14 @@ class RepoManager:
|
|||||||
await self._save_repos()
|
await self._save_repos()
|
||||||
return ret
|
return ret
|
||||||
|
|
||||||
def _load_repos(self) -> MutableMapping[str, Repo]:
|
async def _load_repos(self, set=False) -> MutableMapping[str, Repo]:
|
||||||
return {
|
ret = {
|
||||||
name: Repo.from_json(data) for name, data in
|
name: Repo.from_json(data) for name, data in
|
||||||
self.downloader_config.repos().items()
|
(await self.downloader_config.repos()).items()
|
||||||
}
|
}
|
||||||
|
if set:
|
||||||
|
self._repos = ret
|
||||||
|
return ret
|
||||||
|
|
||||||
async def _save_repos(self):
|
async def _save_repos(self):
|
||||||
repo_json_info = {name: r.to_json() for name, r in self._repos.items()}
|
repo_json_info = {name: r.to_json() for name, r in self._repos.items()}
|
||||||
|
|||||||
@ -72,9 +72,9 @@ SLOT_PAYOUTS_MSG = ("Slot machine payouts:\n"
|
|||||||
|
|
||||||
def guild_only_check():
|
def guild_only_check():
|
||||||
async def pred(ctx: commands.Context):
|
async def pred(ctx: commands.Context):
|
||||||
if bank.is_global():
|
if await bank.is_global():
|
||||||
return True
|
return True
|
||||||
elif not bank.is_global() and ctx.guild is not None:
|
elif not await bank.is_global() and ctx.guild is not None:
|
||||||
return True
|
return True
|
||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
@ -146,8 +146,8 @@ class Economy:
|
|||||||
if user is None:
|
if user is None:
|
||||||
user = ctx.author
|
user = ctx.author
|
||||||
|
|
||||||
bal = bank.get_balance(user)
|
bal = await bank.get_balance(user)
|
||||||
currency = bank.get_currency_name(ctx.guild)
|
currency = await bank.get_currency_name(ctx.guild)
|
||||||
|
|
||||||
await ctx.send("{}'s balance is {} {}".format(
|
await ctx.send("{}'s balance is {} {}".format(
|
||||||
user.display_name, bal, currency))
|
user.display_name, bal, currency))
|
||||||
@ -156,7 +156,7 @@ class Economy:
|
|||||||
async def transfer(self, ctx: commands.Context, to: discord.Member, amount: int):
|
async def transfer(self, ctx: commands.Context, to: discord.Member, amount: int):
|
||||||
"""Transfer currency to other users"""
|
"""Transfer currency to other users"""
|
||||||
from_ = ctx.author
|
from_ = ctx.author
|
||||||
currency = bank.get_currency_name(ctx.guild)
|
currency = await bank.get_currency_name(ctx.guild)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await bank.transfer_credits(from_, to, amount)
|
await bank.transfer_credits(from_, to, amount)
|
||||||
@ -206,12 +206,12 @@ class Economy:
|
|||||||
await ctx.send(
|
await ctx.send(
|
||||||
"This will delete all bank accounts for {}.\nIf you're sure, type "
|
"This will delete all bank accounts for {}.\nIf you're sure, type "
|
||||||
"{}bank reset yes".format(
|
"{}bank reset yes".format(
|
||||||
self.bot.user.name if bank.is_global() else "this guild",
|
self.bot.user.name if await bank.is_global() else "this guild",
|
||||||
ctx.prefix
|
ctx.prefix
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if bank.is_global():
|
if await bank.is_global():
|
||||||
# Bank being global means that the check would cause only
|
# Bank being global means that the check would cause only
|
||||||
# the owner and any co-owners to be able to run the command
|
# the owner and any co-owners to be able to run the command
|
||||||
# so if we're in the function, it's safe to assume that the
|
# so if we're in the function, it's safe to assume that the
|
||||||
@ -232,18 +232,18 @@ class Economy:
|
|||||||
guild = ctx.guild
|
guild = ctx.guild
|
||||||
|
|
||||||
cur_time = calendar.timegm(ctx.message.created_at.utctimetuple())
|
cur_time = calendar.timegm(ctx.message.created_at.utctimetuple())
|
||||||
credits_name = bank.get_currency_name(ctx.guild)
|
credits_name = await bank.get_currency_name(ctx.guild)
|
||||||
if bank.is_global():
|
if await bank.is_global():
|
||||||
next_payday = self.config.user(author).next_payday()
|
next_payday = await self.config.user(author).next_payday()
|
||||||
if cur_time >= next_payday:
|
if cur_time >= next_payday:
|
||||||
await bank.deposit_credits(author, self.config.PAYDAY_CREDITS())
|
await bank.deposit_credits(author, await self.config.PAYDAY_CREDITS())
|
||||||
next_payday = cur_time + self.config.PAYDAY_TIME()
|
next_payday = cur_time + await self.config.PAYDAY_TIME()
|
||||||
await self.config.user(author).next_payday.set(next_payday)
|
await self.config.user(author).next_payday.set(next_payday)
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
"{} Here, take some {}. Enjoy! (+{}"
|
"{} Here, take some {}. Enjoy! (+{}"
|
||||||
" {}!)".format(
|
" {}!)".format(
|
||||||
author.mention, credits_name,
|
author.mention, credits_name,
|
||||||
str(self.config.PAYDAY_CREDITS()),
|
str(await self.config.PAYDAY_CREDITS()),
|
||||||
credits_name
|
credits_name
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@ -254,16 +254,16 @@ class Economy:
|
|||||||
" wait {}.".format(author.mention, dtime)
|
" wait {}.".format(author.mention, dtime)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
next_payday = self.config.member(author).next_payday()
|
next_payday = await self.config.member(author).next_payday()
|
||||||
if cur_time >= next_payday:
|
if cur_time >= next_payday:
|
||||||
await bank.deposit_credits(author, self.config.guild(guild).PAYDAY_CREDITS())
|
await bank.deposit_credits(author, await self.config.guild(guild).PAYDAY_CREDITS())
|
||||||
next_payday = cur_time + self.config.guild(guild).PAYDAY_TIME()
|
next_payday = cur_time + await self.config.guild(guild).PAYDAY_TIME()
|
||||||
await self.config.member(author).next_payday.set(next_payday)
|
await self.config.member(author).next_payday.set(next_payday)
|
||||||
await ctx.send(
|
await ctx.send(
|
||||||
"{} Here, take some {}. Enjoy! (+{}"
|
"{} Here, take some {}. Enjoy! (+{}"
|
||||||
" {}!)".format(
|
" {}!)".format(
|
||||||
author.mention, credits_name,
|
author.mention, credits_name,
|
||||||
str(self.config.guild(guild).PAYDAY_CREDITS()),
|
str(await self.config.guild(guild).PAYDAY_CREDITS()),
|
||||||
credits_name))
|
credits_name))
|
||||||
else:
|
else:
|
||||||
dtime = self.display_time(next_payday - cur_time)
|
dtime = self.display_time(next_payday - cur_time)
|
||||||
@ -282,10 +282,10 @@ class Economy:
|
|||||||
if top < 1:
|
if top < 1:
|
||||||
top = 10
|
top = 10
|
||||||
if bank.is_global():
|
if bank.is_global():
|
||||||
bank_sorted = sorted(bank.get_global_accounts(ctx.author),
|
bank_sorted = sorted(await bank.get_global_accounts(ctx.author),
|
||||||
key=lambda x: x.balance, reverse=True)
|
key=lambda x: x.balance, reverse=True)
|
||||||
else:
|
else:
|
||||||
bank_sorted = sorted(bank.get_guild_accounts(guild),
|
bank_sorted = sorted(await bank.get_guild_accounts(guild),
|
||||||
key=lambda x: x.balance, reverse=True)
|
key=lambda x: x.balance, reverse=True)
|
||||||
if len(bank_sorted) < top:
|
if len(bank_sorted) < top:
|
||||||
top = len(bank_sorted)
|
top = len(bank_sorted)
|
||||||
@ -320,14 +320,14 @@ class Economy:
|
|||||||
author = ctx.author
|
author = ctx.author
|
||||||
guild = ctx.guild
|
guild = ctx.guild
|
||||||
channel = ctx.channel
|
channel = ctx.channel
|
||||||
if bank.is_global():
|
if await bank.is_global():
|
||||||
valid_bid = self.config.SLOT_MIN() <= bid <= self.config.SLOT_MAX()
|
valid_bid = await self.config.SLOT_MIN() <= bid <= await self.config.SLOT_MAX()
|
||||||
slot_time = self.config.SLOT_TIME()
|
slot_time = await self.config.SLOT_TIME()
|
||||||
last_slot = self.config.user(author).last_slot()
|
last_slot = await self.config.user(author).last_slot()
|
||||||
else:
|
else:
|
||||||
valid_bid = self.config.guild(guild).SLOT_MIN() <= bid <= self.config.guild(guild).SLOT_MAX()
|
valid_bid = await self.config.guild(guild).SLOT_MIN() <= bid <= await self.config.guild(guild).SLOT_MAX()
|
||||||
slot_time = self.config.guild(guild).SLOT_TIME()
|
slot_time = await self.config.guild(guild).SLOT_TIME()
|
||||||
last_slot = self.config.member(author).last_slot()
|
last_slot = await self.config.member(author).last_slot()
|
||||||
now = calendar.timegm(ctx.message.created_at.utctimetuple())
|
now = calendar.timegm(ctx.message.created_at.utctimetuple())
|
||||||
|
|
||||||
if (now - last_slot) < slot_time:
|
if (now - last_slot) < slot_time:
|
||||||
@ -336,10 +336,10 @@ class Economy:
|
|||||||
if not valid_bid:
|
if not valid_bid:
|
||||||
await ctx.send("That's an invalid bid amount, sorry :/")
|
await ctx.send("That's an invalid bid amount, sorry :/")
|
||||||
return
|
return
|
||||||
if not bank.can_spend(author, bid):
|
if not await bank.can_spend(author, bid):
|
||||||
await ctx.send("You ain't got enough money, friend.")
|
await ctx.send("You ain't got enough money, friend.")
|
||||||
return
|
return
|
||||||
if bank.is_global():
|
if await bank.is_global():
|
||||||
await self.config.user(author).last_slot.set(now)
|
await self.config.user(author).last_slot.set(now)
|
||||||
else:
|
else:
|
||||||
await self.config.member(author).last_slot.set(now)
|
await self.config.member(author).last_slot.set(now)
|
||||||
@ -379,7 +379,7 @@ class Economy:
|
|||||||
payout = PAYOUTS["2 symbols"]
|
payout = PAYOUTS["2 symbols"]
|
||||||
|
|
||||||
if payout:
|
if payout:
|
||||||
then = bank.get_balance(author)
|
then = await bank.get_balance(author)
|
||||||
pay = payout["payout"](bid)
|
pay = payout["payout"](bid)
|
||||||
now = then - bid + pay
|
now = then - bid + pay
|
||||||
await bank.set_balance(author, now)
|
await bank.set_balance(author, now)
|
||||||
@ -387,7 +387,7 @@ class Economy:
|
|||||||
"".format(slot, author.mention,
|
"".format(slot, author.mention,
|
||||||
payout["phrase"], bid, then, now))
|
payout["phrase"], bid, then, now))
|
||||||
else:
|
else:
|
||||||
then = bank.get_balance(author)
|
then = await bank.get_balance(author)
|
||||||
await bank.withdraw_credits(author, bid)
|
await bank.withdraw_credits(author, bid)
|
||||||
now = then - bid
|
now = then - bid
|
||||||
await channel.send("{}\n{} Nothing!\nYour bid: {}\n{} → {}!"
|
await channel.send("{}\n{} Nothing!\nYour bid: {}\n{} → {}!"
|
||||||
@ -402,18 +402,18 @@ class Economy:
|
|||||||
if ctx.invoked_subcommand is None:
|
if ctx.invoked_subcommand is None:
|
||||||
await self.bot.send_cmd_help(ctx)
|
await self.bot.send_cmd_help(ctx)
|
||||||
if bank.is_global():
|
if bank.is_global():
|
||||||
slot_min = self.config.SLOT_MIN()
|
slot_min = await self.config.SLOT_MIN()
|
||||||
slot_max = self.config.SLOT_MAX()
|
slot_max = await self.config.SLOT_MAX()
|
||||||
slot_time = self.config.SLOT_TIME()
|
slot_time = await self.config.SLOT_TIME()
|
||||||
payday_time = self.config.PAYDAY_TIME()
|
payday_time = await self.config.PAYDAY_TIME()
|
||||||
payday_amount = self.config.PAYDAY_CREDITS()
|
payday_amount = await self.config.PAYDAY_CREDITS()
|
||||||
else:
|
else:
|
||||||
slot_min = self.config.guild(guild).SLOT_MIN()
|
slot_min = await self.config.guild(guild).SLOT_MIN()
|
||||||
slot_max = self.config.guild(guild).SLOT_MAX()
|
slot_max = await self.config.guild(guild).SLOT_MAX()
|
||||||
slot_time = self.config.guild(guild).SLOT_TIME()
|
slot_time = await self.config.guild(guild).SLOT_TIME()
|
||||||
payday_time = self.config.guild(guild).PAYDAY_TIME()
|
payday_time = await self.config.guild(guild).PAYDAY_TIME()
|
||||||
payday_amount = self.config.guild(guild).PAYDAY_CREDITS()
|
payday_amount = await self.config.guild(guild).PAYDAY_CREDITS()
|
||||||
register_amount = bank.get_default_balance(guild)
|
register_amount = await bank.get_default_balance(guild)
|
||||||
msg = box(
|
msg = box(
|
||||||
"Minimum slot bid: {}\n"
|
"Minimum slot bid: {}\n"
|
||||||
"Maximum slot bid: {}\n"
|
"Maximum slot bid: {}\n"
|
||||||
@ -436,24 +436,24 @@ class Economy:
|
|||||||
await ctx.send('Invalid min bid amount.')
|
await ctx.send('Invalid min bid amount.')
|
||||||
return
|
return
|
||||||
guild = ctx.guild
|
guild = ctx.guild
|
||||||
if bank.is_global():
|
if await bank.is_global():
|
||||||
await self.config.SLOT_MIN.set(bid)
|
await self.config.SLOT_MIN.set(bid)
|
||||||
else:
|
else:
|
||||||
await self.config.guild(guild).SLOT_MIN.set(bid)
|
await self.config.guild(guild).SLOT_MIN.set(bid)
|
||||||
credits_name = bank.get_currency_name(guild)
|
credits_name = await bank.get_currency_name(guild)
|
||||||
await ctx.send("Minimum bid is now {} {}.".format(bid, credits_name))
|
await ctx.send("Minimum bid is now {} {}.".format(bid, credits_name))
|
||||||
|
|
||||||
@economyset.command()
|
@economyset.command()
|
||||||
async def slotmax(self, ctx: commands.Context, bid: int):
|
async def slotmax(self, ctx: commands.Context, bid: int):
|
||||||
"""Maximum slot machine bid"""
|
"""Maximum slot machine bid"""
|
||||||
slot_min = self.config.SLOT_MIN()
|
slot_min = await self.config.SLOT_MIN()
|
||||||
if bid < 1 or bid < slot_min:
|
if bid < 1 or bid < slot_min:
|
||||||
await ctx.send('Invalid slotmax bid amount. Must be greater'
|
await ctx.send('Invalid slotmax bid amount. Must be greater'
|
||||||
' than slotmin.')
|
' than slotmin.')
|
||||||
return
|
return
|
||||||
guild = ctx.guild
|
guild = ctx.guild
|
||||||
credits_name = bank.get_currency_name(guild)
|
credits_name = await bank.get_currency_name(guild)
|
||||||
if bank.is_global():
|
if await bank.is_global():
|
||||||
await self.config.SLOT_MAX.set(bid)
|
await self.config.SLOT_MAX.set(bid)
|
||||||
else:
|
else:
|
||||||
await self.config.guild(guild).SLOT_MAX.set(bid)
|
await self.config.guild(guild).SLOT_MAX.set(bid)
|
||||||
@ -463,7 +463,7 @@ class Economy:
|
|||||||
async def slottime(self, ctx: commands.Context, seconds: int):
|
async def slottime(self, ctx: commands.Context, seconds: int):
|
||||||
"""Seconds between each slots use"""
|
"""Seconds between each slots use"""
|
||||||
guild = ctx.guild
|
guild = ctx.guild
|
||||||
if bank.is_global():
|
if await bank.is_global():
|
||||||
await self.config.SLOT_TIME.set(seconds)
|
await self.config.SLOT_TIME.set(seconds)
|
||||||
else:
|
else:
|
||||||
await self.config.guild(guild).SLOT_TIME.set(seconds)
|
await self.config.guild(guild).SLOT_TIME.set(seconds)
|
||||||
@ -473,7 +473,7 @@ class Economy:
|
|||||||
async def paydaytime(self, ctx: commands.Context, seconds: int):
|
async def paydaytime(self, ctx: commands.Context, seconds: int):
|
||||||
"""Seconds between each payday"""
|
"""Seconds between each payday"""
|
||||||
guild = ctx.guild
|
guild = ctx.guild
|
||||||
if bank.is_global():
|
if await bank.is_global():
|
||||||
await self.config.PAYDAY_TIME.set(seconds)
|
await self.config.PAYDAY_TIME.set(seconds)
|
||||||
else:
|
else:
|
||||||
await self.config.guild(guild).PAYDAY_TIME.set(seconds)
|
await self.config.guild(guild).PAYDAY_TIME.set(seconds)
|
||||||
@ -484,11 +484,11 @@ class Economy:
|
|||||||
async def paydayamount(self, ctx: commands.Context, creds: int):
|
async def paydayamount(self, ctx: commands.Context, creds: int):
|
||||||
"""Amount earned each payday"""
|
"""Amount earned each payday"""
|
||||||
guild = ctx.guild
|
guild = ctx.guild
|
||||||
credits_name = bank.get_currency_name(guild)
|
credits_name = await bank.get_currency_name(guild)
|
||||||
if creds <= 0:
|
if creds <= 0:
|
||||||
await ctx.send("Har har so funny.")
|
await ctx.send("Har har so funny.")
|
||||||
return
|
return
|
||||||
if bank.is_global():
|
if await bank.is_global():
|
||||||
await self.config.PAYDAY_CREDITS.set(creds)
|
await self.config.PAYDAY_CREDITS.set(creds)
|
||||||
else:
|
else:
|
||||||
await self.config.guild(guild).PAYDAY_CREDITS.set(creds)
|
await self.config.guild(guild).PAYDAY_CREDITS.set(creds)
|
||||||
@ -501,7 +501,7 @@ class Economy:
|
|||||||
guild = ctx.guild
|
guild = ctx.guild
|
||||||
if creds < 0:
|
if creds < 0:
|
||||||
creds = 0
|
creds = 0
|
||||||
credits_name = bank.get_currency_name(guild)
|
credits_name = await bank.get_currency_name(guild)
|
||||||
await bank.set_default_balance(creds, guild)
|
await bank.set_default_balance(creds, guild)
|
||||||
await ctx.send("Registering an account will now give {} {}."
|
await ctx.send("Registering an account will now give {} {}."
|
||||||
"".format(creds, credits_name))
|
"".format(creds, credits_name))
|
||||||
|
|||||||
89
core/bank.py
89
core/bank.py
@ -1,6 +1,6 @@
|
|||||||
import datetime
|
import datetime
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Tuple, Generator, Union
|
from typing import Tuple, Generator, Union, List
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
@ -78,17 +78,17 @@ def _decode_time(time: int) -> datetime.datetime:
|
|||||||
return datetime.datetime.utcfromtimestamp(time)
|
return datetime.datetime.utcfromtimestamp(time)
|
||||||
|
|
||||||
|
|
||||||
def get_balance(member: discord.Member) -> int:
|
async def get_balance(member: discord.Member) -> int:
|
||||||
"""
|
"""
|
||||||
Gets the current balance of a member.
|
Gets the current balance of a member.
|
||||||
:param member:
|
:param member:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
acc = get_account(member)
|
acc = await get_account(member)
|
||||||
return acc.balance
|
return acc.balance
|
||||||
|
|
||||||
|
|
||||||
def can_spend(member: discord.Member, amount: int) -> bool:
|
async def can_spend(member: discord.Member, amount: int) -> bool:
|
||||||
"""
|
"""
|
||||||
Determines if a member can spend the given amount.
|
Determines if a member can spend the given amount.
|
||||||
:param member:
|
:param member:
|
||||||
@ -97,7 +97,7 @@ def can_spend(member: discord.Member, amount: int) -> bool:
|
|||||||
"""
|
"""
|
||||||
if _invalid_amount(amount):
|
if _invalid_amount(amount):
|
||||||
return False
|
return False
|
||||||
return get_balance(member) > amount
|
return await get_balance(member) > amount
|
||||||
|
|
||||||
|
|
||||||
async def set_balance(member: discord.Member, amount: int) -> int:
|
async def set_balance(member: discord.Member, amount: int) -> int:
|
||||||
@ -111,17 +111,17 @@ async def set_balance(member: discord.Member, amount: int) -> int:
|
|||||||
"""
|
"""
|
||||||
if amount < 0:
|
if amount < 0:
|
||||||
raise ValueError("Not allowed to have negative balance.")
|
raise ValueError("Not allowed to have negative balance.")
|
||||||
if is_global():
|
if await is_global():
|
||||||
group = _conf.user(member)
|
group = _conf.user(member)
|
||||||
else:
|
else:
|
||||||
group = _conf.member(member)
|
group = _conf.member(member)
|
||||||
await group.balance.set(amount)
|
await group.balance.set(amount)
|
||||||
|
|
||||||
if group.created_at() == 0:
|
if await group.created_at() == 0:
|
||||||
time = _encoded_current_time()
|
time = _encoded_current_time()
|
||||||
await group.created_at.set(time)
|
await group.created_at.set(time)
|
||||||
|
|
||||||
if group.name() == "":
|
if await group.name() == "":
|
||||||
await group.name.set(member.display_name)
|
await group.name.set(member.display_name)
|
||||||
|
|
||||||
return amount
|
return amount
|
||||||
@ -144,7 +144,7 @@ async def withdraw_credits(member: discord.Member, amount: int) -> int:
|
|||||||
if _invalid_amount(amount):
|
if _invalid_amount(amount):
|
||||||
raise ValueError("Invalid withdrawal amount {} <= 0".format(amount))
|
raise ValueError("Invalid withdrawal amount {} <= 0".format(amount))
|
||||||
|
|
||||||
bal = get_balance(member)
|
bal = await get_balance(member)
|
||||||
if amount > bal:
|
if amount > bal:
|
||||||
raise ValueError("Insufficient funds {} > {}".format(amount, bal))
|
raise ValueError("Insufficient funds {} > {}".format(amount, bal))
|
||||||
|
|
||||||
@ -163,7 +163,7 @@ async def deposit_credits(member: discord.Member, amount: int) -> int:
|
|||||||
if _invalid_amount(amount):
|
if _invalid_amount(amount):
|
||||||
raise ValueError("Invalid withdrawal amount {} <= 0".format(amount))
|
raise ValueError("Invalid withdrawal amount {} <= 0".format(amount))
|
||||||
|
|
||||||
bal = get_balance(member)
|
bal = await get_balance(member)
|
||||||
return await set_balance(member, amount + bal)
|
return await set_balance(member, amount + bal)
|
||||||
|
|
||||||
|
|
||||||
@ -190,13 +190,13 @@ async def wipe_bank(user: Union[discord.User, discord.Member]):
|
|||||||
Deletes all accounts from the bank.
|
Deletes all accounts from the bank.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if is_global():
|
if await is_global():
|
||||||
await _conf.user(user).clear()
|
await _conf.user(user).clear()
|
||||||
else:
|
else:
|
||||||
await _conf.member(user).clear()
|
await _conf.member(user).clear()
|
||||||
|
|
||||||
|
|
||||||
def get_guild_accounts(guild: discord.Guild) -> Generator[Account, None, None]:
|
async def get_guild_accounts(guild: discord.Guild) -> List[Account]:
|
||||||
"""
|
"""
|
||||||
Gets all account data for the given guild.
|
Gets all account data for the given guild.
|
||||||
|
|
||||||
@ -207,14 +207,16 @@ def get_guild_accounts(guild: discord.Guild) -> Generator[Account, None, None]:
|
|||||||
if is_global():
|
if is_global():
|
||||||
raise RuntimeError("The bank is currently global.")
|
raise RuntimeError("The bank is currently global.")
|
||||||
|
|
||||||
accs = _conf.member(guild.owner).all_from_kind()
|
ret = []
|
||||||
|
accs = await _conf.member(guild.owner).all_from_kind()
|
||||||
for user_id, acc in accs.items():
|
for user_id, acc in accs.items():
|
||||||
acc_data = acc.copy() # There ya go kowlin
|
acc_data = acc.copy() # There ya go kowlin
|
||||||
acc_data['created_at'] = _decode_time(acc_data['created_at'])
|
acc_data['created_at'] = _decode_time(acc_data['created_at'])
|
||||||
yield Account(**acc_data)
|
ret.append(Account(**acc_data))
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def get_global_accounts(user: discord.User) -> Generator[Account, None, None]:
|
async def get_global_accounts(user: discord.User) -> List[Account]:
|
||||||
"""
|
"""
|
||||||
Gets all global account data.
|
Gets all global account data.
|
||||||
|
|
||||||
@ -225,44 +227,47 @@ def get_global_accounts(user: discord.User) -> Generator[Account, None, None]:
|
|||||||
if not is_global():
|
if not is_global():
|
||||||
raise RuntimeError("The bank is not currently global.")
|
raise RuntimeError("The bank is not currently global.")
|
||||||
|
|
||||||
accs = _conf.user(user).all_from_kind() # this is a dict of user -> acc
|
ret = []
|
||||||
|
accs = await _conf.user(user).all_from_kind() # this is a dict of user -> acc
|
||||||
for user_id, acc in accs.items():
|
for user_id, acc in accs.items():
|
||||||
acc_data = acc.copy()
|
acc_data = acc.copy()
|
||||||
acc_data['created_at'] = _decode_time(acc_data['created_at'])
|
acc_data['created_at'] = _decode_time(acc_data['created_at'])
|
||||||
yield Account(**acc_data)
|
ret.append(Account(**acc_data))
|
||||||
|
|
||||||
|
return ret
|
||||||
|
|
||||||
|
|
||||||
def get_account(member: Union[discord.Member, discord.User]) -> Account:
|
async def get_account(member: Union[discord.Member, discord.User]) -> Account:
|
||||||
"""
|
"""
|
||||||
Gets the appropriate account for the given member.
|
Gets the appropriate account for the given member.
|
||||||
:param member:
|
:param member:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if is_global():
|
if await is_global():
|
||||||
acc_data = _conf.user(member)().copy()
|
acc_data = (await _conf.user(member)()).copy()
|
||||||
default = _DEFAULT_USER.copy()
|
default = _DEFAULT_USER.copy()
|
||||||
else:
|
else:
|
||||||
acc_data = _conf.member(member)().copy()
|
acc_data = (await _conf.member(member)()).copy()
|
||||||
default = _DEFAULT_MEMBER.copy()
|
default = _DEFAULT_MEMBER.copy()
|
||||||
|
|
||||||
if acc_data == {}:
|
if acc_data == {}:
|
||||||
acc_data = default
|
acc_data = default
|
||||||
acc_data['name'] = member.display_name
|
acc_data['name'] = member.display_name
|
||||||
try:
|
try:
|
||||||
acc_data['balance'] = get_default_balance(member.guild)
|
acc_data['balance'] = await get_default_balance(member.guild)
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
acc_data['balance'] = get_default_balance()
|
acc_data['balance'] = await get_default_balance()
|
||||||
|
|
||||||
acc_data['created_at'] = _decode_time(acc_data['created_at'])
|
acc_data['created_at'] = _decode_time(acc_data['created_at'])
|
||||||
return Account(**acc_data)
|
return Account(**acc_data)
|
||||||
|
|
||||||
|
|
||||||
def is_global() -> bool:
|
async def is_global() -> bool:
|
||||||
"""
|
"""
|
||||||
Determines if the bank is currently global.
|
Determines if the bank is currently global.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
return _conf.is_global()
|
return await _conf.is_global()
|
||||||
|
|
||||||
|
|
||||||
async def set_global(global_: bool, user: Union[discord.User, discord.Member]) -> bool:
|
async def set_global(global_: bool, user: Union[discord.User, discord.Member]) -> bool:
|
||||||
@ -272,7 +277,7 @@ async def set_global(global_: bool, user: Union[discord.User, discord.Member]) -
|
|||||||
:param user: Must be a Member object if changing TO global mode.
|
:param user: Must be a Member object if changing TO global mode.
|
||||||
:return: New bank mode, True is global.
|
:return: New bank mode, True is global.
|
||||||
"""
|
"""
|
||||||
if is_global() is global_:
|
if (await is_global()) is global_:
|
||||||
return global_
|
return global_
|
||||||
|
|
||||||
if is_global():
|
if is_global():
|
||||||
@ -287,7 +292,7 @@ async def set_global(global_: bool, user: Union[discord.User, discord.Member]) -
|
|||||||
return global_
|
return global_
|
||||||
|
|
||||||
|
|
||||||
def get_bank_name(guild: discord.Guild=None) -> str:
|
async def get_bank_name(guild: discord.Guild=None) -> str:
|
||||||
"""
|
"""
|
||||||
Gets the current bank name. If the bank is guild-specific the
|
Gets the current bank name. If the bank is guild-specific the
|
||||||
guild parameter is required.
|
guild parameter is required.
|
||||||
@ -296,10 +301,10 @@ def get_bank_name(guild: discord.Guild=None) -> str:
|
|||||||
:param guild:
|
:param guild:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if is_global():
|
if await is_global():
|
||||||
return _conf.bank_name()
|
return await _conf.bank_name()
|
||||||
elif guild is not None:
|
elif guild is not None:
|
||||||
return _conf.guild(guild).bank_name()
|
return await _conf.guild(guild).bank_name()
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Guild parameter is required and missing.")
|
raise RuntimeError("Guild parameter is required and missing.")
|
||||||
|
|
||||||
@ -314,7 +319,7 @@ async def set_bank_name(name: str, guild: discord.Guild=None) -> str:
|
|||||||
:param guild:
|
:param guild:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if is_global():
|
if await is_global():
|
||||||
await _conf.bank_name.set(name)
|
await _conf.bank_name.set(name)
|
||||||
elif guild is not None:
|
elif guild is not None:
|
||||||
await _conf.guild(guild).bank_name.set(name)
|
await _conf.guild(guild).bank_name.set(name)
|
||||||
@ -324,7 +329,7 @@ async def set_bank_name(name: str, guild: discord.Guild=None) -> str:
|
|||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def get_currency_name(guild: discord.Guild=None) -> str:
|
async def get_currency_name(guild: discord.Guild=None) -> str:
|
||||||
"""
|
"""
|
||||||
Gets the currency name of the bank. The guild parameter is required if
|
Gets the currency name of the bank. The guild parameter is required if
|
||||||
the bank is guild-specific.
|
the bank is guild-specific.
|
||||||
@ -333,10 +338,10 @@ def get_currency_name(guild: discord.Guild=None) -> str:
|
|||||||
:param guild:
|
:param guild:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if is_global():
|
if await is_global():
|
||||||
return _conf.currency()
|
return await _conf.currency()
|
||||||
elif guild is not None:
|
elif guild is not None:
|
||||||
return _conf.guild(guild).currency()
|
return await _conf.guild(guild).currency()
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Guild must be provided.")
|
raise RuntimeError("Guild must be provided.")
|
||||||
|
|
||||||
@ -351,7 +356,7 @@ async def set_currency_name(name: str, guild: discord.Guild=None) -> str:
|
|||||||
:param guild:
|
:param guild:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if is_global():
|
if await is_global():
|
||||||
await _conf.currency.set(name)
|
await _conf.currency.set(name)
|
||||||
elif guild is not None:
|
elif guild is not None:
|
||||||
await _conf.guild(guild).currency.set(name)
|
await _conf.guild(guild).currency.set(name)
|
||||||
@ -361,7 +366,7 @@ async def set_currency_name(name: str, guild: discord.Guild=None) -> str:
|
|||||||
return name
|
return name
|
||||||
|
|
||||||
|
|
||||||
def get_default_balance(guild: discord.Guild=None) -> int:
|
async def get_default_balance(guild: discord.Guild=None) -> int:
|
||||||
"""
|
"""
|
||||||
Gets the current default balance amount. If the bank is guild-specific
|
Gets the current default balance amount. If the bank is guild-specific
|
||||||
you must pass guild.
|
you must pass guild.
|
||||||
@ -370,10 +375,10 @@ def get_default_balance(guild: discord.Guild=None) -> int:
|
|||||||
:param guild:
|
:param guild:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
if is_global():
|
if await is_global():
|
||||||
return _conf.default_balance()
|
return await _conf.default_balance()
|
||||||
elif guild is not None:
|
elif guild is not None:
|
||||||
return _conf.guild(guild).default_balance()
|
return await _conf.guild(guild).default_balance()
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Guild is missing and required!")
|
raise RuntimeError("Guild is missing and required!")
|
||||||
|
|
||||||
@ -393,9 +398,11 @@ async def set_default_balance(amount: int, guild: discord.Guild=None) -> int:
|
|||||||
if amount < 0:
|
if amount < 0:
|
||||||
raise ValueError("Amount must be greater than zero.")
|
raise ValueError("Amount must be greater than zero.")
|
||||||
|
|
||||||
if is_global():
|
if await is_global():
|
||||||
await _conf.default_balance.set(amount)
|
await _conf.default_balance.set(amount)
|
||||||
elif guild is not None:
|
elif guild is not None:
|
||||||
await _conf.guild(guild).default_balance.set(amount)
|
await _conf.guild(guild).default_balance.set(amount)
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("Guild is missing and required.")
|
raise RuntimeError("Guild is missing and required.")
|
||||||
|
|
||||||
|
return amount
|
||||||
|
|||||||
23
core/bot.py
23
core/bot.py
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import importlib.util
|
import importlib.util
|
||||||
from importlib.machinery import ModuleSpec
|
from importlib.machinery import ModuleSpec
|
||||||
|
|
||||||
@ -39,14 +40,14 @@ class Red(commands.Bot):
|
|||||||
mod_role=None
|
mod_role=None
|
||||||
)
|
)
|
||||||
|
|
||||||
def prefix_manager(bot, message):
|
async def prefix_manager(bot, message):
|
||||||
if not cli_flags.prefix:
|
if not cli_flags.prefix:
|
||||||
global_prefix = self.db.prefix()
|
global_prefix = await bot.db.prefix()
|
||||||
else:
|
else:
|
||||||
global_prefix = cli_flags.prefix
|
global_prefix = cli_flags.prefix
|
||||||
if message.guild is None:
|
if message.guild is None:
|
||||||
return global_prefix
|
return global_prefix
|
||||||
server_prefix = self.db.guild(message.guild).prefix()
|
server_prefix = await bot.db.guild(message.guild).prefix()
|
||||||
return server_prefix if server_prefix else global_prefix
|
return server_prefix if server_prefix else global_prefix
|
||||||
|
|
||||||
if "command_prefix" not in kwargs:
|
if "command_prefix" not in kwargs:
|
||||||
@ -56,7 +57,8 @@ class Red(commands.Bot):
|
|||||||
kwargs["owner_id"] = cli_flags.owner
|
kwargs["owner_id"] = cli_flags.owner
|
||||||
|
|
||||||
if "owner_id" not in kwargs:
|
if "owner_id" not in kwargs:
|
||||||
kwargs["owner_id"] = self.db.owner()
|
loop = asyncio.get_event_loop()
|
||||||
|
loop.run_until_complete(self._dict_abuse(kwargs))
|
||||||
|
|
||||||
self.counter = Counter()
|
self.counter = Counter()
|
||||||
self.uptime = None
|
self.uptime = None
|
||||||
@ -68,6 +70,15 @@ class Red(commands.Bot):
|
|||||||
|
|
||||||
super().__init__(**kwargs)
|
super().__init__(**kwargs)
|
||||||
|
|
||||||
|
async def _dict_abuse(self, indict):
|
||||||
|
"""
|
||||||
|
Please blame <@269933075037814786> for this.
|
||||||
|
:param indict:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
|
||||||
|
indict['owner_id'] = await self.db.owner()
|
||||||
|
|
||||||
async def is_owner(self, user):
|
async def is_owner(self, user):
|
||||||
if user.id in self._co_owners:
|
if user.id in self._co_owners:
|
||||||
return True
|
return True
|
||||||
@ -103,13 +114,13 @@ class Red(commands.Bot):
|
|||||||
await self.db.packages.set(packages)
|
await self.db.packages.set(packages)
|
||||||
|
|
||||||
async def add_loaded_package(self, pkg_name: str):
|
async def add_loaded_package(self, pkg_name: str):
|
||||||
curr_pkgs = self.db.packages()
|
curr_pkgs = await self.db.packages()
|
||||||
if pkg_name not in curr_pkgs:
|
if pkg_name not in curr_pkgs:
|
||||||
curr_pkgs.append(pkg_name)
|
curr_pkgs.append(pkg_name)
|
||||||
await self.save_packages_status(curr_pkgs)
|
await self.save_packages_status(curr_pkgs)
|
||||||
|
|
||||||
async def remove_loaded_package(self, pkg_name: str):
|
async def remove_loaded_package(self, pkg_name: str):
|
||||||
curr_pkgs = self.db.packages()
|
curr_pkgs = await self.db.packages()
|
||||||
if pkg_name in curr_pkgs:
|
if pkg_name in curr_pkgs:
|
||||||
await self.save_packages_status([p for p in curr_pkgs if p != pkg_name])
|
await self.save_packages_status([p for p in curr_pkgs if p != pkg_name])
|
||||||
|
|
||||||
|
|||||||
@ -31,26 +31,29 @@ class CogManager:
|
|||||||
install_path=str(bot_dir.resolve() / "cogs")
|
install_path=str(bot_dir.resolve() / "cogs")
|
||||||
)
|
)
|
||||||
|
|
||||||
self._paths = set(list(self.conf.paths()) + list(paths))
|
self._paths = list(paths)
|
||||||
|
|
||||||
@property
|
async def paths(self) -> Tuple[Path, ...]:
|
||||||
def paths(self) -> Tuple[Path, ...]:
|
|
||||||
"""
|
"""
|
||||||
This will return all currently valid path directories.
|
This will return all currently valid path directories.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
paths = [Path(p) for p in self._paths]
|
conf_paths = await self.conf.paths()
|
||||||
|
other_paths = self._paths
|
||||||
|
|
||||||
|
all_paths = set(list(conf_paths) + list(other_paths))
|
||||||
|
|
||||||
|
paths = [Path(p) for p in all_paths]
|
||||||
if self.install_path not in paths:
|
if self.install_path not in paths:
|
||||||
paths.insert(0, self.install_path)
|
paths.insert(0, await self.install_path())
|
||||||
return tuple(p.resolve() for p in paths if p.is_dir())
|
return tuple(p.resolve() for p in paths if p.is_dir())
|
||||||
|
|
||||||
@property
|
async def install_path(self) -> Path:
|
||||||
def install_path(self) -> Path:
|
|
||||||
"""
|
"""
|
||||||
Returns the install path for 3rd party cogs.
|
Returns the install path for 3rd party cogs.
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
p = Path(self.conf.install_path())
|
p = Path(await self.conf.install_path())
|
||||||
return p.resolve()
|
return p.resolve()
|
||||||
|
|
||||||
async def set_install_path(self, path: Path) -> Path:
|
async def set_install_path(self, path: Path) -> Path:
|
||||||
@ -99,10 +102,10 @@ class CogManager:
|
|||||||
if not path.is_dir():
|
if not path.is_dir():
|
||||||
raise InvalidPath("'{}' is not a valid directory.".format(path))
|
raise InvalidPath("'{}' is not a valid directory.".format(path))
|
||||||
|
|
||||||
if path == self.install_path:
|
if path == await self.install_path():
|
||||||
raise ValueError("Cannot add the install path as an additional path.")
|
raise ValueError("Cannot add the install path as an additional path.")
|
||||||
|
|
||||||
all_paths = set(self.paths + (path, ))
|
all_paths = set(await self.paths() + (path, ))
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
await self.set_paths(all_paths)
|
await self.set_paths(all_paths)
|
||||||
|
|
||||||
@ -113,7 +116,7 @@ class CogManager:
|
|||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
path = self._ensure_path_obj(path)
|
path = self._ensure_path_obj(path)
|
||||||
all_paths = list(self.paths)
|
all_paths = list(await self.paths())
|
||||||
if path in all_paths:
|
if path in all_paths:
|
||||||
all_paths.remove(path) # Modifies in place
|
all_paths.remove(path) # Modifies in place
|
||||||
await self.set_paths(all_paths)
|
await self.set_paths(all_paths)
|
||||||
@ -125,11 +128,10 @@ class CogManager:
|
|||||||
:param paths_:
|
:param paths_:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
self._paths = paths_
|
|
||||||
str_paths = [str(p) for p in paths_]
|
str_paths = [str(p) for p in paths_]
|
||||||
await self.conf.paths.set(str_paths)
|
await self.conf.paths.set(str_paths)
|
||||||
|
|
||||||
def find_cog(self, name: str) -> ModuleSpec:
|
async def find_cog(self, name: str) -> ModuleSpec:
|
||||||
"""
|
"""
|
||||||
Finds a cog in the list of available path.
|
Finds a cog in the list of available path.
|
||||||
|
|
||||||
@ -137,7 +139,7 @@ class CogManager:
|
|||||||
:param name:
|
:param name:
|
||||||
:return:
|
:return:
|
||||||
"""
|
"""
|
||||||
resolved_paths = [str(p.resolve()) for p in self.paths]
|
resolved_paths = [str(p.resolve()) for p in await self.paths()]
|
||||||
for finder, module_name, _ in pkgutil.iter_modules(resolved_paths):
|
for finder, module_name, _ in pkgutil.iter_modules(resolved_paths):
|
||||||
if name == module_name:
|
if name == module_name:
|
||||||
spec = finder.find_spec(name)
|
spec = finder.find_spec(name)
|
||||||
@ -166,7 +168,7 @@ class CogManagerUI:
|
|||||||
"""
|
"""
|
||||||
Lists current cog paths in order of priority.
|
Lists current cog paths in order of priority.
|
||||||
"""
|
"""
|
||||||
install_path = ctx.bot.cog_mgr.install_path
|
install_path = await ctx.bot.cog_mgr.install_path()
|
||||||
cog_paths = ctx.bot.cog_mgr.paths
|
cog_paths = ctx.bot.cog_mgr.paths
|
||||||
cog_paths = [p for p in cog_paths if p != install_path]
|
cog_paths = [p for p in cog_paths if p != install_path]
|
||||||
|
|
||||||
@ -204,7 +206,7 @@ class CogManagerUI:
|
|||||||
Removes a path from the available cog paths given the path_number
|
Removes a path from the available cog paths given the path_number
|
||||||
from !paths
|
from !paths
|
||||||
"""
|
"""
|
||||||
cog_paths = ctx.bot.cog_mgr.paths
|
cog_paths = await ctx.bot.cog_mgr.paths()
|
||||||
try:
|
try:
|
||||||
to_remove = cog_paths[path_number]
|
to_remove = cog_paths[path_number]
|
||||||
except IndexError:
|
except IndexError:
|
||||||
@ -224,7 +226,7 @@ class CogManagerUI:
|
|||||||
from_ -= 1
|
from_ -= 1
|
||||||
to -= 1
|
to -= 1
|
||||||
|
|
||||||
all_paths = list(ctx.bot.cog_mgr.paths)
|
all_paths = list(await ctx.bot.cog_mgr.paths())
|
||||||
try:
|
try:
|
||||||
to_move = all_paths.pop(from_)
|
to_move = all_paths.pop(from_)
|
||||||
except IndexError:
|
except IndexError:
|
||||||
@ -257,6 +259,6 @@ class CogManagerUI:
|
|||||||
await ctx.send("That path does not exist.")
|
await ctx.send("That path does not exist.")
|
||||||
return
|
return
|
||||||
|
|
||||||
install_path = ctx.bot.cog_mgr.install_path
|
install_path = await ctx.bot.cog_mgr.install_path()
|
||||||
await ctx.send("The bot will install new cogs to the `{}`"
|
await ctx.send("The bot will install new cogs to the `{}`"
|
||||||
" directory.".format(install_path))
|
" directory.".format(install_path))
|
||||||
|
|||||||
@ -38,6 +38,14 @@ class Value:
|
|||||||
def identifiers(self):
|
def identifiers(self):
|
||||||
return tuple(str(i) for i in self._identifiers)
|
return tuple(str(i) for i in self._identifiers)
|
||||||
|
|
||||||
|
async def _get(self, default):
|
||||||
|
driver = self.spawner.get_driver()
|
||||||
|
try:
|
||||||
|
ret = await driver.get(self.identifiers)
|
||||||
|
except KeyError:
|
||||||
|
return default or self.default
|
||||||
|
return ret
|
||||||
|
|
||||||
def __call__(self, default=None):
|
def __call__(self, default=None):
|
||||||
"""
|
"""
|
||||||
Each :py:class:`Value` object is created by the :py:meth:`Group.__getattr__` method.
|
Each :py:class:`Value` object is created by the :py:meth:`Group.__getattr__` method.
|
||||||
@ -46,25 +54,26 @@ class Value:
|
|||||||
|
|
||||||
For example::
|
For example::
|
||||||
|
|
||||||
foo = conf.guild(some_guild).foo()
|
foo = await conf.guild(some_guild).foo()
|
||||||
|
|
||||||
# Is equivalent to this
|
# Is equivalent to this
|
||||||
|
|
||||||
group_obj = conf.guild(some_guild)
|
group_obj = conf.guild(some_guild)
|
||||||
value_obj = conf.foo
|
value_obj = conf.foo
|
||||||
foo = value_obj()
|
foo = await value_obj()
|
||||||
|
|
||||||
|
.. important::
|
||||||
|
|
||||||
|
This is now, for all intents and purposes, a coroutine.
|
||||||
|
|
||||||
:param default:
|
:param default:
|
||||||
This argument acts as an override for the registered default provided by :py:attr:`default`. This argument
|
This argument acts as an override for the registered default provided by :py:attr:`default`. This argument
|
||||||
is ignored if its value is :python:`None`.
|
is ignored if its value is :python:`None`.
|
||||||
:type default: Optional[object]
|
:type default: Optional[object]
|
||||||
|
:return:
|
||||||
|
A coroutine object that must be awaited.
|
||||||
"""
|
"""
|
||||||
driver = self.spawner.get_driver()
|
return self._get(default)
|
||||||
try:
|
|
||||||
ret = driver.get(self.identifiers)
|
|
||||||
except KeyError:
|
|
||||||
return default or self.default
|
|
||||||
return ret
|
|
||||||
|
|
||||||
async def set(self, value):
|
async def set(self, value):
|
||||||
"""
|
"""
|
||||||
@ -182,7 +191,7 @@ class Group(Value):
|
|||||||
|
|
||||||
return not isinstance(default, dict)
|
return not isinstance(default, dict)
|
||||||
|
|
||||||
def get_attr(self, item: str, default=None, resolve=True):
|
async def get_attr(self, item: str, default=None, resolve=True):
|
||||||
"""
|
"""
|
||||||
This is available to use as an alternative to using normal Python attribute access. It is required if you find
|
This is available to use as an alternative to using normal Python attribute access. It is required if you find
|
||||||
a need for dynamic attribute access.
|
a need for dynamic attribute access.
|
||||||
@ -198,7 +207,7 @@ class Group(Value):
|
|||||||
user = ctx.author
|
user = ctx.author
|
||||||
|
|
||||||
# Where the value of item is the name of the data field in Config
|
# Where the value of item is the name of the data field in Config
|
||||||
await ctx.send(self.conf.user(user).get_attr(item))
|
await ctx.send(await self.conf.user(user).get_attr(item))
|
||||||
|
|
||||||
:param str item:
|
:param str item:
|
||||||
The name of the data field in :py:class:`.Config`.
|
The name of the data field in :py:class:`.Config`.
|
||||||
@ -211,20 +220,20 @@ class Group(Value):
|
|||||||
"""
|
"""
|
||||||
value = getattr(self, item)
|
value = getattr(self, item)
|
||||||
if resolve:
|
if resolve:
|
||||||
return value(default=default)
|
return await value(default=default)
|
||||||
else:
|
else:
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def all(self) -> dict:
|
async def all(self) -> dict:
|
||||||
"""
|
"""
|
||||||
This method allows you to get "all" of a particular group of data. It will return the dictionary of all data
|
This method allows you to get "all" of a particular group of data. It will return the dictionary of all data
|
||||||
for a particular Guild/Channel/Role/User/Member etc.
|
for a particular Guild/Channel/Role/User/Member etc.
|
||||||
|
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
return self()
|
return await self()
|
||||||
|
|
||||||
def all_from_kind(self) -> dict:
|
async def all_from_kind(self) -> dict:
|
||||||
"""
|
"""
|
||||||
This method allows you to get all data from all entries in a given Kind. It will return a dictionary of Kind
|
This method allows you to get all data from all entries in a given Kind. It will return a dictionary of Kind
|
||||||
ID's -> data.
|
ID's -> data.
|
||||||
@ -232,7 +241,7 @@ class Group(Value):
|
|||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
return self._super_group()
|
return await self._super_group()
|
||||||
|
|
||||||
async def set(self, value):
|
async def set(self, value):
|
||||||
if not isinstance(value, dict):
|
if not isinstance(value, dict):
|
||||||
@ -292,18 +301,18 @@ class MemberGroup(Group):
|
|||||||
)
|
)
|
||||||
return group_obj
|
return group_obj
|
||||||
|
|
||||||
def all_guilds(self) -> dict:
|
async def all_guilds(self) -> dict:
|
||||||
"""
|
"""
|
||||||
Returns a dict of :code:`GUILD_ID -> MEMBER_ID -> data`.
|
Returns a dict of :code:`GUILD_ID -> MEMBER_ID -> data`.
|
||||||
|
|
||||||
:rtype: dict
|
:rtype: dict
|
||||||
"""
|
"""
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
return self._super_group()
|
return await self._super_group()
|
||||||
|
|
||||||
def all(self) -> dict:
|
async def all(self) -> dict:
|
||||||
# noinspection PyTypeChecker
|
# noinspection PyTypeChecker
|
||||||
return self._guild_group()
|
return await self._guild_group()
|
||||||
|
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
@ -315,7 +324,7 @@ class Config:
|
|||||||
however the process for accessing global data is a bit different. There is no :python:`global` method
|
however the process for accessing global data is a bit different. There is no :python:`global` method
|
||||||
because global data is accessed by normal attribute access::
|
because global data is accessed by normal attribute access::
|
||||||
|
|
||||||
conf.foo()
|
await conf.foo()
|
||||||
|
|
||||||
.. py:attribute:: cog_name
|
.. py:attribute:: cog_name
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,7 @@ class Core:
|
|||||||
async def load(self, ctx, *, cog_name: str):
|
async def load(self, ctx, *, cog_name: str):
|
||||||
"""Loads a package"""
|
"""Loads a package"""
|
||||||
try:
|
try:
|
||||||
spec = ctx.bot.cog_mgr.find_cog(cog_name)
|
spec = await ctx.bot.cog_mgr.find_cog(cog_name)
|
||||||
except NoModuleFound:
|
except NoModuleFound:
|
||||||
await ctx.send("No module by that name was found in any"
|
await ctx.send("No module by that name was found in any"
|
||||||
" cog path.")
|
" cog path.")
|
||||||
@ -63,7 +63,7 @@ class Core:
|
|||||||
ctx.bot.unload_extension(cog_name)
|
ctx.bot.unload_extension(cog_name)
|
||||||
self.cleanup_and_refresh_modules(cog_name)
|
self.cleanup_and_refresh_modules(cog_name)
|
||||||
try:
|
try:
|
||||||
spec = ctx.bot.cog_mgr.find_cog(cog_name)
|
spec = await ctx.bot.cog_mgr.find_cog(cog_name)
|
||||||
ctx.bot.load_extension(spec)
|
ctx.bot.load_extension(spec)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception("Package reloading failed", exc_info=e)
|
log.exception("Package reloading failed", exc_info=e)
|
||||||
|
|||||||
@ -5,7 +5,7 @@ class BaseDriver:
|
|||||||
def get_driver(self):
|
def get_driver(self):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
def get(self, identifiers: Tuple[str]):
|
async def get(self, identifiers: Tuple[str]):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
async def set(self, identifiers: Tuple[str], value):
|
async def set(self, identifiers: Tuple[str], value):
|
||||||
|
|||||||
@ -32,7 +32,7 @@ class JSON(BaseDriver):
|
|||||||
def get_driver(self):
|
def get_driver(self):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def get(self, identifiers: Tuple[str]):
|
async def get(self, identifiers: Tuple[str]):
|
||||||
partial = self.data
|
partial = self.data
|
||||||
for i in identifiers:
|
for i in identifiers:
|
||||||
partial = partial[i]
|
partial = partial[i]
|
||||||
|
|||||||
@ -34,11 +34,11 @@ def init_events(bot, cli_flags):
|
|||||||
if cli_flags.no_cogs is False:
|
if cli_flags.no_cogs is False:
|
||||||
print("Loading packages...")
|
print("Loading packages...")
|
||||||
failed = []
|
failed = []
|
||||||
packages = bot.db.packages()
|
packages = await bot.db.packages()
|
||||||
|
|
||||||
for package in packages:
|
for package in packages:
|
||||||
try:
|
try:
|
||||||
spec = bot.cog_mgr.find_cog(package)
|
spec = await bot.cog_mgr.find_cog(package)
|
||||||
bot.load_extension(spec)
|
bot.load_extension(spec)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
log.exception("Failed to load package {}".format(package),
|
log.exception("Failed to load package {}".format(package),
|
||||||
|
|||||||
27
main.py
27
main.py
@ -73,6 +73,17 @@ def determine_main_folder() -> Path:
|
|||||||
return Path(os.path.dirname(__file__)).resolve()
|
return Path(os.path.dirname(__file__)).resolve()
|
||||||
|
|
||||||
|
|
||||||
|
async def _get_prefix_and_token(red, indict):
|
||||||
|
"""
|
||||||
|
Again, please blame <@269933075037814786> for this.
|
||||||
|
:param indict:
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
indict['token'] = await red.db.token()
|
||||||
|
indict['prefix'] = await red.db.prefix()
|
||||||
|
indict['enable_sentry'] = await red.db.enable_sentry()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
cli_flags = parse_cli_flags()
|
cli_flags = parse_cli_flags()
|
||||||
log, sentry_log = init_loggers(cli_flags)
|
log, sentry_log = init_loggers(cli_flags)
|
||||||
@ -89,8 +100,13 @@ if __name__ == '__main__':
|
|||||||
if cli_flags.dev:
|
if cli_flags.dev:
|
||||||
red.add_cog(Dev())
|
red.add_cog(Dev())
|
||||||
|
|
||||||
token = os.environ.get("RED_TOKEN", red.db.token())
|
loop = asyncio.get_event_loop()
|
||||||
prefix = cli_flags.prefix or red.db.prefix()
|
tmp_data = {}
|
||||||
|
loop.run_until_complete(_get_prefix_and_token(red, tmp_data))
|
||||||
|
|
||||||
|
token = os.environ.get("RED_TOKEN", tmp_data['token'])
|
||||||
|
prefix = cli_flags.prefix or tmp_data['prefix']
|
||||||
|
enable_sentry = tmp_data['enable_sentry']
|
||||||
|
|
||||||
if token is None or not prefix:
|
if token is None or not prefix:
|
||||||
if cli_flags.no_prompt is False:
|
if cli_flags.no_prompt is False:
|
||||||
@ -102,13 +118,14 @@ if __name__ == '__main__':
|
|||||||
log.critical("Token and prefix must be set in order to login.")
|
log.critical("Token and prefix must be set in order to login.")
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if red.db.enable_sentry() is None:
|
if enable_sentry is None:
|
||||||
ask_sentry(red)
|
ask_sentry(red)
|
||||||
|
|
||||||
if red.db.enable_sentry():
|
loop.run_until_complete(_get_prefix_and_token(red, tmp_data))
|
||||||
|
|
||||||
|
if tmp_data['enable_sentry']:
|
||||||
init_sentry_logging(sentry_log)
|
init_sentry_logging(sentry_log)
|
||||||
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
cleanup_tasks = True
|
cleanup_tasks = True
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -16,12 +16,14 @@ def test_is_valid_alias_name(alias):
|
|||||||
assert alias.is_valid_alias_name("not valid name") is False
|
assert alias.is_valid_alias_name("not valid name") is False
|
||||||
|
|
||||||
|
|
||||||
def test_empty_guild_aliases(alias, empty_guild):
|
@pytest.mark.asyncio
|
||||||
assert list(alias.unloaded_aliases(empty_guild)) == []
|
async def test_empty_guild_aliases(alias, empty_guild):
|
||||||
|
assert list(await alias.unloaded_aliases(empty_guild)) == []
|
||||||
|
|
||||||
|
|
||||||
def test_empty_global_aliases(alias):
|
@pytest.mark.asyncio
|
||||||
assert list(alias.unloaded_global_aliases()) == []
|
async def test_empty_global_aliases(alias):
|
||||||
|
assert list(await alias.unloaded_global_aliases()) == []
|
||||||
|
|
||||||
|
|
||||||
async def create_test_guild_alias(alias, ctx):
|
async def create_test_guild_alias(alias, ctx):
|
||||||
@ -36,7 +38,7 @@ async def create_test_global_alias(alias, ctx):
|
|||||||
async def test_add_guild_alias(alias, ctx):
|
async def test_add_guild_alias(alias, ctx):
|
||||||
await create_test_guild_alias(alias, ctx)
|
await create_test_guild_alias(alias, ctx)
|
||||||
|
|
||||||
is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
|
is_alias, alias_obj = await alias.is_alias(ctx.guild, "test")
|
||||||
assert is_alias is True
|
assert is_alias is True
|
||||||
assert alias_obj.global_ is False
|
assert alias_obj.global_ is False
|
||||||
|
|
||||||
@ -44,19 +46,19 @@ async def test_add_guild_alias(alias, ctx):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_guild_alias(alias, ctx):
|
async def test_delete_guild_alias(alias, ctx):
|
||||||
await create_test_guild_alias(alias, ctx)
|
await create_test_guild_alias(alias, ctx)
|
||||||
is_alias, _ = alias.is_alias(ctx.guild, "test")
|
is_alias, _ = await alias.is_alias(ctx.guild, "test")
|
||||||
assert is_alias is True
|
assert is_alias is True
|
||||||
|
|
||||||
await alias.delete_alias(ctx, "test")
|
await alias.delete_alias(ctx, "test")
|
||||||
|
|
||||||
is_alias, _ = alias.is_alias(ctx.guild, "test")
|
is_alias, _ = await alias.is_alias(ctx.guild, "test")
|
||||||
assert is_alias is False
|
assert is_alias is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_add_global_alias(alias, ctx):
|
async def test_add_global_alias(alias, ctx):
|
||||||
await create_test_global_alias(alias, ctx)
|
await create_test_global_alias(alias, ctx)
|
||||||
is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
|
is_alias, alias_obj = await alias.is_alias(ctx.guild, "test")
|
||||||
|
|
||||||
assert is_alias is True
|
assert is_alias is True
|
||||||
assert alias_obj.global_ is True
|
assert alias_obj.global_ is True
|
||||||
@ -65,7 +67,7 @@ async def test_add_global_alias(alias, ctx):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_delete_global_alias(alias, ctx):
|
async def test_delete_global_alias(alias, ctx):
|
||||||
await create_test_global_alias(alias, ctx)
|
await create_test_global_alias(alias, ctx)
|
||||||
is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
|
is_alias, alias_obj = await alias.is_alias(ctx.guild, "test")
|
||||||
assert is_alias is True
|
assert is_alias is True
|
||||||
assert alias_obj.global_ is True
|
assert alias_obj.global_ is True
|
||||||
|
|
||||||
|
|||||||
@ -11,13 +11,14 @@ def bank(config):
|
|||||||
return bank
|
return bank
|
||||||
|
|
||||||
|
|
||||||
def test_bank_register(bank, ctx):
|
@pytest.mark.asyncio
|
||||||
default_bal = bank.get_default_balance(ctx.guild)
|
async def test_bank_register(bank, ctx):
|
||||||
assert default_bal == bank.get_account(ctx.author).balance
|
default_bal = await bank.get_default_balance(ctx.guild)
|
||||||
|
assert default_bal == (await bank.get_account(ctx.author)).balance
|
||||||
|
|
||||||
|
|
||||||
async def has_account(member, bank):
|
async def has_account(member, bank):
|
||||||
balance = bank.get_balance(member)
|
balance = await bank.get_balance(member)
|
||||||
if balance == 0:
|
if balance == 0:
|
||||||
balance = 1
|
balance = 1
|
||||||
await bank.set_balance(member, balance)
|
await bank.set_balance(member, balance)
|
||||||
@ -27,11 +28,11 @@ async def has_account(member, bank):
|
|||||||
async def test_bank_transfer(bank, member_factory):
|
async def test_bank_transfer(bank, member_factory):
|
||||||
mbr1 = member_factory.get()
|
mbr1 = member_factory.get()
|
||||||
mbr2 = member_factory.get()
|
mbr2 = member_factory.get()
|
||||||
bal1 = bank.get_account(mbr1).balance
|
bal1 = (await bank.get_account(mbr1)).balance
|
||||||
bal2 = bank.get_account(mbr2).balance
|
bal2 = (await bank.get_account(mbr2)).balance
|
||||||
await bank.transfer_credits(mbr1, mbr2, 50)
|
await bank.transfer_credits(mbr1, mbr2, 50)
|
||||||
newbal1 = bank.get_account(mbr1).balance
|
newbal1 = (await bank.get_account(mbr1)).balance
|
||||||
newbal2 = bank.get_account(mbr2).balance
|
newbal2 = (await bank.get_account(mbr2)).balance
|
||||||
assert bal1 - 50 == newbal1
|
assert bal1 - 50 == newbal1
|
||||||
assert bal2 + 50 == newbal2
|
assert bal2 + 50 == newbal2
|
||||||
|
|
||||||
@ -40,16 +41,16 @@ async def test_bank_transfer(bank, member_factory):
|
|||||||
async def test_bank_set(bank, member_factory):
|
async def test_bank_set(bank, member_factory):
|
||||||
mbr = member_factory.get()
|
mbr = member_factory.get()
|
||||||
await bank.set_balance(mbr, 250)
|
await bank.set_balance(mbr, 250)
|
||||||
acc = bank.get_account(mbr)
|
acc = await bank.get_account(mbr)
|
||||||
assert acc.balance == 250
|
assert acc.balance == 250
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_bank_can_spend(bank, member_factory):
|
async def test_bank_can_spend(bank, member_factory):
|
||||||
mbr = member_factory.get()
|
mbr = member_factory.get()
|
||||||
canspend = bank.can_spend(mbr, 50)
|
canspend = await bank.can_spend(mbr, 50)
|
||||||
assert canspend == (50 < bank.get_default_balance(mbr.guild))
|
assert canspend == (50 < await bank.get_default_balance(mbr.guild))
|
||||||
await bank.set_balance(mbr, 200)
|
await bank.set_balance(mbr, 200)
|
||||||
acc = bank.get_account(mbr)
|
acc = await bank.get_account(mbr)
|
||||||
canspendnow = bank.can_spend(mbr, 100)
|
canspendnow = await bank.can_spend(mbr, 100)
|
||||||
assert canspendnow
|
assert canspendnow
|
||||||
|
|||||||
@ -14,16 +14,17 @@ def default_dir(red):
|
|||||||
return red.main_dir
|
return red.main_dir
|
||||||
|
|
||||||
|
|
||||||
def test_ensure_cogs_in_paths(cog_mgr, default_dir):
|
@pytest.mark.asyncio
|
||||||
|
async def test_ensure_cogs_in_paths(cog_mgr, default_dir):
|
||||||
cogs_dir = default_dir / 'cogs'
|
cogs_dir = default_dir / 'cogs'
|
||||||
assert cogs_dir in cog_mgr.paths
|
assert cogs_dir in await cog_mgr.paths()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_install_path_set(cog_mgr: cog_manager.CogManager, tmpdir):
|
async def test_install_path_set(cog_mgr: cog_manager.CogManager, tmpdir):
|
||||||
path = Path(str(tmpdir))
|
path = Path(str(tmpdir))
|
||||||
await cog_mgr.set_install_path(path)
|
await cog_mgr.set_install_path(path)
|
||||||
assert cog_mgr.install_path == path
|
assert await cog_mgr.install_path() == path
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -38,7 +39,7 @@ async def test_install_path_set_bad(cog_mgr):
|
|||||||
async def test_add_path(cog_mgr, tmpdir):
|
async def test_add_path(cog_mgr, tmpdir):
|
||||||
path = Path(str(tmpdir))
|
path = Path(str(tmpdir))
|
||||||
await cog_mgr.add_path(path)
|
await cog_mgr.add_path(path)
|
||||||
assert path in cog_mgr.paths
|
assert path in await cog_mgr.paths()
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -54,4 +55,4 @@ async def test_remove_path(cog_mgr, tmpdir):
|
|||||||
path = Path(str(tmpdir))
|
path = Path(str(tmpdir))
|
||||||
await cog_mgr.add_path(path)
|
await cog_mgr.add_path(path)
|
||||||
await cog_mgr.remove_path(path)
|
await cog_mgr.remove_path(path)
|
||||||
assert path not in cog_mgr.paths
|
assert path not in await cog_mgr.paths()
|
||||||
|
|||||||
@ -2,10 +2,11 @@ import pytest
|
|||||||
|
|
||||||
|
|
||||||
#region Register Tests
|
#region Register Tests
|
||||||
def test_config_register_global(config):
|
@pytest.mark.asyncio
|
||||||
|
async def test_config_register_global(config):
|
||||||
config.register_global(enabled=False)
|
config.register_global(enabled=False)
|
||||||
assert config.defaults["GLOBAL"]["enabled"] is False
|
assert config.defaults["GLOBAL"]["enabled"] is False
|
||||||
assert config.enabled() is False
|
assert await config.enabled() is False
|
||||||
|
|
||||||
|
|
||||||
def test_config_register_global_badvalues(config):
|
def test_config_register_global_badvalues(config):
|
||||||
@ -13,61 +14,69 @@ def test_config_register_global_badvalues(config):
|
|||||||
config.register_global(**{"invalid var name": True})
|
config.register_global(**{"invalid var name": True})
|
||||||
|
|
||||||
|
|
||||||
def test_config_register_guild(config, empty_guild):
|
@pytest.mark.asyncio
|
||||||
|
async def test_config_register_guild(config, empty_guild):
|
||||||
config.register_guild(enabled=False, some_list=[], some_dict={})
|
config.register_guild(enabled=False, some_list=[], some_dict={})
|
||||||
assert config.defaults[config.GUILD]["enabled"] is False
|
assert config.defaults[config.GUILD]["enabled"] is False
|
||||||
assert config.defaults[config.GUILD]["some_list"] == []
|
assert config.defaults[config.GUILD]["some_list"] == []
|
||||||
assert config.defaults[config.GUILD]["some_dict"] == {}
|
assert config.defaults[config.GUILD]["some_dict"] == {}
|
||||||
|
|
||||||
assert config.guild(empty_guild).enabled() is False
|
assert await config.guild(empty_guild).enabled() is False
|
||||||
assert config.guild(empty_guild).some_list() == []
|
assert await config.guild(empty_guild).some_list() == []
|
||||||
assert config.guild(empty_guild).some_dict() == {}
|
assert await config.guild(empty_guild).some_dict() == {}
|
||||||
|
|
||||||
|
|
||||||
def test_config_register_channel(config, empty_channel):
|
@pytest.mark.asyncio
|
||||||
|
async def test_config_register_channel(config, empty_channel):
|
||||||
config.register_channel(enabled=False)
|
config.register_channel(enabled=False)
|
||||||
assert config.defaults[config.CHANNEL]["enabled"] is False
|
assert config.defaults[config.CHANNEL]["enabled"] is False
|
||||||
assert config.channel(empty_channel).enabled() is False
|
assert await config.channel(empty_channel).enabled() is False
|
||||||
|
|
||||||
|
|
||||||
def test_config_register_role(config, empty_role):
|
@pytest.mark.asyncio
|
||||||
|
async def test_config_register_role(config, empty_role):
|
||||||
config.register_role(enabled=False)
|
config.register_role(enabled=False)
|
||||||
assert config.defaults[config.ROLE]["enabled"] is False
|
assert config.defaults[config.ROLE]["enabled"] is False
|
||||||
assert config.role(empty_role).enabled() is False
|
assert await config.role(empty_role).enabled() is False
|
||||||
|
|
||||||
|
|
||||||
def test_config_register_member(config, empty_member):
|
@pytest.mark.asyncio
|
||||||
|
async def test_config_register_member(config, empty_member):
|
||||||
config.register_member(some_number=-1)
|
config.register_member(some_number=-1)
|
||||||
assert config.defaults[config.MEMBER]["some_number"] == -1
|
assert config.defaults[config.MEMBER]["some_number"] == -1
|
||||||
assert config.member(empty_member).some_number() == -1
|
assert await config.member(empty_member).some_number() == -1
|
||||||
|
|
||||||
|
|
||||||
def test_config_register_user(config, empty_user):
|
@pytest.mark.asyncio
|
||||||
|
async def test_config_register_user(config, empty_user):
|
||||||
config.register_user(some_value=None)
|
config.register_user(some_value=None)
|
||||||
assert config.defaults[config.USER]["some_value"] is None
|
assert config.defaults[config.USER]["some_value"] is None
|
||||||
assert config.user(empty_user).some_value() is None
|
assert await config.user(empty_user).some_value() is None
|
||||||
|
|
||||||
|
|
||||||
def test_config_force_register_global(config_fr):
|
@pytest.mark.asyncio
|
||||||
|
async def test_config_force_register_global(config_fr):
|
||||||
with pytest.raises(AttributeError):
|
with pytest.raises(AttributeError):
|
||||||
config_fr.enabled()
|
await config_fr.enabled()
|
||||||
|
|
||||||
config_fr.register_global(enabled=True)
|
config_fr.register_global(enabled=True)
|
||||||
assert config_fr.enabled() is True
|
assert await config_fr.enabled() is True
|
||||||
#endregion
|
#endregion
|
||||||
|
|
||||||
|
|
||||||
# Test nested registration
|
# Test nested registration
|
||||||
def test_nested_registration(config):
|
@pytest.mark.asyncio
|
||||||
|
async def test_nested_registration(config):
|
||||||
config.register_global(foo__bar__baz=False)
|
config.register_global(foo__bar__baz=False)
|
||||||
assert config.foo.bar.baz() is False
|
assert await config.foo.bar.baz() is False
|
||||||
|
|
||||||
|
|
||||||
def test_nested_registration_asdict(config):
|
@pytest.mark.asyncio
|
||||||
|
async def test_nested_registration_asdict(config):
|
||||||
defaults = {'bar': {'baz': False}}
|
defaults = {'bar': {'baz': False}}
|
||||||
config.register_global(foo=defaults)
|
config.register_global(foo=defaults)
|
||||||
|
|
||||||
assert config.foo.bar.baz() is False
|
assert await config.foo.bar.baz() is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -75,20 +84,22 @@ async def test_nested_registration_and_changing(config):
|
|||||||
defaults = {'bar': {'baz': False}}
|
defaults = {'bar': {'baz': False}}
|
||||||
config.register_global(foo=defaults)
|
config.register_global(foo=defaults)
|
||||||
|
|
||||||
assert config.foo.bar.baz() is False
|
assert await config.foo.bar.baz() is False
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await config.foo.set(True)
|
await config.foo.set(True)
|
||||||
|
|
||||||
|
|
||||||
def test_doubleset_default(config):
|
@pytest.mark.asyncio
|
||||||
|
async def test_doubleset_default(config):
|
||||||
config.register_global(foo=True)
|
config.register_global(foo=True)
|
||||||
config.register_global(foo=False)
|
config.register_global(foo=False)
|
||||||
|
|
||||||
assert config.foo() is False
|
assert await config.foo() is False
|
||||||
|
|
||||||
|
|
||||||
def test_nested_registration_multidict(config):
|
@pytest.mark.asyncio
|
||||||
|
async def test_nested_registration_multidict(config):
|
||||||
defaults = {
|
defaults = {
|
||||||
"foo": {
|
"foo": {
|
||||||
"bar": {
|
"bar": {
|
||||||
@ -99,8 +110,8 @@ def test_nested_registration_multidict(config):
|
|||||||
}
|
}
|
||||||
config.register_global(**defaults)
|
config.register_global(**defaults)
|
||||||
|
|
||||||
assert config.foo.bar.baz() is True
|
assert await config.foo.bar.baz() is True
|
||||||
assert config.blah() is True
|
assert await config.blah() is True
|
||||||
|
|
||||||
|
|
||||||
def test_nested_group_value_badreg(config):
|
def test_nested_group_value_badreg(config):
|
||||||
@ -109,56 +120,66 @@ def test_nested_group_value_badreg(config):
|
|||||||
config.register_global(foo__bar=False)
|
config.register_global(foo__bar=False)
|
||||||
|
|
||||||
|
|
||||||
def test_nested_toplevel_reg(config):
|
@pytest.mark.asyncio
|
||||||
|
async def test_nested_toplevel_reg(config):
|
||||||
defaults = {'bar': True, 'baz': False}
|
defaults = {'bar': True, 'baz': False}
|
||||||
config.register_global(foo=defaults)
|
config.register_global(foo=defaults)
|
||||||
|
|
||||||
assert config.foo.bar() is True
|
assert await config.foo.bar() is True
|
||||||
assert config.foo.baz() is False
|
assert await config.foo.baz() is False
|
||||||
|
|
||||||
|
|
||||||
def test_nested_overlapping(config):
|
@pytest.mark.asyncio
|
||||||
|
async def test_nested_overlapping(config):
|
||||||
config.register_global(foo__bar=True)
|
config.register_global(foo__bar=True)
|
||||||
config.register_global(foo__baz=False)
|
config.register_global(foo__baz=False)
|
||||||
|
|
||||||
assert config.foo.bar() is True
|
assert await config.foo.bar() is True
|
||||||
assert config.foo.baz() is False
|
assert await config.foo.baz() is False
|
||||||
|
|
||||||
|
|
||||||
def test_nesting_nofr(config):
|
@pytest.mark.asyncio
|
||||||
|
async def test_nesting_nofr(config):
|
||||||
config.register_global(foo={})
|
config.register_global(foo={})
|
||||||
assert config.foo.bar() is None
|
assert await config.foo.bar() is None
|
||||||
assert config.foo() == {}
|
assert await config.foo() == {}
|
||||||
|
|
||||||
|
|
||||||
#region Default Value Overrides
|
# region Default Value Overrides
|
||||||
def test_global_default_override(config):
|
@pytest.mark.asyncio
|
||||||
assert config.enabled(True) is True
|
async def test_global_default_override(config):
|
||||||
|
assert await config.enabled(True) is True
|
||||||
|
|
||||||
|
|
||||||
def test_global_default_nofr(config):
|
@pytest.mark.asyncio
|
||||||
assert config.nofr() is None
|
async def test_global_default_nofr(config):
|
||||||
assert config.nofr(True) is True
|
assert await config.nofr() is None
|
||||||
|
assert await config.nofr(True) is True
|
||||||
|
|
||||||
|
|
||||||
def test_guild_default_override(config, empty_guild):
|
@pytest.mark.asyncio
|
||||||
assert config.guild(empty_guild).enabled(True) is True
|
async def test_guild_default_override(config, empty_guild):
|
||||||
|
assert await config.guild(empty_guild).enabled(True) is True
|
||||||
|
|
||||||
|
|
||||||
def test_channel_default_override(config, empty_channel):
|
@pytest.mark.asyncio
|
||||||
assert config.channel(empty_channel).enabled(True) is True
|
async def test_channel_default_override(config, empty_channel):
|
||||||
|
assert await config.channel(empty_channel).enabled(True) is True
|
||||||
|
|
||||||
|
|
||||||
def test_role_default_override(config, empty_role):
|
@pytest.mark.asyncio
|
||||||
assert config.role(empty_role).enabled(True) is True
|
async def test_role_default_override(config, empty_role):
|
||||||
|
assert await config.role(empty_role).enabled(True) is True
|
||||||
|
|
||||||
|
|
||||||
def test_member_default_override(config, empty_member):
|
@pytest.mark.asyncio
|
||||||
assert config.member(empty_member).enabled(True) is True
|
async def test_member_default_override(config, empty_member):
|
||||||
|
assert await config.member(empty_member).enabled(True) is True
|
||||||
|
|
||||||
|
|
||||||
def test_user_default_override(config, empty_user):
|
@pytest.mark.asyncio
|
||||||
assert config.user(empty_user).some_value(True) is True
|
async def test_user_default_override(config, empty_user):
|
||||||
|
assert await config.user(empty_user).some_value(True) is True
|
||||||
#endregion
|
#endregion
|
||||||
|
|
||||||
|
|
||||||
@ -166,32 +187,32 @@ def test_user_default_override(config, empty_user):
|
|||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_set_global(config):
|
async def test_set_global(config):
|
||||||
await config.enabled.set(True)
|
await config.enabled.set(True)
|
||||||
assert config.enabled() is True
|
assert await config.enabled() is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_set_guild(config, empty_guild):
|
async def test_set_guild(config, empty_guild):
|
||||||
await config.guild(empty_guild).enabled.set(True)
|
await config.guild(empty_guild).enabled.set(True)
|
||||||
assert config.guild(empty_guild).enabled() is True
|
assert await config.guild(empty_guild).enabled() is True
|
||||||
|
|
||||||
curr_list = config.guild(empty_guild).some_list([1, 2, 3])
|
curr_list = await config.guild(empty_guild).some_list([1, 2, 3])
|
||||||
assert curr_list == [1, 2, 3]
|
assert curr_list == [1, 2, 3]
|
||||||
curr_list.append(4)
|
curr_list.append(4)
|
||||||
|
|
||||||
await config.guild(empty_guild).some_list.set(curr_list)
|
await config.guild(empty_guild).some_list.set(curr_list)
|
||||||
assert config.guild(empty_guild).some_list() == curr_list
|
assert await config.guild(empty_guild).some_list() == curr_list
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_set_channel(config, empty_channel):
|
async def test_set_channel(config, empty_channel):
|
||||||
await config.channel(empty_channel).enabled.set(True)
|
await config.channel(empty_channel).enabled.set(True)
|
||||||
assert config.channel(empty_channel).enabled() is True
|
assert await config.channel(empty_channel).enabled() is True
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_set_channel_no_register(config, empty_channel):
|
async def test_set_channel_no_register(config, empty_channel):
|
||||||
await config.channel(empty_channel).no_register.set(True)
|
await config.channel(empty_channel).no_register.set(True)
|
||||||
assert config.channel(empty_channel).no_register() is True
|
assert await config.channel(empty_channel).no_register() is True
|
||||||
#endregion
|
#endregion
|
||||||
|
|
||||||
|
|
||||||
@ -200,11 +221,12 @@ async def test_set_channel_no_register(config, empty_channel):
|
|||||||
async def test_set_dynamic_attr(config):
|
async def test_set_dynamic_attr(config):
|
||||||
await config.set_attr("foobar", True)
|
await config.set_attr("foobar", True)
|
||||||
|
|
||||||
assert config.foobar() is True
|
assert await config.foobar() is True
|
||||||
|
|
||||||
|
|
||||||
def test_get_dynamic_attr(config):
|
@pytest.mark.asyncio
|
||||||
assert config.get_attr("foobaz", True) is True
|
async def test_get_dynamic_attr(config):
|
||||||
|
assert await config.get_attr("foobaz", True) is True
|
||||||
|
|
||||||
|
|
||||||
# Member Group testing
|
# Member Group testing
|
||||||
@ -212,7 +234,7 @@ def test_get_dynamic_attr(config):
|
|||||||
async def test_membergroup_allguilds(config, empty_member):
|
async def test_membergroup_allguilds(config, empty_member):
|
||||||
await config.member(empty_member).foo.set(False)
|
await config.member(empty_member).foo.set(False)
|
||||||
|
|
||||||
all_servers = config.member(empty_member).all_guilds()
|
all_servers = await config.member(empty_member).all_guilds()
|
||||||
assert str(empty_member.guild.id) in all_servers
|
assert str(empty_member.guild.id) in all_servers
|
||||||
|
|
||||||
|
|
||||||
@ -220,7 +242,7 @@ async def test_membergroup_allguilds(config, empty_member):
|
|||||||
async def test_membergroup_allmembers(config, empty_member):
|
async def test_membergroup_allmembers(config, empty_member):
|
||||||
await config.member(empty_member).foo.set(False)
|
await config.member(empty_member).foo.set(False)
|
||||||
|
|
||||||
all_members = config.member(empty_member).all()
|
all_members = await config.member(empty_member).all()
|
||||||
assert str(empty_member.id) in all_members
|
assert str(empty_member.id) in all_members
|
||||||
|
|
||||||
|
|
||||||
@ -232,13 +254,13 @@ async def test_global_clear(config):
|
|||||||
await config.foo.set(False)
|
await config.foo.set(False)
|
||||||
await config.bar.set(True)
|
await config.bar.set(True)
|
||||||
|
|
||||||
assert config.foo() is False
|
assert await config.foo() is False
|
||||||
assert config.bar() is True
|
assert await config.bar() is True
|
||||||
|
|
||||||
await config.clear()
|
await config.clear()
|
||||||
|
|
||||||
assert config.foo() is True
|
assert await config.foo() is True
|
||||||
assert config.bar() is False
|
assert await config.bar() is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -247,17 +269,17 @@ async def test_member_clear(config, member_factory):
|
|||||||
|
|
||||||
m1 = member_factory.get()
|
m1 = member_factory.get()
|
||||||
await config.member(m1).foo.set(False)
|
await config.member(m1).foo.set(False)
|
||||||
assert config.member(m1).foo() is False
|
assert await config.member(m1).foo() is False
|
||||||
|
|
||||||
m2 = member_factory.get()
|
m2 = member_factory.get()
|
||||||
await config.member(m2).foo.set(False)
|
await config.member(m2).foo.set(False)
|
||||||
assert config.member(m2).foo() is False
|
assert await config.member(m2).foo() is False
|
||||||
|
|
||||||
assert m1.guild.id != m2.guild.id
|
assert m1.guild.id != m2.guild.id
|
||||||
|
|
||||||
await config.member(m1).clear()
|
await config.member(m1).clear()
|
||||||
assert config.member(m1).foo() is True
|
assert await config.member(m1).foo() is True
|
||||||
assert config.member(m2).foo() is False
|
assert await config.member(m2).foo() is False
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@ -269,11 +291,11 @@ async def test_member_clear_all(config, member_factory):
|
|||||||
server_ids.append(member.guild.id)
|
server_ids.append(member.guild.id)
|
||||||
|
|
||||||
member = member_factory.get()
|
member = member_factory.get()
|
||||||
assert len(config.member(member).all_guilds()) == len(server_ids)
|
assert len(await config.member(member).all_guilds()) == len(server_ids)
|
||||||
|
|
||||||
await config.member(member).clear_all()
|
await config.member(member).clear_all()
|
||||||
|
|
||||||
assert len(config.member(member).all_guilds()) == 0
|
assert len(await config.member(member).all_guilds()) == 0
|
||||||
|
|
||||||
|
|
||||||
# Get All testing
|
# Get All testing
|
||||||
@ -284,7 +306,7 @@ async def test_user_get_all_from_kind(config, user_factory):
|
|||||||
await config.user(user).foo.set(True)
|
await config.user(user).foo.set(True)
|
||||||
|
|
||||||
user = user_factory.get()
|
user = user_factory.get()
|
||||||
all_data = config.user(user).all_from_kind()
|
all_data = await config.user(user).all_from_kind()
|
||||||
|
|
||||||
assert len(all_data) == 5
|
assert len(all_data) == 5
|
||||||
|
|
||||||
@ -294,4 +316,4 @@ async def test_user_getalldata(config, user_factory):
|
|||||||
user = user_factory.get()
|
user = user_factory.get()
|
||||||
await config.user(user).foo.set(False)
|
await config.user(user).foo.set(False)
|
||||||
|
|
||||||
assert "foo" in config.user(user).all()
|
assert "foo" in await config.user(user).all()
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user