Merge V3/feature/audio into V3/develop (a.k.a. audio refactor) (#3459)

This commit is contained in:
Draper
2020-05-20 21:30:06 +01:00
committed by GitHub
parent ef76affd77
commit 8fa47cb789
53 changed files with 12372 additions and 10144 deletions

View File

@@ -0,0 +1,10 @@
from . import (
api_utils,
global_db,
interface,
local_db,
playlist_interface,
playlist_wrapper,
spotify,
youtube,
)

View File

@@ -0,0 +1,140 @@
import datetime
import json
import logging
from collections import namedtuple
from dataclasses import dataclass, field
from typing import List, MutableMapping, Optional, Union
import discord
from redbot.core.bot import Red
from redbot.core.utils.chat_formatting import humanize_list
from ..errors import InvalidPlaylistScope, MissingAuthor, MissingGuild
from ..utils import PlaylistScope
log = logging.getLogger("red.cogs.Audio.api.utils")
@dataclass
class YouTubeCacheFetchResult:
query: Optional[str]
last_updated: int
def __post_init__(self):
if isinstance(self.last_updated, int):
self.updated_on: datetime.datetime = datetime.datetime.fromtimestamp(self.last_updated)
@dataclass
class SpotifyCacheFetchResult:
query: Optional[str]
last_updated: int
def __post_init__(self):
if isinstance(self.last_updated, int):
self.updated_on: datetime.datetime = datetime.datetime.fromtimestamp(self.last_updated)
@dataclass
class LavalinkCacheFetchResult:
query: Optional[MutableMapping]
last_updated: int
def __post_init__(self):
if isinstance(self.last_updated, int):
self.updated_on: datetime.datetime = datetime.datetime.fromtimestamp(self.last_updated)
if isinstance(self.query, str):
self.query = json.loads(self.query)
@dataclass
class LavalinkCacheFetchForGlobalResult:
query: str
data: MutableMapping
def __post_init__(self):
if isinstance(self.data, str):
self.data_string = str(self.data)
self.data = json.loads(self.data)
@dataclass
class PlaylistFetchResult:
playlist_id: int
playlist_name: str
scope_id: int
author_id: int
playlist_url: Optional[str] = None
tracks: List[MutableMapping] = field(default_factory=lambda: [])
def __post_init__(self):
if isinstance(self.tracks, str):
self.tracks = json.loads(self.tracks)
def standardize_scope(scope: str) -> str:
"""Convert any of the used scopes into one we are expecting"""
scope = scope.upper()
valid_scopes = ["GLOBAL", "GUILD", "AUTHOR", "USER", "SERVER", "MEMBER", "BOT"]
if scope in PlaylistScope.list():
return scope
elif scope not in valid_scopes:
raise InvalidPlaylistScope(
f'"{scope}" is not a valid playlist scope.'
f" Scope needs to be one of the following: {humanize_list(valid_scopes)}"
)
if scope in ["GLOBAL", "BOT"]:
scope = PlaylistScope.GLOBAL.value
elif scope in ["GUILD", "SERVER"]:
scope = PlaylistScope.GUILD.value
elif scope in ["USER", "MEMBER", "AUTHOR"]:
scope = PlaylistScope.USER.value
return scope
def prepare_config_scope(
bot: Red,
scope,
author: Union[discord.abc.User, int] = None,
guild: Union[discord.Guild, int] = None,
):
"""Return the scope used by Playlists"""
scope = standardize_scope(scope)
if scope == PlaylistScope.GLOBAL.value:
config_scope = [PlaylistScope.GLOBAL.value, bot.user.id]
elif scope == PlaylistScope.USER.value:
if author is None:
raise MissingAuthor("Invalid author for user scope.")
config_scope = [PlaylistScope.USER.value, int(getattr(author, "id", author))]
else:
if guild is None:
raise MissingGuild("Invalid guild for guild scope.")
config_scope = [PlaylistScope.GUILD.value, int(getattr(guild, "id", guild))]
return config_scope
def prepare_config_scope_for_migration23( # TODO: remove me in a future version ?
scope, author: Union[discord.abc.User, int] = None, guild: discord.Guild = None
):
"""Return the scope used by Playlists"""
scope = standardize_scope(scope)
if scope == PlaylistScope.GLOBAL.value:
config_scope = [PlaylistScope.GLOBAL.value]
elif scope == PlaylistScope.USER.value:
if author is None:
raise MissingAuthor("Invalid author for user scope.")
config_scope = [PlaylistScope.USER.value, str(getattr(author, "id", author))]
else:
if guild is None:
raise MissingGuild("Invalid guild for guild scope.")
config_scope = [PlaylistScope.GUILD.value, str(getattr(guild, "id", guild))]
return config_scope
FakePlaylist = namedtuple("Playlist", "author scope")

View File

@@ -0,0 +1,42 @@
import asyncio
import contextlib
import logging
import urllib.parse
from typing import Mapping, Optional, TYPE_CHECKING, Union
import aiohttp
from lavalink.rest_api import LoadResult
from redbot.core import Config
from redbot.core.bot import Red
from redbot.core.commands import Cog
from ..audio_dataclasses import Query
from ..audio_logging import IS_DEBUG, debug_exc_log
if TYPE_CHECKING:
from .. import Audio
_API_URL = "https://redbot.app/"
log = logging.getLogger("red.cogs.Audio.api.GlobalDB")
class GlobalCacheWrapper:
def __init__(
self, bot: Red, config: Config, session: aiohttp.ClientSession, cog: Union["Audio", Cog]
):
# Place Holder for the Global Cache PR
self.bot = bot
self.config = config
self.session = session
self.api_key = None
self._handshake_token = ""
self.can_write = False
self._handshake_token = ""
self.has_api_key = None
self._token: Mapping[str, str] = {}
self.cog = cog
def update_token(self, new_token: Mapping[str, str]):
self._token = new_token

View File

@@ -0,0 +1,894 @@
import asyncio
import datetime
import json
import logging
import random
import time
from collections import namedtuple
from typing import Callable, List, MutableMapping, Optional, TYPE_CHECKING, Tuple, Union, cast
import aiohttp
import discord
import lavalink
from lavalink.rest_api import LoadResult
from redbot.core.utils import AsyncIter
from redbot.core import Config, commands
from redbot.core.bot import Red
from redbot.core.commands import Cog, Context
from redbot.core.i18n import Translator
from redbot.core.utils.dbtools import APSWConnectionWrapper
from ..audio_dataclasses import Query
from ..audio_logging import IS_DEBUG, debug_exc_log
from ..errors import DatabaseError, SpotifyFetchError, TrackEnqueueError
from ..utils import CacheLevel, Notifier
from .global_db import GlobalCacheWrapper
from .local_db import LocalCacheWrapper
from .playlist_interface import get_playlist
from .playlist_wrapper import PlaylistWrapper
from .spotify import SpotifyWrapper
from .youtube import YouTubeWrapper
if TYPE_CHECKING:
from .. import Audio
_ = Translator("Audio", __file__)
log = logging.getLogger("red.cogs.Audio.api.AudioAPIInterface")
_TOP_100_US = "https://www.youtube.com/playlist?list=PL4fGSI1pDJn5rWitrRWFKdm-ulaFiIyoK"
class AudioAPIInterface:
"""Handles music queries.
Always tries the Local cache first, then Global cache before making API calls.
"""
def __init__(
self,
bot: Red,
config: Config,
session: aiohttp.ClientSession,
conn: APSWConnectionWrapper,
cog: Union["Audio", Cog],
):
self.bot = bot
self.config = config
self.conn = conn
self.cog = cog
self.spotify_api: SpotifyWrapper = SpotifyWrapper(self.bot, self.config, session, self.cog)
self.youtube_api: YouTubeWrapper = YouTubeWrapper(self.bot, self.config, session, self.cog)
self.local_cache_api = LocalCacheWrapper(self.bot, self.config, self.conn, self.cog)
self.global_cache_api = GlobalCacheWrapper(self.bot, self.config, session, self.cog)
self._session: aiohttp.ClientSession = session
self._tasks: MutableMapping = {}
self._lock: asyncio.Lock = asyncio.Lock()
async def initialize(self) -> None:
"""Initialises the Local Cache connection"""
await self.local_cache_api.lavalink.init()
def close(self) -> None:
"""Closes the Local Cache connection"""
self.local_cache_api.lavalink.close()
async def get_random_track_from_db(self) -> Optional[MutableMapping]:
"""Get a random track from the local database and return it"""
track: Optional[MutableMapping] = {}
try:
query_data = {}
date = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=7)
date_timestamp = int(date.timestamp())
query_data["day"] = date_timestamp
max_age = await self.config.cache_age()
maxage = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(
days=max_age
)
maxage_int = int(time.mktime(maxage.timetuple()))
query_data["maxage"] = maxage_int
track = await self.local_cache_api.lavalink.fetch_random(query_data)
if track is not None:
if track.get("loadType") == "V2_COMPACT":
track["loadType"] = "V2_COMPAT"
results = LoadResult(track)
track = random.choice(list(results.tracks))
except Exception as exc:
debug_exc_log(log, exc, "Failed to fetch a random track from database")
track = {}
if not track:
return None
return track
async def route_tasks(
self, action_type: str = None, data: Union[List[MutableMapping], MutableMapping] = None,
) -> None:
"""Separate the tasks and run them in the appropriate functions"""
if not data:
return
if action_type == "insert" and isinstance(data, list):
for table, d in data:
if table == "lavalink":
await self.local_cache_api.lavalink.insert(d)
elif table == "youtube":
await self.local_cache_api.youtube.insert(d)
elif table == "spotify":
await self.local_cache_api.spotify.insert(d)
elif action_type == "update" and isinstance(data, dict):
for table, d in data:
if table == "lavalink":
await self.local_cache_api.lavalink.update(data)
elif table == "youtube":
await self.local_cache_api.youtube.update(data)
elif table == "spotify":
await self.local_cache_api.spotify.update(data)
async def run_tasks(self, ctx: Optional[commands.Context] = None, message_id=None) -> None:
"""Run tasks for a specific context"""
if message_id is not None:
lock_id = message_id
elif ctx is not None:
lock_id = ctx.message.id
else:
return
lock_author = ctx.author if ctx else None
async with self._lock:
if lock_id in self._tasks:
if IS_DEBUG:
log.debug(f"Running database writes for {lock_id} ({lock_author})")
try:
tasks = self._tasks[lock_id]
tasks = [self.route_tasks(a, tasks[a]) for a in tasks]
await asyncio.gather(*tasks, return_exceptions=True)
del self._tasks[lock_id]
except Exception as exc:
debug_exc_log(
log, exc, f"Failed database writes for {lock_id} ({lock_author})"
)
else:
if IS_DEBUG:
log.debug(f"Completed database writes for {lock_id} ({lock_author})")
async def run_all_pending_tasks(self) -> None:
"""Run all pending tasks left in the cache, called on cog_unload"""
async with self._lock:
if IS_DEBUG:
log.debug("Running pending writes to database")
try:
tasks: MutableMapping = {"update": [], "insert": [], "global": []}
async for k, task in AsyncIter(self._tasks.items()):
async for t, args in AsyncIter(task.items()):
tasks[t].append(args)
self._tasks = {}
coro_tasks = [self.route_tasks(a, tasks[a]) for a in tasks]
await asyncio.gather(*coro_tasks, return_exceptions=True)
except Exception as exc:
debug_exc_log(log, exc, "Failed database writes")
else:
if IS_DEBUG:
log.debug("Completed pending writes to database have finished")
def append_task(self, ctx: commands.Context, event: str, task: Tuple, _id: int = None) -> None:
"""Add a task to the cache to be run later"""
lock_id = _id or ctx.message.id
if lock_id not in self._tasks:
self._tasks[lock_id] = {"update": [], "insert": [], "global": []}
self._tasks[lock_id][event].append(task)
async def fetch_spotify_query(
self,
ctx: commands.Context,
query_type: str,
uri: str,
notifier: Optional[Notifier],
skip_youtube: bool = False,
current_cache_level: CacheLevel = CacheLevel.none(),
) -> List[str]:
"""Return youtube URLS for the spotify URL provided"""
youtube_urls = []
tracks = await self.fetch_from_spotify_api(
query_type, uri, params=None, notifier=notifier, ctx=ctx
)
total_tracks = len(tracks)
database_entries = []
track_count = 0
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
youtube_cache = CacheLevel.set_youtube().is_subset(current_cache_level)
async for track in AsyncIter(tracks):
if isinstance(track, str):
break
elif isinstance(track, dict) and track.get("error", {}).get("message") == "invalid id":
continue
(
song_url,
track_info,
uri,
artist_name,
track_name,
_id,
_type,
) = await self.spotify_api.get_spotify_track_info(track, ctx)
database_entries.append(
{
"id": _id,
"type": _type,
"uri": uri,
"track_name": track_name,
"artist_name": artist_name,
"song_url": song_url,
"track_info": track_info,
"last_updated": time_now,
"last_fetched": time_now,
}
)
if skip_youtube is False:
val = None
if youtube_cache:
try:
(val, last_update) = await self.local_cache_api.youtube.fetch_one(
{"track": track_info}
)
except Exception as exc:
debug_exc_log(log, exc, f"Failed to fetch {track_info} from YouTube table")
if val is None:
val = await self.fetch_youtube_query(
ctx, track_info, current_cache_level=current_cache_level
)
if youtube_cache and val:
task = ("update", ("youtube", {"track": track_info}))
self.append_task(ctx, *task)
if val:
youtube_urls.append(val)
else:
youtube_urls.append(track_info)
track_count += 1
if notifier is not None and ((track_count % 2 == 0) or (track_count == total_tracks)):
await notifier.notify_user(current=track_count, total=total_tracks, key="youtube")
if CacheLevel.set_spotify().is_subset(current_cache_level):
task = ("insert", ("spotify", database_entries))
self.append_task(ctx, *task)
return youtube_urls
async def fetch_from_spotify_api(
self,
query_type: str,
uri: str,
recursive: Union[str, bool] = False,
params: MutableMapping = None,
notifier: Optional[Notifier] = None,
ctx: Context = None,
) -> Union[List[MutableMapping], List[str]]:
"""Gets track info from spotify API"""
if recursive is False:
(call, params) = self.spotify_api.spotify_format_call(query_type, uri)
results = await self.spotify_api.make_get_call(call, params)
else:
if isinstance(recursive, str):
results = await self.spotify_api.make_get_call(recursive, params)
else:
results = {}
try:
if results["error"]["status"] == 401 and not recursive:
raise SpotifyFetchError(
_(
"The Spotify API key or client secret has not been set properly. "
"\nUse `{prefix}audioset spotifyapi` for instructions."
)
)
elif recursive:
return {"next": None}
except KeyError:
pass
if recursive:
return results
tracks = []
track_count = 0
total_tracks = results.get("tracks", results).get("total", 1)
while True:
new_tracks: List = []
if query_type == "track":
new_tracks = results
tracks.append(new_tracks)
elif query_type == "album":
tracks_raw = results.get("tracks", results).get("items", [])
if tracks_raw:
new_tracks = tracks_raw
tracks.extend(new_tracks)
else:
tracks_raw = results.get("tracks", results).get("items", [])
if tracks_raw:
new_tracks = [k["track"] for k in tracks_raw if k.get("track")]
tracks.extend(new_tracks)
track_count += len(new_tracks)
if notifier:
await notifier.notify_user(current=track_count, total=total_tracks, key="spotify")
try:
if results.get("next") is not None:
results = await self.fetch_from_spotify_api(
query_type, uri, results["next"], params, notifier=notifier
)
continue
else:
break
except KeyError:
raise SpotifyFetchError(
_("This doesn't seem to be a valid Spotify playlist/album URL or code.")
)
return tracks
async def spotify_query(
self,
ctx: commands.Context,
query_type: str,
uri: str,
skip_youtube: bool = False,
notifier: Optional[Notifier] = None,
) -> List[str]:
"""Queries the Database then falls back to Spotify and YouTube APIs.
Parameters
----------
ctx: commands.Context
The context this method is being called under.
query_type : str
Type of query to perform (Pl
uri: str
Spotify URL ID.
skip_youtube:bool
Whether or not to skip YouTube API Calls.
notifier: Notifier
A Notifier object to handle the user UI notifications while tracks are loaded.
Returns
-------
List[str]
List of Youtube URLs.
"""
current_cache_level = CacheLevel(await self.config.cache_level())
cache_enabled = CacheLevel.set_spotify().is_subset(current_cache_level)
if query_type == "track" and cache_enabled:
try:
(val, last_update) = await self.local_cache_api.spotify.fetch_one(
{"uri": f"spotify:track:{uri}"}
)
except Exception as exc:
debug_exc_log(
log, exc, f"Failed to fetch 'spotify:track:{uri}' from Spotify table"
)
val = None
else:
val = None
youtube_urls = []
if val is None:
urls = await self.fetch_spotify_query(
ctx,
query_type,
uri,
notifier,
skip_youtube,
current_cache_level=current_cache_level,
)
youtube_urls.extend(urls)
else:
if query_type == "track" and cache_enabled:
task = ("update", ("spotify", {"uri": f"spotify:track:{uri}"}))
self.append_task(ctx, *task)
youtube_urls.append(val)
return youtube_urls
async def spotify_enqueue(
self,
ctx: commands.Context,
query_type: str,
uri: str,
enqueue: bool,
player: lavalink.Player,
lock: Callable,
notifier: Optional[Notifier] = None,
forced: bool = False,
query_global: bool = False,
) -> List[lavalink.Track]:
"""Queries the Database then falls back to Spotify and YouTube APIs then Enqueued matched tracks.
Parameters
----------
ctx: commands.Context
The context this method is being called under.
query_type : str
Type of query to perform (Pl
uri: str
Spotify URL ID.
enqueue:bool
Whether or not to enqueue the tracks
player: lavalink.Player
The current Player.
notifier: Notifier
A Notifier object to handle the user UI notifications while tracks are loaded.
lock: Callable
A callable handling the Track enqueue lock while spotify tracks are being added.
query_global: bool
Whether or not to query the global API.
forced: bool
Ignore Cache and make a fetch from API.
Returns
-------
List[str]
List of Youtube URLs.
"""
# globaldb_toggle = await self.config.global_db_enabled()
track_list: List = []
has_not_allowed = False
try:
current_cache_level = CacheLevel(await self.config.cache_level())
guild_data = await self.config.guild(ctx.guild).all()
enqueued_tracks = 0
consecutive_fails = 0
queue_dur = await self.cog.queue_duration(ctx)
queue_total_duration = self.cog.format_time(queue_dur)
before_queue_length = len(player.queue)
tracks_from_spotify = await self.fetch_from_spotify_api(
query_type, uri, params=None, notifier=notifier
)
total_tracks = len(tracks_from_spotify)
if total_tracks < 1 and notifier is not None:
lock(ctx, False)
embed3 = discord.Embed(
colour=await ctx.embed_colour(),
title=_("This doesn't seem to be a supported Spotify URL or code."),
)
await notifier.update_embed(embed3)
return track_list
database_entries = []
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
youtube_cache = CacheLevel.set_youtube().is_subset(current_cache_level)
spotify_cache = CacheLevel.set_spotify().is_subset(current_cache_level)
async for track_count, track in AsyncIter(tracks_from_spotify).enumerate(start=1):
(
song_url,
track_info,
uri,
artist_name,
track_name,
_id,
_type,
) = await self.spotify_api.get_spotify_track_info(track, ctx)
database_entries.append(
{
"id": _id,
"type": _type,
"uri": uri,
"track_name": track_name,
"artist_name": artist_name,
"song_url": song_url,
"track_info": track_info,
"last_updated": time_now,
"last_fetched": time_now,
}
)
val = None
llresponse = None
if youtube_cache:
try:
(val, last_updated) = await self.local_cache_api.youtube.fetch_one(
{"track": track_info}
)
except Exception as exc:
debug_exc_log(log, exc, f"Failed to fetch {track_info} from YouTube table")
if val is None:
val = await self.fetch_youtube_query(
ctx, track_info, current_cache_level=current_cache_level
)
if youtube_cache and val and llresponse is None:
task = ("update", ("youtube", {"track": track_info}))
self.append_task(ctx, *task)
if llresponse is not None:
track_object = llresponse.tracks
elif val:
try:
(result, called_api) = await self.fetch_track(
ctx,
player,
Query.process_input(val, self.cog.local_folder_current_path),
forced=forced,
)
except (RuntimeError, aiohttp.ServerDisconnectedError):
lock(ctx, False)
error_embed = discord.Embed(
colour=await ctx.embed_colour(),
title=_("The connection was reset while loading the playlist."),
)
if notifier is not None:
await notifier.update_embed(error_embed)
break
except asyncio.TimeoutError:
lock(ctx, False)
error_embed = discord.Embed(
colour=await ctx.embed_colour(),
title=_("Player timeout, skipping remaining tracks."),
)
if notifier is not None:
await notifier.update_embed(error_embed)
break
track_object = result.tracks
else:
track_object = []
if (track_count % 2 == 0) or (track_count == total_tracks):
key = "lavalink"
seconds = "???"
second_key = None
if notifier is not None:
await notifier.notify_user(
current=track_count,
total=total_tracks,
key=key,
seconds_key=second_key,
seconds=seconds,
)
if consecutive_fails >= 10:
error_embed = discord.Embed(
colour=await ctx.embed_colour(),
title=_("Failing to get tracks, skipping remaining."),
)
if notifier is not None:
await notifier.update_embed(error_embed)
break
if not track_object:
consecutive_fails += 1
continue
consecutive_fails = 0
single_track = track_object[0]
if not await self.cog.is_query_allowed(
self.config,
ctx.guild,
(
f"{single_track.title} {single_track.author} {single_track.uri} "
f"{Query.process_input(single_track, self.cog.local_folder_current_path)}"
),
):
has_not_allowed = True
if IS_DEBUG:
log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})")
continue
track_list.append(single_track)
if enqueue:
if len(player.queue) >= 10000:
continue
if guild_data["maxlength"] > 0:
if self.cog.is_track_too_long(single_track, guild_data["maxlength"]):
enqueued_tracks += 1
player.add(ctx.author, single_track)
self.bot.dispatch(
"red_audio_track_enqueue",
player.channel.guild,
single_track,
ctx.author,
)
else:
enqueued_tracks += 1
player.add(ctx.author, single_track)
self.bot.dispatch(
"red_audio_track_enqueue",
player.channel.guild,
single_track,
ctx.author,
)
if not player.current:
await player.play()
if not track_list and not has_not_allowed:
raise SpotifyFetchError(
message=_(
"Nothing found.\nThe YouTube API key may be invalid "
"or you may be rate limited on YouTube's search service.\n"
"Check the YouTube API key again and follow the instructions "
"at `{prefix}audioset youtubeapi`."
)
)
player.maybe_shuffle()
if enqueue and tracks_from_spotify:
if total_tracks > enqueued_tracks:
maxlength_msg = _(" {bad_tracks} tracks cannot be queued.").format(
bad_tracks=(total_tracks - enqueued_tracks)
)
else:
maxlength_msg = ""
embed = discord.Embed(
colour=await ctx.embed_colour(),
title=_("Playlist Enqueued"),
description=_("Added {num} tracks to the queue.{maxlength_msg}").format(
num=enqueued_tracks, maxlength_msg=maxlength_msg
),
)
if not guild_data["shuffle"] and queue_dur > 0:
embed.set_footer(
text=_(
"{time} until start of playlist"
" playback: starts at #{position} in queue"
).format(time=queue_total_duration, position=before_queue_length + 1)
)
if notifier is not None:
await notifier.update_embed(embed)
lock(ctx, False)
if spotify_cache:
task = ("insert", ("spotify", database_entries))
self.append_task(ctx, *task)
except Exception as exc:
lock(ctx, False)
raise exc
finally:
lock(ctx, False)
return track_list
async def fetch_youtube_query(
self,
ctx: commands.Context,
track_info: str,
current_cache_level: CacheLevel = CacheLevel.none(),
) -> Optional[str]:
"""
Call the Youtube API and returns the youtube URL that the query matched
"""
track_url = await self.youtube_api.get_call(track_info)
if CacheLevel.set_youtube().is_subset(current_cache_level) and track_url:
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
task = (
"insert",
(
"youtube",
[
{
"track_info": track_info,
"track_url": track_url,
"last_updated": time_now,
"last_fetched": time_now,
}
],
),
)
self.append_task(ctx, *task)
return track_url
async def fetch_from_youtube_api(
self, ctx: commands.Context, track_info: str
) -> Optional[str]:
"""
Gets an YouTube URL from for the query
"""
current_cache_level = CacheLevel(await self.config.cache_level())
cache_enabled = CacheLevel.set_youtube().is_subset(current_cache_level)
val = None
if cache_enabled:
try:
(val, update) = await self.local_cache_api.youtube.fetch_one({"track": track_info})
except Exception as exc:
debug_exc_log(log, exc, f"Failed to fetch {track_info} from YouTube table")
if val is None:
youtube_url = await self.fetch_youtube_query(
ctx, track_info, current_cache_level=current_cache_level
)
else:
if cache_enabled:
task = ("update", ("youtube", {"track": track_info}))
self.append_task(ctx, *task)
youtube_url = val
return youtube_url
async def fetch_track(
self,
ctx: commands.Context,
player: lavalink.Player,
query: Query,
forced: bool = False,
lazy: bool = False,
should_query_global: bool = True,
) -> Tuple[LoadResult, bool]:
"""A replacement for :code:`lavalink.Player.load_tracks`. This will try to get a valid
cached entry first if not found or if in valid it will then call the lavalink API.
Parameters
----------
ctx: commands.Context
The context this method is being called under.
player : lavalink.Player
The player who's requesting the query.
query: audio_dataclasses.Query
The Query object for the query in question.
forced:bool
Whether or not to skip cache and call API first.
lazy:bool
If set to True, it will not call the api if a track is not found.
should_query_global:bool
If the method should query the global database.
Returns
-------
Tuple[lavalink.LoadResult, bool]
Tuple with the Load result and whether or not the API was called.
"""
current_cache_level = CacheLevel(await self.config.cache_level())
cache_enabled = CacheLevel.set_lavalink().is_subset(current_cache_level)
val = None
query = Query.process_input(query, self.cog.local_folder_current_path)
query_string = str(query)
valid_global_entry = False
results = None
called_api = False
prefer_lyrics = await self.cog.get_lyrics_status(ctx)
if prefer_lyrics and query.is_youtube and query.is_search:
query_string = f"{query} - lyrics"
if cache_enabled and not forced and not query.is_local:
try:
(val, last_updated) = await self.local_cache_api.lavalink.fetch_one(
{"query": query_string}
)
except Exception as exc:
debug_exc_log(log, exc, f"Failed to fetch '{query_string}' from Lavalink table")
if val and isinstance(val, dict):
if IS_DEBUG:
log.debug(f"Updating Local Database with {query_string}")
task = ("update", ("lavalink", {"query": query_string}))
self.append_task(ctx, *task)
else:
val = None
if val and not forced and isinstance(val, dict):
valid_global_entry = False
called_api = False
else:
val = None
if valid_global_entry:
pass
elif lazy is True:
called_api = False
elif val and not forced and isinstance(val, dict):
data = val
data["query"] = query_string
if data.get("loadType") == "V2_COMPACT":
data["loadType"] = "V2_COMPAT"
results = LoadResult(data)
called_api = False
if results.has_error:
# If cached value has an invalid entry make a new call so that it gets updated
results, called_api = await self.fetch_track(ctx, player, query, forced=True)
else:
if IS_DEBUG:
log.debug(f"Querying Lavalink api for {query_string}")
called_api = True
try:
results = await player.load_tracks(query_string)
except KeyError:
results = None
except RuntimeError:
raise TrackEnqueueError
if results is None:
results = LoadResult({"loadType": "LOAD_FAILED", "playlistInfo": {}, "tracks": []})
if (
cache_enabled
and results.load_type
and not results.has_error
and not query.is_local
and results.tracks
):
try:
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
data = json.dumps(results._raw)
if all(k in data for k in ["loadType", "playlistInfo", "isSeekable", "isStream"]):
task = (
"insert",
(
"lavalink",
[
{
"query": query_string,
"data": data,
"last_updated": time_now,
"last_fetched": time_now,
}
],
),
)
self.append_task(ctx, *task)
except Exception as exc:
debug_exc_log(
log,
exc,
f"Failed to enqueue write task for '{query_string}' to Lavalink table",
)
return results, called_api
async def autoplay(self, player: lavalink.Player, playlist_api: PlaylistWrapper):
"""
Enqueue a random track
"""
autoplaylist = await self.config.guild(player.channel.guild).autoplaylist()
current_cache_level = CacheLevel(await self.config.cache_level())
cache_enabled = CacheLevel.set_lavalink().is_subset(current_cache_level)
playlist = None
tracks = None
if autoplaylist["enabled"]:
try:
playlist = await get_playlist(
autoplaylist["id"],
autoplaylist["scope"],
self.bot,
playlist_api,
player.channel.guild,
player.channel.guild.me,
)
tracks = playlist.tracks_obj
except Exception as exc:
debug_exc_log(log, exc, "Failed to fetch playlist for autoplay")
if not tracks or not getattr(playlist, "tracks", None):
if cache_enabled:
track = await self.get_random_track_from_db()
tracks = [] if not track else [track]
if not tracks:
ctx = namedtuple("Context", "message guild cog")
(results, called_api) = await self.fetch_track(
cast(
commands.Context, ctx(player.channel.guild, player.channel.guild, self.cog)
),
player,
Query.process_input(_TOP_100_US, self.cog.local_folder_current_path),
)
tracks = list(results.tracks)
if tracks:
multiple = len(tracks) > 1
valid = not multiple
tries = len(tracks)
track = tracks[0]
while valid is False and multiple:
tries -= 1
if tries <= 0:
raise DatabaseError("No valid entry found")
track = random.choice(tracks)
query = Query.process_input(track, self.cog.local_folder_current_path)
await asyncio.sleep(0.001)
if not query.valid or (
query.is_local
and query.local_track_path is not None
and not query.local_track_path.exists()
):
continue
if not await self.cog.is_query_allowed(
self.config,
player.channel.guild,
(
f"{track.title} {track.author} {track.uri} "
f"{str(Query.process_input(track, self.cog.local_folder_current_path))}"
),
):
if IS_DEBUG:
log.debug(
"Query is not allowed in "
f"{player.channel.guild} ({player.channel.guild.id})"
)
continue
valid = True
track.extras["autoplay"] = True
player.add(player.channel.guild.me, track)
self.bot.dispatch(
"red_audio_track_auto_play", player.channel.guild, track, player.channel.guild.me
)
if not player.current:
await player.play()

View File

@@ -0,0 +1,372 @@
import concurrent
import contextlib
import datetime
import logging
import random
import time
from types import SimpleNamespace
from typing import Callable, List, MutableMapping, Optional, TYPE_CHECKING, Tuple, Union
from redbot.core.utils import AsyncIter
from redbot.core import Config
from redbot.core.bot import Red
from redbot.core.commands import Cog
from redbot.core.utils.dbtools import APSWConnectionWrapper
from ..audio_logging import debug_exc_log
from ..sql_statements import (
LAVALINK_CREATE_INDEX,
LAVALINK_CREATE_TABLE,
LAVALINK_DELETE_OLD_ENTRIES,
LAVALINK_FETCH_ALL_ENTRIES_GLOBAL,
LAVALINK_QUERY,
LAVALINK_QUERY_ALL,
LAVALINK_QUERY_LAST_FETCHED_RANDOM,
LAVALINK_UPDATE,
LAVALINK_UPSERT,
SPOTIFY_CREATE_INDEX,
SPOTIFY_CREATE_TABLE,
SPOTIFY_DELETE_OLD_ENTRIES,
SPOTIFY_QUERY,
SPOTIFY_QUERY_ALL,
SPOTIFY_QUERY_LAST_FETCHED_RANDOM,
SPOTIFY_UPDATE,
SPOTIFY_UPSERT,
YOUTUBE_CREATE_INDEX,
YOUTUBE_CREATE_TABLE,
YOUTUBE_DELETE_OLD_ENTRIES,
YOUTUBE_QUERY,
YOUTUBE_QUERY_ALL,
YOUTUBE_QUERY_LAST_FETCHED_RANDOM,
YOUTUBE_UPDATE,
YOUTUBE_UPSERT,
PRAGMA_FETCH_user_version,
PRAGMA_SET_journal_mode,
PRAGMA_SET_read_uncommitted,
PRAGMA_SET_temp_store,
PRAGMA_SET_user_version,
)
from .api_utils import (
LavalinkCacheFetchForGlobalResult,
LavalinkCacheFetchResult,
SpotifyCacheFetchResult,
YouTubeCacheFetchResult,
)
if TYPE_CHECKING:
from .. import Audio
log = logging.getLogger("red.cogs.Audio.api.LocalDB")
_SCHEMA_VERSION = 3
class BaseWrapper:
def __init__(
self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
):
self.bot = bot
self.config = config
self.database = conn
self.statement = SimpleNamespace()
self.statement.pragma_temp_store = PRAGMA_SET_temp_store
self.statement.pragma_journal_mode = PRAGMA_SET_journal_mode
self.statement.pragma_read_uncommitted = PRAGMA_SET_read_uncommitted
self.statement.set_user_version = PRAGMA_SET_user_version
self.statement.get_user_version = PRAGMA_FETCH_user_version
self.fetch_result: Optional[Callable] = None
self.cog = cog
async def init(self) -> None:
"""Initialize the local cache"""
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(self.database.cursor().execute, self.statement.pragma_temp_store)
executor.submit(self.database.cursor().execute, self.statement.pragma_journal_mode)
executor.submit(self.database.cursor().execute, self.statement.pragma_read_uncommitted)
executor.submit(self.maybe_migrate)
executor.submit(self.database.cursor().execute, LAVALINK_CREATE_TABLE)
executor.submit(self.database.cursor().execute, LAVALINK_CREATE_INDEX)
executor.submit(self.database.cursor().execute, YOUTUBE_CREATE_TABLE)
executor.submit(self.database.cursor().execute, YOUTUBE_CREATE_INDEX)
executor.submit(self.database.cursor().execute, SPOTIFY_CREATE_TABLE)
executor.submit(self.database.cursor().execute, SPOTIFY_CREATE_INDEX)
await self.clean_up_old_entries()
def close(self) -> None:
"""Close the connection with the local cache"""
with contextlib.suppress(Exception):
self.database.close()
async def clean_up_old_entries(self) -> None:
"""Delete entries older than x in the local cache tables"""
max_age = await self.config.cache_age()
maxage = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=max_age)
maxage_int = int(time.mktime(maxage.timetuple()))
values = {"maxage": maxage_int}
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(self.database.cursor().execute, LAVALINK_DELETE_OLD_ENTRIES, values)
executor.submit(self.database.cursor().execute, YOUTUBE_DELETE_OLD_ENTRIES, values)
executor.submit(self.database.cursor().execute, SPOTIFY_DELETE_OLD_ENTRIES, values)
def maybe_migrate(self) -> None:
"""Maybe migrate Database schema for the local cache"""
current_version = 0
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
for future in concurrent.futures.as_completed(
[executor.submit(self.database.cursor().execute, self.statement.get_user_version)]
):
try:
row_result = future.result()
current_version = row_result.fetchone()
break
except Exception as exc:
debug_exc_log(log, exc, "Failed to completed fetch from database")
if isinstance(current_version, tuple):
current_version = current_version[0]
if current_version == _SCHEMA_VERSION:
return
executor.submit(
self.database.cursor().execute,
self.statement.set_user_version,
{"version": _SCHEMA_VERSION},
)
async def insert(self, values: List[MutableMapping]) -> None:
"""Insert an entry into the local cache"""
try:
with self.database.transaction() as transaction:
transaction.executemany(self.statement.upsert, values)
except Exception as exc:
debug_exc_log(log, exc, "Error during table insert")
async def update(self, values: MutableMapping) -> None:
"""Update an entry of the local cache"""
try:
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
values["last_fetched"] = time_now
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(self.database.cursor().execute, self.statement.update, values)
except Exception as exc:
debug_exc_log(log, exc, "Error during table update")
async def _fetch_one(
self, values: MutableMapping
) -> Optional[
Union[LavalinkCacheFetchResult, SpotifyCacheFetchResult, YouTubeCacheFetchResult]
]:
"""Get an entry from the local cache"""
max_age = await self.config.cache_age()
maxage = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(days=max_age)
maxage_int = int(time.mktime(maxage.timetuple()))
values.update({"maxage": maxage_int})
row = None
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
for future in concurrent.futures.as_completed(
[executor.submit(self.database.cursor().execute, self.statement.get_one, values)]
):
try:
row_result = future.result()
row = row_result.fetchone()
except Exception as exc:
debug_exc_log(log, exc, "Failed to completed fetch from database")
if not row:
return None
if self.fetch_result is None:
return None
return self.fetch_result(*row)
async def _fetch_all(
self, values: MutableMapping
) -> List[Union[LavalinkCacheFetchResult, SpotifyCacheFetchResult, YouTubeCacheFetchResult]]:
"""Get all entries from the local cache"""
output = []
row_result = []
if self.fetch_result is None:
return []
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
for future in concurrent.futures.as_completed(
[executor.submit(self.database.cursor().execute, self.statement.get_all, values)]
):
try:
row_result = future.result()
except Exception as exc:
debug_exc_log(log, exc, "Failed to completed fetch from database")
async for row in AsyncIter(row_result):
output.append(self.fetch_result(*row))
return output
async def _fetch_random(
self, values: MutableMapping
) -> Optional[
Union[LavalinkCacheFetchResult, SpotifyCacheFetchResult, YouTubeCacheFetchResult]
]:
"""Get a random entry from the local cache"""
row = None
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
for future in concurrent.futures.as_completed(
[
executor.submit(
self.database.cursor().execute, self.statement.get_random, values
)
]
):
try:
row_result = future.result()
rows = row_result.fetchall()
if rows:
row = random.choice(rows)
else:
row = None
except Exception as exc:
debug_exc_log(log, exc, "Failed to completed random fetch from database")
if not row:
return None
if self.fetch_result is None:
return None
return self.fetch_result(*row)
class YouTubeTableWrapper(BaseWrapper):
def __init__(
self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
):
super().__init__(bot, config, conn, cog)
self.statement.upsert = YOUTUBE_UPSERT
self.statement.update = YOUTUBE_UPDATE
self.statement.get_one = YOUTUBE_QUERY
self.statement.get_all = YOUTUBE_QUERY_ALL
self.statement.get_random = YOUTUBE_QUERY_LAST_FETCHED_RANDOM
self.fetch_result = YouTubeCacheFetchResult
async def fetch_one(
self, values: MutableMapping
) -> Tuple[Optional[str], Optional[datetime.datetime]]:
"""Get an entry from the Youtube table"""
result = await self._fetch_one(values)
if not result or not isinstance(result.query, str):
return None, None
return result.query, result.updated_on
async def fetch_all(self, values: MutableMapping) -> List[YouTubeCacheFetchResult]:
"""Get all entries from the Youtube table"""
result = await self._fetch_all(values)
if result and isinstance(result[0], YouTubeCacheFetchResult):
return result
return []
async def fetch_random(self, values: MutableMapping) -> Optional[str]:
"""Get a random entry from the Youtube table"""
result = await self._fetch_random(values)
if not result or not isinstance(result.query, str):
return None
return result.query
class SpotifyTableWrapper(BaseWrapper):
def __init__(
self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
):
super().__init__(bot, config, conn, cog)
self.statement.upsert = SPOTIFY_UPSERT
self.statement.update = SPOTIFY_UPDATE
self.statement.get_one = SPOTIFY_QUERY
self.statement.get_all = SPOTIFY_QUERY_ALL
self.statement.get_random = SPOTIFY_QUERY_LAST_FETCHED_RANDOM
self.fetch_result = SpotifyCacheFetchResult
async def fetch_one(
self, values: MutableMapping
) -> Tuple[Optional[str], Optional[datetime.datetime]]:
"""Get an entry from the Spotify table"""
result = await self._fetch_one(values)
if not result or not isinstance(result.query, str):
return None, None
return result.query, result.updated_on
async def fetch_all(self, values: MutableMapping) -> List[SpotifyCacheFetchResult]:
"""Get all entries from the Spotify table"""
result = await self._fetch_all(values)
if result and isinstance(result[0], SpotifyCacheFetchResult):
return result
return []
async def fetch_random(self, values: MutableMapping) -> Optional[str]:
"""Get a random entry from the Spotify table"""
result = await self._fetch_random(values)
if not result or not isinstance(result.query, str):
return None
return result.query
class LavalinkTableWrapper(BaseWrapper):
def __init__(
self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
):
super().__init__(bot, config, conn, cog)
self.statement.upsert = LAVALINK_UPSERT
self.statement.update = LAVALINK_UPDATE
self.statement.get_one = LAVALINK_QUERY
self.statement.get_all = LAVALINK_QUERY_ALL
self.statement.get_random = LAVALINK_QUERY_LAST_FETCHED_RANDOM
self.statement.get_all_global = LAVALINK_FETCH_ALL_ENTRIES_GLOBAL
self.fetch_result = LavalinkCacheFetchResult
self.fetch_for_global: Optional[Callable] = None
async def fetch_one(
self, values: MutableMapping
) -> Tuple[Optional[MutableMapping], Optional[datetime.datetime]]:
"""Get an entry from the Lavalink table"""
result = await self._fetch_one(values)
if not result or not isinstance(result.query, dict):
return None, None
return result.query, result.updated_on
async def fetch_all(self, values: MutableMapping) -> List[LavalinkCacheFetchResult]:
"""Get all entries from the Lavalink table"""
result = await self._fetch_all(values)
if result and isinstance(result[0], LavalinkCacheFetchResult):
return result
return []
async def fetch_random(self, values: MutableMapping) -> Optional[MutableMapping]:
"""Get a random entry from the Lavalink table"""
result = await self._fetch_random(values)
if not result or not isinstance(result.query, dict):
return None
return result.query
async def fetch_all_for_global(self) -> List[LavalinkCacheFetchForGlobalResult]:
"""Get all entries from the Lavalink table"""
output: List[LavalinkCacheFetchForGlobalResult] = []
row_result = []
if self.fetch_for_global is None:
return []
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
for future in concurrent.futures.as_completed(
[executor.submit(self.database.cursor().execute, self.statement.get_all_global)]
):
try:
row_result = future.result()
except Exception as exc:
debug_exc_log(log, exc, "Failed to completed fetch from database")
async for row in AsyncIter(row_result):
output.append(self.fetch_for_global(*row))
return output
class LocalCacheWrapper:
"""Wraps all table apis into 1 object representing the local cache"""
def __init__(
self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
):
self.bot = bot
self.config = config
self.database = conn
self.cog = cog
self.lavalink: LavalinkTableWrapper = LavalinkTableWrapper(bot, config, conn, self.cog)
self.spotify: SpotifyTableWrapper = SpotifyTableWrapper(bot, config, conn, self.cog)
self.youtube: YouTubeTableWrapper = YouTubeTableWrapper(bot, config, conn, self.cog)

