diff --git a/cogs/alias/alias.py b/cogs/alias/alias.py index 959ede213..23ae4df29 100644 --- a/cogs/alias/alias.py +++ b/cogs/alias/alias.py @@ -38,26 +38,26 @@ class Alias: self._aliases.register_global(**self.default_global_settings) self._aliases.register_guild(**self.default_guild_settings) - def unloaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]: - return (AliasEntry.from_json(d) for d in self._aliases.guild(guild).entries()) + async def unloaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]: + return (AliasEntry.from_json(d) for d in (await self._aliases.guild(guild).entries())) - def unloaded_global_aliases(self) -> Generator[AliasEntry, None, None]: - return (AliasEntry.from_json(d) for d in self._aliases.entries()) + async def unloaded_global_aliases(self) -> Generator[AliasEntry, None, None]: + return (AliasEntry.from_json(d) for d in (await self._aliases.entries())) - def loaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]: + async def loaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]: return (AliasEntry.from_json(d, bot=self.bot) - for d in self._aliases.guild(guild).entries()) + for d in (await self._aliases.guild(guild).entries())) - def loaded_global_aliases(self) -> Generator[AliasEntry, None, None]: - return (AliasEntry.from_json(d, bot=self.bot) for d in self._aliases.entries()) + async def loaded_global_aliases(self) -> Generator[AliasEntry, None, None]: + return (AliasEntry.from_json(d, bot=self.bot) for d in (await self._aliases.entries())) - def is_alias(self, guild: discord.Guild, alias_name: str, + async def is_alias(self, guild: discord.Guild, alias_name: str, server_aliases: Iterable[AliasEntry]=()) -> (bool, AliasEntry): if not server_aliases: - server_aliases = self.unloaded_aliases(guild) + server_aliases = await self.unloaded_aliases(guild) - global_aliases = self.unloaded_global_aliases() + global_aliases = await self.unloaded_global_aliases() for aliases in (server_aliases, global_aliases): for alias in aliases: @@ -79,11 +79,11 @@ class Alias: alias = AliasEntry(alias_name, command, ctx.author, global_=global_) if global_: - curr_aliases = self._aliases.entries() + curr_aliases = await self._aliases.entries() curr_aliases.append(alias.to_json()) await self._aliases.entries.set(curr_aliases) else: - curr_aliases = self._aliases.guild(ctx.guild).entries() + curr_aliases = await self._aliases.guild(ctx.guild).entries() curr_aliases.append(alias.to_json()) await self._aliases.guild(ctx.guild).entries.set(curr_aliases) @@ -94,10 +94,10 @@ class Alias: async def delete_alias(self, ctx: commands.Context, alias_name: str, global_: bool=False) -> bool: if global_: - aliases = self.unloaded_global_aliases() + aliases = await self.unloaded_global_aliases() setter_func = self._aliases.entries.set else: - aliases = self.unloaded_aliases(ctx.guild) + aliases = await self.unloaded_aliases(ctx.guild) setter_func = self._aliases.guild(ctx.guild).entries.set did_delete_alias = False @@ -161,7 +161,7 @@ class Alias: except IndexError: return False - is_alias, alias = self.is_alias(message.guild, potential_alias, server_aliases=aliases) + is_alias, alias = await self.is_alias(message.guild, potential_alias, server_aliases=aliases) if is_alias: await self.call_alias(message, prefix, alias) @@ -206,7 +206,7 @@ class Alias: " name is already a command on this bot.").format(alias_name)) return - is_alias, _ = self.is_alias(ctx.guild, alias_name) + is_alias, _ = await self.is_alias(ctx.guild, alias_name) if is_alias: await ctx.send(("You attempted to create a new alias" " with the name {} but that" @@ -285,7 +285,7 @@ class Alias: @commands.guild_only() async def _show_alias(self, ctx: commands.Context, alias_name: str): """Shows what command the alias executes.""" - is_alias, alias = self.is_alias(ctx.guild, alias_name) + is_alias, alias = await self.is_alias(ctx.guild, alias_name) if is_alias: await ctx.send(("The `{}` alias will execute the" @@ -299,7 +299,7 @@ class Alias: """ Deletes an existing alias on this server. """ - aliases = self.unloaded_aliases(ctx.guild) + aliases = await self.unloaded_aliases(ctx.guild) try: next(aliases) except StopIteration: @@ -317,7 +317,7 @@ class Alias: """ Deletes an existing global alias. """ - aliases = self.unloaded_global_aliases() + aliases = await self.unloaded_global_aliases() try: next(aliases) except StopIteration: @@ -336,7 +336,7 @@ class Alias: """ Lists the available aliases on this server. """ - names = ["Aliases:", ] + sorted(["+ " + a.name for a in self.unloaded_aliases(ctx.guild)]) + names = ["Aliases:", ] + sorted(["+ " + a.name for a in (await self.unloaded_aliases(ctx.guild))]) if len(names) == 0: await ctx.send("There are no aliases on this server.") else: @@ -347,16 +347,16 @@ class Alias: """ Lists the available global aliases on this bot. """ - names = ["Aliases:", ] + sorted(["+ " + a.name for a in self.unloaded_global_aliases()]) + names = ["Aliases:", ] + sorted(["+ " + a.name for a in await self.unloaded_global_aliases()]) if len(names) == 0: await ctx.send("There are no aliases on this server.") else: await ctx.send(box("\n".join(names), "diff")) async def on_message(self, message: discord.Message): - aliases = list(self.unloaded_global_aliases()) + aliases = list(await self.unloaded_global_aliases()) if message.guild is not None: - aliases = aliases + list(self.unloaded_aliases(message.guild)) + aliases = aliases + list(await self.unloaded_aliases(message.guild)) if len(aliases) == 0: return diff --git a/cogs/bank/bank.py b/cogs/bank/bank.py index dbdab2684..55b8084ca 100644 --- a/cogs/bank/bank.py +++ b/cogs/bank/bank.py @@ -6,7 +6,7 @@ from core.bot import Red # Only used for type hints def check_global_setting_guildowner(): async def pred(ctx: commands.Context): - if bank.is_global(): + if await bank.is_global(): return checks.is_owner() else: return checks.guildowner_or_permissions(administrator=True) @@ -15,7 +15,7 @@ def check_global_setting_guildowner(): def check_global_setting_admin(): async def pred(ctx: commands.Context): - if bank.is_global(): + if await bank.is_global(): return checks.is_owner() else: return checks.admin_or_permissions(manage_guild=True) @@ -43,7 +43,7 @@ class Bank: """Toggles whether the bank is global or not If the bank is global, it will become per-guild If the bank is per-guild, it will become global""" - cur_setting = bank.is_global() + cur_setting = await bank.is_global() await bank.set_global(not cur_setting, ctx.author) word = "per-guild" if cur_setting else "global" diff --git a/cogs/downloader/downloader.py b/cogs/downloader/downloader.py index 32d9a4d49..254e2c6b0 100644 --- a/cogs/downloader/downloader.py +++ b/cogs/downloader/downloader.py @@ -49,22 +49,20 @@ class Downloader: self._repo_manager = RepoManager(self.conf) - @property - def cog_install_path(self): + async def cog_install_path(self): """ Returns the current cog install path. :return: """ - return self.bot.cog_mgr.install_path + return await self.bot.cog_mgr.install_path() - @property - def installed_cogs(self) -> Tuple[Installable]: + async def installed_cogs(self) -> Tuple[Installable]: """ Returns the dictionary mapping cog name to install location and repo name. :return: """ - installed = self.conf.installed() + installed = await self.conf.installed() # noinspection PyTypeChecker return tuple(Installable.from_json(v) for v in installed) @@ -74,7 +72,7 @@ class Downloader: :param cog: :return: """ - installed = self.conf.installed() + installed = await self.conf.installed() cog_json = cog.to_json() if cog_json not in installed: @@ -87,7 +85,7 @@ class Downloader: :param cog: :return: """ - installed = self.conf.installed() + installed = await self.conf.installed() cog_json = cog.to_json() if cog_json in installed: @@ -102,7 +100,7 @@ class Downloader: """ failed = [] for cog in cogs: - if not await cog.copy_to(self.cog_install_path): + if not await cog.copy_to(await self.cog_install_path()): failed.append(cog) # noinspection PyTypeChecker @@ -249,7 +247,7 @@ class Downloader: " `{}`: `{}`".format(cog.name, cog.requirements)) return - await repo_name.install_cog(cog, self.cog_install_path) + await repo_name.install_cog(cog, await self.cog_install_path()) await self._add_to_installed(cog) @@ -266,7 +264,7 @@ class Downloader: # noinspection PyUnresolvedReferences,PyProtectedMember real_name = cog_name.name - poss_installed_path = self.cog_install_path / real_name + poss_installed_path = (await self.cog_install_path()) / real_name if poss_installed_path.exists(): await self._delete_cog(poss_installed_path) # noinspection PyTypeChecker @@ -284,7 +282,7 @@ class Downloader: """ if cog_name is None: updated = await self._repo_manager.update_all_repos() - installed_cogs = set(self.installed_cogs) + installed_cogs = set(await self.installed_cogs()) updated_cogs = set(cog for repo in updated.keys() for cog in repo.available_cogs) installed_and_updated = updated_cogs & installed_cogs @@ -325,14 +323,14 @@ class Downloader: msg = "Information on {}:\n{}".format(cog.name, cog.description or "") await ctx.send(box(msg)) - def is_installed(self, cog_name: str) -> (bool, Union[Installable, None]): + async def is_installed(self, cog_name: str) -> (bool, Union[Installable, None]): """ Checks to see if a cog with the given name was installed through Downloader. :param cog_name: :return: is_installed, Installable """ - for installable in self.installed_cogs: + for installable in await self.installed_cogs(): if installable.name == cog_name: return True, installable return False, None @@ -384,7 +382,7 @@ class Downloader: # Check if in installed cogs cog_name = self.cog_name_from_instance(command.instance) - installed, cog_installable = self.is_installed(cog_name) + installed, cog_installable = await self.is_installed(cog_name) if installed: msg = self.format_findcog_info(command_name, cog_installable) else: diff --git a/cogs/downloader/repo_manager.py b/cogs/downloader/repo_manager.py index ea2120f4c..d4729a228 100644 --- a/cogs/downloader/repo_manager.py +++ b/cogs/downloader/repo_manager.py @@ -430,7 +430,10 @@ class RepoManager: self.repos_folder = Path(__file__).parent / 'repos' - self._repos = self._load_repos() # str_name: Repo + self._repos = {} + + loop = asyncio.get_event_loop() + loop.run_until_complete(self._load_repos(set=True)) # str_name: Repo def does_repo_exist(self, name: str) -> bool: return name in self._repos @@ -494,7 +497,6 @@ class RepoManager: shutil.rmtree(str(repo.folder_path)) - repos = self.downloader_config.repos() try: del self._repos[name] except KeyError: @@ -518,11 +520,14 @@ class RepoManager: await self._save_repos() return ret - def _load_repos(self) -> MutableMapping[str, Repo]: - return { + async def _load_repos(self, set=False) -> MutableMapping[str, Repo]: + ret = { name: Repo.from_json(data) for name, data in - self.downloader_config.repos().items() + (await self.downloader_config.repos()).items() } + if set: + self._repos = ret + return ret async def _save_repos(self): repo_json_info = {name: r.to_json() for name, r in self._repos.items()} diff --git a/cogs/economy/economy.py b/cogs/economy/economy.py index b0ac80a53..5c38f3c5f 100644 --- a/cogs/economy/economy.py +++ b/cogs/economy/economy.py @@ -72,9 +72,9 @@ SLOT_PAYOUTS_MSG = ("Slot machine payouts:\n" def guild_only_check(): async def pred(ctx: commands.Context): - if bank.is_global(): + if await bank.is_global(): return True - elif not bank.is_global() and ctx.guild is not None: + elif not await bank.is_global() and ctx.guild is not None: return True else: return False @@ -146,8 +146,8 @@ class Economy: if user is None: user = ctx.author - bal = bank.get_balance(user) - currency = bank.get_currency_name(ctx.guild) + bal = await bank.get_balance(user) + currency = await bank.get_currency_name(ctx.guild) await ctx.send("{}'s balance is {} {}".format( user.display_name, bal, currency)) @@ -156,7 +156,7 @@ class Economy: async def transfer(self, ctx: commands.Context, to: discord.Member, amount: int): """Transfer currency to other users""" from_ = ctx.author - currency = bank.get_currency_name(ctx.guild) + currency = await bank.get_currency_name(ctx.guild) try: await bank.transfer_credits(from_, to, amount) @@ -206,12 +206,12 @@ class Economy: await ctx.send( "This will delete all bank accounts for {}.\nIf you're sure, type " "{}bank reset yes".format( - self.bot.user.name if bank.is_global() else "this guild", + self.bot.user.name if await bank.is_global() else "this guild", ctx.prefix ) ) else: - if bank.is_global(): + if await bank.is_global(): # Bank being global means that the check would cause only # the owner and any co-owners to be able to run the command # so if we're in the function, it's safe to assume that the @@ -232,18 +232,18 @@ class Economy: guild = ctx.guild cur_time = calendar.timegm(ctx.message.created_at.utctimetuple()) - credits_name = bank.get_currency_name(ctx.guild) - if bank.is_global(): - next_payday = self.config.user(author).next_payday() + credits_name = await bank.get_currency_name(ctx.guild) + if await bank.is_global(): + next_payday = await self.config.user(author).next_payday() if cur_time >= next_payday: - await bank.deposit_credits(author, self.config.PAYDAY_CREDITS()) - next_payday = cur_time + self.config.PAYDAY_TIME() + await bank.deposit_credits(author, await self.config.PAYDAY_CREDITS()) + next_payday = cur_time + await self.config.PAYDAY_TIME() await self.config.user(author).next_payday.set(next_payday) await ctx.send( "{} Here, take some {}. Enjoy! (+{}" " {}!)".format( author.mention, credits_name, - str(self.config.PAYDAY_CREDITS()), + str(await self.config.PAYDAY_CREDITS()), credits_name ) ) @@ -254,16 +254,16 @@ class Economy: " wait {}.".format(author.mention, dtime) ) else: - next_payday = self.config.member(author).next_payday() + next_payday = await self.config.member(author).next_payday() if cur_time >= next_payday: - await bank.deposit_credits(author, self.config.guild(guild).PAYDAY_CREDITS()) - next_payday = cur_time + self.config.guild(guild).PAYDAY_TIME() + await bank.deposit_credits(author, await self.config.guild(guild).PAYDAY_CREDITS()) + next_payday = cur_time + await self.config.guild(guild).PAYDAY_TIME() await self.config.member(author).next_payday.set(next_payday) await ctx.send( "{} Here, take some {}. Enjoy! (+{}" " {}!)".format( author.mention, credits_name, - str(self.config.guild(guild).PAYDAY_CREDITS()), + str(await self.config.guild(guild).PAYDAY_CREDITS()), credits_name)) else: dtime = self.display_time(next_payday - cur_time) @@ -282,10 +282,10 @@ class Economy: if top < 1: top = 10 if bank.is_global(): - bank_sorted = sorted(bank.get_global_accounts(ctx.author), + bank_sorted = sorted(await bank.get_global_accounts(ctx.author), key=lambda x: x.balance, reverse=True) else: - bank_sorted = sorted(bank.get_guild_accounts(guild), + bank_sorted = sorted(await bank.get_guild_accounts(guild), key=lambda x: x.balance, reverse=True) if len(bank_sorted) < top: top = len(bank_sorted) @@ -320,14 +320,14 @@ class Economy: author = ctx.author guild = ctx.guild channel = ctx.channel - if bank.is_global(): - valid_bid = self.config.SLOT_MIN() <= bid <= self.config.SLOT_MAX() - slot_time = self.config.SLOT_TIME() - last_slot = self.config.user(author).last_slot() + if await bank.is_global(): + valid_bid = await self.config.SLOT_MIN() <= bid <= await self.config.SLOT_MAX() + slot_time = await self.config.SLOT_TIME() + last_slot = await self.config.user(author).last_slot() else: - valid_bid = self.config.guild(guild).SLOT_MIN() <= bid <= self.config.guild(guild).SLOT_MAX() - slot_time = self.config.guild(guild).SLOT_TIME() - last_slot = self.config.member(author).last_slot() + valid_bid = await self.config.guild(guild).SLOT_MIN() <= bid <= await self.config.guild(guild).SLOT_MAX() + slot_time = await self.config.guild(guild).SLOT_TIME() + last_slot = await self.config.member(author).last_slot() now = calendar.timegm(ctx.message.created_at.utctimetuple()) if (now - last_slot) < slot_time: @@ -336,10 +336,10 @@ class Economy: if not valid_bid: await ctx.send("That's an invalid bid amount, sorry :/") return - if not bank.can_spend(author, bid): + if not await bank.can_spend(author, bid): await ctx.send("You ain't got enough money, friend.") return - if bank.is_global(): + if await bank.is_global(): await self.config.user(author).last_slot.set(now) else: await self.config.member(author).last_slot.set(now) @@ -379,7 +379,7 @@ class Economy: payout = PAYOUTS["2 symbols"] if payout: - then = bank.get_balance(author) + then = await bank.get_balance(author) pay = payout["payout"](bid) now = then - bid + pay await bank.set_balance(author, now) @@ -387,7 +387,7 @@ class Economy: "".format(slot, author.mention, payout["phrase"], bid, then, now)) else: - then = bank.get_balance(author) + then = await bank.get_balance(author) await bank.withdraw_credits(author, bid) now = then - bid await channel.send("{}\n{} Nothing!\nYour bid: {}\n{} → {}!" @@ -402,18 +402,18 @@ class Economy: if ctx.invoked_subcommand is None: await self.bot.send_cmd_help(ctx) if bank.is_global(): - slot_min = self.config.SLOT_MIN() - slot_max = self.config.SLOT_MAX() - slot_time = self.config.SLOT_TIME() - payday_time = self.config.PAYDAY_TIME() - payday_amount = self.config.PAYDAY_CREDITS() + slot_min = await self.config.SLOT_MIN() + slot_max = await self.config.SLOT_MAX() + slot_time = await self.config.SLOT_TIME() + payday_time = await self.config.PAYDAY_TIME() + payday_amount = await self.config.PAYDAY_CREDITS() else: - slot_min = self.config.guild(guild).SLOT_MIN() - slot_max = self.config.guild(guild).SLOT_MAX() - slot_time = self.config.guild(guild).SLOT_TIME() - payday_time = self.config.guild(guild).PAYDAY_TIME() - payday_amount = self.config.guild(guild).PAYDAY_CREDITS() - register_amount = bank.get_default_balance(guild) + slot_min = await self.config.guild(guild).SLOT_MIN() + slot_max = await self.config.guild(guild).SLOT_MAX() + slot_time = await self.config.guild(guild).SLOT_TIME() + payday_time = await self.config.guild(guild).PAYDAY_TIME() + payday_amount = await self.config.guild(guild).PAYDAY_CREDITS() + register_amount = await bank.get_default_balance(guild) msg = box( "Minimum slot bid: {}\n" "Maximum slot bid: {}\n" @@ -436,24 +436,24 @@ class Economy: await ctx.send('Invalid min bid amount.') return guild = ctx.guild - if bank.is_global(): + if await bank.is_global(): await self.config.SLOT_MIN.set(bid) else: await self.config.guild(guild).SLOT_MIN.set(bid) - credits_name = bank.get_currency_name(guild) + credits_name = await bank.get_currency_name(guild) await ctx.send("Minimum bid is now {} {}.".format(bid, credits_name)) @economyset.command() async def slotmax(self, ctx: commands.Context, bid: int): """Maximum slot machine bid""" - slot_min = self.config.SLOT_MIN() + slot_min = await self.config.SLOT_MIN() if bid < 1 or bid < slot_min: await ctx.send('Invalid slotmax bid amount. Must be greater' ' than slotmin.') return guild = ctx.guild - credits_name = bank.get_currency_name(guild) - if bank.is_global(): + credits_name = await bank.get_currency_name(guild) + if await bank.is_global(): await self.config.SLOT_MAX.set(bid) else: await self.config.guild(guild).SLOT_MAX.set(bid) @@ -463,7 +463,7 @@ class Economy: async def slottime(self, ctx: commands.Context, seconds: int): """Seconds between each slots use""" guild = ctx.guild - if bank.is_global(): + if await bank.is_global(): await self.config.SLOT_TIME.set(seconds) else: await self.config.guild(guild).SLOT_TIME.set(seconds) @@ -473,7 +473,7 @@ class Economy: async def paydaytime(self, ctx: commands.Context, seconds: int): """Seconds between each payday""" guild = ctx.guild - if bank.is_global(): + if await bank.is_global(): await self.config.PAYDAY_TIME.set(seconds) else: await self.config.guild(guild).PAYDAY_TIME.set(seconds) @@ -484,11 +484,11 @@ class Economy: async def paydayamount(self, ctx: commands.Context, creds: int): """Amount earned each payday""" guild = ctx.guild - credits_name = bank.get_currency_name(guild) + credits_name = await bank.get_currency_name(guild) if creds <= 0: await ctx.send("Har har so funny.") return - if bank.is_global(): + if await bank.is_global(): await self.config.PAYDAY_CREDITS.set(creds) else: await self.config.guild(guild).PAYDAY_CREDITS.set(creds) @@ -501,7 +501,7 @@ class Economy: guild = ctx.guild if creds < 0: creds = 0 - credits_name = bank.get_currency_name(guild) + credits_name = await bank.get_currency_name(guild) await bank.set_default_balance(creds, guild) await ctx.send("Registering an account will now give {} {}." "".format(creds, credits_name)) diff --git a/core/bank.py b/core/bank.py index f62032f09..f8f326b22 100644 --- a/core/bank.py +++ b/core/bank.py @@ -1,6 +1,6 @@ import datetime from collections import namedtuple -from typing import Tuple, Generator, Union +from typing import Tuple, Generator, Union, List import discord from copy import deepcopy @@ -78,17 +78,17 @@ def _decode_time(time: int) -> datetime.datetime: return datetime.datetime.utcfromtimestamp(time) -def get_balance(member: discord.Member) -> int: +async def get_balance(member: discord.Member) -> int: """ Gets the current balance of a member. :param member: :return: """ - acc = get_account(member) + acc = await get_account(member) return acc.balance -def can_spend(member: discord.Member, amount: int) -> bool: +async def can_spend(member: discord.Member, amount: int) -> bool: """ Determines if a member can spend the given amount. :param member: @@ -97,7 +97,7 @@ def can_spend(member: discord.Member, amount: int) -> bool: """ if _invalid_amount(amount): return False - return get_balance(member) > amount + return await get_balance(member) > amount async def set_balance(member: discord.Member, amount: int) -> int: @@ -111,17 +111,17 @@ async def set_balance(member: discord.Member, amount: int) -> int: """ if amount < 0: raise ValueError("Not allowed to have negative balance.") - if is_global(): + if await is_global(): group = _conf.user(member) else: group = _conf.member(member) await group.balance.set(amount) - if group.created_at() == 0: + if await group.created_at() == 0: time = _encoded_current_time() await group.created_at.set(time) - if group.name() == "": + if await group.name() == "": await group.name.set(member.display_name) return amount @@ -144,7 +144,7 @@ async def withdraw_credits(member: discord.Member, amount: int) -> int: if _invalid_amount(amount): raise ValueError("Invalid withdrawal amount {} <= 0".format(amount)) - bal = get_balance(member) + bal = await get_balance(member) if amount > bal: raise ValueError("Insufficient funds {} > {}".format(amount, bal)) @@ -163,7 +163,7 @@ async def deposit_credits(member: discord.Member, amount: int) -> int: if _invalid_amount(amount): raise ValueError("Invalid withdrawal amount {} <= 0".format(amount)) - bal = get_balance(member) + bal = await get_balance(member) return await set_balance(member, amount + bal) @@ -190,13 +190,13 @@ async def wipe_bank(user: Union[discord.User, discord.Member]): Deletes all accounts from the bank. :return: """ - if is_global(): + if await is_global(): await _conf.user(user).clear() else: await _conf.member(user).clear() -def get_guild_accounts(guild: discord.Guild) -> Generator[Account, None, None]: +async def get_guild_accounts(guild: discord.Guild) -> List[Account]: """ Gets all account data for the given guild. @@ -207,14 +207,16 @@ def get_guild_accounts(guild: discord.Guild) -> Generator[Account, None, None]: if is_global(): raise RuntimeError("The bank is currently global.") - accs = _conf.member(guild.owner).all_from_kind() + ret = [] + accs = await _conf.member(guild.owner).all_from_kind() for user_id, acc in accs.items(): acc_data = acc.copy() # There ya go kowlin acc_data['created_at'] = _decode_time(acc_data['created_at']) - yield Account(**acc_data) + ret.append(Account(**acc_data)) + return ret -def get_global_accounts(user: discord.User) -> Generator[Account, None, None]: +async def get_global_accounts(user: discord.User) -> List[Account]: """ Gets all global account data. @@ -225,44 +227,47 @@ def get_global_accounts(user: discord.User) -> Generator[Account, None, None]: if not is_global(): raise RuntimeError("The bank is not currently global.") - accs = _conf.user(user).all_from_kind() # this is a dict of user -> acc + ret = [] + accs = await _conf.user(user).all_from_kind() # this is a dict of user -> acc for user_id, acc in accs.items(): acc_data = acc.copy() acc_data['created_at'] = _decode_time(acc_data['created_at']) - yield Account(**acc_data) + ret.append(Account(**acc_data)) + + return ret -def get_account(member: Union[discord.Member, discord.User]) -> Account: +async def get_account(member: Union[discord.Member, discord.User]) -> Account: """ Gets the appropriate account for the given member. :param member: :return: """ - if is_global(): - acc_data = _conf.user(member)().copy() + if await is_global(): + acc_data = (await _conf.user(member)()).copy() default = _DEFAULT_USER.copy() else: - acc_data = _conf.member(member)().copy() + acc_data = (await _conf.member(member)()).copy() default = _DEFAULT_MEMBER.copy() if acc_data == {}: acc_data = default acc_data['name'] = member.display_name try: - acc_data['balance'] = get_default_balance(member.guild) + acc_data['balance'] = await get_default_balance(member.guild) except AttributeError: - acc_data['balance'] = get_default_balance() + acc_data['balance'] = await get_default_balance() acc_data['created_at'] = _decode_time(acc_data['created_at']) return Account(**acc_data) -def is_global() -> bool: +async def is_global() -> bool: """ Determines if the bank is currently global. :return: """ - return _conf.is_global() + return await _conf.is_global() async def set_global(global_: bool, user: Union[discord.User, discord.Member]) -> bool: @@ -272,7 +277,7 @@ async def set_global(global_: bool, user: Union[discord.User, discord.Member]) - :param user: Must be a Member object if changing TO global mode. :return: New bank mode, True is global. """ - if is_global() is global_: + if (await is_global()) is global_: return global_ if is_global(): @@ -287,7 +292,7 @@ async def set_global(global_: bool, user: Union[discord.User, discord.Member]) - return global_ -def get_bank_name(guild: discord.Guild=None) -> str: +async def get_bank_name(guild: discord.Guild=None) -> str: """ Gets the current bank name. If the bank is guild-specific the guild parameter is required. @@ -296,10 +301,10 @@ def get_bank_name(guild: discord.Guild=None) -> str: :param guild: :return: """ - if is_global(): - return _conf.bank_name() + if await is_global(): + return await _conf.bank_name() elif guild is not None: - return _conf.guild(guild).bank_name() + return await _conf.guild(guild).bank_name() else: raise RuntimeError("Guild parameter is required and missing.") @@ -314,7 +319,7 @@ async def set_bank_name(name: str, guild: discord.Guild=None) -> str: :param guild: :return: """ - if is_global(): + if await is_global(): await _conf.bank_name.set(name) elif guild is not None: await _conf.guild(guild).bank_name.set(name) @@ -324,7 +329,7 @@ async def set_bank_name(name: str, guild: discord.Guild=None) -> str: return name -def get_currency_name(guild: discord.Guild=None) -> str: +async def get_currency_name(guild: discord.Guild=None) -> str: """ Gets the currency name of the bank. The guild parameter is required if the bank is guild-specific. @@ -333,10 +338,10 @@ def get_currency_name(guild: discord.Guild=None) -> str: :param guild: :return: """ - if is_global(): - return _conf.currency() + if await is_global(): + return await _conf.currency() elif guild is not None: - return _conf.guild(guild).currency() + return await _conf.guild(guild).currency() else: raise RuntimeError("Guild must be provided.") @@ -351,7 +356,7 @@ async def set_currency_name(name: str, guild: discord.Guild=None) -> str: :param guild: :return: """ - if is_global(): + if await is_global(): await _conf.currency.set(name) elif guild is not None: await _conf.guild(guild).currency.set(name) @@ -361,7 +366,7 @@ async def set_currency_name(name: str, guild: discord.Guild=None) -> str: return name -def get_default_balance(guild: discord.Guild=None) -> int: +async def get_default_balance(guild: discord.Guild=None) -> int: """ Gets the current default balance amount. If the bank is guild-specific you must pass guild. @@ -370,10 +375,10 @@ def get_default_balance(guild: discord.Guild=None) -> int: :param guild: :return: """ - if is_global(): - return _conf.default_balance() + if await is_global(): + return await _conf.default_balance() elif guild is not None: - return _conf.guild(guild).default_balance() + return await _conf.guild(guild).default_balance() else: raise RuntimeError("Guild is missing and required!") @@ -393,9 +398,11 @@ async def set_default_balance(amount: int, guild: discord.Guild=None) -> int: if amount < 0: raise ValueError("Amount must be greater than zero.") - if is_global(): + if await is_global(): await _conf.default_balance.set(amount) elif guild is not None: await _conf.guild(guild).default_balance.set(amount) else: raise RuntimeError("Guild is missing and required.") + + return amount diff --git a/core/bot.py b/core/bot.py index 0c5b04bb4..8f295ac56 100644 --- a/core/bot.py +++ b/core/bot.py @@ -1,3 +1,4 @@ +import asyncio import importlib.util from importlib.machinery import ModuleSpec @@ -39,14 +40,14 @@ class Red(commands.Bot): mod_role=None ) - def prefix_manager(bot, message): + async def prefix_manager(bot, message): if not cli_flags.prefix: - global_prefix = self.db.prefix() + global_prefix = await bot.db.prefix() else: global_prefix = cli_flags.prefix if message.guild is None: return global_prefix - server_prefix = self.db.guild(message.guild).prefix() + server_prefix = await bot.db.guild(message.guild).prefix() return server_prefix if server_prefix else global_prefix if "command_prefix" not in kwargs: @@ -56,7 +57,8 @@ class Red(commands.Bot): kwargs["owner_id"] = cli_flags.owner if "owner_id" not in kwargs: - kwargs["owner_id"] = self.db.owner() + loop = asyncio.get_event_loop() + loop.run_until_complete(self._dict_abuse(kwargs)) self.counter = Counter() self.uptime = None @@ -68,6 +70,15 @@ class Red(commands.Bot): super().__init__(**kwargs) + async def _dict_abuse(self, indict): + """ + Please blame <@269933075037814786> for this. + :param indict: + :return: + """ + + indict['owner_id'] = await self.db.owner() + async def is_owner(self, user): if user.id in self._co_owners: return True @@ -103,13 +114,13 @@ class Red(commands.Bot): await self.db.packages.set(packages) async def add_loaded_package(self, pkg_name: str): - curr_pkgs = self.db.packages() + curr_pkgs = await self.db.packages() if pkg_name not in curr_pkgs: curr_pkgs.append(pkg_name) await self.save_packages_status(curr_pkgs) async def remove_loaded_package(self, pkg_name: str): - curr_pkgs = self.db.packages() + curr_pkgs = await self.db.packages() if pkg_name in curr_pkgs: await self.save_packages_status([p for p in curr_pkgs if p != pkg_name]) diff --git a/core/cog_manager.py b/core/cog_manager.py index a819313c6..8a2daed88 100644 --- a/core/cog_manager.py +++ b/core/cog_manager.py @@ -31,26 +31,29 @@ class CogManager: install_path=str(bot_dir.resolve() / "cogs") ) - self._paths = set(list(self.conf.paths()) + list(paths)) + self._paths = list(paths) - @property - def paths(self) -> Tuple[Path, ...]: + async def paths(self) -> Tuple[Path, ...]: """ This will return all currently valid path directories. :return: """ - paths = [Path(p) for p in self._paths] + conf_paths = await self.conf.paths() + other_paths = self._paths + + all_paths = set(list(conf_paths) + list(other_paths)) + + paths = [Path(p) for p in all_paths] if self.install_path not in paths: - paths.insert(0, self.install_path) + paths.insert(0, await self.install_path()) return tuple(p.resolve() for p in paths if p.is_dir()) - @property - def install_path(self) -> Path: + async def install_path(self) -> Path: """ Returns the install path for 3rd party cogs. :return: """ - p = Path(self.conf.install_path()) + p = Path(await self.conf.install_path()) return p.resolve() async def set_install_path(self, path: Path) -> Path: @@ -99,10 +102,10 @@ class CogManager: if not path.is_dir(): raise InvalidPath("'{}' is not a valid directory.".format(path)) - if path == self.install_path: + if path == await self.install_path(): raise ValueError("Cannot add the install path as an additional path.") - all_paths = set(self.paths + (path, )) + all_paths = set(await self.paths() + (path, )) # noinspection PyTypeChecker await self.set_paths(all_paths) @@ -113,7 +116,7 @@ class CogManager: :return: """ path = self._ensure_path_obj(path) - all_paths = list(self.paths) + all_paths = list(await self.paths()) if path in all_paths: all_paths.remove(path) # Modifies in place await self.set_paths(all_paths) @@ -125,11 +128,10 @@ class CogManager: :param paths_: :return: """ - self._paths = paths_ str_paths = [str(p) for p in paths_] await self.conf.paths.set(str_paths) - def find_cog(self, name: str) -> ModuleSpec: + async def find_cog(self, name: str) -> ModuleSpec: """ Finds a cog in the list of available path. @@ -137,7 +139,7 @@ class CogManager: :param name: :return: """ - resolved_paths = [str(p.resolve()) for p in self.paths] + resolved_paths = [str(p.resolve()) for p in await self.paths()] for finder, module_name, _ in pkgutil.iter_modules(resolved_paths): if name == module_name: spec = finder.find_spec(name) @@ -166,7 +168,7 @@ class CogManagerUI: """ Lists current cog paths in order of priority. """ - install_path = ctx.bot.cog_mgr.install_path + install_path = await ctx.bot.cog_mgr.install_path() cog_paths = ctx.bot.cog_mgr.paths cog_paths = [p for p in cog_paths if p != install_path] @@ -204,7 +206,7 @@ class CogManagerUI: Removes a path from the available cog paths given the path_number from !paths """ - cog_paths = ctx.bot.cog_mgr.paths + cog_paths = await ctx.bot.cog_mgr.paths() try: to_remove = cog_paths[path_number] except IndexError: @@ -224,7 +226,7 @@ class CogManagerUI: from_ -= 1 to -= 1 - all_paths = list(ctx.bot.cog_mgr.paths) + all_paths = list(await ctx.bot.cog_mgr.paths()) try: to_move = all_paths.pop(from_) except IndexError: @@ -257,6 +259,6 @@ class CogManagerUI: await ctx.send("That path does not exist.") return - install_path = ctx.bot.cog_mgr.install_path + install_path = await ctx.bot.cog_mgr.install_path() await ctx.send("The bot will install new cogs to the `{}`" " directory.".format(install_path)) diff --git a/core/config.py b/core/config.py index 3fe245604..e7e096bd2 100644 --- a/core/config.py +++ b/core/config.py @@ -38,6 +38,14 @@ class Value: def identifiers(self): return tuple(str(i) for i in self._identifiers) + async def _get(self, default): + driver = self.spawner.get_driver() + try: + ret = await driver.get(self.identifiers) + except KeyError: + return default or self.default + return ret + def __call__(self, default=None): """ Each :py:class:`Value` object is created by the :py:meth:`Group.__getattr__` method. @@ -46,25 +54,26 @@ class Value: For example:: - foo = conf.guild(some_guild).foo() + foo = await conf.guild(some_guild).foo() # Is equivalent to this group_obj = conf.guild(some_guild) value_obj = conf.foo - foo = value_obj() + foo = await value_obj() + + .. important:: + + This is now, for all intents and purposes, a coroutine. :param default: This argument acts as an override for the registered default provided by :py:attr:`default`. This argument is ignored if its value is :python:`None`. :type default: Optional[object] + :return: + A coroutine object that must be awaited. """ - driver = self.spawner.get_driver() - try: - ret = driver.get(self.identifiers) - except KeyError: - return default or self.default - return ret + return self._get(default) async def set(self, value): """ @@ -182,7 +191,7 @@ class Group(Value): return not isinstance(default, dict) - def get_attr(self, item: str, default=None, resolve=True): + async def get_attr(self, item: str, default=None, resolve=True): """ This is available to use as an alternative to using normal Python attribute access. It is required if you find a need for dynamic attribute access. @@ -198,7 +207,7 @@ class Group(Value): user = ctx.author # Where the value of item is the name of the data field in Config - await ctx.send(self.conf.user(user).get_attr(item)) + await ctx.send(await self.conf.user(user).get_attr(item)) :param str item: The name of the data field in :py:class:`.Config`. @@ -211,20 +220,20 @@ class Group(Value): """ value = getattr(self, item) if resolve: - return value(default=default) + return await value(default=default) else: return value - def all(self) -> dict: + async def all(self) -> dict: """ This method allows you to get "all" of a particular group of data. It will return the dictionary of all data for a particular Guild/Channel/Role/User/Member etc. :rtype: dict """ - return self() + return await self() - def all_from_kind(self) -> dict: + async def all_from_kind(self) -> dict: """ This method allows you to get all data from all entries in a given Kind. It will return a dictionary of Kind ID's -> data. @@ -232,7 +241,7 @@ class Group(Value): :rtype: dict """ # noinspection PyTypeChecker - return self._super_group() + return await self._super_group() async def set(self, value): if not isinstance(value, dict): @@ -292,18 +301,18 @@ class MemberGroup(Group): ) return group_obj - def all_guilds(self) -> dict: + async def all_guilds(self) -> dict: """ Returns a dict of :code:`GUILD_ID -> MEMBER_ID -> data`. :rtype: dict """ # noinspection PyTypeChecker - return self._super_group() + return await self._super_group() - def all(self) -> dict: + async def all(self) -> dict: # noinspection PyTypeChecker - return self._guild_group() + return await self._guild_group() class Config: @@ -315,7 +324,7 @@ class Config: however the process for accessing global data is a bit different. There is no :python:`global` method because global data is accessed by normal attribute access:: - conf.foo() + await conf.foo() .. py:attribute:: cog_name diff --git a/core/core_commands.py b/core/core_commands.py index 1bd667390..47a37c7c3 100644 --- a/core/core_commands.py +++ b/core/core_commands.py @@ -29,7 +29,7 @@ class Core: async def load(self, ctx, *, cog_name: str): """Loads a package""" try: - spec = ctx.bot.cog_mgr.find_cog(cog_name) + spec = await ctx.bot.cog_mgr.find_cog(cog_name) except NoModuleFound: await ctx.send("No module by that name was found in any" " cog path.") @@ -63,7 +63,7 @@ class Core: ctx.bot.unload_extension(cog_name) self.cleanup_and_refresh_modules(cog_name) try: - spec = ctx.bot.cog_mgr.find_cog(cog_name) + spec = await ctx.bot.cog_mgr.find_cog(cog_name) ctx.bot.load_extension(spec) except Exception as e: log.exception("Package reloading failed", exc_info=e) diff --git a/core/drivers/red_base.py b/core/drivers/red_base.py index a0d838272..708063003 100644 --- a/core/drivers/red_base.py +++ b/core/drivers/red_base.py @@ -5,7 +5,7 @@ class BaseDriver: def get_driver(self): raise NotImplementedError - def get(self, identifiers: Tuple[str]): + async def get(self, identifiers: Tuple[str]): raise NotImplementedError async def set(self, identifiers: Tuple[str], value): diff --git a/core/drivers/red_json.py b/core/drivers/red_json.py index 7dfc64a5a..81f242238 100644 --- a/core/drivers/red_json.py +++ b/core/drivers/red_json.py @@ -32,7 +32,7 @@ class JSON(BaseDriver): def get_driver(self): return self - def get(self, identifiers: Tuple[str]): + async def get(self, identifiers: Tuple[str]): partial = self.data for i in identifiers: partial = partial[i] diff --git a/core/events.py b/core/events.py index 14ccc7dca..fee13f8a0 100644 --- a/core/events.py +++ b/core/events.py @@ -34,11 +34,11 @@ def init_events(bot, cli_flags): if cli_flags.no_cogs is False: print("Loading packages...") failed = [] - packages = bot.db.packages() + packages = await bot.db.packages() for package in packages: try: - spec = bot.cog_mgr.find_cog(package) + spec = await bot.cog_mgr.find_cog(package) bot.load_extension(spec) except Exception as e: log.exception("Failed to load package {}".format(package), diff --git a/main.py b/main.py index 8197a7a9a..d741355ce 100644 --- a/main.py +++ b/main.py @@ -73,6 +73,17 @@ def determine_main_folder() -> Path: return Path(os.path.dirname(__file__)).resolve() +async def _get_prefix_and_token(red, indict): + """ + Again, please blame <@269933075037814786> for this. + :param indict: + :return: + """ + indict['token'] = await red.db.token() + indict['prefix'] = await red.db.prefix() + indict['enable_sentry'] = await red.db.enable_sentry() + + if __name__ == '__main__': cli_flags = parse_cli_flags() log, sentry_log = init_loggers(cli_flags) @@ -89,8 +100,13 @@ if __name__ == '__main__': if cli_flags.dev: red.add_cog(Dev()) - token = os.environ.get("RED_TOKEN", red.db.token()) - prefix = cli_flags.prefix or red.db.prefix() + loop = asyncio.get_event_loop() + tmp_data = {} + loop.run_until_complete(_get_prefix_and_token(red, tmp_data)) + + token = os.environ.get("RED_TOKEN", tmp_data['token']) + prefix = cli_flags.prefix or tmp_data['prefix'] + enable_sentry = tmp_data['enable_sentry'] if token is None or not prefix: if cli_flags.no_prompt is False: @@ -102,13 +118,14 @@ if __name__ == '__main__': log.critical("Token and prefix must be set in order to login.") sys.exit(1) - if red.db.enable_sentry() is None: + if enable_sentry is None: ask_sentry(red) - if red.db.enable_sentry(): + loop.run_until_complete(_get_prefix_and_token(red, tmp_data)) + + if tmp_data['enable_sentry']: init_sentry_logging(sentry_log) - loop = asyncio.get_event_loop() cleanup_tasks = True try: diff --git a/tests/cogs/test_alias.py b/tests/cogs/test_alias.py index f9c59b7a2..04bfd1750 100644 --- a/tests/cogs/test_alias.py +++ b/tests/cogs/test_alias.py @@ -16,12 +16,14 @@ def test_is_valid_alias_name(alias): assert alias.is_valid_alias_name("not valid name") is False -def test_empty_guild_aliases(alias, empty_guild): - assert list(alias.unloaded_aliases(empty_guild)) == [] +@pytest.mark.asyncio +async def test_empty_guild_aliases(alias, empty_guild): + assert list(await alias.unloaded_aliases(empty_guild)) == [] -def test_empty_global_aliases(alias): - assert list(alias.unloaded_global_aliases()) == [] +@pytest.mark.asyncio +async def test_empty_global_aliases(alias): + assert list(await alias.unloaded_global_aliases()) == [] async def create_test_guild_alias(alias, ctx): @@ -36,7 +38,7 @@ async def create_test_global_alias(alias, ctx): async def test_add_guild_alias(alias, ctx): await create_test_guild_alias(alias, ctx) - is_alias, alias_obj = alias.is_alias(ctx.guild, "test") + is_alias, alias_obj = await alias.is_alias(ctx.guild, "test") assert is_alias is True assert alias_obj.global_ is False @@ -44,19 +46,19 @@ async def test_add_guild_alias(alias, ctx): @pytest.mark.asyncio async def test_delete_guild_alias(alias, ctx): await create_test_guild_alias(alias, ctx) - is_alias, _ = alias.is_alias(ctx.guild, "test") + is_alias, _ = await alias.is_alias(ctx.guild, "test") assert is_alias is True await alias.delete_alias(ctx, "test") - is_alias, _ = alias.is_alias(ctx.guild, "test") + is_alias, _ = await alias.is_alias(ctx.guild, "test") assert is_alias is False @pytest.mark.asyncio async def test_add_global_alias(alias, ctx): await create_test_global_alias(alias, ctx) - is_alias, alias_obj = alias.is_alias(ctx.guild, "test") + is_alias, alias_obj = await alias.is_alias(ctx.guild, "test") assert is_alias is True assert alias_obj.global_ is True @@ -65,7 +67,7 @@ async def test_add_global_alias(alias, ctx): @pytest.mark.asyncio async def test_delete_global_alias(alias, ctx): await create_test_global_alias(alias, ctx) - is_alias, alias_obj = alias.is_alias(ctx.guild, "test") + is_alias, alias_obj = await alias.is_alias(ctx.guild, "test") assert is_alias is True assert alias_obj.global_ is True diff --git a/tests/cogs/test_economy.py b/tests/cogs/test_economy.py index 70bd6e98a..4c68a87c7 100644 --- a/tests/cogs/test_economy.py +++ b/tests/cogs/test_economy.py @@ -11,13 +11,14 @@ def bank(config): return bank -def test_bank_register(bank, ctx): - default_bal = bank.get_default_balance(ctx.guild) - assert default_bal == bank.get_account(ctx.author).balance +@pytest.mark.asyncio +async def test_bank_register(bank, ctx): + default_bal = await bank.get_default_balance(ctx.guild) + assert default_bal == (await bank.get_account(ctx.author)).balance async def has_account(member, bank): - balance = bank.get_balance(member) + balance = await bank.get_balance(member) if balance == 0: balance = 1 await bank.set_balance(member, balance) @@ -27,11 +28,11 @@ async def has_account(member, bank): async def test_bank_transfer(bank, member_factory): mbr1 = member_factory.get() mbr2 = member_factory.get() - bal1 = bank.get_account(mbr1).balance - bal2 = bank.get_account(mbr2).balance + bal1 = (await bank.get_account(mbr1)).balance + bal2 = (await bank.get_account(mbr2)).balance await bank.transfer_credits(mbr1, mbr2, 50) - newbal1 = bank.get_account(mbr1).balance - newbal2 = bank.get_account(mbr2).balance + newbal1 = (await bank.get_account(mbr1)).balance + newbal2 = (await bank.get_account(mbr2)).balance assert bal1 - 50 == newbal1 assert bal2 + 50 == newbal2 @@ -40,16 +41,16 @@ async def test_bank_transfer(bank, member_factory): async def test_bank_set(bank, member_factory): mbr = member_factory.get() await bank.set_balance(mbr, 250) - acc = bank.get_account(mbr) + acc = await bank.get_account(mbr) assert acc.balance == 250 @pytest.mark.asyncio async def test_bank_can_spend(bank, member_factory): mbr = member_factory.get() - canspend = bank.can_spend(mbr, 50) - assert canspend == (50 < bank.get_default_balance(mbr.guild)) + canspend = await bank.can_spend(mbr, 50) + assert canspend == (50 < await bank.get_default_balance(mbr.guild)) await bank.set_balance(mbr, 200) - acc = bank.get_account(mbr) - canspendnow = bank.can_spend(mbr, 100) + acc = await bank.get_account(mbr) + canspendnow = await bank.can_spend(mbr, 100) assert canspendnow diff --git a/tests/core/test_cog_manager.py b/tests/core/test_cog_manager.py index ab4470418..2af284ddf 100644 --- a/tests/core/test_cog_manager.py +++ b/tests/core/test_cog_manager.py @@ -14,16 +14,17 @@ def default_dir(red): return red.main_dir -def test_ensure_cogs_in_paths(cog_mgr, default_dir): +@pytest.mark.asyncio +async def test_ensure_cogs_in_paths(cog_mgr, default_dir): cogs_dir = default_dir / 'cogs' - assert cogs_dir in cog_mgr.paths + assert cogs_dir in await cog_mgr.paths() @pytest.mark.asyncio async def test_install_path_set(cog_mgr: cog_manager.CogManager, tmpdir): path = Path(str(tmpdir)) await cog_mgr.set_install_path(path) - assert cog_mgr.install_path == path + assert await cog_mgr.install_path() == path @pytest.mark.asyncio @@ -38,7 +39,7 @@ async def test_install_path_set_bad(cog_mgr): async def test_add_path(cog_mgr, tmpdir): path = Path(str(tmpdir)) await cog_mgr.add_path(path) - assert path in cog_mgr.paths + assert path in await cog_mgr.paths() @pytest.mark.asyncio @@ -54,4 +55,4 @@ async def test_remove_path(cog_mgr, tmpdir): path = Path(str(tmpdir)) await cog_mgr.add_path(path) await cog_mgr.remove_path(path) - assert path not in cog_mgr.paths + assert path not in await cog_mgr.paths() diff --git a/tests/core/test_config.py b/tests/core/test_config.py index 2e34218c3..40d859a52 100644 --- a/tests/core/test_config.py +++ b/tests/core/test_config.py @@ -2,10 +2,11 @@ import pytest #region Register Tests -def test_config_register_global(config): +@pytest.mark.asyncio +async def test_config_register_global(config): config.register_global(enabled=False) assert config.defaults["GLOBAL"]["enabled"] is False - assert config.enabled() is False + assert await config.enabled() is False def test_config_register_global_badvalues(config): @@ -13,61 +14,69 @@ def test_config_register_global_badvalues(config): config.register_global(**{"invalid var name": True}) -def test_config_register_guild(config, empty_guild): +@pytest.mark.asyncio +async def test_config_register_guild(config, empty_guild): config.register_guild(enabled=False, some_list=[], some_dict={}) assert config.defaults[config.GUILD]["enabled"] is False assert config.defaults[config.GUILD]["some_list"] == [] assert config.defaults[config.GUILD]["some_dict"] == {} - assert config.guild(empty_guild).enabled() is False - assert config.guild(empty_guild).some_list() == [] - assert config.guild(empty_guild).some_dict() == {} + assert await config.guild(empty_guild).enabled() is False + assert await config.guild(empty_guild).some_list() == [] + assert await config.guild(empty_guild).some_dict() == {} -def test_config_register_channel(config, empty_channel): +@pytest.mark.asyncio +async def test_config_register_channel(config, empty_channel): config.register_channel(enabled=False) assert config.defaults[config.CHANNEL]["enabled"] is False - assert config.channel(empty_channel).enabled() is False + assert await config.channel(empty_channel).enabled() is False -def test_config_register_role(config, empty_role): +@pytest.mark.asyncio +async def test_config_register_role(config, empty_role): config.register_role(enabled=False) assert config.defaults[config.ROLE]["enabled"] is False - assert config.role(empty_role).enabled() is False + assert await config.role(empty_role).enabled() is False -def test_config_register_member(config, empty_member): +@pytest.mark.asyncio +async def test_config_register_member(config, empty_member): config.register_member(some_number=-1) assert config.defaults[config.MEMBER]["some_number"] == -1 - assert config.member(empty_member).some_number() == -1 + assert await config.member(empty_member).some_number() == -1 -def test_config_register_user(config, empty_user): +@pytest.mark.asyncio +async def test_config_register_user(config, empty_user): config.register_user(some_value=None) assert config.defaults[config.USER]["some_value"] is None - assert config.user(empty_user).some_value() is None + assert await config.user(empty_user).some_value() is None -def test_config_force_register_global(config_fr): +@pytest.mark.asyncio +async def test_config_force_register_global(config_fr): with pytest.raises(AttributeError): - config_fr.enabled() + await config_fr.enabled() config_fr.register_global(enabled=True) - assert config_fr.enabled() is True + assert await config_fr.enabled() is True #endregion # Test nested registration -def test_nested_registration(config): +@pytest.mark.asyncio +async def test_nested_registration(config): config.register_global(foo__bar__baz=False) - assert config.foo.bar.baz() is False + assert await config.foo.bar.baz() is False -def test_nested_registration_asdict(config): +@pytest.mark.asyncio +async def test_nested_registration_asdict(config): defaults = {'bar': {'baz': False}} config.register_global(foo=defaults) - assert config.foo.bar.baz() is False + assert await config.foo.bar.baz() is False @pytest.mark.asyncio @@ -75,20 +84,22 @@ async def test_nested_registration_and_changing(config): defaults = {'bar': {'baz': False}} config.register_global(foo=defaults) - assert config.foo.bar.baz() is False + assert await config.foo.bar.baz() is False with pytest.raises(ValueError): await config.foo.set(True) -def test_doubleset_default(config): +@pytest.mark.asyncio +async def test_doubleset_default(config): config.register_global(foo=True) config.register_global(foo=False) - assert config.foo() is False + assert await config.foo() is False -def test_nested_registration_multidict(config): +@pytest.mark.asyncio +async def test_nested_registration_multidict(config): defaults = { "foo": { "bar": { @@ -99,8 +110,8 @@ def test_nested_registration_multidict(config): } config.register_global(**defaults) - assert config.foo.bar.baz() is True - assert config.blah() is True + assert await config.foo.bar.baz() is True + assert await config.blah() is True def test_nested_group_value_badreg(config): @@ -109,56 +120,66 @@ def test_nested_group_value_badreg(config): config.register_global(foo__bar=False) -def test_nested_toplevel_reg(config): +@pytest.mark.asyncio +async def test_nested_toplevel_reg(config): defaults = {'bar': True, 'baz': False} config.register_global(foo=defaults) - assert config.foo.bar() is True - assert config.foo.baz() is False + assert await config.foo.bar() is True + assert await config.foo.baz() is False -def test_nested_overlapping(config): +@pytest.mark.asyncio +async def test_nested_overlapping(config): config.register_global(foo__bar=True) config.register_global(foo__baz=False) - assert config.foo.bar() is True - assert config.foo.baz() is False + assert await config.foo.bar() is True + assert await config.foo.baz() is False -def test_nesting_nofr(config): +@pytest.mark.asyncio +async def test_nesting_nofr(config): config.register_global(foo={}) - assert config.foo.bar() is None - assert config.foo() == {} + assert await config.foo.bar() is None + assert await config.foo() == {} -#region Default Value Overrides -def test_global_default_override(config): - assert config.enabled(True) is True +# region Default Value Overrides +@pytest.mark.asyncio +async def test_global_default_override(config): + assert await config.enabled(True) is True -def test_global_default_nofr(config): - assert config.nofr() is None - assert config.nofr(True) is True +@pytest.mark.asyncio +async def test_global_default_nofr(config): + assert await config.nofr() is None + assert await config.nofr(True) is True -def test_guild_default_override(config, empty_guild): - assert config.guild(empty_guild).enabled(True) is True +@pytest.mark.asyncio +async def test_guild_default_override(config, empty_guild): + assert await config.guild(empty_guild).enabled(True) is True -def test_channel_default_override(config, empty_channel): - assert config.channel(empty_channel).enabled(True) is True +@pytest.mark.asyncio +async def test_channel_default_override(config, empty_channel): + assert await config.channel(empty_channel).enabled(True) is True -def test_role_default_override(config, empty_role): - assert config.role(empty_role).enabled(True) is True +@pytest.mark.asyncio +async def test_role_default_override(config, empty_role): + assert await config.role(empty_role).enabled(True) is True -def test_member_default_override(config, empty_member): - assert config.member(empty_member).enabled(True) is True +@pytest.mark.asyncio +async def test_member_default_override(config, empty_member): + assert await config.member(empty_member).enabled(True) is True -def test_user_default_override(config, empty_user): - assert config.user(empty_user).some_value(True) is True +@pytest.mark.asyncio +async def test_user_default_override(config, empty_user): + assert await config.user(empty_user).some_value(True) is True #endregion @@ -166,32 +187,32 @@ def test_user_default_override(config, empty_user): @pytest.mark.asyncio async def test_set_global(config): await config.enabled.set(True) - assert config.enabled() is True + assert await config.enabled() is True @pytest.mark.asyncio async def test_set_guild(config, empty_guild): await config.guild(empty_guild).enabled.set(True) - assert config.guild(empty_guild).enabled() is True + assert await config.guild(empty_guild).enabled() is True - curr_list = config.guild(empty_guild).some_list([1, 2, 3]) + curr_list = await config.guild(empty_guild).some_list([1, 2, 3]) assert curr_list == [1, 2, 3] curr_list.append(4) await config.guild(empty_guild).some_list.set(curr_list) - assert config.guild(empty_guild).some_list() == curr_list + assert await config.guild(empty_guild).some_list() == curr_list @pytest.mark.asyncio async def test_set_channel(config, empty_channel): await config.channel(empty_channel).enabled.set(True) - assert config.channel(empty_channel).enabled() is True + assert await config.channel(empty_channel).enabled() is True @pytest.mark.asyncio async def test_set_channel_no_register(config, empty_channel): await config.channel(empty_channel).no_register.set(True) - assert config.channel(empty_channel).no_register() is True + assert await config.channel(empty_channel).no_register() is True #endregion @@ -200,11 +221,12 @@ async def test_set_channel_no_register(config, empty_channel): async def test_set_dynamic_attr(config): await config.set_attr("foobar", True) - assert config.foobar() is True + assert await config.foobar() is True -def test_get_dynamic_attr(config): - assert config.get_attr("foobaz", True) is True +@pytest.mark.asyncio +async def test_get_dynamic_attr(config): + assert await config.get_attr("foobaz", True) is True # Member Group testing @@ -212,7 +234,7 @@ def test_get_dynamic_attr(config): async def test_membergroup_allguilds(config, empty_member): await config.member(empty_member).foo.set(False) - all_servers = config.member(empty_member).all_guilds() + all_servers = await config.member(empty_member).all_guilds() assert str(empty_member.guild.id) in all_servers @@ -220,7 +242,7 @@ async def test_membergroup_allguilds(config, empty_member): async def test_membergroup_allmembers(config, empty_member): await config.member(empty_member).foo.set(False) - all_members = config.member(empty_member).all() + all_members = await config.member(empty_member).all() assert str(empty_member.id) in all_members @@ -232,13 +254,13 @@ async def test_global_clear(config): await config.foo.set(False) await config.bar.set(True) - assert config.foo() is False - assert config.bar() is True + assert await config.foo() is False + assert await config.bar() is True await config.clear() - assert config.foo() is True - assert config.bar() is False + assert await config.foo() is True + assert await config.bar() is False @pytest.mark.asyncio @@ -247,17 +269,17 @@ async def test_member_clear(config, member_factory): m1 = member_factory.get() await config.member(m1).foo.set(False) - assert config.member(m1).foo() is False + assert await config.member(m1).foo() is False m2 = member_factory.get() await config.member(m2).foo.set(False) - assert config.member(m2).foo() is False + assert await config.member(m2).foo() is False assert m1.guild.id != m2.guild.id await config.member(m1).clear() - assert config.member(m1).foo() is True - assert config.member(m2).foo() is False + assert await config.member(m1).foo() is True + assert await config.member(m2).foo() is False @pytest.mark.asyncio @@ -269,11 +291,11 @@ async def test_member_clear_all(config, member_factory): server_ids.append(member.guild.id) member = member_factory.get() - assert len(config.member(member).all_guilds()) == len(server_ids) + assert len(await config.member(member).all_guilds()) == len(server_ids) await config.member(member).clear_all() - assert len(config.member(member).all_guilds()) == 0 + assert len(await config.member(member).all_guilds()) == 0 # Get All testing @@ -284,7 +306,7 @@ async def test_user_get_all_from_kind(config, user_factory): await config.user(user).foo.set(True) user = user_factory.get() - all_data = config.user(user).all_from_kind() + all_data = await config.user(user).all_from_kind() assert len(all_data) == 5 @@ -294,4 +316,4 @@ async def test_user_getalldata(config, user_factory): user = user_factory.get() await config.user(user).foo.set(False) - assert "foo" in config.user(user).all() + assert "foo" in await config.user(user).all()