diff --git a/redbot/__main__.py b/redbot/__main__.py index 8509def4a..e64de887a 100644 --- a/redbot/__main__.py +++ b/redbot/__main__.py @@ -117,7 +117,7 @@ def main(): if cli_flags.dev: red.add_cog(Dev()) # noinspection PyProtectedMember - modlog._init() + loop.run_until_complete(modlog._init()) # noinspection PyProtectedMember bank._init() diff --git a/redbot/core/modlog.py b/redbot/core/modlog.py index 9b95e2abe..7a6db696b 100644 --- a/redbot/core/modlog.py +++ b/redbot/core/modlog.py @@ -1,5 +1,5 @@ from datetime import datetime -from typing import List, Union +from typing import List, Union, Optional, cast import discord @@ -30,18 +30,66 @@ __all__ = [ "reset_cases", ] -_DEFAULT_GLOBAL = {"casetypes": {}} +_conf: Optional[Config] = None -_DEFAULT_GUILD = {"mod_log": None, "cases": {}, "casetypes": {}} - -_conf: Config = None +_CASETYPES = "CASETYPES" +_CASES = "CASES" +_SCHEMA_VERSION = 2 -def _init(): +async def _init(): global _conf _conf = Config.get_conf(None, 1354799444, cog_name="ModLog") - _conf.register_global(**_DEFAULT_GLOBAL) - _conf.register_guild(**_DEFAULT_GUILD) + _conf.register_global(schema_version=1) + _conf.register_guild(mod_log=None, casetypes={}) + _conf.init_custom(_CASETYPES, 1) + _conf.init_custom(_CASES, 2) + _conf.register_custom( + _CASETYPES, default_setting=None, image=None, case_str=None, audit_type=None + ) + _conf.register_custom( + _CASES, + case_number=None, + action_type=None, + guild=None, + created_at=None, + user=None, + moderator=None, + reason=None, + until=None, + channel=None, + amended_by=None, + modified_at=None, + message=None, + ) + await _migrate_config(from_version=await _conf.schema_version(), to_version=_SCHEMA_VERSION) + + +async def _migrate_config(from_version: int, to_version: int): + if from_version == to_version: + return + elif from_version < to_version: + # casetypes go from GLOBAL -> casetypes to CASETYPES + all_casetypes = await _conf.get_raw("casetypes", default={}) + if all_casetypes: + await _conf.custom(_CASETYPES).set(all_casetypes) + + # cases go from GUILD -> guild_id -> cases to CASES -> guild_id -> cases + all_guild_data = await _conf.all_guilds() + all_cases = {} + for guild_id, guild_data in all_guild_data.items(): + guild_cases = guild_data.pop("cases", None) + if guild_cases: + all_cases[str(guild_id)] = guild_cases + await _conf.custom(_CASES).set(all_cases) + + # new schema is now in place + await _conf.schema_version.set(_SCHEMA_VERSION) + + # migration done, now let's delete all the old stuff + await _conf.clear_raw("casetypes") + for guild_id in all_guild_data: + await _conf.guild(cast(discord.Guild, discord.Object(id=guild_id))).clear_raw("cases") class Case: @@ -53,15 +101,15 @@ class Case: guild: discord.Guild, created_at: int, action_type: str, - user: discord.User, - moderator: discord.Member, + user: Union[discord.User, int], + moderator: discord.User, case_number: int, reason: str = None, until: int = None, - channel: discord.TextChannel = None, - amended_by: discord.Member = None, - modified_at: int = None, - message: discord.Message = None, + channel: Optional[Union[discord.TextChannel, discord.VoiceChannel, int]] = None, + amended_by: Optional[discord.User] = None, + modified_at: Optional[int] = None, + message: Optional[discord.Message] = None, ): self.bot = bot self.guild = guild @@ -90,7 +138,7 @@ class Case: for item in list(data.keys()): setattr(self, item, data[item]) - await _conf.guild(self.guild).cases.set_raw(str(self.case_number), value=self.to_json()) + await _conf.custom(_CASES, str(self.guild.id), str(self.case_number)).set(self.to_json()) self.bot.dispatch("modlog_case_edit", self) async def message_content(self, embed: bool = True): @@ -119,11 +167,7 @@ class Case: reason = "**Reason:** Use the `reason` command to add it" if self.moderator is not None: - moderator = escape_spoilers( - "{}#{} ({})\n".format( - self.moderator.name, self.moderator.discriminator, self.moderator.id - ) - ) + moderator = escape_spoilers(f"{self.moderator} ({self.moderator.id})") else: moderator = "Unknown" until = None @@ -151,21 +195,28 @@ class Case: datetime.fromtimestamp(self.modified_at).strftime("%Y-%m-%d %H:%M:%S") ) - user = escape_spoilers( - filter_invites( - "{}#{} ({})\n".format(self.user.name, self.user.discriminator, self.user.id) - ) - ) # Invites and spoilers get rendered even in embeds. + if isinstance(self.user, int): + user = f"Deleted User#0000 ({self.user})" + avatar_url = None + else: + user = escape_spoilers( + filter_invites(f"{self.user} ({self.user.id})") + ) # Invites and spoilers get rendered even in embeds. + avatar_url = self.user.avatar_url + if embed: emb = discord.Embed(title=title, description=reason) - emb.set_author(name=user, icon_url=self.user.avatar_url) + if avatar_url is not None: + emb.set_author(name=user, icon_url=avatar_url) emb.add_field(name="Moderator", value=moderator, inline=False) if until and duration: emb.add_field(name="Until", value=until) emb.add_field(name="Duration", value=duration) - if self.channel: + if isinstance(self.channel, int): + emb.add_field(name="Channel", value=f"{self.channel} (deleted)", inline=False) + elif self.channel is not None: emb.add_field(name="Channel", value=self.channel.name, inline=False) if amended_by: emb.add_field(name="Amended by", value=amended_by) @@ -203,12 +254,15 @@ class Case: mod = self.moderator.id else: mod = None + if isinstance(self.user, int): + user_id = self.user + else: + user_id = self.user.id data = { - "case_number": self.case_number, "action_type": self.action_type, "guild": self.guild.id, "created_at": self.created_at, - "user": self.user.id, + "user": user_id, "moderator": mod, "reason": self.reason, "until": self.until, @@ -220,7 +274,9 @@ class Case: return data @classmethod - async def from_json(cls, mod_channel: discord.TextChannel, bot: Red, data: dict): + async def from_json( + cls, mod_channel: discord.TextChannel, bot: Red, case_number: int, data: dict, **kwargs + ): """Get a Case object from the provided information Parameters @@ -229,8 +285,14 @@ class Case: The mod log channel for the guild bot: Red The bot's instance. Needed to get the target user + case_number: int + The case's number. data: dict The JSON representation of the case to be gotten + **kwargs + Extra attributes for the Case instance which override values + in the data dict. These should be complete objects and not + IDs, where possible. Returns ------- @@ -246,31 +308,55 @@ class Case: `discord.HTTPException` A generic API issue """ - guild = mod_channel.guild - if data["message"]: - try: - message = await mod_channel.fetch_message(data["message"]) - except discord.NotFound: + guild = kwargs.get("guild") or mod_channel.guild + + message = kwargs.get("message") + if message is None: + message_id = data.get("message") + if message_id is not None: + try: + message = discord.utils.get(bot.cached_messages, id=message_id) + except AttributeError: + # bot.cached_messages didn't exist prior to discord.py 1.1.0 + message = None + if message is None: + try: + message = await mod_channel.fetch_message(message_id) + except (discord.NotFound, AttributeError): + message = None + else: message = None - user = await bot.fetch_user(data["user"]) - moderator = guild.get_member(data["moderator"]) - channel = guild.get_channel(data["channel"]) - amended_by = guild.get_member(data["amended_by"]) - case_guild = bot.get_guild(data["guild"]) + + user_objects = {"user": None, "moderator": None, "amended_by": None} + for user_key in tuple(user_objects): + user_object = kwargs.get(user_key) + if user_object is None: + user_id = data.get(user_key) + if user_id is None: + user_object = None + else: + user_object = bot.get_user(user_id) + if user_object is None: + try: + user_object = await bot.fetch_user(user_id) + except discord.NotFound: + user_object = user_id + user_objects[user_key] = user_object + + channel = kwargs.get("channel") or guild.get_channel(data["channel"]) or data["channel"] + case_guild = kwargs.get("guild") or bot.get_guild(data["guild"]) return cls( bot=bot, guild=case_guild, created_at=data["created_at"], action_type=data["action_type"], - user=user, - moderator=moderator, - case_number=data["case_number"], + case_number=case_number, reason=data["reason"], until=data["until"], channel=channel, - amended_by=amended_by, modified_at=data["modified_at"], message=message, + **user_objects, ) @@ -300,8 +386,8 @@ class CaseType: default_setting: bool, image: str, case_str: str, - audit_type: str = None, - guild: discord.Guild = None, + audit_type: Optional[str] = None, + guild: Optional[discord.Guild] = None, ): self.name = name self.default_setting = default_setting @@ -318,7 +404,7 @@ class CaseType: "case_str": self.case_str, "audit_type": self.audit_type, } - await _conf.casetypes.set_raw(self.name, value=data) + await _conf.custom(_CASETYPES, self.name).set(data) async def is_enabled(self) -> bool: """ @@ -352,23 +438,27 @@ class CaseType: await _conf.guild(self.guild).casetypes.set_raw(self.name, value=enabled) @classmethod - def from_json(cls, data: dict): + def from_json(cls, name: str, data: dict, **kwargs): """ Parameters ---------- - data: dict - The data to create an instance from + name : str + The casetype's name. + data : dict + The JSON data to create an instance from + **kwargs + Values for other attributes of the instance Returns ------- CaseType """ - return cls(**data) + return cls(name=name, **data, **kwargs) -async def get_next_case_number(guild: discord.Guild) -> str: +async def get_next_case_number(guild: discord.Guild) -> int: """ Gets the next case number @@ -379,12 +469,15 @@ async def get_next_case_number(guild: discord.Guild) -> str: Returns ------- - str + int The next case number """ - cases = sorted((await _conf.guild(guild).get_raw("cases")), key=lambda x: int(x), reverse=True) - return str(int(cases[0]) + 1) if cases else "1" + case_numbers = (await _conf.custom(_CASES, guild.id).all()).keys() + if not case_numbers: + return 1 + else: + return max(map(int, case_numbers)) + 1 async def get_case(case_number: int, guild: discord.Guild, bot: Red) -> Case: @@ -412,11 +505,11 @@ async def get_case(case_number: int, guild: discord.Guild, bot: Red) -> Case: """ try: - case = await _conf.guild(guild).cases.get_raw(str(case_number)) + case = await _conf.custom(_CASES, str(guild.id), str(case_number)).all() except KeyError as e: raise RuntimeError("That case does not exist for guild {}".format(guild.name)) from e mod_channel = await get_modlog_channel(guild) - return await Case.from_json(mod_channel, bot, case) + return await Case.from_json(mod_channel, bot, case_number, case) async def get_all_cases(guild: discord.Guild, bot: Red) -> List[Case]: @@ -436,12 +529,12 @@ async def get_all_cases(guild: discord.Guild, bot: Red) -> List[Case]: A list of all cases for the guild """ - cases = await _conf.guild(guild).get_raw("cases") - case_numbers = list(cases.keys()) - case_list = [] - for case in case_numbers: - case_list.append(await get_case(case, guild, bot)) - return case_list + cases = await _conf.custom(_CASES, str(guild.id)).all() + mod_channel = await get_modlog_channel(guild) + return [ + await Case.from_json(mod_channel, bot, case_number, case_data) + for case_number, case_data in cases.items() + ] async def get_cases_for_member( @@ -470,15 +563,13 @@ async def get_cases_for_member( ------ ValueError If at least one of member or member_id is not provided - `discord.NotFound` - A user with this ID does not exist. `discord.Forbidden` The bot does not have permission to fetch the modlog message which was sent. `discord.HTTPException` Fetching the user failed. """ - cases = await _conf.guild(guild).get_raw("cases") + cases = await _conf.custom(_CASES, str(guild.id)).all() if not (member_id or member): raise ValueError("Expected a member or a member id to be provided.") from None @@ -487,43 +578,21 @@ async def get_cases_for_member( member_id = member.id if not member: - member = guild.get_member(member_id) + member = bot.get_user(member_id) if not member: - member = await bot.fetch_user(member_id) + try: + member = await bot.fetch_user(member_id) + except discord.NotFound: + member = member_id try: - mod_channel = await get_modlog_channel(guild) + modlog_channel = await get_modlog_channel(guild) except RuntimeError: - mod_channel = None - - async def make_case(data: dict) -> Case: - - message = None - if data["message"] and mod_channel: - try: - message = await mod_channel.fetch_message(data["message"]) - except discord.NotFound: - pass - - return Case( - bot=bot, - guild=bot.get_guild(data["guild"]), - created_at=data["created_at"], - action_type=data["action_type"], - user=member, - moderator=guild.get_member(data["moderator"]), - case_number=data["case_number"], - reason=data["reason"], - until=data["until"], - channel=guild.get_channel(data["channel"]), - amended_by=guild.get_member(data["amended_by"]), - modified_at=data["modified_at"], - message=message, - ) + modlog_channel = None cases = [ - await make_case(case_data) - for case_data in cases.values() + await Case.from_json(modlog_channel, bot, case_number, case_data, user=member, guild=guild) + for case_number, case_data in cases.items() if case_data["user"] == member_id ] @@ -536,11 +605,11 @@ async def create_case( created_at: datetime, action_type: str, user: Union[discord.User, discord.Member], - moderator: discord.Member = None, - reason: str = None, - until: datetime = None, - channel: discord.TextChannel = None, -) -> Union[Case, None]: + moderator: Optional[Union[discord.User, discord.Member]] = None, + reason: Optional[str] = None, + until: Optional[datetime] = None, + channel: Optional[discord.TextChannel] = None, +) -> Optional[Case]: """ Creates a new case. @@ -548,36 +617,36 @@ async def create_case( Parameters ---------- - bot: `Red` + bot: Red The bot object - guild: `discord.Guild` + guild: discord.Guild The guild the action was taken in created_at: datetime The time the action occurred at action_type: str The type of action that was taken - user: `discord.User` or `discord.Member` + user: Union[discord.User, discord.Member] The user target by the action - moderator: `discord.Member` + moderator: Optional[Union[discord.User, discord.Member]] The moderator who took the action - reason: str + reason: Optional[str] The reason the action was taken - until: datetime + until: Optional[datetime] The time the action is in effect until - channel: `discord.TextChannel` or `discord.VoiceChannel` + channel: Optional[discord.TextChannel] The channel the action was taken in """ case_type = await get_casetype(action_type, guild) if case_type is None: - return None + return if not await case_type.is_enabled(): - return None + return if user == bot.user: - return None + return - next_case_number = int(await get_next_case_number(guild)) + next_case_number = await get_next_case_number(guild) case = Case( bot, @@ -594,12 +663,12 @@ async def create_case( modified_at=None, message=None, ) - await _conf.guild(guild).cases.set_raw(str(next_case_number), value=case.to_json()) + await _conf.custom(_CASES, str(guild.id), str(next_case_number)).set(case.to_json()) bot.dispatch("modlog_case_create", case) return case -async def get_casetype(name: str, guild: discord.Guild = None) -> Union[CaseType, None]: +async def get_casetype(name: str, guild: Optional[discord.Guild] = None) -> Optional[CaseType]: """ Gets the case type @@ -607,22 +676,21 @@ async def get_casetype(name: str, guild: discord.Guild = None) -> Union[CaseType ---------- name: str The name of the case type to get - guild: discord.Guild + guild: Optional[discord.Guild] If provided, sets the case type's guild attribute to this guild Returns ------- - CaseType or None + Optional[CaseType] """ - casetypes = await _conf.get_raw("casetypes") - if name in casetypes: - data = casetypes[name] - data["name"] = name - casetype = CaseType.from_json(data) + try: + data = await _conf.custom(_CASETYPES, name).all() + except KeyError: + return + else: + casetype = CaseType.from_json(name, data) casetype.guild = guild return casetype - else: - return None async def get_all_casetypes(guild: discord.Guild = None) -> List[CaseType]: @@ -635,15 +703,10 @@ async def get_all_casetypes(guild: discord.Guild = None) -> List[CaseType]: A list of case types """ - casetypes = await _conf.get_raw("casetypes", default={}) - typelist = [] - for ct in casetypes.keys(): - data = casetypes[ct] - data["name"] = ct - casetype = CaseType.from_json(data) - casetype.guild = guild - typelist.append(casetype) - return typelist + return [ + CaseType.from_json(name, data, guild=guild) + for name, data in await _conf.custom(_CASETYPES).all() + ] async def register_casetype( @@ -822,7 +885,7 @@ async def set_modlog_channel( return True -async def reset_cases(guild: discord.Guild) -> bool: +async def reset_cases(guild: discord.Guild) -> None: """ Wipes all modlog cases for the specified guild @@ -831,14 +894,8 @@ async def reset_cases(guild: discord.Guild) -> bool: guild: `discord.Guild` The guild to reset cases for - Returns - ------- - bool - `True` if successful - """ - await _conf.guild(guild).cases.set({}) - return True + await _conf.custom(_CASES, str(guild.id)).clear() def _strfdelta(delta): diff --git a/redbot/pytest/mod.py b/redbot/pytest/mod.py index 30b2fcda3..243af525b 100644 --- a/redbot/pytest/mod.py +++ b/redbot/pytest/mod.py @@ -5,11 +5,11 @@ __all__ = ["mod"] @pytest.fixture -def mod(config, monkeypatch): +async def mod(config, monkeypatch): from redbot.core import Config with monkeypatch.context() as m: m.setattr(Config, "get_conf", lambda *args, **kwargs: config) - modlog._init() + await modlog._init() return modlog