View File

@@ -0,0 +1,647 @@
import logging
from typing import List, MutableMapping, Optional, Union
import discord
import lavalink
from redbot.core.utils import AsyncIter
from redbot.core import Config, commands
from redbot.core.bot import Red
from ..errors import NotAllowed
from ..utils import PlaylistScope
from .api_utils import PlaylistFetchResult, prepare_config_scope, standardize_scope
from .playlist_wrapper import PlaylistWrapper
log = logging.getLogger("red.cogs.Audio.api.PlaylistsInterface")
class Playlist:
"""A single playlist."""
def __init__(
self,
bot: Red,
playlist_api: PlaylistWrapper,
scope: str,
author: int,
playlist_id: int,
name: str,
playlist_url: Optional[str] = None,
tracks: Optional[List[MutableMapping]] = None,
guild: Union[discord.Guild, int, None] = None,
):
self.bot = bot
self.guild = guild
self.scope = standardize_scope(scope)
self.config_scope = prepare_config_scope(self.bot, self.scope, author, guild)
self.scope_id = self.config_scope[-1]
self.author = author
self.author_id = getattr(self.author, "id", self.author)
self.guild_id = (
getattr(guild, "id", guild) if self.scope == PlaylistScope.GLOBAL.value else None
)
self.id = playlist_id
self.name = name
self.url = playlist_url
self.tracks = tracks or []
self.tracks_obj = [lavalink.Track(data=track) for track in self.tracks]
self.playlist_api = playlist_api
def __repr__(self):
return (
f"Playlist(name={self.name}, id={self.id}, scope={self.scope}, "
f"scope_id={self.scope_id}, author={self.author_id}, "
f"tracks={len(self.tracks)}, url={self.url})"
)
async def edit(self, data: MutableMapping):
"""
Edits a Playlist.
Parameters
----------
data: dict
The attributes to change.
"""
# Disallow ID editing
if "id" in data:
raise NotAllowed("Playlist ID cannot be edited.")
for item in list(data.keys()):
setattr(self, item, data[item])
await self.save()
return self
async def save(self):
"""Saves a Playlist."""
scope, scope_id = self.config_scope
await self.playlist_api.upsert(
scope,
playlist_id=int(self.id),
playlist_name=self.name,
scope_id=scope_id,
author_id=self.author_id,
playlist_url=self.url,
tracks=self.tracks,
)
def to_json(self) -> MutableMapping:
"""Transform the object to a dict.
Returns
-------
dict
The playlist in the form of a dict.
"""
data = dict(
id=self.id,
author=self.author_id,
guild=self.guild_id,
name=self.name,
playlist_url=self.url,
tracks=self.tracks,
)
return data
@classmethod
async def from_json(
cls,
bot: Red,
playlist_api: PlaylistWrapper,
scope: str,
playlist_number: int,
data: PlaylistFetchResult,
**kwargs,
) -> "Playlist":
"""Get a Playlist object from the provided information.
Parameters
----------
bot: Red
The bot's instance. Needed to get the target user.
playlist_api: PlaylistWrapper
The Playlist API interface.
scope:str
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
playlist_number: int
The playlist's number.
data: PlaylistFetchResult
The PlaylistFetchResult representation of the playlist to be gotten.
**kwargs
Extra attributes for the Playlist instance which override values
in the data dict. These should be complete objects and not
IDs, where possible.
Returns
-------
Playlist
The playlist object for the requested playlist.
Raises
------
`InvalidPlaylistScope`
Passing a scope that is not supported.
`MissingGuild`
Trying to access the Guild scope without a guild.
`MissingAuthor`
Trying to access the User scope without an user id.
"""
guild = data.scope_id if scope == PlaylistScope.GUILD.value else kwargs.get("guild")
author = data.author_id
playlist_id = data.playlist_id or playlist_number
name = data.playlist_name
playlist_url = data.playlist_url
tracks = data.tracks
return cls(
bot=bot,
playlist_api=playlist_api,
guild=guild,
scope=scope,
author=author,
playlist_id=playlist_id,
name=name,
playlist_url=playlist_url,
tracks=tracks,
)
class PlaylistCompat23:
"""A single playlist, migrating from Schema 2 to Schema 3"""
def __init__(
self,
bot: Red,
playlist_api: PlaylistWrapper,
scope: str,
author: int,
playlist_id: int,
name: str,
playlist_url: Optional[str] = None,
tracks: Optional[List[MutableMapping]] = None,
guild: Union[discord.Guild, int, None] = None,
):
self.bot = bot
self.guild = guild
self.scope = standardize_scope(scope)
self.author = author
self.id = playlist_id
self.name = name
self.url = playlist_url
self.tracks = tracks or []
self.playlist_api = playlist_api
@classmethod
async def from_json(
cls,
bot: Red,
playlist_api: PlaylistWrapper,
scope: str,
playlist_number: int,
data: MutableMapping,
**kwargs,
) -> "PlaylistCompat23":
"""Get a Playlist object from the provided information.
Parameters
----------
bot: Red
The Bot instance.
playlist_api: PlaylistWrapper
The Playlist API interface.
scope:str
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
playlist_number: int
The playlist's number.
data: MutableMapping
The JSON representation of the playlist to be gotten.
**kwargs
Extra attributes for the Playlist instance which override values
in the data dict. These should be complete objects and not
IDs, where possible.
Returns
-------
Playlist
The playlist object for the requested playlist.
Raises
------
`InvalidPlaylistScope`
Passing a scope that is not supported.
`MissingGuild`
Trying to access the Guild scope without a guild.
`MissingAuthor`
Trying to access the User scope without an user id.
"""
guild = data.get("guild") or kwargs.get("guild")
author: int = data.get("author") or 0
playlist_id = data.get("id") or playlist_number
name = data.get("name", "Unnamed")
playlist_url = data.get("playlist_url", None)
tracks = data.get("tracks", [])
return cls(
bot=bot,
playlist_api=playlist_api,
guild=guild,
scope=scope,
author=author,
playlist_id=playlist_id,
name=name,
playlist_url=playlist_url,
tracks=tracks,
)
async def save(self):
"""Saves a Playlist to SQL."""
scope, scope_id = prepare_config_scope(self.bot, self.scope, self.author, self.guild)
await self.playlist_api.upsert(
scope,
playlist_id=int(self.id),
playlist_name=self.name,
scope_id=scope_id,
author_id=self.author,
playlist_url=self.url,
tracks=self.tracks,
)
async def get_all_playlist_for_migration23(
bot: Red,
playlist_api: PlaylistWrapper,
config: Config,
scope: str,
guild: Union[discord.Guild, int] = None,
) -> List[PlaylistCompat23]:
"""
Gets all playlist for the specified scope.
Parameters
----------
bot: Red
The Bot instance.
playlist_api: PlaylistWrapper
The Playlist API interface.
config: Config
The Audio cog Config instance.
scope: str
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
guild: discord.Guild
The guild to get the playlist from if scope is GUILDPLAYLIST.
Returns
-------
list
A list of all playlists for the specified scope
Raises
------
`InvalidPlaylistScope`
Passing a scope that is not supported.
`MissingGuild`
Trying to access the Guild scope without a guild.
`MissingAuthor`
Trying to access the User scope without an user id.
"""
playlists = await config.custom(scope).all()
if scope == PlaylistScope.GLOBAL.value:
return [
await PlaylistCompat23.from_json(
bot,
playlist_api,
scope,
playlist_number,
playlist_data,
guild=guild,
author=int(playlist_data.get("author", 0)),
)
async for playlist_number, playlist_data in AsyncIter(playlists.items())
]
elif scope == PlaylistScope.USER.value:
return [
await PlaylistCompat23.from_json(
bot,
playlist_api,
scope,
playlist_number,
playlist_data,
guild=guild,
author=int(user_id),
)
async for user_id, scopedata in AsyncIter(playlists.items())
async for playlist_number, playlist_data in AsyncIter(scopedata.items())
]
else:
return [
await PlaylistCompat23.from_json(
bot,
playlist_api,
scope,
playlist_number,
playlist_data,
guild=int(guild_id),
author=int(playlist_data.get("author", 0)),
)
async for guild_id, scopedata in AsyncIter(playlists.items())
async for playlist_number, playlist_data in AsyncIter(scopedata.items())
]
async def get_playlist(
playlist_number: int,
scope: str,
bot: Red,
playlist_api: PlaylistWrapper,
guild: Union[discord.Guild, int] = None,
author: Union[discord.abc.User, int] = None,
) -> Playlist:
"""
Gets the playlist with the associated playlist number.
Parameters
----------
playlist_number: int
The playlist number for the playlist to get.
playlist_api: PlaylistWrapper
The Playlist API interface.
scope: str
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
guild: discord.Guild
The guild to get the playlist from if scope is GUILDPLAYLIST.
author: int
The ID of the user to get the playlist from if scope is USERPLAYLIST.
bot: Red
The bot's instance.
Returns
-------
Playlist
The playlist associated with the playlist number.
Raises
------
`RuntimeError`
If there is no playlist for the specified number.
`InvalidPlaylistScope`
Passing a scope that is not supported.
`MissingGuild`
Trying to access the Guild scope without a guild.
`MissingAuthor`
Trying to access the User scope without an user id.
"""
scope_standard, scope_id = prepare_config_scope(bot, scope, author, guild)
playlist_data = await playlist_api.fetch(scope_standard, playlist_number, scope_id)
if not (playlist_data and playlist_data.playlist_id):
raise RuntimeError(f"That playlist does not exist for the following scope: {scope}")
return await Playlist.from_json(
bot,
playlist_api,
scope_standard,
playlist_number,
playlist_data,
guild=guild,
author=author,
)
async def get_all_playlist(
scope: str,
bot: Red,
playlist_api: PlaylistWrapper,
guild: Union[discord.Guild, int] = None,
author: Union[discord.abc.User, int] = None,
specified_user: bool = False,
) -> List[Playlist]:
"""
Gets all playlist for the specified scope.
Parameters
----------
scope: str
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
guild: discord.Guild
The guild to get the playlist from if scope is GUILDPLAYLIST.
author: int
The ID of the user to get the playlist from if scope is USERPLAYLIST.
bot: Red
The bot's instance
playlist_api: PlaylistWrapper
The Playlist API interface.
specified_user:bool
Whether or not user ID was passed as an argparse.
Returns
-------
list
A list of all playlists for the specified scope
Raises
------
`InvalidPlaylistScope`
Passing a scope that is not supported.
`MissingGuild`
Trying to access the Guild scope without a guild.
`MissingAuthor`
Trying to access the User scope without an user id.
"""
scope_standard, scope_id = prepare_config_scope(bot, scope, author, guild)
if specified_user:
user_id = getattr(author, "id", author)
playlists = await playlist_api.fetch_all(scope_standard, scope_id, author_id=user_id)
else:
playlists = await playlist_api.fetch_all(scope_standard, scope_id)
playlist_list = []
async for playlist in AsyncIter(playlists):
playlist_list.append(
await Playlist.from_json(
bot,
playlist_api,
scope,
playlist.playlist_id,
playlist,
guild=guild,
author=author,
)
)
return playlist_list
async def get_all_playlist_converter(
scope: str,
bot: Red,
playlist_api: PlaylistWrapper,
arg: str,
guild: Union[discord.Guild, int] = None,
author: Union[discord.abc.User, int] = None,
) -> List[Playlist]:
"""
Gets all playlist for the specified scope.
Parameters
----------
scope: str
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
guild: discord.Guild
The guild to get the playlist from if scope is GUILDPLAYLIST.
author: int
The ID of the user to get the playlist from if scope is USERPLAYLIST.
bot: Red
The bot's instance
arg:str
The value to lookup.
playlist_api: PlaylistWrapper
The Playlist API interface.
Returns
-------
list
A list of all playlists for the specified scope
Raises
------
`InvalidPlaylistScope`
Passing a scope that is not supported.
`MissingGuild`
Trying to access the Guild scope without a guild.
`MissingAuthor`
Trying to access the User scope without an user id.
"""
scope_standard, scope_id = prepare_config_scope(bot, scope, author, guild)
playlists = await playlist_api.fetch_all_converter(
scope_standard, playlist_name=arg, playlist_id=arg
)
playlist_list = []
async for playlist in AsyncIter(playlists):
playlist_list.append(
await Playlist.from_json(
bot,
playlist_api,
scope,
playlist.playlist_id,
playlist,
guild=guild,
author=author,
)
)
return playlist_list
async def create_playlist(
ctx: commands.Context,
playlist_api: PlaylistWrapper,
scope: str,
playlist_name: str,
playlist_url: Optional[str] = None,
tracks: Optional[List[MutableMapping]] = None,
author: Optional[discord.User] = None,
guild: Optional[discord.Guild] = None,
) -> Optional[Playlist]:
"""Creates a new Playlist.
Parameters
----------
ctx: commands.Context
The context in which the play list is being created.
scope: str
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
playlist_name: str
The name of the new playlist.
playlist_url:str
the url of the new playlist.
tracks: List[MutableMapping]
A list of tracks to add to the playlist.
author: discord.User
The Author of the playlist.
If provided it will create a playlist under this user.
This is only required when creating a playlist in User scope.
guild: discord.Guild
The guild to create this playlist under.
This is only used when creating a playlist in the Guild scope
playlist_api: PlaylistWrapper
The Playlist API interface.
Raises
------
`InvalidPlaylistScope`
Passing a scope that is not supported.
`MissingGuild`
Trying to access the Guild scope without a guild.
`MissingAuthor`
Trying to access the User scope without an user id.
"""
playlist = Playlist(
ctx.bot,
playlist_api,
scope,
author.id if author else None,
ctx.message.id,
playlist_name,
playlist_url,
tracks,
guild or ctx.guild,
)
await playlist.save()
return playlist
async def reset_playlist(
bot: Red,
playlist_api: PlaylistWrapper,
scope: str,
guild: Union[discord.Guild, int] = None,
author: Union[discord.abc.User, int] = None,
) -> None:
"""Wipes all playlists for the specified scope.
Parameters
----------
bot: Red
The bot's instance
scope: str
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
guild: discord.Guild
The guild to get the playlist from if scope is GUILDPLAYLIST.
author: int
The ID of the user to get the playlist from if scope is USERPLAYLIST.
playlist_api: PlaylistWrapper
The Playlist API interface.
Raises
------
`InvalidPlaylistScope`
Passing a scope that is not supported.
`MissingGuild`
Trying to access the Guild scope without a guild.
`MissingAuthor`
Trying to access the User scope without an user id.
"""
scope, scope_id = prepare_config_scope(bot, scope, author, guild)
await playlist_api.drop(scope)
await playlist_api.create_table()
async def delete_playlist(
bot: Red,
playlist_api: PlaylistWrapper,
scope: str,
playlist_id: Union[str, int],
guild: discord.Guild,
author: Union[discord.abc.User, int] = None,
) -> None:
"""Deletes the specified playlist.
Parameters
----------
bot: Red
The bot's instance
scope: str
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
playlist_id: Union[str, int]
The ID of the playlist.
guild: discord.Guild
The guild to get the playlist from if scope is GUILDPLAYLIST.
author: int
The ID of the user to get the playlist from if scope is USERPLAYLIST.
playlist_api: PlaylistWrapper
The Playlist API interface.
Raises
------
`InvalidPlaylistScope`
Passing a scope that is not supported.
`MissingGuild`
Trying to access the Guild scope without a guild.
`MissingAuthor`
Trying to access the User scope without an user id.
"""
scope, scope_id = prepare_config_scope(bot, scope, author, guild)
await playlist_api.delete(scope, int(playlist_id), scope_id)

