[Bank API] Add cost decorator (#2761)

This commit is contained in:
Michael H
2019-07-02 20:07:19 -04:00
committed by Toby Harradine
parent d1593b8069
commit 0eb22c84ff
4 changed files with 90 additions and 2 deletions

View File

@@ -1,9 +1,14 @@
import asyncio
import datetime
from typing import Union, List, Optional
from functools import wraps
import discord
from . import Config, errors
from . import Config, errors, commands
from .i18n import Translator
_ = Translator("Bank API", __file__)
__all__ = [
"MAX_BALANCE",
@@ -24,6 +29,8 @@ __all__ = [
"set_currency_name",
"get_default_balance",
"set_default_balance",
"cost",
"AbortPurchase",
]
MAX_BALANCE = 2 ** 63 - 1
@@ -669,3 +676,69 @@ async def set_default_balance(amount: int, guild: discord.Guild = None) -> int:
raise RuntimeError("Guild is missing and required.")
return amount
class AbortPurchase(Exception):
pass
def cost(amount: int):
"""
Decorates a coroutine-function or command to have a cost.
If the command raises an exception, the cost will be refunded.
You can intentionally refund by raising `AbortPurchase`
(this error will be consumed and not show to users)
Other exceptions will propogate and will be handled by Red's (and/or
any other configured) error handling.
"""
if not isinstance(amount, int) or amount < 0:
raise ValueError("This decorator requires an integer cost greater than or equal to zero")
def deco(coro_or_command):
is_command = isinstance(coro_or_command, commands.Command)
if not is_command and not asyncio.iscoroutinefunction(coro_or_command):
raise TypeError("@bank.cost() can only be used on commands or `async def` functions")
coro = coro_or_command.callback if is_command else coro_or_command
@wraps(coro)
async def wrapped(*args, **kwargs):
context: commands.Context = None
for arg in args:
if isinstance(arg, commands.Context):
context = arg
break
if not context.guild and not await is_global():
raise commands.UserFeedbackCheckFailure(
_("Can't pay for this command in DM without a global bank.")
)
try:
await withdraw_credits(context.author, amount)
except Exception:
credits_name = await get_currency_name(context.guild)
raise commands.UserFeedbackCheckFailure(
_("You need at least {cost} {currency} to use this command.").format(
cost=amount, currency=credits_name
)
)
else:
try:
return await coro(*args, **kwargs)
except AbortPurchase:
await deposit_credits(context.author, amount)
except Exception:
await deposit_credits(context.author, amount)
raise
if not is_command:
return wrapped
else:
wrapped.__module__ = coro_or_command.callback.__module__
coro_or_command.callback = wrapped
return coro_or_command
return deco