mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 03:08:55 -05:00
[Bank API] Add cost decorator (#2761)
This commit is contained in:
parent
d1593b8069
commit
0eb22c84ff
@ -40,3 +40,7 @@ Bank
|
||||
|
||||
.. automodule:: redbot.core.bank
|
||||
:members:
|
||||
:exclude-members: cost
|
||||
|
||||
.. autofunction:: cost
|
||||
:decorator:
|
||||
|
||||
@ -1,9 +1,14 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
from typing import Union, List, Optional
|
||||
from functools import wraps
|
||||
|
||||
import discord
|
||||
|
||||
from . import Config, errors
|
||||
from . import Config, errors, commands
|
||||
from .i18n import Translator
|
||||
|
||||
_ = Translator("Bank API", __file__)
|
||||
|
||||
__all__ = [
|
||||
"MAX_BALANCE",
|
||||
@ -24,6 +29,8 @@ __all__ = [
|
||||
"set_currency_name",
|
||||
"get_default_balance",
|
||||
"set_default_balance",
|
||||
"cost",
|
||||
"AbortPurchase",
|
||||
]
|
||||
|
||||
MAX_BALANCE = 2 ** 63 - 1
|
||||
@ -669,3 +676,69 @@ async def set_default_balance(amount: int, guild: discord.Guild = None) -> int:
|
||||
raise RuntimeError("Guild is missing and required.")
|
||||
|
||||
return amount
|
||||
|
||||
|
||||
class AbortPurchase(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def cost(amount: int):
|
||||
"""
|
||||
Decorates a coroutine-function or command to have a cost.
|
||||
|
||||
If the command raises an exception, the cost will be refunded.
|
||||
|
||||
You can intentionally refund by raising `AbortPurchase`
|
||||
(this error will be consumed and not show to users)
|
||||
|
||||
Other exceptions will propogate and will be handled by Red's (and/or
|
||||
any other configured) error handling.
|
||||
"""
|
||||
if not isinstance(amount, int) or amount < 0:
|
||||
raise ValueError("This decorator requires an integer cost greater than or equal to zero")
|
||||
|
||||
def deco(coro_or_command):
|
||||
is_command = isinstance(coro_or_command, commands.Command)
|
||||
if not is_command and not asyncio.iscoroutinefunction(coro_or_command):
|
||||
raise TypeError("@bank.cost() can only be used on commands or `async def` functions")
|
||||
|
||||
coro = coro_or_command.callback if is_command else coro_or_command
|
||||
|
||||
@wraps(coro)
|
||||
async def wrapped(*args, **kwargs):
|
||||
context: commands.Context = None
|
||||
for arg in args:
|
||||
if isinstance(arg, commands.Context):
|
||||
context = arg
|
||||
break
|
||||
|
||||
if not context.guild and not await is_global():
|
||||
raise commands.UserFeedbackCheckFailure(
|
||||
_("Can't pay for this command in DM without a global bank.")
|
||||
)
|
||||
try:
|
||||
await withdraw_credits(context.author, amount)
|
||||
except Exception:
|
||||
credits_name = await get_currency_name(context.guild)
|
||||
raise commands.UserFeedbackCheckFailure(
|
||||
_("You need at least {cost} {currency} to use this command.").format(
|
||||
cost=amount, currency=credits_name
|
||||
)
|
||||
)
|
||||
else:
|
||||
try:
|
||||
return await coro(*args, **kwargs)
|
||||
except AbortPurchase:
|
||||
await deposit_credits(context.author, amount)
|
||||
except Exception:
|
||||
await deposit_credits(context.author, amount)
|
||||
raise
|
||||
|
||||
if not is_command:
|
||||
return wrapped
|
||||
else:
|
||||
wrapped.__module__ = coro_or_command.callback.__module__
|
||||
coro_or_command.callback = wrapped
|
||||
return coro_or_command
|
||||
|
||||
return deco
|
||||
|
||||
@ -3,7 +3,7 @@ import inspect
|
||||
import discord
|
||||
from discord.ext import commands
|
||||
|
||||
__all__ = ["ConversionFailure", "BotMissingPermissions"]
|
||||
__all__ = ["ConversionFailure", "BotMissingPermissions", "UserFeedbackCheckFailure"]
|
||||
|
||||
|
||||
class ConversionFailure(commands.BadArgument):
|
||||
@ -22,3 +22,11 @@ class BotMissingPermissions(commands.CheckFailure):
|
||||
def __init__(self, missing: discord.Permissions, *args):
|
||||
self.missing: discord.Permissions = missing
|
||||
super().__init__(*args)
|
||||
|
||||
|
||||
class UserFeedbackCheckFailure(commands.CheckFailure):
|
||||
"""A version of CheckFailure which isn't suppressed."""
|
||||
|
||||
def __init__(self, message=None, *args):
|
||||
self.message = message
|
||||
super().__init__(message, *args)
|
||||
|
||||
@ -224,6 +224,9 @@ def init_events(bot, cli_flags):
|
||||
perms=format_perms_list(error.missing), plural=plural
|
||||
)
|
||||
)
|
||||
elif isinstance(error, commands.UserFeedbackCheckFailure):
|
||||
if error.message:
|
||||
await ctx.send(error.message)
|
||||
elif isinstance(error, commands.CheckFailure):
|
||||
pass
|
||||
elif isinstance(error, commands.NoPrivateMessage):
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user