From 6946970fcc20055114ab49d66313a5ab3af37ba5 Mon Sep 17 00:00:00 2001 From: Draper <27962761+Drapersniper@users.noreply.github.com> Date: Wed, 18 Dec 2019 11:20:16 +0000 Subject: [PATCH] Implement schema migration --- redbot/cogs/audio/audio.py | 24 +- redbot/cogs/audio/playlists.py | 507 +++++++++++++++++++++++++-------- 2 files changed, 404 insertions(+), 127 deletions(-) diff --git a/redbot/cogs/audio/audio.py b/redbot/cogs/audio/audio.py index 209a571bb..948405b21 100644 --- a/redbot/cogs/audio/audio.py +++ b/redbot/cogs/audio/audio.py @@ -5,7 +5,6 @@ import datetime import heapq import json import logging -import os import random import re import time @@ -56,7 +55,8 @@ from .playlists import ( get_all_playlist, get_playlist, humanize_scope, -) + database, + get_all_playlist_for_migration23) from .utils import * @@ -67,7 +67,7 @@ __author__ = ["aikaterna", "Draper"] log = logging.getLogger("red.audio") -_SCHEMA_VERSION = 2 +_SCHEMA_VERSION = 3 LazyGreedyConverter = get_lazy_converter("--") PlaylistConverter = get_playlist_converter() @@ -204,10 +204,10 @@ class Audio(commands.Cog): await self.bot.wait_until_ready() # Unlike most cases, we want the cache to exit before migration. await self.music_cache.initialize(self.config) + pass_config_to_dependencies(self.config, self.bot, await self.config.localpath()) await self._migrate_config( from_version=await self.config.schema_version(), to_version=_SCHEMA_VERSION ) - pass_config_to_dependencies(self.config, self.bot, await self.config.localpath()) self._restart_connect() self._disconnect_task = self.bot.loop.create_task(self.disconnect_timer()) lavalink.register_event_listener(self.event_handler) @@ -233,7 +233,7 @@ class Audio(commands.Cog): time_now = str(datetime.datetime.now(datetime.timezone.utc)) if from_version == to_version: return - elif from_version < to_version: + if from_version < 2 <= to_version: all_guild_data = await self.config.all_guilds() all_playlist = {} for guild_id, guild_data in all_guild_data.items(): @@ -271,6 +271,14 @@ class Audio(commands.Cog): await self.config.guild( cast(discord.Guild, discord.Object(id=guild_id)) ).clear_raw("playlists") + if from_version < 3 <= to_version: + for scope in PlaylistScope.list(): + scope_playlist = await get_all_playlist_for_migration23(scope) + for p in scope_playlist: + await p.save() + await self.config.custom(scope).clear() + await self.config.schema_version.set(_SCHEMA_VERSION) + if database_entries and HAS_SQL: await self.music_cache.insert("lavalink", database_entries) @@ -3947,9 +3955,7 @@ class Audio(commands.Cog): ] playlist_data[ "playlist" - ] = ( - playlist_songs_backwards_compatible - ) # TODO: Keep new playlists backwards compatible, Remove me in a few releases + ] = playlist_songs_backwards_compatible # TODO: Keep new playlists backwards compatible, Remove me in a few releases playlist_data[ "link" ] = ( @@ -7020,5 +7026,7 @@ class Audio(commands.Cog): async def _close_database(self): await self.music_cache.run_all_pending_tasks() await self.music_cache.close() + if database: + database.close() __del__ = cog_unload diff --git a/redbot/cogs/audio/playlists.py b/redbot/cogs/audio/playlists.py index a96a87637..2bf155dc5 100644 --- a/redbot/cogs/audio/playlists.py +++ b/redbot/cogs/audio/playlists.py @@ -1,8 +1,8 @@ import json -import os from collections import namedtuple +from dataclasses import dataclass from enum import Enum, unique -from typing import List, Optional, Union, Tuple +from typing import List, Optional, Union import apsw import discord @@ -17,7 +17,7 @@ from .errors import InvalidPlaylistScope, MissingAuthor, MissingGuild, NotAllowe _config: Config = None _bot: Red = None -_database: "Database" = None +database: "Database" = None __all__ = [ "Playlist", @@ -30,6 +30,8 @@ __all__ = [ "humanize_scope", "standardize_scope", "FakePlaylist", + "get_all_playlist_for_migration23", + "database", ] FakePlaylist = namedtuple("Playlist", "author scope") @@ -53,28 +55,37 @@ PRAGMA optimize = 1; """ _CREATE_TABLE = """ -CREATE TABLE IF NOT EXISTS GLOBAL ( -playlist_id INTEGER PRIMARY KEY, -playlist_name TEXT NOT NULL, -scope_id INTEGER NOT NULL, -author_id INTEGER NOT NULL, -playlist_url TEXT, -tracks BLOB); +CREATE TABLE IF NOT EXISTS playlists ( + scope_type INTEGER NOT NULL, + playlist_id INTEGER NOT NULL, + playlist_name TEXT NOT NULL, + scope_id INTEGER NOT NULL, + author_id INTEGER NOT NULL, + playlist_url TEXT, + tracks BLOB, + PRIMARY KEY (playlist_id, scope_id, scope_type) +); """ -_DROP = """ -DROP TABLE {table}; -""" _DELETE = """ -DELETE FROM {table} +DELETE FROM playlists WHERE - ( - playlist_id = :playlist_id - AND - scope_id = :scope_id - ) + ( + scope_type = :scope_type + AND + playlist_id = :playlist_id + AND + scope_id = :scope_id + ) ; """ +_DELETE_SCOPE = """ +DELETE FROM playlists +WHERE + scope_type = :scope_type +; +""" + _FETCH_ALL = """ SELECT playlist_id, @@ -83,7 +94,28 @@ scope_id, author_id, playlist_url, tracks -FROM {table}; +FROM playlists +WHERE + scope_type = :scope_type +; +""" + +_FETCH_ALL_WITH_FILTER = """ +SELECT +playlist_id, +playlist_name, +scope_id, +author_id, +playlist_url, +tracks +FROM playlists +WHERE + ( + scope_type = :scope_type + AND + author_id = :author_id + ) +; """ _FETCH = """ @@ -94,18 +126,21 @@ scope_id, author_id, playlist_url, tracks -FROM {table} +FROM playlists WHERE - ( - playlist_id = :playlist_id - AND - scope_id = :scope_id - ) + ( + scope_type = :scope_type + AND + playlist_id = :playlist_id + AND + scope_id = :scope_id + ) """ _UPSET = """INSERT INTO -{table} +playlists ( + scope_type playlist_id, playlist_name, scope_id, @@ -115,6 +150,7 @@ _UPSET = """INSERT INTO ) VALUES ( + :scope_type, :playlist_id, :playlist_name, :scope_id, @@ -124,18 +160,29 @@ VALUES ) ON CONFLICT ( + scope_type, playlist_id, scope_id ) DO UPDATE SET - playlist_name = :playlist_name, - playlist_url = :playlist_url, - tracks = :tracks + playlist_name = excluded.playlist_name, + playlist_url = excluded.playlist_url, + tracks = excluded.tracks ; """ +@dataclass +class SQLFetchResult: + playlist_id: int + playlist_name: str + scope_id: int + author_id: int + playlist_url: Optional[str] = None + tracks: str = "[]" + + @unique class PlaylistScope(Enum): GLOBAL = "GLOBALPLAYLIST" @@ -152,56 +199,77 @@ class PlaylistScope(Enum): class Database: def __init__(self): - self._database = apsw.Connection( - str(cog_data_path(_bot.get_cog("Audio")) / "playlists.db") - ) + self._database = apsw.Connection(str(cog_data_path(_bot.get_cog("Audio")) / "Audio.db")) self.cursor = self._database.cursor() self.cursor.execute(_PRAGMA_UPDATE_temp_store) self.cursor.execute(_PRAGMA_UPDATE_journal_mode) self.cursor.execute(_PRAGMA_UPDATE_wal_autocheckpoint) self.cursor.execute(_PRAGMA_UPDATE_read_uncommitted) - for t in ["GLOBAL", "GUILD", "USER"]: - self.cursor.execute(_CREATE_TABLE.format(table=t)) + self.cursor.execute(_CREATE_TABLE) + + def close(self): + self.cursor.execute(_PRAGMA_UPDATE_optimize) + self._database.close() @staticmethod - def parse_query(scope: PlaylistScope, query: str): + def get_scope_type(scope: str) -> int: if scope == PlaylistScope.GLOBAL.value: - table = "GLOBAL" + table = 1 elif scope == PlaylistScope.GUILD.value: - table = "GUILD" + table = 2 elif scope == PlaylistScope.USER.value: - table = "USER" + table = 3 else: raise - return query.format(table=table) + return table - def fetch( - self, scope: PlaylistScope, playlist_id: int, scope_id: int - ) -> Tuple[int, str, int, int, str, str]: - query = self.parse_query(scope, _FETCH) + def fetch(self, scope: str, playlist_id: int, scope_id: int) -> SQLFetchResult: + scope_type = self.get_scope_type(scope) + row = ( + self.cursor.execute( + _FETCH, + ({"playlist_id": playlist_id, "scope_id": scope_id, "scope_type": scope_type}), + ).fetchone() + or [] + ) + + return SQLFetchResult(*row) if row else None + + def fetch_all(self, scope: str, author_id=None) -> List[SQLFetchResult]: + scope_type = self.get_scope_type(scope) + if author_id is not None: + output = ( + self.cursor.execute( + _FETCH_ALL, ({"scope_type": scope_type, "author_id": author_id}) + ).fetchall() + or [] + ) + else: + output = ( + self.cursor.execute( + _FETCH_ALL_WITH_FILTER, ({"scope_type": scope_type}) + ).fetchall() + or [] + ) + return [SQLFetchResult(*row) for row in output] if output else [] + + def delete(self, scope: str, playlist_id: int, scope_id: int): + scope_type = self.get_scope_type(scope) return self.cursor.execute( - query, ({"playlist_id": playlist_id, "scope_id": scope_id}) - ).fetchone() + _DELETE, ({"playlist_id": playlist_id, "scope_id": scope_id, "scope_type": scope_type}) + ) - def delete(self, scope: PlaylistScope, playlist_id: int, scope_id: int): - query = self.parse_query(scope, _DELETE) - return self.cursor.execute(query, ({"playlist_id": playlist_id, "scope_id": scope_id})) + def drop(self, scope: str): + scope_type = self.get_scope_type(scope) + return self.cursor.execute(_DELETE_SCOPE, ({"scope_type": scope_type})) - def fetch_all(self, scope: PlaylistScope) -> List[Tuple[int, str, int, int, str, str]]: - query = self.parse_query(scope, _FETCH_ALL) - return self.cursor.execute(query).fetchall() - - def drop(self, scope: PlaylistScope): - query = self.parse_query(scope, _DROP) - return self.cursor.execute(query) - - def create_table(self, scope: PlaylistScope): - query = self.parse_query(scope, _CREATE_TABLE) - return self.cursor.execute(query) + def create_table(self, scope: str): + scope_type = self.get_scope_type(scope) + return self.cursor.execute(_CREATE_TABLE, ({"scope_type": scope_type})) def upsert( self, - scope: PlaylistScope, + scope: str, playlist_id: int, playlist_name: str, scope_id: int, @@ -209,11 +277,12 @@ class Database: playlist_url: str, tracks: List[dict], ): - query = self.parse_query(scope, _UPSET) + scope_type = self.get_scope_type(scope) self.cursor.execute( - query, + _UPSET, ( { + "scope_type": scope_type, "playlist_id": playlist_id, "playlist_name": playlist_name, "scope_id": scope_id, @@ -226,13 +295,13 @@ class Database: def _pass_config_to_playlist(config: Config, bot: Red): - global _config, _bot, _database + global _config, _bot, database if _config is None: _config = config if _bot is None: _bot = bot - if _database is None: - _database = Database() + if database is None: + database = Database() def standardize_scope(scope) -> str: @@ -260,11 +329,11 @@ def standardize_scope(scope) -> str: def humanize_scope(scope, ctx=None, the=None): if scope == PlaylistScope.GLOBAL.value: - return ctx or _("the ") if the else "" + "Global" + return ctx or _("the ") if the else "" + _("Global") elif scope == PlaylistScope.GUILD.value: - return ctx.name if ctx else _("the ") if the else "" + "Server" + return ctx.name if ctx else _("the ") if the else "" + _("Server") elif scope == PlaylistScope.USER.value: - return str(ctx) if ctx else _("the ") if the else "" + "User" + return str(ctx) if ctx else _("the ") if the else "" + _("User") def _prepare_config_scope( @@ -285,7 +354,25 @@ def _prepare_config_scope( return config_scope -class Playlist: +def _prepare_config_scope_for_migration23( # TODO: remove me in a future version ? + scope, author: Union[discord.abc.User, int] = None, guild: discord.Guild = None +): + scope = standardize_scope(scope) + + if scope == PlaylistScope.GLOBAL.value: + config_scope = [PlaylistScope.GLOBAL.value] + elif scope == PlaylistScope.USER.value: + if author is None: + raise MissingAuthor("Invalid author for user scope.") + config_scope = [PlaylistScope.USER.value, str(getattr(author, "id", author))] + else: + if guild is None: + raise MissingGuild("Invalid guild for guild scope.") + config_scope = [PlaylistScope.GUILD.value, str(getattr(guild, "id", guild))] + return config_scope + + +class PlaylistMigration23: # TODO: remove me in a future version ? """A single playlist.""" def __init__( @@ -302,7 +389,7 @@ class Playlist: self.bot = bot self.guild = guild self.scope = standardize_scope(scope) - self.config_scope = _prepare_config_scope(self.scope, author, guild) + self.config_scope = _prepare_config_scope_for_migration23(self.scope, author, guild) self.author = author self.guild_id = ( getattr(guild, "id", guild) if self.scope == PlaylistScope.GLOBAL.value else None @@ -313,14 +400,6 @@ class Playlist: self.tracks = tracks or [] self.tracks_obj = [lavalink.Track(data=track) for track in self.tracks] - def _get_scope_id(self): - if self.scope == PlaylistScope.GLOBAL.value: - return self.bot.user.id - elif self.scope == PlaylistScope.USER.value: - return self.author - else: - return self.guild.id - async def edit(self, data: dict): """ Edits a Playlist. @@ -335,22 +414,8 @@ class Playlist: for item in list(data.keys()): setattr(self, item, data[item]) - await self.save() - async def save(self): - """ - Saves a Playlist. - """ - scope, scope_id = self.config_scope - _database.upsert( - scope, - playlist_id=int(self.id), - playlist_name=self.name, - scope_id=scope_id, - author_id=self.author, - playlist_url=self.url, - tracks=self.tracks, - ) + await _config.custom(*self.config_scope, str(self.id)).set(self.to_json()) def to_json(self) -> dict: """Transform the object to a dict. @@ -418,8 +483,208 @@ class Playlist: tracks=tracks, ) + async def save(self): + """ + Saves a Playlist to SQL. + """ + scope, scope_id = _prepare_config_scope(self.scope, self.author, self.guild) + database.upsert( + scope, + playlist_id=int(self.id), + playlist_name=self.name, + scope_id=scope_id, + author_id=self.author, + playlist_url=self.url, + tracks=self.tracks, + ) -async def get_playlist( # TODO: convert to SQL + +async def get_all_playlist_for_migration23( # TODO: remove me in a future version ? + scope: str, + bot: Red, + guild: Union[discord.Guild, int] = None, + author: Union[discord.abc.User, int] = None, +) -> List[PlaylistMigration23]: + """ + Gets all playlist for the specified scope. + Parameters + ---------- + scope: str + The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'. + guild: discord.Guild + The guild to get the playlist from if scope is GUILDPLAYLIST. + author: int + The ID of the user to get the playlist from if scope is USERPLAYLIST. + bot: Red + The bot's instance + specified_user:bool + Whether or not user ID was passed as an argparse. + Returns + ------- + list + A list of all playlists for the specified scope + Raises + ------ + `InvalidPlaylistScope` + Passing a scope that is not supported. + `MissingGuild` + Trying to access the Guild scope without a guild. + `MissingAuthor` + Trying to access the User scope without an user id. + """ + playlists = await _config.custom(scope).all() + if scope == PlaylistScope.GLOBAL.value: + return [ + await PlaylistMigration23.from_json( + bot, scope, playlist_number, playlist_data, guild=guild, author=author + ) + for playlist_number, playlist_data in playlists.items() + ] + elif scope == PlaylistScope.USER.value: + return [ + await PlaylistMigration23.from_json( + bot, scope, playlist_number, playlist_data, guild=guild, author=author + ) + for user_id, scopedata in playlists.items() + for playlist_number, playlist_data in scopedata.items() + ] + else: + return [ + await PlaylistMigration23.from_json( + bot, scope, playlist_number, playlist_data, guild=guild, author=author + ) + for guild_id, scopedata in playlists.items() + for playlist_number, playlist_data in scopedata.items() + ] + + +class Playlist: + """A single playlist.""" + + def __init__( + self, + bot: Red, + scope: str, + author: int, + playlist_id: int, + name: str, + playlist_url: Optional[str] = None, + tracks: Optional[List[dict]] = None, + guild: Union[discord.Guild, int, None] = None, + ): + self.bot = bot + self.guild = guild + self.scope = standardize_scope(scope) + self.config_scope = _prepare_config_scope(self.scope, author, guild) + self.author = author + self.guild_id = ( + getattr(guild, "id", guild) if self.scope == PlaylistScope.GLOBAL.value else None + ) + self.id = playlist_id + self.name = name + self.url = playlist_url + self.tracks = tracks or [] + self.tracks_obj = [lavalink.Track(data=track) for track in self.tracks] + + async def edit(self, data: dict): + """ + Edits a Playlist. + Parameters + ---------- + data: dict + The attributes to change. + """ + # Disallow ID editing + if "id" in data: + raise NotAllowed("Playlist ID cannot be edited.") + + for item in list(data.keys()): + setattr(self, item, data[item]) + await self.save() + + async def save(self): + """ + Saves a Playlist. + """ + scope, scope_id = self.config_scope + database.upsert( + scope, + playlist_id=int(self.id), + playlist_name=self.name, + scope_id=scope_id, + author_id=self.author, + playlist_url=self.url, + tracks=self.tracks, + ) + + def to_json(self) -> dict: + """Transform the object to a dict. + Returns + ------- + dict + The playlist in the form of a dict. + """ + data = dict( + id=self.id, + author=self.author, + guild=self.guild_id, + name=self.name, + playlist_url=self.url, + tracks=self.tracks, + ) + + return data + + @classmethod + async def from_json(cls, bot: Red, scope: str, playlist_number: int, data: dict, **kwargs): + """Get a Playlist object from the provided information. + Parameters + ---------- + bot: Red + The bot's instance. Needed to get the target user. + scope:str + The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'. + playlist_number: int + The playlist's number. + data: dict + The JSON representation of the playlist to be gotten. + **kwargs + Extra attributes for the Playlist instance which override values + in the data dict. These should be complete objects and not + IDs, where possible. + Returns + ------- + Playlist + The playlist object for the requested playlist. + Raises + ------ + `InvalidPlaylistScope` + Passing a scope that is not supported. + `MissingGuild` + Trying to access the Guild scope without a guild. + `MissingAuthor` + Trying to access the User scope without an user id. + """ + guild = data.scope_id if scope == PlaylistScope.GUILD.value else kwargs.get("guild") + author = data.author_id + playlist_id = data.playlist_id or playlist_number + name = data.playlist_name + playlist_url = data.playlist_url + tracks = json.loads(data.tracks) + + return cls( + bot=bot, + guild=guild, + scope=scope, + author=author, + playlist_id=playlist_id, + name=name, + playlist_url=playlist_url, + tracks=tracks, + ) + + +async def get_playlist( playlist_number: int, scope: str, bot: Red, @@ -455,17 +720,17 @@ async def get_playlist( # TODO: convert to SQL `MissingAuthor` Trying to access the User scope without an user id. """ - playlist_data = await _config.custom( - *_prepare_config_scope(scope, author, guild), str(playlist_number) - ).all() - if not playlist_data["id"]: + scope_standard, scope_id = _prepare_config_scope(scope, author, guild) + playlist_data = database.fetch(scope_standard, playlist_number, scope_id) + + if not playlist_data.playlist_id: raise RuntimeError(f"That playlist does not exist for the following scope: {scope}") return await Playlist.from_json( - bot, scope, playlist_number, playlist_data, guild=guild, author=author + bot, scope_standard, playlist_number, playlist_data, guild=guild, author=author ) -async def get_all_playlist( # TODO: convert to SQL +async def get_all_playlist( scope: str, bot: Red, guild: Union[discord.Guild, int] = None, @@ -499,23 +764,20 @@ async def get_all_playlist( # TODO: convert to SQL `MissingAuthor` Trying to access the User scope without an user id. """ - playlists = await _config.custom(*_prepare_config_scope(scope, author, guild)).all() + scope_standard, scope_id = _prepare_config_scope(scope, author, guild) + if specified_user: user_id = getattr(author, "id", author) - return [ - await Playlist.from_json( - bot, scope, playlist_number, playlist_data, guild=guild, author=author - ) - for playlist_number, playlist_data in playlists.items() - if user_id == playlist_data.get("author") - ] + playlists = database.fetch_all(scope_standard, author_id=user_id) else: - return [ - await Playlist.from_json( - bot, scope, playlist_number, playlist_data, guild=guild, author=author - ) - for playlist_number, playlist_data in playlists.items() - ] + playlists = database.fetch_all(scope_standard) + + return [ + await Playlist.from_json( + bot, scope, playlist.playlist_id, playlist, guild=guild, author=author + ) + for playlist in playlists + ] async def create_playlist( @@ -561,7 +823,14 @@ async def create_playlist( """ playlist = Playlist( - ctx.bot, scope, author.id, ctx.message.id, playlist_name, playlist_url, tracks, ctx.guild + ctx.bot, + scope, + author.id, + ctx.message.id, + playlist_name, + playlist_url, + tracks, + guild or ctx.guild, ) await playlist.save() return playlist @@ -594,8 +863,8 @@ async def reset_playlist( Trying to access the User scope without an user id. """ scope, scope_id = _prepare_config_scope(scope, author, guild) - _database.drop(scope) - _database.create_table(scope) + database.drop(scope) + database.create_table(scope) async def delete_playlist( @@ -628,4 +897,4 @@ async def delete_playlist( Trying to access the User scope without an user id. """ scope, scope_id = _prepare_config_scope(scope, author, guild) - _database.delete(scope, int(playlist_id), scope_id) + database.delete(scope, int(playlist_id), scope_id)