Use a SnowflakeList for Command disable instead of checks (#5552)

This commit is contained in:
Predä 2023-05-17 22:33:04 +02:00 committed by GitHub
parent edb3369169
commit f47d1dffb3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -92,7 +92,6 @@ RESERVED_COMMAND_NAMES = (
) )
_ = Translator("commands.commands", __file__) _ = Translator("commands.commands", __file__)
DisablerDictType = MutableMapping[discord.Guild, Callable[["Context"], Awaitable[bool]]]
class RedUnhandledAPI(Exception): class RedUnhandledAPI(Exception):
@ -311,6 +310,7 @@ class Command(CogCommandMixin, DPYCommand):
def __init__(self, *args, **kwargs): def __init__(self, *args, **kwargs):
self.ignore_optional_for_conversion = kwargs.pop("ignore_optional_for_conversion", False) self.ignore_optional_for_conversion = kwargs.pop("ignore_optional_for_conversion", False)
self._disabled_in: discord.utils.SnowflakeList = discord.utils.SnowflakeList([])
self._help_override = kwargs.pop("help_override", None) self._help_override = kwargs.pop("help_override", None)
self.translator = kwargs.pop("i18n", None) self.translator = kwargs.pop("i18n", None)
super().__init__(*args, **kwargs) super().__init__(*args, **kwargs)
@ -461,10 +461,20 @@ class Command(CogCommandMixin, DPYCommand):
if not change_permission_state: if not change_permission_state:
ctx.permission_state = original_state ctx.permission_state = original_state
def is_enabled(self, ctx) -> bool:
if not self.enabled:
return False
if ctx.guild:
if self._disabled_in.has(ctx.guild.id):
return False
return True
async def prepare(self, ctx, /): async def prepare(self, ctx, /):
ctx.command = self ctx.command = self
if not self.enabled: cmd_enabled = self.is_enabled(ctx)
if not cmd_enabled:
raise DisabledCommand(f"{self.name} command is disabled") raise DisabledCommand(f"{self.name} command is disabled")
if not await self.can_run(ctx, change_permission_state=True): if not await self.can_run(ctx, change_permission_state=True):
@ -533,11 +543,10 @@ class Command(CogCommandMixin, DPYCommand):
``True`` if the command wasn't already disabled. ``True`` if the command wasn't already disabled.
""" """
disabler = get_command_disabler(guild) if self._disabled_in.has(guild.id):
if disabler in self.checks:
return False return False
else:
self.checks.append(disabler) self._disabled_in.add(guild.id)
return True return True
def enable_in(self, guild: discord.Guild) -> bool: def enable_in(self, guild: discord.Guild) -> bool:
@ -554,12 +563,11 @@ class Command(CogCommandMixin, DPYCommand):
``True`` if the command wasn't already enabled. ``True`` if the command wasn't already enabled.
""" """
disabler = get_command_disabler(guild)
try: try:
self.checks.remove(disabler) self._disabled_in.remove(guild.id)
except ValueError: except ValueError:
return False return False
else:
return True return True
def allow_for(self, model_id: Union[int, str], guild_id: int) -> None: def allow_for(self, model_id: Union[int, str], guild_id: int) -> None:
@ -1151,28 +1159,6 @@ def group(name=None, cls=Group, **attrs):
return dpy_command_deco(name, cls, **attrs) return dpy_command_deco(name, cls, **attrs)
__command_disablers: DisablerDictType = weakref.WeakValueDictionary()
def get_command_disabler(guild: discord.Guild) -> Callable[["Context"], Awaitable[bool]]:
"""Get the command disabler for a guild.
A command disabler is a simple check predicate which returns
``False`` if the context is within the given guild.
"""
try:
return __command_disablers[guild.id]
except KeyError:
async def disabler(ctx: "Context") -> bool:
if ctx.guild is not None and ctx.guild.id == guild.id:
raise DisabledCommand()
return True
__command_disablers[guild.id] = disabler
return disabler
# The below are intentionally left out of `__all__` # The below are intentionally left out of `__all__`
# as they are not intended for general use # as they are not intended for general use
class _AlwaysAvailableMixin: class _AlwaysAvailableMixin: