mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 11:18:54 -05:00
Default rules for subcommands precede supercommands (#2422)
This incorporates default rules into the same resolution techniques used by concrete rules. Resolves #2313. Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
parent
889fa63aff
commit
435fc141ae
@ -5,7 +5,7 @@ replace those from the `discord.ext.commands` module.
|
|||||||
"""
|
"""
|
||||||
import inspect
|
import inspect
|
||||||
import weakref
|
import weakref
|
||||||
from typing import Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING
|
from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
from discord.ext import commands
|
from discord.ext import commands
|
||||||
@ -50,20 +50,54 @@ class CogCommandMixin:
|
|||||||
checks=getattr(decorated, "__requires_checks__", []),
|
checks=getattr(decorated, "__requires_checks__", []),
|
||||||
)
|
)
|
||||||
|
|
||||||
def allow_for(self, model_id: int, guild_id: int) -> None:
|
def allow_for(self, model_id: Union[int, str], guild_id: int) -> None:
|
||||||
"""Actively allow this command for the given model."""
|
"""Actively allow this command for the given model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_id : Union[int, str]
|
||||||
|
Must be an `int` if supplying an ID. `str` is only valid
|
||||||
|
for "default".
|
||||||
|
guild_id : int
|
||||||
|
The guild ID to allow this cog or command in. For global
|
||||||
|
rules, use ``0``.
|
||||||
|
|
||||||
|
"""
|
||||||
self.requires.set_rule(model_id, PermState.ACTIVE_ALLOW, guild_id=guild_id)
|
self.requires.set_rule(model_id, PermState.ACTIVE_ALLOW, guild_id=guild_id)
|
||||||
|
|
||||||
def deny_to(self, model_id: int, guild_id: int) -> None:
|
def deny_to(self, model_id: Union[int, str], guild_id: int) -> None:
|
||||||
"""Actively deny this command to the given model."""
|
"""Actively deny this command to the given model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_id : Union[int, str]
|
||||||
|
Must be an `int` if supplying an ID. `str` is only valid
|
||||||
|
for "default".
|
||||||
|
guild_id : int
|
||||||
|
The guild ID to deny this cog or command in. For global
|
||||||
|
rules, use ``0``.
|
||||||
|
|
||||||
|
"""
|
||||||
cur_rule = self.requires.get_rule(model_id, guild_id=guild_id)
|
cur_rule = self.requires.get_rule(model_id, guild_id=guild_id)
|
||||||
if cur_rule is PermState.PASSIVE_ALLOW:
|
if cur_rule is PermState.PASSIVE_ALLOW:
|
||||||
self.requires.set_rule(model_id, PermState.CAUTIOUS_ALLOW, guild_id=guild_id)
|
self.requires.set_rule(model_id, PermState.CAUTIOUS_ALLOW, guild_id=guild_id)
|
||||||
else:
|
else:
|
||||||
self.requires.set_rule(model_id, PermState.ACTIVE_DENY, guild_id=guild_id)
|
self.requires.set_rule(model_id, PermState.ACTIVE_DENY, guild_id=guild_id)
|
||||||
|
|
||||||
def clear_rule_for(self, model_id: int, guild_id: int) -> Tuple[PermState, PermState]:
|
def clear_rule_for(
|
||||||
"""Clear the rule which is currently set for this model."""
|
self, model_id: Union[int, str], guild_id: int
|
||||||
|
) -> Tuple[PermState, PermState]:
|
||||||
|
"""Clear the rule which is currently set for this model.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_id : Union[int, str]
|
||||||
|
Must be an `int` if supplying an ID. `str` is only valid
|
||||||
|
for "default".
|
||||||
|
guild_id : int
|
||||||
|
The guild ID. For global rules, use ``0``.
|
||||||
|
|
||||||
|
"""
|
||||||
cur_rule = self.requires.get_rule(model_id, guild_id=guild_id)
|
cur_rule = self.requires.get_rule(model_id, guild_id=guild_id)
|
||||||
if cur_rule is PermState.ACTIVE_ALLOW:
|
if cur_rule is PermState.ACTIVE_ALLOW:
|
||||||
new_rule = PermState.NORMAL
|
new_rule = PermState.NORMAL
|
||||||
@ -84,15 +118,17 @@ class CogCommandMixin:
|
|||||||
rule : Optional[bool]
|
rule : Optional[bool]
|
||||||
The rule to set as default. If ``True`` for allow,
|
The rule to set as default. If ``True`` for allow,
|
||||||
``False`` for deny and ``None`` for normal.
|
``False`` for deny and ``None`` for normal.
|
||||||
guild_id : Optional[int]
|
guild_id : int
|
||||||
Specify to set the default rule for a specific guild.
|
The guild to set the default rule in. When ``0``, this will
|
||||||
When ``None``, this will set the global default rule.
|
set the global default rule.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if guild_id:
|
if rule is None:
|
||||||
self.requires.set_default_guild_rule(guild_id, PermState.from_bool(rule))
|
self.clear_rule_for(Requires.DEFAULT, guild_id=guild_id)
|
||||||
else:
|
elif rule is True:
|
||||||
self.requires.default_global_rule = PermState.from_bool(rule)
|
self.allow_for(Requires.DEFAULT, guild_id=guild_id)
|
||||||
|
elif rule is False:
|
||||||
|
self.deny_to(Requires.DEFAULT, guild_id=guild_id)
|
||||||
|
|
||||||
|
|
||||||
class Command(CogCommandMixin, commands.Command):
|
class Command(CogCommandMixin, commands.Command):
|
||||||
@ -335,7 +371,7 @@ class Command(CogCommandMixin, commands.Command):
|
|||||||
else:
|
else:
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def allow_for(self, model_id: int, guild_id: int) -> None:
|
def allow_for(self, model_id: Union[int, str], guild_id: int) -> None:
|
||||||
super().allow_for(model_id, guild_id=guild_id)
|
super().allow_for(model_id, guild_id=guild_id)
|
||||||
parents = self.parents
|
parents = self.parents
|
||||||
if self.instance is not None:
|
if self.instance is not None:
|
||||||
@ -347,7 +383,9 @@ class Command(CogCommandMixin, commands.Command):
|
|||||||
elif cur_rule is PermState.ACTIVE_DENY:
|
elif cur_rule is PermState.ACTIVE_DENY:
|
||||||
parent.requires.set_rule(model_id, PermState.CAUTIOUS_ALLOW, guild_id=guild_id)
|
parent.requires.set_rule(model_id, PermState.CAUTIOUS_ALLOW, guild_id=guild_id)
|
||||||
|
|
||||||
def clear_rule_for(self, model_id: int, guild_id: int) -> Tuple[PermState, PermState]:
|
def clear_rule_for(
|
||||||
|
self, model_id: Union[int, str], guild_id: int
|
||||||
|
) -> Tuple[PermState, PermState]:
|
||||||
old_rule, new_rule = super().clear_rule_for(model_id, guild_id=guild_id)
|
old_rule, new_rule = super().clear_rule_for(model_id, guild_id=guild_id)
|
||||||
if old_rule is PermState.ACTIVE_ALLOW:
|
if old_rule is PermState.ACTIVE_ALLOW:
|
||||||
parents = self.parents
|
parents = self.parents
|
||||||
@ -396,8 +434,28 @@ class CogGroupMixin:
|
|||||||
all_commands: Dict[str, Command]
|
all_commands: Dict[str, Command]
|
||||||
|
|
||||||
def reevaluate_rules_for(
|
def reevaluate_rules_for(
|
||||||
self, model_id: int, guild_id: Optional[int]
|
self, model_id: Union[str, int], guild_id: Optional[int]
|
||||||
) -> Tuple[PermState, bool]:
|
) -> Tuple[PermState, bool]:
|
||||||
|
"""Re-evaluate a rule by checking subcommand rules.
|
||||||
|
|
||||||
|
This is called when a subcommand is no longer actively allowed.
|
||||||
|
|
||||||
|
Parameters
|
||||||
|
----------
|
||||||
|
model_id : Union[int, str]
|
||||||
|
Must be an `int` if supplying an ID. `str` is only valid
|
||||||
|
for "default".
|
||||||
|
guild_id : int
|
||||||
|
The guild ID. For global rules, use ``0``.
|
||||||
|
|
||||||
|
Returns
|
||||||
|
-------
|
||||||
|
Tuple[PermState, bool]
|
||||||
|
A 2-tuple containing the new rule and a bool indicating
|
||||||
|
whether or not the rule was changed as a result of this
|
||||||
|
call.
|
||||||
|
|
||||||
|
"""
|
||||||
cur_rule = self.requires.get_rule(model_id, guild_id=guild_id)
|
cur_rule = self.requires.get_rule(model_id, guild_id=guild_id)
|
||||||
if cur_rule in (PermState.NORMAL, PermState.ACTIVE_ALLOW, PermState.ACTIVE_DENY):
|
if cur_rule in (PermState.NORMAL, PermState.ACTIVE_ALLOW, PermState.ACTIVE_DENY):
|
||||||
# These three states are unaffected by subcommand rules
|
# These three states are unaffected by subcommand rules
|
||||||
|
|||||||
@ -19,6 +19,7 @@ from typing import (
|
|||||||
TYPE_CHECKING,
|
TYPE_CHECKING,
|
||||||
TypeVar,
|
TypeVar,
|
||||||
Tuple,
|
Tuple,
|
||||||
|
ClassVar,
|
||||||
)
|
)
|
||||||
|
|
||||||
import discord
|
import discord
|
||||||
@ -284,6 +285,14 @@ class Requires:
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
DEFAULT: ClassVar[str] = "default"
|
||||||
|
"""The key for the default rule in a rules dict."""
|
||||||
|
|
||||||
|
GLOBAL: ClassVar[int] = 0
|
||||||
|
"""Should be used in place of a guild ID when setting/getting
|
||||||
|
global rules.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
privilege_level: Optional[PrivilegeLevel],
|
privilege_level: Optional[PrivilegeLevel],
|
||||||
@ -307,10 +316,8 @@ class Requires:
|
|||||||
self.bot_perms.update(**bot_perms)
|
self.bot_perms.update(**bot_perms)
|
||||||
else:
|
else:
|
||||||
self.bot_perms = bot_perms
|
self.bot_perms = bot_perms
|
||||||
self.default_global_rule: PermState = PermState.NORMAL
|
self._global_rules: _RulesDict = _RulesDict()
|
||||||
self._global_rules: _IntKeyDict[PermState] = _IntKeyDict()
|
self._guild_rules: _IntKeyDict[_RulesDict] = _IntKeyDict[_RulesDict]()
|
||||||
self._default_guild_rules: _IntKeyDict[PermState] = _IntKeyDict()
|
|
||||||
self._guild_rules: _IntKeyDict[_IntKeyDict[PermState]] = _IntKeyDict()
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_decorator(
|
def get_decorator(
|
||||||
@ -334,16 +341,17 @@ class Requires:
|
|||||||
|
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
def get_rule(self, model: Union[int, PermissionModel], guild_id: int) -> PermState:
|
def get_rule(self, model: Union[int, str, PermissionModel], guild_id: int) -> PermState:
|
||||||
"""Get the rule for a particular model.
|
"""Get the rule for a particular model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
model : PermissionModel
|
model : Union[int, str, PermissionModel]
|
||||||
The model to get the rule for.
|
The model to get the rule for. `str` is only valid for
|
||||||
|
`Requires.DEFAULT`.
|
||||||
guild_id : int
|
guild_id : int
|
||||||
The ID of the guild for the rule's scope. Set to ``0``
|
The ID of the guild for the rule's scope. Set to
|
||||||
for a global rule.
|
`Requires.GLOBAL` for a global rule.
|
||||||
|
|
||||||
Returns
|
Returns
|
||||||
-------
|
-------
|
||||||
@ -352,31 +360,32 @@ class Requires:
|
|||||||
for an explanation.
|
for an explanation.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if not isinstance(model, int):
|
if not isinstance(model, (str, int)):
|
||||||
model = model.id
|
model = model.id
|
||||||
if guild_id:
|
if guild_id:
|
||||||
rules = self._guild_rules.get(guild_id, _IntKeyDict())
|
rules = self._guild_rules.get(guild_id, _RulesDict())
|
||||||
else:
|
else:
|
||||||
rules = self._global_rules
|
rules = self._global_rules
|
||||||
return rules.get(model, PermState.NORMAL)
|
return rules.get(model, PermState.NORMAL)
|
||||||
|
|
||||||
def set_rule(self, model_id: int, rule: PermState, guild_id: int) -> None:
|
def set_rule(self, model_id: Union[str, int], rule: PermState, guild_id: int) -> None:
|
||||||
"""Set the rule for a particular model.
|
"""Set the rule for a particular model.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
model_id : PermissionModel
|
model_id : Union[str, int]
|
||||||
The model to add a rule for.
|
The model to add a rule for. `str` is only valid for
|
||||||
|
`Requires.DEFAULT`.
|
||||||
rule : PermState
|
rule : PermState
|
||||||
Which state this rule should be set as. See the `PermState`
|
Which state this rule should be set as. See the `PermState`
|
||||||
class for an explanation.
|
class for an explanation.
|
||||||
guild_id : int
|
guild_id : int
|
||||||
The ID of the guild for the rule's scope. Set to ``0``
|
The ID of the guild for the rule's scope. Set to
|
||||||
for a global rule.
|
`Requires.GLOBAL` for a global rule.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if guild_id:
|
if guild_id:
|
||||||
rules = self._guild_rules.setdefault(guild_id, _IntKeyDict())
|
rules = self._guild_rules.setdefault(guild_id, _RulesDict())
|
||||||
else:
|
else:
|
||||||
rules = self._global_rules
|
rules = self._global_rules
|
||||||
if rule is PermState.NORMAL:
|
if rule is PermState.NORMAL:
|
||||||
@ -387,27 +396,24 @@ class Requires:
|
|||||||
def clear_all_rules(self, guild_id: int) -> None:
|
def clear_all_rules(self, guild_id: int) -> None:
|
||||||
"""Clear all rules of a particular scope.
|
"""Clear all rules of a particular scope.
|
||||||
|
|
||||||
|
This will preserve the default rule, if set.
|
||||||
|
|
||||||
Parameters
|
Parameters
|
||||||
----------
|
----------
|
||||||
guild_id : int
|
guild_id : int
|
||||||
The guild ID to clear rules for. If ``0``, this will
|
The guild ID to clear rules for. If set to
|
||||||
clear all global rules and leave all guild rules
|
`Requires.GLOBAL`, this will clear all global rules and
|
||||||
untouched.
|
leave all guild rules untouched.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
if guild_id:
|
if guild_id:
|
||||||
rules = self._guild_rules.setdefault(guild_id, _IntKeyDict())
|
rules = self._guild_rules.setdefault(guild_id, _RulesDict())
|
||||||
else:
|
else:
|
||||||
rules = self._global_rules
|
rules = self._global_rules
|
||||||
|
default = rules.get(self.DEFAULT, None)
|
||||||
rules.clear()
|
rules.clear()
|
||||||
|
if default is not None:
|
||||||
def get_default_guild_rule(self, guild_id: int) -> PermState:
|
rules[self.DEFAULT] = default
|
||||||
"""Get the default rule for a guild."""
|
|
||||||
return self._default_guild_rules.get(guild_id, PermState.NORMAL)
|
|
||||||
|
|
||||||
def set_default_guild_rule(self, guild_id: int, rule: PermState) -> None:
|
|
||||||
"""Set the default rule for a guild."""
|
|
||||||
self._default_guild_rules[guild_id] = rule
|
|
||||||
|
|
||||||
async def verify(self, ctx: "Context") -> bool:
|
async def verify(self, ctx: "Context") -> bool:
|
||||||
"""Check if the given context passes the requirements.
|
"""Check if the given context passes the requirements.
|
||||||
@ -470,9 +476,9 @@ class Requires:
|
|||||||
# We must check what would happen normally, if no explicit rules were set.
|
# We must check what would happen normally, if no explicit rules were set.
|
||||||
default_rule = PermState.NORMAL
|
default_rule = PermState.NORMAL
|
||||||
if ctx.guild is not None:
|
if ctx.guild is not None:
|
||||||
default_rule = self.get_default_guild_rule(guild_id=ctx.guild.id)
|
default_rule = self.get_rule(self.DEFAULT, guild_id=ctx.guild.id)
|
||||||
if default_rule is PermState.NORMAL:
|
if default_rule is PermState.NORMAL:
|
||||||
default_rule = self.default_global_rule
|
default_rule = self.get_rule(self.DEFAULT, self.GLOBAL)
|
||||||
|
|
||||||
if default_rule == PermState.ACTIVE_DENY:
|
if default_rule == PermState.ACTIVE_DENY:
|
||||||
would_invoke = False
|
would_invoke = False
|
||||||
@ -510,7 +516,7 @@ class Requires:
|
|||||||
rule = self._global_rules.get(author.id)
|
rule = self._global_rules.get(author.id)
|
||||||
if rule is not None:
|
if rule is not None:
|
||||||
return rule
|
return rule
|
||||||
return self.default_global_rule
|
return self.get_rule(self.DEFAULT, self.GLOBAL)
|
||||||
|
|
||||||
rules_chain = [self._global_rules]
|
rules_chain = [self._global_rules]
|
||||||
guild_rules = self._guild_rules.get(ctx.guild.id)
|
guild_rules = self._guild_rules.get(ctx.guild.id)
|
||||||
@ -534,9 +540,9 @@ class Requires:
|
|||||||
return rule
|
return rule
|
||||||
del model_chain[-1] # We don't check for the guild in guild rules
|
del model_chain[-1] # We don't check for the guild in guild rules
|
||||||
|
|
||||||
default_rule = self.get_default_guild_rule(guild.id)
|
default_rule = self.get_rule(self.DEFAULT, guild.id)
|
||||||
if default_rule is PermState.NORMAL:
|
if default_rule is PermState.NORMAL:
|
||||||
default_rule = self.default_global_rule
|
default_rule = self.get_rule(self.DEFAULT, self.GLOBAL)
|
||||||
return default_rule
|
return default_rule
|
||||||
|
|
||||||
async def _verify_checks(self, ctx: "Context") -> bool:
|
async def _verify_checks(self, ctx: "Context") -> bool:
|
||||||
@ -706,6 +712,20 @@ class _IntKeyDict(Dict[int, _T]):
|
|||||||
return super().__setitem__(key, value)
|
return super().__setitem__(key, value)
|
||||||
|
|
||||||
|
|
||||||
|
class _RulesDict(Dict[Union[int, str], PermState]):
|
||||||
|
"""Dict subclass which throws a KeyError when an invalid key is used."""
|
||||||
|
|
||||||
|
def __getitem__(self, key: Any) -> PermState:
|
||||||
|
if key != Requires.DEFAULT and not isinstance(key, int):
|
||||||
|
raise TypeError(f'Expected "{Requires.DEFAULT}" or int key, not "{key}"')
|
||||||
|
return super().__getitem__(key)
|
||||||
|
|
||||||
|
def __setitem__(self, key: Any, value: PermState) -> None:
|
||||||
|
if key != Requires.DEFAULT and not isinstance(key, int):
|
||||||
|
raise TypeError(f'Expected "{Requires.DEFAULT}" or int key, not "{key}"')
|
||||||
|
return super().__setitem__(key, value)
|
||||||
|
|
||||||
|
|
||||||
def _validate_perms_dict(perms: Dict[str, bool]) -> None:
|
def _validate_perms_dict(perms: Dict[str, bool]) -> None:
|
||||||
for perm, value in perms.items():
|
for perm, value in perms.items():
|
||||||
try:
|
try:
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user