Default rules for subcommands precede supercommands (#2422)

This incorporates default rules into the same resolution techniques used by concrete rules.

Resolves #2313.

Signed-off-by: Toby Harradine <tobyharradine@gmail.com>
This commit is contained in:
Toby Harradine 2019-02-11 14:14:29 +11:00 committed by GitHub
parent 889fa63aff
commit 435fc141ae
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 129 additions and 51 deletions

View File

@ -5,7 +5,7 @@ replace those from the `discord.ext.commands` module.
""" """
import inspect import inspect
import weakref import weakref
from typing import Awaitable, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING from typing import Awaitable, Callable, Dict, List, Optional, Tuple, Union, TYPE_CHECKING
import discord import discord
from discord.ext import commands from discord.ext import commands
@ -50,20 +50,54 @@ class CogCommandMixin:
checks=getattr(decorated, "__requires_checks__", []), checks=getattr(decorated, "__requires_checks__", []),
) )
def allow_for(self, model_id: int, guild_id: int) -> None: def allow_for(self, model_id: Union[int, str], guild_id: int) -> None:
"""Actively allow this command for the given model.""" """Actively allow this command for the given model.
Parameters
----------
model_id : Union[int, str]
Must be an `int` if supplying an ID. `str` is only valid
for "default".
guild_id : int
The guild ID to allow this cog or command in. For global
rules, use ``0``.
"""
self.requires.set_rule(model_id, PermState.ACTIVE_ALLOW, guild_id=guild_id) self.requires.set_rule(model_id, PermState.ACTIVE_ALLOW, guild_id=guild_id)
def deny_to(self, model_id: int, guild_id: int) -> None: def deny_to(self, model_id: Union[int, str], guild_id: int) -> None:
"""Actively deny this command to the given model.""" """Actively deny this command to the given model.
Parameters
----------
model_id : Union[int, str]
Must be an `int` if supplying an ID. `str` is only valid
for "default".
guild_id : int
The guild ID to deny this cog or command in. For global
rules, use ``0``.
"""
cur_rule = self.requires.get_rule(model_id, guild_id=guild_id) cur_rule = self.requires.get_rule(model_id, guild_id=guild_id)
if cur_rule is PermState.PASSIVE_ALLOW: if cur_rule is PermState.PASSIVE_ALLOW:
self.requires.set_rule(model_id, PermState.CAUTIOUS_ALLOW, guild_id=guild_id) self.requires.set_rule(model_id, PermState.CAUTIOUS_ALLOW, guild_id=guild_id)
else: else:
self.requires.set_rule(model_id, PermState.ACTIVE_DENY, guild_id=guild_id) self.requires.set_rule(model_id, PermState.ACTIVE_DENY, guild_id=guild_id)
def clear_rule_for(self, model_id: int, guild_id: int) -> Tuple[PermState, PermState]: def clear_rule_for(
"""Clear the rule which is currently set for this model.""" self, model_id: Union[int, str], guild_id: int
) -> Tuple[PermState, PermState]:
"""Clear the rule which is currently set for this model.
Parameters
----------
model_id : Union[int, str]
Must be an `int` if supplying an ID. `str` is only valid
for "default".
guild_id : int
The guild ID. For global rules, use ``0``.
"""
cur_rule = self.requires.get_rule(model_id, guild_id=guild_id) cur_rule = self.requires.get_rule(model_id, guild_id=guild_id)
if cur_rule is PermState.ACTIVE_ALLOW: if cur_rule is PermState.ACTIVE_ALLOW:
new_rule = PermState.NORMAL new_rule = PermState.NORMAL
@ -84,15 +118,17 @@ class CogCommandMixin:
rule : Optional[bool] rule : Optional[bool]
The rule to set as default. If ``True`` for allow, The rule to set as default. If ``True`` for allow,
``False`` for deny and ``None`` for normal. ``False`` for deny and ``None`` for normal.
guild_id : Optional[int] guild_id : int
Specify to set the default rule for a specific guild. The guild to set the default rule in. When ``0``, this will
When ``None``, this will set the global default rule. set the global default rule.
""" """
if guild_id: if rule is None:
self.requires.set_default_guild_rule(guild_id, PermState.from_bool(rule)) self.clear_rule_for(Requires.DEFAULT, guild_id=guild_id)
else: elif rule is True:
self.requires.default_global_rule = PermState.from_bool(rule) self.allow_for(Requires.DEFAULT, guild_id=guild_id)
elif rule is False:
self.deny_to(Requires.DEFAULT, guild_id=guild_id)
class Command(CogCommandMixin, commands.Command): class Command(CogCommandMixin, commands.Command):
@ -335,7 +371,7 @@ class Command(CogCommandMixin, commands.Command):
else: else:
return True return True
def allow_for(self, model_id: int, guild_id: int) -> None: def allow_for(self, model_id: Union[int, str], guild_id: int) -> None:
super().allow_for(model_id, guild_id=guild_id) super().allow_for(model_id, guild_id=guild_id)
parents = self.parents parents = self.parents
if self.instance is not None: if self.instance is not None:
@ -347,7 +383,9 @@ class Command(CogCommandMixin, commands.Command):
elif cur_rule is PermState.ACTIVE_DENY: elif cur_rule is PermState.ACTIVE_DENY:
parent.requires.set_rule(model_id, PermState.CAUTIOUS_ALLOW, guild_id=guild_id) parent.requires.set_rule(model_id, PermState.CAUTIOUS_ALLOW, guild_id=guild_id)
def clear_rule_for(self, model_id: int, guild_id: int) -> Tuple[PermState, PermState]: def clear_rule_for(
self, model_id: Union[int, str], guild_id: int
) -> Tuple[PermState, PermState]:
old_rule, new_rule = super().clear_rule_for(model_id, guild_id=guild_id) old_rule, new_rule = super().clear_rule_for(model_id, guild_id=guild_id)
if old_rule is PermState.ACTIVE_ALLOW: if old_rule is PermState.ACTIVE_ALLOW:
parents = self.parents parents = self.parents
@ -396,8 +434,28 @@ class CogGroupMixin:
all_commands: Dict[str, Command] all_commands: Dict[str, Command]
def reevaluate_rules_for( def reevaluate_rules_for(
self, model_id: int, guild_id: Optional[int] self, model_id: Union[str, int], guild_id: Optional[int]
) -> Tuple[PermState, bool]: ) -> Tuple[PermState, bool]:
"""Re-evaluate a rule by checking subcommand rules.
This is called when a subcommand is no longer actively allowed.
Parameters
----------
model_id : Union[int, str]
Must be an `int` if supplying an ID. `str` is only valid
for "default".
guild_id : int
The guild ID. For global rules, use ``0``.
Returns
-------
Tuple[PermState, bool]
A 2-tuple containing the new rule and a bool indicating
whether or not the rule was changed as a result of this
call.
"""
cur_rule = self.requires.get_rule(model_id, guild_id=guild_id) cur_rule = self.requires.get_rule(model_id, guild_id=guild_id)
if cur_rule in (PermState.NORMAL, PermState.ACTIVE_ALLOW, PermState.ACTIVE_DENY): if cur_rule in (PermState.NORMAL, PermState.ACTIVE_ALLOW, PermState.ACTIVE_DENY):
# These three states are unaffected by subcommand rules # These three states are unaffected by subcommand rules

