import concurrent import contextlib import datetime import random import time from pathlib import Path from types import SimpleNamespace from typing import TYPE_CHECKING, Callable, List, MutableMapping, Optional, Tuple, Union from red_commons.logging import getLogger from redbot.core import Config from redbot.core.bot import Red from redbot.core.commands import Cog from redbot.core.i18n import Translator from redbot.core.utils import AsyncIter from redbot.core.utils.dbtools import APSWConnectionWrapper from ..sql_statements import ( LAVALINK_CREATE_INDEX, LAVALINK_CREATE_TABLE, LAVALINK_DELETE_OLD_ENTRIES, LAVALINK_FETCH_ALL_ENTRIES_GLOBAL, LAVALINK_QUERY, LAVALINK_QUERY_ALL, LAVALINK_QUERY_LAST_FETCHED_RANDOM, LAVALINK_UPDATE, LAVALINK_UPSERT, SPOTIFY_CREATE_INDEX, SPOTIFY_CREATE_TABLE, SPOTIFY_DELETE_OLD_ENTRIES, SPOTIFY_QUERY, SPOTIFY_QUERY_ALL, SPOTIFY_QUERY_LAST_FETCHED_RANDOM, SPOTIFY_UPDATE, SPOTIFY_UPSERT, YOUTUBE_CREATE_INDEX, YOUTUBE_CREATE_TABLE, YOUTUBE_DELETE_OLD_ENTRIES, YOUTUBE_QUERY, YOUTUBE_QUERY_ALL, YOUTUBE_QUERY_LAST_FETCHED_RANDOM, YOUTUBE_UPDATE, YOUTUBE_UPSERT, PRAGMA_FETCH_user_version, PRAGMA_SET_journal_mode, PRAGMA_SET_read_uncommitted, PRAGMA_SET_temp_store, PRAGMA_SET_user_version, ) from .api_utils import ( LavalinkCacheFetchForGlobalResult, LavalinkCacheFetchResult, SpotifyCacheFetchResult, YouTubeCacheFetchResult, ) if TYPE_CHECKING: from .. import Audio log = getLogger("red.cogs.Audio.api.LocalDB") _ = Translator("Audio", Path(__file__)) _SCHEMA_VERSION = 3 class BaseWrapper: def __init__( self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog] ): self.bot = bot self.config = config self.database = conn self.statement = SimpleNamespace() self.statement.pragma_temp_store = PRAGMA_SET_temp_store self.statement.pragma_journal_mode = PRAGMA_SET_journal_mode self.statement.pragma_read_uncommitted = PRAGMA_SET_read_uncommitted self.statement.set_user_version = PRAGMA_SET_user_version self.statement.get_user_version = PRAGMA_FETCH_user_version self.fetch_result: Optional[Callable] = None self.cog = cog async def init(self) -> None: """Initialize the local cache""" with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: executor.submit(self.database.cursor().execute, self.statement.pragma_temp_store) executor.submit(self.database.cursor().execute, self.statement.pragma_journal_mode) executor.submit(self.database.cursor().execute, self.statement.pragma_read_uncommitted) executor.submit(self.maybe_migrate) executor.submit(self.database.cursor().execute, LAVALINK_CREATE_TABLE) executor.submit(self.database.cursor().execute, LAVALINK_CREATE_INDEX) executor.submit(self.database.cursor().execute, YOUTUBE_CREATE_TABLE) executor.submit(self.database.cursor().execute, YOUTUBE_CREATE_INDEX) executor.submit(self.database.cursor().execute, SPOTIFY_CREATE_TABLE) executor.submit(self.database.cursor().execute, SPOTIFY_CREATE_INDEX) await self.clean_up_old_entries() def close(self) -> None: """Close the connection with the local cache""" with contextlib.suppress(Exception): self.database.close() async def clean_up_old_entries(self) -> None: """Delete entries older than x in the local cache tables""" max_age = await self.config.cache_age() maxage = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=max_age) maxage_int = int(time.mktime(maxage.timetuple())) values = {"maxage": maxage_int} with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: executor.submit(self.database.cursor().execute, LAVALINK_DELETE_OLD_ENTRIES, values) executor.submit(self.database.cursor().execute, YOUTUBE_DELETE_OLD_ENTRIES, values) executor.submit(self.database.cursor().execute, SPOTIFY_DELETE_OLD_ENTRIES, values) def maybe_migrate(self) -> None: """Maybe migrate Database schema for the local cache""" current_version = 0 with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: for future in concurrent.futures.as_completed( [executor.submit(self.database.cursor().execute, self.statement.get_user_version)] ): try: row_result = future.result() current_version = row_result.fetchone() break except Exception as exc: log.verbose("Failed to completed fetch from database", exc_info=exc) if isinstance(current_version, tuple): current_version = current_version[0] if current_version == _SCHEMA_VERSION: return executor.submit( self.database.cursor().execute, self.statement.set_user_version, {"version": _SCHEMA_VERSION}, ) async def insert(self, values: List[MutableMapping]) -> None: """Insert an entry into the local cache""" try: with self.database.transaction() as transaction: transaction.executemany(self.statement.upsert, values) except Exception as exc: log.trace("Error during table insert", exc_info=exc) async def update(self, values: MutableMapping) -> None: """Update an entry of the local cache""" try: time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp()) values["last_fetched"] = time_now with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: executor.submit(self.database.cursor().execute, self.statement.update, values) except Exception as exc: log.verbose("Error during table update", exc_info=exc) async def _fetch_one( self, values: MutableMapping ) -> Optional[ Union[LavalinkCacheFetchResult, SpotifyCacheFetchResult, YouTubeCacheFetchResult] ]: """Get an entry from the local cache""" max_age = await self.config.cache_age() maxage = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=max_age) maxage_int = int(time.mktime(maxage.timetuple())) values.update({"maxage": maxage_int}) row = None with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: for future in concurrent.futures.as_completed( [executor.submit(self.database.cursor().execute, self.statement.get_one, values)] ): try: row_result = future.result() row = row_result.fetchone() except Exception as exc: log.verbose("Failed to completed fetch from database", exc_info=exc) if not row: return None if self.fetch_result is None: return None return self.fetch_result(*row) async def _fetch_all( self, values: MutableMapping ) -> List[Union[LavalinkCacheFetchResult, SpotifyCacheFetchResult, YouTubeCacheFetchResult]]: """Get all entries from the local cache""" output = [] row_result = [] if self.fetch_result is None: return [] with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: for future in concurrent.futures.as_completed( [executor.submit(self.database.cursor().execute, self.statement.get_all, values)] ): try: row_result = future.result() except Exception as exc: log.verbose("Failed to completed fetch from database", exc_info=exc) async for row in AsyncIter(row_result): output.append(self.fetch_result(*row)) return output async def _fetch_random( self, values: MutableMapping ) -> Optional[ Union[LavalinkCacheFetchResult, SpotifyCacheFetchResult, YouTubeCacheFetchResult] ]: """Get a random entry from the local cache""" row = None with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: for future in concurrent.futures.as_completed( [ executor.submit( self.database.cursor().execute, self.statement.get_random, values ) ] ): try: row_result = future.result() rows = row_result.fetchall() if rows: row = random.choice(rows) else: row = None except Exception as exc: log.verbose("Failed to completed random fetch from database", exc_info=exc) if not row: return None if self.fetch_result is None: return None return self.fetch_result(*row) class YouTubeTableWrapper(BaseWrapper): def __init__( self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog] ): super().__init__(bot, config, conn, cog) self.statement.upsert = YOUTUBE_UPSERT self.statement.update = YOUTUBE_UPDATE self.statement.get_one = YOUTUBE_QUERY self.statement.get_all = YOUTUBE_QUERY_ALL self.statement.get_random = YOUTUBE_QUERY_LAST_FETCHED_RANDOM self.fetch_result = YouTubeCacheFetchResult async def fetch_one( self, values: MutableMapping ) -> Tuple[Optional[str], Optional[datetime.datetime]]: """Get an entry from the Youtube table""" result = await self._fetch_one(values) if not result or not isinstance(result.query, str): return None, None return result.query, result.updated_on async def fetch_all(self, values: MutableMapping) -> List[YouTubeCacheFetchResult]: """Get all entries from the Youtube table""" result = await self._fetch_all(values) if result and isinstance(result[0], YouTubeCacheFetchResult): return result return [] async def fetch_random(self, values: MutableMapping) -> Optional[str]: """Get a random entry from the Youtube table""" result = await self._fetch_random(values) if not result or not isinstance(result.query, str): return None return result.query class SpotifyTableWrapper(BaseWrapper): def __init__( self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog] ): super().__init__(bot, config, conn, cog) self.statement.upsert = SPOTIFY_UPSERT self.statement.update = SPOTIFY_UPDATE self.statement.get_one = SPOTIFY_QUERY self.statement.get_all = SPOTIFY_QUERY_ALL self.statement.get_random = SPOTIFY_QUERY_LAST_FETCHED_RANDOM self.fetch_result = SpotifyCacheFetchResult async def fetch_one( self, values: MutableMapping ) -> Tuple[Optional[str], Optional[datetime.datetime]]: """Get an entry from the Spotify table""" result = await self._fetch_one(values) if not result or not isinstance(result.query, str): return None, None return result.query, result.updated_on async def fetch_all(self, values: MutableMapping) -> List[SpotifyCacheFetchResult]: """Get all entries from the Spotify table""" result = await self._fetch_all(values) if result and isinstance(result[0], SpotifyCacheFetchResult): return result return [] async def fetch_random(self, values: MutableMapping) -> Optional[str]: """Get a random entry from the Spotify table""" result = await self._fetch_random(values) if not result or not isinstance(result.query, str): return None return result.query class LavalinkTableWrapper(BaseWrapper): def __init__( self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog] ): super().__init__(bot, config, conn, cog) self.statement.upsert = LAVALINK_UPSERT self.statement.update = LAVALINK_UPDATE self.statement.get_one = LAVALINK_QUERY self.statement.get_all = LAVALINK_QUERY_ALL self.statement.get_random = LAVALINK_QUERY_LAST_FETCHED_RANDOM self.statement.get_all_global = LAVALINK_FETCH_ALL_ENTRIES_GLOBAL self.fetch_result = LavalinkCacheFetchResult self.fetch_for_global: Optional[Callable] = LavalinkCacheFetchForGlobalResult async def fetch_one( self, values: MutableMapping ) -> Tuple[Optional[MutableMapping], Optional[datetime.datetime]]: """Get an entry from the Lavalink table""" result = await self._fetch_one(values) if not result or not isinstance(result.query, dict): return None, None return result.query, result.updated_on async def fetch_all(self, values: MutableMapping) -> List[LavalinkCacheFetchResult]: """Get all entries from the Lavalink table""" result = await self._fetch_all(values) if result and isinstance(result[0], LavalinkCacheFetchResult): return result return [] async def fetch_random(self, values: MutableMapping) -> Optional[MutableMapping]: """Get a random entry from the Lavalink table""" result = await self._fetch_random(values) if not result or not isinstance(result.query, dict): return None return result.query async def fetch_all_for_global(self) -> List[LavalinkCacheFetchForGlobalResult]: """Get all entries from the Lavalink table""" output: List[LavalinkCacheFetchForGlobalResult] = [] row_result = [] if self.fetch_for_global is None: return [] with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: for future in concurrent.futures.as_completed( [executor.submit(self.database.cursor().execute, self.statement.get_all_global)] ): try: row_result = future.result() except Exception as exc: log.verbose("Failed to completed fetch from database", exc_info=exc) async for row in AsyncIter(row_result): output.append(self.fetch_for_global(*row)) return output class LocalCacheWrapper: """Wraps all table apis into 1 object representing the local cache""" def __init__( self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog] ): self.bot = bot self.config = config self.database = conn self.cog = cog self.lavalink: LavalinkTableWrapper = LavalinkTableWrapper(bot, config, conn, self.cog) self.spotify: SpotifyTableWrapper = SpotifyTableWrapper(bot, config, conn, self.cog) self.youtube: YouTubeTableWrapper = YouTubeTableWrapper(bot, config, conn, self.cog)