View File

@@ -0,0 +1,249 @@
import concurrent
import json
import logging
from types import SimpleNamespace
from typing import List, MutableMapping, Optional
from redbot.core.utils import AsyncIter
from redbot.core import Config
from redbot.core.bot import Red
from redbot.core.utils.dbtools import APSWConnectionWrapper
from ..audio_logging import debug_exc_log
from ..sql_statements import (
PLAYLIST_CREATE_INDEX,
PLAYLIST_CREATE_TABLE,
PLAYLIST_DELETE,
PLAYLIST_DELETE_SCHEDULED,
PLAYLIST_DELETE_SCOPE,
PLAYLIST_FETCH,
PLAYLIST_FETCH_ALL,
PLAYLIST_FETCH_ALL_CONVERTER,
PLAYLIST_FETCH_ALL_WITH_FILTER,
PLAYLIST_UPSERT,
PRAGMA_FETCH_user_version,
PRAGMA_SET_journal_mode,
PRAGMA_SET_read_uncommitted,
PRAGMA_SET_temp_store,
PRAGMA_SET_user_version,
)
from ..utils import PlaylistScope
from .api_utils import PlaylistFetchResult
log = logging.getLogger("red.cogs.Audio.api.Playlists")
class PlaylistWrapper:
def __init__(self, bot: Red, config: Config, conn: APSWConnectionWrapper):
self.bot = bot
self.database = conn
self.config = config
self.statement = SimpleNamespace()
self.statement.pragma_temp_store = PRAGMA_SET_temp_store
self.statement.pragma_journal_mode = PRAGMA_SET_journal_mode
self.statement.pragma_read_uncommitted = PRAGMA_SET_read_uncommitted
self.statement.set_user_version = PRAGMA_SET_user_version
self.statement.get_user_version = PRAGMA_FETCH_user_version
self.statement.create_table = PLAYLIST_CREATE_TABLE
self.statement.create_index = PLAYLIST_CREATE_INDEX
self.statement.upsert = PLAYLIST_UPSERT
self.statement.delete = PLAYLIST_DELETE
self.statement.delete_scope = PLAYLIST_DELETE_SCOPE
self.statement.delete_scheduled = PLAYLIST_DELETE_SCHEDULED
self.statement.get_one = PLAYLIST_FETCH
self.statement.get_all = PLAYLIST_FETCH_ALL
self.statement.get_all_with_filter = PLAYLIST_FETCH_ALL_WITH_FILTER
self.statement.get_all_converter = PLAYLIST_FETCH_ALL_CONVERTER
async def init(self) -> None:
"""Initialize the Playlist table"""
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(self.database.cursor().execute, self.statement.pragma_temp_store)
executor.submit(self.database.cursor().execute, self.statement.pragma_journal_mode)
executor.submit(self.database.cursor().execute, self.statement.pragma_read_uncommitted)
executor.submit(self.database.cursor().execute, self.statement.create_table)
executor.submit(self.database.cursor().execute, self.statement.create_index)
@staticmethod
def get_scope_type(scope: str) -> int:
"""Convert a scope to a numerical identifier"""
if scope == PlaylistScope.GLOBAL.value:
table = 1
elif scope == PlaylistScope.USER.value:
table = 3
else:
table = 2
return table
async def fetch(self, scope: str, playlist_id: int, scope_id: int) -> PlaylistFetchResult:
"""Fetch a single playlist"""
scope_type = self.get_scope_type(scope)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
for future in concurrent.futures.as_completed(
[
executor.submit(
self.database.cursor().execute,
self.statement.get_one,
(
{
"playlist_id": playlist_id,
"scope_id": scope_id,
"scope_type": scope_type,
}
),
)
]
):
try:
row_result = future.result()
except Exception as exc:
debug_exc_log(log, exc, "Failed to completed playlist fetch from database")
row = row_result.fetchone()
if row:
row = PlaylistFetchResult(*row)
return row
async def fetch_all(
self, scope: str, scope_id: int, author_id=None
) -> List[PlaylistFetchResult]:
"""Fetch all playlists"""
scope_type = self.get_scope_type(scope)
output = []
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
if author_id is not None:
for future in concurrent.futures.as_completed(
[
executor.submit(
self.database.cursor().execute,
self.statement.get_all_with_filter,
(
{
"scope_type": scope_type,
"scope_id": scope_id,
"author_id": author_id,
}
),
)
]
):
try:
row_result = future.result()
except Exception as exc:
debug_exc_log(log, exc, "Failed to completed playlist fetch from database")
return []
else:
for future in concurrent.futures.as_completed(
[
executor.submit(
self.database.cursor().execute,
self.statement.get_all,
({"scope_type": scope_type, "scope_id": scope_id}),
)
]
):
try:
row_result = future.result()
except Exception as exc:
debug_exc_log(log, exc, "Failed to completed playlist fetch from database")
return []
async for row in AsyncIter(row_result):
output.append(PlaylistFetchResult(*row))
return output
async def fetch_all_converter(
self, scope: str, playlist_name, playlist_id
) -> List[PlaylistFetchResult]:
"""Fetch all playlists with the specified filter"""
scope_type = self.get_scope_type(scope)
try:
playlist_id = int(playlist_id)
except Exception as exc:
debug_exc_log(log, exc, "Failed converting playlist_id to int")
playlist_id = -1
output = []
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
for future in concurrent.futures.as_completed(
[
executor.submit(
self.database.cursor().execute,
self.statement.get_all_converter,
(
{
"scope_type": scope_type,
"playlist_name": playlist_name,
"playlist_id": playlist_id,
}
),
)
]
):
try:
row_result = future.result()
except Exception as exc:
debug_exc_log(log, exc, "Failed to completed fetch from database")
async for row in AsyncIter(row_result):
output.append(PlaylistFetchResult(*row))
return output
async def delete(self, scope: str, playlist_id: int, scope_id: int):
"""Deletes a single playlists"""
scope_type = self.get_scope_type(scope)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(
self.database.cursor().execute,
self.statement.delete,
({"playlist_id": playlist_id, "scope_id": scope_id, "scope_type": scope_type}),
)
async def delete_scheduled(self):
"""Clean up database from all deleted playlists"""
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(self.database.cursor().execute, self.statement.delete_scheduled)
async def drop(self, scope: str):
"""Delete all playlists in a scope"""
scope_type = self.get_scope_type(scope)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(
self.database.cursor().execute,
self.statement.delete_scope,
({"scope_type": scope_type}),
)
async def create_table(self):
"""Create the playlist table"""
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(self.database.cursor().execute, PLAYLIST_CREATE_TABLE)
async def upsert(
self,
scope: str,
playlist_id: int,
playlist_name: str,
scope_id: int,
author_id: int,
playlist_url: Optional[str],
tracks: List[MutableMapping],
):
"""Insert or update a playlist into the database"""
scope_type = self.get_scope_type(scope)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(
self.database.cursor().execute,
self.statement.upsert,
{
"scope_type": str(scope_type),
"playlist_id": int(playlist_id),
"playlist_name": str(playlist_name),
"scope_id": int(scope_id),
"author_id": int(author_id),
"playlist_url": playlist_url,
"tracks": json.dumps(tracks),
},
)

