Create cog disabling API (#4043)

* create cog disbale base

* Because defaults...

* lol

* announcer needs to respect this

* defaultdict mishap

* Allow None as guild

- Mostly for interop with with ctx.guild

* a whitespace issue

* Apparently, I broke this too

* Apply suggestions from code review

Co-authored-by: jack1142 <6032823+jack1142@users.noreply.github.com>

* This can probably be more optimized later, but since this is a cached value, it's not a large issue

* Report tunnel closing

* mod too

* whitespace issue

* Fix Artifact of prior method naming

* these 3 places should have the check if i understood it correctly

* Announce the closed tunnels

* tunnel oversight

* Make the player stop at next track

* added where draper said to put it

* Apply suggestions from code review

Co-authored-by: jack1142 <6032823+jack1142@users.noreply.github.com>

Co-authored-by: jack1142 <6032823+jack1142@users.noreply.github.com>
Co-authored-by: Drapersniper <27962761+drapersniper@users.noreply.github.com>
This commit is contained in:
Michael H 2020-07-28 14:52:36 -04:00 committed by GitHub
parent 97379afe6d
commit 1d80fe9aec
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 329 additions and 5 deletions

View File

@ -229,3 +229,12 @@ Not all of these are strict requirements (some are) but are all generally advisa
- We announce this in advance. - We announce this in advance.
- If you need help, ask. - If you need help, ask.
14. Check events against ``bot.cog_disabled_in_guild``
- Not all events need to be checked, only those that interact with a guild.
- Some discretion may apply, for example,
a cog which logs command invocation errors could choose to ignore this
but a cog which takes actions based on messages should not.
15. Respect settings when treating non command messages as commands.

View File

@ -40,6 +40,8 @@ class Announcer:
self.active = False self.active = False
async def _get_announce_channel(self, guild: discord.Guild) -> Optional[discord.TextChannel]: async def _get_announce_channel(self, guild: discord.Guild) -> Optional[discord.TextChannel]:
if await self.ctx.bot.cog_disabled_in_guild_raw("Admin", guild.id):
return
channel_id = await self.config.guild(guild).announce_channel() channel_id = await self.config.guild(guild).announce_channel()
return guild.get_channel(channel_id) return guild.get_channel(channel_id)

View File

@ -326,6 +326,11 @@ class Alias(commands.Cog):
@commands.Cog.listener() @commands.Cog.listener()
async def on_message_without_command(self, message: discord.Message): async def on_message_without_command(self, message: discord.Message):
if message.guild is not None:
if await self.bot.cog_disabled_in_guild(self, message.guild):
return
try: try:
prefix = await self.get_prefix(message) prefix = await self.get_prefix(message)
except ValueError: except ValueError:

View File

@ -25,6 +25,13 @@ class AudioEvents(MixinMeta, metaclass=CompositeMetaClass):
): ):
if not (track and guild): if not (track and guild):
return return
if await self.bot.cog_disabled_in_guild(self, guild):
player = lavalink.get_player(guild.id)
await player.stop()
await player.disconnect()
return
track_identifier = track.track_identifier track_identifier = track.track_identifier
if self.playlist_api is not None: if self.playlist_api is not None:
daily_cache = self._daily_playlist_cache.setdefault( daily_cache = self._daily_playlist_cache.setdefault(

View File

@ -178,6 +178,8 @@ class DpyEvents(MixinMeta, metaclass=CompositeMetaClass):
async def on_voice_state_update( async def on_voice_state_update(
self, member: discord.Member, before: discord.VoiceState, after: discord.VoiceState self, member: discord.Member, before: discord.VoiceState, after: discord.VoiceState
) -> None: ) -> None:
if await self.bot.cog_disabled_in_guild(self, member.guild):
return
await self.cog_ready_event.wait() await self.cog_ready_event.wait()
if after.channel != before.channel: if after.channel != before.channel:
try: try:

View File

@ -19,7 +19,13 @@ class LavalinkEvents(MixinMeta, metaclass=CompositeMetaClass):
current_track = player.current current_track = player.current
current_channel = player.channel current_channel = player.channel
guild = self.rgetattr(current_channel, "guild", None) guild = self.rgetattr(current_channel, "guild", None)
if await self.bot.cog_disabled_in_guild(self, guild):
await player.stop()
await player.disconnect()
return
guild_id = self.rgetattr(guild, "id", None) guild_id = self.rgetattr(guild, "id", None)
if not guild:
return
current_requester = self.rgetattr(current_track, "requester", None) current_requester = self.rgetattr(current_track, "requester", None)
current_stream = self.rgetattr(current_track, "is_stream", None) current_stream = self.rgetattr(current_track, "is_stream", None)
current_length = self.rgetattr(current_track, "length", None) current_length = self.rgetattr(current_track, "length", None)

View File

@ -20,6 +20,8 @@ class PlayerTasks(MixinMeta, metaclass=CompositeMetaClass):
while True: while True:
async for p in AsyncIter(lavalink.all_players()): async for p in AsyncIter(lavalink.all_players()):
server = p.channel.guild server = p.channel.guild
if await self.bot.cog_disabled_in_guild(self, server):
continue
if [self.bot.user] == p.channel.members: if [self.bot.user] == p.channel.members:
stop_times.setdefault(server.id, time.time()) stop_times.setdefault(server.id, time.time())

View File

@ -516,6 +516,9 @@ class CustomCommands(commands.Cog):
if len(message.content) < 2 or is_private or not user_allowed or message.author.bot: if len(message.content) < 2 or is_private or not user_allowed or message.author.bot:
return return
if await self.bot.cog_disabled_in_guild(self, message.guild):
return
ctx = await self.bot.get_context(message) ctx = await self.bot.get_context(message)
if ctx.prefix is None: if ctx.prefix is None:

View File

@ -369,6 +369,10 @@ class Filter(commands.Cog):
async def on_message(self, message: discord.Message): async def on_message(self, message: discord.Message):
if isinstance(message.channel, discord.abc.PrivateChannel): if isinstance(message.channel, discord.abc.PrivateChannel):
return return
if await self.bot.cog_disabled_in_guild(self, message.guild):
return
author = message.author author = message.author
valid_user = isinstance(author, discord.Member) and not author.bot valid_user = isinstance(author, discord.Member) and not author.bot
if not valid_user: if not valid_user:
@ -395,6 +399,11 @@ class Filter(commands.Cog):
await self.maybe_filter_name(member) await self.maybe_filter_name(member)
async def maybe_filter_name(self, member: discord.Member): async def maybe_filter_name(self, member: discord.Member):
guild = member.guild
if (not guild) or await self.bot.cog_disabled_in_guild(self, guild):
return
if not member.guild.me.guild_permissions.manage_nicknames: if not member.guild.me.guild_permissions.manage_nicknames:
return # No permissions to manage nicknames, so can't do anything return # No permissions to manage nicknames, so can't do anything
if member.top_role >= member.guild.me.top_role: if member.top_role >= member.guild.me.top_role:

View File

@ -79,6 +79,10 @@ class Events(MixinMeta):
author = message.author author = message.author
if message.guild is None or self.bot.user == author: if message.guild is None or self.bot.user == author:
return return
if await self.bot.cog_disabled_in_guild(self, message.guild):
return
valid_user = isinstance(author, discord.Member) and not author.bot valid_user = isinstance(author, discord.Member) and not author.bot
if not valid_user: if not valid_user:
return return
@ -110,6 +114,9 @@ class Events(MixinMeta):
@commands.Cog.listener() @commands.Cog.listener()
async def on_member_update(self, before: discord.Member, after: discord.Member): async def on_member_update(self, before: discord.Member, after: discord.Member):
if before.nick != after.nick and after.nick is not None: if before.nick != after.nick and after.nick is not None:
guild = after.guild
if (not guild) or await self.bot.cog_disabled_in_guild(self, guild):
return
async with self.config.member(before).past_nicks() as nick_list: async with self.config.member(before).past_nicks() as nick_list:
while None in nick_list: # clean out null entries from a bug while None in nick_list: # clean out null entries from a bug
nick_list.remove(None) nick_list.remove(None)

View File

@ -142,6 +142,9 @@ class KickBanMixin(MixinMeta):
if not guild.me.guild_permissions.ban_members: if not guild.me.guild_permissions.ban_members:
continue continue
if await self.bot.cog_disabled_in_guild(self, guild):
continue
async with self.config.guild(guild).current_tempbans() as guild_tempbans: async with self.config.guild(guild).current_tempbans() as guild_tempbans:
for uid in guild_tempbans.copy(): for uid in guild_tempbans.copy():
unban_time = datetime.utcfromtimestamp( unban_time = datetime.utcfromtimestamp(

View File

@ -293,10 +293,11 @@ class Reports(commands.Cog):
pass pass
@commands.Cog.listener() @commands.Cog.listener()
async def on_raw_reaction_add(self, payload): async def on_raw_reaction_add(self, payload: discord.RawReactionActionEvent):
""" """
oh dear.... oh dear....
""" """
if not str(payload.emoji) == "\N{NEGATIVE SQUARED CROSS MARK}": if not str(payload.emoji) == "\N{NEGATIVE SQUARED CROSS MARK}":
return return
@ -314,13 +315,35 @@ class Reports(commands.Cog):
@commands.Cog.listener() @commands.Cog.listener()
async def on_message(self, message: discord.Message): async def on_message(self, message: discord.Message):
to_remove = []
for k, v in self.tunnel_store.items(): for k, v in self.tunnel_store.items():
topic = _("Re: ticket# {1} in {0.name}").format(*k)
guild, ticket_number = k
if await self.bot.cog_disabled_in_guild(self, guild):
to_remove.append(k)
continue
topic = _("Re: ticket# {ticket_number} in {guild.name}").format(
ticket_number=ticket_number, guild=guild
)
# Tunnels won't forward unintended messages, this is safe # Tunnels won't forward unintended messages, this is safe
msgs = await v["tun"].communicate(message=message, topic=topic) msgs = await v["tun"].communicate(message=message, topic=topic)
if msgs: if msgs:
self.tunnel_store[k]["msgs"] = msgs self.tunnel_store[k]["msgs"] = msgs
for key in to_remove:
if tun := self.tunnel_store.pop(key, None):
guild, ticket = key
await tun["tun"].close_because_disabled(
_(
"Correspondence about ticket# {ticket_number} in "
"{guild.name} has been ended due "
"to reports being disabled in that server."
).format(ticket_number=ticket, guild=guild)
)
@commands.guild_only() @commands.guild_only()
@checks.mod_or_permissions(manage_roles=True) @checks.mod_or_permissions(manage_roles=True)
@report.command(name="interact") @report.command(name="interact")

View File

@ -702,6 +702,8 @@ class Streams(commands.Cog):
continue continue
for message in stream._messages_cache: for message in stream._messages_cache:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
if await self.bot.cog_disabled_in_guild(self, message.guild):
continue
autodelete = await self.config.guild(message.guild).autodelete() autodelete = await self.config.guild(message.guild).autodelete()
if autodelete: if autodelete:
await message.delete() await message.delete()
@ -714,6 +716,8 @@ class Streams(commands.Cog):
channel = self.bot.get_channel(channel_id) channel = self.bot.get_channel(channel_id)
if not channel: if not channel:
continue continue
if await self.bot.cog_disabled_in_guild(self, channel.guild):
continue
ignore_reruns = await self.config.guild(channel.guild).ignore_reruns() ignore_reruns = await self.config.guild(channel.guild).ignore_reruns()
if ignore_reruns and is_rerun: if ignore_reruns and is_rerun:
continue continue

View File

@ -37,7 +37,12 @@ from .dev_commands import Dev
from .events import init_events from .events import init_events
from .global_checks import init_global_checks from .global_checks import init_global_checks
from .settings_caches import PrefixManager, IgnoreManager, WhitelistBlacklistManager from .settings_caches import (
PrefixManager,
IgnoreManager,
WhitelistBlacklistManager,
DisabledCogCache,
)
from .rpc import RPCMixin from .rpc import RPCMixin
from .utils import common_filters from .utils import common_filters
@ -132,12 +137,16 @@ class RedBase(
self._config.register_channel(embeds=None, ignored=False) self._config.register_channel(embeds=None, ignored=False)
self._config.register_user(embeds=None) self._config.register_user(embeds=None)
self._config.init_custom("COG_DISABLE_SETTINGS", 2)
self._config.register_custom("COG_DISABLE_SETTINGS", disabled=None)
self._config.init_custom(CUSTOM_GROUPS, 2) self._config.init_custom(CUSTOM_GROUPS, 2)
self._config.register_custom(CUSTOM_GROUPS) self._config.register_custom(CUSTOM_GROUPS)
self._config.init_custom(SHARED_API_TOKENS, 2) self._config.init_custom(SHARED_API_TOKENS, 2)
self._config.register_custom(SHARED_API_TOKENS) self._config.register_custom(SHARED_API_TOKENS)
self._prefix_cache = PrefixManager(self._config, cli_flags) self._prefix_cache = PrefixManager(self._config, cli_flags)
self._disabled_cog_cache = DisabledCogCache(self._config)
self._ignored_cache = IgnoreManager(self._config) self._ignored_cache = IgnoreManager(self._config)
self._whiteblacklist_cache = WhitelistBlacklistManager(self._config) self._whiteblacklist_cache = WhitelistBlacklistManager(self._config)
@ -217,6 +226,41 @@ class RedBase(
return_exceptions=return_exceptions, return_exceptions=return_exceptions,
) )
async def cog_disabled_in_guild(
self, cog: commands.Cog, guild: Optional[discord.Guild]
) -> bool:
"""
Check if a cog is disabled in a guild
Parameters
----------
cog: commands.Cog
guild: Optional[discord.Guild]
Returns
-------
bool
"""
if guild is None:
return False
return await self._disabled_cog_cache.cog_disabled_in_guild(cog.qualified_name, guild.id)
async def cog_disabled_in_guild_raw(self, cog_name: str, guild_id: int) -> bool:
"""
Check if a cog is disabled in a guild without the cog or guild object
Parameters
----------
cog_name: str
This should be the cog's qualified name, not neccessarily the classname
guild_id: int
Returns
-------
bool
"""
return await self._disabled_cog_cache.cog_disabled_in_guild(cog_name, guild_id)
def remove_before_invoke_hook(self, coro: PreInvokeCoroutine) -> None: def remove_before_invoke_hook(self, coro: PreInvokeCoroutine) -> None:
""" """
Functional method to remove a `before_invoke` hook. Functional method to remove a `before_invoke` hook.

View File

@ -511,6 +511,10 @@ class Requires:
bot_user = ctx.bot.user bot_user = ctx.bot.user
else: else:
bot_user = ctx.guild.me bot_user = ctx.guild.me
cog = ctx.cog
if cog and await ctx.bot.cog_disabled_in_guild(cog, ctx.guild):
raise discord.ext.commands.DisabledCommand()
bot_perms = ctx.channel.permissions_for(bot_user) bot_perms = ctx.channel.permissions_for(bot_user)
if not (bot_perms.administrator or bot_perms >= self.bot_perms): if not (bot_perms.administrator or bot_perms >= self.bot_perms):
raise BotMissingPermissions(missing=self._missing_perms(self.bot_perms, bot_perms)) raise BotMissingPermissions(missing=self._missing_perms(self.bot_perms, bot_perms))

View File

@ -2174,9 +2174,83 @@ class Core(commands.Cog, CoreLogic):
@checks.guildowner_or_permissions(administrator=True) @checks.guildowner_or_permissions(administrator=True)
@commands.group(name="command") @commands.group(name="command")
async def command_manager(self, ctx: commands.Context): async def command_manager(self, ctx: commands.Context):
"""Manage the bot's commands.""" """Manage the bot's commands and cogs."""
pass pass
@checks.is_owner()
@command_manager.command(name="defaultdisablecog")
async def command_default_disable_cog(self, ctx: commands.Context, *, cogname: str):
"""Set the default state for a cog as disabled."""
cog = self.bot.get_cog(cogname)
if not cog:
return await ctx.send(_("Cog with the given name doesn't exist."))
if cog == self:
return await ctx.send(_("You can't disable this cog by default."))
await self.bot._disabled_cog_cache.default_disable(cogname)
await ctx.send(_("{cogname} has been set as disabled by default.").format(cogname=cogname))
@checks.is_owner()
@command_manager.command(name="defaultenablecog")
async def command_default_enable_cog(self, ctx: commands.Context, *, cogname: str):
"""Set the default state for a cog as enabled."""
cog = self.bot.get_cog(cogname)
if not cog:
return await ctx.send(_("Cog with the given name doesn't exist."))
await self.bot._disabled_cog_cache.default_enable(cogname)
await ctx.send(_("{cogname} has been set as enabled by default.").format(cogname=cogname))
@commands.guild_only()
@command_manager.command(name="disablecog")
async def command_disable_cog(self, ctx: commands.Context, *, cogname: str):
"""Disable a cog in this guild."""
cog = self.bot.get_cog(cogname)
if not cog:
return await ctx.send(_("Cog with the given name doesn't exist."))
if cog == self:
return await ctx.send(_("You can't disable this cog as you would lock yourself out."))
if await self.bot._disabled_cog_cache.disable_cog_in_guild(cogname, ctx.guild.id):
await ctx.send(_("{cogname} has been disabled in this guild.").format(cogname=cogname))
else:
await ctx.send(
_("{cogname} was already disabled (nothing to do).").format(cogname=cogname)
)
@commands.guild_only()
@command_manager.command(name="enablecog")
async def command_enable_cog(self, ctx: commands.Context, *, cogname: str):
"""Enable a cog in this guild."""
if await self.bot._disabled_cog_cache.enable_cog_in_guild(cogname, ctx.guild.id):
await ctx.send(_("{cogname} has been enabled in this guild.").format(cogname=cogname))
else:
# putting this here allows enabling a cog that isn't loaded but was disabled.
cog = self.bot.get_cog(cogname)
if not cog:
return await ctx.send(_("Cog with the given name doesn't exist."))
await ctx.send(
_("{cogname} was not disabled (nothing to do).").format(cogname=cogname)
)
@commands.guild_only()
@command_manager.command(name="listdisabledcogs")
async def command_list_disabled_cogs(self, ctx: commands.Context):
"""List the cogs which are disabled in this guild."""
disabled = [
cog.qualified_name
for cog in self.bot.cogs.values()
if await self.bot._disabled_cog_cache.cog_disabled_in_guild(
cog.qualified_name, ctx.guild.id
)
]
if disabled:
output = _("The following cogs are disabled in this guild:\n")
output += humanize_list(disabled)
for page in pagify(output):
await ctx.send(page)
else:
await ctx.send(_("There are no disabled cogs in this guild."))
@command_manager.group(name="listdisabled", invoke_without_command=True) @command_manager.group(name="listdisabled", invoke_without_command=True)
async def list_disabled(self, ctx: commands.Context): async def list_disabled(self, ctx: commands.Context):
""" """

View File

@ -1,7 +1,8 @@
from __future__ import annotations from __future__ import annotations
from typing import Dict, List, Optional, Union, Set, Iterable from typing import Dict, List, Optional, Union, Set, Iterable, Tuple
from argparse import Namespace from argparse import Namespace
from collections import defaultdict
import discord import discord
@ -254,3 +255,108 @@ class WhitelistBlacklistManager:
) )
self._cached_blacklist[gid].difference_update(role_or_user) self._cached_blacklist[gid].difference_update(role_or_user)
await self._config.guild_from_id(gid).blacklist.set(list(self._cached_blacklist[gid])) await self._config.guild_from_id(gid).blacklist.set(list(self._cached_blacklist[gid]))
class DisabledCogCache:
def __init__(self, config: Config):
self._config = config
self._disable_map: Dict[str, Dict[int, bool]] = defaultdict(dict)
async def cog_disabled_in_guild(self, cog_name: str, guild_id: int) -> bool:
"""
Check if a cog is disabled in a guild
Parameters
----------
cog_name: str
This should be the cog's qualified name, not neccessarily the classname
guild_id: int
Returns
-------
bool
"""
if guild_id in self._disable_map[cog_name]:
return self._disable_map[cog_name][guild_id]
gset = await self._config.custom("COG_DISABLE_SETTINGS", cog_name, guild_id).disabled()
if gset is None:
gset = await self._config.custom("COG_DISABLE_SETTINGS", cog_name, 0).disabled()
if gset is None:
gset = False
self._disable_map[cog_name][guild_id] = gset
return gset
async def default_disable(self, cog_name: str):
"""
Sets the default for a cog as disabled.
Parameters
----------
cog_name: str
This should be the cog's qualified name, not neccessarily the classname
"""
await self._config.custom("COG_DISABLE_SETTINGS", cog_name, 0).disabled.set(True)
del self._disable_map[cog_name]
async def default_enable(self, cog_name: str):
"""
Sets the default for a cog as enabled.
Parameters
----------
cog_name: str
This should be the cog's qualified name, not neccessarily the classname
"""
await self._config.custom("COG_DISABLE_SETTINGS", cog_name, 0).disabled.clear()
del self._disable_map[cog_name]
async def disable_cog_in_guild(self, cog_name: str, guild_id: int) -> bool:
"""
Disable a cog in a guild.
Parameters
----------
cog_name: str
This should be the cog's qualified name, not neccessarily the classname
guild_id: int
Returns
-------
bool
Whether or not any change was made.
This may be useful for settings commands.
"""
if await self.cog_disabled_in_guild(cog_name, guild_id):
return False
self._disable_map[cog_name][guild_id] = True
await self._config.custom("COG_DISABLE_SETTINGS", cog_name, guild_id).disabled.set(True)
return True
async def enable_cog_in_guild(self, cog_name: str, guild_id: int) -> bool:
"""
Enable a cog in a guild.
Parameters
----------
cog_name: str
This should be the cog's qualified name, not neccessarily the classname
guild_id: int
Returns
-------
bool
Whether or not any change was made.
This may be useful for settings commands.
"""
if not await self.cog_disabled_in_guild(cog_name, guild_id):
return False
self._disable_map[cog_name][guild_id] = False
await self._config.custom("COG_DISABLE_SETTINGS", cog_name, guild_id).disabled.set(False)
return True

View File

@ -1,3 +1,4 @@
import asyncio
import discord import discord
from datetime import datetime from datetime import datetime
from redbot.core.utils.chat_formatting import pagify from redbot.core.utils.chat_formatting import pagify
@ -175,6 +176,19 @@ class Tunnel(metaclass=TunnelMeta):
# Backwards-compatible typo fix (GH-2496) # Backwards-compatible typo fix (GH-2496)
files_from_attatch = files_from_attach files_from_attatch = files_from_attach
async def close_because_disabled(self, close_message: str):
"""
Sends a mesage to both ends of the tunnel that the tunnel is now closed.
Parameters
----------
close_message: str
The message to send to both ends of the tunnel.
"""
tasks = [destination.send(close_message) for destination in (self.recipient, self.origin)]
await asyncio.gather(*tasks, return_exceptions=True)
async def communicate( async def communicate(
self, *, message: discord.Message, topic: str = None, skip_message_content: bool = False self, *, message: discord.Message, topic: str = None, skip_message_content: bool = False
): ):