[Bank API] Add cost decorator (#2761)

This commit is contained in:
Michael H 2019-07-02 20:07:19 -04:00 committed by Toby Harradine
parent d1593b8069
commit 0eb22c84ff
4 changed files with 90 additions and 2 deletions

View File

@ -40,3 +40,7 @@ Bank
.. automodule:: redbot.core.bank .. automodule:: redbot.core.bank
:members: :members:
:exclude-members: cost
.. autofunction:: cost
:decorator:

View File

@ -1,9 +1,14 @@
import asyncio
import datetime import datetime
from typing import Union, List, Optional from typing import Union, List, Optional
from functools import wraps
import discord import discord
from . import Config, errors from . import Config, errors, commands
from .i18n import Translator
_ = Translator("Bank API", __file__)
__all__ = [ __all__ = [
"MAX_BALANCE", "MAX_BALANCE",
@ -24,6 +29,8 @@ __all__ = [
"set_currency_name", "set_currency_name",
"get_default_balance", "get_default_balance",
"set_default_balance", "set_default_balance",
"cost",
"AbortPurchase",
] ]
MAX_BALANCE = 2 ** 63 - 1 MAX_BALANCE = 2 ** 63 - 1
@ -669,3 +676,69 @@ async def set_default_balance(amount: int, guild: discord.Guild = None) -> int:
raise RuntimeError("Guild is missing and required.") raise RuntimeError("Guild is missing and required.")
return amount return amount
class AbortPurchase(Exception):
pass
def cost(amount: int):
"""
Decorates a coroutine-function or command to have a cost.
If the command raises an exception, the cost will be refunded.
You can intentionally refund by raising `AbortPurchase`
(this error will be consumed and not show to users)
Other exceptions will propogate and will be handled by Red's (and/or
any other configured) error handling.
"""
if not isinstance(amount, int) or amount < 0:
raise ValueError("This decorator requires an integer cost greater than or equal to zero")
def deco(coro_or_command):
is_command = isinstance(coro_or_command, commands.Command)
if not is_command and not asyncio.iscoroutinefunction(coro_or_command):
raise TypeError("@bank.cost() can only be used on commands or `async def` functions")
coro = coro_or_command.callback if is_command else coro_or_command
@wraps(coro)
async def wrapped(*args, **kwargs):
context: commands.Context = None
for arg in args:
if isinstance(arg, commands.Context):
context = arg
break
if not context.guild and not await is_global():
raise commands.UserFeedbackCheckFailure(
_("Can't pay for this command in DM without a global bank.")
)
try:
await withdraw_credits(context.author, amount)
except Exception:
credits_name = await get_currency_name(context.guild)
raise commands.UserFeedbackCheckFailure(
_("You need at least {cost} {currency} to use this command.").format(
cost=amount, currency=credits_name
)
)
else:
try:
return await coro(*args, **kwargs)
except AbortPurchase:
await deposit_credits(context.author, amount)
except Exception:
await deposit_credits(context.author, amount)
raise
if not is_command:
return wrapped
else:
wrapped.__module__ = coro_or_command.callback.__module__
coro_or_command.callback = wrapped
return coro_or_command
return deco

View File

@ -3,7 +3,7 @@ import inspect
import discord import discord
from discord.ext import commands from discord.ext import commands
__all__ = ["ConversionFailure", "BotMissingPermissions"] __all__ = ["ConversionFailure", "BotMissingPermissions", "UserFeedbackCheckFailure"]
class ConversionFailure(commands.BadArgument): class ConversionFailure(commands.BadArgument):
@ -22,3 +22,11 @@ class BotMissingPermissions(commands.CheckFailure):
def __init__(self, missing: discord.Permissions, *args): def __init__(self, missing: discord.Permissions, *args):
self.missing: discord.Permissions = missing self.missing: discord.Permissions = missing
super().__init__(*args) super().__init__(*args)
class UserFeedbackCheckFailure(commands.CheckFailure):
"""A version of CheckFailure which isn't suppressed."""
def __init__(self, message=None, *args):
self.message = message
super().__init__(message, *args)

View File

@ -224,6 +224,9 @@ def init_events(bot, cli_flags):
perms=format_perms_list(error.missing), plural=plural perms=format_perms_list(error.missing), plural=plural
) )
) )
elif isinstance(error, commands.UserFeedbackCheckFailure):
if error.message:
await ctx.send(error.message)
elif isinstance(error, commands.CheckFailure): elif isinstance(error, commands.CheckFailure):
pass pass
elif isinstance(error, commands.NoPrivateMessage): elif isinstance(error, commands.NoPrivateMessage):