[Config] Asynchronous getters (#907)

* Make config get async

* Asyncify alias

* Asyncify bank

* Asyncify cog manager

* IT BOOTS

* Asyncify core commands

* Asyncify repo manager

* Asyncify downloader

* Asyncify economy

* Asyncify alias TESTS

* Asyncify economy TESTS

* Asyncify downloader TESTS

* Asyncify config TESTS

* A bank thing

* Asyncify Bank cog

* Warning message in docs

* Update docs with await syntax

* Update docs with await syntax
This commit is contained in:
Will 2017-08-11 21:43:21 -04:00 committed by GitHub
parent cf8e11238c
commit de912a3cfb
18 changed files with 371 additions and 296 deletions

View File

@ -38,26 +38,26 @@ class Alias:
self._aliases.register_global(**self.default_global_settings)
self._aliases.register_guild(**self.default_guild_settings)
def unloaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]:
return (AliasEntry.from_json(d) for d in self._aliases.guild(guild).entries())
async def unloaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]:
return (AliasEntry.from_json(d) for d in (await self._aliases.guild(guild).entries()))
def unloaded_global_aliases(self) -> Generator[AliasEntry, None, None]:
return (AliasEntry.from_json(d) for d in self._aliases.entries())
async def unloaded_global_aliases(self) -> Generator[AliasEntry, None, None]:
return (AliasEntry.from_json(d) for d in (await self._aliases.entries()))
def loaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]:
async def loaded_aliases(self, guild: discord.Guild) -> Generator[AliasEntry, None, None]:
return (AliasEntry.from_json(d, bot=self.bot)
for d in self._aliases.guild(guild).entries())
for d in (await self._aliases.guild(guild).entries()))
def loaded_global_aliases(self) -> Generator[AliasEntry, None, None]:
return (AliasEntry.from_json(d, bot=self.bot) for d in self._aliases.entries())
async def loaded_global_aliases(self) -> Generator[AliasEntry, None, None]:
return (AliasEntry.from_json(d, bot=self.bot) for d in (await self._aliases.entries()))
def is_alias(self, guild: discord.Guild, alias_name: str,
async def is_alias(self, guild: discord.Guild, alias_name: str,
server_aliases: Iterable[AliasEntry]=()) -> (bool, AliasEntry):
if not server_aliases:
server_aliases = self.unloaded_aliases(guild)
server_aliases = await self.unloaded_aliases(guild)
global_aliases = self.unloaded_global_aliases()
global_aliases = await self.unloaded_global_aliases()
for aliases in (server_aliases, global_aliases):
for alias in aliases:
@ -79,11 +79,11 @@ class Alias:
alias = AliasEntry(alias_name, command, ctx.author, global_=global_)
if global_:
curr_aliases = self._aliases.entries()
curr_aliases = await self._aliases.entries()
curr_aliases.append(alias.to_json())
await self._aliases.entries.set(curr_aliases)
else:
curr_aliases = self._aliases.guild(ctx.guild).entries()
curr_aliases = await self._aliases.guild(ctx.guild).entries()
curr_aliases.append(alias.to_json())
await self._aliases.guild(ctx.guild).entries.set(curr_aliases)
@ -94,10 +94,10 @@ class Alias:
async def delete_alias(self, ctx: commands.Context, alias_name: str,
global_: bool=False) -> bool:
if global_:
aliases = self.unloaded_global_aliases()
aliases = await self.unloaded_global_aliases()
setter_func = self._aliases.entries.set
else:
aliases = self.unloaded_aliases(ctx.guild)
aliases = await self.unloaded_aliases(ctx.guild)
setter_func = self._aliases.guild(ctx.guild).entries.set
did_delete_alias = False
@ -161,7 +161,7 @@ class Alias:
except IndexError:
return False
is_alias, alias = self.is_alias(message.guild, potential_alias, server_aliases=aliases)
is_alias, alias = await self.is_alias(message.guild, potential_alias, server_aliases=aliases)
if is_alias:
await self.call_alias(message, prefix, alias)
@ -206,7 +206,7 @@ class Alias:
" name is already a command on this bot.").format(alias_name))
return
is_alias, _ = self.is_alias(ctx.guild, alias_name)
is_alias, _ = await self.is_alias(ctx.guild, alias_name)
if is_alias:
await ctx.send(("You attempted to create a new alias"
" with the name {} but that"
@ -285,7 +285,7 @@ class Alias:
@commands.guild_only()
async def _show_alias(self, ctx: commands.Context, alias_name: str):
"""Shows what command the alias executes."""
is_alias, alias = self.is_alias(ctx.guild, alias_name)
is_alias, alias = await self.is_alias(ctx.guild, alias_name)
if is_alias:
await ctx.send(("The `{}` alias will execute the"
@ -299,7 +299,7 @@ class Alias:
"""
Deletes an existing alias on this server.
"""
aliases = self.unloaded_aliases(ctx.guild)
aliases = await self.unloaded_aliases(ctx.guild)
try:
next(aliases)
except StopIteration:
@ -317,7 +317,7 @@ class Alias:
"""
Deletes an existing global alias.
"""
aliases = self.unloaded_global_aliases()
aliases = await self.unloaded_global_aliases()
try:
next(aliases)
except StopIteration:
@ -336,7 +336,7 @@ class Alias:
"""
Lists the available aliases on this server.
"""
names = ["Aliases:", ] + sorted(["+ " + a.name for a in self.unloaded_aliases(ctx.guild)])
names = ["Aliases:", ] + sorted(["+ " + a.name for a in (await self.unloaded_aliases(ctx.guild))])
if len(names) == 0:
await ctx.send("There are no aliases on this server.")
else:
@ -347,16 +347,16 @@ class Alias:
"""
Lists the available global aliases on this bot.
"""
names = ["Aliases:", ] + sorted(["+ " + a.name for a in self.unloaded_global_aliases()])
names = ["Aliases:", ] + sorted(["+ " + a.name for a in await self.unloaded_global_aliases()])
if len(names) == 0:
await ctx.send("There are no aliases on this server.")
else:
await ctx.send(box("\n".join(names), "diff"))
async def on_message(self, message: discord.Message):
aliases = list(self.unloaded_global_aliases())
aliases = list(await self.unloaded_global_aliases())
if message.guild is not None:
aliases = aliases + list(self.unloaded_aliases(message.guild))
aliases = aliases + list(await self.unloaded_aliases(message.guild))
if len(aliases) == 0:
return

View File

@ -6,7 +6,7 @@ from core.bot import Red # Only used for type hints
def check_global_setting_guildowner():
async def pred(ctx: commands.Context):
if bank.is_global():
if await bank.is_global():
return checks.is_owner()
else:
return checks.guildowner_or_permissions(administrator=True)
@ -15,7 +15,7 @@ def check_global_setting_guildowner():
def check_global_setting_admin():
async def pred(ctx: commands.Context):
if bank.is_global():
if await bank.is_global():
return checks.is_owner()
else:
return checks.admin_or_permissions(manage_guild=True)
@ -43,7 +43,7 @@ class Bank:
"""Toggles whether the bank is global or not
If the bank is global, it will become per-guild
If the bank is per-guild, it will become global"""
cur_setting = bank.is_global()
cur_setting = await bank.is_global()
await bank.set_global(not cur_setting, ctx.author)
word = "per-guild" if cur_setting else "global"

View File

@ -49,22 +49,20 @@ class Downloader:
self._repo_manager = RepoManager(self.conf)
@property
def cog_install_path(self):
async def cog_install_path(self):
"""
Returns the current cog install path.
:return:
"""
return self.bot.cog_mgr.install_path
return await self.bot.cog_mgr.install_path()
@property
def installed_cogs(self) -> Tuple[Installable]:
async def installed_cogs(self) -> Tuple[Installable]:
"""
Returns the dictionary mapping cog name to install location
and repo name.
:return:
"""
installed = self.conf.installed()
installed = await self.conf.installed()
# noinspection PyTypeChecker
return tuple(Installable.from_json(v) for v in installed)
@ -74,7 +72,7 @@ class Downloader:
:param cog:
:return:
"""
installed = self.conf.installed()
installed = await self.conf.installed()
cog_json = cog.to_json()
if cog_json not in installed:
@ -87,7 +85,7 @@ class Downloader:
:param cog:
:return:
"""
installed = self.conf.installed()
installed = await self.conf.installed()
cog_json = cog.to_json()
if cog_json in installed:
@ -102,7 +100,7 @@ class Downloader:
"""
failed = []
for cog in cogs:
if not await cog.copy_to(self.cog_install_path):
if not await cog.copy_to(await self.cog_install_path()):
failed.append(cog)
# noinspection PyTypeChecker
@ -249,7 +247,7 @@ class Downloader:
" `{}`: `{}`".format(cog.name, cog.requirements))
return
await repo_name.install_cog(cog, self.cog_install_path)
await repo_name.install_cog(cog, await self.cog_install_path())
await self._add_to_installed(cog)
@ -266,7 +264,7 @@ class Downloader:
# noinspection PyUnresolvedReferences,PyProtectedMember
real_name = cog_name.name
poss_installed_path = self.cog_install_path / real_name
poss_installed_path = (await self.cog_install_path()) / real_name
if poss_installed_path.exists():
await self._delete_cog(poss_installed_path)
# noinspection PyTypeChecker
@ -284,7 +282,7 @@ class Downloader:
"""
if cog_name is None:
updated = await self._repo_manager.update_all_repos()
installed_cogs = set(self.installed_cogs)
installed_cogs = set(await self.installed_cogs())
updated_cogs = set(cog for repo in updated.keys() for cog in repo.available_cogs)
installed_and_updated = updated_cogs & installed_cogs
@ -325,14 +323,14 @@ class Downloader:
msg = "Information on {}:\n{}".format(cog.name, cog.description or "")
await ctx.send(box(msg))
def is_installed(self, cog_name: str) -> (bool, Union[Installable, None]):
async def is_installed(self, cog_name: str) -> (bool, Union[Installable, None]):
"""
Checks to see if a cog with the given name was installed
through Downloader.
:param cog_name:
:return: is_installed, Installable
"""
for installable in self.installed_cogs:
for installable in await self.installed_cogs():
if installable.name == cog_name:
return True, installable
return False, None
@ -384,7 +382,7 @@ class Downloader:
# Check if in installed cogs
cog_name = self.cog_name_from_instance(command.instance)
installed, cog_installable = self.is_installed(cog_name)
installed, cog_installable = await self.is_installed(cog_name)
if installed:
msg = self.format_findcog_info(command_name, cog_installable)
else:

View File

@ -430,7 +430,10 @@ class RepoManager:
self.repos_folder = Path(__file__).parent / 'repos'
self._repos = self._load_repos() # str_name: Repo
self._repos = {}
loop = asyncio.get_event_loop()
loop.run_until_complete(self._load_repos(set=True)) # str_name: Repo
def does_repo_exist(self, name: str) -> bool:
return name in self._repos
@ -494,7 +497,6 @@ class RepoManager:
shutil.rmtree(str(repo.folder_path))
repos = self.downloader_config.repos()
try:
del self._repos[name]
except KeyError:
@ -518,11 +520,14 @@ class RepoManager:
await self._save_repos()
return ret
def _load_repos(self) -> MutableMapping[str, Repo]:
return {
async def _load_repos(self, set=False) -> MutableMapping[str, Repo]:
ret = {
name: Repo.from_json(data) for name, data in
self.downloader_config.repos().items()
(await self.downloader_config.repos()).items()
}
if set:
self._repos = ret
return ret
async def _save_repos(self):
repo_json_info = {name: r.to_json() for name, r in self._repos.items()}

View File

@ -72,9 +72,9 @@ SLOT_PAYOUTS_MSG = ("Slot machine payouts:\n"
def guild_only_check():
async def pred(ctx: commands.Context):
if bank.is_global():
if await bank.is_global():
return True
elif not bank.is_global() and ctx.guild is not None:
elif not await bank.is_global() and ctx.guild is not None:
return True
else:
return False
@ -146,8 +146,8 @@ class Economy:
if user is None:
user = ctx.author
bal = bank.get_balance(user)
currency = bank.get_currency_name(ctx.guild)
bal = await bank.get_balance(user)
currency = await bank.get_currency_name(ctx.guild)
await ctx.send("{}'s balance is {} {}".format(
user.display_name, bal, currency))
@ -156,7 +156,7 @@ class Economy:
async def transfer(self, ctx: commands.Context, to: discord.Member, amount: int):
"""Transfer currency to other users"""
from_ = ctx.author
currency = bank.get_currency_name(ctx.guild)
currency = await bank.get_currency_name(ctx.guild)
try:
await bank.transfer_credits(from_, to, amount)
@ -206,12 +206,12 @@ class Economy:
await ctx.send(
"This will delete all bank accounts for {}.\nIf you're sure, type "
"{}bank reset yes".format(
self.bot.user.name if bank.is_global() else "this guild",
self.bot.user.name if await bank.is_global() else "this guild",
ctx.prefix
)
)
else:
if bank.is_global():
if await bank.is_global():
# Bank being global means that the check would cause only
# the owner and any co-owners to be able to run the command
# so if we're in the function, it's safe to assume that the
@ -232,18 +232,18 @@ class Economy:
guild = ctx.guild
cur_time = calendar.timegm(ctx.message.created_at.utctimetuple())
credits_name = bank.get_currency_name(ctx.guild)
if bank.is_global():
next_payday = self.config.user(author).next_payday()
credits_name = await bank.get_currency_name(ctx.guild)
if await bank.is_global():
next_payday = await self.config.user(author).next_payday()
if cur_time >= next_payday:
await bank.deposit_credits(author, self.config.PAYDAY_CREDITS())
next_payday = cur_time + self.config.PAYDAY_TIME()
await bank.deposit_credits(author, await self.config.PAYDAY_CREDITS())
next_payday = cur_time + await self.config.PAYDAY_TIME()
await self.config.user(author).next_payday.set(next_payday)
await ctx.send(
"{} Here, take some {}. Enjoy! (+{}"
" {}!)".format(
author.mention, credits_name,
str(self.config.PAYDAY_CREDITS()),
str(await self.config.PAYDAY_CREDITS()),
credits_name
)
)
@ -254,16 +254,16 @@ class Economy:
" wait {}.".format(author.mention, dtime)
)
else:
next_payday = self.config.member(author).next_payday()
next_payday = await self.config.member(author).next_payday()
if cur_time >= next_payday:
await bank.deposit_credits(author, self.config.guild(guild).PAYDAY_CREDITS())
next_payday = cur_time + self.config.guild(guild).PAYDAY_TIME()
await bank.deposit_credits(author, await self.config.guild(guild).PAYDAY_CREDITS())
next_payday = cur_time + await self.config.guild(guild).PAYDAY_TIME()
await self.config.member(author).next_payday.set(next_payday)
await ctx.send(
"{} Here, take some {}. Enjoy! (+{}"
" {}!)".format(
author.mention, credits_name,
str(self.config.guild(guild).PAYDAY_CREDITS()),
str(await self.config.guild(guild).PAYDAY_CREDITS()),
credits_name))
else:
dtime = self.display_time(next_payday - cur_time)
@ -282,10 +282,10 @@ class Economy:
if top < 1:
top = 10
if bank.is_global():
bank_sorted = sorted(bank.get_global_accounts(ctx.author),
bank_sorted = sorted(await bank.get_global_accounts(ctx.author),
key=lambda x: x.balance, reverse=True)
else:
bank_sorted = sorted(bank.get_guild_accounts(guild),
bank_sorted = sorted(await bank.get_guild_accounts(guild),
key=lambda x: x.balance, reverse=True)
if len(bank_sorted) < top:
top = len(bank_sorted)
@ -320,14 +320,14 @@ class Economy:
author = ctx.author
guild = ctx.guild
channel = ctx.channel
if bank.is_global():
valid_bid = self.config.SLOT_MIN() <= bid <= self.config.SLOT_MAX()
slot_time = self.config.SLOT_TIME()
last_slot = self.config.user(author).last_slot()
if await bank.is_global():
valid_bid = await self.config.SLOT_MIN() <= bid <= await self.config.SLOT_MAX()
slot_time = await self.config.SLOT_TIME()
last_slot = await self.config.user(author).last_slot()
else:
valid_bid = self.config.guild(guild).SLOT_MIN() <= bid <= self.config.guild(guild).SLOT_MAX()
slot_time = self.config.guild(guild).SLOT_TIME()
last_slot = self.config.member(author).last_slot()
valid_bid = await self.config.guild(guild).SLOT_MIN() <= bid <= await self.config.guild(guild).SLOT_MAX()
slot_time = await self.config.guild(guild).SLOT_TIME()
last_slot = await self.config.member(author).last_slot()
now = calendar.timegm(ctx.message.created_at.utctimetuple())
if (now - last_slot) < slot_time:
@ -336,10 +336,10 @@ class Economy:
if not valid_bid:
await ctx.send("That's an invalid bid amount, sorry :/")
return
if not bank.can_spend(author, bid):
if not await bank.can_spend(author, bid):
await ctx.send("You ain't got enough money, friend.")
return
if bank.is_global():
if await bank.is_global():
await self.config.user(author).last_slot.set(now)
else:
await self.config.member(author).last_slot.set(now)
@ -379,7 +379,7 @@ class Economy:
payout = PAYOUTS["2 symbols"]
if payout:
then = bank.get_balance(author)
then = await bank.get_balance(author)
pay = payout["payout"](bid)
now = then - bid + pay
await bank.set_balance(author, now)
@ -387,7 +387,7 @@ class Economy:
"".format(slot, author.mention,
payout["phrase"], bid, then, now))
else:
then = bank.get_balance(author)
then = await bank.get_balance(author)
await bank.withdraw_credits(author, bid)
now = then - bid
await channel.send("{}\n{} Nothing!\nYour bid: {}\n{}{}!"
@ -402,18 +402,18 @@ class Economy:
if ctx.invoked_subcommand is None:
await self.bot.send_cmd_help(ctx)
if bank.is_global():
slot_min = self.config.SLOT_MIN()
slot_max = self.config.SLOT_MAX()
slot_time = self.config.SLOT_TIME()
payday_time = self.config.PAYDAY_TIME()
payday_amount = self.config.PAYDAY_CREDITS()
slot_min = await self.config.SLOT_MIN()
slot_max = await self.config.SLOT_MAX()
slot_time = await self.config.SLOT_TIME()
payday_time = await self.config.PAYDAY_TIME()
payday_amount = await self.config.PAYDAY_CREDITS()
else:
slot_min = self.config.guild(guild).SLOT_MIN()
slot_max = self.config.guild(guild).SLOT_MAX()
slot_time = self.config.guild(guild).SLOT_TIME()
payday_time = self.config.guild(guild).PAYDAY_TIME()
payday_amount = self.config.guild(guild).PAYDAY_CREDITS()
register_amount = bank.get_default_balance(guild)
slot_min = await self.config.guild(guild).SLOT_MIN()
slot_max = await self.config.guild(guild).SLOT_MAX()
slot_time = await self.config.guild(guild).SLOT_TIME()
payday_time = await self.config.guild(guild).PAYDAY_TIME()
payday_amount = await self.config.guild(guild).PAYDAY_CREDITS()
register_amount = await bank.get_default_balance(guild)
msg = box(
"Minimum slot bid: {}\n"
"Maximum slot bid: {}\n"
@ -436,24 +436,24 @@ class Economy:
await ctx.send('Invalid min bid amount.')
return
guild = ctx.guild
if bank.is_global():
if await bank.is_global():
await self.config.SLOT_MIN.set(bid)
else:
await self.config.guild(guild).SLOT_MIN.set(bid)
credits_name = bank.get_currency_name(guild)
credits_name = await bank.get_currency_name(guild)
await ctx.send("Minimum bid is now {} {}.".format(bid, credits_name))
@economyset.command()
async def slotmax(self, ctx: commands.Context, bid: int):
"""Maximum slot machine bid"""
slot_min = self.config.SLOT_MIN()
slot_min = await self.config.SLOT_MIN()
if bid < 1 or bid < slot_min:
await ctx.send('Invalid slotmax bid amount. Must be greater'
' than slotmin.')
return
guild = ctx.guild
credits_name = bank.get_currency_name(guild)
if bank.is_global():
credits_name = await bank.get_currency_name(guild)
if await bank.is_global():
await self.config.SLOT_MAX.set(bid)
else:
await self.config.guild(guild).SLOT_MAX.set(bid)
@ -463,7 +463,7 @@ class Economy:
async def slottime(self, ctx: commands.Context, seconds: int):
"""Seconds between each slots use"""
guild = ctx.guild
if bank.is_global():
if await bank.is_global():
await self.config.SLOT_TIME.set(seconds)
else:
await self.config.guild(guild).SLOT_TIME.set(seconds)
@ -473,7 +473,7 @@ class Economy:
async def paydaytime(self, ctx: commands.Context, seconds: int):
"""Seconds between each payday"""
guild = ctx.guild
if bank.is_global():
if await bank.is_global():
await self.config.PAYDAY_TIME.set(seconds)
else:
await self.config.guild(guild).PAYDAY_TIME.set(seconds)
@ -484,11 +484,11 @@ class Economy:
async def paydayamount(self, ctx: commands.Context, creds: int):
"""Amount earned each payday"""
guild = ctx.guild
credits_name = bank.get_currency_name(guild)
credits_name = await bank.get_currency_name(guild)
if creds <= 0:
await ctx.send("Har har so funny.")
return
if bank.is_global():
if await bank.is_global():
await self.config.PAYDAY_CREDITS.set(creds)
else:
await self.config.guild(guild).PAYDAY_CREDITS.set(creds)
@ -501,7 +501,7 @@ class Economy:
guild = ctx.guild
if creds < 0:
creds = 0
credits_name = bank.get_currency_name(guild)
credits_name = await bank.get_currency_name(guild)
await bank.set_default_balance(creds, guild)
await ctx.send("Registering an account will now give {} {}."
"".format(creds, credits_name))

View File

@ -1,6 +1,6 @@
import datetime
from collections import namedtuple
from typing import Tuple, Generator, Union
from typing import Tuple, Generator, Union, List
import discord
from copy import deepcopy
@ -78,17 +78,17 @@ def _decode_time(time: int) -> datetime.datetime:
return datetime.datetime.utcfromtimestamp(time)
def get_balance(member: discord.Member) -> int:
async def get_balance(member: discord.Member) -> int:
"""
Gets the current balance of a member.
:param member:
:return:
"""
acc = get_account(member)
acc = await get_account(member)
return acc.balance
def can_spend(member: discord.Member, amount: int) -> bool:
async def can_spend(member: discord.Member, amount: int) -> bool:
"""
Determines if a member can spend the given amount.
:param member:
@ -97,7 +97,7 @@ def can_spend(member: discord.Member, amount: int) -> bool:
"""
if _invalid_amount(amount):
return False
return get_balance(member) > amount
return await get_balance(member) > amount
async def set_balance(member: discord.Member, amount: int) -> int:
@ -111,17 +111,17 @@ async def set_balance(member: discord.Member, amount: int) -> int:
"""
if amount < 0:
raise ValueError("Not allowed to have negative balance.")
if is_global():
if await is_global():
group = _conf.user(member)
else:
group = _conf.member(member)
await group.balance.set(amount)
if group.created_at() == 0:
if await group.created_at() == 0:
time = _encoded_current_time()
await group.created_at.set(time)
if group.name() == "":
if await group.name() == "":
await group.name.set(member.display_name)
return amount
@ -144,7 +144,7 @@ async def withdraw_credits(member: discord.Member, amount: int) -> int:
if _invalid_amount(amount):
raise ValueError("Invalid withdrawal amount {} <= 0".format(amount))
bal = get_balance(member)
bal = await get_balance(member)
if amount > bal:
raise ValueError("Insufficient funds {} > {}".format(amount, bal))
@ -163,7 +163,7 @@ async def deposit_credits(member: discord.Member, amount: int) -> int:
if _invalid_amount(amount):
raise ValueError("Invalid withdrawal amount {} <= 0".format(amount))
bal = get_balance(member)
bal = await get_balance(member)
return await set_balance(member, amount + bal)
@ -190,13 +190,13 @@ async def wipe_bank(user: Union[discord.User, discord.Member]):
Deletes all accounts from the bank.
:return:
"""
if is_global():
if await is_global():
await _conf.user(user).clear()
else:
await _conf.member(user).clear()
def get_guild_accounts(guild: discord.Guild) -> Generator[Account, None, None]:
async def get_guild_accounts(guild: discord.Guild) -> List[Account]:
"""
Gets all account data for the given guild.
@ -207,14 +207,16 @@ def get_guild_accounts(guild: discord.Guild) -> Generator[Account, None, None]:
if is_global():
raise RuntimeError("The bank is currently global.")
accs = _conf.member(guild.owner).all_from_kind()
ret = []
accs = await _conf.member(guild.owner).all_from_kind()
for user_id, acc in accs.items():
acc_data = acc.copy() # There ya go kowlin
acc_data['created_at'] = _decode_time(acc_data['created_at'])
yield Account(**acc_data)
ret.append(Account(**acc_data))
return ret
def get_global_accounts(user: discord.User) -> Generator[Account, None, None]:
async def get_global_accounts(user: discord.User) -> List[Account]:
"""
Gets all global account data.
@ -225,44 +227,47 @@ def get_global_accounts(user: discord.User) -> Generator[Account, None, None]:
if not is_global():
raise RuntimeError("The bank is not currently global.")
accs = _conf.user(user).all_from_kind() # this is a dict of user -> acc
ret = []
accs = await _conf.user(user).all_from_kind() # this is a dict of user -> acc
for user_id, acc in accs.items():
acc_data = acc.copy()
acc_data['created_at'] = _decode_time(acc_data['created_at'])
yield Account(**acc_data)
ret.append(Account(**acc_data))
return ret
def get_account(member: Union[discord.Member, discord.User]) -> Account:
async def get_account(member: Union[discord.Member, discord.User]) -> Account:
"""
Gets the appropriate account for the given member.
:param member:
:return:
"""
if is_global():
acc_data = _conf.user(member)().copy()
if await is_global():
acc_data = (await _conf.user(member)()).copy()
default = _DEFAULT_USER.copy()
else:
acc_data = _conf.member(member)().copy()
acc_data = (await _conf.member(member)()).copy()
default = _DEFAULT_MEMBER.copy()
if acc_data == {}:
acc_data = default
acc_data['name'] = member.display_name
try:
acc_data['balance'] = get_default_balance(member.guild)
acc_data['balance'] = await get_default_balance(member.guild)
except AttributeError:
acc_data['balance'] = get_default_balance()
acc_data['balance'] = await get_default_balance()
acc_data['created_at'] = _decode_time(acc_data['created_at'])
return Account(**acc_data)
def is_global() -> bool:
async def is_global() -> bool:
"""
Determines if the bank is currently global.
:return:
"""
return _conf.is_global()
return await _conf.is_global()
async def set_global(global_: bool, user: Union[discord.User, discord.Member]) -> bool:
@ -272,7 +277,7 @@ async def set_global(global_: bool, user: Union[discord.User, discord.Member]) -
:param user: Must be a Member object if changing TO global mode.
:return: New bank mode, True is global.
"""
if is_global() is global_:
if (await is_global()) is global_:
return global_
if is_global():
@ -287,7 +292,7 @@ async def set_global(global_: bool, user: Union[discord.User, discord.Member]) -
return global_
def get_bank_name(guild: discord.Guild=None) -> str:
async def get_bank_name(guild: discord.Guild=None) -> str:
"""
Gets the current bank name. If the bank is guild-specific the
guild parameter is required.
@ -296,10 +301,10 @@ def get_bank_name(guild: discord.Guild=None) -> str:
:param guild:
:return:
"""
if is_global():
return _conf.bank_name()
if await is_global():
return await _conf.bank_name()
elif guild is not None:
return _conf.guild(guild).bank_name()
return await _conf.guild(guild).bank_name()
else:
raise RuntimeError("Guild parameter is required and missing.")
@ -314,7 +319,7 @@ async def set_bank_name(name: str, guild: discord.Guild=None) -> str:
:param guild:
:return:
"""
if is_global():
if await is_global():
await _conf.bank_name.set(name)
elif guild is not None:
await _conf.guild(guild).bank_name.set(name)
@ -324,7 +329,7 @@ async def set_bank_name(name: str, guild: discord.Guild=None) -> str:
return name
def get_currency_name(guild: discord.Guild=None) -> str:
async def get_currency_name(guild: discord.Guild=None) -> str:
"""
Gets the currency name of the bank. The guild parameter is required if
the bank is guild-specific.
@ -333,10 +338,10 @@ def get_currency_name(guild: discord.Guild=None) -> str:
:param guild:
:return:
"""
if is_global():
return _conf.currency()
if await is_global():
return await _conf.currency()
elif guild is not None:
return _conf.guild(guild).currency()
return await _conf.guild(guild).currency()
else:
raise RuntimeError("Guild must be provided.")
@ -351,7 +356,7 @@ async def set_currency_name(name: str, guild: discord.Guild=None) -> str:
:param guild:
:return:
"""
if is_global():
if await is_global():
await _conf.currency.set(name)
elif guild is not None:
await _conf.guild(guild).currency.set(name)
@ -361,7 +366,7 @@ async def set_currency_name(name: str, guild: discord.Guild=None) -> str:
return name
def get_default_balance(guild: discord.Guild=None) -> int:
async def get_default_balance(guild: discord.Guild=None) -> int:
"""
Gets the current default balance amount. If the bank is guild-specific
you must pass guild.
@ -370,10 +375,10 @@ def get_default_balance(guild: discord.Guild=None) -> int:
:param guild:
:return:
"""
if is_global():
return _conf.default_balance()
if await is_global():
return await _conf.default_balance()
elif guild is not None:
return _conf.guild(guild).default_balance()
return await _conf.guild(guild).default_balance()
else:
raise RuntimeError("Guild is missing and required!")
@ -393,9 +398,11 @@ async def set_default_balance(amount: int, guild: discord.Guild=None) -> int:
if amount < 0:
raise ValueError("Amount must be greater than zero.")
if is_global():
if await is_global():
await _conf.default_balance.set(amount)
elif guild is not None:
await _conf.guild(guild).default_balance.set(amount)
else:
raise RuntimeError("Guild is missing and required.")
return amount

View File

@ -1,3 +1,4 @@
import asyncio
import importlib.util
from importlib.machinery import ModuleSpec
@ -39,14 +40,14 @@ class Red(commands.Bot):
mod_role=None
)
def prefix_manager(bot, message):
async def prefix_manager(bot, message):
if not cli_flags.prefix:
global_prefix = self.db.prefix()
global_prefix = await bot.db.prefix()
else:
global_prefix = cli_flags.prefix
if message.guild is None:
return global_prefix
server_prefix = self.db.guild(message.guild).prefix()
server_prefix = await bot.db.guild(message.guild).prefix()
return server_prefix if server_prefix else global_prefix
if "command_prefix" not in kwargs:
@ -56,7 +57,8 @@ class Red(commands.Bot):
kwargs["owner_id"] = cli_flags.owner
if "owner_id" not in kwargs:
kwargs["owner_id"] = self.db.owner()
loop = asyncio.get_event_loop()
loop.run_until_complete(self._dict_abuse(kwargs))
self.counter = Counter()
self.uptime = None
@ -68,6 +70,15 @@ class Red(commands.Bot):
super().__init__(**kwargs)
async def _dict_abuse(self, indict):
"""
Please blame <@269933075037814786> for this.
:param indict:
:return:
"""
indict['owner_id'] = await self.db.owner()
async def is_owner(self, user):
if user.id in self._co_owners:
return True
@ -103,13 +114,13 @@ class Red(commands.Bot):
await self.db.packages.set(packages)
async def add_loaded_package(self, pkg_name: str):
curr_pkgs = self.db.packages()
curr_pkgs = await self.db.packages()
if pkg_name not in curr_pkgs:
curr_pkgs.append(pkg_name)
await self.save_packages_status(curr_pkgs)
async def remove_loaded_package(self, pkg_name: str):
curr_pkgs = self.db.packages()
curr_pkgs = await self.db.packages()
if pkg_name in curr_pkgs:
await self.save_packages_status([p for p in curr_pkgs if p != pkg_name])

View File

@ -31,26 +31,29 @@ class CogManager:
install_path=str(bot_dir.resolve() / "cogs")
)
self._paths = set(list(self.conf.paths()) + list(paths))
self._paths = list(paths)
@property
def paths(self) -> Tuple[Path, ...]:
async def paths(self) -> Tuple[Path, ...]:
"""
This will return all currently valid path directories.
:return:
"""
paths = [Path(p) for p in self._paths]
conf_paths = await self.conf.paths()
other_paths = self._paths
all_paths = set(list(conf_paths) + list(other_paths))
paths = [Path(p) for p in all_paths]
if self.install_path not in paths:
paths.insert(0, self.install_path)
paths.insert(0, await self.install_path())
return tuple(p.resolve() for p in paths if p.is_dir())
@property
def install_path(self) -> Path:
async def install_path(self) -> Path:
"""
Returns the install path for 3rd party cogs.
:return:
"""
p = Path(self.conf.install_path())
p = Path(await self.conf.install_path())
return p.resolve()
async def set_install_path(self, path: Path) -> Path:
@ -99,10 +102,10 @@ class CogManager:
if not path.is_dir():
raise InvalidPath("'{}' is not a valid directory.".format(path))
if path == self.install_path:
if path == await self.install_path():
raise ValueError("Cannot add the install path as an additional path.")
all_paths = set(self.paths + (path, ))
all_paths = set(await self.paths() + (path, ))
# noinspection PyTypeChecker
await self.set_paths(all_paths)
@ -113,7 +116,7 @@ class CogManager:
:return:
"""
path = self._ensure_path_obj(path)
all_paths = list(self.paths)
all_paths = list(await self.paths())
if path in all_paths:
all_paths.remove(path) # Modifies in place
await self.set_paths(all_paths)
@ -125,11 +128,10 @@ class CogManager:
:param paths_:
:return:
"""
self._paths = paths_
str_paths = [str(p) for p in paths_]
await self.conf.paths.set(str_paths)
def find_cog(self, name: str) -> ModuleSpec:
async def find_cog(self, name: str) -> ModuleSpec:
"""
Finds a cog in the list of available path.
@ -137,7 +139,7 @@ class CogManager:
:param name:
:return:
"""
resolved_paths = [str(p.resolve()) for p in self.paths]
resolved_paths = [str(p.resolve()) for p in await self.paths()]
for finder, module_name, _ in pkgutil.iter_modules(resolved_paths):
if name == module_name:
spec = finder.find_spec(name)
@ -166,7 +168,7 @@ class CogManagerUI:
"""
Lists current cog paths in order of priority.
"""
install_path = ctx.bot.cog_mgr.install_path
install_path = await ctx.bot.cog_mgr.install_path()
cog_paths = ctx.bot.cog_mgr.paths
cog_paths = [p for p in cog_paths if p != install_path]
@ -204,7 +206,7 @@ class CogManagerUI:
Removes a path from the available cog paths given the path_number
from !paths
"""
cog_paths = ctx.bot.cog_mgr.paths
cog_paths = await ctx.bot.cog_mgr.paths()
try:
to_remove = cog_paths[path_number]
except IndexError:
@ -224,7 +226,7 @@ class CogManagerUI:
from_ -= 1
to -= 1
all_paths = list(ctx.bot.cog_mgr.paths)
all_paths = list(await ctx.bot.cog_mgr.paths())
try:
to_move = all_paths.pop(from_)
except IndexError:
@ -257,6 +259,6 @@ class CogManagerUI:
await ctx.send("That path does not exist.")
return
install_path = ctx.bot.cog_mgr.install_path
install_path = await ctx.bot.cog_mgr.install_path()
await ctx.send("The bot will install new cogs to the `{}`"
" directory.".format(install_path))

View File

@ -38,6 +38,14 @@ class Value:
def identifiers(self):
return tuple(str(i) for i in self._identifiers)
async def _get(self, default):
driver = self.spawner.get_driver()
try:
ret = await driver.get(self.identifiers)
except KeyError:
return default or self.default
return ret
def __call__(self, default=None):
"""
Each :py:class:`Value` object is created by the :py:meth:`Group.__getattr__` method.
@ -46,25 +54,26 @@ class Value:
For example::
foo = conf.guild(some_guild).foo()
foo = await conf.guild(some_guild).foo()
# Is equivalent to this
group_obj = conf.guild(some_guild)
value_obj = conf.foo
foo = value_obj()
foo = await value_obj()
.. important::
This is now, for all intents and purposes, a coroutine.
:param default:
This argument acts as an override for the registered default provided by :py:attr:`default`. This argument
is ignored if its value is :python:`None`.
:type default: Optional[object]
:return:
A coroutine object that must be awaited.
"""
driver = self.spawner.get_driver()
try:
ret = driver.get(self.identifiers)
except KeyError:
return default or self.default
return ret
return self._get(default)
async def set(self, value):
"""
@ -182,7 +191,7 @@ class Group(Value):
return not isinstance(default, dict)
def get_attr(self, item: str, default=None, resolve=True):
async def get_attr(self, item: str, default=None, resolve=True):
"""
This is available to use as an alternative to using normal Python attribute access. It is required if you find
a need for dynamic attribute access.
@ -198,7 +207,7 @@ class Group(Value):
user = ctx.author
# Where the value of item is the name of the data field in Config
await ctx.send(self.conf.user(user).get_attr(item))
await ctx.send(await self.conf.user(user).get_attr(item))
:param str item:
The name of the data field in :py:class:`.Config`.
@ -211,20 +220,20 @@ class Group(Value):
"""
value = getattr(self, item)
if resolve:
return value(default=default)
return await value(default=default)
else:
return value
def all(self) -> dict:
async def all(self) -> dict:
"""
This method allows you to get "all" of a particular group of data. It will return the dictionary of all data
for a particular Guild/Channel/Role/User/Member etc.
:rtype: dict
"""
return self()
return await self()
def all_from_kind(self) -> dict:
async def all_from_kind(self) -> dict:
"""
This method allows you to get all data from all entries in a given Kind. It will return a dictionary of Kind
ID's -> data.
@ -232,7 +241,7 @@ class Group(Value):
:rtype: dict
"""
# noinspection PyTypeChecker
return self._super_group()
return await self._super_group()
async def set(self, value):
if not isinstance(value, dict):
@ -292,18 +301,18 @@ class MemberGroup(Group):
)
return group_obj
def all_guilds(self) -> dict:
async def all_guilds(self) -> dict:
"""
Returns a dict of :code:`GUILD_ID -> MEMBER_ID -> data`.
:rtype: dict
"""
# noinspection PyTypeChecker
return self._super_group()
return await self._super_group()
def all(self) -> dict:
async def all(self) -> dict:
# noinspection PyTypeChecker
return self._guild_group()
return await self._guild_group()
class Config:
@ -315,7 +324,7 @@ class Config:
however the process for accessing global data is a bit different. There is no :python:`global` method
because global data is accessed by normal attribute access::
conf.foo()
await conf.foo()
.. py:attribute:: cog_name

View File

@ -29,7 +29,7 @@ class Core:
async def load(self, ctx, *, cog_name: str):
"""Loads a package"""
try:
spec = ctx.bot.cog_mgr.find_cog(cog_name)
spec = await ctx.bot.cog_mgr.find_cog(cog_name)
except NoModuleFound:
await ctx.send("No module by that name was found in any"
" cog path.")
@ -63,7 +63,7 @@ class Core:
ctx.bot.unload_extension(cog_name)
self.cleanup_and_refresh_modules(cog_name)
try:
spec = ctx.bot.cog_mgr.find_cog(cog_name)
spec = await ctx.bot.cog_mgr.find_cog(cog_name)
ctx.bot.load_extension(spec)
except Exception as e:
log.exception("Package reloading failed", exc_info=e)

View File

@ -5,7 +5,7 @@ class BaseDriver:
def get_driver(self):
raise NotImplementedError
def get(self, identifiers: Tuple[str]):
async def get(self, identifiers: Tuple[str]):
raise NotImplementedError
async def set(self, identifiers: Tuple[str], value):

View File

@ -32,7 +32,7 @@ class JSON(BaseDriver):
def get_driver(self):
return self
def get(self, identifiers: Tuple[str]):
async def get(self, identifiers: Tuple[str]):
partial = self.data
for i in identifiers:
partial = partial[i]

View File

@ -34,11 +34,11 @@ def init_events(bot, cli_flags):
if cli_flags.no_cogs is False:
print("Loading packages...")
failed = []
packages = bot.db.packages()
packages = await bot.db.packages()
for package in packages:
try:
spec = bot.cog_mgr.find_cog(package)
spec = await bot.cog_mgr.find_cog(package)
bot.load_extension(spec)
except Exception as e:
log.exception("Failed to load package {}".format(package),

27
main.py
View File

@ -73,6 +73,17 @@ def determine_main_folder() -> Path:
return Path(os.path.dirname(__file__)).resolve()
async def _get_prefix_and_token(red, indict):
"""
Again, please blame <@269933075037814786> for this.
:param indict:
:return:
"""
indict['token'] = await red.db.token()
indict['prefix'] = await red.db.prefix()
indict['enable_sentry'] = await red.db.enable_sentry()
if __name__ == '__main__':
cli_flags = parse_cli_flags()
log, sentry_log = init_loggers(cli_flags)
@ -89,8 +100,13 @@ if __name__ == '__main__':
if cli_flags.dev:
red.add_cog(Dev())
token = os.environ.get("RED_TOKEN", red.db.token())
prefix = cli_flags.prefix or red.db.prefix()
loop = asyncio.get_event_loop()
tmp_data = {}
loop.run_until_complete(_get_prefix_and_token(red, tmp_data))
token = os.environ.get("RED_TOKEN", tmp_data['token'])
prefix = cli_flags.prefix or tmp_data['prefix']
enable_sentry = tmp_data['enable_sentry']
if token is None or not prefix:
if cli_flags.no_prompt is False:
@ -102,13 +118,14 @@ if __name__ == '__main__':
log.critical("Token and prefix must be set in order to login.")
sys.exit(1)
if red.db.enable_sentry() is None:
if enable_sentry is None:
ask_sentry(red)
if red.db.enable_sentry():
loop.run_until_complete(_get_prefix_and_token(red, tmp_data))
if tmp_data['enable_sentry']:
init_sentry_logging(sentry_log)
loop = asyncio.get_event_loop()
cleanup_tasks = True
try:

View File

@ -16,12 +16,14 @@ def test_is_valid_alias_name(alias):
assert alias.is_valid_alias_name("not valid name") is False
def test_empty_guild_aliases(alias, empty_guild):
assert list(alias.unloaded_aliases(empty_guild)) == []
@pytest.mark.asyncio
async def test_empty_guild_aliases(alias, empty_guild):
assert list(await alias.unloaded_aliases(empty_guild)) == []
def test_empty_global_aliases(alias):
assert list(alias.unloaded_global_aliases()) == []
@pytest.mark.asyncio
async def test_empty_global_aliases(alias):
assert list(await alias.unloaded_global_aliases()) == []
async def create_test_guild_alias(alias, ctx):
@ -36,7 +38,7 @@ async def create_test_global_alias(alias, ctx):
async def test_add_guild_alias(alias, ctx):
await create_test_guild_alias(alias, ctx)
is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
is_alias, alias_obj = await alias.is_alias(ctx.guild, "test")
assert is_alias is True
assert alias_obj.global_ is False
@ -44,19 +46,19 @@ async def test_add_guild_alias(alias, ctx):
@pytest.mark.asyncio
async def test_delete_guild_alias(alias, ctx):
await create_test_guild_alias(alias, ctx)
is_alias, _ = alias.is_alias(ctx.guild, "test")
is_alias, _ = await alias.is_alias(ctx.guild, "test")
assert is_alias is True
await alias.delete_alias(ctx, "test")
is_alias, _ = alias.is_alias(ctx.guild, "test")
is_alias, _ = await alias.is_alias(ctx.guild, "test")
assert is_alias is False
@pytest.mark.asyncio
async def test_add_global_alias(alias, ctx):
await create_test_global_alias(alias, ctx)
is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
is_alias, alias_obj = await alias.is_alias(ctx.guild, "test")
assert is_alias is True
assert alias_obj.global_ is True
@ -65,7 +67,7 @@ async def test_add_global_alias(alias, ctx):
@pytest.mark.asyncio
async def test_delete_global_alias(alias, ctx):
await create_test_global_alias(alias, ctx)
is_alias, alias_obj = alias.is_alias(ctx.guild, "test")
is_alias, alias_obj = await alias.is_alias(ctx.guild, "test")
assert is_alias is True
assert alias_obj.global_ is True

View File

@ -11,13 +11,14 @@ def bank(config):
return bank
def test_bank_register(bank, ctx):
default_bal = bank.get_default_balance(ctx.guild)
assert default_bal == bank.get_account(ctx.author).balance
@pytest.mark.asyncio
async def test_bank_register(bank, ctx):
default_bal = await bank.get_default_balance(ctx.guild)
assert default_bal == (await bank.get_account(ctx.author)).balance
async def has_account(member, bank):
balance = bank.get_balance(member)
balance = await bank.get_balance(member)
if balance == 0:
balance = 1
await bank.set_balance(member, balance)
@ -27,11 +28,11 @@ async def has_account(member, bank):
async def test_bank_transfer(bank, member_factory):
mbr1 = member_factory.get()
mbr2 = member_factory.get()
bal1 = bank.get_account(mbr1).balance
bal2 = bank.get_account(mbr2).balance
bal1 = (await bank.get_account(mbr1)).balance
bal2 = (await bank.get_account(mbr2)).balance
await bank.transfer_credits(mbr1, mbr2, 50)
newbal1 = bank.get_account(mbr1).balance
newbal2 = bank.get_account(mbr2).balance
newbal1 = (await bank.get_account(mbr1)).balance
newbal2 = (await bank.get_account(mbr2)).balance
assert bal1 - 50 == newbal1
assert bal2 + 50 == newbal2
@ -40,16 +41,16 @@ async def test_bank_transfer(bank, member_factory):
async def test_bank_set(bank, member_factory):
mbr = member_factory.get()
await bank.set_balance(mbr, 250)
acc = bank.get_account(mbr)
acc = await bank.get_account(mbr)
assert acc.balance == 250
@pytest.mark.asyncio
async def test_bank_can_spend(bank, member_factory):
mbr = member_factory.get()
canspend = bank.can_spend(mbr, 50)
assert canspend == (50 < bank.get_default_balance(mbr.guild))
canspend = await bank.can_spend(mbr, 50)
assert canspend == (50 < await bank.get_default_balance(mbr.guild))
await bank.set_balance(mbr, 200)
acc = bank.get_account(mbr)
canspendnow = bank.can_spend(mbr, 100)
acc = await bank.get_account(mbr)
canspendnow = await bank.can_spend(mbr, 100)
assert canspendnow

View File

@ -14,16 +14,17 @@ def default_dir(red):
return red.main_dir
def test_ensure_cogs_in_paths(cog_mgr, default_dir):
@pytest.mark.asyncio
async def test_ensure_cogs_in_paths(cog_mgr, default_dir):
cogs_dir = default_dir / 'cogs'
assert cogs_dir in cog_mgr.paths
assert cogs_dir in await cog_mgr.paths()
@pytest.mark.asyncio
async def test_install_path_set(cog_mgr: cog_manager.CogManager, tmpdir):
path = Path(str(tmpdir))
await cog_mgr.set_install_path(path)
assert cog_mgr.install_path == path
assert await cog_mgr.install_path() == path
@pytest.mark.asyncio
@ -38,7 +39,7 @@ async def test_install_path_set_bad(cog_mgr):
async def test_add_path(cog_mgr, tmpdir):
path = Path(str(tmpdir))
await cog_mgr.add_path(path)
assert path in cog_mgr.paths
assert path in await cog_mgr.paths()
@pytest.mark.asyncio
@ -54,4 +55,4 @@ async def test_remove_path(cog_mgr, tmpdir):
path = Path(str(tmpdir))
await cog_mgr.add_path(path)
await cog_mgr.remove_path(path)
assert path not in cog_mgr.paths
assert path not in await cog_mgr.paths()

View File

@ -2,10 +2,11 @@ import pytest
#region Register Tests
def test_config_register_global(config):
@pytest.mark.asyncio
async def test_config_register_global(config):
config.register_global(enabled=False)
assert config.defaults["GLOBAL"]["enabled"] is False
assert config.enabled() is False
assert await config.enabled() is False
def test_config_register_global_badvalues(config):
@ -13,61 +14,69 @@ def test_config_register_global_badvalues(config):
config.register_global(**{"invalid var name": True})
def test_config_register_guild(config, empty_guild):
@pytest.mark.asyncio
async def test_config_register_guild(config, empty_guild):
config.register_guild(enabled=False, some_list=[], some_dict={})
assert config.defaults[config.GUILD]["enabled"] is False
assert config.defaults[config.GUILD]["some_list"] == []
assert config.defaults[config.GUILD]["some_dict"] == {}
assert config.guild(empty_guild).enabled() is False
assert config.guild(empty_guild).some_list() == []
assert config.guild(empty_guild).some_dict() == {}
assert await config.guild(empty_guild).enabled() is False
assert await config.guild(empty_guild).some_list() == []
assert await config.guild(empty_guild).some_dict() == {}
def test_config_register_channel(config, empty_channel):
@pytest.mark.asyncio
async def test_config_register_channel(config, empty_channel):
config.register_channel(enabled=False)
assert config.defaults[config.CHANNEL]["enabled"] is False
assert config.channel(empty_channel).enabled() is False
assert await config.channel(empty_channel).enabled() is False
def test_config_register_role(config, empty_role):
@pytest.mark.asyncio
async def test_config_register_role(config, empty_role):
config.register_role(enabled=False)
assert config.defaults[config.ROLE]["enabled"] is False
assert config.role(empty_role).enabled() is False
assert await config.role(empty_role).enabled() is False
def test_config_register_member(config, empty_member):
@pytest.mark.asyncio
async def test_config_register_member(config, empty_member):
config.register_member(some_number=-1)
assert config.defaults[config.MEMBER]["some_number"] == -1
assert config.member(empty_member).some_number() == -1
assert await config.member(empty_member).some_number() == -1
def test_config_register_user(config, empty_user):
@pytest.mark.asyncio
async def test_config_register_user(config, empty_user):
config.register_user(some_value=None)
assert config.defaults[config.USER]["some_value"] is None
assert config.user(empty_user).some_value() is None
assert await config.user(empty_user).some_value() is None
def test_config_force_register_global(config_fr):
@pytest.mark.asyncio
async def test_config_force_register_global(config_fr):
with pytest.raises(AttributeError):
config_fr.enabled()
await config_fr.enabled()
config_fr.register_global(enabled=True)
assert config_fr.enabled() is True
assert await config_fr.enabled() is True
#endregion
# Test nested registration
def test_nested_registration(config):
@pytest.mark.asyncio
async def test_nested_registration(config):
config.register_global(foo__bar__baz=False)
assert config.foo.bar.baz() is False
assert await config.foo.bar.baz() is False
def test_nested_registration_asdict(config):
@pytest.mark.asyncio
async def test_nested_registration_asdict(config):
defaults = {'bar': {'baz': False}}
config.register_global(foo=defaults)
assert config.foo.bar.baz() is False
assert await config.foo.bar.baz() is False
@pytest.mark.asyncio
@ -75,20 +84,22 @@ async def test_nested_registration_and_changing(config):
defaults = {'bar': {'baz': False}}
config.register_global(foo=defaults)
assert config.foo.bar.baz() is False
assert await config.foo.bar.baz() is False
with pytest.raises(ValueError):
await config.foo.set(True)
def test_doubleset_default(config):
@pytest.mark.asyncio
async def test_doubleset_default(config):
config.register_global(foo=True)
config.register_global(foo=False)
assert config.foo() is False
assert await config.foo() is False
def test_nested_registration_multidict(config):
@pytest.mark.asyncio
async def test_nested_registration_multidict(config):
defaults = {
"foo": {
"bar": {
@ -99,8 +110,8 @@ def test_nested_registration_multidict(config):
}
config.register_global(**defaults)
assert config.foo.bar.baz() is True
assert config.blah() is True
assert await config.foo.bar.baz() is True
assert await config.blah() is True
def test_nested_group_value_badreg(config):
@ -109,56 +120,66 @@ def test_nested_group_value_badreg(config):
config.register_global(foo__bar=False)
def test_nested_toplevel_reg(config):
@pytest.mark.asyncio
async def test_nested_toplevel_reg(config):
defaults = {'bar': True, 'baz': False}
config.register_global(foo=defaults)
assert config.foo.bar() is True
assert config.foo.baz() is False
assert await config.foo.bar() is True
assert await config.foo.baz() is False
def test_nested_overlapping(config):
@pytest.mark.asyncio
async def test_nested_overlapping(config):
config.register_global(foo__bar=True)
config.register_global(foo__baz=False)
assert config.foo.bar() is True
assert config.foo.baz() is False
assert await config.foo.bar() is True
assert await config.foo.baz() is False
def test_nesting_nofr(config):
@pytest.mark.asyncio
async def test_nesting_nofr(config):
config.register_global(foo={})
assert config.foo.bar() is None
assert config.foo() == {}
assert await config.foo.bar() is None
assert await config.foo() == {}
#region Default Value Overrides
def test_global_default_override(config):
assert config.enabled(True) is True
# region Default Value Overrides
@pytest.mark.asyncio
async def test_global_default_override(config):
assert await config.enabled(True) is True
def test_global_default_nofr(config):
assert config.nofr() is None
assert config.nofr(True) is True
@pytest.mark.asyncio
async def test_global_default_nofr(config):
assert await config.nofr() is None
assert await config.nofr(True) is True
def test_guild_default_override(config, empty_guild):
assert config.guild(empty_guild).enabled(True) is True
@pytest.mark.asyncio
async def test_guild_default_override(config, empty_guild):
assert await config.guild(empty_guild).enabled(True) is True
def test_channel_default_override(config, empty_channel):
assert config.channel(empty_channel).enabled(True) is True
@pytest.mark.asyncio
async def test_channel_default_override(config, empty_channel):
assert await config.channel(empty_channel).enabled(True) is True
def test_role_default_override(config, empty_role):
assert config.role(empty_role).enabled(True) is True
@pytest.mark.asyncio
async def test_role_default_override(config, empty_role):
assert await config.role(empty_role).enabled(True) is True
def test_member_default_override(config, empty_member):
assert config.member(empty_member).enabled(True) is True
@pytest.mark.asyncio
async def test_member_default_override(config, empty_member):
assert await config.member(empty_member).enabled(True) is True
def test_user_default_override(config, empty_user):
assert config.user(empty_user).some_value(True) is True
@pytest.mark.asyncio
async def test_user_default_override(config, empty_user):
assert await config.user(empty_user).some_value(True) is True
#endregion
@ -166,32 +187,32 @@ def test_user_default_override(config, empty_user):
@pytest.mark.asyncio
async def test_set_global(config):
await config.enabled.set(True)
assert config.enabled() is True
assert await config.enabled() is True
@pytest.mark.asyncio
async def test_set_guild(config, empty_guild):
await config.guild(empty_guild).enabled.set(True)
assert config.guild(empty_guild).enabled() is True
assert await config.guild(empty_guild).enabled() is True
curr_list = config.guild(empty_guild).some_list([1, 2, 3])
curr_list = await config.guild(empty_guild).some_list([1, 2, 3])
assert curr_list == [1, 2, 3]
curr_list.append(4)
await config.guild(empty_guild).some_list.set(curr_list)
assert config.guild(empty_guild).some_list() == curr_list
assert await config.guild(empty_guild).some_list() == curr_list
@pytest.mark.asyncio
async def test_set_channel(config, empty_channel):
await config.channel(empty_channel).enabled.set(True)
assert config.channel(empty_channel).enabled() is True
assert await config.channel(empty_channel).enabled() is True
@pytest.mark.asyncio
async def test_set_channel_no_register(config, empty_channel):
await config.channel(empty_channel).no_register.set(True)
assert config.channel(empty_channel).no_register() is True
assert await config.channel(empty_channel).no_register() is True
#endregion
@ -200,11 +221,12 @@ async def test_set_channel_no_register(config, empty_channel):
async def test_set_dynamic_attr(config):
await config.set_attr("foobar", True)
assert config.foobar() is True
assert await config.foobar() is True
def test_get_dynamic_attr(config):
assert config.get_attr("foobaz", True) is True
@pytest.mark.asyncio
async def test_get_dynamic_attr(config):
assert await config.get_attr("foobaz", True) is True
# Member Group testing
@ -212,7 +234,7 @@ def test_get_dynamic_attr(config):
async def test_membergroup_allguilds(config, empty_member):
await config.member(empty_member).foo.set(False)
all_servers = config.member(empty_member).all_guilds()
all_servers = await config.member(empty_member).all_guilds()
assert str(empty_member.guild.id) in all_servers
@ -220,7 +242,7 @@ async def test_membergroup_allguilds(config, empty_member):
async def test_membergroup_allmembers(config, empty_member):
await config.member(empty_member).foo.set(False)
all_members = config.member(empty_member).all()
all_members = await config.member(empty_member).all()
assert str(empty_member.id) in all_members
@ -232,13 +254,13 @@ async def test_global_clear(config):
await config.foo.set(False)
await config.bar.set(True)
assert config.foo() is False
assert config.bar() is True
assert await config.foo() is False
assert await config.bar() is True
await config.clear()
assert config.foo() is True
assert config.bar() is False
assert await config.foo() is True
assert await config.bar() is False
@pytest.mark.asyncio
@ -247,17 +269,17 @@ async def test_member_clear(config, member_factory):
m1 = member_factory.get()
await config.member(m1).foo.set(False)
assert config.member(m1).foo() is False
assert await config.member(m1).foo() is False
m2 = member_factory.get()
await config.member(m2).foo.set(False)
assert config.member(m2).foo() is False
assert await config.member(m2).foo() is False
assert m1.guild.id != m2.guild.id
await config.member(m1).clear()
assert config.member(m1).foo() is True
assert config.member(m2).foo() is False
assert await config.member(m1).foo() is True
assert await config.member(m2).foo() is False
@pytest.mark.asyncio
@ -269,11 +291,11 @@ async def test_member_clear_all(config, member_factory):
server_ids.append(member.guild.id)
member = member_factory.get()
assert len(config.member(member).all_guilds()) == len(server_ids)
assert len(await config.member(member).all_guilds()) == len(server_ids)
await config.member(member).clear_all()
assert len(config.member(member).all_guilds()) == 0
assert len(await config.member(member).all_guilds()) == 0
# Get All testing
@ -284,7 +306,7 @@ async def test_user_get_all_from_kind(config, user_factory):
await config.user(user).foo.set(True)
user = user_factory.get()
all_data = config.user(user).all_from_kind()
all_data = await config.user(user).all_from_kind()
assert len(all_data) == 5
@ -294,4 +316,4 @@ async def test_user_getalldata(config, user_factory):
user = user_factory.get()
await config.user(user).foo.set(False)
assert "foo" in config.user(user).all()
assert "foo" in await config.user(user).all()