mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-20 18:06:08 -05:00
Merge V3/feature/audio into V3/develop (a.k.a. audio refactor) (#3459)
This commit is contained in:
10
redbot/cogs/audio/apis/__init__.py
Normal file
10
redbot/cogs/audio/apis/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
from . import (
|
||||
api_utils,
|
||||
global_db,
|
||||
interface,
|
||||
local_db,
|
||||
playlist_interface,
|
||||
playlist_wrapper,
|
||||
spotify,
|
||||
youtube,
|
||||
)
|
||||
140
redbot/cogs/audio/apis/api_utils.py
Normal file
140
redbot/cogs/audio/apis/api_utils.py
Normal file
@@ -0,0 +1,140 @@
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, MutableMapping, Optional, Union
|
||||
|
||||
import discord
|
||||
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core.utils.chat_formatting import humanize_list
|
||||
|
||||
from ..errors import InvalidPlaylistScope, MissingAuthor, MissingGuild
|
||||
from ..utils import PlaylistScope
|
||||
|
||||
log = logging.getLogger("red.cogs.Audio.api.utils")
|
||||
|
||||
|
||||
@dataclass
|
||||
class YouTubeCacheFetchResult:
|
||||
query: Optional[str]
|
||||
last_updated: int
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.last_updated, int):
|
||||
self.updated_on: datetime.datetime = datetime.datetime.fromtimestamp(self.last_updated)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SpotifyCacheFetchResult:
|
||||
query: Optional[str]
|
||||
last_updated: int
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.last_updated, int):
|
||||
self.updated_on: datetime.datetime = datetime.datetime.fromtimestamp(self.last_updated)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LavalinkCacheFetchResult:
|
||||
query: Optional[MutableMapping]
|
||||
last_updated: int
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.last_updated, int):
|
||||
self.updated_on: datetime.datetime = datetime.datetime.fromtimestamp(self.last_updated)
|
||||
|
||||
if isinstance(self.query, str):
|
||||
self.query = json.loads(self.query)
|
||||
|
||||
|
||||
@dataclass
|
||||
class LavalinkCacheFetchForGlobalResult:
|
||||
query: str
|
||||
data: MutableMapping
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.data, str):
|
||||
self.data_string = str(self.data)
|
||||
self.data = json.loads(self.data)
|
||||
|
||||
|
||||
@dataclass
|
||||
class PlaylistFetchResult:
|
||||
playlist_id: int
|
||||
playlist_name: str
|
||||
scope_id: int
|
||||
author_id: int
|
||||
playlist_url: Optional[str] = None
|
||||
tracks: List[MutableMapping] = field(default_factory=lambda: [])
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.tracks, str):
|
||||
self.tracks = json.loads(self.tracks)
|
||||
|
||||
|
||||
def standardize_scope(scope: str) -> str:
|
||||
"""Convert any of the used scopes into one we are expecting"""
|
||||
scope = scope.upper()
|
||||
valid_scopes = ["GLOBAL", "GUILD", "AUTHOR", "USER", "SERVER", "MEMBER", "BOT"]
|
||||
|
||||
if scope in PlaylistScope.list():
|
||||
return scope
|
||||
elif scope not in valid_scopes:
|
||||
raise InvalidPlaylistScope(
|
||||
f'"{scope}" is not a valid playlist scope.'
|
||||
f" Scope needs to be one of the following: {humanize_list(valid_scopes)}"
|
||||
)
|
||||
|
||||
if scope in ["GLOBAL", "BOT"]:
|
||||
scope = PlaylistScope.GLOBAL.value
|
||||
elif scope in ["GUILD", "SERVER"]:
|
||||
scope = PlaylistScope.GUILD.value
|
||||
elif scope in ["USER", "MEMBER", "AUTHOR"]:
|
||||
scope = PlaylistScope.USER.value
|
||||
|
||||
return scope
|
||||
|
||||
|
||||
def prepare_config_scope(
|
||||
bot: Red,
|
||||
scope,
|
||||
author: Union[discord.abc.User, int] = None,
|
||||
guild: Union[discord.Guild, int] = None,
|
||||
):
|
||||
"""Return the scope used by Playlists"""
|
||||
scope = standardize_scope(scope)
|
||||
if scope == PlaylistScope.GLOBAL.value:
|
||||
config_scope = [PlaylistScope.GLOBAL.value, bot.user.id]
|
||||
elif scope == PlaylistScope.USER.value:
|
||||
if author is None:
|
||||
raise MissingAuthor("Invalid author for user scope.")
|
||||
config_scope = [PlaylistScope.USER.value, int(getattr(author, "id", author))]
|
||||
else:
|
||||
if guild is None:
|
||||
raise MissingGuild("Invalid guild for guild scope.")
|
||||
config_scope = [PlaylistScope.GUILD.value, int(getattr(guild, "id", guild))]
|
||||
return config_scope
|
||||
|
||||
|
||||
def prepare_config_scope_for_migration23( # TODO: remove me in a future version ?
|
||||
scope, author: Union[discord.abc.User, int] = None, guild: discord.Guild = None
|
||||
):
|
||||
"""Return the scope used by Playlists"""
|
||||
scope = standardize_scope(scope)
|
||||
|
||||
if scope == PlaylistScope.GLOBAL.value:
|
||||
config_scope = [PlaylistScope.GLOBAL.value]
|
||||
elif scope == PlaylistScope.USER.value:
|
||||
if author is None:
|
||||
raise MissingAuthor("Invalid author for user scope.")
|
||||
config_scope = [PlaylistScope.USER.value, str(getattr(author, "id", author))]
|
||||
else:
|
||||
if guild is None:
|
||||
raise MissingGuild("Invalid guild for guild scope.")
|
||||
config_scope = [PlaylistScope.GUILD.value, str(getattr(guild, "id", guild))]
|
||||
return config_scope
|
||||
|
||||
|
||||
FakePlaylist = namedtuple("Playlist", "author scope")
|
||||
42
redbot/cogs/audio/apis/global_db.py
Normal file
42
redbot/cogs/audio/apis/global_db.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import logging
|
||||
import urllib.parse
|
||||
from typing import Mapping, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import aiohttp
|
||||
from lavalink.rest_api import LoadResult
|
||||
|
||||
from redbot.core import Config
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core.commands import Cog
|
||||
|
||||
from ..audio_dataclasses import Query
|
||||
from ..audio_logging import IS_DEBUG, debug_exc_log
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .. import Audio
|
||||
|
||||
_API_URL = "https://redbot.app/"
|
||||
|
||||
log = logging.getLogger("red.cogs.Audio.api.GlobalDB")
|
||||
|
||||
|
||||
class GlobalCacheWrapper:
|
||||
def __init__(
|
||||
self, bot: Red, config: Config, session: aiohttp.ClientSession, cog: Union["Audio", Cog]
|
||||
):
|
||||
# Place Holder for the Global Cache PR
|
||||
self.bot = bot
|
||||
self.config = config
|
||||
self.session = session
|
||||
self.api_key = None
|
||||
self._handshake_token = ""
|
||||
self.can_write = False
|
||||
self._handshake_token = ""
|
||||
self.has_api_key = None
|
||||
self._token: Mapping[str, str] = {}
|
||||
self.cog = cog
|
||||
|
||||
def update_token(self, new_token: Mapping[str, str]):
|
||||
self._token = new_token
|
||||
894
redbot/cogs/audio/apis/interface.py
Normal file
894
redbot/cogs/audio/apis/interface.py
Normal file
@@ -0,0 +1,894 @@
|
||||
import asyncio
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from collections import namedtuple
|
||||
from typing import Callable, List, MutableMapping, Optional, TYPE_CHECKING, Tuple, Union, cast
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
import lavalink
|
||||
from lavalink.rest_api import LoadResult
|
||||
from redbot.core.utils import AsyncIter
|
||||
|
||||
from redbot.core import Config, commands
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core.commands import Cog, Context
|
||||
from redbot.core.i18n import Translator
|
||||
from redbot.core.utils.dbtools import APSWConnectionWrapper
|
||||
|
||||
from ..audio_dataclasses import Query
|
||||
from ..audio_logging import IS_DEBUG, debug_exc_log
|
||||
from ..errors import DatabaseError, SpotifyFetchError, TrackEnqueueError
|
||||
from ..utils import CacheLevel, Notifier
|
||||
from .global_db import GlobalCacheWrapper
|
||||
from .local_db import LocalCacheWrapper
|
||||
from .playlist_interface import get_playlist
|
||||
from .playlist_wrapper import PlaylistWrapper
|
||||
from .spotify import SpotifyWrapper
|
||||
from .youtube import YouTubeWrapper
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .. import Audio
|
||||
|
||||
_ = Translator("Audio", __file__)
|
||||
log = logging.getLogger("red.cogs.Audio.api.AudioAPIInterface")
|
||||
_TOP_100_US = "https://www.youtube.com/playlist?list=PL4fGSI1pDJn5rWitrRWFKdm-ulaFiIyoK"
|
||||
|
||||
|
||||
class AudioAPIInterface:
|
||||
"""Handles music queries.
|
||||
|
||||
Always tries the Local cache first, then Global cache before making API calls.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot: Red,
|
||||
config: Config,
|
||||
session: aiohttp.ClientSession,
|
||||
conn: APSWConnectionWrapper,
|
||||
cog: Union["Audio", Cog],
|
||||
):
|
||||
self.bot = bot
|
||||
self.config = config
|
||||
self.conn = conn
|
||||
self.cog = cog
|
||||
self.spotify_api: SpotifyWrapper = SpotifyWrapper(self.bot, self.config, session, self.cog)
|
||||
self.youtube_api: YouTubeWrapper = YouTubeWrapper(self.bot, self.config, session, self.cog)
|
||||
self.local_cache_api = LocalCacheWrapper(self.bot, self.config, self.conn, self.cog)
|
||||
self.global_cache_api = GlobalCacheWrapper(self.bot, self.config, session, self.cog)
|
||||
self._session: aiohttp.ClientSession = session
|
||||
self._tasks: MutableMapping = {}
|
||||
self._lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialises the Local Cache connection"""
|
||||
await self.local_cache_api.lavalink.init()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the Local Cache connection"""
|
||||
self.local_cache_api.lavalink.close()
|
||||
|
||||
async def get_random_track_from_db(self) -> Optional[MutableMapping]:
|
||||
"""Get a random track from the local database and return it"""
|
||||
track: Optional[MutableMapping] = {}
|
||||
try:
|
||||
query_data = {}
|
||||
date = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=7)
|
||||
date_timestamp = int(date.timestamp())
|
||||
query_data["day"] = date_timestamp
|
||||
max_age = await self.config.cache_age()
|
||||
maxage = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(
|
||||
days=max_age
|
||||
)
|
||||
maxage_int = int(time.mktime(maxage.timetuple()))
|
||||
query_data["maxage"] = maxage_int
|
||||
track = await self.local_cache_api.lavalink.fetch_random(query_data)
|
||||
if track is not None:
|
||||
if track.get("loadType") == "V2_COMPACT":
|
||||
track["loadType"] = "V2_COMPAT"
|
||||
results = LoadResult(track)
|
||||
track = random.choice(list(results.tracks))
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed to fetch a random track from database")
|
||||
track = {}
|
||||
|
||||
if not track:
|
||||
return None
|
||||
|
||||
return track
|
||||
|
||||
async def route_tasks(
|
||||
self, action_type: str = None, data: Union[List[MutableMapping], MutableMapping] = None,
|
||||
) -> None:
|
||||
"""Separate the tasks and run them in the appropriate functions"""
|
||||
|
||||
if not data:
|
||||
return
|
||||
if action_type == "insert" and isinstance(data, list):
|
||||
for table, d in data:
|
||||
if table == "lavalink":
|
||||
await self.local_cache_api.lavalink.insert(d)
|
||||
elif table == "youtube":
|
||||
await self.local_cache_api.youtube.insert(d)
|
||||
elif table == "spotify":
|
||||
await self.local_cache_api.spotify.insert(d)
|
||||
elif action_type == "update" and isinstance(data, dict):
|
||||
for table, d in data:
|
||||
if table == "lavalink":
|
||||
await self.local_cache_api.lavalink.update(data)
|
||||
elif table == "youtube":
|
||||
await self.local_cache_api.youtube.update(data)
|
||||
elif table == "spotify":
|
||||
await self.local_cache_api.spotify.update(data)
|
||||
|
||||
async def run_tasks(self, ctx: Optional[commands.Context] = None, message_id=None) -> None:
|
||||
"""Run tasks for a specific context"""
|
||||
if message_id is not None:
|
||||
lock_id = message_id
|
||||
elif ctx is not None:
|
||||
lock_id = ctx.message.id
|
||||
else:
|
||||
return
|
||||
lock_author = ctx.author if ctx else None
|
||||
async with self._lock:
|
||||
if lock_id in self._tasks:
|
||||
if IS_DEBUG:
|
||||
log.debug(f"Running database writes for {lock_id} ({lock_author})")
|
||||
try:
|
||||
tasks = self._tasks[lock_id]
|
||||
tasks = [self.route_tasks(a, tasks[a]) for a in tasks]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
del self._tasks[lock_id]
|
||||
except Exception as exc:
|
||||
debug_exc_log(
|
||||
log, exc, f"Failed database writes for {lock_id} ({lock_author})"
|
||||
)
|
||||
else:
|
||||
if IS_DEBUG:
|
||||
log.debug(f"Completed database writes for {lock_id} ({lock_author})")
|
||||
|
||||
async def run_all_pending_tasks(self) -> None:
|
||||
"""Run all pending tasks left in the cache, called on cog_unload"""
|
||||
async with self._lock:
|
||||
if IS_DEBUG:
|
||||
log.debug("Running pending writes to database")
|
||||
try:
|
||||
tasks: MutableMapping = {"update": [], "insert": [], "global": []}
|
||||
async for k, task in AsyncIter(self._tasks.items()):
|
||||
async for t, args in AsyncIter(task.items()):
|
||||
tasks[t].append(args)
|
||||
self._tasks = {}
|
||||
coro_tasks = [self.route_tasks(a, tasks[a]) for a in tasks]
|
||||
|
||||
await asyncio.gather(*coro_tasks, return_exceptions=True)
|
||||
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed database writes")
|
||||
else:
|
||||
if IS_DEBUG:
|
||||
log.debug("Completed pending writes to database have finished")
|
||||
|
||||
def append_task(self, ctx: commands.Context, event: str, task: Tuple, _id: int = None) -> None:
|
||||
"""Add a task to the cache to be run later"""
|
||||
lock_id = _id or ctx.message.id
|
||||
if lock_id not in self._tasks:
|
||||
self._tasks[lock_id] = {"update": [], "insert": [], "global": []}
|
||||
self._tasks[lock_id][event].append(task)
|
||||
|
||||
async def fetch_spotify_query(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
query_type: str,
|
||||
uri: str,
|
||||
notifier: Optional[Notifier],
|
||||
skip_youtube: bool = False,
|
||||
current_cache_level: CacheLevel = CacheLevel.none(),
|
||||
) -> List[str]:
|
||||
"""Return youtube URLS for the spotify URL provided"""
|
||||
youtube_urls = []
|
||||
tracks = await self.fetch_from_spotify_api(
|
||||
query_type, uri, params=None, notifier=notifier, ctx=ctx
|
||||
)
|
||||
total_tracks = len(tracks)
|
||||
database_entries = []
|
||||
track_count = 0
|
||||
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
youtube_cache = CacheLevel.set_youtube().is_subset(current_cache_level)
|
||||
async for track in AsyncIter(tracks):
|
||||
if isinstance(track, str):
|
||||
break
|
||||
elif isinstance(track, dict) and track.get("error", {}).get("message") == "invalid id":
|
||||
continue
|
||||
(
|
||||
song_url,
|
||||
track_info,
|
||||
uri,
|
||||
artist_name,
|
||||
track_name,
|
||||
_id,
|
||||
_type,
|
||||
) = await self.spotify_api.get_spotify_track_info(track, ctx)
|
||||
|
||||
database_entries.append(
|
||||
{
|
||||
"id": _id,
|
||||
"type": _type,
|
||||
"uri": uri,
|
||||
"track_name": track_name,
|
||||
"artist_name": artist_name,
|
||||
"song_url": song_url,
|
||||
"track_info": track_info,
|
||||
"last_updated": time_now,
|
||||
"last_fetched": time_now,
|
||||
}
|
||||
)
|
||||
if skip_youtube is False:
|
||||
val = None
|
||||
if youtube_cache:
|
||||
try:
|
||||
(val, last_update) = await self.local_cache_api.youtube.fetch_one(
|
||||
{"track": track_info}
|
||||
)
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, f"Failed to fetch {track_info} from YouTube table")
|
||||
|
||||
if val is None:
|
||||
val = await self.fetch_youtube_query(
|
||||
ctx, track_info, current_cache_level=current_cache_level
|
||||
)
|
||||
if youtube_cache and val:
|
||||
task = ("update", ("youtube", {"track": track_info}))
|
||||
self.append_task(ctx, *task)
|
||||
if val:
|
||||
youtube_urls.append(val)
|
||||
else:
|
||||
youtube_urls.append(track_info)
|
||||
track_count += 1
|
||||
if notifier is not None and ((track_count % 2 == 0) or (track_count == total_tracks)):
|
||||
await notifier.notify_user(current=track_count, total=total_tracks, key="youtube")
|
||||
if CacheLevel.set_spotify().is_subset(current_cache_level):
|
||||
task = ("insert", ("spotify", database_entries))
|
||||
self.append_task(ctx, *task)
|
||||
return youtube_urls
|
||||
|
||||
async def fetch_from_spotify_api(
|
||||
self,
|
||||
query_type: str,
|
||||
uri: str,
|
||||
recursive: Union[str, bool] = False,
|
||||
params: MutableMapping = None,
|
||||
notifier: Optional[Notifier] = None,
|
||||
ctx: Context = None,
|
||||
) -> Union[List[MutableMapping], List[str]]:
|
||||
"""Gets track info from spotify API"""
|
||||
|
||||
if recursive is False:
|
||||
(call, params) = self.spotify_api.spotify_format_call(query_type, uri)
|
||||
results = await self.spotify_api.make_get_call(call, params)
|
||||
else:
|
||||
if isinstance(recursive, str):
|
||||
results = await self.spotify_api.make_get_call(recursive, params)
|
||||
else:
|
||||
results = {}
|
||||
try:
|
||||
if results["error"]["status"] == 401 and not recursive:
|
||||
raise SpotifyFetchError(
|
||||
_(
|
||||
"The Spotify API key or client secret has not been set properly. "
|
||||
"\nUse `{prefix}audioset spotifyapi` for instructions."
|
||||
)
|
||||
)
|
||||
elif recursive:
|
||||
return {"next": None}
|
||||
except KeyError:
|
||||
pass
|
||||
if recursive:
|
||||
return results
|
||||
tracks = []
|
||||
track_count = 0
|
||||
total_tracks = results.get("tracks", results).get("total", 1)
|
||||
while True:
|
||||
new_tracks: List = []
|
||||
if query_type == "track":
|
||||
new_tracks = results
|
||||
tracks.append(new_tracks)
|
||||
elif query_type == "album":
|
||||
tracks_raw = results.get("tracks", results).get("items", [])
|
||||
if tracks_raw:
|
||||
new_tracks = tracks_raw
|
||||
tracks.extend(new_tracks)
|
||||
else:
|
||||
tracks_raw = results.get("tracks", results).get("items", [])
|
||||
if tracks_raw:
|
||||
new_tracks = [k["track"] for k in tracks_raw if k.get("track")]
|
||||
tracks.extend(new_tracks)
|
||||
track_count += len(new_tracks)
|
||||
if notifier:
|
||||
await notifier.notify_user(current=track_count, total=total_tracks, key="spotify")
|
||||
try:
|
||||
if results.get("next") is not None:
|
||||
results = await self.fetch_from_spotify_api(
|
||||
query_type, uri, results["next"], params, notifier=notifier
|
||||
)
|
||||
continue
|
||||
else:
|
||||
break
|
||||
except KeyError:
|
||||
raise SpotifyFetchError(
|
||||
_("This doesn't seem to be a valid Spotify playlist/album URL or code.")
|
||||
)
|
||||
return tracks
|
||||
|
||||
async def spotify_query(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
query_type: str,
|
||||
uri: str,
|
||||
skip_youtube: bool = False,
|
||||
notifier: Optional[Notifier] = None,
|
||||
) -> List[str]:
|
||||
"""Queries the Database then falls back to Spotify and YouTube APIs.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx: commands.Context
|
||||
The context this method is being called under.
|
||||
query_type : str
|
||||
Type of query to perform (Pl
|
||||
uri: str
|
||||
Spotify URL ID.
|
||||
skip_youtube:bool
|
||||
Whether or not to skip YouTube API Calls.
|
||||
notifier: Notifier
|
||||
A Notifier object to handle the user UI notifications while tracks are loaded.
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
List of Youtube URLs.
|
||||
"""
|
||||
current_cache_level = CacheLevel(await self.config.cache_level())
|
||||
cache_enabled = CacheLevel.set_spotify().is_subset(current_cache_level)
|
||||
if query_type == "track" and cache_enabled:
|
||||
try:
|
||||
(val, last_update) = await self.local_cache_api.spotify.fetch_one(
|
||||
{"uri": f"spotify:track:{uri}"}
|
||||
)
|
||||
except Exception as exc:
|
||||
debug_exc_log(
|
||||
log, exc, f"Failed to fetch 'spotify:track:{uri}' from Spotify table"
|
||||
)
|
||||
val = None
|
||||
else:
|
||||
val = None
|
||||
youtube_urls = []
|
||||
if val is None:
|
||||
urls = await self.fetch_spotify_query(
|
||||
ctx,
|
||||
query_type,
|
||||
uri,
|
||||
notifier,
|
||||
skip_youtube,
|
||||
current_cache_level=current_cache_level,
|
||||
)
|
||||
youtube_urls.extend(urls)
|
||||
else:
|
||||
if query_type == "track" and cache_enabled:
|
||||
task = ("update", ("spotify", {"uri": f"spotify:track:{uri}"}))
|
||||
self.append_task(ctx, *task)
|
||||
youtube_urls.append(val)
|
||||
return youtube_urls
|
||||
|
||||
async def spotify_enqueue(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
query_type: str,
|
||||
uri: str,
|
||||
enqueue: bool,
|
||||
player: lavalink.Player,
|
||||
lock: Callable,
|
||||
notifier: Optional[Notifier] = None,
|
||||
forced: bool = False,
|
||||
query_global: bool = False,
|
||||
) -> List[lavalink.Track]:
|
||||
"""Queries the Database then falls back to Spotify and YouTube APIs then Enqueued matched tracks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx: commands.Context
|
||||
The context this method is being called under.
|
||||
query_type : str
|
||||
Type of query to perform (Pl
|
||||
uri: str
|
||||
Spotify URL ID.
|
||||
enqueue:bool
|
||||
Whether or not to enqueue the tracks
|
||||
player: lavalink.Player
|
||||
The current Player.
|
||||
notifier: Notifier
|
||||
A Notifier object to handle the user UI notifications while tracks are loaded.
|
||||
lock: Callable
|
||||
A callable handling the Track enqueue lock while spotify tracks are being added.
|
||||
query_global: bool
|
||||
Whether or not to query the global API.
|
||||
forced: bool
|
||||
Ignore Cache and make a fetch from API.
|
||||
Returns
|
||||
-------
|
||||
List[str]
|
||||
List of Youtube URLs.
|
||||
"""
|
||||
# globaldb_toggle = await self.config.global_db_enabled()
|
||||
track_list: List = []
|
||||
has_not_allowed = False
|
||||
try:
|
||||
current_cache_level = CacheLevel(await self.config.cache_level())
|
||||
guild_data = await self.config.guild(ctx.guild).all()
|
||||
enqueued_tracks = 0
|
||||
consecutive_fails = 0
|
||||
queue_dur = await self.cog.queue_duration(ctx)
|
||||
queue_total_duration = self.cog.format_time(queue_dur)
|
||||
before_queue_length = len(player.queue)
|
||||
tracks_from_spotify = await self.fetch_from_spotify_api(
|
||||
query_type, uri, params=None, notifier=notifier
|
||||
)
|
||||
total_tracks = len(tracks_from_spotify)
|
||||
if total_tracks < 1 and notifier is not None:
|
||||
lock(ctx, False)
|
||||
embed3 = discord.Embed(
|
||||
colour=await ctx.embed_colour(),
|
||||
title=_("This doesn't seem to be a supported Spotify URL or code."),
|
||||
)
|
||||
await notifier.update_embed(embed3)
|
||||
|
||||
return track_list
|
||||
database_entries = []
|
||||
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
|
||||
youtube_cache = CacheLevel.set_youtube().is_subset(current_cache_level)
|
||||
spotify_cache = CacheLevel.set_spotify().is_subset(current_cache_level)
|
||||
async for track_count, track in AsyncIter(tracks_from_spotify).enumerate(start=1):
|
||||
(
|
||||
song_url,
|
||||
track_info,
|
||||
uri,
|
||||
artist_name,
|
||||
track_name,
|
||||
_id,
|
||||
_type,
|
||||
) = await self.spotify_api.get_spotify_track_info(track, ctx)
|
||||
|
||||
database_entries.append(
|
||||
{
|
||||
"id": _id,
|
||||
"type": _type,
|
||||
"uri": uri,
|
||||
"track_name": track_name,
|
||||
"artist_name": artist_name,
|
||||
"song_url": song_url,
|
||||
"track_info": track_info,
|
||||
"last_updated": time_now,
|
||||
"last_fetched": time_now,
|
||||
}
|
||||
)
|
||||
val = None
|
||||
llresponse = None
|
||||
if youtube_cache:
|
||||
try:
|
||||
(val, last_updated) = await self.local_cache_api.youtube.fetch_one(
|
||||
{"track": track_info}
|
||||
)
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, f"Failed to fetch {track_info} from YouTube table")
|
||||
|
||||
if val is None:
|
||||
val = await self.fetch_youtube_query(
|
||||
ctx, track_info, current_cache_level=current_cache_level
|
||||
)
|
||||
if youtube_cache and val and llresponse is None:
|
||||
task = ("update", ("youtube", {"track": track_info}))
|
||||
self.append_task(ctx, *task)
|
||||
|
||||
if llresponse is not None:
|
||||
track_object = llresponse.tracks
|
||||
elif val:
|
||||
try:
|
||||
(result, called_api) = await self.fetch_track(
|
||||
ctx,
|
||||
player,
|
||||
Query.process_input(val, self.cog.local_folder_current_path),
|
||||
forced=forced,
|
||||
)
|
||||
except (RuntimeError, aiohttp.ServerDisconnectedError):
|
||||
lock(ctx, False)
|
||||
error_embed = discord.Embed(
|
||||
colour=await ctx.embed_colour(),
|
||||
title=_("The connection was reset while loading the playlist."),
|
||||
)
|
||||
if notifier is not None:
|
||||
await notifier.update_embed(error_embed)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
lock(ctx, False)
|
||||
error_embed = discord.Embed(
|
||||
colour=await ctx.embed_colour(),
|
||||
title=_("Player timeout, skipping remaining tracks."),
|
||||
)
|
||||
if notifier is not None:
|
||||
await notifier.update_embed(error_embed)
|
||||
break
|
||||
track_object = result.tracks
|
||||
else:
|
||||
track_object = []
|
||||
if (track_count % 2 == 0) or (track_count == total_tracks):
|
||||
key = "lavalink"
|
||||
seconds = "???"
|
||||
second_key = None
|
||||
if notifier is not None:
|
||||
await notifier.notify_user(
|
||||
current=track_count,
|
||||
total=total_tracks,
|
||||
key=key,
|
||||
seconds_key=second_key,
|
||||
seconds=seconds,
|
||||
)
|
||||
|
||||
if consecutive_fails >= 10:
|
||||
error_embed = discord.Embed(
|
||||
colour=await ctx.embed_colour(),
|
||||
title=_("Failing to get tracks, skipping remaining."),
|
||||
)
|
||||
if notifier is not None:
|
||||
await notifier.update_embed(error_embed)
|
||||
break
|
||||
if not track_object:
|
||||
consecutive_fails += 1
|
||||
continue
|
||||
consecutive_fails = 0
|
||||
single_track = track_object[0]
|
||||
if not await self.cog.is_query_allowed(
|
||||
self.config,
|
||||
ctx.guild,
|
||||
(
|
||||
f"{single_track.title} {single_track.author} {single_track.uri} "
|
||||
f"{Query.process_input(single_track, self.cog.local_folder_current_path)}"
|
||||
),
|
||||
):
|
||||
has_not_allowed = True
|
||||
if IS_DEBUG:
|
||||
log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})")
|
||||
continue
|
||||
track_list.append(single_track)
|
||||
if enqueue:
|
||||
if len(player.queue) >= 10000:
|
||||
continue
|
||||
if guild_data["maxlength"] > 0:
|
||||
if self.cog.is_track_too_long(single_track, guild_data["maxlength"]):
|
||||
enqueued_tracks += 1
|
||||
player.add(ctx.author, single_track)
|
||||
self.bot.dispatch(
|
||||
"red_audio_track_enqueue",
|
||||
player.channel.guild,
|
||||
single_track,
|
||||
ctx.author,
|
||||
)
|
||||
else:
|
||||
enqueued_tracks += 1
|
||||
player.add(ctx.author, single_track)
|
||||
self.bot.dispatch(
|
||||
"red_audio_track_enqueue",
|
||||
player.channel.guild,
|
||||
single_track,
|
||||
ctx.author,
|
||||
)
|
||||
|
||||
if not player.current:
|
||||
await player.play()
|
||||
if not track_list and not has_not_allowed:
|
||||
raise SpotifyFetchError(
|
||||
message=_(
|
||||
"Nothing found.\nThe YouTube API key may be invalid "
|
||||
"or you may be rate limited on YouTube's search service.\n"
|
||||
"Check the YouTube API key again and follow the instructions "
|
||||
"at `{prefix}audioset youtubeapi`."
|
||||
)
|
||||
)
|
||||
player.maybe_shuffle()
|
||||
if enqueue and tracks_from_spotify:
|
||||
if total_tracks > enqueued_tracks:
|
||||
maxlength_msg = _(" {bad_tracks} tracks cannot be queued.").format(
|
||||
bad_tracks=(total_tracks - enqueued_tracks)
|
||||
)
|
||||
else:
|
||||
maxlength_msg = ""
|
||||
|
||||
embed = discord.Embed(
|
||||
colour=await ctx.embed_colour(),
|
||||
title=_("Playlist Enqueued"),
|
||||
description=_("Added {num} tracks to the queue.{maxlength_msg}").format(
|
||||
num=enqueued_tracks, maxlength_msg=maxlength_msg
|
||||
),
|
||||
)
|
||||
if not guild_data["shuffle"] and queue_dur > 0:
|
||||
embed.set_footer(
|
||||
text=_(
|
||||
"{time} until start of playlist"
|
||||
" playback: starts at #{position} in queue"
|
||||
).format(time=queue_total_duration, position=before_queue_length + 1)
|
||||
)
|
||||
|
||||
if notifier is not None:
|
||||
await notifier.update_embed(embed)
|
||||
lock(ctx, False)
|
||||
|
||||
if spotify_cache:
|
||||
task = ("insert", ("spotify", database_entries))
|
||||
self.append_task(ctx, *task)
|
||||
except Exception as exc:
|
||||
lock(ctx, False)
|
||||
raise exc
|
||||
finally:
|
||||
lock(ctx, False)
|
||||
return track_list
|
||||
|
||||
async def fetch_youtube_query(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
track_info: str,
|
||||
current_cache_level: CacheLevel = CacheLevel.none(),
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Call the Youtube API and returns the youtube URL that the query matched
|
||||
"""
|
||||
track_url = await self.youtube_api.get_call(track_info)
|
||||
if CacheLevel.set_youtube().is_subset(current_cache_level) and track_url:
|
||||
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
task = (
|
||||
"insert",
|
||||
(
|
||||
"youtube",
|
||||
[
|
||||
{
|
||||
"track_info": track_info,
|
||||
"track_url": track_url,
|
||||
"last_updated": time_now,
|
||||
"last_fetched": time_now,
|
||||
}
|
||||
],
|
||||
),
|
||||
)
|
||||
self.append_task(ctx, *task)
|
||||
return track_url
|
||||
|
||||
async def fetch_from_youtube_api(
|
||||
self, ctx: commands.Context, track_info: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Gets an YouTube URL from for the query
|
||||
"""
|
||||
current_cache_level = CacheLevel(await self.config.cache_level())
|
||||
cache_enabled = CacheLevel.set_youtube().is_subset(current_cache_level)
|
||||
val = None
|
||||
if cache_enabled:
|
||||
try:
|
||||
(val, update) = await self.local_cache_api.youtube.fetch_one({"track": track_info})
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, f"Failed to fetch {track_info} from YouTube table")
|
||||
if val is None:
|
||||
youtube_url = await self.fetch_youtube_query(
|
||||
ctx, track_info, current_cache_level=current_cache_level
|
||||
)
|
||||
else:
|
||||
if cache_enabled:
|
||||
task = ("update", ("youtube", {"track": track_info}))
|
||||
self.append_task(ctx, *task)
|
||||
youtube_url = val
|
||||
return youtube_url
|
||||
|
||||
async def fetch_track(
|
||||
self,
|
||||
ctx: commands.Context,
|
||||
player: lavalink.Player,
|
||||
query: Query,
|
||||
forced: bool = False,
|
||||
lazy: bool = False,
|
||||
should_query_global: bool = True,
|
||||
) -> Tuple[LoadResult, bool]:
|
||||
"""A replacement for :code:`lavalink.Player.load_tracks`. This will try to get a valid
|
||||
cached entry first if not found or if in valid it will then call the lavalink API.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx: commands.Context
|
||||
The context this method is being called under.
|
||||
player : lavalink.Player
|
||||
The player who's requesting the query.
|
||||
query: audio_dataclasses.Query
|
||||
The Query object for the query in question.
|
||||
forced:bool
|
||||
Whether or not to skip cache and call API first.
|
||||
lazy:bool
|
||||
If set to True, it will not call the api if a track is not found.
|
||||
should_query_global:bool
|
||||
If the method should query the global database.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tuple[lavalink.LoadResult, bool]
|
||||
Tuple with the Load result and whether or not the API was called.
|
||||
"""
|
||||
current_cache_level = CacheLevel(await self.config.cache_level())
|
||||
cache_enabled = CacheLevel.set_lavalink().is_subset(current_cache_level)
|
||||
val = None
|
||||
query = Query.process_input(query, self.cog.local_folder_current_path)
|
||||
query_string = str(query)
|
||||
valid_global_entry = False
|
||||
results = None
|
||||
called_api = False
|
||||
prefer_lyrics = await self.cog.get_lyrics_status(ctx)
|
||||
if prefer_lyrics and query.is_youtube and query.is_search:
|
||||
query_string = f"{query} - lyrics"
|
||||
if cache_enabled and not forced and not query.is_local:
|
||||
try:
|
||||
(val, last_updated) = await self.local_cache_api.lavalink.fetch_one(
|
||||
{"query": query_string}
|
||||
)
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, f"Failed to fetch '{query_string}' from Lavalink table")
|
||||
|
||||
if val and isinstance(val, dict):
|
||||
if IS_DEBUG:
|
||||
log.debug(f"Updating Local Database with {query_string}")
|
||||
task = ("update", ("lavalink", {"query": query_string}))
|
||||
self.append_task(ctx, *task)
|
||||
else:
|
||||
val = None
|
||||
|
||||
if val and not forced and isinstance(val, dict):
|
||||
valid_global_entry = False
|
||||
called_api = False
|
||||
else:
|
||||
val = None
|
||||
|
||||
if valid_global_entry:
|
||||
pass
|
||||
elif lazy is True:
|
||||
called_api = False
|
||||
elif val and not forced and isinstance(val, dict):
|
||||
data = val
|
||||
data["query"] = query_string
|
||||
if data.get("loadType") == "V2_COMPACT":
|
||||
data["loadType"] = "V2_COMPAT"
|
||||
results = LoadResult(data)
|
||||
called_api = False
|
||||
if results.has_error:
|
||||
# If cached value has an invalid entry make a new call so that it gets updated
|
||||
results, called_api = await self.fetch_track(ctx, player, query, forced=True)
|
||||
else:
|
||||
if IS_DEBUG:
|
||||
log.debug(f"Querying Lavalink api for {query_string}")
|
||||
called_api = True
|
||||
try:
|
||||
results = await player.load_tracks(query_string)
|
||||
except KeyError:
|
||||
results = None
|
||||
except RuntimeError:
|
||||
raise TrackEnqueueError
|
||||
if results is None:
|
||||
results = LoadResult({"loadType": "LOAD_FAILED", "playlistInfo": {}, "tracks": []})
|
||||
|
||||
if (
|
||||
cache_enabled
|
||||
and results.load_type
|
||||
and not results.has_error
|
||||
and not query.is_local
|
||||
and results.tracks
|
||||
):
|
||||
try:
|
||||
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
data = json.dumps(results._raw)
|
||||
if all(k in data for k in ["loadType", "playlistInfo", "isSeekable", "isStream"]):
|
||||
task = (
|
||||
"insert",
|
||||
(
|
||||
"lavalink",
|
||||
[
|
||||
{
|
||||
"query": query_string,
|
||||
"data": data,
|
||||
"last_updated": time_now,
|
||||
"last_fetched": time_now,
|
||||
}
|
||||
],
|
||||
),
|
||||
)
|
||||
self.append_task(ctx, *task)
|
||||
except Exception as exc:
|
||||
debug_exc_log(
|
||||
log,
|
||||
exc,
|
||||
f"Failed to enqueue write task for '{query_string}' to Lavalink table",
|
||||
)
|
||||
return results, called_api
|
||||
|
||||
async def autoplay(self, player: lavalink.Player, playlist_api: PlaylistWrapper):
|
||||
"""
|
||||
Enqueue a random track
|
||||
"""
|
||||
autoplaylist = await self.config.guild(player.channel.guild).autoplaylist()
|
||||
current_cache_level = CacheLevel(await self.config.cache_level())
|
||||
cache_enabled = CacheLevel.set_lavalink().is_subset(current_cache_level)
|
||||
playlist = None
|
||||
tracks = None
|
||||
if autoplaylist["enabled"]:
|
||||
try:
|
||||
playlist = await get_playlist(
|
||||
autoplaylist["id"],
|
||||
autoplaylist["scope"],
|
||||
self.bot,
|
||||
playlist_api,
|
||||
player.channel.guild,
|
||||
player.channel.guild.me,
|
||||
)
|
||||
tracks = playlist.tracks_obj
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed to fetch playlist for autoplay")
|
||||
|
||||
if not tracks or not getattr(playlist, "tracks", None):
|
||||
if cache_enabled:
|
||||
track = await self.get_random_track_from_db()
|
||||
tracks = [] if not track else [track]
|
||||
if not tracks:
|
||||
ctx = namedtuple("Context", "message guild cog")
|
||||
(results, called_api) = await self.fetch_track(
|
||||
cast(
|
||||
commands.Context, ctx(player.channel.guild, player.channel.guild, self.cog)
|
||||
),
|
||||
player,
|
||||
Query.process_input(_TOP_100_US, self.cog.local_folder_current_path),
|
||||
)
|
||||
tracks = list(results.tracks)
|
||||
if tracks:
|
||||
multiple = len(tracks) > 1
|
||||
valid = not multiple
|
||||
tries = len(tracks)
|
||||
track = tracks[0]
|
||||
while valid is False and multiple:
|
||||
tries -= 1
|
||||
if tries <= 0:
|
||||
raise DatabaseError("No valid entry found")
|
||||
track = random.choice(tracks)
|
||||
query = Query.process_input(track, self.cog.local_folder_current_path)
|
||||
await asyncio.sleep(0.001)
|
||||
if not query.valid or (
|
||||
query.is_local
|
||||
and query.local_track_path is not None
|
||||
and not query.local_track_path.exists()
|
||||
):
|
||||
continue
|
||||
if not await self.cog.is_query_allowed(
|
||||
self.config,
|
||||
player.channel.guild,
|
||||
(
|
||||
f"{track.title} {track.author} {track.uri} "
|
||||
f"{str(Query.process_input(track, self.cog.local_folder_current_path))}"
|
||||
),
|
||||
):
|
||||
if IS_DEBUG:
|
||||
log.debug(
|
||||
"Query is not allowed in "
|
||||
f"{player.channel.guild} ({player.channel.guild.id})"
|
||||
)
|
||||
continue
|
||||
valid = True
|
||||
|
||||
track.extras["autoplay"] = True
|
||||
player.add(player.channel.guild.me, track)
|
||||
self.bot.dispatch(
|
||||
"red_audio_track_auto_play", player.channel.guild, track, player.channel.guild.me
|
||||
)
|
||||
if not player.current:
|
||||
await player.play()
|
||||
372
redbot/cogs/audio/apis/local_db.py
Normal file
372
redbot/cogs/audio/apis/local_db.py
Normal file
@@ -0,0 +1,372 @@
|
||||
import concurrent
|
||||
import contextlib
|
||||
import datetime
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from typing import Callable, List, MutableMapping, Optional, TYPE_CHECKING, Tuple, Union
|
||||
|
||||
from redbot.core.utils import AsyncIter
|
||||
|
||||
from redbot.core import Config
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core.commands import Cog
|
||||
from redbot.core.utils.dbtools import APSWConnectionWrapper
|
||||
|
||||
from ..audio_logging import debug_exc_log
|
||||
from ..sql_statements import (
|
||||
LAVALINK_CREATE_INDEX,
|
||||
LAVALINK_CREATE_TABLE,
|
||||
LAVALINK_DELETE_OLD_ENTRIES,
|
||||
LAVALINK_FETCH_ALL_ENTRIES_GLOBAL,
|
||||
LAVALINK_QUERY,
|
||||
LAVALINK_QUERY_ALL,
|
||||
LAVALINK_QUERY_LAST_FETCHED_RANDOM,
|
||||
LAVALINK_UPDATE,
|
||||
LAVALINK_UPSERT,
|
||||
SPOTIFY_CREATE_INDEX,
|
||||
SPOTIFY_CREATE_TABLE,
|
||||
SPOTIFY_DELETE_OLD_ENTRIES,
|
||||
SPOTIFY_QUERY,
|
||||
SPOTIFY_QUERY_ALL,
|
||||
SPOTIFY_QUERY_LAST_FETCHED_RANDOM,
|
||||
SPOTIFY_UPDATE,
|
||||
SPOTIFY_UPSERT,
|
||||
YOUTUBE_CREATE_INDEX,
|
||||
YOUTUBE_CREATE_TABLE,
|
||||
YOUTUBE_DELETE_OLD_ENTRIES,
|
||||
YOUTUBE_QUERY,
|
||||
YOUTUBE_QUERY_ALL,
|
||||
YOUTUBE_QUERY_LAST_FETCHED_RANDOM,
|
||||
YOUTUBE_UPDATE,
|
||||
YOUTUBE_UPSERT,
|
||||
PRAGMA_FETCH_user_version,
|
||||
PRAGMA_SET_journal_mode,
|
||||
PRAGMA_SET_read_uncommitted,
|
||||
PRAGMA_SET_temp_store,
|
||||
PRAGMA_SET_user_version,
|
||||
)
|
||||
from .api_utils import (
|
||||
LavalinkCacheFetchForGlobalResult,
|
||||
LavalinkCacheFetchResult,
|
||||
SpotifyCacheFetchResult,
|
||||
YouTubeCacheFetchResult,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .. import Audio
|
||||
|
||||
|
||||
log = logging.getLogger("red.cogs.Audio.api.LocalDB")
|
||||
|
||||
_SCHEMA_VERSION = 3
|
||||
|
||||
|
||||
class BaseWrapper:
|
||||
def __init__(
|
||||
self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
|
||||
):
|
||||
self.bot = bot
|
||||
self.config = config
|
||||
self.database = conn
|
||||
self.statement = SimpleNamespace()
|
||||
self.statement.pragma_temp_store = PRAGMA_SET_temp_store
|
||||
self.statement.pragma_journal_mode = PRAGMA_SET_journal_mode
|
||||
self.statement.pragma_read_uncommitted = PRAGMA_SET_read_uncommitted
|
||||
self.statement.set_user_version = PRAGMA_SET_user_version
|
||||
self.statement.get_user_version = PRAGMA_FETCH_user_version
|
||||
self.fetch_result: Optional[Callable] = None
|
||||
self.cog = cog
|
||||
|
||||
async def init(self) -> None:
|
||||
"""Initialize the local cache"""
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(self.database.cursor().execute, self.statement.pragma_temp_store)
|
||||
executor.submit(self.database.cursor().execute, self.statement.pragma_journal_mode)
|
||||
executor.submit(self.database.cursor().execute, self.statement.pragma_read_uncommitted)
|
||||
executor.submit(self.maybe_migrate)
|
||||
executor.submit(self.database.cursor().execute, LAVALINK_CREATE_TABLE)
|
||||
executor.submit(self.database.cursor().execute, LAVALINK_CREATE_INDEX)
|
||||
executor.submit(self.database.cursor().execute, YOUTUBE_CREATE_TABLE)
|
||||
executor.submit(self.database.cursor().execute, YOUTUBE_CREATE_INDEX)
|
||||
executor.submit(self.database.cursor().execute, SPOTIFY_CREATE_TABLE)
|
||||
executor.submit(self.database.cursor().execute, SPOTIFY_CREATE_INDEX)
|
||||
await self.clean_up_old_entries()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the connection with the local cache"""
|
||||
with contextlib.suppress(Exception):
|
||||
self.database.close()
|
||||
|
||||
async def clean_up_old_entries(self) -> None:
|
||||
"""Delete entries older than x in the local cache tables"""
|
||||
max_age = await self.config.cache_age()
|
||||
maxage = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=max_age)
|
||||
maxage_int = int(time.mktime(maxage.timetuple()))
|
||||
values = {"maxage": maxage_int}
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(self.database.cursor().execute, LAVALINK_DELETE_OLD_ENTRIES, values)
|
||||
executor.submit(self.database.cursor().execute, YOUTUBE_DELETE_OLD_ENTRIES, values)
|
||||
executor.submit(self.database.cursor().execute, SPOTIFY_DELETE_OLD_ENTRIES, values)
|
||||
|
||||
def maybe_migrate(self) -> None:
|
||||
"""Maybe migrate Database schema for the local cache"""
|
||||
current_version = 0
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
for future in concurrent.futures.as_completed(
|
||||
[executor.submit(self.database.cursor().execute, self.statement.get_user_version)]
|
||||
):
|
||||
try:
|
||||
row_result = future.result()
|
||||
current_version = row_result.fetchone()
|
||||
break
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed to completed fetch from database")
|
||||
if isinstance(current_version, tuple):
|
||||
current_version = current_version[0]
|
||||
if current_version == _SCHEMA_VERSION:
|
||||
return
|
||||
executor.submit(
|
||||
self.database.cursor().execute,
|
||||
self.statement.set_user_version,
|
||||
{"version": _SCHEMA_VERSION},
|
||||
)
|
||||
|
||||
async def insert(self, values: List[MutableMapping]) -> None:
|
||||
"""Insert an entry into the local cache"""
|
||||
try:
|
||||
with self.database.transaction() as transaction:
|
||||
transaction.executemany(self.statement.upsert, values)
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Error during table insert")
|
||||
|
||||
async def update(self, values: MutableMapping) -> None:
|
||||
"""Update an entry of the local cache"""
|
||||
|
||||
try:
|
||||
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
values["last_fetched"] = time_now
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(self.database.cursor().execute, self.statement.update, values)
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Error during table update")
|
||||
|
||||
async def _fetch_one(
|
||||
self, values: MutableMapping
|
||||
) -> Optional[
|
||||
Union[LavalinkCacheFetchResult, SpotifyCacheFetchResult, YouTubeCacheFetchResult]
|
||||
]:
|
||||
"""Get an entry from the local cache"""
|
||||
max_age = await self.config.cache_age()
|
||||
maxage = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=max_age)
|
||||
maxage_int = int(time.mktime(maxage.timetuple()))
|
||||
values.update({"maxage": maxage_int})
|
||||
row = None
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
for future in concurrent.futures.as_completed(
|
||||
[executor.submit(self.database.cursor().execute, self.statement.get_one, values)]
|
||||
):
|
||||
try:
|
||||
row_result = future.result()
|
||||
row = row_result.fetchone()
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed to completed fetch from database")
|
||||
if not row:
|
||||
return None
|
||||
if self.fetch_result is None:
|
||||
return None
|
||||
return self.fetch_result(*row)
|
||||
|
||||
async def _fetch_all(
|
||||
self, values: MutableMapping
|
||||
) -> List[Union[LavalinkCacheFetchResult, SpotifyCacheFetchResult, YouTubeCacheFetchResult]]:
|
||||
"""Get all entries from the local cache"""
|
||||
output = []
|
||||
row_result = []
|
||||
if self.fetch_result is None:
|
||||
return []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
for future in concurrent.futures.as_completed(
|
||||
[executor.submit(self.database.cursor().execute, self.statement.get_all, values)]
|
||||
):
|
||||
try:
|
||||
row_result = future.result()
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed to completed fetch from database")
|
||||
async for row in AsyncIter(row_result):
|
||||
output.append(self.fetch_result(*row))
|
||||
return output
|
||||
|
||||
async def _fetch_random(
|
||||
self, values: MutableMapping
|
||||
) -> Optional[
|
||||
Union[LavalinkCacheFetchResult, SpotifyCacheFetchResult, YouTubeCacheFetchResult]
|
||||
]:
|
||||
"""Get a random entry from the local cache"""
|
||||
row = None
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
for future in concurrent.futures.as_completed(
|
||||
[
|
||||
executor.submit(
|
||||
self.database.cursor().execute, self.statement.get_random, values
|
||||
)
|
||||
]
|
||||
):
|
||||
try:
|
||||
row_result = future.result()
|
||||
rows = row_result.fetchall()
|
||||
if rows:
|
||||
row = random.choice(rows)
|
||||
else:
|
||||
row = None
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed to completed random fetch from database")
|
||||
if not row:
|
||||
return None
|
||||
if self.fetch_result is None:
|
||||
return None
|
||||
return self.fetch_result(*row)
|
||||
|
||||
|
||||
class YouTubeTableWrapper(BaseWrapper):
|
||||
def __init__(
|
||||
self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
|
||||
):
|
||||
super().__init__(bot, config, conn, cog)
|
||||
self.statement.upsert = YOUTUBE_UPSERT
|
||||
self.statement.update = YOUTUBE_UPDATE
|
||||
self.statement.get_one = YOUTUBE_QUERY
|
||||
self.statement.get_all = YOUTUBE_QUERY_ALL
|
||||
self.statement.get_random = YOUTUBE_QUERY_LAST_FETCHED_RANDOM
|
||||
self.fetch_result = YouTubeCacheFetchResult
|
||||
|
||||
async def fetch_one(
|
||||
self, values: MutableMapping
|
||||
) -> Tuple[Optional[str], Optional[datetime.datetime]]:
|
||||
"""Get an entry from the Youtube table"""
|
||||
result = await self._fetch_one(values)
|
||||
if not result or not isinstance(result.query, str):
|
||||
return None, None
|
||||
return result.query, result.updated_on
|
||||
|
||||
async def fetch_all(self, values: MutableMapping) -> List[YouTubeCacheFetchResult]:
|
||||
"""Get all entries from the Youtube table"""
|
||||
result = await self._fetch_all(values)
|
||||
if result and isinstance(result[0], YouTubeCacheFetchResult):
|
||||
return result
|
||||
return []
|
||||
|
||||
async def fetch_random(self, values: MutableMapping) -> Optional[str]:
|
||||
"""Get a random entry from the Youtube table"""
|
||||
result = await self._fetch_random(values)
|
||||
if not result or not isinstance(result.query, str):
|
||||
return None
|
||||
return result.query
|
||||
|
||||
|
||||
class SpotifyTableWrapper(BaseWrapper):
|
||||
def __init__(
|
||||
self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
|
||||
):
|
||||
super().__init__(bot, config, conn, cog)
|
||||
self.statement.upsert = SPOTIFY_UPSERT
|
||||
self.statement.update = SPOTIFY_UPDATE
|
||||
self.statement.get_one = SPOTIFY_QUERY
|
||||
self.statement.get_all = SPOTIFY_QUERY_ALL
|
||||
self.statement.get_random = SPOTIFY_QUERY_LAST_FETCHED_RANDOM
|
||||
self.fetch_result = SpotifyCacheFetchResult
|
||||
|
||||
async def fetch_one(
|
||||
self, values: MutableMapping
|
||||
) -> Tuple[Optional[str], Optional[datetime.datetime]]:
|
||||
"""Get an entry from the Spotify table"""
|
||||
result = await self._fetch_one(values)
|
||||
if not result or not isinstance(result.query, str):
|
||||
return None, None
|
||||
return result.query, result.updated_on
|
||||
|
||||
async def fetch_all(self, values: MutableMapping) -> List[SpotifyCacheFetchResult]:
|
||||
"""Get all entries from the Spotify table"""
|
||||
result = await self._fetch_all(values)
|
||||
if result and isinstance(result[0], SpotifyCacheFetchResult):
|
||||
return result
|
||||
return []
|
||||
|
||||
async def fetch_random(self, values: MutableMapping) -> Optional[str]:
|
||||
"""Get a random entry from the Spotify table"""
|
||||
result = await self._fetch_random(values)
|
||||
if not result or not isinstance(result.query, str):
|
||||
return None
|
||||
return result.query
|
||||
|
||||
|
||||
class LavalinkTableWrapper(BaseWrapper):
|
||||
def __init__(
|
||||
self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
|
||||
):
|
||||
super().__init__(bot, config, conn, cog)
|
||||
self.statement.upsert = LAVALINK_UPSERT
|
||||
self.statement.update = LAVALINK_UPDATE
|
||||
self.statement.get_one = LAVALINK_QUERY
|
||||
self.statement.get_all = LAVALINK_QUERY_ALL
|
||||
self.statement.get_random = LAVALINK_QUERY_LAST_FETCHED_RANDOM
|
||||
self.statement.get_all_global = LAVALINK_FETCH_ALL_ENTRIES_GLOBAL
|
||||
self.fetch_result = LavalinkCacheFetchResult
|
||||
self.fetch_for_global: Optional[Callable] = None
|
||||
|
||||
async def fetch_one(
|
||||
self, values: MutableMapping
|
||||
) -> Tuple[Optional[MutableMapping], Optional[datetime.datetime]]:
|
||||
"""Get an entry from the Lavalink table"""
|
||||
result = await self._fetch_one(values)
|
||||
if not result or not isinstance(result.query, dict):
|
||||
return None, None
|
||||
return result.query, result.updated_on
|
||||
|
||||
async def fetch_all(self, values: MutableMapping) -> List[LavalinkCacheFetchResult]:
|
||||
"""Get all entries from the Lavalink table"""
|
||||
result = await self._fetch_all(values)
|
||||
if result and isinstance(result[0], LavalinkCacheFetchResult):
|
||||
return result
|
||||
return []
|
||||
|
||||
async def fetch_random(self, values: MutableMapping) -> Optional[MutableMapping]:
|
||||
"""Get a random entry from the Lavalink table"""
|
||||
result = await self._fetch_random(values)
|
||||
if not result or not isinstance(result.query, dict):
|
||||
return None
|
||||
return result.query
|
||||
|
||||
async def fetch_all_for_global(self) -> List[LavalinkCacheFetchForGlobalResult]:
|
||||
"""Get all entries from the Lavalink table"""
|
||||
output: List[LavalinkCacheFetchForGlobalResult] = []
|
||||
row_result = []
|
||||
if self.fetch_for_global is None:
|
||||
return []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
for future in concurrent.futures.as_completed(
|
||||
[executor.submit(self.database.cursor().execute, self.statement.get_all_global)]
|
||||
):
|
||||
try:
|
||||
row_result = future.result()
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed to completed fetch from database")
|
||||
async for row in AsyncIter(row_result):
|
||||
output.append(self.fetch_for_global(*row))
|
||||
return output
|
||||
|
||||
|
||||
class LocalCacheWrapper:
|
||||
"""Wraps all table apis into 1 object representing the local cache"""
|
||||
|
||||
def __init__(
|
||||
self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
|
||||
):
|
||||
self.bot = bot
|
||||
self.config = config
|
||||
self.database = conn
|
||||
self.cog = cog
|
||||
self.lavalink: LavalinkTableWrapper = LavalinkTableWrapper(bot, config, conn, self.cog)
|
||||
self.spotify: SpotifyTableWrapper = SpotifyTableWrapper(bot, config, conn, self.cog)
|
||||
self.youtube: YouTubeTableWrapper = YouTubeTableWrapper(bot, config, conn, self.cog)
|
||||
647
redbot/cogs/audio/apis/playlist_interface.py
Normal file
647
redbot/cogs/audio/apis/playlist_interface.py
Normal file
@@ -0,0 +1,647 @@
|
||||
import logging
|
||||
from typing import List, MutableMapping, Optional, Union
|
||||
|
||||
import discord
|
||||
import lavalink
|
||||
from redbot.core.utils import AsyncIter
|
||||
|
||||
from redbot.core import Config, commands
|
||||
from redbot.core.bot import Red
|
||||
|
||||
from ..errors import NotAllowed
|
||||
from ..utils import PlaylistScope
|
||||
from .api_utils import PlaylistFetchResult, prepare_config_scope, standardize_scope
|
||||
from .playlist_wrapper import PlaylistWrapper
|
||||
|
||||
log = logging.getLogger("red.cogs.Audio.api.PlaylistsInterface")
|
||||
|
||||
|
||||
class Playlist:
|
||||
"""A single playlist."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot: Red,
|
||||
playlist_api: PlaylistWrapper,
|
||||
scope: str,
|
||||
author: int,
|
||||
playlist_id: int,
|
||||
name: str,
|
||||
playlist_url: Optional[str] = None,
|
||||
tracks: Optional[List[MutableMapping]] = None,
|
||||
guild: Union[discord.Guild, int, None] = None,
|
||||
):
|
||||
self.bot = bot
|
||||
self.guild = guild
|
||||
self.scope = standardize_scope(scope)
|
||||
self.config_scope = prepare_config_scope(self.bot, self.scope, author, guild)
|
||||
self.scope_id = self.config_scope[-1]
|
||||
self.author = author
|
||||
self.author_id = getattr(self.author, "id", self.author)
|
||||
self.guild_id = (
|
||||
getattr(guild, "id", guild) if self.scope == PlaylistScope.GLOBAL.value else None
|
||||
)
|
||||
self.id = playlist_id
|
||||
self.name = name
|
||||
self.url = playlist_url
|
||||
self.tracks = tracks or []
|
||||
self.tracks_obj = [lavalink.Track(data=track) for track in self.tracks]
|
||||
self.playlist_api = playlist_api
|
||||
|
||||
def __repr__(self):
|
||||
return (
|
||||
f"Playlist(name={self.name}, id={self.id}, scope={self.scope}, "
|
||||
f"scope_id={self.scope_id}, author={self.author_id}, "
|
||||
f"tracks={len(self.tracks)}, url={self.url})"
|
||||
)
|
||||
|
||||
async def edit(self, data: MutableMapping):
|
||||
"""
|
||||
Edits a Playlist.
|
||||
Parameters
|
||||
----------
|
||||
data: dict
|
||||
The attributes to change.
|
||||
"""
|
||||
# Disallow ID editing
|
||||
if "id" in data:
|
||||
raise NotAllowed("Playlist ID cannot be edited.")
|
||||
|
||||
for item in list(data.keys()):
|
||||
setattr(self, item, data[item])
|
||||
await self.save()
|
||||
return self
|
||||
|
||||
async def save(self):
|
||||
"""Saves a Playlist."""
|
||||
scope, scope_id = self.config_scope
|
||||
await self.playlist_api.upsert(
|
||||
scope,
|
||||
playlist_id=int(self.id),
|
||||
playlist_name=self.name,
|
||||
scope_id=scope_id,
|
||||
author_id=self.author_id,
|
||||
playlist_url=self.url,
|
||||
tracks=self.tracks,
|
||||
)
|
||||
|
||||
def to_json(self) -> MutableMapping:
|
||||
"""Transform the object to a dict.
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
The playlist in the form of a dict.
|
||||
"""
|
||||
data = dict(
|
||||
id=self.id,
|
||||
author=self.author_id,
|
||||
guild=self.guild_id,
|
||||
name=self.name,
|
||||
playlist_url=self.url,
|
||||
tracks=self.tracks,
|
||||
)
|
||||
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
async def from_json(
|
||||
cls,
|
||||
bot: Red,
|
||||
playlist_api: PlaylistWrapper,
|
||||
scope: str,
|
||||
playlist_number: int,
|
||||
data: PlaylistFetchResult,
|
||||
**kwargs,
|
||||
) -> "Playlist":
|
||||
"""Get a Playlist object from the provided information.
|
||||
Parameters
|
||||
----------
|
||||
bot: Red
|
||||
The bot's instance. Needed to get the target user.
|
||||
playlist_api: PlaylistWrapper
|
||||
The Playlist API interface.
|
||||
scope:str
|
||||
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
|
||||
playlist_number: int
|
||||
The playlist's number.
|
||||
data: PlaylistFetchResult
|
||||
The PlaylistFetchResult representation of the playlist to be gotten.
|
||||
**kwargs
|
||||
Extra attributes for the Playlist instance which override values
|
||||
in the data dict. These should be complete objects and not
|
||||
IDs, where possible.
|
||||
Returns
|
||||
-------
|
||||
Playlist
|
||||
The playlist object for the requested playlist.
|
||||
Raises
|
||||
------
|
||||
`InvalidPlaylistScope`
|
||||
Passing a scope that is not supported.
|
||||
`MissingGuild`
|
||||
Trying to access the Guild scope without a guild.
|
||||
`MissingAuthor`
|
||||
Trying to access the User scope without an user id.
|
||||
"""
|
||||
guild = data.scope_id if scope == PlaylistScope.GUILD.value else kwargs.get("guild")
|
||||
author = data.author_id
|
||||
playlist_id = data.playlist_id or playlist_number
|
||||
name = data.playlist_name
|
||||
playlist_url = data.playlist_url
|
||||
tracks = data.tracks
|
||||
|
||||
return cls(
|
||||
bot=bot,
|
||||
playlist_api=playlist_api,
|
||||
guild=guild,
|
||||
scope=scope,
|
||||
author=author,
|
||||
playlist_id=playlist_id,
|
||||
name=name,
|
||||
playlist_url=playlist_url,
|
||||
tracks=tracks,
|
||||
)
|
||||
|
||||
|
||||
class PlaylistCompat23:
|
||||
"""A single playlist, migrating from Schema 2 to Schema 3"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
bot: Red,
|
||||
playlist_api: PlaylistWrapper,
|
||||
scope: str,
|
||||
author: int,
|
||||
playlist_id: int,
|
||||
name: str,
|
||||
playlist_url: Optional[str] = None,
|
||||
tracks: Optional[List[MutableMapping]] = None,
|
||||
guild: Union[discord.Guild, int, None] = None,
|
||||
):
|
||||
|
||||
self.bot = bot
|
||||
self.guild = guild
|
||||
self.scope = standardize_scope(scope)
|
||||
self.author = author
|
||||
self.id = playlist_id
|
||||
self.name = name
|
||||
self.url = playlist_url
|
||||
self.tracks = tracks or []
|
||||
|
||||
self.playlist_api = playlist_api
|
||||
|
||||
@classmethod
|
||||
async def from_json(
|
||||
cls,
|
||||
bot: Red,
|
||||
playlist_api: PlaylistWrapper,
|
||||
scope: str,
|
||||
playlist_number: int,
|
||||
data: MutableMapping,
|
||||
**kwargs,
|
||||
) -> "PlaylistCompat23":
|
||||
"""Get a Playlist object from the provided information.
|
||||
Parameters
|
||||
----------
|
||||
bot: Red
|
||||
The Bot instance.
|
||||
playlist_api: PlaylistWrapper
|
||||
The Playlist API interface.
|
||||
scope:str
|
||||
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
|
||||
playlist_number: int
|
||||
The playlist's number.
|
||||
data: MutableMapping
|
||||
The JSON representation of the playlist to be gotten.
|
||||
**kwargs
|
||||
Extra attributes for the Playlist instance which override values
|
||||
in the data dict. These should be complete objects and not
|
||||
IDs, where possible.
|
||||
Returns
|
||||
-------
|
||||
Playlist
|
||||
The playlist object for the requested playlist.
|
||||
Raises
|
||||
------
|
||||
`InvalidPlaylistScope`
|
||||
Passing a scope that is not supported.
|
||||
`MissingGuild`
|
||||
Trying to access the Guild scope without a guild.
|
||||
`MissingAuthor`
|
||||
Trying to access the User scope without an user id.
|
||||
"""
|
||||
guild = data.get("guild") or kwargs.get("guild")
|
||||
author: int = data.get("author") or 0
|
||||
playlist_id = data.get("id") or playlist_number
|
||||
name = data.get("name", "Unnamed")
|
||||
playlist_url = data.get("playlist_url", None)
|
||||
tracks = data.get("tracks", [])
|
||||
|
||||
return cls(
|
||||
bot=bot,
|
||||
playlist_api=playlist_api,
|
||||
guild=guild,
|
||||
scope=scope,
|
||||
author=author,
|
||||
playlist_id=playlist_id,
|
||||
name=name,
|
||||
playlist_url=playlist_url,
|
||||
tracks=tracks,
|
||||
)
|
||||
|
||||
async def save(self):
|
||||
"""Saves a Playlist to SQL."""
|
||||
scope, scope_id = prepare_config_scope(self.bot, self.scope, self.author, self.guild)
|
||||
await self.playlist_api.upsert(
|
||||
scope,
|
||||
playlist_id=int(self.id),
|
||||
playlist_name=self.name,
|
||||
scope_id=scope_id,
|
||||
author_id=self.author,
|
||||
playlist_url=self.url,
|
||||
tracks=self.tracks,
|
||||
)
|
||||
|
||||
|
||||
async def get_all_playlist_for_migration23(
|
||||
bot: Red,
|
||||
playlist_api: PlaylistWrapper,
|
||||
config: Config,
|
||||
scope: str,
|
||||
guild: Union[discord.Guild, int] = None,
|
||||
) -> List[PlaylistCompat23]:
|
||||
"""
|
||||
Gets all playlist for the specified scope.
|
||||
Parameters
|
||||
----------
|
||||
bot: Red
|
||||
The Bot instance.
|
||||
playlist_api: PlaylistWrapper
|
||||
The Playlist API interface.
|
||||
config: Config
|
||||
The Audio cog Config instance.
|
||||
scope: str
|
||||
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
|
||||
guild: discord.Guild
|
||||
The guild to get the playlist from if scope is GUILDPLAYLIST.
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list of all playlists for the specified scope
|
||||
Raises
|
||||
------
|
||||
`InvalidPlaylistScope`
|
||||
Passing a scope that is not supported.
|
||||
`MissingGuild`
|
||||
Trying to access the Guild scope without a guild.
|
||||
`MissingAuthor`
|
||||
Trying to access the User scope without an user id.
|
||||
"""
|
||||
playlists = await config.custom(scope).all()
|
||||
if scope == PlaylistScope.GLOBAL.value:
|
||||
return [
|
||||
await PlaylistCompat23.from_json(
|
||||
bot,
|
||||
playlist_api,
|
||||
scope,
|
||||
playlist_number,
|
||||
playlist_data,
|
||||
guild=guild,
|
||||
author=int(playlist_data.get("author", 0)),
|
||||
)
|
||||
async for playlist_number, playlist_data in AsyncIter(playlists.items())
|
||||
]
|
||||
elif scope == PlaylistScope.USER.value:
|
||||
return [
|
||||
await PlaylistCompat23.from_json(
|
||||
bot,
|
||||
playlist_api,
|
||||
scope,
|
||||
playlist_number,
|
||||
playlist_data,
|
||||
guild=guild,
|
||||
author=int(user_id),
|
||||
)
|
||||
async for user_id, scopedata in AsyncIter(playlists.items())
|
||||
async for playlist_number, playlist_data in AsyncIter(scopedata.items())
|
||||
]
|
||||
else:
|
||||
return [
|
||||
await PlaylistCompat23.from_json(
|
||||
bot,
|
||||
playlist_api,
|
||||
scope,
|
||||
playlist_number,
|
||||
playlist_data,
|
||||
guild=int(guild_id),
|
||||
author=int(playlist_data.get("author", 0)),
|
||||
)
|
||||
async for guild_id, scopedata in AsyncIter(playlists.items())
|
||||
async for playlist_number, playlist_data in AsyncIter(scopedata.items())
|
||||
]
|
||||
|
||||
|
||||
async def get_playlist(
|
||||
playlist_number: int,
|
||||
scope: str,
|
||||
bot: Red,
|
||||
playlist_api: PlaylistWrapper,
|
||||
guild: Union[discord.Guild, int] = None,
|
||||
author: Union[discord.abc.User, int] = None,
|
||||
) -> Playlist:
|
||||
"""
|
||||
Gets the playlist with the associated playlist number.
|
||||
Parameters
|
||||
----------
|
||||
playlist_number: int
|
||||
The playlist number for the playlist to get.
|
||||
playlist_api: PlaylistWrapper
|
||||
The Playlist API interface.
|
||||
scope: str
|
||||
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
|
||||
guild: discord.Guild
|
||||
The guild to get the playlist from if scope is GUILDPLAYLIST.
|
||||
author: int
|
||||
The ID of the user to get the playlist from if scope is USERPLAYLIST.
|
||||
bot: Red
|
||||
The bot's instance.
|
||||
Returns
|
||||
-------
|
||||
Playlist
|
||||
The playlist associated with the playlist number.
|
||||
Raises
|
||||
------
|
||||
`RuntimeError`
|
||||
If there is no playlist for the specified number.
|
||||
`InvalidPlaylistScope`
|
||||
Passing a scope that is not supported.
|
||||
`MissingGuild`
|
||||
Trying to access the Guild scope without a guild.
|
||||
`MissingAuthor`
|
||||
Trying to access the User scope without an user id.
|
||||
"""
|
||||
scope_standard, scope_id = prepare_config_scope(bot, scope, author, guild)
|
||||
playlist_data = await playlist_api.fetch(scope_standard, playlist_number, scope_id)
|
||||
|
||||
if not (playlist_data and playlist_data.playlist_id):
|
||||
raise RuntimeError(f"That playlist does not exist for the following scope: {scope}")
|
||||
return await Playlist.from_json(
|
||||
bot,
|
||||
playlist_api,
|
||||
scope_standard,
|
||||
playlist_number,
|
||||
playlist_data,
|
||||
guild=guild,
|
||||
author=author,
|
||||
)
|
||||
|
||||
|
||||
async def get_all_playlist(
|
||||
scope: str,
|
||||
bot: Red,
|
||||
playlist_api: PlaylistWrapper,
|
||||
guild: Union[discord.Guild, int] = None,
|
||||
author: Union[discord.abc.User, int] = None,
|
||||
specified_user: bool = False,
|
||||
) -> List[Playlist]:
|
||||
"""
|
||||
Gets all playlist for the specified scope.
|
||||
Parameters
|
||||
----------
|
||||
scope: str
|
||||
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
|
||||
guild: discord.Guild
|
||||
The guild to get the playlist from if scope is GUILDPLAYLIST.
|
||||
author: int
|
||||
The ID of the user to get the playlist from if scope is USERPLAYLIST.
|
||||
bot: Red
|
||||
The bot's instance
|
||||
playlist_api: PlaylistWrapper
|
||||
The Playlist API interface.
|
||||
specified_user:bool
|
||||
Whether or not user ID was passed as an argparse.
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list of all playlists for the specified scope
|
||||
Raises
|
||||
------
|
||||
`InvalidPlaylistScope`
|
||||
Passing a scope that is not supported.
|
||||
`MissingGuild`
|
||||
Trying to access the Guild scope without a guild.
|
||||
`MissingAuthor`
|
||||
Trying to access the User scope without an user id.
|
||||
"""
|
||||
scope_standard, scope_id = prepare_config_scope(bot, scope, author, guild)
|
||||
|
||||
if specified_user:
|
||||
user_id = getattr(author, "id", author)
|
||||
playlists = await playlist_api.fetch_all(scope_standard, scope_id, author_id=user_id)
|
||||
else:
|
||||
playlists = await playlist_api.fetch_all(scope_standard, scope_id)
|
||||
|
||||
playlist_list = []
|
||||
async for playlist in AsyncIter(playlists):
|
||||
playlist_list.append(
|
||||
await Playlist.from_json(
|
||||
bot,
|
||||
playlist_api,
|
||||
scope,
|
||||
playlist.playlist_id,
|
||||
playlist,
|
||||
guild=guild,
|
||||
author=author,
|
||||
)
|
||||
)
|
||||
return playlist_list
|
||||
|
||||
|
||||
async def get_all_playlist_converter(
|
||||
scope: str,
|
||||
bot: Red,
|
||||
playlist_api: PlaylistWrapper,
|
||||
arg: str,
|
||||
guild: Union[discord.Guild, int] = None,
|
||||
author: Union[discord.abc.User, int] = None,
|
||||
) -> List[Playlist]:
|
||||
"""
|
||||
Gets all playlist for the specified scope.
|
||||
Parameters
|
||||
----------
|
||||
scope: str
|
||||
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
|
||||
guild: discord.Guild
|
||||
The guild to get the playlist from if scope is GUILDPLAYLIST.
|
||||
author: int
|
||||
The ID of the user to get the playlist from if scope is USERPLAYLIST.
|
||||
bot: Red
|
||||
The bot's instance
|
||||
arg:str
|
||||
The value to lookup.
|
||||
playlist_api: PlaylistWrapper
|
||||
The Playlist API interface.
|
||||
Returns
|
||||
-------
|
||||
list
|
||||
A list of all playlists for the specified scope
|
||||
Raises
|
||||
------
|
||||
`InvalidPlaylistScope`
|
||||
Passing a scope that is not supported.
|
||||
`MissingGuild`
|
||||
Trying to access the Guild scope without a guild.
|
||||
`MissingAuthor`
|
||||
Trying to access the User scope without an user id.
|
||||
"""
|
||||
scope_standard, scope_id = prepare_config_scope(bot, scope, author, guild)
|
||||
playlists = await playlist_api.fetch_all_converter(
|
||||
scope_standard, playlist_name=arg, playlist_id=arg
|
||||
)
|
||||
playlist_list = []
|
||||
async for playlist in AsyncIter(playlists):
|
||||
playlist_list.append(
|
||||
await Playlist.from_json(
|
||||
bot,
|
||||
playlist_api,
|
||||
scope,
|
||||
playlist.playlist_id,
|
||||
playlist,
|
||||
guild=guild,
|
||||
author=author,
|
||||
)
|
||||
)
|
||||
return playlist_list
|
||||
|
||||
|
||||
async def create_playlist(
|
||||
ctx: commands.Context,
|
||||
playlist_api: PlaylistWrapper,
|
||||
scope: str,
|
||||
playlist_name: str,
|
||||
playlist_url: Optional[str] = None,
|
||||
tracks: Optional[List[MutableMapping]] = None,
|
||||
author: Optional[discord.User] = None,
|
||||
guild: Optional[discord.Guild] = None,
|
||||
) -> Optional[Playlist]:
|
||||
"""Creates a new Playlist.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
ctx: commands.Context
|
||||
The context in which the play list is being created.
|
||||
scope: str
|
||||
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
|
||||
playlist_name: str
|
||||
The name of the new playlist.
|
||||
playlist_url:str
|
||||
the url of the new playlist.
|
||||
tracks: List[MutableMapping]
|
||||
A list of tracks to add to the playlist.
|
||||
author: discord.User
|
||||
The Author of the playlist.
|
||||
If provided it will create a playlist under this user.
|
||||
This is only required when creating a playlist in User scope.
|
||||
guild: discord.Guild
|
||||
The guild to create this playlist under.
|
||||
This is only used when creating a playlist in the Guild scope
|
||||
playlist_api: PlaylistWrapper
|
||||
The Playlist API interface.
|
||||
|
||||
Raises
|
||||
------
|
||||
`InvalidPlaylistScope`
|
||||
Passing a scope that is not supported.
|
||||
`MissingGuild`
|
||||
Trying to access the Guild scope without a guild.
|
||||
`MissingAuthor`
|
||||
Trying to access the User scope without an user id.
|
||||
"""
|
||||
|
||||
playlist = Playlist(
|
||||
ctx.bot,
|
||||
playlist_api,
|
||||
scope,
|
||||
author.id if author else None,
|
||||
ctx.message.id,
|
||||
playlist_name,
|
||||
playlist_url,
|
||||
tracks,
|
||||
guild or ctx.guild,
|
||||
)
|
||||
await playlist.save()
|
||||
return playlist
|
||||
|
||||
|
||||
async def reset_playlist(
|
||||
bot: Red,
|
||||
playlist_api: PlaylistWrapper,
|
||||
scope: str,
|
||||
guild: Union[discord.Guild, int] = None,
|
||||
author: Union[discord.abc.User, int] = None,
|
||||
) -> None:
|
||||
"""Wipes all playlists for the specified scope.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bot: Red
|
||||
The bot's instance
|
||||
scope: str
|
||||
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
|
||||
guild: discord.Guild
|
||||
The guild to get the playlist from if scope is GUILDPLAYLIST.
|
||||
author: int
|
||||
The ID of the user to get the playlist from if scope is USERPLAYLIST.
|
||||
playlist_api: PlaylistWrapper
|
||||
The Playlist API interface.
|
||||
|
||||
Raises
|
||||
------
|
||||
`InvalidPlaylistScope`
|
||||
Passing a scope that is not supported.
|
||||
`MissingGuild`
|
||||
Trying to access the Guild scope without a guild.
|
||||
`MissingAuthor`
|
||||
Trying to access the User scope without an user id.
|
||||
"""
|
||||
scope, scope_id = prepare_config_scope(bot, scope, author, guild)
|
||||
await playlist_api.drop(scope)
|
||||
await playlist_api.create_table()
|
||||
|
||||
|
||||
async def delete_playlist(
|
||||
bot: Red,
|
||||
playlist_api: PlaylistWrapper,
|
||||
scope: str,
|
||||
playlist_id: Union[str, int],
|
||||
guild: discord.Guild,
|
||||
author: Union[discord.abc.User, int] = None,
|
||||
) -> None:
|
||||
"""Deletes the specified playlist.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
bot: Red
|
||||
The bot's instance
|
||||
scope: str
|
||||
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
|
||||
playlist_id: Union[str, int]
|
||||
The ID of the playlist.
|
||||
guild: discord.Guild
|
||||
The guild to get the playlist from if scope is GUILDPLAYLIST.
|
||||
author: int
|
||||
The ID of the user to get the playlist from if scope is USERPLAYLIST.
|
||||
playlist_api: PlaylistWrapper
|
||||
The Playlist API interface.
|
||||
|
||||
Raises
|
||||
------
|
||||
`InvalidPlaylistScope`
|
||||
Passing a scope that is not supported.
|
||||
`MissingGuild`
|
||||
Trying to access the Guild scope without a guild.
|
||||
`MissingAuthor`
|
||||
Trying to access the User scope without an user id.
|
||||
"""
|
||||
scope, scope_id = prepare_config_scope(bot, scope, author, guild)
|
||||
await playlist_api.delete(scope, int(playlist_id), scope_id)
|
||||
249
redbot/cogs/audio/apis/playlist_wrapper.py
Normal file
249
redbot/cogs/audio/apis/playlist_wrapper.py
Normal file
@@ -0,0 +1,249 @@
|
||||
import concurrent
|
||||
import json
|
||||
import logging
|
||||
from types import SimpleNamespace
|
||||
from typing import List, MutableMapping, Optional
|
||||
|
||||
from redbot.core.utils import AsyncIter
|
||||
|
||||
from redbot.core import Config
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core.utils.dbtools import APSWConnectionWrapper
|
||||
|
||||
from ..audio_logging import debug_exc_log
|
||||
from ..sql_statements import (
|
||||
PLAYLIST_CREATE_INDEX,
|
||||
PLAYLIST_CREATE_TABLE,
|
||||
PLAYLIST_DELETE,
|
||||
PLAYLIST_DELETE_SCHEDULED,
|
||||
PLAYLIST_DELETE_SCOPE,
|
||||
PLAYLIST_FETCH,
|
||||
PLAYLIST_FETCH_ALL,
|
||||
PLAYLIST_FETCH_ALL_CONVERTER,
|
||||
PLAYLIST_FETCH_ALL_WITH_FILTER,
|
||||
PLAYLIST_UPSERT,
|
||||
PRAGMA_FETCH_user_version,
|
||||
PRAGMA_SET_journal_mode,
|
||||
PRAGMA_SET_read_uncommitted,
|
||||
PRAGMA_SET_temp_store,
|
||||
PRAGMA_SET_user_version,
|
||||
)
|
||||
from ..utils import PlaylistScope
|
||||
from .api_utils import PlaylistFetchResult
|
||||
|
||||
log = logging.getLogger("red.cogs.Audio.api.Playlists")
|
||||
|
||||
|
||||
class PlaylistWrapper:
|
||||
def __init__(self, bot: Red, config: Config, conn: APSWConnectionWrapper):
|
||||
self.bot = bot
|
||||
self.database = conn
|
||||
self.config = config
|
||||
self.statement = SimpleNamespace()
|
||||
self.statement.pragma_temp_store = PRAGMA_SET_temp_store
|
||||
self.statement.pragma_journal_mode = PRAGMA_SET_journal_mode
|
||||
self.statement.pragma_read_uncommitted = PRAGMA_SET_read_uncommitted
|
||||
self.statement.set_user_version = PRAGMA_SET_user_version
|
||||
self.statement.get_user_version = PRAGMA_FETCH_user_version
|
||||
self.statement.create_table = PLAYLIST_CREATE_TABLE
|
||||
self.statement.create_index = PLAYLIST_CREATE_INDEX
|
||||
|
||||
self.statement.upsert = PLAYLIST_UPSERT
|
||||
self.statement.delete = PLAYLIST_DELETE
|
||||
self.statement.delete_scope = PLAYLIST_DELETE_SCOPE
|
||||
self.statement.delete_scheduled = PLAYLIST_DELETE_SCHEDULED
|
||||
|
||||
self.statement.get_one = PLAYLIST_FETCH
|
||||
self.statement.get_all = PLAYLIST_FETCH_ALL
|
||||
self.statement.get_all_with_filter = PLAYLIST_FETCH_ALL_WITH_FILTER
|
||||
self.statement.get_all_converter = PLAYLIST_FETCH_ALL_CONVERTER
|
||||
|
||||
async def init(self) -> None:
|
||||
"""Initialize the Playlist table"""
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(self.database.cursor().execute, self.statement.pragma_temp_store)
|
||||
executor.submit(self.database.cursor().execute, self.statement.pragma_journal_mode)
|
||||
executor.submit(self.database.cursor().execute, self.statement.pragma_read_uncommitted)
|
||||
executor.submit(self.database.cursor().execute, self.statement.create_table)
|
||||
executor.submit(self.database.cursor().execute, self.statement.create_index)
|
||||
|
||||
@staticmethod
|
||||
def get_scope_type(scope: str) -> int:
|
||||
"""Convert a scope to a numerical identifier"""
|
||||
if scope == PlaylistScope.GLOBAL.value:
|
||||
table = 1
|
||||
elif scope == PlaylistScope.USER.value:
|
||||
table = 3
|
||||
else:
|
||||
table = 2
|
||||
return table
|
||||
|
||||
async def fetch(self, scope: str, playlist_id: int, scope_id: int) -> PlaylistFetchResult:
|
||||
"""Fetch a single playlist"""
|
||||
scope_type = self.get_scope_type(scope)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
for future in concurrent.futures.as_completed(
|
||||
[
|
||||
executor.submit(
|
||||
self.database.cursor().execute,
|
||||
self.statement.get_one,
|
||||
(
|
||||
{
|
||||
"playlist_id": playlist_id,
|
||||
"scope_id": scope_id,
|
||||
"scope_type": scope_type,
|
||||
}
|
||||
),
|
||||
)
|
||||
]
|
||||
):
|
||||
try:
|
||||
row_result = future.result()
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed to completed playlist fetch from database")
|
||||
row = row_result.fetchone()
|
||||
if row:
|
||||
row = PlaylistFetchResult(*row)
|
||||
return row
|
||||
|
||||
async def fetch_all(
|
||||
self, scope: str, scope_id: int, author_id=None
|
||||
) -> List[PlaylistFetchResult]:
|
||||
"""Fetch all playlists"""
|
||||
scope_type = self.get_scope_type(scope)
|
||||
output = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
if author_id is not None:
|
||||
for future in concurrent.futures.as_completed(
|
||||
[
|
||||
executor.submit(
|
||||
self.database.cursor().execute,
|
||||
self.statement.get_all_with_filter,
|
||||
(
|
||||
{
|
||||
"scope_type": scope_type,
|
||||
"scope_id": scope_id,
|
||||
"author_id": author_id,
|
||||
}
|
||||
),
|
||||
)
|
||||
]
|
||||
):
|
||||
try:
|
||||
row_result = future.result()
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed to completed playlist fetch from database")
|
||||
return []
|
||||
else:
|
||||
for future in concurrent.futures.as_completed(
|
||||
[
|
||||
executor.submit(
|
||||
self.database.cursor().execute,
|
||||
self.statement.get_all,
|
||||
({"scope_type": scope_type, "scope_id": scope_id}),
|
||||
)
|
||||
]
|
||||
):
|
||||
try:
|
||||
row_result = future.result()
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed to completed playlist fetch from database")
|
||||
return []
|
||||
async for row in AsyncIter(row_result):
|
||||
output.append(PlaylistFetchResult(*row))
|
||||
return output
|
||||
|
||||
async def fetch_all_converter(
|
||||
self, scope: str, playlist_name, playlist_id
|
||||
) -> List[PlaylistFetchResult]:
|
||||
"""Fetch all playlists with the specified filter"""
|
||||
scope_type = self.get_scope_type(scope)
|
||||
try:
|
||||
playlist_id = int(playlist_id)
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed converting playlist_id to int")
|
||||
playlist_id = -1
|
||||
|
||||
output = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
for future in concurrent.futures.as_completed(
|
||||
[
|
||||
executor.submit(
|
||||
self.database.cursor().execute,
|
||||
self.statement.get_all_converter,
|
||||
(
|
||||
{
|
||||
"scope_type": scope_type,
|
||||
"playlist_name": playlist_name,
|
||||
"playlist_id": playlist_id,
|
||||
}
|
||||
),
|
||||
)
|
||||
]
|
||||
):
|
||||
try:
|
||||
row_result = future.result()
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed to completed fetch from database")
|
||||
|
||||
async for row in AsyncIter(row_result):
|
||||
output.append(PlaylistFetchResult(*row))
|
||||
return output
|
||||
|
||||
async def delete(self, scope: str, playlist_id: int, scope_id: int):
|
||||
"""Deletes a single playlists"""
|
||||
scope_type = self.get_scope_type(scope)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(
|
||||
self.database.cursor().execute,
|
||||
self.statement.delete,
|
||||
({"playlist_id": playlist_id, "scope_id": scope_id, "scope_type": scope_type}),
|
||||
)
|
||||
|
||||
async def delete_scheduled(self):
|
||||
"""Clean up database from all deleted playlists"""
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(self.database.cursor().execute, self.statement.delete_scheduled)
|
||||
|
||||
async def drop(self, scope: str):
|
||||
"""Delete all playlists in a scope"""
|
||||
scope_type = self.get_scope_type(scope)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(
|
||||
self.database.cursor().execute,
|
||||
self.statement.delete_scope,
|
||||
({"scope_type": scope_type}),
|
||||
)
|
||||
|
||||
async def create_table(self):
|
||||
"""Create the playlist table"""
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(self.database.cursor().execute, PLAYLIST_CREATE_TABLE)
|
||||
|
||||
async def upsert(
|
||||
self,
|
||||
scope: str,
|
||||
playlist_id: int,
|
||||
playlist_name: str,
|
||||
scope_id: int,
|
||||
author_id: int,
|
||||
playlist_url: Optional[str],
|
||||
tracks: List[MutableMapping],
|
||||
):
|
||||
"""Insert or update a playlist into the database"""
|
||||
scope_type = self.get_scope_type(scope)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(
|
||||
self.database.cursor().execute,
|
||||
self.statement.upsert,
|
||||
{
|
||||
"scope_type": str(scope_type),
|
||||
"playlist_id": int(playlist_id),
|
||||
"playlist_name": str(playlist_name),
|
||||
"scope_id": int(scope_id),
|
||||
"author_id": int(author_id),
|
||||
"playlist_url": playlist_url,
|
||||
"tracks": json.dumps(tracks),
|
||||
},
|
||||
)
|
||||
189
redbot/cogs/audio/apis/spotify.py
Normal file
189
redbot/cogs/audio/apis/spotify.py
Normal file
@@ -0,0 +1,189 @@
|
||||
import base64
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Mapping, MutableMapping, Optional, TYPE_CHECKING, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
from redbot.core.i18n import Translator
|
||||
from redbot.core.utils import AsyncIter
|
||||
|
||||
from redbot.core import Config
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core.commands import Cog, Context
|
||||
|
||||
from ..errors import SpotifyFetchError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .. import Audio
|
||||
|
||||
_ = Translator("Audio", __file__)
|
||||
|
||||
log = logging.getLogger("red.cogs.Audio.api.Spotify")
|
||||
|
||||
|
||||
CATEGORY_ENDPOINT = "https://api.spotify.com/v1/browse/categories"
|
||||
TOKEN_ENDPOINT = "https://accounts.spotify.com/api/token"
|
||||
ALBUMS_ENDPOINT = "https://api.spotify.com/v1/albums"
|
||||
TRACKS_ENDPOINT = "https://api.spotify.com/v1/tracks"
|
||||
PLAYLISTS_ENDPOINT = "https://api.spotify.com/v1/playlists"
|
||||
|
||||
|
||||
class SpotifyWrapper:
|
||||
"""Wrapper for the Spotify API."""
|
||||
|
||||
def __init__(
|
||||
self, bot: Red, config: Config, session: aiohttp.ClientSession, cog: Union["Audio", Cog]
|
||||
):
|
||||
self.bot = bot
|
||||
self.config = config
|
||||
self.session = session
|
||||
self.spotify_token: Optional[MutableMapping] = None
|
||||
self.client_id: Optional[str] = None
|
||||
self.client_secret: Optional[str] = None
|
||||
self._token: Mapping[str, str] = {}
|
||||
self.cog = cog
|
||||
|
||||
@staticmethod
|
||||
def spotify_format_call(query_type: str, key: str) -> Tuple[str, MutableMapping]:
|
||||
"""Format the spotify endpoint"""
|
||||
params: MutableMapping = {}
|
||||
if query_type == "album":
|
||||
query = f"{ALBUMS_ENDPOINT}/{key}/tracks"
|
||||
elif query_type == "track":
|
||||
query = f"{TRACKS_ENDPOINT}/{key}"
|
||||
else:
|
||||
query = f"{PLAYLISTS_ENDPOINT}/{key}/tracks"
|
||||
return query, params
|
||||
|
||||
async def get_spotify_track_info(
|
||||
self, track_data: MutableMapping, ctx: Context
|
||||
) -> Tuple[str, ...]:
|
||||
"""Extract track info from spotify response"""
|
||||
prefer_lyrics = await self.cog.get_lyrics_status(ctx)
|
||||
track_name = track_data["name"]
|
||||
if prefer_lyrics:
|
||||
track_name = f"{track_name} - lyrics"
|
||||
artist_name = track_data["artists"][0]["name"]
|
||||
track_info = f"{track_name} {artist_name}"
|
||||
song_url = track_data.get("external_urls", {}).get("spotify")
|
||||
uri = track_data["uri"]
|
||||
_id = track_data["id"]
|
||||
_type = track_data["type"]
|
||||
|
||||
return song_url, track_info, uri, artist_name, track_name, _id, _type
|
||||
|
||||
@staticmethod
|
||||
async def is_access_token_valid(token: MutableMapping) -> bool:
|
||||
"""Check if current token is not too old"""
|
||||
return (token["expires_at"] - int(time.time())) < 60
|
||||
|
||||
@staticmethod
|
||||
def make_auth_header(
|
||||
client_id: Optional[str], client_secret: Optional[str]
|
||||
) -> MutableMapping[str, Union[str, int]]:
|
||||
"""Make Authorization header for spotify token"""
|
||||
if client_id is None:
|
||||
client_id = ""
|
||||
if client_secret is None:
|
||||
client_secret = ""
|
||||
auth_header = base64.b64encode(f"{client_id}:{client_secret}".encode("ascii"))
|
||||
return {"Authorization": f"Basic {auth_header.decode('ascii')}"}
|
||||
|
||||
async def get(
|
||||
self, url: str, headers: MutableMapping = None, params: MutableMapping = None
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Make a GET request to the spotify API"""
|
||||
if params is None:
|
||||
params = {}
|
||||
async with self.session.request("GET", url, params=params, headers=headers) as r:
|
||||
data = await r.json()
|
||||
if r.status != 200:
|
||||
log.debug(f"Issue making GET request to {url}: [{r.status}] {data}")
|
||||
return data
|
||||
|
||||
def update_token(self, new_token: Mapping[str, str]):
|
||||
self._token = new_token
|
||||
|
||||
async def get_token(self) -> None:
|
||||
"""Get the stored spotify tokens"""
|
||||
if not self._token:
|
||||
self._token = await self.bot.get_shared_api_tokens("spotify")
|
||||
|
||||
self.client_id = self._token.get("client_id", "")
|
||||
self.client_secret = self._token.get("client_secret", "")
|
||||
|
||||
async def get_country_code(self, ctx: Context = None) -> str:
|
||||
return await self.config.guild(ctx.guild).country_code() if ctx else "US"
|
||||
|
||||
async def request_access_token(self) -> MutableMapping:
|
||||
"""Make a spotify call to get the auth token"""
|
||||
await self.get_token()
|
||||
payload = {"grant_type": "client_credentials"}
|
||||
headers = self.make_auth_header(self.client_id, self.client_secret)
|
||||
r = await self.post(TOKEN_ENDPOINT, payload=payload, headers=headers)
|
||||
return r
|
||||
|
||||
async def get_access_token(self) -> Optional[str]:
|
||||
"""Get the access_token"""
|
||||
if self.spotify_token and not await self.is_access_token_valid(self.spotify_token):
|
||||
return self.spotify_token["access_token"]
|
||||
token = await self.request_access_token()
|
||||
if token is None:
|
||||
log.debug("Requested a token from Spotify, did not end up getting one.")
|
||||
try:
|
||||
token["expires_at"] = int(time.time()) + int(token["expires_in"])
|
||||
except KeyError:
|
||||
return None
|
||||
self.spotify_token = token
|
||||
log.debug(f"Created a new access token for Spotify: {token}")
|
||||
return self.spotify_token["access_token"]
|
||||
|
||||
async def post(
|
||||
self, url: str, payload: MutableMapping, headers: MutableMapping = None
|
||||
) -> MutableMapping:
|
||||
"""Make a POST call to spotify"""
|
||||
async with self.session.post(url, data=payload, headers=headers) as r:
|
||||
data = await r.json()
|
||||
if r.status != 200:
|
||||
log.debug(f"Issue making POST request to {url}: [{r.status}] {data}")
|
||||
return data
|
||||
|
||||
async def make_get_call(self, url: str, params: MutableMapping) -> MutableMapping:
|
||||
"""Make a Get call to spotify"""
|
||||
token = await self.get_access_token()
|
||||
return await self.get(url, params=params, headers={"Authorization": f"Bearer {token}"})
|
||||
|
||||
async def get_categories(self, ctx: Context = None) -> List[MutableMapping]:
|
||||
"""Get the spotify categories"""
|
||||
country_code = await self.get_country_code(ctx=ctx)
|
||||
params: MutableMapping = {"country": country_code} if country_code else {}
|
||||
result = await self.make_get_call(CATEGORY_ENDPOINT, params=params)
|
||||
with contextlib.suppress(KeyError):
|
||||
if result["error"]["status"] == 401:
|
||||
raise SpotifyFetchError(
|
||||
message=_(
|
||||
"The Spotify API key or client secret has not been set properly. "
|
||||
"\nUse `{prefix}audioset spotifyapi` for instructions."
|
||||
)
|
||||
)
|
||||
categories = result.get("categories", {}).get("items", [])
|
||||
return [{c["name"]: c["id"]} for c in categories if c]
|
||||
|
||||
async def get_playlist_from_category(self, category: str, ctx: Context = None):
|
||||
"""Get spotify playlists for the specified category"""
|
||||
url = f"{CATEGORY_ENDPOINT}/{category}/playlists"
|
||||
country_code = await self.get_country_code(ctx=ctx)
|
||||
params: MutableMapping = {"country": country_code} if country_code else {}
|
||||
result = await self.make_get_call(url, params=params)
|
||||
playlists = result.get("playlists", {}).get("items", [])
|
||||
return [
|
||||
{
|
||||
"name": c["name"],
|
||||
"uri": c["uri"],
|
||||
"url": c.get("external_urls", {}).get("spotify"),
|
||||
"tracks": c.get("tracks", {}).get("total", "Unknown"),
|
||||
}
|
||||
async for c in AsyncIter(playlists)
|
||||
if c
|
||||
]
|
||||
65
redbot/cogs/audio/apis/youtube.py
Normal file
65
redbot/cogs/audio/apis/youtube.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import logging
|
||||
from typing import Mapping, Optional, TYPE_CHECKING, Union
|
||||
|
||||
import aiohttp
|
||||
|
||||
from redbot.core import Config
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core.commands import Cog
|
||||
|
||||
from ..errors import YouTubeApiError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .. import Audio
|
||||
|
||||
log = logging.getLogger("red.cogs.Audio.api.YouTube")
|
||||
|
||||
SEARCH_ENDPOINT = "https://www.googleapis.com/youtube/v3/search"
|
||||
|
||||
|
||||
class YouTubeWrapper:
|
||||
"""Wrapper for the YouTube Data API."""
|
||||
|
||||
def __init__(
|
||||
self, bot: Red, config: Config, session: aiohttp.ClientSession, cog: Union["Audio", Cog]
|
||||
):
|
||||
self.bot = bot
|
||||
self.config = config
|
||||
self.session = session
|
||||
self.api_key: Optional[str] = None
|
||||
self._token: Mapping[str, str] = {}
|
||||
self.cog = cog
|
||||
|
||||
def update_token(self, new_token: Mapping[str, str]):
|
||||
self._token = new_token
|
||||
|
||||
async def _get_api_key(self,) -> str:
|
||||
"""Get the stored youtube token"""
|
||||
if not self._token:
|
||||
self._token = await self.bot.get_shared_api_tokens("youtube")
|
||||
self.api_key = self._token.get("api_key", "")
|
||||
return self.api_key if self.api_key is not None else ""
|
||||
|
||||
async def get_call(self, query: str) -> Optional[str]:
|
||||
"""Make a Get call to youtube data api"""
|
||||
params = {
|
||||
"q": query,
|
||||
"part": "id",
|
||||
"key": await self._get_api_key(),
|
||||
"maxResults": 1,
|
||||
"type": "video",
|
||||
}
|
||||
async with self.session.request("GET", SEARCH_ENDPOINT, params=params) as r:
|
||||
if r.status in [400, 404]:
|
||||
return None
|
||||
elif r.status in [403, 429]:
|
||||
if r.reason == "quotaExceeded":
|
||||
raise YouTubeApiError("Your YouTube Data API quota has been reached.")
|
||||
return None
|
||||
else:
|
||||
search_response = await r.json()
|
||||
for search_result in search_response.get("items", []):
|
||||
if search_result["id"]["kind"] == "youtube#video":
|
||||
return f"https://www.youtube.com/watch?v={search_result['id']['videoId']}"
|
||||
|
||||
return None
|
||||
Reference in New Issue
Block a user