[Config] Asynchronous getters (#907)

* Make config get async

* Asyncify alias

* Asyncify bank

* Asyncify cog manager

* IT BOOTS

* Asyncify core commands

* Asyncify repo manager

* Asyncify downloader

* Asyncify economy

* Asyncify alias TESTS

* Asyncify economy TESTS

* Asyncify downloader TESTS

* Asyncify config TESTS

* A bank thing

* Asyncify Bank cog

* Warning message in docs

* Update docs with await syntax

* Update docs with await syntax
This commit is contained in:
Will 2017-08-11 21:43:21 -04:00 committed by GitHub
parent cf8e11238c
commit de912a3cfb
18 changed files with 371 additions and 296 deletions

View File

@ -38,26 +38,26 @@ class Alias:
self._aliases.register_global(**self.default_global_settings) self._aliases.register_global(**self.default_global_settings)
self._aliases.register_guild(**self.default_guild_settings) self._aliases.register_guild(**self.default_guild_settings)
def unloaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]: async def unloaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]:
return (AliasEntry.from_json(d) for d in self._aliases.guild(guild).entries()) return (AliasEntry.from_json(d) for d in (await self._aliases.guild(guild).entries()))
def unloaded_global_aliases(self) -> Generator[AliasEntry, None, None]: async def unloaded_global_aliases(self) -> Generator[AliasEntry, None, None]:
return (AliasEntry.from_json(d) for d in self._aliases.entries()) return (AliasEntry.from_json(d) for d in (await self._aliases.entries()))
def loaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]: async def loaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]:
return (AliasEntry.from_json(d, bot=self.bot) return (AliasEntry.from_json(d, bot=self.bot)
for d in self._aliases.guild(guild).entries()) for d in (await self._aliases.guild(guild).entries()))
def loaded_global_aliases(self) -> Generator[AliasEntry, None, None]: async def loaded_global_aliases(self) -> Generator[AliasEntry, None, None]:
return (AliasEntry.from_json(d, bot=self.bot) for d in self._aliases.entries()) return (AliasEntry.from_json(d, bot=self.bot) for d in (await self._aliases.entries()))
def is_alias(self, guild: discord.Guild, alias_name: str, async def is_alias(self, guild: discord.Guild, alias_name: str,
server_aliases: Iterable[AliasEntry]=()) -> (bool, AliasEntry): server_aliases: Iterable[AliasEntry]=()) -> (bool, AliasEntry):
if not server_aliases: if not server_aliases:
server_aliases = self.unloaded_aliases(guild) server_aliases = await self.unloaded_aliases(guild)
global_aliases = self.unloaded_global_aliases() global_aliases = await self.unloaded_global_aliases()
for aliases in (server_aliases, global_aliases): for aliases in (server_aliases, global_aliases):
for alias in aliases: for alias in aliases:
@ -79,11 +79,11 @@ class Alias:
alias = AliasEntry(alias_name, command, ctx.author, global_=global_) alias = AliasEntry(alias_name, command, ctx.author, global_=global_)
if global_: if global_:
curr_aliases = self._aliases.entries() curr_aliases = await self._aliases.entries()
curr_aliases.append(alias.to_json()) curr_aliases.append(alias.to_json())
await self._aliases.entries.set(curr_aliases) await self._aliases.entries.set(curr_aliases)
else: else:
curr_aliases = self._aliases.guild(ctx.guild).entries() curr_aliases = await self._aliases.guild(ctx.guild).entries()
curr_aliases.append(alias.to_json()) curr_aliases.append(alias.to_json())
await self._aliases.guild(ctx.guild).entries.set(curr_aliases) await self._aliases.guild(ctx.guild).entries.set(curr_aliases)
@ -94,10 +94,10 @@ class Alias:
async def delete_alias(self, ctx: commands.Context, alias_name: str, async def delete_alias(self, ctx: commands.Context, alias_name: str,
global_: bool=False) -> bool: global_: bool=False) -> bool:
if global_: if global_:
aliases = self.unloaded_global_aliases() aliases = await self.unloaded_global_aliases()
setter_func = self._aliases.entries.set setter_func = self._aliases.entries.set
else: else:
aliases = self.unloaded_aliases(ctx.guild) aliases = await self.unloaded_aliases(ctx.guild)
setter_func = self._aliases.guild(ctx.guild).entries.set setter_func = self._aliases.guild(ctx.guild).entries.set
did_delete_alias = False did_delete_alias = False
@ -161,7 +161,7 @@ class Alias:
except IndexError: except IndexError:
return False return False
is_alias, alias = self.is_alias(message.guild, potential_alias, server_aliases=aliases) is_alias, alias = await self.is_alias(message.guild, potential_alias, server_aliases=aliases)
if is_alias: if is_alias:
await self.call_alias(message, prefix, alias) await self.call_alias(message, prefix, alias)
@ -206,7 +206,7 @@ class Alias:
" name is already a command on this bot.").format(alias_name)) " name is already a command on this bot.").format(alias_name))
return return
is_alias, _ = self.is_alias(ctx.guild, alias_name) is_alias, _ = await self.is_alias(ctx.guild, alias_name)
if is_alias: if is_alias:
await ctx.send(("You attempted to create a new alias" await ctx.send(("You attempted to create a new alias"
" with the name {} but that" " with the name {} but that"
@ -285,7 +285,7 @@ class Alias:
@commands.guild_only() @commands.guild_only()
async def _show_alias(self, ctx: commands.Context, alias_name: str): async def _show_alias(self, ctx: commands.Context, alias_name: str):
"""Shows what command the alias executes.""" """Shows what command the alias executes."""
is_alias, alias = self.is_alias(ctx.guild, alias_name) is_alias, alias = await self.is_alias(ctx.guild, alias_name)
if is_alias: if is_alias:
await ctx.send(("The `{}` alias will execute the" await ctx.send(("The `{}` alias will execute the"
@ -299,7 +299,7 @@ class Alias:
""" """
Deletes an existing alias on this server. Deletes an existing alias on this server.
""" """
aliases = self.unloaded_aliases(ctx.guild) aliases = await self.unloaded_aliases(ctx.guild)
try: try:
next(aliases) next(aliases)
except StopIteration: except StopIteration:
@ -317,7 +317,7 @@ class Alias:
""" """
Deletes an existing global alias. Deletes an existing global alias.
""" """
aliases = self.unloaded_global_aliases() aliases = await self.unloaded_global_aliases()
try: try:
next(aliases) next(aliases)
except StopIteration: except StopIteration:
@ -336,7 +336,7 @@ class Alias:
""" """
Lists the available aliases on this server. Lists the available aliases on this server.
""" """
names = ["Aliases:", ] + sorted(["+ " + a.name for a in self.unloaded_aliases(ctx.guild)]) names = ["Aliases:", ] + sorted(["+ " + a.name for a in (await self.unloaded_aliases(ctx.guild))])
if len(names) == 0: if len(names) == 0:
await ctx.send("There are no aliases on this server.") await ctx.send("There are no aliases on this server.")
else: else:
@ -347,16 +347,16 @@ class Alias:
""" """
Lists the available global aliases on this bot. Lists the available global aliases on this bot.
""" """
names = ["Aliases:", ] + sorted(["+ " + a.name for a in self.unloaded_global_aliases()]) names = ["Aliases:", ] + sorted(["+ " + a.name for a in await self.unloaded_global_aliases()])
if len(names) == 0: if len(names) == 0:
await ctx.send("There are no aliases on this server.") await ctx.send("There are no aliases on this server.")
else: else:
await ctx.send(box("\n".join(names), "diff")) await ctx.send(box("\n".join(names), "diff"))
async def on_message(self, message: discord.Message): async def on_message(self, message: discord.Message):
aliases = list(self.unloaded_global_aliases()) aliases = list(await self.unloaded_global_aliases())
if message.guild is not None: if message.guild is not None:
aliases = aliases + list(self.unloaded_aliases(message.guild)) aliases = aliases + list(await self.unloaded_aliases(message.guild))
if len(aliases) == 0: if len(aliases) == 0:
return return

View File

@ -6,7 +6,7 @@ from core.bot import Red # Only used for type hints
def check_global_setting_guildowner(): def check_global_setting_guildowner():
async def pred(ctx: commands.Context): async def pred(ctx: commands.Context):
if bank.is_global(): if await bank.is_global():
return checks.is_owner() return checks.is_owner()
else: else:
return checks.guildowner_or_permissions(administrator=True) return checks.guildowner_or_permissions(administrator=True)
@ -15,7 +15,7 @@ def check_global_setting_guildowner():
def check_global_setting_admin(): def check_global_setting_admin():
async def pred(ctx: commands.Context): async def pred(ctx: commands.Context):
if bank.is_global(): if await bank.is_global():
return checks.is_owner() return checks.is_owner()
else: else:
return checks.admin_or_permissions(manage_guild=True) return checks.admin_or_permissions(manage_guild=True)
@ -43,7 +43,7 @@ class Bank:
"""Toggles whether the bank is global or not """Toggles whether the bank is global or not
If the bank is global, it will become per-guild If the bank is global, it will become per-guild
If the bank is per-guild, it will become global""" If the bank is per-guild, it will become global"""
cur_setting = bank.is_global() cur_setting = await bank.is_global()
await bank.set_global(not cur_setting, ctx.author) await bank.set_global(not cur_setting, ctx.author)
word = "per-guild" if cur_setting else "global" word = "per-guild" if cur_setting else "global"

View File

@ -49,22 +49,20 @@ class Downloader:
self._repo_manager = RepoManager(self.conf) self._repo_manager = RepoManager(self.conf)
@property async def cog_install_path(self):
def cog_install_path(self):
""" """
Returns the current cog install path. Returns the current cog install path.
:return: :return:
""" """
return self.bot.cog_mgr.install_path return await self.bot.cog_mgr.install_path()
@property async def installed_cogs(self) -> Tuple[Installable]:
def installed_cogs(self) -> Tuple[Installable]:
""" """
Returns the dictionary mapping cog name to install location Returns the dictionary mapping cog name to install location
and repo name. and repo name.
:return: :return:
""" """
installed = self.conf.installed() installed = await self.conf.installed()
# noinspection PyTypeChecker # noinspection PyTypeChecker
return tuple(Installable.from_json(v) for v in installed) return tuple(Installable.from_json(v) for v in installed)
@ -74,7 +72,7 @@ class Downloader:
:param cog: :param cog:
:return: :return:
""" """
installed = self.conf.installed() installed = await self.conf.installed()
cog_json = cog.to_json() cog_json = cog.to_json()
if cog_json not in installed: if cog_json not in installed:
@ -87,7 +85,7 @@ class Downloader:
:param cog: :param cog:
:return: :return:
""" """
installed = self.conf.installed() installed = await self.conf.installed()
cog_json = cog.to_json() cog_json = cog.to_json()
if cog_json in installed: if cog_json in installed:
@ -102,7 +100,7 @@ class Downloader:
""" """
failed = [] failed = []
for cog in cogs: for cog in cogs:
if not await cog.copy_to(self.cog_install_path): if not await cog.copy_to(await self.cog_install_path()):
failed.append(cog) failed.append(cog)
# noinspection PyTypeChecker # noinspection PyTypeChecker
@ -249,7 +247,7 @@ class Downloader:
" `{}`: `{}`".format(cog.name, cog.requirements)) " `{}`: `{}`".format(cog.name, cog.requirements))
return return
await repo_name.install_cog(cog, self.cog_install_path) await repo_name.install_cog(cog, await self.cog_install_path())
await self._add_to_installed(cog) await self._add_to_installed(cog)
@ -266,7 +264,7 @@ class Downloader:
# noinspection PyUnresolvedReferences,PyProtectedMember # noinspection PyUnresolvedReferences,PyProtectedMember
real_name = cog_name.name real_name = cog_name.name
poss_installed_path = self.cog_install_path / real_name poss_installed_path = (await self.cog_install_path()) / real_name
if poss_installed_path.exists(): if poss_installed_path.exists():
await self._delete_cog(poss_installed_path) await self._delete_cog(poss_installed_path)
# noinspection PyTypeChecker # noinspection PyTypeChecker
@ -284,7 +282,7 @@ class Downloader:
""" """
if cog_name is None: if cog_name is None:
updated = await self._repo_manager.update_all_repos() updated = await self._repo_manager.update_all_repos()
installed_cogs = set(self.installed_cogs) installed_cogs = set(await self.installed_cogs())
updated_cogs = set(cog for repo in updated.keys() for cog in repo.available_cogs) updated_cogs = set(cog for repo in updated.keys() for cog in repo.available_cogs)
installed_and_updated = updated_cogs & installed_cogs installed_and_updated = updated_cogs & installed_cogs
@ -325,14 +323,14 @@ class Downloader:
msg = "Information on {}:\n{}".format(cog.name, cog.description or "") msg = "Information on {}:\n{}".format(cog.name, cog.description or "")
await ctx.send(box(msg)) await ctx.send(box(msg))
def is_installed(self, cog_name: str) -> (bool, Union[Installable, None]): async def is_installed(self, cog_name: str) -> (bool, Union[Installable, None]):
""" """
Checks to see if a cog with the given name was installed Checks to see if a cog with the given name was installed
through Downloader. through Downloader.
:param cog_name: :param cog_name:
:return: is_installed, Installable :return: is_installed, Installable
""" """
for installable in self.installed_cogs: for installable in await self.installed_cogs():
if installable.name == cog_name: if installable.name == cog_name:
return True, installable return True, installable
return False, None return False, None
@ -384,7 +382,7 @@ class Downloader:
# Check if in installed cogs # Check if in installed cogs
cog_name = self.cog_name_from_instance(command.instance) cog_name = self.cog_name_from_instance(command.instance)
installed, cog_installable = self.is_installed(cog_name) installed, cog_installable = await self.is_installed(cog_name)
if installed: if installed:
msg = self.format_findcog_info(command_name, cog_installable) msg = self.format_findcog_info(command_name, cog_installable)
else: else:

View File

@ -430,7 +430,10 @@ class RepoManager:
self.repos_folder = Path(__file__).parent / 'repos' self.repos_folder = Path(__file__).parent / 'repos'
self._repos = self._load_repos() # str_name: Repo self._repos = {}
loop = asyncio.get_event_loop()
loop.run_until_complete(self._load_repos(set=True)) # str_name: Repo
def does_repo_exist(self, name: str) -> bool: def does_repo_exist(self, name: str) -> bool:
return name in self._repos return name in self._repos
@ -494,7 +497,6 @@ class RepoManager:
shutil.rmtree(str(repo.folder_path)) shutil.rmtree(str(repo.folder_path))
repos = self.downloader_config.repos()
try: try:
del self._repos[name] del self._repos[name]
except KeyError: except KeyError:
@ -518,11 +520,14 @@ class RepoManager:
await self._save_repos() await self._save_repos()
return ret return ret
def _load_repos(self) -> MutableMapping[str, Repo]: async def _load_repos(self, set=False) -> MutableMapping[str, Repo]:
return { ret = {
name: Repo.from_json(data) for name, data in name: Repo.from_json(data) for name, data in
self.downloader_config.repos().items() (await self.downloader_config.repos()).items()
} }
if set:
self._repos = ret
return ret
async def _save_repos(self): async def _save_repos(self):
repo_json_info = {name: r.to_json() for name, r in self._repos.items()} repo_json_info = {name: r.to_json() for name, r in self._repos.items()}

View File

@ -72,9 +72,9 @@ SLOT_PAYOUTS_MSG = ("Slot machine payouts:\n"
def guild_only_check(): def guild_only_check():
async def pred(ctx: commands.Context): async def pred(ctx: commands.Context):
if bank.is_global(): if await bank.is_global():
return True return True
elif not bank.is_global() and ctx.guild is not None: elif not await bank.is_global() and ctx.guild is not None:
return True return True
else: else:
return False return False
@ -146,8 +146,8 @@ class Economy:
if user is None: if user is None:
user = ctx.author user = ctx.author
bal = bank.get_balance(user) bal = await bank.get_balance(user)
currency = bank.get_currency_name(ctx.guild) currency = await bank.get_currency_name(ctx.guild)
await ctx.send("{}'s balance is {} {}".format( await ctx.send("{}'s balance is {} {}".format(
user.display_name, bal, currency)) user.display_name, bal, currency))
@ -156,7 +156,7 @@ class Economy:
async def transfer(self, ctx: commands.Context, to: discord.Member, amount: int): async def transfer(self, ctx: commands.Context, to: discord.Member, amount: int):
"""Transfer currency to other users""" """Transfer currency to other users"""
from_ = ctx.author from_ = ctx.author
currency = bank.get_currency_name(ctx.guild) currency = await bank.get_currency_name(ctx.guild)
try: try:
await bank.transfer_credits(from_, to, amount) await bank.transfer_credits(from_, to, amount)
@ -206,12 +206,12 @@ class Economy:
await ctx.send( await ctx.send(
"This will delete all bank accounts for {}.\nIf you're sure, type " "This will delete all bank accounts for {}.\nIf you're sure, type "
"{}bank reset yes".format( "{}bank reset yes".format(
self.bot.user.name if bank.is_global() else "this guild", self.bot.user.name if await bank.is_global() else "this guild",
ctx.prefix ctx.prefix
) )
) )
else: else:
if bank.is_global(): if await bank.is_global():
# Bank being global means that the check would cause only # Bank being global means that the check would cause only
# the owner and any co-owners to be able to run the command # the owner and any co-owners to be able to run the command
# so if we're in the function, it's safe to assume that the # so if we're in the function, it's safe to assume that the
@ -232,18 +232,18 @@ class Economy:
guild = ctx.guild guild = ctx.guild
cur_time = calendar.timegm(ctx.message.created_at.utctimetuple()) cur_time = calendar.timegm(ctx.message.created_at.utctimetuple())
credits_name = bank.get_currency_name(ctx.guild) credits_name = await bank.get_currency_name(ctx.guild)
if bank.is_global(): if await bank.is_global():
next_payday = self.config.user(author).next_payday() next_payday = await self.config.user(author).next_payday()
if cur_time >= next_payday: if cur_time >= next_payday:
await bank.deposit_credits(author, self.config.PAYDAY_CREDITS()) await bank.deposit_credits(author, await self.config.PAYDAY_CREDITS())
next_payday = cur_time + self.config.PAYDAY_TIME() next_payday = cur_time + await self.config.PAYDAY_TIME()
await self.config.user(author).next_payday.set(next_payday) await self.config.user(author).next_payday.set(next_payday)
await ctx.send( await ctx.send(
"{} Here, take some {}. Enjoy! (+{}" "{} Here, take some {}. Enjoy! (+{}"
" {}!)".format( " {}!)".format(
author.mention, credits_name, author.mention, credits_name,
str(self.config.PAYDAY_CREDITS()), str(await self.config.PAYDAY_CREDITS()),
credits_name credits_name
) )
) )
@ -254,16 +254,16 @@ class Economy:
" wait {}.".format(author.mention, dtime) " wait {}.".format(author.mention, dtime)
) )
else: else:
next_payday = self.config.member(author).next_payday() next_payday = await self.config.member(author).next_payday()
if cur_time >= next_payday: if cur_time >= next_payday:
await bank.deposit_credits(author, self.config.guild(guild).PAYDAY_CREDITS()) await bank.deposit_credits(author, await self.config.guild(guild).PAYDAY_CREDITS())
next_payday = cur_time + self.config.guild(guild).PAYDAY_TIME() next_payday = cur_time + await self.config.guild(guild).PAYDAY_TIME()
await self.config.member(author).next_payday.set(next_payday) await self.config.member(author).next_payday.set(next_payday)
await ctx.send( await ctx.send(
"{} Here, take some {}. Enjoy! (+{}" "{} Here, take some {}. Enjoy! (+{}"
" {}!)".format( " {}!)".format(
author.mention, credits_name, author.mention, credits_name,
str(self.config.guild(guild).PAYDAY_CREDITS()), str(await self.config.guild(guild).PAYDAY_CREDITS()),
credits_name)) credits_name))
else: else:
dtime = self.display_time(next_payday - cur_time) dtime = self.display_time(next_payday - cur_time)
@ -282,10 +282,10 @@ class Economy:
if top < 1: if top < 1:
top = 10 top = 10
if bank.is_global(): if bank.is_global():
bank_sorted = sorted(bank.get_global_accounts(ctx.author), bank_sorted = sorted(await bank.get_global_accounts(ctx.author),
key=lambda x: x.balance, reverse=True) key=lambda x: x.balance, reverse=True)
else: else:
bank_sorted = sorted(bank.get_guild_accounts(guild), bank_sorted = sorted(await bank.get_guild_accounts(guild),
key=lambda x: x.balance, reverse=True) key=lambda x: x.balance, reverse=True)
if len(bank_sorted) < top: if len(bank_sorted) < top:
top = len(bank_sorted) top = len(bank_sorted)
@ -320,14 +320,14 @@ class Economy:
author = ctx.author author = ctx.author
guild = ctx.guild guild = ctx.guild
channel = ctx.channel channel = ctx.channel
if bank.is_global(): if await bank.is_global():
valid_bid = self.config.SLOT_MIN() <= bid <= self.config.SLOT_MAX() valid_bid = await self.config.SLOT_MIN() <= bid <= await self.config.SLOT_MAX()
slot_time = self.config.SLOT_TIME() slot_time = await self.config.SLOT_TIME()
last_slot = self.config.user(author).last_slot() last_slot = await self.config.user(author).last_slot()
else: else:
valid_bid = self.config.guild(guild).SLOT_MIN() <= bid <= self.config.guild(guild).SLOT_MAX() valid_bid = await self.config.guild(guild).SLOT_MIN() <= bid <= await self.config.guild(guild).SLOT_MAX()
slot_time = self.config.guild(guild).SLOT_TIME() slot_time = await self.config.guild(guild).SLOT_TIME()
last_slot = self.config.member(author).last_slot() last_slot = await self.config.member(author).last_slot()
now = calendar.timegm(ctx.message.created_at.utctimetuple()) now = calendar.timegm(ctx.message.created_at.utctimetuple())
if (now - last_slot) < slot_time: if (now - last_slot) < slot_time:
@ -336,10 +336,10 @@ class Economy:
if not valid_bid: if not valid_bid:
await ctx.send("That's an invalid bid amount, sorry :/") await ctx.send("That's an invalid bid amount, sorry :/")
return return
if not bank.can_spend(author, bid): if not await bank.can_spend(author, bid):
await ctx.send("You ain't got enough money, friend.") await ctx.send("You ain't got enough money, friend.")
return return
if bank.is_global(): if await bank.is_global():
await self.config.user(author).last_slot.set(now) await self.config.user(author).last_slot.set(now)
else: else:
await self.config.member(author).last_slot.set(now) await self.config.member(author).last_slot.set(now)
@ -379,7 +379,7 @@ class Economy:
payout = PAYOUTS["2 symbols"] payout = PAYOUTS["2 symbols"]
if payout: if payout:
then = bank.get_balance(author) then = await bank.get_balance(author)
pay = payout["payout"](bid) pay = payout["payout"](bid)
now = then - bid + pay now = then - bid + pay
await bank.set_balance(author, now) await bank.set_balance(author, now)
@ -387,7 +387,7 @@ class Economy:
"".format(slot, author.mention, "".format(slot, author.mention,
payout["phrase"], bid, then, now)) payout["phrase"], bid, then, now))
else: else:
then = bank.get_balance(author) then = await bank.get_balance(author)
await bank.withdraw_credits(author, bid) await bank.withdraw_credits(author, bid)
now = then - bid now = then - bid
await channel.send("{}\n{} Nothing!\nYour bid: {}\n{}{}!" await channel.send("{}\n{} Nothing!\nYour bid: {}\n{}{}!"
@ -402,18 +402,18 @@ class Economy:
if ctx.invoked_subcommand is None: if ctx.invoked_subcommand is None:
await self.bot.send_cmd_help(ctx) await self.bot.send_cmd_help(ctx)
if bank.is_global(): if bank.is_global():
slot_min = self.config.SLOT_MIN() slot_min = await self.config.SLOT_MIN()
slot_max = self.config.SLOT_MAX() slot_max = await self.config.SLOT_MAX()
slot_time = self.config.SLOT_TIME() slot_time = await self.config.SLOT_TIME()
payday_time = self.config.PAYDAY_TIME() payday_time = await self.config.PAYDAY_TIME()
payday_amount = self.config.PAYDAY_CREDITS() payday_amount = await self.config.PAYDAY_CREDITS()
else: else:
slot_min = self.config.guild(guild).SLOT_MIN() slot_min = await self.config.guild(guild).SLOT_MIN()
slot_max = self.config.guild(guild).SLOT_MAX() slot_max = await self.config.guild(guild).SLOT_MAX()
slot_time = self.config.guild(guild).SLOT_TIME() slot_time = await self.config.guild(guild).SLOT_TIME()
payday_time = self.config.guild(guild).PAYDAY_TIME() payday_time = await self.config.guild(guild).PAYDAY_TIME()
payday_amount = self.config.guild(guild).PAYDAY_CREDITS() payday_amount = await self.config.guild(guild).PAYDAY_CREDITS()
register_amount = bank.get_default_balance(guild) register_amount = await bank.get_default_balance(guild)
msg = box( msg = box(
"Minimum slot bid: {}\n" "Minimum slot bid: {}\n"
"Maximum slot bid: {}\n" "Maximum slot bid: {}\n"
@ -436,24 +436,24 @@ class Economy:
await ctx.send('Invalid min bid amount.') await ctx.send('Invalid min bid amount.')
return return
guild = ctx.guild guild = ctx.guild
if bank.is_global(): if await bank.is_global():
await self.config.SLOT_MIN.set(bid) await self.config.SLOT_MIN.set(bid)
else: else:
await self.config.guild(guild).SLOT_MIN.set(bid) await self.config.guild(guild).SLOT_MIN.set(bid)
credits_name = bank.get_currency_name(guild) credits_name = await bank.get_currency_name(guild)
await ctx.send("Minimum bid is now {} {}.".format(bid, credits_name)) await ctx.send("Minimum bid is now {} {}.".format(bid, credits_name))
@economyset.command() @economyset.command()
async def slotmax(self, ctx: commands.Context, bid: int): async def slotmax(self, ctx: commands.Context, bid: int):
"""Maximum slot machine bid""" """Maximum slot machine bid"""
slot_min = self.config.SLOT_MIN() slot_min = await self.config.SLOT_MIN()
if bid < 1 or bid < slot_min: if bid < 1 or bid < slot_min:
await ctx.send('Invalid slotmax bid amount. Must be greater' await ctx.send('Invalid slotmax bid amount. Must be greater'
' than slotmin.') ' than slotmin.')
return return
guild = ctx.guild guild = ctx.guild
credits_name = bank.get_currency_name(guild) credits_name = await bank.get_currency_name(guild)
if bank.is_global(): if await bank.is_global():
await self.config.SLOT_MAX.set(bid) await self.config.SLOT_MAX.set(bid)
else: else:
await self.config.guild(guild).SLOT_MAX.set(bid) await self.config.guild(guild).SLOT_MAX.set(bid)
@ -463,7 +463,7 @@ class Economy:
async def slottime(self, ctx: commands.Context, seconds: int): async def slottime(self, ctx: commands.Context, seconds: int):
"""Seconds between each slots use""" """Seconds between each slots use"""
guild = ctx.guild guild = ctx.guild
if bank.is_global(): if await bank.is_global():
await self.config.SLOT_TIME.set(seconds) await self.config.SLOT_TIME.set(seconds)
else: else:
await self.config.guild(guild).SLOT_TIME.set(seconds) await self.config.guild(guild).SLOT_TIME.set(seconds)
@ -473,7 +473,7 @@ class Economy:
async def paydaytime(self, ctx: commands.Context, seconds: int): async def paydaytime(self, ctx: commands.Context, seconds: int):
"""Seconds between each payday""" """Seconds between each payday"""
guild = ctx.guild guild = ctx.guild
if bank.is_global(): if await bank.is_global():
await self.config.PAYDAY_TIME.set(seconds) await self.config.PAYDAY_TIME.set(seconds)
else: else:
await self.config.guild(guild).PAYDAY_TIME.set(seconds) await self.config.guild(guild).PAYDAY_TIME.set(seconds)
@ -484,11 +484,11 @@ class Economy:
async def paydayamount(self, ctx: commands.Context, creds: int): async def paydayamount(self, ctx: commands.Context, creds: int):
"""Amount earned each payday""" """Amount earned each payday"""
guild = ctx.guild guild = ctx.guild
credits_name = bank.get_currency_name(guild) credits_name = await bank.get_currency_name(guild)
if creds <= 0: if creds <= 0:
await ctx.send("Har har so funny.") await ctx.send("Har har so funny.")
return return
if bank.is_global(): if await bank.is_global():
await self.config.PAYDAY_CREDITS.set(creds) await self.config.PAYDAY_CREDITS.set(creds)
else: else:
await self.config.guild(guild).PAYDAY_CREDITS.set(creds) await self.config.guild(guild).PAYDAY_CREDITS.set(creds)
@ -501,7 +501,7 @@ class Economy:
guild = ctx.guild guild = ctx.guild
if creds < 0: if creds < 0:
creds = 0 creds = 0
credits_name = bank.get_currency_name(guild) credits_name = await bank.get_currency_name(guild)
await bank.set_default_balance(creds, guild) await bank.set_default_balance(creds, guild)
await ctx.send("Registering an account will now give {} {}." await ctx.send("Registering an account will now give {} {}."
"".format(creds, credits_name)) "".format(creds, credits_name))

View File

@ -1,6 +1,6 @@
import datetime import datetime
from collections import namedtuple from collections import namedtuple
from typing import Tuple, Generator, Union from typing import Tuple, Generator, Union, List
import discord import discord
from copy import deepcopy from copy import deepcopy
@ -78,17 +78,17 @@ def _decode_time(time: int) -> datetime.datetime:
return datetime.datetime.utcfromtimestamp(time) return datetime.datetime.utcfromtimestamp(time)
def get_balance(member: discord.Member) -> int: async def get_balance(member: discord.Member) -> int:
""" """
Gets the current balance of a member. Gets the current balance of a member.
:param member: :param member:
:return: :return:
""" """
acc = get_account(member) acc = await get_account(member)
return acc.balance return acc.balance
def can_spend(member: discord.Member, amount: int) -> bool: async def can_spend(member: discord.Member, amount: int) -> bool:
""" """
Determines if a member can spend the given amount. Determines if a member can spend the given amount.
:param member: :param member:
@ -97,7 +97,7 @@ def can_spend(member: discord.Member, amount: int) -> bool:
""" """
if _invalid_amount(amount): if _invalid_amount(amount):
return False return False
return get_balance(member) > amount return await get_balance(member) > amount
async def set_balance(member: discord.Member, amount: int) -> int: async def set_balance(member: discord.Member, amount: int) -> int:
@ -111,17 +111,17 @@ async def set_balance(member: discord.Member, amount: int) -> int:
""" """
if amount < 0: if amount < 0:
raise ValueError("Not allowed to have negative balance.") raise ValueError("Not allowed to have negative balance.")
if is_global(): if await is_global():
group = _conf.user(member) group = _conf.user(member)
else: else:
group = _conf.member(member) group = _conf.member(member)
await group.balance.set(amount) await group.balance.set(amount)
if group.created_at() == 0: if await group.created_at() == 0:
time = _encoded_current_time() time = _encoded_current_time()
await group.created_at.set(time) await group.created_at.set(time)
if group.name() == "": if await group.name() == "":
await group.name.set(member.display_name) await group.name.set(member.display_name)
return amount return amount
@ -144,7 +144,7 @@ async def withdraw_credits(member: discord.Member, amount: int) -> int:
if _invalid_amount(amount): if _invalid_amount(amount):
raise ValueError("Invalid withdrawal amount {} <= 0".format(amount)) raise ValueError("Invalid withdrawal amount {} <= 0".format(amount))
bal = get_balance(member) bal = await get_balance(member)
if amount > bal: if amount > bal:
raise ValueError("Insufficient funds {} > {}".format(amount, bal)) raise ValueError("Insufficient funds {} > {}".format(amount, bal))
@ -163,7 +163,7 @@ async def deposit_credits(member: discord.Member, amount: int) -> int:
if _invalid_amount(amount): if _invalid_amount(amount):
raise ValueError("Invalid withdrawal amount {} <= 0".format(amount)) raise ValueError("Invalid withdrawal amount {} <= 0".format(amount))
bal = get_balance(member) bal = await get_balance(member)
return await set_balance(member, amount + bal) return await set_balance(member, amount + bal)
@ -190,13 +190,13 @@ async def wipe_bank(user: Union[discord.User, discord.Member]):
Deletes all accounts from the bank. Deletes all accounts from the bank.
:return: :return:
""" """
if is_global(): if await is_global():
await _conf.user(user).clear() await _conf.user(user).clear()
else: else:
await _conf.member(user).clear() await _conf.member(user).clear()
def get_guild_accounts(guild: discord.Guild) -> Generator[Account, None, None]: async def get_guild_accounts(guild: discord.Guild) -> List[Account]:
""" """
Gets all account data for the given guild. Gets all account data for the given guild.
@ -207,14 +207,16 @@ def get_guild_accounts(guild: discord.Guild) -> Generator[Account, None, None]:
if is_global(): if is_global():
raise RuntimeError("The bank is currently global.") raise RuntimeError("The bank is currently global.")
accs = _conf.member(guild.owner).all_from_kind() ret = []
accs = await _conf.member(guild.owner).all_from_kind()
for user_id, acc in accs.items(): for user_id, acc in accs.items():
acc_data = acc.copy() # There ya go kowlin acc_data = acc.copy() # There ya go kowlin
acc_data['created_at'] = _decode_time(acc_data['created_at']) acc_data['created_at'] = _decode_time(acc_data['created_at'])
yield Account(**acc_data) ret.append(Account(**acc_data))
return ret
def get_global_accounts(user: discord.User) -> Generator[Account, None, None]: async def get_global_accounts(user: discord.User) -> List[Account]:
""" """
Gets all global account data. Gets all global account data.
@ -225,44 +227,47 @@ def get_global_accounts(user: discord.User) -> Generator[Account, None, None]:
if not is_global(): if not is_global():
raise RuntimeError("The bank is not currently global.") raise RuntimeError("The bank is not currently global.")
accs = _conf.user(user).all_from_kind() # this is a dict of user -> acc ret = []
accs = await _conf.user(user).all_from_kind() # this is a dict of user -> acc
for user_id, acc in accs.items(): for user_id, acc in accs.items():
acc_data = acc.copy() acc_data = acc.copy()
acc_data['created_at'] = _decode_time(acc_data['created_at']) acc_data['created_at'] = _decode_time(acc_data['created_at'])
yield Account(**acc_data) ret.append(Account(**acc_data))
return ret
def get_account(member: Union[discord.Member, discord.User]) -> Account: async def get_account(member: Union[discord.Member, discord.User]) -> Account:
""" """
Gets the appropriate account for the given member. Gets the appropriate account for the given member.
:param member: :param member:
:return: :return:
""" """
if is_global(): if await is_global():
acc_data = _conf.user(member)().copy() acc_data = (await _conf.user(member)()).copy()
default = _DEFAULT_USER.copy() default = _DEFAULT_USER.copy()
else: else:
acc_data = _conf.member(member)().copy() acc_data = (await _conf.member(member)()).copy()
default = _DEFAULT_MEMBER.copy() default = _DEFAULT_MEMBER.copy()
if acc_data == {}: if acc_data == {}:
acc_data = default acc_data = default
acc_data['name'] = member.display_name acc_data['name'] = member.display_name
try: try:
acc_data['balance'] = get_default_balance(member.guild) acc_data['balance'] = await get_default_balance(member.guild)
except AttributeError: except AttributeError:
acc_data['balance'] = get_default_balance() acc_data['balance'] = await get_default_balance()
acc_data['created_at'] = _decode_time(acc_data['created_at']) acc_data['created_at'] = _decode_time(acc_data['created_at'])
return Account(**acc_data) return Account(**acc_data)
def is_global() -> bool: async def is_global() -> bool:
""" """
Determines if the bank is currently global. Determines if the bank is currently global.
:return: :return:
""" """
return _conf.is_global() return await _conf.is_global()
async def set_global(global_: bool, user: Union[discord.User, discord.Member]) -> bool: async def set_global(global_: bool, user: Union[discord.User, discord.Member]) -> bool:
@ -272,7 +277,7 @@ async def set_global(global_: bool, user: Union[discord.User, discord.Member]) -
:param user: Must be a Member object if changing TO global mode. :param user: Must be a Member object if changing TO global mode.
:return: New bank mode, True is global. :return: New bank mode, True is global.
""" """
if is_global() is global_: if (await is_global()) is global_:
return global_ return global_
if is_global(): if is_global():
@ -287,7 +292,7 @@ async def set_global(global_: bool, user: Union[discord.User, discord.Member]) -
return global_ return global_
def get_bank_name(guild: discord.Guild=None) -> str: async def get_bank_name(guild: discord.Guild=None) -> str:
""" """
Gets the current bank name. If the bank is guild-specific the Gets the current bank name. If the bank is guild-specific the
guild parameter is required. guild parameter is required.
@ -296,10 +301,10 @@ def get_bank_name(guild: discord.Guild=None) -> str:
:param guild: :param guild:
:return: :return:
""" """
if is_global(): if await is_global():
return _conf.bank_name() return await _conf.bank_name()
elif guild is not None: elif guild is not None:
return _conf.guild(guild).bank_name() return await _conf.guild(guild).bank_name()
else: else:
raise RuntimeError("Guild parameter is required and missing.") raise RuntimeError("Guild parameter is required and missing.")
@ -314,7 +319,7 @@ async def set_bank_name(name: str, guild: discord.Guild=None) -> str:
:param guild: :param guild:
:return: :return:
""" """
if is_global(): if await is_global():
await _conf.bank_name.set(name) await _conf.bank_name.set(name)
elif guild is not None: elif guild is not None:
await _conf.guild(guild).bank_name.set(name) await _conf.guild(guild).bank_name.set(name)
@ -324,7 +329,7 @@ async def set_bank_name(name: str, guild: discord.Guild=None) -> str:
return name return name
def get_currency_name(guild: discord.Guild=None) -> str: async def get_currency_name(guild: discord.Guild=None) -> str:
""" """
Gets the currency name of the bank. The guild parameter is required if Gets the currency name of the bank. The guild parameter is required if
the bank is guild-specific. the bank is guild-specific.
@ -333,10 +338,10 @@ def get_currency_name(guild: discord.Guild=None) -> str:
:param guild: :param guild:
:return: :return:
""" """
if is_global(): if await is_global():
return _conf.currency() return await _conf.currency()
elif guild is not None: elif guild is not None:
return _conf.guild(guild).currency() return await _conf.guild(guild).currency()
else: else:
raise RuntimeError("Guild must be provided.") raise RuntimeError("Guild must be provided.")
@ -351,7 +356,7 @@ async def set_currency_name(name: str, guild: discord.Guild=None) -> str:
:param guild: :param guild:
:return: :return:
""" """
if is_global(): if await is_global():
await _conf.currency.set(name) await _conf.currency.set(name)
elif guild is not None: elif guild is not None:
await _conf.guild(guild).currency.set(name) await _conf.guild(guild).currency.set(name)
@ -361,7 +366,7 @@ async def set_currency_name(name: str, guild: discord.Guild=None) -> str:
return name return name
def get_default_balance(guild: discord.Guild=None) -> int: async def get_default_balance(guild: discord.Guild=None) -> int:
""" """
Gets the current default balance amount. If the bank is guild-specific Gets the current default balance amount. If the bank is guild-specific
you must pass guild. you must pass guild.
@ -370,10 +375,10 @@ def get_default_balance(guild: discord.Guild=None) -> int:
:param guild: :param guild:
:return: :return:
""" """
if is_global(): if await is_global():
return _conf.default_balance() return await _conf.default_balance()
elif guild is not None: elif guild is not None:
return _conf.guild(guild).default_balance() return await _conf.guild(guild).default_balance()
else: else:
raise RuntimeError("Guild is missing and required!") raise RuntimeError("Guild is missing and required!")
@ -393,9 +398,11 @@ async def set_default_balance(amount: int, guild: discord.Guild=None) -> int:
if amount < 0: if amount < 0:
raise ValueError("Amount must be greater than zero.") raise ValueError("Amount must be greater than zero.")
if is_global(): if await is_global():
await _conf.default_balance.set(amount) await _conf.default_balance.set(amount)
elif guild is not None: elif guild is not None:
await _conf.guild(guild).default_balance.set(amount) await _conf.guild(guild).default_balance.set(amount)
else: else:
raise RuntimeError("Guild is missing and required.") raise RuntimeError("Guild is missing and required.")
return amount

View File

@ -1,3 +1,4 @@
import asyncio
import importlib.util import importlib.util
from importlib.machinery import ModuleSpec from importlib.machinery import ModuleSpec
@ -39,14 +40,14 @@ class Red(commands.Bot):
mod_role=None mod_role=None
) )
def prefix_manager(bot, message): async def prefix_manager(bot, message):
if not cli_flags.prefix: if not cli_flags.prefix:
global_prefix = self.db.prefix() global_prefix = await bot.db.prefix()
else: else:
global_prefix = cli_flags.prefix global_prefix = cli_flags.prefix
if message.guild is None: if message.guild is None:
return global_prefix return global_prefix
server_prefix = self.db.guild(message.guild).prefix() server_prefix = await bot.db.guild(message.guild).prefix()
return server_prefix if server_prefix else global_prefix return server_prefix if server_prefix else global_prefix
if "command_prefix" not in kwargs: if "command_prefix" not in kwargs:
@ -56,7 +57,8 @@ class Red(commands.Bot):
kwargs["owner_id"] = cli_flags.owner kwargs["owner_id"] = cli_flags.owner
if "owner_id" not in kwargs: if "owner_id" not in kwargs:
kwargs["owner_id"] = self.db.owner() loop = asyncio.get_event_loop()
loop.run_until_complete(self._dict_abuse(kwargs))
self.counter = Counter() self.counter = Counter()
self.uptime = None self.uptime = None
@ -68,6 +70,15 @@ class Red(commands.Bot):
super().__init__(**kwargs) super().__init__(**kwargs)
async def _dict_abuse(self, indict):
"""
Please blame <@269933075037814786> for this.
:param indict:
:return:
"""
indict['owner_id'] = await self.db.owner()
async def is_owner(self, user): async def is_owner(self, user):
if user.id in self._co_owners: if user.id in self._co_owners:
return True return True
@ -103,13 +114,13 @@ class Red(commands.Bot):
await self.db.packages.set(packages) await self.db.packages.set(packages)
async def add_loaded_package(self, pkg_name: str): async def add_loaded_package(self, pkg_name: str):
curr_pkgs = self.db.packages() curr_pkgs = await self.db.packages()
if pkg_name not in curr_pkgs: if pkg_name not in curr_pkgs:
curr_pkgs.append(pkg_name) curr_pkgs.append(pkg_name)
await self.save_packages_status(curr_pkgs) await self.save_packages_status(curr_pkgs)
async def remove_loaded_package(self, pkg_name: str): async def remove_loaded_package(self, pkg_name: str):
curr_pkgs = self.db.packages() curr_pkgs = await self.db.packages()
if pkg_name in curr_pkgs: if pkg_name in curr_pkgs:
await self.save_packages_status([p for p in curr_pkgs if p != pkg_name]) await self.save_packages_status([p for p in curr_pkgs if p != pkg_name])

View File

@ -31,26 +31,29 @@ class CogManager:
install_path=str(bot_dir.resolve() / "cogs") install_path=str(bot_dir.resolve() / "cogs")
) )
self._paths = set(list(self.conf.paths()) + list(paths)) self._paths = list(paths)
@property async def paths(self) -> Tuple[Path, ...]:
def paths(self) -> Tuple[Path, ...]:
""" """
This will return all currently valid path directories. This will return all currently valid path directories.
:return: :return:
""" """
paths = [Path(p) for p in self._paths] conf_paths = await self.conf.paths()
other_paths = self._paths
all_paths = set(list(conf_paths) + list(other_paths))
paths = [Path(p) for p in all_paths]
if self.install_path not in paths: if self.install_path not in paths:
paths.insert(0, self.install_path) paths.insert(0, await self.install_path())
return tuple(p.resolve() for p in paths if p.is_dir()) return tuple(p.resolve() for p in paths if p.is_dir())
@property async def install_path(self) -> Path:
def install_path(self) -> Path:
""" """
Returns the install path for 3rd party cogs. Returns the install path for 3rd party cogs.
:return: :return:
""" """
p = Path(self.conf.install_path()) p = Path(await self.conf.install_path())
return p.resolve() return p.resolve()
async def set_install_path(self, path: Path) -> Path: async def set_install_path(self, path: Path) -> Path:
@ -99,10 +102,10 @@ class CogManager:
if not path.is_dir(): if not path.is_dir():
raise InvalidPath("'{}' is not a valid directory.".format(path)) raise InvalidPath("'{}' is not a valid directory.".format(path))
if path == self.install_path: if path == await self.install_path():
raise ValueError("Cannot add the install path as an additional path.") raise ValueError("Cannot add the install path as an additional path.")
all_paths = set(self.paths + (path, )) all_paths = set(await self.paths() + (path, ))
# noinspection PyTypeChecker # noinspection PyTypeChecker
await self.set_paths(all_paths) await self.set_paths(all_paths)
@ -113,7 +116,7 @@ class CogManager:
:return: :return:
""" """
path = self._ensure_path_obj(path) path = self._ensure_path_obj(path)
all_paths = list(self.paths) all_paths = list(await self.paths())
if path in all_paths: if path in all_paths:
all_paths.remove(path) # Modifies in place all_paths.remove(path) # Modifies in place
await self.set_paths(all_paths) await self.set_paths(all_paths)
@ -125,11 +128,10 @@ class CogManager:
:param paths_: :param paths_:
:return: :return:
""" """
self._paths = paths_
str_paths = [str(p) for p in paths_] str_paths = [str(p) for p in paths_]
await self.conf.paths.set(str_paths) await self.conf.paths.set(str_paths)
def find_cog(self, name: str) -> ModuleSpec: async def find_cog(self, name: str) -> ModuleSpec:
""" """
Finds a cog in the list of available path. Finds a cog in the list of available path.
@ -137,7 +139,7 @@ class CogManager:
:param name: :param name:
:return: :return:
""" """
resolved_paths = [str(p.resolve()) for p in self.paths] resolved_paths = [str(p.resolve()) for p in await self.paths()]
for finder, module_name, _ in pkgutil.iter_modules(resolved_paths): for finder, module_name, _ in pkgutil.iter_modules(resolved_paths):
if name == module_name: if name == module_name:
spec = finder.find_spec(name) spec = finder.find_spec(name)
@ -166,7 +168,7 @@ class CogManagerUI:
""" """
Lists current cog paths in order of priority. Lists current cog paths in order of priority.
""" """
install_path = ctx.bot.cog_mgr.install_path install_path = await ctx.bot.cog_mgr.install_path()
cog_paths = ctx.bot.cog_mgr.paths cog_paths = ctx.bot.cog_mgr.paths
cog_paths = [p for p in cog_paths if p != install_path] cog_paths = [p for p in cog_paths if p != install_path]
@ -204,7 +206,7 @@ class CogManagerUI:
Removes a path from the available cog paths given the path_number Removes a path from the available cog paths given the path_number
from !paths from !paths
""" """
cog_paths = ctx.bot.cog_mgr.paths cog_paths = await ctx.bot.cog_mgr.paths()
try: try:
to_remove = cog_paths[path_number] to_remove = cog_paths[path_number]
except IndexError: except IndexError:
@ -224,7 +226,7 @@ class CogManagerUI:
from_ -= 1 from_ -= 1
to -= 1 to -= 1
all_paths = list(ctx.bot.cog_mgr.paths) all_paths = list(await ctx.bot.cog_mgr.paths())
try: try:
to_move = all_paths.pop(from_) to_move = all_paths.pop(from_)
except IndexError: except IndexError:
@ -257,6 +259,6 @@ class CogManagerUI:
await ctx.send("That path does not exist.") await ctx.send("That path does not exist.")
return return
install_path = ctx.bot.cog_mgr.install_path install_path = await ctx.bot.cog_mgr.install_path()
await ctx.send("The bot will install new cogs to the `{}`" await ctx.send("The bot will install new cogs to the `{}`"
" directory.".format(install_path)) " directory.".format(install_path))

View File

@ -38,6 +38,14 @@ class Value:
def identifiers(self): def identifiers(self):
return tuple(str(i) for i in self._identifiers) return tuple(str(i) for i in self._identifiers)
async def _get(self, default):
driver = self.spawner.get_driver()
try:
ret = await driver.get(self.identifiers)
except KeyError:
return default or self.default
return ret
def __call__(self, default=None): def __call__(self, default=None):
""" """
Each :py:class:`Value` object is created by the :py:meth:`Group.__getattr__` method. Each :py:class:`Value` object is created by the :py:meth:`Group.__getattr__` method.
@ -46,25 +54,26 @@ class Value:
For example:: For example::
foo = conf.guild(some_guild).foo() foo = await conf.guild(some_guild).foo()
# Is equivalent to this # Is equivalent to this
group_obj = conf.guild(some_guild) group_obj = conf.guild(some_guild)
value_obj = conf.foo value_obj = conf.foo
foo = value_obj() foo = await value_obj()
.. important::
This is now, for all intents and purposes, a coroutine.
:param default: :param default:
This argument acts as an override for the registered default provided by :py:attr:`default`. This argument This argument acts as an override for the registered default provided by :py:attr:`default`. This argument
is ignored if its value is :python:`None`. is ignored if its value is :python:`None`.
:type default: Optional[object] :type default: Optional[object]
:return:
A coroutine object that must be awaited.
""" """
driver = self.spawner.get_driver() return self._get(default)
try:
ret = driver.get(self.identifiers)
except KeyError:
return default or self.default
return ret
async def set(self, value): async def set(self, value):
""" """
@ -182,7 +191,7 @@ class Group(Value):
return not isinstance(default, dict) return not isinstance(default, dict)
def get_attr(self, item: str, default=None, resolve=True): async def get_attr(self, item: str, default=None, resolve=True):
""" """
This is available to use as an alternative to using normal Python attribute access. It is required if you find This is available to use as an alternative to using normal Python attribute access. It is required if you find
a need for dynamic attribute access. a need for dynamic attribute access.
@ -198,7 +207,7 @@ class Group(Value):
user = ctx.author user = ctx.author
# Where the value of item is the name of the data field in Config # Where the value of item is the name of the data field in Config
await ctx.send(self.conf.user(user).get_attr(item)) await ctx.send(await self.conf.user(user).get_attr(item))
:param str item: :param str item:
The name of the data field in :py:class:`.Config`. The name of the data field in :py:class:`.Config`.
@ -211,20 +220,20 @@ class Group(Value):
""" """
value = getattr(self, item) value = getattr(self, item)
if resolve: if resolve:
return value(default=default) return await value(default=default)
else: else:
return value return value
def all(self) -> dict: async def all(self) -> dict:
""" """
This method allows you to get "all" of a particular group of data. It will return the dictionary of all data This method allows you to get "all" of a particular group of data. It will return the dictionary of all data
for a particular Guild/Channel/Role/User/Member etc. for a particular Guild/Channel/Role/User/Member etc.
:rtype: dict :rtype: dict
""" """
return self() return await self()
def all_from_kind(self) -> dict: async def all_from_kind(self) -> dict:
""" """
This method allows you to get all data from all entries in a given Kind. It will return a dictionary of Kind This method allows you to get all data from all entries in a given Kind. It will return a dictionary of Kind
ID's -> data. ID's -> data.
@ -232,7 +241,7 @@ class Group(Value):
:rtype: dict :rtype: dict
""" """
# noinspection PyTypeChecker # noinspection PyTypeChecker
return self._super_group() return await self._super_group()
async def set(self, value): async def set(self, value):
if not isinstance(value, dict): if not isinstance(value, dict):
@ -292,18 +301,18 @@ class MemberGroup(Group):
) )
return group_obj return group_obj
def all_guilds(self) -> dict: async def all_guilds(self) -> dict:
""" """
Returns a dict of :code:`GUILD_ID -> MEMBER_ID -> data`. Returns a dict of :code:`GUILD_ID -> MEMBER_ID -> data`.
:rtype: dict :rtype: dict
""" """
# noinspection PyTypeChecker # noinspection PyTypeChecker
return self._super_group() return await self._super_group()
def all(self) -> dict: async def all(self) -> dict:
# noinspection PyTypeChecker # noinspection PyTypeChecker
return self._guild_group() return await self._guild_group()
class Config: class Config:
@ -315,7 +324,7 @@ class Config:
however the process for accessing global data is a bit different. There is no :python:`global` method however the process for accessing global data is a bit different. There is no :python:`global` method
because global data is accessed by normal attribute access:: because global data is accessed by normal attribute access::
conf.foo() await conf.foo()
.. py:attribute:: cog_name .. py:attribute:: cog_name

View File

@ -29,7 +29,7 @@ class Core:
async def load(self, ctx, *, cog_name: str): async def load(self, ctx, *, cog_name: str):
"""Loads a package""" """Loads a package"""
try: try:
spec = ctx.bot.cog_mgr.find_cog(cog_name) spec = await ctx.bot.cog_mgr.find_cog(cog_name)
except NoModuleFound: except NoModuleFound:
await ctx.send("No module by that name was found in any" await ctx.send("No module by that name was found in any"
" cog path.") " cog path.")
@ -63,7 +63,7 @@ class Core:
ctx.bot.unload_extension(cog_name) ctx.bot.unload_extension(cog_name)
self.cleanup_and_refresh_modules(cog_name) self.cleanup_and_refresh_modules(cog_name)
try: try:
spec = ctx.bot.cog_mgr.find_cog(cog_name) spec = await ctx.bot.cog_mgr.find_cog(cog_name)
ctx.bot.load_extension(spec) ctx.bot.load_extension(spec)
except Exception as e: except Exception as e:
log.exception("Package reloading failed", exc_info=e) log.exception("Package reloading failed", exc_info=e)

View File

@ -5,7 +5,7 @@ class BaseDriver:
def get_driver(self): def get_driver(self):
raise NotImplementedError raise NotImplementedError
def get(self, identifiers: Tuple[str]): async def get(self, identifiers: Tuple[str]):
raise NotImplementedError raise NotImplementedError
async def set(self, identifiers: Tuple[str], value): async def set(self, identifiers: Tuple[str], value):

View File

@ -32,7 +32,7 @@ class JSON(BaseDriver):
def get_driver(self): def get_driver(self):
return self return self
def get(self, identifiers: Tuple[str]): async def get(self, identifiers: Tuple[str]):
partial = self.data partial = self.data
for i in identifiers: for i in identifiers:
partial = partial[i] partial = partial[i]

View File

@ -34,11 +34,11 @@ def init_events(bot, cli_flags):
if cli_flags.no_cogs is False: if cli_flags.no_cogs is False:
print("Loading packages...") print("Loading packages...")
failed = [] failed = []
packages = bot.db.packages() packages = await bot.db.packages()
for package in packages: for package in packages:
try: try:
spec = bot.cog_mgr.find_cog(package) spec = await bot.cog_mgr.find_cog(package)
bot.load_extension(spec) bot.load_extension(spec)
except Exception as e: except Exception as e:
log.exception("Failed to load package {}".format(package), log.exception("Failed to load package {}".format(package),

27
main.py
View File

@ -73,6 +73,17 @@ def determine_main_folder() -> Path:
return Path(os.path.dirname(__file__)).resolve() return Path(os.path.dirname(__file__)).resolve()
async def _get_prefix_and_token(red, indict):
"""
Again, please blame <@269933075037814786> for this.
:param indict:
:return:
"""
indict['token'] = await red.db.token()
indict['prefix'] = await red.db.prefix()
indict['enable_sentry'] = await red.db.enable_sentry()
if __name__ == '__main__': if __name__ == '__main__':
cli_flags = parse_cli_flags() cli_flags = parse_cli_flags()
log, sentry_log = init_loggers(cli_flags) log, sentry_log = init_loggers(cli_flags)
@ -89,8 +100,13 @@ if __name__ == '__main__':
if cli_flags.dev: if cli_flags.dev:
red.add_cog(Dev()) red.add_cog(Dev())
token = os.environ.get("RED_TOKEN", red.db.token()) loop = asyncio.get_event_loop()
prefix = cli_flags.prefix or red.db.prefix() tmp_data = {}
loop.run_until_complete(_get_prefix_and_token(red, tmp_data))
token = os.environ.get("RED_TOKEN", tmp_data['token'])
prefix = cli_flags.prefix or tmp_data['prefix']
enable_sentry = tmp_data['enable_sentry']
if token is None or not prefix: if token is None or not prefix:
if cli_flags.no_prompt is False: if cli_flags.no_prompt is False:
@ -102,13 +118,14 @@ if __name__ == '__main__':
log.critical("Token and prefix must be set in order to login.") log.critical("Token and prefix must be set in order to login.")
sys.exit(1) sys.exit(1)
if red.db.enable_sentry() is None: if enable_sentry is None:
ask_sentry(red) ask_sentry(red)
if red.db.enable_sentry(): loop.run_until_complete(_get_prefix_and_token(red, tmp_data))
if tmp_data['enable_sentry']:
init_sentry_logging(sentry_log) init_sentry_logging(sentry_log)
loop = asyncio.get_event_loop()
cleanup_tasks = True cleanup_tasks = True
try: try:

View File

@ -16,12 +16,14 @@ def test_is_valid_alias_name(alias):
assert alias.is_valid_alias_name("not valid name") is False assert alias.is_valid_alias_name("not valid name") is False
def test_empty_guild_aliases(alias, empty_guild): @pytest.mark.asyncio
assert list(alias.unloaded_aliases(empty_guild)) == [] async def test_empty_guild_aliases(alias, empty_guild):
assert list(await alias.unloaded_aliases(empty_guild)) == []
def test_empty_global_aliases(alias): @pytest.mark.asyncio
assert list(alias.unloaded_global_aliases()) == [] async def test_empty_global_aliases(alias):
assert list(await alias.unloaded_global_aliases()) == []
async def create_test_guild_alias(alias, ctx): async def create_test_guild_alias(alias, ctx):
@ -36,7 +38,7 @@ async def create_test_global_alias(alias, ctx):
async def test_add_guild_alias(alias, ctx): async def test_add_guild_alias(alias, ctx):
await create_test_guild_alias(alias, ctx) await create_test_guild_alias(alias, ctx)
is_alias, alias_obj = alias.is_alias(ctx.guild, "test") is_alias, alias_obj = await alias.is_alias(ctx.guild, "test")
assert is_alias is True assert is_alias is True
assert alias_obj.global_ is False assert alias_obj.global_ is False
@ -44,19 +46,19 @@ async def test_add_guild_alias(alias, ctx):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_guild_alias(alias, ctx): async def test_delete_guild_alias(alias, ctx):
await create_test_guild_alias(alias, ctx) await create_test_guild_alias(alias, ctx)
is_alias, _ = alias.is_alias(ctx.guild, "test") is_alias, _ = await alias.is_alias(ctx.guild, "test")
assert is_alias is True assert is_alias is True
await alias.delete_alias(ctx, "test") await alias.delete_alias(ctx, "test")
is_alias, _ = alias.is_alias(ctx.guild, "test") is_alias, _ = await alias.is_alias(ctx.guild, "test")
assert is_alias is False assert is_alias is False
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_add_global_alias(alias, ctx): async def test_add_global_alias(alias, ctx):
await create_test_global_alias(alias, ctx) await create_test_global_alias(alias, ctx)
is_alias, alias_obj = alias.is_alias(ctx.guild, "test") is_alias, alias_obj = await alias.is_alias(ctx.guild, "test")
assert is_alias is True assert is_alias is True
assert alias_obj.global_ is True assert alias_obj.global_ is True
@ -65,7 +67,7 @@ async def test_add_global_alias(alias, ctx):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_delete_global_alias(alias, ctx): async def test_delete_global_alias(alias, ctx):
await create_test_global_alias(alias, ctx) await create_test_global_alias(alias, ctx)
is_alias, alias_obj = alias.is_alias(ctx.guild, "test") is_alias, alias_obj = await alias.is_alias(ctx.guild, "test")
assert is_alias is True assert is_alias is True
assert alias_obj.global_ is True assert alias_obj.global_ is True

View File

@ -11,13 +11,14 @@ def bank(config):
return bank return bank
def test_bank_register(bank, ctx): @pytest.mark.asyncio
default_bal = bank.get_default_balance(ctx.guild) async def test_bank_register(bank, ctx):
assert default_bal == bank.get_account(ctx.author).balance default_bal = await bank.get_default_balance(ctx.guild)
assert default_bal == (await bank.get_account(ctx.author)).balance
async def has_account(member, bank): async def has_account(member, bank):
balance = bank.get_balance(member) balance = await bank.get_balance(member)
if balance == 0: if balance == 0:
balance = 1 balance = 1
await bank.set_balance(member, balance) await bank.set_balance(member, balance)
@ -27,11 +28,11 @@ async def has_account(member, bank):
async def test_bank_transfer(bank, member_factory): async def test_bank_transfer(bank, member_factory):
mbr1 = member_factory.get() mbr1 = member_factory.get()
mbr2 = member_factory.get() mbr2 = member_factory.get()
bal1 = bank.get_account(mbr1).balance bal1 = (await bank.get_account(mbr1)).balance
bal2 = bank.get_account(mbr2).balance bal2 = (await bank.get_account(mbr2)).balance
await bank.transfer_credits(mbr1, mbr2, 50) await bank.transfer_credits(mbr1, mbr2, 50)
newbal1 = bank.get_account(mbr1).balance newbal1 = (await bank.get_account(mbr1)).balance
newbal2 = bank.get_account(mbr2).balance newbal2 = (await bank.get_account(mbr2)).balance
assert bal1 - 50 == newbal1 assert bal1 - 50 == newbal1
assert bal2 + 50 == newbal2 assert bal2 + 50 == newbal2
@ -40,16 +41,16 @@ async def test_bank_transfer(bank, member_factory):
async def test_bank_set(bank, member_factory): async def test_bank_set(bank, member_factory):
mbr = member_factory.get() mbr = member_factory.get()
await bank.set_balance(mbr, 250) await bank.set_balance(mbr, 250)
acc = bank.get_account(mbr) acc = await bank.get_account(mbr)
assert acc.balance == 250 assert acc.balance == 250
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_bank_can_spend(bank, member_factory): async def test_bank_can_spend(bank, member_factory):
mbr = member_factory.get() mbr = member_factory.get()
canspend = bank.can_spend(mbr, 50) canspend = await bank.can_spend(mbr, 50)
assert canspend == (50 < bank.get_default_balance(mbr.guild)) assert canspend == (50 < await bank.get_default_balance(mbr.guild))
await bank.set_balance(mbr, 200) await bank.set_balance(mbr, 200)
acc = bank.get_account(mbr) acc = await bank.get_account(mbr)
canspendnow = bank.can_spend(mbr, 100) canspendnow = await bank.can_spend(mbr, 100)
assert canspendnow assert canspendnow

View File

@ -14,16 +14,17 @@ def default_dir(red):
return red.main_dir return red.main_dir
def test_ensure_cogs_in_paths(cog_mgr, default_dir): @pytest.mark.asyncio
async def test_ensure_cogs_in_paths(cog_mgr, default_dir):
cogs_dir = default_dir / 'cogs' cogs_dir = default_dir / 'cogs'
assert cogs_dir in cog_mgr.paths assert cogs_dir in await cog_mgr.paths()
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_install_path_set(cog_mgr: cog_manager.CogManager, tmpdir): async def test_install_path_set(cog_mgr: cog_manager.CogManager, tmpdir):
path = Path(str(tmpdir)) path = Path(str(tmpdir))
await cog_mgr.set_install_path(path) await cog_mgr.set_install_path(path)
assert cog_mgr.install_path == path assert await cog_mgr.install_path() == path
@pytest.mark.asyncio @pytest.mark.asyncio
@ -38,7 +39,7 @@ async def test_install_path_set_bad(cog_mgr):
async def test_add_path(cog_mgr, tmpdir): async def test_add_path(cog_mgr, tmpdir):
path = Path(str(tmpdir)) path = Path(str(tmpdir))
await cog_mgr.add_path(path) await cog_mgr.add_path(path)
assert path in cog_mgr.paths assert path in await cog_mgr.paths()
@pytest.mark.asyncio @pytest.mark.asyncio
@ -54,4 +55,4 @@ async def test_remove_path(cog_mgr, tmpdir):
path = Path(str(tmpdir)) path = Path(str(tmpdir))
await cog_mgr.add_path(path) await cog_mgr.add_path(path)
await cog_mgr.remove_path(path) await cog_mgr.remove_path(path)
assert path not in cog_mgr.paths assert path not in await cog_mgr.paths()

View File

@ -2,10 +2,11 @@ import pytest
#region Register Tests #region Register Tests
def test_config_register_global(config): @pytest.mark.asyncio
async def test_config_register_global(config):
config.register_global(enabled=False) config.register_global(enabled=False)
assert config.defaults["GLOBAL"]["enabled"] is False assert config.defaults["GLOBAL"]["enabled"] is False
assert config.enabled() is False assert await config.enabled() is False
def test_config_register_global_badvalues(config): def test_config_register_global_badvalues(config):
@ -13,61 +14,69 @@ def test_config_register_global_badvalues(config):
config.register_global(**{"invalid var name": True}) config.register_global(**{"invalid var name": True})
def test_config_register_guild(config, empty_guild): @pytest.mark.asyncio
async def test_config_register_guild(config, empty_guild):
config.register_guild(enabled=False, some_list=[], some_dict={}) config.register_guild(enabled=False, some_list=[], some_dict={})
assert config.defaults[config.GUILD]["enabled"] is False assert config.defaults[config.GUILD]["enabled"] is False
assert config.defaults[config.GUILD]["some_list"] == [] assert config.defaults[config.GUILD]["some_list"] == []
assert config.defaults[config.GUILD]["some_dict"] == {} assert config.defaults[config.GUILD]["some_dict"] == {}
assert config.guild(empty_guild).enabled() is False assert await config.guild(empty_guild).enabled() is False
assert config.guild(empty_guild).some_list() == [] assert await config.guild(empty_guild).some_list() == []
assert config.guild(empty_guild).some_dict() == {} assert await config.guild(empty_guild).some_dict() == {}
def test_config_register_channel(config, empty_channel): @pytest.mark.asyncio
async def test_config_register_channel(config, empty_channel):
config.register_channel(enabled=False) config.register_channel(enabled=False)
assert config.defaults[config.CHANNEL]["enabled"] is False assert config.defaults[config.CHANNEL]["enabled"] is False
assert config.channel(empty_channel).enabled() is False assert await config.channel(empty_channel).enabled() is False
def test_config_register_role(config, empty_role): @pytest.mark.asyncio
async def test_config_register_role(config, empty_role):
config.register_role(enabled=False) config.register_role(enabled=False)
assert config.defaults[config.ROLE]["enabled"] is False assert config.defaults[config.ROLE]["enabled"] is False
assert config.role(empty_role).enabled() is False assert await config.role(empty_role).enabled() is False
def test_config_register_member(config, empty_member): @pytest.mark.asyncio
async def test_config_register_member(config, empty_member):
config.register_member(some_number=-1) config.register_member(some_number=-1)
assert config.defaults[config.MEMBER]["some_number"] == -1 assert config.defaults[config.MEMBER]["some_number"] == -1
assert config.member(empty_member).some_number() == -1 assert await config.member(empty_member).some_number() == -1
def test_config_register_user(config, empty_user): @pytest.mark.asyncio
async def test_config_register_user(config, empty_user):
config.register_user(some_value=None) config.register_user(some_value=None)
assert config.defaults[config.USER]["some_value"] is None assert config.defaults[config.USER]["some_value"] is None
assert config.user(empty_user).some_value() is None assert await config.user(empty_user).some_value() is None
def test_config_force_register_global(config_fr): @pytest.mark.asyncio
async def test_config_force_register_global(config_fr):
with pytest.raises(AttributeError): with pytest.raises(AttributeError):
config_fr.enabled() await config_fr.enabled()
config_fr.register_global(enabled=True) config_fr.register_global(enabled=True)
assert config_fr.enabled() is True assert await config_fr.enabled() is True
#endregion #endregion
# Test nested registration # Test nested registration
def test_nested_registration(config): @pytest.mark.asyncio
async def test_nested_registration(config):
config.register_global(foo__bar__baz=False) config.register_global(foo__bar__baz=False)
assert config.foo.bar.baz() is False assert await config.foo.bar.baz() is False
def test_nested_registration_asdict(config): @pytest.mark.asyncio
async def test_nested_registration_asdict(config):
defaults = {'bar': {'baz': False}} defaults = {'bar': {'baz': False}}
config.register_global(foo=defaults) config.register_global(foo=defaults)
assert config.foo.bar.baz() is False assert await config.foo.bar.baz() is False
@pytest.mark.asyncio @pytest.mark.asyncio
@ -75,20 +84,22 @@ async def test_nested_registration_and_changing(config):
defaults = {'bar': {'baz': False}} defaults = {'bar': {'baz': False}}
config.register_global(foo=defaults) config.register_global(foo=defaults)
assert config.foo.bar.baz() is False assert await config.foo.bar.baz() is False
with pytest.raises(ValueError): with pytest.raises(ValueError):
await config.foo.set(True) await config.foo.set(True)
def test_doubleset_default(config): @pytest.mark.asyncio
async def test_doubleset_default(config):
config.register_global(foo=True) config.register_global(foo=True)
config.register_global(foo=False) config.register_global(foo=False)
assert config.foo() is False assert await config.foo() is False
def test_nested_registration_multidict(config): @pytest.mark.asyncio
async def test_nested_registration_multidict(config):
defaults = { defaults = {
"foo": { "foo": {
"bar": { "bar": {
@ -99,8 +110,8 @@ def test_nested_registration_multidict(config):
} }
config.register_global(**defaults) config.register_global(**defaults)
assert config.foo.bar.baz() is True assert await config.foo.bar.baz() is True
assert config.blah() is True assert await config.blah() is True
def test_nested_group_value_badreg(config): def test_nested_group_value_badreg(config):
@ -109,56 +120,66 @@ def test_nested_group_value_badreg(config):
config.register_global(foo__bar=False) config.register_global(foo__bar=False)
def test_nested_toplevel_reg(config): @pytest.mark.asyncio
async def test_nested_toplevel_reg(config):
defaults = {'bar': True, 'baz': False} defaults = {'bar': True, 'baz': False}
config.register_global(foo=defaults) config.register_global(foo=defaults)
assert config.foo.bar() is True assert await config.foo.bar() is True
assert config.foo.baz() is False assert await config.foo.baz() is False
def test_nested_overlapping(config): @pytest.mark.asyncio
async def test_nested_overlapping(config):
config.register_global(foo__bar=True) config.register_global(foo__bar=True)
config.register_global(foo__baz=False) config.register_global(foo__baz=False)
assert config.foo.bar() is True assert await config.foo.bar() is True
assert config.foo.baz() is False assert await config.foo.baz() is False
def test_nesting_nofr(config): @pytest.mark.asyncio
async def test_nesting_nofr(config):
config.register_global(foo={}) config.register_global(foo={})
assert config.foo.bar() is None assert await config.foo.bar() is None
assert config.foo() == {} assert await config.foo() == {}
# region Default Value Overrides # region Default Value Overrides
def test_global_default_override(config): @pytest.mark.asyncio
assert config.enabled(True) is True async def test_global_default_override(config):
assert await config.enabled(True) is True
def test_global_default_nofr(config): @pytest.mark.asyncio
assert config.nofr() is None async def test_global_default_nofr(config):
assert config.nofr(True) is True assert await config.nofr() is None
assert await config.nofr(True) is True
def test_guild_default_override(config, empty_guild): @pytest.mark.asyncio
assert config.guild(empty_guild).enabled(True) is True async def test_guild_default_override(config, empty_guild):
assert await config.guild(empty_guild).enabled(True) is True
def test_channel_default_override(config, empty_channel): @pytest.mark.asyncio
assert config.channel(empty_channel).enabled(True) is True async def test_channel_default_override(config, empty_channel):
assert await config.channel(empty_channel).enabled(True) is True
def test_role_default_override(config, empty_role): @pytest.mark.asyncio
assert config.role(empty_role).enabled(True) is True async def test_role_default_override(config, empty_role):
assert await config.role(empty_role).enabled(True) is True
def test_member_default_override(config, empty_member): @pytest.mark.asyncio
assert config.member(empty_member).enabled(True) is True async def test_member_default_override(config, empty_member):
assert await config.member(empty_member).enabled(True) is True
def test_user_default_override(config, empty_user): @pytest.mark.asyncio
assert config.user(empty_user).some_value(True) is True async def test_user_default_override(config, empty_user):
assert await config.user(empty_user).some_value(True) is True
#endregion #endregion
@ -166,32 +187,32 @@ def test_user_default_override(config, empty_user):
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_global(config): async def test_set_global(config):
await config.enabled.set(True) await config.enabled.set(True)
assert config.enabled() is True assert await config.enabled() is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_guild(config, empty_guild): async def test_set_guild(config, empty_guild):
await config.guild(empty_guild).enabled.set(True) await config.guild(empty_guild).enabled.set(True)
assert config.guild(empty_guild).enabled() is True assert await config.guild(empty_guild).enabled() is True
curr_list = config.guild(empty_guild).some_list([1, 2, 3]) curr_list = await config.guild(empty_guild).some_list([1, 2, 3])
assert curr_list == [1, 2, 3] assert curr_list == [1, 2, 3]
curr_list.append(4) curr_list.append(4)
await config.guild(empty_guild).some_list.set(curr_list) await config.guild(empty_guild).some_list.set(curr_list)
assert config.guild(empty_guild).some_list() == curr_list assert await config.guild(empty_guild).some_list() == curr_list
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_channel(config, empty_channel): async def test_set_channel(config, empty_channel):
await config.channel(empty_channel).enabled.set(True) await config.channel(empty_channel).enabled.set(True)
assert config.channel(empty_channel).enabled() is True assert await config.channel(empty_channel).enabled() is True
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_set_channel_no_register(config, empty_channel): async def test_set_channel_no_register(config, empty_channel):
await config.channel(empty_channel).no_register.set(True) await config.channel(empty_channel).no_register.set(True)
assert config.channel(empty_channel).no_register() is True assert await config.channel(empty_channel).no_register() is True
#endregion #endregion
@ -200,11 +221,12 @@ async def test_set_channel_no_register(config, empty_channel):
async def test_set_dynamic_attr(config): async def test_set_dynamic_attr(config):
await config.set_attr("foobar", True) await config.set_attr("foobar", True)
assert config.foobar() is True assert await config.foobar() is True
def test_get_dynamic_attr(config): @pytest.mark.asyncio
assert config.get_attr("foobaz", True) is True async def test_get_dynamic_attr(config):
assert await config.get_attr("foobaz", True) is True
# Member Group testing # Member Group testing
@ -212,7 +234,7 @@ def test_get_dynamic_attr(config):
async def test_membergroup_allguilds(config, empty_member): async def test_membergroup_allguilds(config, empty_member):
await config.member(empty_member).foo.set(False) await config.member(empty_member).foo.set(False)
all_servers = config.member(empty_member).all_guilds() all_servers = await config.member(empty_member).all_guilds()
assert str(empty_member.guild.id) in all_servers assert str(empty_member.guild.id) in all_servers
@ -220,7 +242,7 @@ async def test_membergroup_allguilds(config, empty_member):
async def test_membergroup_allmembers(config, empty_member): async def test_membergroup_allmembers(config, empty_member):
await config.member(empty_member).foo.set(False) await config.member(empty_member).foo.set(False)
all_members = config.member(empty_member).all() all_members = await config.member(empty_member).all()
assert str(empty_member.id) in all_members assert str(empty_member.id) in all_members
@ -232,13 +254,13 @@ async def test_global_clear(config):
await config.foo.set(False) await config.foo.set(False)
await config.bar.set(True) await config.bar.set(True)
assert config.foo() is False assert await config.foo() is False
assert config.bar() is True assert await config.bar() is True
await config.clear() await config.clear()
assert config.foo() is True assert await config.foo() is True
assert config.bar() is False assert await config.bar() is False
@pytest.mark.asyncio @pytest.mark.asyncio
@ -247,17 +269,17 @@ async def test_member_clear(config, member_factory):
m1 = member_factory.get() m1 = member_factory.get()
await config.member(m1).foo.set(False) await config.member(m1).foo.set(False)
assert config.member(m1).foo() is False assert await config.member(m1).foo() is False
m2 = member_factory.get() m2 = member_factory.get()
await config.member(m2).foo.set(False) await config.member(m2).foo.set(False)
assert config.member(m2).foo() is False assert await config.member(m2).foo() is False
assert m1.guild.id != m2.guild.id assert m1.guild.id != m2.guild.id
await config.member(m1).clear() await config.member(m1).clear()
assert config.member(m1).foo() is True assert await config.member(m1).foo() is True
assert config.member(m2).foo() is False assert await config.member(m2).foo() is False
@pytest.mark.asyncio @pytest.mark.asyncio
@ -269,11 +291,11 @@ async def test_member_clear_all(config, member_factory):
server_ids.append(member.guild.id) server_ids.append(member.guild.id)
member = member_factory.get() member = member_factory.get()
assert len(config.member(member).all_guilds()) == len(server_ids) assert len(await config.member(member).all_guilds()) == len(server_ids)
await config.member(member).clear_all() await config.member(member).clear_all()
assert len(config.member(member).all_guilds()) == 0 assert len(await config.member(member).all_guilds()) == 0
# Get All testing # Get All testing
@ -284,7 +306,7 @@ async def test_user_get_all_from_kind(config, user_factory):
await config.user(user).foo.set(True) await config.user(user).foo.set(True)
user = user_factory.get() user = user_factory.get()
all_data = config.user(user).all_from_kind() all_data = await config.user(user).all_from_kind()
assert len(all_data) == 5 assert len(all_data) == 5
@ -294,4 +316,4 @@ async def test_user_getalldata(config, user_factory):
user = user_factory.get() user = user_factory.get()
await config.user(user).foo.set(False) await config.user(user).foo.set(False)
assert "foo" in config.user(user).all() assert "foo" in await config.user(user).all()