mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-21 18:27:59 -05:00
Default rules for subcommands precede supercommands (#2422)
This incorporates default rules into the same resolution techniques used by concrete rules. Resolves #2313. Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
@@ -19,6 +19,7 @@ from typing import (
|
||||
TYPE_CHECKING,
|
||||
TypeVar,
|
||||
Tuple,
|
||||
ClassVar,
|
||||
)
|
||||
|
||||
import discord
|
||||
@@ -284,6 +285,14 @@ class Requires:
|
||||
|
||||
"""
|
||||
|
||||
DEFAULT: ClassVar[str] = "default"
|
||||
"""The key for the default rule in a rules dict."""
|
||||
|
||||
GLOBAL: ClassVar[int] = 0
|
||||
"""Should be used in place of a guild ID when setting/getting
|
||||
global rules.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
privilege_level: Optional[PrivilegeLevel],
|
||||
@@ -307,10 +316,8 @@ class Requires:
|
||||
self.bot_perms.update(**bot_perms)
|
||||
else:
|
||||
self.bot_perms = bot_perms
|
||||
self.default_global_rule: PermState = PermState.NORMAL
|
||||
self._global_rules: _IntKeyDict[PermState] = _IntKeyDict()
|
||||
self._default_guild_rules: _IntKeyDict[PermState] = _IntKeyDict()
|
||||
self._guild_rules: _IntKeyDict[_IntKeyDict[PermState]] = _IntKeyDict()
|
||||
self._global_rules: _RulesDict = _RulesDict()
|
||||
self._guild_rules: _IntKeyDict[_RulesDict] = _IntKeyDict[_RulesDict]()
|
||||
|
||||
@staticmethod
|
||||
def get_decorator(
|
||||
@@ -334,16 +341,17 @@ class Requires:
|
||||
|
||||
return decorator
|
||||
|
||||
def get_rule(self, model: Union[int, PermissionModel], guild_id: int) -> PermState:
|
||||
def get_rule(self, model: Union[int, str, PermissionModel], guild_id: int) -> PermState:
|
||||
"""Get the rule for a particular model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model : PermissionModel
|
||||
The model to get the rule for.
|
||||
model : Union[int, str, PermissionModel]
|
||||
The model to get the rule for. `str` is only valid for
|
||||
`Requires.DEFAULT`.
|
||||
guild_id : int
|
||||
The ID of the guild for the rule's scope. Set to ``0``
|
||||
for a global rule.
|
||||
The ID of the guild for the rule's scope. Set to
|
||||
`Requires.GLOBAL` for a global rule.
|
||||
|
||||
Returns
|
||||
-------
|
||||
@@ -352,31 +360,32 @@ class Requires:
|
||||
for an explanation.
|
||||
|
||||
"""
|
||||
if not isinstance(model, int):
|
||||
if not isinstance(model, (str, int)):
|
||||
model = model.id
|
||||
if guild_id:
|
||||
rules = self._guild_rules.get(guild_id, _IntKeyDict())
|
||||
rules = self._guild_rules.get(guild_id, _RulesDict())
|
||||
else:
|
||||
rules = self._global_rules
|
||||
return rules.get(model, PermState.NORMAL)
|
||||
|
||||
def set_rule(self, model_id: int, rule: PermState, guild_id: int) -> None:
|
||||
def set_rule(self, model_id: Union[str, int], rule: PermState, guild_id: int) -> None:
|
||||
"""Set the rule for a particular model.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
model_id : PermissionModel
|
||||
The model to add a rule for.
|
||||
model_id : Union[str, int]
|
||||
The model to add a rule for. `str` is only valid for
|
||||
`Requires.DEFAULT`.
|
||||
rule : PermState
|
||||
Which state this rule should be set as. See the `PermState`
|
||||
class for an explanation.
|
||||
guild_id : int
|
||||
The ID of the guild for the rule's scope. Set to ``0``
|
||||
for a global rule.
|
||||
The ID of the guild for the rule's scope. Set to
|
||||
`Requires.GLOBAL` for a global rule.
|
||||
|
||||
"""
|
||||
if guild_id:
|
||||
rules = self._guild_rules.setdefault(guild_id, _IntKeyDict())
|
||||
rules = self._guild_rules.setdefault(guild_id, _RulesDict())
|
||||
else:
|
||||
rules = self._global_rules
|
||||
if rule is PermState.NORMAL:
|
||||
@@ -387,27 +396,24 @@ class Requires:
|
||||
def clear_all_rules(self, guild_id: int) -> None:
|
||||
"""Clear all rules of a particular scope.
|
||||
|
||||
This will preserve the default rule, if set.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
guild_id : int
|
||||
The guild ID to clear rules for. If ``0``, this will
|
||||
clear all global rules and leave all guild rules
|
||||
untouched.
|
||||
The guild ID to clear rules for. If set to
|
||||
`Requires.GLOBAL`, this will clear all global rules and
|
||||
leave all guild rules untouched.
|
||||
|
||||
"""
|
||||
if guild_id:
|
||||
rules = self._guild_rules.setdefault(guild_id, _IntKeyDict())
|
||||
rules = self._guild_rules.setdefault(guild_id, _RulesDict())
|
||||
else:
|
||||
rules = self._global_rules
|
||||
default = rules.get(self.DEFAULT, None)
|
||||
rules.clear()
|
||||
|
||||
def get_default_guild_rule(self, guild_id: int) -> PermState:
|
||||
"""Get the default rule for a guild."""
|
||||
return self._default_guild_rules.get(guild_id, PermState.NORMAL)
|
||||
|
||||
def set_default_guild_rule(self, guild_id: int, rule: PermState) -> None:
|
||||
"""Set the default rule for a guild."""
|
||||
self._default_guild_rules[guild_id] = rule
|
||||
if default is not None:
|
||||
rules[self.DEFAULT] = default
|
||||
|
||||
async def verify(self, ctx: "Context") -> bool:
|
||||
"""Check if the given context passes the requirements.
|
||||
@@ -470,9 +476,9 @@ class Requires:
|
||||
# We must check what would happen normally, if no explicit rules were set.
|
||||
default_rule = PermState.NORMAL
|
||||
if ctx.guild is not None:
|
||||
default_rule = self.get_default_guild_rule(guild_id=ctx.guild.id)
|
||||
default_rule = self.get_rule(self.DEFAULT, guild_id=ctx.guild.id)
|
||||
if default_rule is PermState.NORMAL:
|
||||
default_rule = self.default_global_rule
|
||||
default_rule = self.get_rule(self.DEFAULT, self.GLOBAL)
|
||||
|
||||
if default_rule == PermState.ACTIVE_DENY:
|
||||
would_invoke = False
|
||||
@@ -510,7 +516,7 @@ class Requires:
|
||||
rule = self._global_rules.get(author.id)
|
||||
if rule is not None:
|
||||
return rule
|
||||
return self.default_global_rule
|
||||
return self.get_rule(self.DEFAULT, self.GLOBAL)
|
||||
|
||||
rules_chain = [self._global_rules]
|
||||
guild_rules = self._guild_rules.get(ctx.guild.id)
|
||||
@@ -534,9 +540,9 @@ class Requires:
|
||||
return rule
|
||||
del model_chain[-1] # We don't check for the guild in guild rules
|
||||
|
||||
default_rule = self.get_default_guild_rule(guild.id)
|
||||
default_rule = self.get_rule(self.DEFAULT, guild.id)
|
||||
if default_rule is PermState.NORMAL:
|
||||
default_rule = self.default_global_rule
|
||||
default_rule = self.get_rule(self.DEFAULT, self.GLOBAL)
|
||||
return default_rule
|
||||
|
||||
async def _verify_checks(self, ctx: "Context") -> bool:
|
||||
@@ -706,6 +712,20 @@ class _IntKeyDict(Dict[int, _T]):
|
||||
return super().__setitem__(key, value)
|
||||
|
||||
|
||||
class _RulesDict(Dict[Union[int, str], PermState]):
|
||||
"""Dict subclass which throws a KeyError when an invalid key is used."""
|
||||
|
||||
def __getitem__(self, key: Any) -> PermState:
|
||||
if key != Requires.DEFAULT and not isinstance(key, int):
|
||||
raise TypeError(f'Expected "{Requires.DEFAULT}" or int key, not "{key}"')
|
||||
return super().__getitem__(key)
|
||||
|
||||
def __setitem__(self, key: Any, value: PermState) -> None:
|
||||
if key != Requires.DEFAULT and not isinstance(key, int):
|
||||
raise TypeError(f'Expected "{Requires.DEFAULT}" or int key, not "{key}"')
|
||||
return super().__setitem__(key, value)
|
||||
|
||||
|
||||
def _validate_perms_dict(perms: Dict[str, bool]) -> None:
|
||||
for perm, value in perms.items():
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user