View File

@@ -0,0 +1,189 @@
import base64
import contextlib
import logging
import time
from typing import List, Mapping, MutableMapping, Optional, TYPE_CHECKING, Tuple, Union
import aiohttp
from redbot.core.i18n import Translator
from redbot.core.utils import AsyncIter
from redbot.core import Config
from redbot.core.bot import Red
from redbot.core.commands import Cog, Context
from ..errors import SpotifyFetchError
if TYPE_CHECKING:
from .. import Audio
_ = Translator("Audio", __file__)
log = logging.getLogger("red.cogs.Audio.api.Spotify")
CATEGORY_ENDPOINT = "https://api.spotify.com/v1/browse/categories"
TOKEN_ENDPOINT = "https://accounts.spotify.com/api/token"
ALBUMS_ENDPOINT = "https://api.spotify.com/v1/albums"
TRACKS_ENDPOINT = "https://api.spotify.com/v1/tracks"
PLAYLISTS_ENDPOINT = "https://api.spotify.com/v1/playlists"
class SpotifyWrapper:
"""Wrapper for the Spotify API."""
def __init__(
self, bot: Red, config: Config, session: aiohttp.ClientSession, cog: Union["Audio", Cog]
):
self.bot = bot
self.config = config
self.session = session
self.spotify_token: Optional[MutableMapping] = None
self.client_id: Optional[str] = None
self.client_secret: Optional[str] = None
self._token: Mapping[str, str] = {}
self.cog = cog
@staticmethod
def spotify_format_call(query_type: str, key: str) -> Tuple[str, MutableMapping]:
"""Format the spotify endpoint"""
params: MutableMapping = {}
if query_type == "album":
query = f"{ALBUMS_ENDPOINT}/{key}/tracks"
elif query_type == "track":
query = f"{TRACKS_ENDPOINT}/{key}"
else:
query = f"{PLAYLISTS_ENDPOINT}/{key}/tracks"
return query, params
async def get_spotify_track_info(
self, track_data: MutableMapping, ctx: Context
) -> Tuple[str, ...]:
"""Extract track info from spotify response"""
prefer_lyrics = await self.cog.get_lyrics_status(ctx)
track_name = track_data["name"]
if prefer_lyrics:
track_name = f"{track_name} - lyrics"
artist_name = track_data["artists"][0]["name"]
track_info = f"{track_name} {artist_name}"
song_url = track_data.get("external_urls", {}).get("spotify")
uri = track_data["uri"]
_id = track_data["id"]
_type = track_data["type"]
return song_url, track_info, uri, artist_name, track_name, _id, _type
@staticmethod
async def is_access_token_valid(token: MutableMapping) -> bool:
"""Check if current token is not too old"""
return (token["expires_at"] - int(time.time())) < 60
@staticmethod
def make_auth_header(
client_id: Optional[str], client_secret: Optional[str]
) -> MutableMapping[str, Union[str, int]]:
"""Make Authorization header for spotify token"""
if client_id is None:
client_id = ""
if client_secret is None:
client_secret = ""
auth_header = base64.b64encode(f"{client_id}:{client_secret}".encode("ascii"))
return {"Authorization": f"Basic {auth_header.decode('ascii')}"}
async def get(
self, url: str, headers: MutableMapping = None, params: MutableMapping = None
) -> MutableMapping[str, str]:
"""Make a GET request to the spotify API"""
if params is None:
params = {}
async with self.session.request("GET", url, params=params, headers=headers) as r:
data = await r.json()
if r.status != 200:
log.debug(f"Issue making GET request to {url}: [{r.status}] {data}")
return data
def update_token(self, new_token: Mapping[str, str]):
self._token = new_token
async def get_token(self) -> None:
"""Get the stored spotify tokens"""
if not self._token:
self._token = await self.bot.get_shared_api_tokens("spotify")
self.client_id = self._token.get("client_id", "")
self.client_secret = self._token.get("client_secret", "")
async def get_country_code(self, ctx: Context = None) -> str:
return await self.config.guild(ctx.guild).country_code() if ctx else "US"
async def request_access_token(self) -> MutableMapping:
"""Make a spotify call to get the auth token"""
await self.get_token()
payload = {"grant_type": "client_credentials"}
headers = self.make_auth_header(self.client_id, self.client_secret)
r = await self.post(TOKEN_ENDPOINT, payload=payload, headers=headers)
return r
async def get_access_token(self) -> Optional[str]:
"""Get the access_token"""
if self.spotify_token and not await self.is_access_token_valid(self.spotify_token):
return self.spotify_token["access_token"]
token = await self.request_access_token()
if token is None:
log.debug("Requested a token from Spotify, did not end up getting one.")
try:
token["expires_at"] = int(time.time()) + int(token["expires_in"])
except KeyError:
return None
self.spotify_token = token
log.debug(f"Created a new access token for Spotify: {token}")
return self.spotify_token["access_token"]
async def post(
self, url: str, payload: MutableMapping, headers: MutableMapping = None
) -> MutableMapping:
"""Make a POST call to spotify"""
async with self.session.post(url, data=payload, headers=headers) as r:
data = await r.json()
if r.status != 200:
log.debug(f"Issue making POST request to {url}: [{r.status}] {data}")
return data
async def make_get_call(self, url: str, params: MutableMapping) -> MutableMapping:
"""Make a Get call to spotify"""
token = await self.get_access_token()
return await self.get(url, params=params, headers={"Authorization": f"Bearer {token}"})
async def get_categories(self, ctx: Context = None) -> List[MutableMapping]:
"""Get the spotify categories"""
country_code = await self.get_country_code(ctx=ctx)
params: MutableMapping = {"country": country_code} if country_code else {}
result = await self.make_get_call(CATEGORY_ENDPOINT, params=params)
with contextlib.suppress(KeyError):
if result["error"]["status"] == 401:
raise SpotifyFetchError(
message=_(
"The Spotify API key or client secret has not been set properly. "
"\nUse `{prefix}audioset spotifyapi` for instructions."
)
)
categories = result.get("categories", {}).get("items", [])
return [{c["name"]: c["id"]} for c in categories if c]
async def get_playlist_from_category(self, category: str, ctx: Context = None):
"""Get spotify playlists for the specified category"""
url = f"{CATEGORY_ENDPOINT}/{category}/playlists"
country_code = await self.get_country_code(ctx=ctx)
params: MutableMapping = {"country": country_code} if country_code else {}
result = await self.make_get_call(url, params=params)
playlists = result.get("playlists", {}).get("items", [])
return [
{
"name": c["name"],
"uri": c["uri"],
"url": c.get("external_urls", {}).get("spotify"),
"tracks": c.get("tracks", {}).get("total", "Unknown"),
}
async for c in AsyncIter(playlists)
if c
]

