mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-06 11:18:54 -05:00
* Limit Playlists Signed-off-by: Drapersniper <27962761+drapersniper@users.noreply.github.com> * logging improvements Signed-off-by: Drapersniper <27962761+drapersniper@users.noreply.github.com> * logging improvements Signed-off-by: Drapersniper <27962761+drapersniper@users.noreply.github.com> * sigh Signed-off-by: Drapersniper <27962761+drapersniper@users.noreply.github.com> * chore Signed-off-by: Drapersniper <27962761+drapersniper@users.noreply.github.com>
943 lines
36 KiB
Python
943 lines
36 KiB
Python
import asyncio
|
|
import base64
|
|
import contextlib
|
|
import datetime
|
|
import json
|
|
import logging
|
|
import random
|
|
import time
|
|
from collections import namedtuple
|
|
from typing import Callable, List, MutableMapping, Optional, TYPE_CHECKING, Tuple, Union, NoReturn
|
|
|
|
import aiohttp
|
|
import discord
|
|
import lavalink
|
|
from lavalink.rest_api import LoadResult
|
|
|
|
from redbot.core import Config, commands
|
|
from redbot.core.bot import Red
|
|
from redbot.core.i18n import Translator, cog_i18n
|
|
|
|
from . import audio_dataclasses
|
|
from .databases import CacheInterface, SQLError
|
|
from .errors import DatabaseError, SpotifyFetchError, YouTubeApiError, TrackEnqueueError
|
|
from .playlists import get_playlist
|
|
from .utils import CacheLevel, Notifier, is_allowed, queue_duration, track_limit
|
|
|
|
log = logging.getLogger("red.audio.cache")
|
|
_ = Translator("Audio", __file__)
|
|
|
|
_TOP_100_GLOBALS = "https://www.youtube.com/playlist?list=PL4fGSI1pDJn6puJdseH2Rt9sMvt9E2M4i"
|
|
_TOP_100_US = "https://www.youtube.com/playlist?list=PL4fGSI1pDJn5rWitrRWFKdm-ulaFiIyoK"
|
|
|
|
if TYPE_CHECKING:
|
|
_database: CacheInterface
|
|
_bot: Red
|
|
_config: Config
|
|
else:
|
|
_database = None
|
|
_bot = None
|
|
_config = None
|
|
|
|
|
|
def _pass_config_to_apis(config: Config, bot: Red):
|
|
global _database, _config, _bot
|
|
if _config is None:
|
|
_config = config
|
|
if _bot is None:
|
|
_bot = bot
|
|
if _database is None:
|
|
_database = CacheInterface()
|
|
|
|
|
|
class SpotifyAPI:
|
|
"""Wrapper for the Spotify API."""
|
|
|
|
def __init__(self, bot: Red, session: aiohttp.ClientSession):
|
|
self.bot = bot
|
|
self.session = session
|
|
self.spotify_token: Optional[MutableMapping[str, Union[str, int]]] = None
|
|
self.client_id = None
|
|
self.client_secret = None
|
|
|
|
@staticmethod
|
|
async def _check_token(token: MutableMapping):
|
|
now = int(time.time())
|
|
return token["expires_at"] - now < 60
|
|
|
|
@staticmethod
|
|
def _make_token_auth(
|
|
client_id: Optional[str], client_secret: Optional[str]
|
|
) -> MutableMapping[str, Union[str, int]]:
|
|
if client_id is None:
|
|
client_id = ""
|
|
if client_secret is None:
|
|
client_secret = ""
|
|
|
|
auth_header = base64.b64encode((client_id + ":" + client_secret).encode("ascii"))
|
|
return {"Authorization": "Basic %s" % auth_header.decode("ascii")}
|
|
|
|
async def _make_get(
|
|
self, url: str, headers: MutableMapping = None, params: MutableMapping = None
|
|
) -> MutableMapping[str, str]:
|
|
if params is None:
|
|
params = {}
|
|
async with self.session.request("GET", url, params=params, headers=headers) as r:
|
|
if r.status != 200:
|
|
log.debug(
|
|
"Issue making GET request to {0}: [{1.status}] {2}".format(
|
|
url, r, await r.json()
|
|
)
|
|
)
|
|
return await r.json()
|
|
|
|
async def _get_auth(self) -> NoReturn:
|
|
tokens = await self.bot.get_shared_api_tokens("spotify")
|
|
self.client_id = tokens.get("client_id", "")
|
|
self.client_secret = tokens.get("client_secret", "")
|
|
|
|
async def _request_token(self) -> MutableMapping[str, Union[str, int]]:
|
|
await self._get_auth()
|
|
|
|
payload = {"grant_type": "client_credentials"}
|
|
headers = self._make_token_auth(self.client_id, self.client_secret)
|
|
r = await self.post_call(
|
|
"https://accounts.spotify.com/api/token", payload=payload, headers=headers
|
|
)
|
|
return r
|
|
|
|
async def _get_spotify_token(self) -> Optional[str]:
|
|
if self.spotify_token and not await self._check_token(self.spotify_token):
|
|
return self.spotify_token["access_token"]
|
|
token = await self._request_token()
|
|
if token is None:
|
|
log.debug("Requested a token from Spotify, did not end up getting one.")
|
|
try:
|
|
token["expires_at"] = int(time.time()) + token["expires_in"]
|
|
except KeyError:
|
|
return
|
|
self.spotify_token = token
|
|
log.debug("Created a new access token for Spotify: {0}".format(token))
|
|
return self.spotify_token["access_token"]
|
|
|
|
async def post_call(
|
|
self, url: str, payload: MutableMapping, headers: MutableMapping = None
|
|
) -> MutableMapping[str, Union[str, int]]:
|
|
async with self.session.post(url, data=payload, headers=headers) as r:
|
|
if r.status != 200:
|
|
log.debug(
|
|
"Issue making POST request to {0}: [{1.status}] {2}".format(
|
|
url, r, await r.json()
|
|
)
|
|
)
|
|
return await r.json()
|
|
|
|
async def get_call(
|
|
self, url: str, params: MutableMapping
|
|
) -> MutableMapping[str, Union[str, int]]:
|
|
token = await self._get_spotify_token()
|
|
return await self._make_get(
|
|
url, params=params, headers={"Authorization": "Bearer {0}".format(token)}
|
|
)
|
|
|
|
async def get_categories(self) -> List[MutableMapping]:
|
|
url = "https://api.spotify.com/v1/browse/categories"
|
|
params = {}
|
|
result = await self.get_call(url, params=params)
|
|
with contextlib.suppress(KeyError):
|
|
if result["error"]["status"] == 401:
|
|
raise SpotifyFetchError(
|
|
message=(
|
|
"The Spotify API key or client secret has not been set properly. "
|
|
"\nUse `{prefix}audioset spotifyapi` for instructions."
|
|
)
|
|
)
|
|
categories = result.get("categories", {}).get("items", [])
|
|
return [{c["name"]: c["id"]} for c in categories]
|
|
|
|
async def get_playlist_from_category(self, category: str):
|
|
url = f"https://api.spotify.com/v1/browse/categories/{category}/playlists"
|
|
params = {}
|
|
result = await self.get_call(url, params=params)
|
|
playlists = result.get("playlists", {}).get("items", [])
|
|
return [
|
|
{
|
|
"name": c["name"],
|
|
"uri": c["uri"],
|
|
"url": c.get("external_urls", {}).get("spotify"),
|
|
"tracks": c.get("tracks", {}).get("total", "Unknown"),
|
|
}
|
|
for c in playlists
|
|
]
|
|
|
|
|
|
class YouTubeAPI:
|
|
"""Wrapper for the YouTube Data API."""
|
|
|
|
def __init__(self, bot: Red, session: aiohttp.ClientSession):
|
|
self.bot = bot
|
|
self.session = session
|
|
self.api_key = None
|
|
|
|
async def _get_api_key(self,) -> str:
|
|
tokens = await self.bot.get_shared_api_tokens("youtube")
|
|
self.api_key = tokens.get("api_key", "")
|
|
return self.api_key
|
|
|
|
async def get_call(self, query: str) -> Optional[str]:
|
|
params = {
|
|
"q": query,
|
|
"part": "id",
|
|
"key": await self._get_api_key(),
|
|
"maxResults": 1,
|
|
"type": "video",
|
|
}
|
|
yt_url = "https://www.googleapis.com/youtube/v3/search"
|
|
async with self.session.request("GET", yt_url, params=params) as r:
|
|
if r.status in [400, 404]:
|
|
return None
|
|
elif r.status in [403, 429]:
|
|
if r.reason == "quotaExceeded":
|
|
raise YouTubeApiError("Your YouTube Data API quota has been reached.")
|
|
|
|
return None
|
|
else:
|
|
search_response = await r.json()
|
|
for search_result in search_response.get("items", []):
|
|
if search_result["id"]["kind"] == "youtube#video":
|
|
return f"https://www.youtube.com/watch?v={search_result['id']['videoId']}"
|
|
|
|
|
|
@cog_i18n(_)
|
|
class MusicCache:
|
|
"""Handles music queries to the Spotify and Youtube Data API.
|
|
|
|
Always tries the Cache first.
|
|
"""
|
|
|
|
def __init__(self, bot: Red, session: aiohttp.ClientSession):
|
|
self.bot = bot
|
|
self.spotify_api: SpotifyAPI = SpotifyAPI(bot, session)
|
|
self.youtube_api: YouTubeAPI = YouTubeAPI(bot, session)
|
|
self._session: aiohttp.ClientSession = session
|
|
self.database = _database
|
|
|
|
self._tasks: MutableMapping = {}
|
|
self._lock: asyncio.Lock = asyncio.Lock()
|
|
self.config: Optional[Config] = None
|
|
|
|
async def initialize(self, config: Config):
|
|
self.config = config
|
|
await _database.init()
|
|
|
|
@staticmethod
|
|
def _spotify_format_call(qtype: str, key: str) -> Tuple[str, MutableMapping]:
|
|
params = {}
|
|
if qtype == "album":
|
|
query = f"https://api.spotify.com/v1/albums/{key}/tracks"
|
|
elif qtype == "track":
|
|
query = f"https://api.spotify.com/v1/tracks/{key}"
|
|
else:
|
|
query = f"https://api.spotify.com/v1/playlists/{key}/tracks"
|
|
return query, params
|
|
|
|
@staticmethod
|
|
def _get_spotify_track_info(track_data: MutableMapping) -> Tuple[str, ...]:
|
|
artist_name = track_data["artists"][0]["name"]
|
|
track_name = track_data["name"]
|
|
track_info = f"{track_name} {artist_name}"
|
|
song_url = track_data.get("external_urls", {}).get("spotify")
|
|
uri = track_data["uri"]
|
|
_id = track_data["id"]
|
|
_type = track_data["type"]
|
|
|
|
return song_url, track_info, uri, artist_name, track_name, _id, _type
|
|
|
|
async def _spotify_first_time_query(
|
|
self,
|
|
ctx: commands.Context,
|
|
query_type: str,
|
|
uri: str,
|
|
notifier: Notifier,
|
|
skip_youtube: bool = False,
|
|
current_cache_level: CacheLevel = CacheLevel.none(),
|
|
) -> List[str]:
|
|
youtube_urls = []
|
|
|
|
tracks = await self._spotify_fetch_tracks(query_type, uri, params=None, notifier=notifier)
|
|
total_tracks = len(tracks)
|
|
database_entries = []
|
|
track_count = 0
|
|
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
|
youtube_cache = CacheLevel.set_youtube().is_subset(current_cache_level)
|
|
for track in tracks:
|
|
if track.get("error", {}).get("message") == "invalid id":
|
|
continue
|
|
(
|
|
song_url,
|
|
track_info,
|
|
uri,
|
|
artist_name,
|
|
track_name,
|
|
_id,
|
|
_type,
|
|
) = self._get_spotify_track_info(track)
|
|
|
|
database_entries.append(
|
|
{
|
|
"id": _id,
|
|
"type": _type,
|
|
"uri": uri,
|
|
"track_name": track_name,
|
|
"artist_name": artist_name,
|
|
"song_url": song_url,
|
|
"track_info": track_info,
|
|
"last_updated": time_now,
|
|
"last_fetched": time_now,
|
|
}
|
|
)
|
|
if skip_youtube is False:
|
|
val = None
|
|
if youtube_cache:
|
|
update = True
|
|
with contextlib.suppress(SQLError):
|
|
(val, update) = await self.database.fetch_one(
|
|
"youtube", "youtube_url", {"track": track_info}
|
|
)
|
|
if update:
|
|
val = None
|
|
if val is None:
|
|
val = await self._youtube_first_time_query(
|
|
ctx, track_info, current_cache_level=current_cache_level
|
|
)
|
|
if youtube_cache and val:
|
|
task = ("update", ("youtube", {"track": track_info}))
|
|
self.append_task(ctx, *task)
|
|
|
|
if val:
|
|
youtube_urls.append(val)
|
|
else:
|
|
youtube_urls.append(track_info)
|
|
await asyncio.sleep(0)
|
|
track_count += 1
|
|
if notifier and ((track_count % 2 == 0) or (track_count == total_tracks)):
|
|
await notifier.notify_user(current=track_count, total=total_tracks, key="youtube")
|
|
if CacheLevel.set_spotify().is_subset(current_cache_level):
|
|
task = ("insert", ("spotify", database_entries))
|
|
self.append_task(ctx, *task)
|
|
return youtube_urls
|
|
|
|
async def _youtube_first_time_query(
|
|
self,
|
|
ctx: commands.Context,
|
|
track_info: str,
|
|
current_cache_level: CacheLevel = CacheLevel.none(),
|
|
) -> str:
|
|
track_url = await self.youtube_api.get_call(track_info)
|
|
if CacheLevel.set_youtube().is_subset(current_cache_level) and track_url:
|
|
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
|
task = (
|
|
"insert",
|
|
(
|
|
"youtube",
|
|
[
|
|
{
|
|
"track_info": track_info,
|
|
"track_url": track_url,
|
|
"last_updated": time_now,
|
|
"last_fetched": time_now,
|
|
}
|
|
],
|
|
),
|
|
)
|
|
self.append_task(ctx, *task)
|
|
return track_url
|
|
|
|
async def _spotify_fetch_tracks(
|
|
self,
|
|
query_type: str,
|
|
uri: str,
|
|
recursive: Union[str, bool] = False,
|
|
params: MutableMapping = None,
|
|
notifier: Optional[Notifier] = None,
|
|
) -> Union[MutableMapping, List[str]]:
|
|
|
|
if recursive is False:
|
|
(call, params) = self._spotify_format_call(query_type, uri)
|
|
results = await self.spotify_api.get_call(call, params)
|
|
else:
|
|
results = await self.spotify_api.get_call(recursive, params)
|
|
try:
|
|
if results["error"]["status"] == 401 and not recursive:
|
|
raise SpotifyFetchError(
|
|
(
|
|
"The Spotify API key or client secret has not been set properly. "
|
|
"\nUse `{prefix}audioset spotifyapi` for instructions."
|
|
)
|
|
)
|
|
elif recursive:
|
|
return {"next": None}
|
|
except KeyError:
|
|
pass
|
|
if recursive:
|
|
return results
|
|
tracks = []
|
|
track_count = 0
|
|
total_tracks = results.get("tracks", results).get("total", 1)
|
|
while True:
|
|
new_tracks = []
|
|
if query_type == "track":
|
|
new_tracks = results
|
|
tracks.append(new_tracks)
|
|
elif query_type == "album":
|
|
tracks_raw = results.get("tracks", results).get("items", [])
|
|
if tracks_raw:
|
|
new_tracks = tracks_raw
|
|
tracks.extend(new_tracks)
|
|
else:
|
|
tracks_raw = results.get("tracks", results).get("items", [])
|
|
if tracks_raw:
|
|
new_tracks = [k["track"] for k in tracks_raw if k.get("track")]
|
|
tracks.extend(new_tracks)
|
|
track_count += len(new_tracks)
|
|
if notifier:
|
|
await notifier.notify_user(current=track_count, total=total_tracks, key="spotify")
|
|
|
|
try:
|
|
if results.get("next") is not None:
|
|
results = await self._spotify_fetch_tracks(
|
|
query_type, uri, results["next"], params, notifier=notifier
|
|
)
|
|
continue
|
|
else:
|
|
break
|
|
except KeyError:
|
|
raise SpotifyFetchError(
|
|
"This doesn't seem to be a valid Spotify playlist/album URL or code."
|
|
)
|
|
|
|
return tracks
|
|
|
|
async def spotify_query(
|
|
self,
|
|
ctx: commands.Context,
|
|
query_type: str,
|
|
uri: str,
|
|
skip_youtube: bool = False,
|
|
notifier: Optional[Notifier] = None,
|
|
) -> List[str]:
|
|
"""Queries the Database then falls back to Spotify and YouTube APIs.
|
|
|
|
Parameters
|
|
----------
|
|
ctx: commands.Context
|
|
The context this method is being called under.
|
|
query_type : str
|
|
Type of query to perform (Pl
|
|
uri: str
|
|
Spotify URL ID .
|
|
skip_youtube:bool
|
|
Whether or not to skip YouTube API Calls.
|
|
notifier: Notifier
|
|
A Notifier object to handle the user UI notifications while tracks are loaded.
|
|
Returns
|
|
-------
|
|
List[str]
|
|
List of Youtube URLs.
|
|
"""
|
|
current_cache_level = CacheLevel(await self.config.cache_level())
|
|
cache_enabled = CacheLevel.set_spotify().is_subset(current_cache_level)
|
|
if query_type == "track" and cache_enabled:
|
|
update = True
|
|
with contextlib.suppress(SQLError):
|
|
(val, update) = await self.database.fetch_one(
|
|
"spotify", "track_info", {"uri": f"spotify:track:{uri}"}
|
|
)
|
|
if update:
|
|
val = None
|
|
else:
|
|
val = None
|
|
youtube_urls = []
|
|
if val is None:
|
|
urls = await self._spotify_first_time_query(
|
|
ctx,
|
|
query_type,
|
|
uri,
|
|
notifier,
|
|
skip_youtube,
|
|
current_cache_level=current_cache_level,
|
|
)
|
|
youtube_urls.extend(urls)
|
|
else:
|
|
if query_type == "track" and cache_enabled:
|
|
task = ("update", ("spotify", {"uri": f"spotify:track:{uri}"}))
|
|
self.append_task(ctx, *task)
|
|
youtube_urls.append(val)
|
|
return youtube_urls
|
|
|
|
async def spotify_enqueue(
|
|
self,
|
|
ctx: commands.Context,
|
|
query_type: str,
|
|
uri: str,
|
|
enqueue: bool,
|
|
player: lavalink.Player,
|
|
lock: Callable,
|
|
notifier: Optional[Notifier] = None,
|
|
) -> List[lavalink.Track]:
|
|
track_list = []
|
|
has_not_allowed = False
|
|
try:
|
|
current_cache_level = CacheLevel(await self.config.cache_level())
|
|
guild_data = await self.config.guild(ctx.guild).all()
|
|
|
|
# now = int(time.time())
|
|
enqueued_tracks = 0
|
|
consecutive_fails = 0
|
|
queue_dur = await queue_duration(ctx)
|
|
queue_total_duration = lavalink.utils.format_time(queue_dur)
|
|
before_queue_length = len(player.queue)
|
|
tracks_from_spotify = await self._spotify_fetch_tracks(
|
|
query_type, uri, params=None, notifier=notifier
|
|
)
|
|
total_tracks = len(tracks_from_spotify)
|
|
if total_tracks < 1:
|
|
lock(ctx, False)
|
|
embed3 = discord.Embed(
|
|
colour=await ctx.embed_colour(),
|
|
title=_("This doesn't seem to be a supported Spotify URL or code."),
|
|
)
|
|
await notifier.update_embed(embed3)
|
|
|
|
return track_list
|
|
database_entries = []
|
|
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
|
|
|
youtube_cache = CacheLevel.set_youtube().is_subset(current_cache_level)
|
|
spotify_cache = CacheLevel.set_spotify().is_subset(current_cache_level)
|
|
for track_count, track in enumerate(tracks_from_spotify):
|
|
(
|
|
song_url,
|
|
track_info,
|
|
uri,
|
|
artist_name,
|
|
track_name,
|
|
_id,
|
|
_type,
|
|
) = self._get_spotify_track_info(track)
|
|
|
|
database_entries.append(
|
|
{
|
|
"id": _id,
|
|
"type": _type,
|
|
"uri": uri,
|
|
"track_name": track_name,
|
|
"artist_name": artist_name,
|
|
"song_url": song_url,
|
|
"track_info": track_info,
|
|
"last_updated": time_now,
|
|
"last_fetched": time_now,
|
|
}
|
|
)
|
|
val = None
|
|
if youtube_cache:
|
|
update = True
|
|
with contextlib.suppress(SQLError):
|
|
(val, update) = await self.database.fetch_one(
|
|
"youtube", "youtube_url", {"track": track_info}
|
|
)
|
|
if update:
|
|
val = None
|
|
if val is None:
|
|
val = await self._youtube_first_time_query(
|
|
ctx, track_info, current_cache_level=current_cache_level
|
|
)
|
|
if youtube_cache and val:
|
|
task = ("update", ("youtube", {"track": track_info}))
|
|
self.append_task(ctx, *task)
|
|
|
|
if val:
|
|
try:
|
|
(result, called_api) = await self.lavalink_query(
|
|
ctx, player, audio_dataclasses.Query.process_input(val)
|
|
)
|
|
except (RuntimeError, aiohttp.ServerDisconnectedError):
|
|
lock(ctx, False)
|
|
error_embed = discord.Embed(
|
|
colour=await ctx.embed_colour(),
|
|
title=_("The connection was reset while loading the playlist."),
|
|
)
|
|
await notifier.update_embed(error_embed)
|
|
break
|
|
except asyncio.TimeoutError:
|
|
lock(ctx, False)
|
|
error_embed = discord.Embed(
|
|
colour=await ctx.embed_colour(),
|
|
title=_("Player timeout, skipping remaining tracks."),
|
|
)
|
|
await notifier.update_embed(error_embed)
|
|
break
|
|
track_object = result.tracks
|
|
else:
|
|
track_object = []
|
|
if (track_count % 2 == 0) or (track_count == total_tracks):
|
|
key = "lavalink"
|
|
seconds = "???"
|
|
second_key = None
|
|
await notifier.notify_user(
|
|
current=track_count,
|
|
total=total_tracks,
|
|
key=key,
|
|
seconds_key=second_key,
|
|
seconds=seconds,
|
|
)
|
|
|
|
if consecutive_fails >= 10:
|
|
error_embed = discord.Embed(
|
|
colour=await ctx.embed_colour(),
|
|
title=_("Failing to get tracks, skipping remaining."),
|
|
)
|
|
await notifier.update_embed(error_embed)
|
|
break
|
|
if not track_object:
|
|
consecutive_fails += 1
|
|
continue
|
|
consecutive_fails = 0
|
|
single_track = track_object[0]
|
|
if not await is_allowed(
|
|
ctx.guild,
|
|
(
|
|
f"{single_track.title} {single_track.author} {single_track.uri} "
|
|
f"{str(audio_dataclasses.Query.process_input(single_track))}"
|
|
),
|
|
):
|
|
has_not_allowed = True
|
|
log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})")
|
|
continue
|
|
track_list.append(single_track)
|
|
if enqueue:
|
|
if len(player.queue) >= 10000:
|
|
continue
|
|
if guild_data["maxlength"] > 0:
|
|
if track_limit(single_track, guild_data["maxlength"]):
|
|
enqueued_tracks += 1
|
|
player.add(ctx.author, single_track)
|
|
self.bot.dispatch(
|
|
"red_audio_track_enqueue",
|
|
player.channel.guild,
|
|
single_track,
|
|
ctx.author,
|
|
)
|
|
else:
|
|
enqueued_tracks += 1
|
|
player.add(ctx.author, single_track)
|
|
self.bot.dispatch(
|
|
"red_audio_track_enqueue",
|
|
player.channel.guild,
|
|
single_track,
|
|
ctx.author,
|
|
)
|
|
|
|
if not player.current:
|
|
await player.play()
|
|
if len(track_list) == 0:
|
|
if not has_not_allowed:
|
|
raise SpotifyFetchError(
|
|
message=_(
|
|
"Nothing found.\nThe YouTube API key may be invalid "
|
|
"or you may be rate limited on YouTube's search service.\n"
|
|
"Check the YouTube API key again and follow the instructions "
|
|
"at `{prefix}audioset youtubeapi`."
|
|
).format(prefix=ctx.prefix)
|
|
)
|
|
player.maybe_shuffle()
|
|
if enqueue and tracks_from_spotify:
|
|
if total_tracks > enqueued_tracks:
|
|
maxlength_msg = " {bad_tracks} tracks cannot be queued.".format(
|
|
bad_tracks=(total_tracks - enqueued_tracks)
|
|
)
|
|
else:
|
|
maxlength_msg = ""
|
|
|
|
embed = discord.Embed(
|
|
colour=await ctx.embed_colour(),
|
|
title=_("Playlist Enqueued"),
|
|
description=_("Added {num} tracks to the queue.{maxlength_msg}").format(
|
|
num=enqueued_tracks, maxlength_msg=maxlength_msg
|
|
),
|
|
)
|
|
if not guild_data["shuffle"] and queue_dur > 0:
|
|
embed.set_footer(
|
|
text=_(
|
|
"{time} until start of playlist"
|
|
" playback: starts at #{position} in queue"
|
|
).format(time=queue_total_duration, position=before_queue_length + 1)
|
|
)
|
|
|
|
await notifier.update_embed(embed)
|
|
lock(ctx, False)
|
|
|
|
if spotify_cache:
|
|
task = ("insert", ("spotify", database_entries))
|
|
self.append_task(ctx, *task)
|
|
except Exception as e:
|
|
lock(ctx, False)
|
|
raise e
|
|
finally:
|
|
lock(ctx, False)
|
|
return track_list
|
|
|
|
async def youtube_query(self, ctx: commands.Context, track_info: str) -> str:
|
|
current_cache_level = CacheLevel(await self.config.cache_level())
|
|
cache_enabled = CacheLevel.set_youtube().is_subset(current_cache_level)
|
|
val = None
|
|
if cache_enabled:
|
|
update = True
|
|
with contextlib.suppress(SQLError):
|
|
(val, update) = await self.database.fetch_one(
|
|
"youtube", "youtube_url", {"track": track_info}
|
|
)
|
|
if update:
|
|
val = None
|
|
if val is None:
|
|
youtube_url = await self._youtube_first_time_query(
|
|
ctx, track_info, current_cache_level=current_cache_level
|
|
)
|
|
else:
|
|
if cache_enabled:
|
|
task = ("update", ("youtube", {"track": track_info}))
|
|
self.append_task(ctx, *task)
|
|
youtube_url = val
|
|
return youtube_url
|
|
|
|
async def lavalink_query(
|
|
self,
|
|
ctx: commands.Context,
|
|
player: lavalink.Player,
|
|
query: audio_dataclasses.Query,
|
|
forced: bool = False,
|
|
) -> Tuple[LoadResult, bool]:
|
|
"""A replacement for :code:`lavalink.Player.load_tracks`. This will try to get a valid
|
|
cached entry first if not found or if in valid it will then call the lavalink API.
|
|
|
|
Parameters
|
|
----------
|
|
ctx: commands.Context
|
|
The context this method is being called under.
|
|
player : lavalink.Player
|
|
The player who's requesting the query.
|
|
query: audio_dataclasses.Query
|
|
The Query object for the query in question.
|
|
forced:bool
|
|
Whether or not to skip cache and call API first..
|
|
Returns
|
|
-------
|
|
Tuple[lavalink.LoadResult, bool]
|
|
Tuple with the Load result and whether or not the API was called.
|
|
"""
|
|
current_cache_level = CacheLevel(await self.config.cache_level())
|
|
cache_enabled = CacheLevel.set_lavalink().is_subset(current_cache_level)
|
|
val = None
|
|
_raw_query = audio_dataclasses.Query.process_input(query)
|
|
query = str(_raw_query)
|
|
if cache_enabled and not forced and not _raw_query.is_local:
|
|
update = True
|
|
with contextlib.suppress(SQLError):
|
|
(val, update) = await self.database.fetch_one("lavalink", "data", {"query": query})
|
|
if update:
|
|
val = None
|
|
if val and isinstance(val, dict):
|
|
log.debug(f"Querying Local Database for {query}")
|
|
task = ("update", ("lavalink", {"query": query}))
|
|
self.append_task(ctx, *task)
|
|
else:
|
|
val = None
|
|
if val and not forced and isinstance(val, dict):
|
|
data = val
|
|
data["query"] = query
|
|
if data.get("loadType") == "V2_COMPACT":
|
|
data["loadType"] = "V2_COMPAT"
|
|
results = LoadResult(data)
|
|
called_api = False
|
|
if results.has_error:
|
|
# If cached value has an invalid entry make a new call so that it gets updated
|
|
return await self.lavalink_query(ctx, player, _raw_query, forced=True)
|
|
else:
|
|
called_api = True
|
|
results = None
|
|
try:
|
|
results = await player.load_tracks(query)
|
|
except KeyError:
|
|
results = None
|
|
except RuntimeError:
|
|
raise TrackEnqueueError
|
|
if results is None:
|
|
results = LoadResult({"loadType": "LOAD_FAILED", "playlistInfo": {}, "tracks": []})
|
|
if (
|
|
cache_enabled
|
|
and results.load_type
|
|
and not results.has_error
|
|
and not _raw_query.is_local
|
|
and results.tracks
|
|
):
|
|
with contextlib.suppress(SQLError):
|
|
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
|
data = json.dumps(results._raw)
|
|
if all(
|
|
k in data for k in ["loadType", "playlistInfo", "isSeekable", "isStream"]
|
|
):
|
|
task = (
|
|
"insert",
|
|
(
|
|
"lavalink",
|
|
[
|
|
{
|
|
"query": query,
|
|
"data": data,
|
|
"last_updated": time_now,
|
|
"last_fetched": time_now,
|
|
}
|
|
],
|
|
),
|
|
)
|
|
self.append_task(ctx, *task)
|
|
return results, called_api
|
|
|
|
async def run_tasks(self, ctx: Optional[commands.Context] = None, _id=None):
|
|
lock_id = _id or ctx.message.id
|
|
lock_author = ctx.author if ctx else None
|
|
async with self._lock:
|
|
if lock_id in self._tasks:
|
|
log.debug(f"Running database writes for {lock_id} ({lock_author})")
|
|
with contextlib.suppress(Exception):
|
|
tasks = self._tasks[ctx.message.id]
|
|
del self._tasks[ctx.message.id]
|
|
await asyncio.gather(
|
|
*[self.database.insert(*a) for a in tasks["insert"]],
|
|
return_exceptions=True,
|
|
)
|
|
await asyncio.gather(
|
|
*[self.database.update(*a) for a in tasks["update"]],
|
|
return_exceptions=True,
|
|
)
|
|
log.debug(f"Completed database writes for {lock_id} " f"({lock_author})")
|
|
|
|
async def run_all_pending_tasks(self):
|
|
async with self._lock:
|
|
log.debug("Running pending writes to database")
|
|
with contextlib.suppress(Exception):
|
|
tasks = {"update": [], "insert": []}
|
|
for (k, task) in self._tasks.items():
|
|
for t, args in task.items():
|
|
tasks[t].append(args)
|
|
self._tasks = {}
|
|
|
|
await asyncio.gather(
|
|
*[self.database.insert(*a) for a in tasks["insert"]], return_exceptions=True
|
|
)
|
|
await asyncio.gather(
|
|
*[self.database.update(*a) for a in tasks["update"]], return_exceptions=True
|
|
)
|
|
log.debug("Completed pending writes to database have finished")
|
|
|
|
def append_task(self, ctx: commands.Context, event: str, task: tuple, _id=None):
|
|
lock_id = _id or ctx.message.id
|
|
if lock_id not in self._tasks:
|
|
self._tasks[lock_id] = {"update": [], "insert": []}
|
|
self._tasks[lock_id][event].append(task)
|
|
|
|
async def get_random_from_db(self):
|
|
tracks = []
|
|
try:
|
|
query_data = {}
|
|
date = datetime.datetime.now(datetime.timezone.utc) - datetime.timedelta(days=7)
|
|
date = int(date.timestamp())
|
|
query_data["day"] = date
|
|
max_age = await self.config.cache_age()
|
|
maxage = datetime.datetime.now(tz=datetime.timezone.utc) - datetime.timedelta(
|
|
days=max_age
|
|
)
|
|
maxage_int = int(time.mktime(maxage.timetuple()))
|
|
query_data["maxage"] = maxage_int
|
|
|
|
vals = await self.database.fetch_all("lavalink", "data", query_data)
|
|
recently_played = [r.tracks for r in vals if r if isinstance(tracks, dict)]
|
|
|
|
if recently_played:
|
|
track = random.choice(recently_played)
|
|
if track.get("loadType") == "V2_COMPACT":
|
|
track["loadType"] = "V2_COMPAT"
|
|
results = LoadResult(track)
|
|
tracks = list(results.tracks)
|
|
except Exception:
|
|
tracks = []
|
|
|
|
return tracks
|
|
|
|
async def autoplay(self, player: lavalink.Player):
|
|
autoplaylist = await self.config.guild(player.channel.guild).autoplaylist()
|
|
current_cache_level = CacheLevel(await self.config.cache_level())
|
|
cache_enabled = CacheLevel.set_lavalink().is_subset(current_cache_level)
|
|
playlist = None
|
|
tracks = None
|
|
if autoplaylist["enabled"]:
|
|
with contextlib.suppress(Exception):
|
|
playlist = await get_playlist(
|
|
autoplaylist["id"],
|
|
autoplaylist["scope"],
|
|
self.bot,
|
|
player.channel.guild,
|
|
player.channel.guild.me,
|
|
)
|
|
tracks = playlist.tracks_obj
|
|
|
|
if not tracks or not getattr(playlist, "tracks", None):
|
|
if cache_enabled:
|
|
tracks = await self.get_random_from_db()
|
|
if not tracks:
|
|
ctx = namedtuple("Context", "message")
|
|
(results, called_api) = await self.lavalink_query(
|
|
ctx(player.channel.guild),
|
|
player,
|
|
audio_dataclasses.Query.process_input(_TOP_100_US),
|
|
)
|
|
tracks = list(results.tracks)
|
|
if tracks:
|
|
multiple = len(tracks) > 1
|
|
track = tracks[0]
|
|
|
|
valid = not multiple
|
|
tries = len(tracks)
|
|
while valid is False and multiple:
|
|
tries -= 1
|
|
if tries <= 0:
|
|
raise DatabaseError("No valid entry found")
|
|
track = random.choice(tracks)
|
|
query = audio_dataclasses.Query.process_input(track)
|
|
await asyncio.sleep(0.001)
|
|
if not query.valid:
|
|
continue
|
|
if query.is_local and not query.track.exists():
|
|
continue
|
|
if not await is_allowed(
|
|
player.channel.guild,
|
|
(
|
|
f"{track.title} {track.author} {track.uri} "
|
|
f"{str(audio_dataclasses.Query.process_input(track))}"
|
|
),
|
|
):
|
|
log.debug(
|
|
"Query is not allowed in "
|
|
f"{player.channel.guild} ({player.channel.guild.id})"
|
|
)
|
|
continue
|
|
valid = True
|
|
|
|
track.extras["autoplay"] = True
|
|
player.add(player.channel.guild.me, track)
|
|
self.bot.dispatch(
|
|
"red_audio_track_auto_play", player.channel.guild, track, player.channel.guild.me
|
|
)
|
|
if not player.current:
|
|
await player.play()
|