Draper e52c20b9e7 [Audio] Hotfix an edge case where an attribute error can be raised (#3328)
* Limit Playlists

Signed-off-by: Drapersniper <27962761+drapersniper@users.noreply.github.com>

* Hotfix

Signed-off-by: Drapersniper <27962761+drapersniper@users.noreply.github.com>

* Hotfix

Signed-off-by: Drapersniper <27962761+drapersniper@users.noreply.github.com>
2020-01-12 16:20:46 -05:00

375 lines
13 KiB
Python

import asyncio
import concurrent.futures
import contextlib
import datetime
import json
import logging
import time
from dataclasses import dataclass, field
from typing import Dict, List, Optional, TYPE_CHECKING, Tuple, Union, MutableMapping, Mapping
import apsw
from redbot.core import Config
from redbot.core.bot import Red
from redbot.core.data_manager import cog_data_path
from .errors import InvalidTableError
from .sql_statements import *
from .utils import PlaylistScope
log = logging.getLogger("red.audio.database")
if TYPE_CHECKING:
database_connection: apsw.Connection
_bot: Red
_config: Config
else:
_config = None
_bot = None
database_connection = None
SCHEMA_VERSION = 3
SQLError = apsw.ExecutionCompleteError
_PARSER: Mapping = {
"youtube": {
"insert": YOUTUBE_UPSERT,
"youtube_url": {"query": YOUTUBE_QUERY},
"update": YOUTUBE_UPDATE,
},
"spotify": {
"insert": SPOTIFY_UPSERT,
"track_info": {"query": SPOTIFY_QUERY},
"update": SPOTIFY_UPDATE,
},
"lavalink": {
"insert": LAVALINK_UPSERT,
"data": {"query": LAVALINK_QUERY, "played": LAVALINK_QUERY_LAST_FETCHED_RANDOM},
"update": LAVALINK_UPDATE,
},
}
def _pass_config_to_databases(config: Config, bot: Red):
global _config, _bot, database_connection
if _config is None:
_config = config
if _bot is None:
_bot = bot
if database_connection is None:
database_connection = apsw.Connection(
str(cog_data_path(_bot.get_cog("Audio")) / "Audio.db")
)
@dataclass
class PlaylistFetchResult:
playlist_id: int
playlist_name: str
scope_id: int
author_id: int
playlist_url: Optional[str] = None
tracks: List[MutableMapping] = field(default_factory=lambda: [])
def __post_init__(self):
if isinstance(self.tracks, str):
self.tracks = json.loads(self.tracks)
@dataclass
class CacheFetchResult:
query: Optional[Union[str, MutableMapping]]
last_updated: int
def __post_init__(self):
if isinstance(self.last_updated, int):
self.updated_on: datetime.datetime = datetime.datetime.fromtimestamp(self.last_updated)
if isinstance(self.query, str) and all(
k in self.query for k in ["loadType", "playlistInfo", "isSeekable", "isStream"]
):
self.query = json.loads(self.query)
else:
self.query = None
@dataclass
class CacheLastFetchResult:
tracks: List[MutableMapping] = field(default_factory=lambda: [])
def __post_init__(self):
if isinstance(self.tracks, str):
self.tracks = json.loads(self.tracks)
@dataclass
class CacheGetAllLavalink:
query: str
data: List[MutableMapping] = field(default_factory=lambda: [])
def __post_init__(self):
if isinstance(self.data, str):
self.data = json.loads(self.data)
class CacheInterface:
def __init__(self):
self.database = database_connection.cursor()
@staticmethod
def close():
with contextlib.suppress(Exception):
database_connection.close()
async def init(self):
self.database.execute(PRAGMA_SET_temp_store)
self.database.execute(PRAGMA_SET_journal_mode)
self.database.execute(PRAGMA_SET_read_uncommitted)
self.maybe_migrate()
self.database.execute(LAVALINK_CREATE_TABLE)
self.database.execute(LAVALINK_CREATE_INDEX)
self.database.execute(YOUTUBE_CREATE_TABLE)
self.database.execute(YOUTUBE_CREATE_INDEX)
self.database.execute(SPOTIFY_CREATE_TABLE)
self.database.execute(SPOTIFY_CREATE_INDEX)
await self.clean_up_old_entries()
async def clean_up_old_entries(self):
max_age = await _config.cache_age()
maxage = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=max_age)
maxage_int = int(time.mktime(maxage.timetuple()))
values = {"maxage": maxage_int}
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(self.database.execute, LAVALINK_DELETE_OLD_ENTRIES, values)
executor.submit(self.database.execute, YOUTUBE_DELETE_OLD_ENTRIES, values)
executor.submit(self.database.execute, SPOTIFY_DELETE_OLD_ENTRIES, values)
def maybe_migrate(self):
current_version = self.database.execute(PRAGMA_FETCH_user_version).fetchone()
if isinstance(current_version, tuple):
current_version = current_version[0]
if current_version == SCHEMA_VERSION:
return
self.database.execute(PRAGMA_SET_user_version, {"version": SCHEMA_VERSION})
async def insert(self, table: str, values: List[MutableMapping]):
try:
query = _PARSER.get(table, {}).get("insert")
if query is None:
raise InvalidTableError(f"{table} is not a valid table in the database.")
self.database.execute("BEGIN;")
self.database.executemany(query, values)
self.database.execute("COMMIT;")
except Exception as err:
log.debug("Error during audio db insert", exc_info=err)
async def update(self, table: str, values: Dict[str, Union[str, int]]):
try:
table = _PARSER.get(table, {})
sql_query = table.get("update")
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
values["last_fetched"] = time_now
if not table:
raise InvalidTableError(f"{table} is not a valid table in the database.")
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(self.database.execute, sql_query, values)
except Exception as err:
log.debug("Error during audio db update", exc_info=err)
async def fetch_one(
self, table: str, query: str, values: Dict[str, Union[str, int]]
) -> Tuple[Optional[str], bool]:
table = _PARSER.get(table, {})
sql_query = table.get(query, {}).get("query")
if not table:
raise InvalidTableError(f"{table} is not a valid table in the database.")
max_age = await _config.cache_age()
maxage = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=max_age)
maxage_int = int(time.mktime(maxage.timetuple()))
values.update({"maxage": maxage_int})
output = self.database.execute(sql_query, values).fetchone() or (None, 0)
result = CacheFetchResult(*output)
return result.query, False
async def fetch_all(
self, table: str, query: str, values: Dict[str, Union[str, int]]
) -> List[CacheLastFetchResult]:
table = _PARSER.get(table, {})
sql_query = table.get(query, {}).get("played")
if not table:
raise InvalidTableError(f"{table} is not a valid table in the database.")
output = []
for index, row in enumerate(self.database.execute(sql_query, values), start=1):
if index % 50 == 0:
await asyncio.sleep(0.01)
output.append(CacheLastFetchResult(*row))
return output
async def fetch_random(
self, table: str, query: str, values: Dict[str, Union[str, int]]
) -> CacheLastFetchResult:
table = _PARSER.get(table, {})
sql_query = table.get(query, {}).get("played")
if not table:
raise InvalidTableError(f"{table} is not a valid table in the database.")
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
for future in concurrent.futures.as_completed(
[executor.submit(self.database.execute, sql_query, values)]
):
try:
row = future.result()
row = row.fetchone()
except Exception as exc:
log.debug(f"Failed to completed random fetch from database", exc_info=exc)
return CacheLastFetchResult(*row)
class PlaylistInterface:
def __init__(self):
self.cursor = database_connection.cursor()
self.cursor.execute(PRAGMA_SET_temp_store)
self.cursor.execute(PRAGMA_SET_journal_mode)
self.cursor.execute(PRAGMA_SET_read_uncommitted)
self.cursor.execute(PLAYLIST_CREATE_TABLE)
self.cursor.execute(PLAYLIST_CREATE_INDEX)
@staticmethod
def close():
with contextlib.suppress(Exception):
database_connection.close()
@staticmethod
def get_scope_type(scope: str) -> int:
if scope == PlaylistScope.GLOBAL.value:
table = 1
elif scope == PlaylistScope.USER.value:
table = 3
else:
table = 2
return table
def fetch(self, scope: str, playlist_id: int, scope_id: int) -> PlaylistFetchResult:
scope_type = self.get_scope_type(scope)
row = (
self.cursor.execute(
PLAYLIST_FETCH,
({"playlist_id": playlist_id, "scope_id": scope_id, "scope_type": scope_type}),
).fetchone()
or []
)
return PlaylistFetchResult(*row) if row else None
async def fetch_all(
self, scope: str, scope_id: int, author_id=None
) -> List[PlaylistFetchResult]:
scope_type = self.get_scope_type(scope)
if author_id is not None:
output = []
for index, row in enumerate(
self.cursor.execute(
PLAYLIST_FETCH_ALL_WITH_FILTER,
({"scope_type": scope_type, "scope_id": scope_id, "author_id": author_id}),
),
start=1,
):
if index % 50 == 0:
await asyncio.sleep(0.01)
output.append(row)
else:
output = []
for index, row in enumerate(
self.cursor.execute(
PLAYLIST_FETCH_ALL, ({"scope_type": scope_type, "scope_id": scope_id})
),
start=1,
):
if index % 50 == 0:
await asyncio.sleep(0.01)
output.append(row)
return [PlaylistFetchResult(*row) for row in output] if output else []
async def fetch_all_converter(
self, scope: str, playlist_name, playlist_id
) -> List[PlaylistFetchResult]:
scope_type = self.get_scope_type(scope)
try:
playlist_id = int(playlist_id)
except Exception:
playlist_id = -1
output = []
for index, row in enumerate(
self.cursor.execute(
PLAYLIST_FETCH_ALL_CONVERTER,
(
{
"scope_type": scope_type,
"playlist_name": playlist_name,
"playlist_id": playlist_id,
}
),
),
start=1,
):
if index % 50 == 0:
await asyncio.sleep(0.01)
output.append(row)
return [PlaylistFetchResult(*row) for row in output] if output else []
def delete(self, scope: str, playlist_id: int, scope_id: int):
scope_type = self.get_scope_type(scope)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(
self.cursor.execute,
PLAYLIST_DELETE,
({"playlist_id": playlist_id, "scope_id": scope_id, "scope_type": scope_type}),
)
def delete_scheduled(self):
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(self.cursor.execute, PLAYLIST_DELETE_SCHEDULED)
def drop(self, scope: str):
scope_type = self.get_scope_type(scope)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(
self.cursor.execute, PLAYLIST_DELETE_SCOPE, ({"scope_type": scope_type})
)
def create_table(self, scope: str):
scope_type = self.get_scope_type(scope)
return self.cursor.execute(PLAYLIST_CREATE_TABLE, ({"scope_type": scope_type}))
def upsert(
self,
scope: str,
playlist_id: int,
playlist_name: str,
scope_id: int,
author_id: int,
playlist_url: Optional[str],
tracks: List[MutableMapping],
):
scope_type = self.get_scope_type(scope)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(
self.cursor.execute,
PLAYLIST_UPSERT,
{
"scope_type": str(scope_type),
"playlist_id": int(playlist_id),
"playlist_name": str(playlist_name),
"scope_id": int(scope_id),
"author_id": int(author_id),
"playlist_url": playlist_url,
"tracks": json.dumps(tracks),
},
)