View File

@@ -0,0 +1,65 @@
import logging
from typing import Mapping, Optional, TYPE_CHECKING, Union
import aiohttp
from redbot.core import Config
from redbot.core.bot import Red
from redbot.core.commands import Cog
from ..errors import YouTubeApiError
if TYPE_CHECKING:
from .. import Audio
log = logging.getLogger("red.cogs.Audio.api.YouTube")
SEARCH_ENDPOINT = "https://www.googleapis.com/youtube/v3/search"
class YouTubeWrapper:
"""Wrapper for the YouTube Data API."""
def __init__(
self, bot: Red, config: Config, session: aiohttp.ClientSession, cog: Union["Audio", Cog]
):
self.bot = bot
self.config = config
self.session = session
self.api_key: Optional[str] = None
self._token: Mapping[str, str] = {}
self.cog = cog
def update_token(self, new_token: Mapping[str, str]):
self._token = new_token
async def _get_api_key(self,) -> str:
"""Get the stored youtube token"""
if not self._token:
self._token = await self.bot.get_shared_api_tokens("youtube")
self.api_key = self._token.get("api_key", "")
return self.api_key if self.api_key is not None else ""
async def get_call(self, query: str) -> Optional[str]:
"""Make a Get call to youtube data api"""
params = {
"q": query,
"part": "id",
"key": await self._get_api_key(),
"maxResults": 1,
"type": "video",
}
async with self.session.request("GET", SEARCH_ENDPOINT, params=params) as r:
if r.status in [400, 404]:
return None
elif r.status in [403, 429]:
if r.reason == "quotaExceeded":
raise YouTubeApiError("Your YouTube Data API quota has been reached.")
return None
else:
search_response = await r.json()
for search_result in search_response.get("items", []):
if search_result["id"]["kind"] == "youtube#video":
return f"https://www.youtube.com/watch?v={search_result['id']['videoId']}"
return None