From 435fc141aeb52dfae5605f000d16dafa2d416023 Mon Sep 17 00:00:00 2001 From: Toby Harradine Date: Mon, 11 Feb 2019 14:14:29 +1100 Subject: [PATCH] Default rules for subcommands precede supercommands (#2422) This incorporates default rules into the same resolution techniques used by concrete rules. Resolves #2313. Signed-off-by: Toby Harradine --- redbot/core/commands/commands.py | 92 ++++++++++++++++++++++++++------ redbot/core/commands/requires.py | 88 ++++++++++++++++++------------ 2 files changed, 129 insertions(+), 51 deletions(-) diff --git a/redbot/core/commands/commands.py b/redbot/core/commands/commands.py index ea37af87a..4f12d2577 100644 --- a/redbot/core/commands/commands.py +++ b/redbot/core/commands/commands.py @@ -5,7 +5,7 @@ replace those from the `discord.ext.commands` module. """ import inspect import weakref -from typing import Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING +from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING import discord from discord.ext import commands @@ -50,20 +50,54 @@ class CogCommandMixin: checks=getattr(decorated, "__requires_checks__", []), ) - def allow_for(self, model_id: int, guild_id: int) -> None: - """Actively allow this command for the given model.""" + def allow_for(self, model_id: Union[int, str], guild_id: int) -> None: + """Actively allow this command for the given model. + + Parameters + ---------- + model_id : Union[int, str] + Must be an `int` if supplying an ID. `str` is only valid + for "default". + guild_id : int + The guild ID to allow this cog or command in. For global + rules, use ``0``. + + """ self.requires.set_rule(model_id, PermState.ACTIVE_ALLOW, guild_id=guild_id) - def deny_to(self, model_id: int, guild_id: int) -> None: - """Actively deny this command to the given model.""" + def deny_to(self, model_id: Union[int, str], guild_id: int) -> None: + """Actively deny this command to the given model. + + Parameters + ---------- + model_id : Union[int, str] + Must be an `int` if supplying an ID. `str` is only valid + for "default". + guild_id : int + The guild ID to deny this cog or command in. For global + rules, use ``0``. + + """ cur_rule = self.requires.get_rule(model_id, guild_id=guild_id) if cur_rule is PermState.PASSIVE_ALLOW: self.requires.set_rule(model_id, PermState.CAUTIOUS_ALLOW, guild_id=guild_id) else: self.requires.set_rule(model_id, PermState.ACTIVE_DENY, guild_id=guild_id) - def clear_rule_for(self, model_id: int, guild_id: int) -> Tuple[PermState, PermState]: - """Clear the rule which is currently set for this model.""" + def clear_rule_for( + self, model_id: Union[int, str], guild_id: int + ) -> Tuple[PermState, PermState]: + """Clear the rule which is currently set for this model. + + Parameters + ---------- + model_id : Union[int, str] + Must be an `int` if supplying an ID. `str` is only valid + for "default". + guild_id : int + The guild ID. For global rules, use ``0``. + + """ cur_rule = self.requires.get_rule(model_id, guild_id=guild_id) if cur_rule is PermState.ACTIVE_ALLOW: new_rule = PermState.NORMAL @@ -84,15 +118,17 @@ class CogCommandMixin: rule : Optional[bool] The rule to set as default. If ``True`` for allow, ``False`` for deny and ``None`` for normal. - guild_id : Optional[int] - Specify to set the default rule for a specific guild. - When ``None``, this will set the global default rule. + guild_id : int + The guild to set the default rule in. When ``0``, this will + set the global default rule. """ - if guild_id: - self.requires.set_default_guild_rule(guild_id, PermState.from_bool(rule)) - else: - self.requires.default_global_rule = PermState.from_bool(rule) + if rule is None: + self.clear_rule_for(Requires.DEFAULT, guild_id=guild_id) + elif rule is True: + self.allow_for(Requires.DEFAULT, guild_id=guild_id) + elif rule is False: + self.deny_to(Requires.DEFAULT, guild_id=guild_id) class Command(CogCommandMixin, commands.Command): @@ -335,7 +371,7 @@ class Command(CogCommandMixin, commands.Command): else: return True - def allow_for(self, model_id: int, guild_id: int) -> None: + def allow_for(self, model_id: Union[int, str], guild_id: int) -> None: super().allow_for(model_id, guild_id=guild_id) parents = self.parents if self.instance is not None: @@ -347,7 +383,9 @@ class Command(CogCommandMixin, commands.Command): elif cur_rule is PermState.ACTIVE_DENY: parent.requires.set_rule(model_id, PermState.CAUTIOUS_ALLOW, guild_id=guild_id) - def clear_rule_for(self, model_id: int, guild_id: int) -> Tuple[PermState, PermState]: + def clear_rule_for( + self, model_id: Union[int, str], guild_id: int + ) -> Tuple[PermState, PermState]: old_rule, new_rule = super().clear_rule_for(model_id, guild_id=guild_id) if old_rule is PermState.ACTIVE_ALLOW: parents = self.parents @@ -396,8 +434,28 @@ class CogGroupMixin: all_commands: Dict[str, Command] def reevaluate_rules_for( - self, model_id: int, guild_id: Optional[int] + self, model_id: Union[str, int], guild_id: Optional[int] ) -> Tuple[PermState, bool]: + """Re-evaluate a rule by checking subcommand rules. + + This is called when a subcommand is no longer actively allowed. + + Parameters + ---------- + model_id : Union[int, str] + Must be an `int` if supplying an ID. `str` is only valid + for "default". + guild_id : int + The guild ID. For global rules, use ``0``. + + Returns + ------- + Tuple[PermState, bool] + A 2-tuple containing the new rule and a bool indicating + whether or not the rule was changed as a result of this + call. + + """ cur_rule = self.requires.get_rule(model_id, guild_id=guild_id) if cur_rule in (PermState.NORMAL, PermState.ACTIVE_ALLOW, PermState.ACTIVE_DENY): # These three states are unaffected by subcommand rules diff --git a/redbot/core/commands/requires.py b/redbot/core/commands/requires.py index 9759919ba..08b482b6c 100644 --- a/redbot/core/commands/requires.py +++ b/redbot/core/commands/requires.py @@ -19,6 +19,7 @@ from typing import ( TYPE_CHECKING, TypeVar, Tuple, + ClassVar, ) import discord @@ -284,6 +285,14 @@ class Requires: """ + DEFAULT: ClassVar[str] = "default" + """The key for the default rule in a rules dict.""" + + GLOBAL: ClassVar[int] = 0 + """Should be used in place of a guild ID when setting/getting + global rules. + """ + def __init__( self, privilege_level: Optional[PrivilegeLevel], @@ -307,10 +316,8 @@ class Requires: self.bot_perms.update(**bot_perms) else: self.bot_perms = bot_perms - self.default_global_rule: PermState = PermState.NORMAL - self._global_rules: _IntKeyDict[PermState] = _IntKeyDict() - self._default_guild_rules: _IntKeyDict[PermState] = _IntKeyDict() - self._guild_rules: _IntKeyDict[_IntKeyDict[PermState]] = _IntKeyDict() + self._global_rules: _RulesDict = _RulesDict() + self._guild_rules: _IntKeyDict[_RulesDict] = _IntKeyDict[_RulesDict]() @staticmethod def get_decorator( @@ -334,16 +341,17 @@ class Requires: return decorator - def get_rule(self, model: Union[int, PermissionModel], guild_id: int) -> PermState: + def get_rule(self, model: Union[int, str, PermissionModel], guild_id: int) -> PermState: """Get the rule for a particular model. Parameters ---------- - model : PermissionModel - The model to get the rule for. + model : Union[int, str, PermissionModel] + The model to get the rule for. `str` is only valid for + `Requires.DEFAULT`. guild_id : int - The ID of the guild for the rule's scope. Set to ``0`` - for a global rule. + The ID of the guild for the rule's scope. Set to + `Requires.GLOBAL` for a global rule. Returns ------- @@ -352,31 +360,32 @@ class Requires: for an explanation. """ - if not isinstance(model, int): + if not isinstance(model, (str, int)): model = model.id if guild_id: - rules = self._guild_rules.get(guild_id, _IntKeyDict()) + rules = self._guild_rules.get(guild_id, _RulesDict()) else: rules = self._global_rules return rules.get(model, PermState.NORMAL) - def set_rule(self, model_id: int, rule: PermState, guild_id: int) -> None: + def set_rule(self, model_id: Union[str, int], rule: PermState, guild_id: int) -> None: """Set the rule for a particular model. Parameters ---------- - model_id : PermissionModel - The model to add a rule for. + model_id : Union[str, int] + The model to add a rule for. `str` is only valid for + `Requires.DEFAULT`. rule : PermState Which state this rule should be set as. See the `PermState` class for an explanation. guild_id : int - The ID of the guild for the rule's scope. Set to ``0`` - for a global rule. + The ID of the guild for the rule's scope. Set to + `Requires.GLOBAL` for a global rule. """ if guild_id: - rules = self._guild_rules.setdefault(guild_id, _IntKeyDict()) + rules = self._guild_rules.setdefault(guild_id, _RulesDict()) else: rules = self._global_rules if rule is PermState.NORMAL: @@ -387,27 +396,24 @@ class Requires: def clear_all_rules(self, guild_id: int) -> None: """Clear all rules of a particular scope. + This will preserve the default rule, if set. + Parameters ---------- guild_id : int - The guild ID to clear rules for. If ``0``, this will - clear all global rules and leave all guild rules - untouched. + The guild ID to clear rules for. If set to + `Requires.GLOBAL`, this will clear all global rules and + leave all guild rules untouched. """ if guild_id: - rules = self._guild_rules.setdefault(guild_id, _IntKeyDict()) + rules = self._guild_rules.setdefault(guild_id, _RulesDict()) else: rules = self._global_rules + default = rules.get(self.DEFAULT, None) rules.clear() - - def get_default_guild_rule(self, guild_id: int) -> PermState: - """Get the default rule for a guild.""" - return self._default_guild_rules.get(guild_id, PermState.NORMAL) - - def set_default_guild_rule(self, guild_id: int, rule: PermState) -> None: - """Set the default rule for a guild.""" - self._default_guild_rules[guild_id] = rule + if default is not None: + rules[self.DEFAULT] = default async def verify(self, ctx: "Context") -> bool: """Check if the given context passes the requirements. @@ -470,9 +476,9 @@ class Requires: # We must check what would happen normally, if no explicit rules were set. default_rule = PermState.NORMAL if ctx.guild is not None: - default_rule = self.get_default_guild_rule(guild_id=ctx.guild.id) + default_rule = self.get_rule(self.DEFAULT, guild_id=ctx.guild.id) if default_rule is PermState.NORMAL: - default_rule = self.default_global_rule + default_rule = self.get_rule(self.DEFAULT, self.GLOBAL) if default_rule == PermState.ACTIVE_DENY: would_invoke = False @@ -510,7 +516,7 @@ class Requires: rule = self._global_rules.get(author.id) if rule is not None: return rule - return self.default_global_rule + return self.get_rule(self.DEFAULT, self.GLOBAL) rules_chain = [self._global_rules] guild_rules = self._guild_rules.get(ctx.guild.id) @@ -534,9 +540,9 @@ class Requires: return rule del model_chain[-1] # We don't check for the guild in guild rules - default_rule = self.get_default_guild_rule(guild.id) + default_rule = self.get_rule(self.DEFAULT, guild.id) if default_rule is PermState.NORMAL: - default_rule = self.default_global_rule + default_rule = self.get_rule(self.DEFAULT, self.GLOBAL) return default_rule async def _verify_checks(self, ctx: "Context") -> bool: @@ -706,6 +712,20 @@ class _IntKeyDict(Dict[int, _T]): return super().__setitem__(key, value) +class _RulesDict(Dict[Union[int, str], PermState]): + """Dict subclass which throws a KeyError when an invalid key is used.""" + + def __getitem__(self, key: Any) -> PermState: + if key != Requires.DEFAULT and not isinstance(key, int): + raise TypeError(f'Expected "{Requires.DEFAULT}" or int key, not "{key}"') + return super().__getitem__(key) + + def __setitem__(self, key: Any, value: PermState) -> None: + if key != Requires.DEFAULT and not isinstance(key, int): + raise TypeError(f'Expected "{Requires.DEFAULT}" or int key, not "{key}"') + return super().__setitem__(key, value) + + def _validate_perms_dict(perms: Dict[str, bool]) -> None: for perm, value in perms.items(): try: