From 85a1057af68beb8380fcb6213bd6739bc29ffcf2 Mon Sep 17 00:00:00 2001 From: Drapersniper <27962761+drapersniper@users.noreply.github.com> Date: Wed, 18 Dec 2019 20:09:09 +0000 Subject: [PATCH] working Signed-off-by: Drapersniper <27962761+drapersniper@users.noreply.github.com> --- redbot/cogs/audio/audio.py | 100 +++++++++++----------- redbot/cogs/audio/converters.py | 37 +++------ redbot/cogs/audio/playlists.py | 141 +++++++++++++++++++++++++++----- 3 files changed, 181 insertions(+), 97 deletions(-) diff --git a/redbot/cogs/audio/audio.py b/redbot/cogs/audio/audio.py index a472556cd..9cc4a5881 100644 --- a/redbot/cogs/audio/audio.py +++ b/redbot/cogs/audio/audio.py @@ -202,29 +202,33 @@ class Audio(commands.Cog): await self._embed_msg(ctx, _("No DJ role found. Disabling DJ mode.")) async def initialize(self): - await self.bot.wait_until_ready() - # Unlike most cases, we want the cache to exit before migration. - await self.music_cache.initialize(self.config) - pass_config_to_dependencies(self.config, self.bot, await self.config.localpath()) - await self._migrate_config( - from_version=await self.config.schema_version(), to_version=_SCHEMA_VERSION - ) - self._restart_connect() - self._disconnect_task = self.bot.loop.create_task(self.disconnect_timer()) - lavalink.register_event_listener(self.event_handler) - if not HAS_SQL: - error_message = ( - "Audio version: {version}\nThis version requires some SQL dependencies to " - "access the caching features, " - "your Python install is missing some of them.\n\n" - "For instructions on how to fix it Google " - f"`{_ERROR}`.\n" - "You will need to install the missing SQL dependency.\n\n" - ).format(version=__version__) - with contextlib.suppress(discord.HTTPException): - for page in pagify(error_message): - await self.bot.send_to_owners(page) - log.critical(error_message) + try: + await self.bot.wait_until_ready() + # Unlike most cases, we want the cache to exit before migration. + await self.music_cache.initialize(self.config) + pass_config_to_dependencies(self.config, self.bot, await self.config.localpath()) + await self._migrate_config( + from_version=await self.config.schema_version(), to_version=_SCHEMA_VERSION + ) + self._restart_connect() + self._disconnect_task = self.bot.loop.create_task(self.disconnect_timer()) + lavalink.register_event_listener(self.event_handler) + if not HAS_SQL: + error_message = ( + "Audio version: {version}\nThis version requires some SQL dependencies to " + "access the caching features, " + "your Python install is missing some of them.\n\n" + "For instructions on how to fix it Google " + f"`{_ERROR}`.\n" + "You will need to install the missing SQL dependency.\n\n" + ).format(version=__version__) + with contextlib.suppress(discord.HTTPException): + for page in pagify(error_message): + await self.bot.send_to_owners(page) + log.critical(error_message) + except Exception as e: + log.exception("Error on audio init", exc_info=e) + raise e self._ready_event.set() self.bot.dispatch("red_audio_initialized", self) @@ -3196,6 +3200,7 @@ class Audio(commands.Cog): When multiple matches are found but none is selected. """ + correct_scope_matches: List[Playlist] original_input = matches.get("arg") correct_scope_matches = matches.get(scope) guild_to_query = guild.id @@ -3204,50 +3209,40 @@ class Audio(commands.Cog): return None, original_input if scope == PlaylistScope.USER.value: correct_scope_matches = [ - (i[2]["id"], i[2]["name"], len(i[2]["tracks"]), i[2]["author"]) - for i in correct_scope_matches - if str(user_to_query) == i[0] + p for p in correct_scope_matches if user_to_query == p.scope_id ] elif scope == PlaylistScope.GUILD.value: if specified_user: correct_scope_matches = [ - (i[2]["id"], i[2]["name"], len(i[2]["tracks"]), i[2]["author"]) - for i in correct_scope_matches - if str(guild_to_query) == i[0] and i[2]["author"] == user_to_query + p + for p in correct_scope_matches + if guild_to_query == p.scope_id and p.author == user_to_query ] else: correct_scope_matches = [ - (i[2]["id"], i[2]["name"], len(i[2]["tracks"]), i[2]["author"]) - for i in correct_scope_matches - if str(guild_to_query) == i[0] + p for p in correct_scope_matches if guild_to_query == p.scope_id ] else: if specified_user: correct_scope_matches = [ - (i[2]["id"], i[2]["name"], len(i[2]["tracks"]), i[2]["author"]) - for i in correct_scope_matches - if i[2]["author"] == user_to_query + p for p in correct_scope_matches if p.author == user_to_query ] else: - correct_scope_matches = [ - (i[2]["id"], i[2]["name"], len(i[2]["tracks"]), i[2]["author"]) - for i in correct_scope_matches - ] + correct_scope_matches = [p for p in correct_scope_matches] match_count = len(correct_scope_matches) # We done all the trimming we can with the info available time to ask the user + print("correct_scope_matches", correct_scope_matches) if match_count > 10: if original_input.isnumeric(): arg = int(original_input) - correct_scope_matches = [ - (i, n, t, a) for i, n, t, a in correct_scope_matches if i == arg - ] + correct_scope_matches = [p for p in correct_scope_matches if p.id == arg] if match_count > 10: raise TooManyMatches( f"{match_count} playlists match {original_input}: " f"Please try to be more specific, or use the playlist ID." ) elif match_count == 1: - return correct_scope_matches[0][0], original_input + return correct_scope_matches[0].id, original_input elif match_count == 0: return None, original_input @@ -3255,14 +3250,14 @@ class Audio(commands.Cog): pos_len = 3 playlists = f"{'#':{pos_len}}\n" - for number, (pid, pname, ptracks, pauthor) in enumerate(correct_scope_matches, 1): - author = self.bot.get_user(pauthor) or "Unknown" + for number, playlist in enumerate(correct_scope_matches, 1): + author = self.bot.get_user(playlist.author) or "Unknown" line = ( f"{number}." - f" <{pname}>\n" + f" <{playlist.name}>\n" f" - Scope: < {humanize_scope(scope)} >\n" - f" - ID: < {pid} >\n" - f" - Tracks: < {ptracks} >\n" + f" - ID: < {playlist.id} >\n" + f" - Tracks: < {len(playlist.tracks)} >\n" f" - Author: < {author} >\n\n" ) playlists += line @@ -3295,7 +3290,7 @@ class Audio(commands.Cog): ) with contextlib.suppress(discord.HTTPException): await msg.delete() - return correct_scope_matches[pred.result][0], original_input + return correct_scope_matches[pred.result].id, original_input @commands.group() @commands.guild_only() @@ -3954,14 +3949,17 @@ class Audio(commands.Cog): playlist_songs_backwards_compatible = [ track["info"]["uri"] for track in playlist.tracks ] + # TODO: Keep new playlists backwards compatible, Remove me in a few releases playlist_data[ "playlist" - ] = playlist_songs_backwards_compatible # TODO: Keep new playlists backwards compatible, Remove me in a few releases + ] = ( + playlist_songs_backwards_compatible + ) playlist_data[ "link" ] = ( playlist.url - ) # TODO: Keep new playlists backwards compatible, Remove me in a few releases + ) file_name = playlist.id playlist_data.update({"schema": schema, "version": version}) playlist_data = json.dumps(playlist_data) diff --git a/redbot/cogs/audio/converters.py b/redbot/cogs/audio/converters.py index 05736cf8b..857f3e887 100644 --- a/redbot/cogs/audio/converters.py +++ b/redbot/cogs/audio/converters.py @@ -8,7 +8,7 @@ from redbot.core import Config, commands from redbot.core.bot import Red from redbot.core.i18n import Translator -from .playlists import PlaylistScope, standardize_scope +from .playlists import PlaylistScope, standardize_scope, get_all_playlist_converter _ = Translator("Audio", __file__) @@ -22,8 +22,8 @@ __all__ = [ "get_playlist_converter", ] -_config = None -_bot = None +_config: Config = None +_bot: Red = None _SCOPE_HELP = """ Scope must be a valid version of one of the following: @@ -54,29 +54,18 @@ def _pass_config_to_converters(config: Config, bot: Red): class PlaylistConverter(commands.Converter): async def convert(self, ctx: commands.Context, arg: str) -> dict: - global_scope = await _config.custom(PlaylistScope.GLOBAL.value).all() - guild_scope = await _config.custom(PlaylistScope.GUILD.value).all() - user_scope = await _config.custom(PlaylistScope.USER.value).all() - user_matches = [ - (uid, pid, pdata) - for uid, data in user_scope.items() - for pid, pdata in data.items() - if arg == pid or arg.lower() in pdata.get("name", "").lower() - ] - guild_matches = [ - (gid, pid, pdata) - for gid, data in guild_scope.items() - for pid, pdata in data.items() - if arg == pid or arg.lower() in pdata.get("name", "").lower() - ] - global_matches = [ - (None, pid, pdata) - for pid, pdata in global_scope.items() - if arg == pid or arg.lower() in pdata.get("name", "").lower() - ] + global_matches = await get_all_playlist_converter( + PlaylistScope.GLOBAL.value, _bot, arg, guild=ctx.guild, author=ctx.author + ) + guild_matches = await get_all_playlist_converter( + PlaylistScope.GUILD.value, _bot, arg, guild=ctx.guild, author=ctx.author + ) + user_matches = await get_all_playlist_converter( + PlaylistScope.USER.value, _bot, arg, guild=ctx.guild, author=ctx.author + ) + if not user_matches and not guild_matches and not global_matches: raise commands.BadArgument(_("Could not match '{}' to a playlist.").format(arg)) - return { PlaylistScope.GLOBAL.value: global_matches, PlaylistScope.GUILD.value: guild_matches, diff --git a/redbot/cogs/audio/playlists.py b/redbot/cogs/audio/playlists.py index 2624a628a..65c62d136 100644 --- a/redbot/cogs/audio/playlists.py +++ b/redbot/cogs/audio/playlists.py @@ -98,7 +98,9 @@ SELECT FROM playlists WHERE - scope_type = :scope_type ; + scope_type = :scope_type + AND scope_id = :scope_id + ; """ _FETCH_ALL_WITH_FILTER = """ @@ -114,11 +116,35 @@ FROM WHERE ( scope_type = :scope_type + AND scope_id = :scope_id AND author_id = :author_id ) ; """ +_FETCH_ALL_CONVERTER = """ +SELECT + playlist_id, + playlist_name, + scope_id, + author_id, + playlist_url, + tracks +FROM + playlists +WHERE + ( + scope_type = :scope_type + AND + ( + playlist_id = :playlist_id + OR + playlist_name = :playlist_name + ) + ) +; +""" + _FETCH = """ SELECT playlist_id, @@ -211,22 +237,38 @@ class Database: return SQLFetchResult(*row) if row else None - def fetch_all(self, scope: str, author_id=None) -> List[SQLFetchResult]: + def fetch_all(self, scope: str, scope_id: int, author_id=None) -> List[SQLFetchResult]: scope_type = self.get_scope_type(scope) if author_id is not None: - output = ( - self.cursor.execute( - _FETCH_ALL, ({"scope_type": scope_type, "author_id": author_id}) - ).fetchall() - or [] - ) + output = self.cursor.execute( + _FETCH_ALL_WITH_FILTER, + ({"scope_type": scope_type, "scope_id": scope_id, "author_id": author_id}), + ).fetchall() else: - output = ( - self.cursor.execute( - _FETCH_ALL_WITH_FILTER, ({"scope_type": scope_type}) - ).fetchall() - or [] - ) + output = self.cursor.execute( + _FETCH_ALL, ({"scope_type": scope_type, "scope_id": scope_id}) + ).fetchall() + return [SQLFetchResult(*row) for row in output] if output else [] + + def fetch_all_converter(self, scope: str, playlist_name, playlist_id) -> List[SQLFetchResult]: + scope_type = self.get_scope_type(scope) + try: + playlist_id = int(playlist_id) + except: + playlist_id = -1 + output = ( + self.cursor.execute( + _FETCH_ALL_CONVERTER, + ( + { + "scope_type": scope_type, + "playlist_name": playlist_name, + "playlist_id": playlist_id, + } + ), + ).fetchall() + or [] + ) return [SQLFetchResult(*row) for row in output] if output else [] def delete(self, scope: str, playlist_id: int, scope_id: int): @@ -369,9 +411,10 @@ class PlaylistMigration23: # TODO: remove me in a future version ? self.url = playlist_url self.tracks = tracks or [] - @classmethod - async def from_json(cls, scope: str, playlist_number: int, data: dict, **kwargs): + async def from_json( + cls, scope: str, playlist_number: int, data: dict, **kwargs + ) -> "PlaylistMigration23": """Get a Playlist object from the provided information. Parameters ---------- @@ -405,7 +448,7 @@ class PlaylistMigration23: # TODO: remove me in a future version ? playlist_id = data.get("id") or playlist_number name = data.get("name", "Unnamed") playlist_url = data.get("playlist_url", None) - tracks = json.loads(data.get("tracks", "[]")) + tracks = data.get("tracks", []) return cls( guild=guild, @@ -469,7 +512,11 @@ async def get_all_playlist_for_migration23( # TODO: remove me in a future versi if scope == PlaylistScope.GLOBAL.value: return [ await PlaylistMigration23.from_json( - scope, playlist_number, playlist_data, guild=guild, author=int(playlist_data.get("author", 0)) + scope, + playlist_number, + playlist_data, + guild=guild, + author=int(playlist_data.get("author", 0)), ) for playlist_number, playlist_data in playlists.items() ] @@ -484,7 +531,11 @@ async def get_all_playlist_for_migration23( # TODO: remove me in a future versi else: return [ await PlaylistMigration23.from_json( - scope, playlist_number, playlist_data, guild=int(guild_id), author=int(playlist_data.get("author", 0)) + scope, + playlist_number, + playlist_data, + guild=int(guild_id), + author=int(playlist_data.get("author", 0)), ) for guild_id, scopedata in playlists.items() for playlist_number, playlist_data in scopedata.items() @@ -509,6 +560,7 @@ class Playlist: self.guild = guild self.scope = standardize_scope(scope) self.config_scope = _prepare_config_scope(self.scope, author, guild) + self.scope_id = self.config_scope[-1] self.author = author self.guild_id = ( getattr(guild, "id", guild) if self.scope == PlaylistScope.GLOBAL.value else None @@ -569,7 +621,9 @@ class Playlist: return data @classmethod - async def from_json(cls, bot: Red, scope: str, playlist_number: int, data: dict, **kwargs): + async def from_json( + cls, bot: Red, scope: str, playlist_number: int, data: SQLFetchResult, **kwargs + ): """Get a Playlist object from the provided information. Parameters ---------- @@ -701,10 +755,53 @@ async def get_all_playlist( if specified_user: user_id = getattr(author, "id", author) - playlists = database.fetch_all(scope_standard, author_id=user_id) + playlists = database.fetch_all(scope_standard, scope_id, author_id=user_id) else: - playlists = database.fetch_all(scope_standard) + playlists = database.fetch_all(scope_standard, scope_id) + return [ + await Playlist.from_json( + bot, scope, playlist.playlist_id, playlist, guild=guild, author=author + ) + for playlist in playlists + ] + +async def get_all_playlist_converter( + scope: str, + bot: Red, + arg: str, + guild: Union[discord.Guild, int] = None, + author: Union[discord.abc.User, int] = None, +) -> List[Playlist]: + """ + Gets all playlist for the specified scope. + Parameters + ---------- + scope: str + The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'. + guild: discord.Guild + The guild to get the playlist from if scope is GUILDPLAYLIST. + author: int + The ID of the user to get the playlist from if scope is USERPLAYLIST. + bot: Red + The bot's instance + specified_user:bool + Whether or not user ID was passed as an argparse. + Returns + ------- + list + A list of all playlists for the specified scope + Raises + ------ + `InvalidPlaylistScope` + Passing a scope that is not supported. + `MissingGuild` + Trying to access the Guild scope without a guild. + `MissingAuthor` + Trying to access the User scope without an user id. + """ + scope_standard, scope_id = _prepare_config_scope(scope, author, guild) + playlists = database.fetch_all_converter(scope_standard, playlist_name=arg, playlist_id=arg) return [ await Playlist.from_json( bot, scope, playlist.playlist_id, playlist, guild=guild, author=author