mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 11:18:54 -05:00
[Bank API] Add cost decorator (#2761)
This commit is contained in:
parent
d1593b8069
commit
0eb22c84ff
@ -40,3 +40,7 @@ Bank
|
|||||||
|
|
||||||
.. automodule:: redbot.core.bank
|
.. automodule:: redbot.core.bank
|
||||||
:members:
|
:members:
|
||||||
|
:exclude-members: cost
|
||||||
|
|
||||||
|
.. autofunction:: cost
|
||||||
|
:decorator:
|
||||||
|
|||||||
@ -1,9 +1,14 @@
|
|||||||
|
import asyncio
|
||||||
import datetime
|
import datetime
|
||||||
from typing import Union, List, Optional
|
from typing import Union, List, Optional
|
||||||
|
from functools import wraps
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
|
|
||||||
from . import Config, errors
|
from . import Config, errors, commands
|
||||||
|
from .i18n import Translator
|
||||||
|
|
||||||
|
_ = Translator("Bank API", __file__)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"MAX_BALANCE",
|
"MAX_BALANCE",
|
||||||
@ -24,6 +29,8 @@ __all__ = [
|
|||||||
"set_currency_name",
|
"set_currency_name",
|
||||||
"get_default_balance",
|
"get_default_balance",
|
||||||
"set_default_balance",
|
"set_default_balance",
|
||||||
|
"cost",
|
||||||
|
"AbortPurchase",
|
||||||
]
|
]
|
||||||
|
|
||||||
MAX_BALANCE = 2 ** 63 - 1
|
MAX_BALANCE = 2 ** 63 - 1
|
||||||
@ -669,3 +676,69 @@ async def set_default_balance(amount: int, guild: discord.Guild = None) -> int:
|
|||||||
raise RuntimeError("Guild is missing and required.")
|
raise RuntimeError("Guild is missing and required.")
|
||||||
|
|
||||||
return amount
|
return amount
|
||||||
|
|
||||||
|
|
||||||
|
class AbortPurchase(Exception):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
def cost(amount: int):
|
||||||
|
"""
|
||||||
|
Decorates a coroutine-function or command to have a cost.
|
||||||
|
|
||||||
|
If the command raises an exception, the cost will be refunded.
|
||||||
|
|
||||||
|
You can intentionally refund by raising `AbortPurchase`
|
||||||
|
(this error will be consumed and not show to users)
|
||||||
|
|
||||||
|
Other exceptions will propogate and will be handled by Red's (and/or
|
||||||
|
any other configured) error handling.
|
||||||
|
"""
|
||||||
|
if not isinstance(amount, int) or amount < 0:
|
||||||
|
raise ValueError("This decorator requires an integer cost greater than or equal to zero")
|
||||||
|
|
||||||
|
def deco(coro_or_command):
|
||||||
|
is_command = isinstance(coro_or_command, commands.Command)
|
||||||
|
if not is_command and not asyncio.iscoroutinefunction(coro_or_command):
|
||||||
|
raise TypeError("@bank.cost() can only be used on commands or `async def` functions")
|
||||||
|
|
||||||
|
coro = coro_or_command.callback if is_command else coro_or_command
|
||||||
|
|
||||||
|
@wraps(coro)
|
||||||
|
async def wrapped(*args, **kwargs):
|
||||||
|
context: commands.Context = None
|
||||||
|
for arg in args:
|
||||||
|
if isinstance(arg, commands.Context):
|
||||||
|
context = arg
|
||||||
|
break
|
||||||
|
|
||||||
|
if not context.guild and not await is_global():
|
||||||
|
raise commands.UserFeedbackCheckFailure(
|
||||||
|
_("Can't pay for this command in DM without a global bank.")
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
await withdraw_credits(context.author, amount)
|
||||||
|
except Exception:
|
||||||
|
credits_name = await get_currency_name(context.guild)
|
||||||
|
raise commands.UserFeedbackCheckFailure(
|
||||||
|
_("You need at least {cost} {currency} to use this command.").format(
|
||||||
|
cost=amount, currency=credits_name
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
return await coro(*args, **kwargs)
|
||||||
|
except AbortPurchase:
|
||||||
|
await deposit_credits(context.author, amount)
|
||||||
|
except Exception:
|
||||||
|
await deposit_credits(context.author, amount)
|
||||||
|
raise
|
||||||
|
|
||||||
|
if not is_command:
|
||||||
|
return wrapped
|
||||||
|
else:
|
||||||
|
wrapped.__module__ = coro_or_command.callback.__module__
|
||||||
|
coro_or_command.callback = wrapped
|
||||||
|
return coro_or_command
|
||||||
|
|
||||||
|
return deco
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import inspect
|
|||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
|
|
||||||
__all__ = ["ConversionFailure", "BotMissingPermissions"]
|
__all__ = ["ConversionFailure", "BotMissingPermissions", "UserFeedbackCheckFailure"]
|
||||||
|
|
||||||
|
|
||||||
class ConversionFailure(commands.BadArgument):
|
class ConversionFailure(commands.BadArgument):
|
||||||
@ -22,3 +22,11 @@ class BotMissingPermissions(commands.CheckFailure):
|
|||||||
def __init__(self, missing: discord.Permissions, *args):
|
def __init__(self, missing: discord.Permissions, *args):
|
||||||
self.missing: discord.Permissions = missing
|
self.missing: discord.Permissions = missing
|
||||||
super().__init__(*args)
|
super().__init__(*args)
|
||||||
|
|
||||||
|
|
||||||
|
class UserFeedbackCheckFailure(commands.CheckFailure):
|
||||||
|
"""A version of CheckFailure which isn't suppressed."""
|
||||||
|
|
||||||
|
def __init__(self, message=None, *args):
|
||||||
|
self.message = message
|
||||||
|
super().__init__(message, *args)
|
||||||
|
|||||||
@ -224,6 +224,9 @@ def init_events(bot, cli_flags):
|
|||||||
perms=format_perms_list(error.missing), plural=plural
|
perms=format_perms_list(error.missing), plural=plural
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
elif isinstance(error, commands.UserFeedbackCheckFailure):
|
||||||
|
if error.message:
|
||||||
|
await ctx.send(error.message)
|
||||||
elif isinstance(error, commands.CheckFailure):
|
elif isinstance(error, commands.CheckFailure):
|
||||||
pass
|
pass
|
||||||
elif isinstance(error, commands.NoPrivateMessage):
|
elif isinstance(error, commands.NoPrivateMessage):
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user