mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 11:18:54 -05:00
* Limit Playlists Signed-off-by: Drapersniper <27962761+drapersniper@users.noreply.github.com> * Hotfix Signed-off-by: Drapersniper <27962761+drapersniper@users.noreply.github.com> * Hotfix Signed-off-by: Drapersniper <27962761+drapersniper@users.noreply.github.com>
375 lines
13 KiB
Python
375 lines
13 KiB
Python
import asyncio
|
|
import concurrent.futures
|
|
import contextlib
|
|
import datetime
|
|
import json
|
|
import logging
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Dict, List, Optional, TYPE_CHECKING, Tuple, Union, MutableMapping, Mapping
|
|
|
|
import apsw
|
|
|
|
from redbot.core import Config
|
|
from redbot.core.bot import Red
|
|
from redbot.core.data_manager import cog_data_path
|
|
|
|
from .errors import InvalidTableError
|
|
from .sql_statements import *
|
|
from .utils import PlaylistScope
|
|
|
|
log = logging.getLogger("red.audio.database")
|
|
|
|
if TYPE_CHECKING:
|
|
database_connection: apsw.Connection
|
|
_bot: Red
|
|
_config: Config
|
|
else:
|
|
_config = None
|
|
_bot = None
|
|
database_connection = None
|
|
|
|
|
|
SCHEMA_VERSION = 3
|
|
SQLError = apsw.ExecutionCompleteError
|
|
|
|
|
|
_PARSER: Mapping = {
|
|
"youtube": {
|
|
"insert": YOUTUBE_UPSERT,
|
|
"youtube_url": {"query": YOUTUBE_QUERY},
|
|
"update": YOUTUBE_UPDATE,
|
|
},
|
|
"spotify": {
|
|
"insert": SPOTIFY_UPSERT,
|
|
"track_info": {"query": SPOTIFY_QUERY},
|
|
"update": SPOTIFY_UPDATE,
|
|
},
|
|
"lavalink": {
|
|
"insert": LAVALINK_UPSERT,
|
|
"data": {"query": LAVALINK_QUERY, "played": LAVALINK_QUERY_LAST_FETCHED_RANDOM},
|
|
"update": LAVALINK_UPDATE,
|
|
},
|
|
}
|
|
|
|
|
|
def _pass_config_to_databases(config: Config, bot: Red):
|
|
global _config, _bot, database_connection
|
|
if _config is None:
|
|
_config = config
|
|
if _bot is None:
|
|
_bot = bot
|
|
if database_connection is None:
|
|
database_connection = apsw.Connection(
|
|
str(cog_data_path(_bot.get_cog("Audio")) / "Audio.db")
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class PlaylistFetchResult:
|
|
playlist_id: int
|
|
playlist_name: str
|
|
scope_id: int
|
|
author_id: int
|
|
playlist_url: Optional[str] = None
|
|
tracks: List[MutableMapping] = field(default_factory=lambda: [])
|
|
|
|
def __post_init__(self):
|
|
if isinstance(self.tracks, str):
|
|
self.tracks = json.loads(self.tracks)
|
|
|
|
|
|
@dataclass
|
|
class CacheFetchResult:
|
|
query: Optional[Union[str, MutableMapping]]
|
|
last_updated: int
|
|
|
|
def __post_init__(self):
|
|
if isinstance(self.last_updated, int):
|
|
self.updated_on: datetime.datetime = datetime.datetime.fromtimestamp(self.last_updated)
|
|
if isinstance(self.query, str) and all(
|
|
k in self.query for k in ["loadType", "playlistInfo", "isSeekable", "isStream"]
|
|
):
|
|
self.query = json.loads(self.query)
|
|
else:
|
|
self.query = None
|
|
|
|
|
|
@dataclass
|
|
class CacheLastFetchResult:
|
|
tracks: List[MutableMapping] = field(default_factory=lambda: [])
|
|
|
|
def __post_init__(self):
|
|
if isinstance(self.tracks, str):
|
|
self.tracks = json.loads(self.tracks)
|
|
|
|
|
|
@dataclass
|
|
class CacheGetAllLavalink:
|
|
query: str
|
|
data: List[MutableMapping] = field(default_factory=lambda: [])
|
|
|
|
def __post_init__(self):
|
|
if isinstance(self.data, str):
|
|
self.data = json.loads(self.data)
|
|
|
|
|
|
class CacheInterface:
|
|
def __init__(self):
|
|
self.database = database_connection.cursor()
|
|
|
|
@staticmethod
|
|
def close():
|
|
with contextlib.suppress(Exception):
|
|
database_connection.close()
|
|
|
|
async def init(self):
|
|
self.database.execute(PRAGMA_SET_temp_store)
|
|
self.database.execute(PRAGMA_SET_journal_mode)
|
|
self.database.execute(PRAGMA_SET_read_uncommitted)
|
|
self.maybe_migrate()
|
|
|
|
self.database.execute(LAVALINK_CREATE_TABLE)
|
|
self.database.execute(LAVALINK_CREATE_INDEX)
|
|
self.database.execute(YOUTUBE_CREATE_TABLE)
|
|
self.database.execute(YOUTUBE_CREATE_INDEX)
|
|
self.database.execute(SPOTIFY_CREATE_TABLE)
|
|
self.database.execute(SPOTIFY_CREATE_INDEX)
|
|
|
|
await self.clean_up_old_entries()
|
|
|
|
async def clean_up_old_entries(self):
|
|
max_age = await _config.cache_age()
|
|
maxage = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=max_age)
|
|
maxage_int = int(time.mktime(maxage.timetuple()))
|
|
values = {"maxage": maxage_int}
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
|
executor.submit(self.database.execute, LAVALINK_DELETE_OLD_ENTRIES, values)
|
|
executor.submit(self.database.execute, YOUTUBE_DELETE_OLD_ENTRIES, values)
|
|
executor.submit(self.database.execute, SPOTIFY_DELETE_OLD_ENTRIES, values)
|
|
|
|
def maybe_migrate(self):
|
|
current_version = self.database.execute(PRAGMA_FETCH_user_version).fetchone()
|
|
if isinstance(current_version, tuple):
|
|
current_version = current_version[0]
|
|
if current_version == SCHEMA_VERSION:
|
|
return
|
|
self.database.execute(PRAGMA_SET_user_version, {"version": SCHEMA_VERSION})
|
|
|
|
async def insert(self, table: str, values: List[MutableMapping]):
|
|
try:
|
|
query = _PARSER.get(table, {}).get("insert")
|
|
if query is None:
|
|
raise InvalidTableError(f"{table} is not a valid table in the database.")
|
|
self.database.execute("BEGIN;")
|
|
self.database.executemany(query, values)
|
|
self.database.execute("COMMIT;")
|
|
except Exception as err:
|
|
log.debug("Error during audio db insert", exc_info=err)
|
|
|
|
async def update(self, table: str, values: Dict[str, Union[str, int]]):
|
|
try:
|
|
table = _PARSER.get(table, {})
|
|
sql_query = table.get("update")
|
|
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
|
values["last_fetched"] = time_now
|
|
if not table:
|
|
raise InvalidTableError(f"{table} is not a valid table in the database.")
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
|
executor.submit(self.database.execute, sql_query, values)
|
|
except Exception as err:
|
|
log.debug("Error during audio db update", exc_info=err)
|
|
|
|
async def fetch_one(
|
|
self, table: str, query: str, values: Dict[str, Union[str, int]]
|
|
) -> Tuple[Optional[str], bool]:
|
|
table = _PARSER.get(table, {})
|
|
sql_query = table.get(query, {}).get("query")
|
|
if not table:
|
|
raise InvalidTableError(f"{table} is not a valid table in the database.")
|
|
max_age = await _config.cache_age()
|
|
maxage = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=max_age)
|
|
maxage_int = int(time.mktime(maxage.timetuple()))
|
|
values.update({"maxage": maxage_int})
|
|
output = self.database.execute(sql_query, values).fetchone() or (None, 0)
|
|
result = CacheFetchResult(*output)
|
|
return result.query, False
|
|
|
|
async def fetch_all(
|
|
self, table: str, query: str, values: Dict[str, Union[str, int]]
|
|
) -> List[CacheLastFetchResult]:
|
|
table = _PARSER.get(table, {})
|
|
sql_query = table.get(query, {}).get("played")
|
|
if not table:
|
|
raise InvalidTableError(f"{table} is not a valid table in the database.")
|
|
|
|
output = []
|
|
for index, row in enumerate(self.database.execute(sql_query, values), start=1):
|
|
if index % 50 == 0:
|
|
await asyncio.sleep(0.01)
|
|
output.append(CacheLastFetchResult(*row))
|
|
return output
|
|
|
|
async def fetch_random(
|
|
self, table: str, query: str, values: Dict[str, Union[str, int]]
|
|
) -> CacheLastFetchResult:
|
|
table = _PARSER.get(table, {})
|
|
sql_query = table.get(query, {}).get("played")
|
|
if not table:
|
|
raise InvalidTableError(f"{table} is not a valid table in the database.")
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
|
for future in concurrent.futures.as_completed(
|
|
[executor.submit(self.database.execute, sql_query, values)]
|
|
):
|
|
try:
|
|
row = future.result()
|
|
row = row.fetchone()
|
|
except Exception as exc:
|
|
log.debug(f"Failed to completed random fetch from database", exc_info=exc)
|
|
return CacheLastFetchResult(*row)
|
|
|
|
|
|
class PlaylistInterface:
|
|
def __init__(self):
|
|
self.cursor = database_connection.cursor()
|
|
self.cursor.execute(PRAGMA_SET_temp_store)
|
|
self.cursor.execute(PRAGMA_SET_journal_mode)
|
|
self.cursor.execute(PRAGMA_SET_read_uncommitted)
|
|
self.cursor.execute(PLAYLIST_CREATE_TABLE)
|
|
self.cursor.execute(PLAYLIST_CREATE_INDEX)
|
|
|
|
@staticmethod
|
|
def close():
|
|
with contextlib.suppress(Exception):
|
|
database_connection.close()
|
|
|
|
@staticmethod
|
|
def get_scope_type(scope: str) -> int:
|
|
if scope == PlaylistScope.GLOBAL.value:
|
|
table = 1
|
|
elif scope == PlaylistScope.USER.value:
|
|
table = 3
|
|
else:
|
|
table = 2
|
|
return table
|
|
|
|
def fetch(self, scope: str, playlist_id: int, scope_id: int) -> PlaylistFetchResult:
|
|
scope_type = self.get_scope_type(scope)
|
|
row = (
|
|
self.cursor.execute(
|
|
PLAYLIST_FETCH,
|
|
({"playlist_id": playlist_id, "scope_id": scope_id, "scope_type": scope_type}),
|
|
).fetchone()
|
|
or []
|
|
)
|
|
|
|
return PlaylistFetchResult(*row) if row else None
|
|
|
|
async def fetch_all(
|
|
self, scope: str, scope_id: int, author_id=None
|
|
) -> List[PlaylistFetchResult]:
|
|
scope_type = self.get_scope_type(scope)
|
|
if author_id is not None:
|
|
output = []
|
|
for index, row in enumerate(
|
|
self.cursor.execute(
|
|
PLAYLIST_FETCH_ALL_WITH_FILTER,
|
|
({"scope_type": scope_type, "scope_id": scope_id, "author_id": author_id}),
|
|
),
|
|
start=1,
|
|
):
|
|
if index % 50 == 0:
|
|
await asyncio.sleep(0.01)
|
|
output.append(row)
|
|
else:
|
|
output = []
|
|
for index, row in enumerate(
|
|
self.cursor.execute(
|
|
PLAYLIST_FETCH_ALL, ({"scope_type": scope_type, "scope_id": scope_id})
|
|
),
|
|
start=1,
|
|
):
|
|
if index % 50 == 0:
|
|
await asyncio.sleep(0.01)
|
|
output.append(row)
|
|
return [PlaylistFetchResult(*row) for row in output] if output else []
|
|
|
|
async def fetch_all_converter(
|
|
self, scope: str, playlist_name, playlist_id
|
|
) -> List[PlaylistFetchResult]:
|
|
scope_type = self.get_scope_type(scope)
|
|
try:
|
|
playlist_id = int(playlist_id)
|
|
except Exception:
|
|
playlist_id = -1
|
|
|
|
output = []
|
|
for index, row in enumerate(
|
|
self.cursor.execute(
|
|
PLAYLIST_FETCH_ALL_CONVERTER,
|
|
(
|
|
{
|
|
"scope_type": scope_type,
|
|
"playlist_name": playlist_name,
|
|
"playlist_id": playlist_id,
|
|
}
|
|
),
|
|
),
|
|
start=1,
|
|
):
|
|
if index % 50 == 0:
|
|
await asyncio.sleep(0.01)
|
|
output.append(row)
|
|
return [PlaylistFetchResult(*row) for row in output] if output else []
|
|
|
|
def delete(self, scope: str, playlist_id: int, scope_id: int):
|
|
scope_type = self.get_scope_type(scope)
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
|
executor.submit(
|
|
self.cursor.execute,
|
|
PLAYLIST_DELETE,
|
|
({"playlist_id": playlist_id, "scope_id": scope_id, "scope_type": scope_type}),
|
|
)
|
|
|
|
def delete_scheduled(self):
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
|
executor.submit(self.cursor.execute, PLAYLIST_DELETE_SCHEDULED)
|
|
|
|
def drop(self, scope: str):
|
|
scope_type = self.get_scope_type(scope)
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
|
executor.submit(
|
|
self.cursor.execute, PLAYLIST_DELETE_SCOPE, ({"scope_type": scope_type})
|
|
)
|
|
|
|
def create_table(self, scope: str):
|
|
scope_type = self.get_scope_type(scope)
|
|
return self.cursor.execute(PLAYLIST_CREATE_TABLE, ({"scope_type": scope_type}))
|
|
|
|
def upsert(
|
|
self,
|
|
scope: str,
|
|
playlist_id: int,
|
|
playlist_name: str,
|
|
scope_id: int,
|
|
author_id: int,
|
|
playlist_url: Optional[str],
|
|
tracks: List[MutableMapping],
|
|
):
|
|
scope_type = self.get_scope_type(scope)
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
|
executor.submit(
|
|
self.cursor.execute,
|
|
PLAYLIST_UPSERT,
|
|
{
|
|
"scope_type": str(scope_type),
|
|
"playlist_id": int(playlist_id),
|
|
"playlist_name": str(playlist_name),
|
|
"scope_id": int(scope_id),
|
|
"author_id": int(author_id),
|
|
"playlist_url": playlist_url,
|
|
"tracks": json.dumps(tracks),
|
|
},
|
|
)
|