View File

@ -19,6 +19,7 @@ from typing import (
TYPE_CHECKING, TYPE_CHECKING,
TypeVar, TypeVar,
Tuple, Tuple,
ClassVar,
) )
import discord import discord
@ -284,6 +285,14 @@ class Requires:
""" """
DEFAULT: ClassVar[str] = "default"
"""The key for the default rule in a rules dict."""
GLOBAL: ClassVar[int] = 0
"""Should be used in place of a guild ID when setting/getting
global rules.
"""
def __init__( def __init__(
self, self,
privilege_level: Optional[PrivilegeLevel], privilege_level: Optional[PrivilegeLevel],
@ -307,10 +316,8 @@ class Requires:
self.bot_perms.update(**bot_perms) self.bot_perms.update(**bot_perms)
else: else:
self.bot_perms = bot_perms self.bot_perms = bot_perms
self.default_global_rule: PermState = PermState.NORMAL self._global_rules: _RulesDict = _RulesDict()
self._global_rules: _IntKeyDict[PermState] = _IntKeyDict() self._guild_rules: _IntKeyDict[_RulesDict] = _IntKeyDict[_RulesDict]()
self._default_guild_rules: _IntKeyDict[PermState] = _IntKeyDict()
self._guild_rules: _IntKeyDict[_IntKeyDict[PermState]] = _IntKeyDict()
@staticmethod @staticmethod
def get_decorator( def get_decorator(
@ -334,16 +341,17 @@ class Requires:
return decorator return decorator
def get_rule(self, model: Union[int, PermissionModel], guild_id: int) -> PermState: def get_rule(self, model: Union[int, str, PermissionModel], guild_id: int) -> PermState:
"""Get the rule for a particular model. """Get the rule for a particular model.
Parameters Parameters
---------- ----------
model : PermissionModel model : Union[int, str, PermissionModel]
The model to get the rule for. The model to get the rule for. `str` is only valid for
`Requires.DEFAULT`.
guild_id : int guild_id : int
The ID of the guild for the rule's scope. Set to ``0`` The ID of the guild for the rule's scope. Set to
for a global rule. `Requires.GLOBAL` for a global rule.
Returns Returns
------- -------
@ -352,31 +360,32 @@ class Requires:
for an explanation. for an explanation.
""" """
if not isinstance(model, int): if not isinstance(model, (str, int)):
model = model.id model = model.id
if guild_id: if guild_id:
rules = self._guild_rules.get(guild_id, _IntKeyDict()) rules = self._guild_rules.get(guild_id, _RulesDict())
else: else:
rules = self._global_rules rules = self._global_rules
return rules.get(model, PermState.NORMAL) return rules.get(model, PermState.NORMAL)
def set_rule(self, model_id: int, rule: PermState, guild_id: int) -> None: def set_rule(self, model_id: Union[str, int], rule: PermState, guild_id: int) -> None:
"""Set the rule for a particular model. """Set the rule for a particular model.
Parameters Parameters
---------- ----------
model_id : PermissionModel model_id : Union[str, int]
The model to add a rule for. The model to add a rule for. `str` is only valid for
`Requires.DEFAULT`.
rule : PermState rule : PermState
Which state this rule should be set as. See the `PermState` Which state this rule should be set as. See the `PermState`
class for an explanation. class for an explanation.
guild_id : int guild_id : int
The ID of the guild for the rule's scope. Set to ``0`` The ID of the guild for the rule's scope. Set to
for a global rule. `Requires.GLOBAL` for a global rule.
""" """
if guild_id: if guild_id:
rules = self._guild_rules.setdefault(guild_id, _IntKeyDict()) rules = self._guild_rules.setdefault(guild_id, _RulesDict())
else: else:
rules = self._global_rules rules = self._global_rules
if rule is PermState.NORMAL: if rule is PermState.NORMAL:
@ -387,27 +396,24 @@ class Requires:
def clear_all_rules(self, guild_id: int) -> None: def clear_all_rules(self, guild_id: int) -> None:
"""Clear all rules of a particular scope. """Clear all rules of a particular scope.
This will preserve the default rule, if set.
Parameters Parameters
---------- ----------
guild_id : int guild_id : int
The guild ID to clear rules for. If ``0``, this will The guild ID to clear rules for. If set to
clear all global rules and leave all guild rules `Requires.GLOBAL`, this will clear all global rules and
untouched. leave all guild rules untouched.
""" """
if guild_id: if guild_id:
rules = self._guild_rules.setdefault(guild_id, _IntKeyDict()) rules = self._guild_rules.setdefault(guild_id, _RulesDict())
else: else:
rules = self._global_rules rules = self._global_rules
default = rules.get(self.DEFAULT, None)
rules.clear() rules.clear()
if default is not None:
def get_default_guild_rule(self, guild_id: int) -> PermState: rules[self.DEFAULT] = default
"""Get the default rule for a guild."""
return self._default_guild_rules.get(guild_id, PermState.NORMAL)
def set_default_guild_rule(self, guild_id: int, rule: PermState) -> None:
"""Set the default rule for a guild."""
self._default_guild_rules[guild_id] = rule
async def verify(self, ctx: "Context") -> bool: async def verify(self, ctx: "Context") -> bool:
"""Check if the given context passes the requirements. """Check if the given context passes the requirements.
@ -470,9 +476,9 @@ class Requires:
# We must check what would happen normally, if no explicit rules were set. # We must check what would happen normally, if no explicit rules were set.
default_rule = PermState.NORMAL default_rule = PermState.NORMAL
if ctx.guild is not None: if ctx.guild is not None:
default_rule = self.get_default_guild_rule(guild_id=ctx.guild.id) default_rule = self.get_rule(self.DEFAULT, guild_id=ctx.guild.id)
if default_rule is PermState.NORMAL: if default_rule is PermState.NORMAL:
default_rule = self.default_global_rule default_rule = self.get_rule(self.DEFAULT, self.GLOBAL)
if default_rule == PermState.ACTIVE_DENY: if default_rule == PermState.ACTIVE_DENY:
would_invoke = False would_invoke = False
@ -510,7 +516,7 @@ class Requires:
rule = self._global_rules.get(author.id) rule = self._global_rules.get(author.id)
if rule is not None: if rule is not None:
return rule return rule
return self.default_global_rule return self.get_rule(self.DEFAULT, self.GLOBAL)
rules_chain = [self._global_rules] rules_chain = [self._global_rules]
guild_rules = self._guild_rules.get(ctx.guild.id) guild_rules = self._guild_rules.get(ctx.guild.id)
@ -534,9 +540,9 @@ class Requires:
return rule return rule
del model_chain[-1] # We don't check for the guild in guild rules del model_chain[-1] # We don't check for the guild in guild rules
default_rule = self.get_default_guild_rule(guild.id) default_rule = self.get_rule(self.DEFAULT, guild.id)
if default_rule is PermState.NORMAL: if default_rule is PermState.NORMAL:
default_rule = self.default_global_rule default_rule = self.get_rule(self.DEFAULT, self.GLOBAL)
return default_rule return default_rule
async def _verify_checks(self, ctx: "Context") -> bool: async def _verify_checks(self, ctx: "Context") -> bool:
@ -706,6 +712,20 @@ class _IntKeyDict(Dict[int, _T]):
return super().__setitem__(key, value) return super().__setitem__(key, value)
class _RulesDict(Dict[Union[int, str], PermState]):
"""Dict subclass which throws a KeyError when an invalid key is used."""
def __getitem__(self, key: Any) -> PermState:
if key != Requires.DEFAULT and not isinstance(key, int):
raise TypeError(f'Expected "{Requires.DEFAULT}" or int key, not "{key}"')
return super().__getitem__(key)
def __setitem__(self, key: Any, value: PermState) -> None:
if key != Requires.DEFAULT and not isinstance(key, int):
raise TypeError(f'Expected "{Requires.DEFAULT}" or int key, not "{key}"')
return super().__setitem__(key, value)
def _validate_perms_dict(perms: Dict[str, bool]) -> None: def _validate_perms_dict(perms: Dict[str, bool]) -> None:
for perm, value in perms.items(): for perm, value in perms.items():
try: try: