Implement schema migration

This commit is contained in:
Draper
2019-12-18 11:20:16 +00:00
parent 4af33cff17
commit 6946970fcc
2 changed files with 404 additions and 127 deletions

View File

@@ -5,7 +5,6 @@ import datetime
import heapq
import json
import logging
import os
import random
import re
import time
@@ -56,7 +55,8 @@ from .playlists import (
get_all_playlist,
get_playlist,
humanize_scope,
)
database,
get_all_playlist_for_migration23)
from .utils import *
@@ -67,7 +67,7 @@ __author__ = ["aikaterna", "Draper"]
log = logging.getLogger("red.audio")
_SCHEMA_VERSION = 2
_SCHEMA_VERSION = 3
LazyGreedyConverter = get_lazy_converter("--")
PlaylistConverter = get_playlist_converter()
@@ -204,10 +204,10 @@ class Audio(commands.Cog):
await self.bot.wait_until_ready()
# Unlike most cases, we want the cache to exit before migration.
await self.music_cache.initialize(self.config)
pass_config_to_dependencies(self.config, self.bot, await self.config.localpath())
await self._migrate_config(
from_version=await self.config.schema_version(), to_version=_SCHEMA_VERSION
)
pass_config_to_dependencies(self.config, self.bot, await self.config.localpath())
self._restart_connect()
self._disconnect_task = self.bot.loop.create_task(self.disconnect_timer())
lavalink.register_event_listener(self.event_handler)
@@ -233,7 +233,7 @@ class Audio(commands.Cog):
time_now = str(datetime.datetime.now(datetime.timezone.utc))
if from_version == to_version:
return
elif from_version < to_version:
if from_version < 2 <= to_version:
all_guild_data = await self.config.all_guilds()
all_playlist = {}
for guild_id, guild_data in all_guild_data.items():
@@ -271,6 +271,14 @@ class Audio(commands.Cog):
await self.config.guild(
cast(discord.Guild, discord.Object(id=guild_id))
).clear_raw("playlists")
if from_version < 3 <= to_version:
for scope in PlaylistScope.list():
scope_playlist = await get_all_playlist_for_migration23(scope)
for p in scope_playlist:
await p.save()
await self.config.custom(scope).clear()
await self.config.schema_version.set(_SCHEMA_VERSION)
if database_entries and HAS_SQL:
await self.music_cache.insert("lavalink", database_entries)
@@ -3947,9 +3955,7 @@ class Audio(commands.Cog):
]
playlist_data[
"playlist"
] = (
playlist_songs_backwards_compatible
) # TODO: Keep new playlists backwards compatible, Remove me in a few releases
] = playlist_songs_backwards_compatible # TODO: Keep new playlists backwards compatible, Remove me in a few releases
playlist_data[
"link"
] = (
@@ -7020,5 +7026,7 @@ class Audio(commands.Cog):
async def _close_database(self):
await self.music_cache.run_all_pending_tasks()
await self.music_cache.close()
if database:
database.close()
__del__ = cog_unload

View File

@@ -1,8 +1,8 @@
import json
import os
from collections import namedtuple
from dataclasses import dataclass
from enum import Enum, unique
from typing import List, Optional, Union, Tuple
from typing import List, Optional, Union
import apsw
import discord
@@ -17,7 +17,7 @@ from .errors import InvalidPlaylistScope, MissingAuthor, MissingGuild, NotAllowe
_config: Config = None
_bot: Red = None
_database: "Database" = None
database: "Database" = None
__all__ = [
"Playlist",
@@ -30,6 +30,8 @@ __all__ = [
"humanize_scope",
"standardize_scope",
"FakePlaylist",
"get_all_playlist_for_migration23",
"database",
]
FakePlaylist = namedtuple("Playlist", "author scope")
@@ -53,28 +55,37 @@ PRAGMA optimize = 1;
"""
_CREATE_TABLE = """
CREATE TABLE IF NOT EXISTS GLOBAL (
playlist_id INTEGER PRIMARY KEY,
playlist_name TEXT NOT NULL,
scope_id INTEGER NOT NULL,
author_id INTEGER NOT NULL,
playlist_url TEXT,
tracks BLOB);
CREATE TABLE IF NOT EXISTS playlists (
scope_type INTEGER NOT NULL,
playlist_id INTEGER NOT NULL,
playlist_name TEXT NOT NULL,
scope_id INTEGER NOT NULL,
author_id INTEGER NOT NULL,
playlist_url TEXT,
tracks BLOB,
PRIMARY KEY (playlist_id, scope_id, scope_type)
);
"""
_DROP = """
DROP TABLE {table};
"""
_DELETE = """
DELETE FROM {table}
DELETE FROM playlists
WHERE
(
scope_type = :scope_type
AND
playlist_id = :playlist_id
AND
scope_id = :scope_id
)
;
"""
_DELETE_SCOPE = """
DELETE FROM playlists
WHERE
scope_type = :scope_type
;
"""
_FETCH_ALL = """
SELECT
playlist_id,
@@ -83,7 +94,28 @@ scope_id,
author_id,
playlist_url,
tracks
FROM {table};
FROM playlists
WHERE
scope_type = :scope_type
;
"""
_FETCH_ALL_WITH_FILTER = """
SELECT
playlist_id,
playlist_name,
scope_id,
author_id,
playlist_url,
tracks
FROM playlists
WHERE
(
scope_type = :scope_type
AND
author_id = :author_id
)
;
"""
_FETCH = """
@@ -94,9 +126,11 @@ scope_id,
author_id,
playlist_url,
tracks
FROM {table}
FROM playlists
WHERE
(
scope_type = :scope_type
AND
playlist_id = :playlist_id
AND
scope_id = :scope_id
@@ -104,8 +138,9 @@ WHERE
"""
_UPSET = """INSERT INTO
{table}
playlists
(
scope_type
playlist_id,
playlist_name,
scope_id,
@@ -115,6 +150,7 @@ _UPSET = """INSERT INTO
)
VALUES
(
:scope_type,
:playlist_id,
:playlist_name,
:scope_id,
@@ -124,18 +160,29 @@ VALUES
)
ON CONFLICT
(
scope_type,
playlist_id,
scope_id
)
DO UPDATE
SET
playlist_name = :playlist_name,
playlist_url = :playlist_url,
tracks = :tracks
playlist_name = excluded.playlist_name,
playlist_url = excluded.playlist_url,
tracks = excluded.tracks
;
"""
@dataclass
class SQLFetchResult:
playlist_id: int
playlist_name: str
scope_id: int
author_id: int
playlist_url: Optional[str] = None
tracks: str = "[]"
@unique
class PlaylistScope(Enum):
GLOBAL = "GLOBALPLAYLIST"
@@ -152,56 +199,77 @@ class PlaylistScope(Enum):
class Database:
def __init__(self):
self._database = apsw.Connection(
str(cog_data_path(_bot.get_cog("Audio")) / "playlists.db")
)
self._database = apsw.Connection(str(cog_data_path(_bot.get_cog("Audio")) / "Audio.db"))
self.cursor = self._database.cursor()
self.cursor.execute(_PRAGMA_UPDATE_temp_store)
self.cursor.execute(_PRAGMA_UPDATE_journal_mode)
self.cursor.execute(_PRAGMA_UPDATE_wal_autocheckpoint)
self.cursor.execute(_PRAGMA_UPDATE_read_uncommitted)
for t in ["GLOBAL", "GUILD", "USER"]:
self.cursor.execute(_CREATE_TABLE.format(table=t))
self.cursor.execute(_CREATE_TABLE)
def close(self):
self.cursor.execute(_PRAGMA_UPDATE_optimize)
self._database.close()
@staticmethod
def parse_query(scope: PlaylistScope, query: str):
def get_scope_type(scope: str) -> int:
if scope == PlaylistScope.GLOBAL.value:
table = "GLOBAL"
table = 1
elif scope == PlaylistScope.GUILD.value:
table = "GUILD"
table = 2
elif scope == PlaylistScope.USER.value:
table = "USER"
table = 3
else:
raise
return query.format(table=table)
return table
def fetch(
self, scope: PlaylistScope, playlist_id: int, scope_id: int
) -> Tuple[int, str, int, int, str, str]:
query = self.parse_query(scope, _FETCH)
return self.cursor.execute(
query, ({"playlist_id": playlist_id, "scope_id": scope_id})
def fetch(self, scope: str, playlist_id: int, scope_id: int) -> SQLFetchResult:
scope_type = self.get_scope_type(scope)
row = (
self.cursor.execute(
_FETCH,
({"playlist_id": playlist_id, "scope_id": scope_id, "scope_type": scope_type}),
).fetchone()
or []
)
def delete(self, scope: PlaylistScope, playlist_id: int, scope_id: int):
query = self.parse_query(scope, _DELETE)
return self.cursor.execute(query, ({"playlist_id": playlist_id, "scope_id": scope_id}))
return SQLFetchResult(*row) if row else None
def fetch_all(self, scope: PlaylistScope) -> List[Tuple[int, str, int, int, str, str]]:
query = self.parse_query(scope, _FETCH_ALL)
return self.cursor.execute(query).fetchall()
def fetch_all(self, scope: str, author_id=None) -> List[SQLFetchResult]:
scope_type = self.get_scope_type(scope)
if author_id is not None:
output = (
self.cursor.execute(
_FETCH_ALL, ({"scope_type": scope_type, "author_id": author_id})
).fetchall()
or []
)
else:
output = (
self.cursor.execute(
_FETCH_ALL_WITH_FILTER, ({"scope_type": scope_type})
).fetchall()
or []
)
return [SQLFetchResult(*row) for row in output] if output else []
def drop(self, scope: PlaylistScope):
query = self.parse_query(scope, _DROP)
return self.cursor.execute(query)
def delete(self, scope: str, playlist_id: int, scope_id: int):
scope_type = self.get_scope_type(scope)
return self.cursor.execute(
_DELETE, ({"playlist_id": playlist_id, "scope_id": scope_id, "scope_type": scope_type})
)
def create_table(self, scope: PlaylistScope):
query = self.parse_query(scope, _CREATE_TABLE)
return self.cursor.execute(query)
def drop(self, scope: str):
scope_type = self.get_scope_type(scope)
return self.cursor.execute(_DELETE_SCOPE, ({"scope_type": scope_type}))
def create_table(self, scope: str):
scope_type = self.get_scope_type(scope)
return self.cursor.execute(_CREATE_TABLE, ({"scope_type": scope_type}))
def upsert(
self,
scope: PlaylistScope,
scope: str,
playlist_id: int,
playlist_name: str,
scope_id: int,
@@ -209,11 +277,12 @@ class Database:
playlist_url: str,
tracks: List[dict],
):
query = self.parse_query(scope, _UPSET)
scope_type = self.get_scope_type(scope)
self.cursor.execute(
query,
_UPSET,
(
{
"scope_type": scope_type,
"playlist_id": playlist_id,
"playlist_name": playlist_name,
"scope_id": scope_id,
@@ -226,13 +295,13 @@ class Database:
def _pass_config_to_playlist(config: Config, bot: Red):
global _config, _bot, _database
global _config, _bot, database
if _config is None:
_config = config
if _bot is None:
_bot = bot
if _database is None:
_database = Database()
if database is None:
database = Database()
def standardize_scope(scope) -> str:
@@ -260,11 +329,11 @@ def standardize_scope(scope) -> str:
def humanize_scope(scope, ctx=None, the=None):
if scope == PlaylistScope.GLOBAL.value:
return ctx or _("the ") if the else "" + "Global"
return ctx or _("the ") if the else "" + _("Global")
elif scope == PlaylistScope.GUILD.value:
return ctx.name if ctx else _("the ") if the else "" + "Server"
return ctx.name if ctx else _("the ") if the else "" + _("Server")
elif scope == PlaylistScope.USER.value:
return str(ctx) if ctx else _("the ") if the else "" + "User"
return str(ctx) if ctx else _("the ") if the else "" + _("User")
def _prepare_config_scope(
@@ -285,7 +354,25 @@ def _prepare_config_scope(
return config_scope
class Playlist:
def _prepare_config_scope_for_migration23( # TODO: remove me in a future version ?
scope, author: Union[discord.abc.User, int] = None, guild: discord.Guild = None
):
scope = standardize_scope(scope)
if scope == PlaylistScope.GLOBAL.value:
config_scope = [PlaylistScope.GLOBAL.value]
elif scope == PlaylistScope.USER.value:
if author is None:
raise MissingAuthor("Invalid author for user scope.")
config_scope = [PlaylistScope.USER.value, str(getattr(author, "id", author))]
else:
if guild is None:
raise MissingGuild("Invalid guild for guild scope.")
config_scope = [PlaylistScope.GUILD.value, str(getattr(guild, "id", guild))]
return config_scope
class PlaylistMigration23: # TODO: remove me in a future version ?
"""A single playlist."""
def __init__(
@@ -302,7 +389,7 @@ class Playlist:
self.bot = bot
self.guild = guild
self.scope = standardize_scope(scope)
self.config_scope = _prepare_config_scope(self.scope, author, guild)
self.config_scope = _prepare_config_scope_for_migration23(self.scope, author, guild)
self.author = author
self.guild_id = (
getattr(guild, "id", guild) if self.scope == PlaylistScope.GLOBAL.value else None
@@ -313,14 +400,6 @@ class Playlist:
self.tracks = tracks or []
self.tracks_obj = [lavalink.Track(data=track) for track in self.tracks]
def _get_scope_id(self):
if self.scope == PlaylistScope.GLOBAL.value:
return self.bot.user.id
elif self.scope == PlaylistScope.USER.value:
return self.author
else:
return self.guild.id
async def edit(self, data: dict):
"""
Edits a Playlist.
@@ -335,22 +414,8 @@ class Playlist:
for item in list(data.keys()):
setattr(self, item, data[item])
await self.save()
async def save(self):
"""
Saves a Playlist.
"""
scope, scope_id = self.config_scope
_database.upsert(
scope,
playlist_id=int(self.id),
playlist_name=self.name,
scope_id=scope_id,
author_id=self.author,
playlist_url=self.url,
tracks=self.tracks,
)
await _config.custom(*self.config_scope, str(self.id)).set(self.to_json())
def to_json(self) -> dict:
"""Transform the object to a dict.
@@ -418,8 +483,208 @@ class Playlist:
tracks=tracks,
)
async def save(self):
"""
Saves a Playlist to SQL.
"""
scope, scope_id = _prepare_config_scope(self.scope, self.author, self.guild)
database.upsert(
scope,
playlist_id=int(self.id),
playlist_name=self.name,
scope_id=scope_id,
author_id=self.author,
playlist_url=self.url,
tracks=self.tracks,
)
async def get_playlist( # TODO: convert to SQL
async def get_all_playlist_for_migration23( # TODO: remove me in a future version ?
scope: str,
bot: Red,
guild: Union[discord.Guild, int] = None,
author: Union[discord.abc.User, int] = None,
) -> List[PlaylistMigration23]:
"""
Gets all playlist for the specified scope.
Parameters
----------
scope: str
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
guild: discord.Guild
The guild to get the playlist from if scope is GUILDPLAYLIST.
author: int
The ID of the user to get the playlist from if scope is USERPLAYLIST.
bot: Red
The bot's instance
specified_user:bool
Whether or not user ID was passed as an argparse.
Returns
-------
list
A list of all playlists for the specified scope
Raises
------
`InvalidPlaylistScope`
Passing a scope that is not supported.
`MissingGuild`
Trying to access the Guild scope without a guild.
`MissingAuthor`
Trying to access the User scope without an user id.
"""
playlists = await _config.custom(scope).all()
if scope == PlaylistScope.GLOBAL.value:
return [
await PlaylistMigration23.from_json(
bot, scope, playlist_number, playlist_data, guild=guild, author=author
)
for playlist_number, playlist_data in playlists.items()
]
elif scope == PlaylistScope.USER.value:
return [
await PlaylistMigration23.from_json(
bot, scope, playlist_number, playlist_data, guild=guild, author=author
)
for user_id, scopedata in playlists.items()
for playlist_number, playlist_data in scopedata.items()
]
else:
return [
await PlaylistMigration23.from_json(
bot, scope, playlist_number, playlist_data, guild=guild, author=author
)
for guild_id, scopedata in playlists.items()
for playlist_number, playlist_data in scopedata.items()
]
class Playlist:
"""A single playlist."""
def __init__(
self,
bot: Red,
scope: str,
author: int,
playlist_id: int,
name: str,
playlist_url: Optional[str] = None,
tracks: Optional[List[dict]] = None,
guild: Union[discord.Guild, int, None] = None,
):
self.bot = bot
self.guild = guild
self.scope = standardize_scope(scope)
self.config_scope = _prepare_config_scope(self.scope, author, guild)
self.author = author
self.guild_id = (
getattr(guild, "id", guild) if self.scope == PlaylistScope.GLOBAL.value else None
)
self.id = playlist_id
self.name = name
self.url = playlist_url
self.tracks = tracks or []
self.tracks_obj = [lavalink.Track(data=track) for track in self.tracks]
async def edit(self, data: dict):
"""
Edits a Playlist.
Parameters
----------
data: dict
The attributes to change.
"""
# Disallow ID editing
if "id" in data:
raise NotAllowed("Playlist ID cannot be edited.")
for item in list(data.keys()):
setattr(self, item, data[item])
await self.save()
async def save(self):
"""
Saves a Playlist.
"""
scope, scope_id = self.config_scope
database.upsert(
scope,
playlist_id=int(self.id),
playlist_name=self.name,
scope_id=scope_id,
author_id=self.author,
playlist_url=self.url,
tracks=self.tracks,
)
def to_json(self) -> dict:
"""Transform the object to a dict.
Returns
-------
dict
The playlist in the form of a dict.
"""
data = dict(
id=self.id,
author=self.author,
guild=self.guild_id,
name=self.name,
playlist_url=self.url,
tracks=self.tracks,
)
return data
@classmethod
async def from_json(cls, bot: Red, scope: str, playlist_number: int, data: dict, **kwargs):
"""Get a Playlist object from the provided information.
Parameters
----------
bot: Red
The bot's instance. Needed to get the target user.
scope:str
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
playlist_number: int
The playlist's number.
data: dict
The JSON representation of the playlist to be gotten.
**kwargs
Extra attributes for the Playlist instance which override values
in the data dict. These should be complete objects and not
IDs, where possible.
Returns
-------
Playlist
The playlist object for the requested playlist.
Raises
------
`InvalidPlaylistScope`
Passing a scope that is not supported.
`MissingGuild`
Trying to access the Guild scope without a guild.
`MissingAuthor`
Trying to access the User scope without an user id.
"""
guild = data.scope_id if scope == PlaylistScope.GUILD.value else kwargs.get("guild")
author = data.author_id
playlist_id = data.playlist_id or playlist_number
name = data.playlist_name
playlist_url = data.playlist_url
tracks = json.loads(data.tracks)
return cls(
bot=bot,
guild=guild,
scope=scope,
author=author,
playlist_id=playlist_id,
name=name,
playlist_url=playlist_url,
tracks=tracks,
)
async def get_playlist(
playlist_number: int,
scope: str,
bot: Red,
@@ -455,17 +720,17 @@ async def get_playlist( # TODO: convert to SQL
`MissingAuthor`
Trying to access the User scope without an user id.
"""
playlist_data = await _config.custom(
*_prepare_config_scope(scope, author, guild), str(playlist_number)
).all()
if not playlist_data["id"]:
scope_standard, scope_id = _prepare_config_scope(scope, author, guild)
playlist_data = database.fetch(scope_standard, playlist_number, scope_id)
if not playlist_data.playlist_id:
raise RuntimeError(f"That playlist does not exist for the following scope: {scope}")
return await Playlist.from_json(
bot, scope, playlist_number, playlist_data, guild=guild, author=author
bot, scope_standard, playlist_number, playlist_data, guild=guild, author=author
)
async def get_all_playlist( # TODO: convert to SQL
async def get_all_playlist(
scope: str,
bot: Red,
guild: Union[discord.Guild, int] = None,
@@ -499,22 +764,19 @@ async def get_all_playlist( # TODO: convert to SQL
`MissingAuthor`
Trying to access the User scope without an user id.
"""
playlists = await _config.custom(*_prepare_config_scope(scope, author, guild)).all()
scope_standard, scope_id = _prepare_config_scope(scope, author, guild)
if specified_user:
user_id = getattr(author, "id", author)
return [
await Playlist.from_json(
bot, scope, playlist_number, playlist_data, guild=guild, author=author
)
for playlist_number, playlist_data in playlists.items()
if user_id == playlist_data.get("author")
]
playlists = database.fetch_all(scope_standard, author_id=user_id)
else:
playlists = database.fetch_all(scope_standard)
return [
await Playlist.from_json(
bot, scope, playlist_number, playlist_data, guild=guild, author=author
bot, scope, playlist.playlist_id, playlist, guild=guild, author=author
)
for playlist_number, playlist_data in playlists.items()
for playlist in playlists
]
@@ -561,7 +823,14 @@ async def create_playlist(
"""
playlist = Playlist(
ctx.bot, scope, author.id, ctx.message.id, playlist_name, playlist_url, tracks, ctx.guild
ctx.bot,
scope,
author.id,
ctx.message.id,
playlist_name,
playlist_url,
tracks,
guild or ctx.guild,
)
await playlist.save()
return playlist
@@ -594,8 +863,8 @@ async def reset_playlist(
Trying to access the User scope without an user id.
"""
scope, scope_id = _prepare_config_scope(scope, author, guild)
_database.drop(scope)
_database.create_table(scope)
database.drop(scope)
database.create_table(scope)
async def delete_playlist(
@@ -628,4 +897,4 @@ async def delete_playlist(
Trying to access the User scope without an user id.
"""
scope, scope_id = _prepare_config_scope(scope, author, guild)
_database.delete(scope, int(playlist_id), scope_id)
database.delete(scope, int(playlist_id), scope_id)