From 0eb22c84ff554aa1a898e0a87e8e235afc1b1a61 Mon Sep 17 00:00:00 2001 From: Michael H Date: Tue, 2 Jul 2019 20:07:19 -0400 Subject: [PATCH] [Bank API] Add cost decorator (#2761) --- docs/framework_bank.rst | 4 ++ redbot/core/bank.py | 75 +++++++++++++++++++++++++++++++++- redbot/core/commands/errors.py | 10 ++++- redbot/core/events.py | 3 ++ 4 files changed, 90 insertions(+), 2 deletions(-) diff --git a/docs/framework_bank.rst b/docs/framework_bank.rst index 393fff631..b6cfb4ec4 100644 --- a/docs/framework_bank.rst +++ b/docs/framework_bank.rst @@ -40,3 +40,7 @@ Bank .. automodule:: redbot.core.bank :members: + :exclude-members: cost + + .. autofunction:: cost + :decorator: diff --git a/redbot/core/bank.py b/redbot/core/bank.py index 442af8026..3b08a5083 100644 --- a/redbot/core/bank.py +++ b/redbot/core/bank.py @@ -1,9 +1,14 @@ +import asyncio import datetime from typing import Union, List, Optional +from functools import wraps import discord -from . import Config, errors +from . import Config, errors, commands +from .i18n import Translator + +_ = Translator("Bank API", __file__) __all__ = [ "MAX_BALANCE", @@ -24,6 +29,8 @@ __all__ = [ "set_currency_name", "get_default_balance", "set_default_balance", + "cost", + "AbortPurchase", ] MAX_BALANCE = 2 ** 63 - 1 @@ -669,3 +676,69 @@ async def set_default_balance(amount: int, guild: discord.Guild = None) -> int: raise RuntimeError("Guild is missing and required.") return amount + + +class AbortPurchase(Exception): + pass + + +def cost(amount: int): + """ + Decorates a coroutine-function or command to have a cost. + + If the command raises an exception, the cost will be refunded. + + You can intentionally refund by raising `AbortPurchase` + (this error will be consumed and not show to users) + + Other exceptions will propogate and will be handled by Red's (and/or + any other configured) error handling. + """ + if not isinstance(amount, int) or amount < 0: + raise ValueError("This decorator requires an integer cost greater than or equal to zero") + + def deco(coro_or_command): + is_command = isinstance(coro_or_command, commands.Command) + if not is_command and not asyncio.iscoroutinefunction(coro_or_command): + raise TypeError("@bank.cost() can only be used on commands or `async def` functions") + + coro = coro_or_command.callback if is_command else coro_or_command + + @wraps(coro) + async def wrapped(*args, **kwargs): + context: commands.Context = None + for arg in args: + if isinstance(arg, commands.Context): + context = arg + break + + if not context.guild and not await is_global(): + raise commands.UserFeedbackCheckFailure( + _("Can't pay for this command in DM without a global bank.") + ) + try: + await withdraw_credits(context.author, amount) + except Exception: + credits_name = await get_currency_name(context.guild) + raise commands.UserFeedbackCheckFailure( + _("You need at least {cost} {currency} to use this command.").format( + cost=amount, currency=credits_name + ) + ) + else: + try: + return await coro(*args, **kwargs) + except AbortPurchase: + await deposit_credits(context.author, amount) + except Exception: + await deposit_credits(context.author, amount) + raise + + if not is_command: + return wrapped + else: + wrapped.__module__ = coro_or_command.callback.__module__ + coro_or_command.callback = wrapped + return coro_or_command + + return deco diff --git a/redbot/core/commands/errors.py b/redbot/core/commands/errors.py index 5eb0b70bb..5c264c83e 100644 --- a/redbot/core/commands/errors.py +++ b/redbot/core/commands/errors.py @@ -3,7 +3,7 @@ import inspect import discord from discord.ext import commands -__all__ = ["ConversionFailure", "BotMissingPermissions"] +__all__ = ["ConversionFailure", "BotMissingPermissions", "UserFeedbackCheckFailure"] class ConversionFailure(commands.BadArgument): @@ -22,3 +22,11 @@ class BotMissingPermissions(commands.CheckFailure): def __init__(self, missing: discord.Permissions, *args): self.missing: discord.Permissions = missing super().__init__(*args) + + +class UserFeedbackCheckFailure(commands.CheckFailure): + """A version of CheckFailure which isn't suppressed.""" + + def __init__(self, message=None, *args): + self.message = message + super().__init__(message, *args) diff --git a/redbot/core/events.py b/redbot/core/events.py index 33f025789..62f53af56 100644 --- a/redbot/core/events.py +++ b/redbot/core/events.py @@ -224,6 +224,9 @@ def init_events(bot, cli_flags): perms=format_perms_list(error.missing), plural=plural ) ) + elif isinstance(error, commands.UserFeedbackCheckFailure): + if error.message: + await ctx.send(error.message) elif isinstance(error, commands.CheckFailure): pass elif isinstance(error, commands.NoPrivateMessage):