mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-20 18:06:08 -05:00
Audio Cog - v2.3.0 (#4446)
* First commit - Bring everything from dev cog minus NSFW support
* Add a toggle for auto deafen
* Add a one off Send to Owners
* aaaaaaa
* Update this to ensure `get_perms` is not called if the API is disabled
* Apply suggestions from code review
Co-authored-by: Vuks <51289041+Vuks69@users.noreply.github.com>
* silence any errors here (in case API is down so it doesnt affect audio)
* update the message to tell the mto join the Official Red server.
* remove useless sutff, and change dj check order to ensure bot doesnt join VC for non DJ's
* ffs
* Update redbot/cogs/audio/core/tasks/startup.py
Co-authored-by: Twentysix <Twentysix26@users.noreply.github.com>
* Aikas Review
* Add #3995 in here
* update
* *sigh*
* lock behind owner
* to help with debugging
* Revert "to help with debugging"
This reverts commit 8cbf17be
* resolve last review
Co-authored-by: Vuks <51289041+Vuks69@users.noreply.github.com>
Co-authored-by: Twentysix <Twentysix26@users.noreply.github.com>
This commit is contained in:
@@ -6,6 +6,7 @@ from dataclasses import dataclass, field
|
||||
from typing import List, MutableMapping, Optional, Union
|
||||
|
||||
import discord
|
||||
import lavalink
|
||||
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core.utils.chat_formatting import humanize_list
|
||||
@@ -74,8 +75,22 @@ class PlaylistFetchResult:
|
||||
self.tracks = json.loads(self.tracks)
|
||||
|
||||
|
||||
@dataclass
|
||||
class QueueFetchResult:
|
||||
guild_id: int
|
||||
room_id: int
|
||||
track: dict = field(default_factory=lambda: {})
|
||||
track_object: lavalink.Track = None
|
||||
|
||||
def __post_init__(self):
|
||||
if isinstance(self.track, str):
|
||||
self.track = json.loads(self.track)
|
||||
if self.track:
|
||||
self.track_object = lavalink.Track(self.track)
|
||||
|
||||
|
||||
def standardize_scope(scope: str) -> str:
|
||||
"""Convert any of the used scopes into one we are expecting"""
|
||||
"""Convert any of the used scopes into one we are expecting."""
|
||||
scope = scope.upper()
|
||||
valid_scopes = ["GLOBAL", "GUILD", "AUTHOR", "USER", "SERVER", "MEMBER", "BOT"]
|
||||
|
||||
@@ -103,7 +118,7 @@ def prepare_config_scope(
|
||||
author: Union[discord.abc.User, int] = None,
|
||||
guild: Union[discord.Guild, int] = None,
|
||||
):
|
||||
"""Return the scope used by Playlists"""
|
||||
"""Return the scope used by Playlists."""
|
||||
scope = standardize_scope(scope)
|
||||
if scope == PlaylistScope.GLOBAL.value:
|
||||
config_scope = [PlaylistScope.GLOBAL.value, bot.user.id]
|
||||
@@ -121,7 +136,7 @@ def prepare_config_scope(
|
||||
def prepare_config_scope_for_migration23( # TODO: remove me in a future version ?
|
||||
scope, author: Union[discord.abc.User, int] = None, guild: discord.Guild = None
|
||||
):
|
||||
"""Return the scope used by Playlists"""
|
||||
"""Return the scope used by Playlists."""
|
||||
scope = standardize_scope(scope)
|
||||
|
||||
if scope == PlaylistScope.GLOBAL.value:
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import urllib.parse
|
||||
from typing import Mapping, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from copy import copy
|
||||
from typing import TYPE_CHECKING, Mapping, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
from lavalink.rest_api import LoadResult
|
||||
@@ -17,7 +19,7 @@ from ..audio_logging import IS_DEBUG, debug_exc_log
|
||||
if TYPE_CHECKING:
|
||||
from .. import Audio
|
||||
|
||||
_API_URL = "https://redbot.app/"
|
||||
_API_URL = "https://api.redbot.app/"
|
||||
|
||||
log = logging.getLogger("red.cogs.Audio.api.GlobalDB")
|
||||
|
||||
@@ -32,11 +34,150 @@ class GlobalCacheWrapper:
|
||||
self.session = session
|
||||
self.api_key = None
|
||||
self._handshake_token = ""
|
||||
self.can_write = False
|
||||
self._handshake_token = ""
|
||||
self.has_api_key = None
|
||||
self._token: Mapping[str, str] = {}
|
||||
self.cog = cog
|
||||
|
||||
def update_token(self, new_token: Mapping[str, str]):
|
||||
self._token = new_token
|
||||
|
||||
async def _get_api_key(
|
||||
self,
|
||||
) -> Optional[str]:
|
||||
if not self._token:
|
||||
self._token = await self.bot.get_shared_api_tokens("audiodb")
|
||||
self.api_key = self._token.get("api_key", None)
|
||||
self.has_api_key = self.cog.global_api_user.get("can_post")
|
||||
id_list = list(self.bot.owner_ids)
|
||||
self._handshake_token = "||".join(map(str, id_list))
|
||||
return self.api_key
|
||||
|
||||
async def get_call(self, query: Optional[Query] = None) -> dict:
|
||||
api_url = f"{_API_URL}api/v2/queries"
|
||||
if not self.cog.global_api_user.get("can_read"):
|
||||
return {}
|
||||
try:
|
||||
query = Query.process_input(query, self.cog.local_folder_current_path)
|
||||
if any([not query or not query.valid or query.is_spotify or query.is_local]):
|
||||
return {}
|
||||
await self._get_api_key()
|
||||
if self.api_key is None:
|
||||
return {}
|
||||
search_response = "error"
|
||||
query = query.lavalink_query
|
||||
with contextlib.suppress(aiohttp.ContentTypeError, asyncio.TimeoutError):
|
||||
async with self.session.get(
|
||||
api_url,
|
||||
timeout=aiohttp.ClientTimeout(total=await self.config.global_db_get_timeout()),
|
||||
headers={"Authorization": self.api_key, "X-Token": self._handshake_token},
|
||||
params={"query": query},
|
||||
) as r:
|
||||
search_response = await r.json(loads=json.loads)
|
||||
if IS_DEBUG and "x-process-time" in r.headers:
|
||||
log.debug(
|
||||
f"GET || Ping {r.headers.get('x-process-time')} || "
|
||||
f"Status code {r.status} || {query}"
|
||||
)
|
||||
if "tracks" not in search_response:
|
||||
return {}
|
||||
return search_response
|
||||
except Exception as err:
|
||||
debug_exc_log(log, err, f"Failed to Get query: {api_url}/{query}")
|
||||
return {}
|
||||
|
||||
async def get_spotify(self, title: str, author: Optional[str]) -> dict:
|
||||
if not self.cog.global_api_user.get("can_read"):
|
||||
return {}
|
||||
api_url = f"{_API_URL}api/v2/queries/spotify"
|
||||
try:
|
||||
search_response = "error"
|
||||
params = {"title": title, "author": author}
|
||||
await self._get_api_key()
|
||||
if self.api_key is None:
|
||||
return {}
|
||||
with contextlib.suppress(aiohttp.ContentTypeError, asyncio.TimeoutError):
|
||||
async with self.session.get(
|
||||
api_url,
|
||||
timeout=aiohttp.ClientTimeout(total=await self.config.global_db_get_timeout()),
|
||||
headers={"Authorization": self.api_key, "X-Token": self._handshake_token},
|
||||
params=params,
|
||||
) as r:
|
||||
search_response = await r.json(loads=json.loads)
|
||||
if IS_DEBUG and "x-process-time" in r.headers:
|
||||
log.debug(
|
||||
f"GET/spotify || Ping {r.headers.get('x-process-time')} || "
|
||||
f"Status code {r.status} || {title} - {author}"
|
||||
)
|
||||
if "tracks" not in search_response:
|
||||
return {}
|
||||
return search_response
|
||||
except Exception as err:
|
||||
debug_exc_log(log, err, f"Failed to Get query: {api_url}")
|
||||
return {}
|
||||
|
||||
async def post_call(self, llresponse: LoadResult, query: Optional[Query]) -> None:
|
||||
try:
|
||||
if not self.cog.global_api_user.get("can_post"):
|
||||
return
|
||||
query = Query.process_input(query, self.cog.local_folder_current_path)
|
||||
if llresponse.has_error or llresponse.load_type.value in ["NO_MATCHES", "LOAD_FAILED"]:
|
||||
return
|
||||
if query and query.valid and query.is_youtube:
|
||||
query = query.lavalink_query
|
||||
else:
|
||||
return None
|
||||
await self._get_api_key()
|
||||
if self.api_key is None:
|
||||
return None
|
||||
api_url = f"{_API_URL}api/v2/queries"
|
||||
async with self.session.post(
|
||||
api_url,
|
||||
json=llresponse._raw,
|
||||
headers={"Authorization": self.api_key, "X-Token": self._handshake_token},
|
||||
params={"query": query},
|
||||
) as r:
|
||||
await r.read()
|
||||
if IS_DEBUG and "x-process-time" in r.headers:
|
||||
log.debug(
|
||||
f"POST || Ping {r.headers.get('x-process-time')} ||"
|
||||
f" Status code {r.status} || {query}"
|
||||
)
|
||||
except Exception as err:
|
||||
debug_exc_log(log, err, f"Failed to post query: {query}")
|
||||
await asyncio.sleep(0)
|
||||
|
||||
async def update_global(self, llresponse: LoadResult, query: Optional[Query] = None):
|
||||
await self.post_call(llresponse=llresponse, query=query)
|
||||
|
||||
async def report_invalid(self, id: str) -> None:
|
||||
if not self.cog.global_api_user.get("can_delete"):
|
||||
return
|
||||
api_url = f"{_API_URL}api/v2/queries/es/id"
|
||||
with contextlib.suppress(Exception):
|
||||
async with self.session.delete(
|
||||
api_url,
|
||||
headers={"Authorization": self.api_key, "X-Token": self._handshake_token},
|
||||
params={"id": id},
|
||||
) as r:
|
||||
await r.read()
|
||||
|
||||
async def get_perms(self):
|
||||
global_api_user = copy(self.cog.global_api_user)
|
||||
await self._get_api_key()
|
||||
is_enabled = await self.config.global_db_enabled()
|
||||
await self._get_api_key()
|
||||
if (not is_enabled) or self.api_key is None:
|
||||
return global_api_user
|
||||
with contextlib.suppress(Exception):
|
||||
async with aiohttp.ClientSession(json_serialize=json.dumps) as session:
|
||||
async with session.get(
|
||||
f"{_API_URL}api/v2/users/me",
|
||||
headers={"Authorization": self.api_key, "X-Token": self._handshake_token},
|
||||
) as resp:
|
||||
if resp.status == 200:
|
||||
search_response = await resp.json(loads=json.loads)
|
||||
global_api_user["fetched"] = True
|
||||
global_api_user["can_read"] = search_response.get("can_read", False)
|
||||
global_api_user["can_post"] = search_response.get("can_post", False)
|
||||
global_api_user["can_delete"] = search_response.get("can_delete", False)
|
||||
return global_api_user
|
||||
|
||||
@@ -1,30 +1,34 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Callable, List, MutableMapping, Optional, TYPE_CHECKING, Tuple, Union, cast
|
||||
from typing import TYPE_CHECKING, Callable, List, MutableMapping, Optional, Tuple, Union, cast
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
import lavalink
|
||||
from lavalink.rest_api import LoadResult
|
||||
from redbot.core.utils import AsyncIter
|
||||
|
||||
from lavalink.rest_api import LoadResult, LoadType
|
||||
from redbot.core import Config, commands
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core.commands import Cog, Context
|
||||
from redbot.core.i18n import Translator
|
||||
from redbot.core.utils import AsyncIter
|
||||
from redbot.core.utils.dbtools import APSWConnectionWrapper
|
||||
|
||||
from ..audio_dataclasses import Query
|
||||
from ..audio_logging import IS_DEBUG, debug_exc_log
|
||||
from ..errors import DatabaseError, SpotifyFetchError, TrackEnqueueError
|
||||
from ..utils import CacheLevel, Notifier
|
||||
from .api_utils import LavalinkCacheFetchForGlobalResult
|
||||
from .global_db import GlobalCacheWrapper
|
||||
from .local_db import LocalCacheWrapper
|
||||
from .persist_queue_wrapper import QueueInterface
|
||||
from .playlist_interface import get_playlist
|
||||
from .playlist_wrapper import PlaylistWrapper
|
||||
from .spotify import SpotifyWrapper
|
||||
@@ -36,6 +40,7 @@ if TYPE_CHECKING:
|
||||
_ = Translator("Audio", __file__)
|
||||
log = logging.getLogger("red.cogs.Audio.api.AudioAPIInterface")
|
||||
_TOP_100_US = "https://www.youtube.com/playlist?list=PL4fGSI1pDJn5rWitrRWFKdm-ulaFiIyoK"
|
||||
# TODO: Get random from global Cache
|
||||
|
||||
|
||||
class AudioAPIInterface:
|
||||
@@ -60,20 +65,22 @@ class AudioAPIInterface:
|
||||
self.youtube_api: YouTubeWrapper = YouTubeWrapper(self.bot, self.config, session, self.cog)
|
||||
self.local_cache_api = LocalCacheWrapper(self.bot, self.config, self.conn, self.cog)
|
||||
self.global_cache_api = GlobalCacheWrapper(self.bot, self.config, session, self.cog)
|
||||
self.persistent_queue_api = QueueInterface(self.bot, self.config, self.conn, self.cog)
|
||||
self._session: aiohttp.ClientSession = session
|
||||
self._tasks: MutableMapping = {}
|
||||
self._lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialises the Local Cache connection"""
|
||||
"""Initialises the Local Cache connection."""
|
||||
await self.local_cache_api.lavalink.init()
|
||||
await self.persistent_queue_api.init()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the Local Cache connection"""
|
||||
"""Closes the Local Cache connection."""
|
||||
self.local_cache_api.lavalink.close()
|
||||
|
||||
async def get_random_track_from_db(self) -> Optional[MutableMapping]:
|
||||
"""Get a random track from the local database and return it"""
|
||||
async def get_random_track_from_db(self, tries=0) -> Optional[MutableMapping]:
|
||||
"""Get a random track from the local database and return it."""
|
||||
track: Optional[MutableMapping] = {}
|
||||
try:
|
||||
query_data = {}
|
||||
@@ -106,7 +113,7 @@ class AudioAPIInterface:
|
||||
action_type: str = None,
|
||||
data: Union[List[MutableMapping], MutableMapping] = None,
|
||||
) -> None:
|
||||
"""Separate the tasks and run them in the appropriate functions"""
|
||||
"""Separate the tasks and run them in the appropriate functions."""
|
||||
|
||||
if not data:
|
||||
return
|
||||
@@ -126,9 +133,11 @@ class AudioAPIInterface:
|
||||
await self.local_cache_api.youtube.update(data)
|
||||
elif table == "spotify":
|
||||
await self.local_cache_api.spotify.update(data)
|
||||
elif action_type == "global" and isinstance(data, list):
|
||||
await asyncio.gather(*[self.global_cache_api.update_global(**d) for d in data])
|
||||
|
||||
async def run_tasks(self, ctx: Optional[commands.Context] = None, message_id=None) -> None:
|
||||
"""Run tasks for a specific context"""
|
||||
"""Run tasks for a specific context."""
|
||||
if message_id is not None:
|
||||
lock_id = message_id
|
||||
elif ctx is not None:
|
||||
@@ -143,7 +152,7 @@ class AudioAPIInterface:
|
||||
try:
|
||||
tasks = self._tasks[lock_id]
|
||||
tasks = [self.route_tasks(a, tasks[a]) for a in tasks]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
del self._tasks[lock_id]
|
||||
except Exception as exc:
|
||||
debug_exc_log(
|
||||
@@ -154,7 +163,7 @@ class AudioAPIInterface:
|
||||
log.debug(f"Completed database writes for {lock_id} ({lock_author})")
|
||||
|
||||
async def run_all_pending_tasks(self) -> None:
|
||||
"""Run all pending tasks left in the cache, called on cog_unload"""
|
||||
"""Run all pending tasks left in the cache, called on cog_unload."""
|
||||
async with self._lock:
|
||||
if IS_DEBUG:
|
||||
log.debug("Running pending writes to database")
|
||||
@@ -166,7 +175,7 @@ class AudioAPIInterface:
|
||||
self._tasks = {}
|
||||
coro_tasks = [self.route_tasks(a, tasks[a]) for a in tasks]
|
||||
|
||||
await asyncio.gather(*coro_tasks, return_exceptions=True)
|
||||
await asyncio.gather(*coro_tasks, return_exceptions=False)
|
||||
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed database writes")
|
||||
@@ -175,7 +184,7 @@ class AudioAPIInterface:
|
||||
log.debug("Completed pending writes to database have finished")
|
||||
|
||||
def append_task(self, ctx: commands.Context, event: str, task: Tuple, _id: int = None) -> None:
|
||||
"""Add a task to the cache to be run later"""
|
||||
"""Add a task to the cache to be run later."""
|
||||
lock_id = _id or ctx.message.id
|
||||
if lock_id not in self._tasks:
|
||||
self._tasks[lock_id] = {"update": [], "insert": [], "global": []}
|
||||
@@ -190,7 +199,7 @@ class AudioAPIInterface:
|
||||
skip_youtube: bool = False,
|
||||
current_cache_level: CacheLevel = CacheLevel.none(),
|
||||
) -> List[str]:
|
||||
"""Return youtube URLS for the spotify URL provided"""
|
||||
"""Return youtube URLS for the spotify URL provided."""
|
||||
youtube_urls = []
|
||||
tracks = await self.fetch_from_spotify_api(
|
||||
query_type, uri, params=None, notifier=notifier, ctx=ctx
|
||||
@@ -266,7 +275,7 @@ class AudioAPIInterface:
|
||||
notifier: Optional[Notifier] = None,
|
||||
ctx: Context = None,
|
||||
) -> Union[List[MutableMapping], List[str]]:
|
||||
"""Gets track info from spotify API"""
|
||||
"""Gets track info from spotify API."""
|
||||
|
||||
if recursive is False:
|
||||
(call, params) = self.spotify_api.spotify_format_call(query_type, uri)
|
||||
@@ -394,9 +403,10 @@ class AudioAPIInterface:
|
||||
lock: Callable,
|
||||
notifier: Optional[Notifier] = None,
|
||||
forced: bool = False,
|
||||
query_global: bool = False,
|
||||
query_global: bool = True,
|
||||
) -> List[lavalink.Track]:
|
||||
"""Queries the Database then falls back to Spotify and YouTube APIs then Enqueued matched tracks.
|
||||
"""Queries the Database then falls back to Spotify and YouTube APIs then Enqueued matched
|
||||
tracks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -423,7 +433,9 @@ class AudioAPIInterface:
|
||||
List[str]
|
||||
List of Youtube URLs.
|
||||
"""
|
||||
# globaldb_toggle = await self.config.global_db_enabled()
|
||||
await self.global_cache_api._get_api_key()
|
||||
globaldb_toggle = await self.config.global_db_enabled()
|
||||
global_entry = globaldb_toggle and query_global
|
||||
track_list: List = []
|
||||
has_not_allowed = False
|
||||
try:
|
||||
@@ -485,7 +497,14 @@ class AudioAPIInterface:
|
||||
)
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, f"Failed to fetch {track_info} from YouTube table")
|
||||
|
||||
should_query_global = globaldb_toggle and query_global and val is None
|
||||
if should_query_global:
|
||||
llresponse = await self.global_cache_api.get_spotify(track_name, artist_name)
|
||||
if llresponse:
|
||||
if llresponse.get("loadType") == "V2_COMPACT":
|
||||
llresponse["loadType"] = "V2_COMPAT"
|
||||
llresponse = LoadResult(llresponse)
|
||||
val = llresponse or None
|
||||
if val is None:
|
||||
val = await self.fetch_youtube_query(
|
||||
ctx, track_info, current_cache_level=current_cache_level
|
||||
@@ -494,34 +513,44 @@ class AudioAPIInterface:
|
||||
task = ("update", ("youtube", {"track": track_info}))
|
||||
self.append_task(ctx, *task)
|
||||
|
||||
if llresponse is not None:
|
||||
if isinstance(llresponse, LoadResult):
|
||||
track_object = llresponse.tracks
|
||||
elif val:
|
||||
try:
|
||||
(result, called_api) = await self.fetch_track(
|
||||
ctx,
|
||||
player,
|
||||
Query.process_input(val, self.cog.local_folder_current_path),
|
||||
forced=forced,
|
||||
)
|
||||
except (RuntimeError, aiohttp.ServerDisconnectedError):
|
||||
lock(ctx, False)
|
||||
error_embed = discord.Embed(
|
||||
colour=await ctx.embed_colour(),
|
||||
title=_("The connection was reset while loading the playlist."),
|
||||
)
|
||||
if notifier is not None:
|
||||
await notifier.update_embed(error_embed)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
lock(ctx, False)
|
||||
error_embed = discord.Embed(
|
||||
colour=await ctx.embed_colour(),
|
||||
title=_("Player timeout, skipping remaining tracks."),
|
||||
)
|
||||
if notifier is not None:
|
||||
await notifier.update_embed(error_embed)
|
||||
break
|
||||
result = None
|
||||
if should_query_global:
|
||||
llresponse = await self.global_cache_api.get_call(val)
|
||||
if llresponse:
|
||||
if llresponse.get("loadType") == "V2_COMPACT":
|
||||
llresponse["loadType"] = "V2_COMPAT"
|
||||
llresponse = LoadResult(llresponse)
|
||||
result = llresponse or None
|
||||
if not result:
|
||||
try:
|
||||
(result, called_api) = await self.fetch_track(
|
||||
ctx,
|
||||
player,
|
||||
Query.process_input(val, self.cog.local_folder_current_path),
|
||||
forced=forced,
|
||||
should_query_global=not should_query_global,
|
||||
)
|
||||
except (RuntimeError, aiohttp.ServerDisconnectedError):
|
||||
lock(ctx, False)
|
||||
error_embed = discord.Embed(
|
||||
colour=await ctx.embed_colour(),
|
||||
title=_("The connection was reset while loading the playlist."),
|
||||
)
|
||||
if notifier is not None:
|
||||
await notifier.update_embed(error_embed)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
lock(ctx, False)
|
||||
error_embed = discord.Embed(
|
||||
colour=await ctx.embed_colour(),
|
||||
title=_("Player timeout, skipping remaining tracks."),
|
||||
)
|
||||
if notifier is not None:
|
||||
await notifier.update_embed(error_embed)
|
||||
break
|
||||
track_object = result.tracks
|
||||
else:
|
||||
track_object = []
|
||||
@@ -538,7 +567,7 @@ class AudioAPIInterface:
|
||||
seconds=seconds,
|
||||
)
|
||||
|
||||
if consecutive_fails >= 10:
|
||||
if consecutive_fails >= (100 if global_entry else 10):
|
||||
error_embed = discord.Embed(
|
||||
colour=await ctx.embed_colour(),
|
||||
title=_("Failing to get tracks, skipping remaining."),
|
||||
@@ -551,13 +580,12 @@ class AudioAPIInterface:
|
||||
continue
|
||||
consecutive_fails = 0
|
||||
single_track = track_object[0]
|
||||
query = Query.process_input(single_track, self.cog.local_folder_current_path)
|
||||
if not await self.cog.is_query_allowed(
|
||||
self.config,
|
||||
ctx.guild,
|
||||
(
|
||||
f"{single_track.title} {single_track.author} {single_track.uri} "
|
||||
f"{Query.process_input(single_track, self.cog.local_folder_current_path)}"
|
||||
),
|
||||
ctx,
|
||||
f"{single_track.title} {single_track.author} {single_track.uri} {query}",
|
||||
query_obj=query,
|
||||
):
|
||||
has_not_allowed = True
|
||||
if IS_DEBUG:
|
||||
@@ -570,6 +598,13 @@ class AudioAPIInterface:
|
||||
if guild_data["maxlength"] > 0:
|
||||
if self.cog.is_track_length_allowed(single_track, guild_data["maxlength"]):
|
||||
enqueued_tracks += 1
|
||||
single_track.extras.update(
|
||||
{
|
||||
"enqueue_time": int(time.time()),
|
||||
"vc": player.channel.id,
|
||||
"requester": ctx.author.id,
|
||||
}
|
||||
)
|
||||
player.add(ctx.author, single_track)
|
||||
self.bot.dispatch(
|
||||
"red_audio_track_enqueue",
|
||||
@@ -579,6 +614,13 @@ class AudioAPIInterface:
|
||||
)
|
||||
else:
|
||||
enqueued_tracks += 1
|
||||
single_track.extras.update(
|
||||
{
|
||||
"enqueue_time": int(time.time()),
|
||||
"vc": player.channel.id,
|
||||
"requester": ctx.author.id,
|
||||
}
|
||||
)
|
||||
player.add(ctx.author, single_track)
|
||||
self.bot.dispatch(
|
||||
"red_audio_track_enqueue",
|
||||
@@ -642,9 +684,7 @@ class AudioAPIInterface:
|
||||
track_info: str,
|
||||
current_cache_level: CacheLevel = CacheLevel.none(),
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Call the Youtube API and returns the youtube URL that the query matched
|
||||
"""
|
||||
"""Call the Youtube API and returns the youtube URL that the query matched."""
|
||||
track_url = await self.youtube_api.get_call(track_info)
|
||||
if CacheLevel.set_youtube().is_subset(current_cache_level) and track_url:
|
||||
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
@@ -668,9 +708,7 @@ class AudioAPIInterface:
|
||||
async def fetch_from_youtube_api(
|
||||
self, ctx: commands.Context, track_info: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Gets an YouTube URL from for the query
|
||||
"""
|
||||
"""Gets an YouTube URL from for the query."""
|
||||
current_cache_level = CacheLevel(await self.config.cache_level())
|
||||
cache_enabled = CacheLevel.set_youtube().is_subset(current_cache_level)
|
||||
val = None
|
||||
@@ -727,6 +765,7 @@ class AudioAPIInterface:
|
||||
val = None
|
||||
query = Query.process_input(query, self.cog.local_folder_current_path)
|
||||
query_string = str(query)
|
||||
globaldb_toggle = await self.config.global_db_enabled()
|
||||
valid_global_entry = False
|
||||
results = None
|
||||
called_api = False
|
||||
@@ -754,7 +793,31 @@ class AudioAPIInterface:
|
||||
called_api = False
|
||||
else:
|
||||
val = None
|
||||
|
||||
if (
|
||||
globaldb_toggle
|
||||
and not val
|
||||
and should_query_global
|
||||
and not forced
|
||||
and not query.is_local
|
||||
and not query.is_spotify
|
||||
):
|
||||
valid_global_entry = False
|
||||
with contextlib.suppress(Exception):
|
||||
global_entry = await self.global_cache_api.get_call(query=query)
|
||||
if global_entry.get("loadType") == "V2_COMPACT":
|
||||
global_entry["loadType"] = "V2_COMPAT"
|
||||
results = LoadResult(global_entry)
|
||||
if results.load_type in [
|
||||
LoadType.PLAYLIST_LOADED,
|
||||
LoadType.TRACK_LOADED,
|
||||
LoadType.SEARCH_RESULT,
|
||||
LoadType.V2_COMPAT,
|
||||
]:
|
||||
valid_global_entry = True
|
||||
if valid_global_entry:
|
||||
if IS_DEBUG:
|
||||
log.debug(f"Querying Global DB api for {query}")
|
||||
results, called_api = results, False
|
||||
if valid_global_entry:
|
||||
pass
|
||||
elif lazy is True:
|
||||
@@ -769,6 +832,7 @@ class AudioAPIInterface:
|
||||
if results.has_error:
|
||||
# If cached value has an invalid entry make a new call so that it gets updated
|
||||
results, called_api = await self.fetch_track(ctx, player, query, forced=True)
|
||||
valid_global_entry = False
|
||||
else:
|
||||
if IS_DEBUG:
|
||||
log.debug(f"Querying Lavalink api for {query_string}")
|
||||
@@ -781,7 +845,19 @@ class AudioAPIInterface:
|
||||
raise TrackEnqueueError
|
||||
if results is None:
|
||||
results = LoadResult({"loadType": "LOAD_FAILED", "playlistInfo": {}, "tracks": []})
|
||||
|
||||
valid_global_entry = False
|
||||
update_global = (
|
||||
globaldb_toggle and not valid_global_entry and self.global_cache_api.has_api_key
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
if (
|
||||
update_global
|
||||
and not query.is_local
|
||||
and not results.has_error
|
||||
and len(results.tracks) >= 1
|
||||
):
|
||||
global_task = ("global", dict(llresponse=results, query=query))
|
||||
self.append_task(ctx, *global_task)
|
||||
if (
|
||||
cache_enabled
|
||||
and results.load_type
|
||||
@@ -817,9 +893,7 @@ class AudioAPIInterface:
|
||||
return results, called_api
|
||||
|
||||
async def autoplay(self, player: lavalink.Player, playlist_api: PlaylistWrapper):
|
||||
"""
|
||||
Enqueue a random track
|
||||
"""
|
||||
"""Enqueue a random track."""
|
||||
autoplaylist = await self.config.guild(player.channel.guild).autoplaylist()
|
||||
current_cache_level = CacheLevel(await self.config.cache_level())
|
||||
cache_enabled = CacheLevel.set_lavalink().is_subset(current_cache_level)
|
||||
@@ -865,19 +939,18 @@ class AudioAPIInterface:
|
||||
track = random.choice(tracks)
|
||||
query = Query.process_input(track, self.cog.local_folder_current_path)
|
||||
await asyncio.sleep(0.001)
|
||||
if not query.valid or (
|
||||
if (not query.valid) or (
|
||||
query.is_local
|
||||
and query.local_track_path is not None
|
||||
and not query.local_track_path.exists()
|
||||
):
|
||||
continue
|
||||
notify_channel = self.bot.get_channel(player.fetch("channel"))
|
||||
if not await self.cog.is_query_allowed(
|
||||
self.config,
|
||||
player.channel.guild,
|
||||
(
|
||||
f"{track.title} {track.author} {track.uri} "
|
||||
f"{str(Query.process_input(track, self.cog.local_folder_current_path))}"
|
||||
),
|
||||
notify_channel,
|
||||
f"{track.title} {track.author} {track.uri} {query}",
|
||||
query_obj=query,
|
||||
):
|
||||
if IS_DEBUG:
|
||||
log.debug(
|
||||
@@ -886,11 +959,20 @@ class AudioAPIInterface:
|
||||
)
|
||||
continue
|
||||
valid = True
|
||||
|
||||
track.extras["autoplay"] = True
|
||||
track.extras.update(
|
||||
{
|
||||
"autoplay": True,
|
||||
"enqueue_time": int(time.time()),
|
||||
"vc": player.channel.id,
|
||||
"requester": player.channel.guild.me.id,
|
||||
}
|
||||
)
|
||||
player.add(player.channel.guild.me, track)
|
||||
self.bot.dispatch(
|
||||
"red_audio_track_auto_play", player.channel.guild, track, player.channel.guild.me
|
||||
)
|
||||
if not player.current:
|
||||
await player.play()
|
||||
|
||||
async def fetch_all_contribute(self) -> List[LavalinkCacheFetchForGlobalResult]:
|
||||
return await self.local_cache_api.lavalink.fetch_all_for_global()
|
||||
|
||||
@@ -4,14 +4,14 @@ import datetime
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
from types import SimpleNamespace
|
||||
from typing import Callable, List, MutableMapping, Optional, TYPE_CHECKING, Tuple, Union
|
||||
|
||||
from redbot.core.utils import AsyncIter
|
||||
from types import SimpleNamespace
|
||||
from typing import TYPE_CHECKING, Callable, List, MutableMapping, Optional, Tuple, Union
|
||||
|
||||
from redbot.core import Config
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core.commands import Cog
|
||||
from redbot.core.utils import AsyncIter
|
||||
from redbot.core.utils.dbtools import APSWConnectionWrapper
|
||||
|
||||
from ..audio_logging import debug_exc_log
|
||||
@@ -313,7 +313,7 @@ class LavalinkTableWrapper(BaseWrapper):
|
||||
self.statement.get_random = LAVALINK_QUERY_LAST_FETCHED_RANDOM
|
||||
self.statement.get_all_global = LAVALINK_FETCH_ALL_ENTRIES_GLOBAL
|
||||
self.fetch_result = LavalinkCacheFetchResult
|
||||
self.fetch_for_global: Optional[Callable] = None
|
||||
self.fetch_for_global: Optional[Callable] = LavalinkCacheFetchForGlobalResult
|
||||
|
||||
async def fetch_one(
|
||||
self, values: MutableMapping
|
||||
|
||||
133
redbot/cogs/audio/apis/persist_queue_wrapper.py
Normal file
133
redbot/cogs/audio/apis/persist_queue_wrapper.py
Normal file
@@ -0,0 +1,133 @@
|
||||
import concurrent
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import TYPE_CHECKING, List, Union
|
||||
|
||||
import lavalink
|
||||
|
||||
from redbot.core import Config
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core.commands import Cog
|
||||
from redbot.core.utils import AsyncIter
|
||||
from redbot.core.utils.dbtools import APSWConnectionWrapper
|
||||
|
||||
from ..audio_logging import debug_exc_log
|
||||
from ..sql_statements import (
|
||||
PERSIST_QUEUE_BULK_PLAYED,
|
||||
PERSIST_QUEUE_CREATE_INDEX,
|
||||
PERSIST_QUEUE_CREATE_TABLE,
|
||||
PERSIST_QUEUE_DELETE_SCHEDULED,
|
||||
PERSIST_QUEUE_DROP_TABLE,
|
||||
PERSIST_QUEUE_FETCH_ALL,
|
||||
PERSIST_QUEUE_PLAYED,
|
||||
PERSIST_QUEUE_UPSERT,
|
||||
PRAGMA_FETCH_user_version,
|
||||
PRAGMA_SET_journal_mode,
|
||||
PRAGMA_SET_read_uncommitted,
|
||||
PRAGMA_SET_temp_store,
|
||||
PRAGMA_SET_user_version,
|
||||
)
|
||||
from .api_utils import QueueFetchResult
|
||||
|
||||
log = logging.getLogger("red.cogs.Audio.api.PersistQueueWrapper")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .. import Audio
|
||||
|
||||
|
||||
class QueueInterface:
|
||||
def __init__(
|
||||
self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
|
||||
):
|
||||
self.bot = bot
|
||||
self.database = conn
|
||||
self.config = config
|
||||
self.cog = cog
|
||||
self.statement = SimpleNamespace()
|
||||
self.statement.pragma_temp_store = PRAGMA_SET_temp_store
|
||||
self.statement.pragma_journal_mode = PRAGMA_SET_journal_mode
|
||||
self.statement.pragma_read_uncommitted = PRAGMA_SET_read_uncommitted
|
||||
self.statement.set_user_version = PRAGMA_SET_user_version
|
||||
self.statement.get_user_version = PRAGMA_FETCH_user_version
|
||||
self.statement.create_table = PERSIST_QUEUE_CREATE_TABLE
|
||||
self.statement.create_index = PERSIST_QUEUE_CREATE_INDEX
|
||||
|
||||
self.statement.upsert = PERSIST_QUEUE_UPSERT
|
||||
self.statement.update_bulk_player = PERSIST_QUEUE_BULK_PLAYED
|
||||
self.statement.delete_scheduled = PERSIST_QUEUE_DELETE_SCHEDULED
|
||||
self.statement.drop_table = PERSIST_QUEUE_DROP_TABLE
|
||||
|
||||
self.statement.get_all = PERSIST_QUEUE_FETCH_ALL
|
||||
self.statement.get_player = PERSIST_QUEUE_PLAYED
|
||||
|
||||
async def init(self) -> None:
|
||||
"""Initialize the PersistQueue table"""
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(self.database.cursor().execute, self.statement.pragma_temp_store)
|
||||
executor.submit(self.database.cursor().execute, self.statement.pragma_journal_mode)
|
||||
executor.submit(self.database.cursor().execute, self.statement.pragma_read_uncommitted)
|
||||
executor.submit(self.database.cursor().execute, self.statement.create_table)
|
||||
executor.submit(self.database.cursor().execute, self.statement.create_index)
|
||||
|
||||
async def fetch_all(self) -> List[QueueFetchResult]:
|
||||
"""Fetch all playlists"""
|
||||
output = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
for future in concurrent.futures.as_completed(
|
||||
[
|
||||
executor.submit(
|
||||
self.database.cursor().execute,
|
||||
self.statement.get_all,
|
||||
)
|
||||
]
|
||||
):
|
||||
try:
|
||||
row_result = future.result()
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed to complete playlist fetch from database")
|
||||
return []
|
||||
|
||||
async for index, row in AsyncIter(row_result).enumerate(start=1):
|
||||
output.append(QueueFetchResult(*row))
|
||||
return output
|
||||
|
||||
async def played(self, guild_id: int, track_id: str) -> None:
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(
|
||||
self.database.cursor().execute,
|
||||
PERSIST_QUEUE_PLAYED,
|
||||
{"guild_id": guild_id, "track_id": track_id},
|
||||
)
|
||||
|
||||
async def delete_scheduled(self):
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(self.database.cursor().execute, PERSIST_QUEUE_DELETE_SCHEDULED)
|
||||
|
||||
async def drop(self, guild_id: int):
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(
|
||||
self.database.cursor().execute, PERSIST_QUEUE_BULK_PLAYED, ({"guild_id": guild_id})
|
||||
)
|
||||
|
||||
async def enqueued(self, guild_id: int, room_id: int, track: lavalink.Track):
|
||||
enqueue_time = track.extras.get("enqueue_time", 0)
|
||||
if enqueue_time == 0:
|
||||
track.extras["enqueue_time"] = int(time.time())
|
||||
track_identifier = track.track_identifier
|
||||
track = self.cog.track_to_json(track)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(
|
||||
self.database.cursor().execute,
|
||||
PERSIST_QUEUE_UPSERT,
|
||||
{
|
||||
"guild_id": int(guild_id),
|
||||
"room_id": int(room_id),
|
||||
"played": False,
|
||||
"time": enqueue_time,
|
||||
"track": json.dumps(track),
|
||||
"track_id": track_identifier,
|
||||
},
|
||||
)
|
||||
@@ -1,12 +1,13 @@
|
||||
import logging
|
||||
|
||||
from typing import List, MutableMapping, Optional, Union
|
||||
|
||||
import discord
|
||||
import lavalink
|
||||
from redbot.core.utils import AsyncIter
|
||||
|
||||
from redbot.core import Config, commands
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core.utils import AsyncIter
|
||||
|
||||
from ..errors import NotAllowed
|
||||
from ..utils import PlaylistScope
|
||||
|
||||
@@ -1,17 +1,18 @@
|
||||
import concurrent
|
||||
import json
|
||||
import logging
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import List, MutableMapping, Optional
|
||||
|
||||
from redbot.core.utils import AsyncIter
|
||||
|
||||
from redbot.core import Config
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core.utils import AsyncIter
|
||||
from redbot.core.utils.dbtools import APSWConnectionWrapper
|
||||
|
||||
from ..audio_logging import debug_exc_log
|
||||
from ..sql_statements import (
|
||||
HANDLE_DISCORD_DATA_DELETION_QUERY,
|
||||
PLAYLIST_CREATE_INDEX,
|
||||
PLAYLIST_CREATE_TABLE,
|
||||
PLAYLIST_DELETE,
|
||||
@@ -27,7 +28,6 @@ from ..sql_statements import (
|
||||
PRAGMA_SET_read_uncommitted,
|
||||
PRAGMA_SET_temp_store,
|
||||
PRAGMA_SET_user_version,
|
||||
HANDLE_DISCORD_DATA_DELETION_QUERY,
|
||||
)
|
||||
from ..utils import PlaylistScope
|
||||
from .api_utils import PlaylistFetchResult
|
||||
@@ -62,7 +62,7 @@ class PlaylistWrapper:
|
||||
self.statement.drop_user_playlists = HANDLE_DISCORD_DATA_DELETION_QUERY
|
||||
|
||||
async def init(self) -> None:
|
||||
"""Initialize the Playlist table"""
|
||||
"""Initialize the Playlist table."""
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(self.database.cursor().execute, self.statement.pragma_temp_store)
|
||||
executor.submit(self.database.cursor().execute, self.statement.pragma_journal_mode)
|
||||
@@ -72,7 +72,7 @@ class PlaylistWrapper:
|
||||
|
||||
@staticmethod
|
||||
def get_scope_type(scope: str) -> int:
|
||||
"""Convert a scope to a numerical identifier"""
|
||||
"""Convert a scope to a numerical identifier."""
|
||||
if scope == PlaylistScope.GLOBAL.value:
|
||||
table = 1
|
||||
elif scope == PlaylistScope.USER.value:
|
||||
@@ -82,7 +82,7 @@ class PlaylistWrapper:
|
||||
return table
|
||||
|
||||
async def fetch(self, scope: str, playlist_id: int, scope_id: int) -> PlaylistFetchResult:
|
||||
"""Fetch a single playlist"""
|
||||
"""Fetch a single playlist."""
|
||||
scope_type = self.get_scope_type(scope)
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
@@ -113,7 +113,7 @@ class PlaylistWrapper:
|
||||
async def fetch_all(
|
||||
self, scope: str, scope_id: int, author_id=None
|
||||
) -> List[PlaylistFetchResult]:
|
||||
"""Fetch all playlists"""
|
||||
"""Fetch all playlists."""
|
||||
scope_type = self.get_scope_type(scope)
|
||||
output = []
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
@@ -160,7 +160,7 @@ class PlaylistWrapper:
|
||||
async def fetch_all_converter(
|
||||
self, scope: str, playlist_name, playlist_id
|
||||
) -> List[PlaylistFetchResult]:
|
||||
"""Fetch all playlists with the specified filter"""
|
||||
"""Fetch all playlists with the specified filter."""
|
||||
scope_type = self.get_scope_type(scope)
|
||||
try:
|
||||
playlist_id = int(playlist_id)
|
||||
@@ -195,7 +195,7 @@ class PlaylistWrapper:
|
||||
return output
|
||||
|
||||
async def delete(self, scope: str, playlist_id: int, scope_id: int):
|
||||
"""Deletes a single playlists"""
|
||||
"""Deletes a single playlists."""
|
||||
scope_type = self.get_scope_type(scope)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(
|
||||
@@ -205,12 +205,12 @@ class PlaylistWrapper:
|
||||
)
|
||||
|
||||
async def delete_scheduled(self):
|
||||
"""Clean up database from all deleted playlists"""
|
||||
"""Clean up database from all deleted playlists."""
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(self.database.cursor().execute, self.statement.delete_scheduled)
|
||||
|
||||
async def drop(self, scope: str):
|
||||
"""Delete all playlists in a scope"""
|
||||
"""Delete all playlists in a scope."""
|
||||
scope_type = self.get_scope_type(scope)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(
|
||||
@@ -220,7 +220,7 @@ class PlaylistWrapper:
|
||||
)
|
||||
|
||||
async def create_table(self):
|
||||
"""Create the playlist table"""
|
||||
"""Create the playlist table."""
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(self.database.cursor().execute, PLAYLIST_CREATE_TABLE)
|
||||
|
||||
@@ -234,7 +234,7 @@ class PlaylistWrapper:
|
||||
playlist_url: Optional[str],
|
||||
tracks: List[MutableMapping],
|
||||
):
|
||||
"""Insert or update a playlist into the database"""
|
||||
"""Insert or update a playlist into the database."""
|
||||
scope_type = self.get_scope_type(scope)
|
||||
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
|
||||
executor.submit(
|
||||
|
||||
@@ -1,16 +1,18 @@
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import List, Mapping, MutableMapping, Optional, TYPE_CHECKING, Tuple, Union
|
||||
|
||||
from typing import TYPE_CHECKING, List, Mapping, MutableMapping, Optional, Tuple, Union
|
||||
|
||||
import aiohttp
|
||||
from redbot.core.i18n import Translator
|
||||
from redbot.core.utils import AsyncIter
|
||||
|
||||
from redbot.core import Config
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core.commands import Cog, Context
|
||||
from redbot.core.i18n import Translator
|
||||
from redbot.core.utils import AsyncIter
|
||||
|
||||
from ..errors import SpotifyFetchError
|
||||
|
||||
@@ -46,7 +48,7 @@ class SpotifyWrapper:
|
||||
|
||||
@staticmethod
|
||||
def spotify_format_call(query_type: str, key: str) -> Tuple[str, MutableMapping]:
|
||||
"""Format the spotify endpoint"""
|
||||
"""Format the spotify endpoint."""
|
||||
params: MutableMapping = {}
|
||||
if query_type == "album":
|
||||
query = f"{ALBUMS_ENDPOINT}/{key}/tracks"
|
||||
@@ -59,7 +61,7 @@ class SpotifyWrapper:
|
||||
async def get_spotify_track_info(
|
||||
self, track_data: MutableMapping, ctx: Context
|
||||
) -> Tuple[str, ...]:
|
||||
"""Extract track info from spotify response"""
|
||||
"""Extract track info from spotify response."""
|
||||
prefer_lyrics = await self.cog.get_lyrics_status(ctx)
|
||||
track_name = track_data["name"]
|
||||
if prefer_lyrics:
|
||||
@@ -75,14 +77,14 @@ class SpotifyWrapper:
|
||||
|
||||
@staticmethod
|
||||
async def is_access_token_valid(token: MutableMapping) -> bool:
|
||||
"""Check if current token is not too old"""
|
||||
"""Check if current token is not too old."""
|
||||
return (token["expires_at"] - int(time.time())) < 60
|
||||
|
||||
@staticmethod
|
||||
def make_auth_header(
|
||||
client_id: Optional[str], client_secret: Optional[str]
|
||||
) -> MutableMapping[str, Union[str, int]]:
|
||||
"""Make Authorization header for spotify token"""
|
||||
"""Make Authorization header for spotify token."""
|
||||
if client_id is None:
|
||||
client_id = ""
|
||||
if client_secret is None:
|
||||
@@ -93,11 +95,11 @@ class SpotifyWrapper:
|
||||
async def get(
|
||||
self, url: str, headers: MutableMapping = None, params: MutableMapping = None
|
||||
) -> MutableMapping[str, str]:
|
||||
"""Make a GET request to the spotify API"""
|
||||
"""Make a GET request to the spotify API."""
|
||||
if params is None:
|
||||
params = {}
|
||||
async with self.session.request("GET", url, params=params, headers=headers) as r:
|
||||
data = await r.json()
|
||||
data = await r.json(loads=json.loads)
|
||||
if r.status != 200:
|
||||
log.debug(f"Issue making GET request to {url}: [{r.status}] {data}")
|
||||
return data
|
||||
@@ -106,7 +108,7 @@ class SpotifyWrapper:
|
||||
self._token = new_token
|
||||
|
||||
async def get_token(self) -> None:
|
||||
"""Get the stored spotify tokens"""
|
||||
"""Get the stored spotify tokens."""
|
||||
if not self._token:
|
||||
self._token = await self.bot.get_shared_api_tokens("spotify")
|
||||
|
||||
@@ -114,10 +116,17 @@ class SpotifyWrapper:
|
||||
self.client_secret = self._token.get("client_secret", "")
|
||||
|
||||
async def get_country_code(self, ctx: Context = None) -> str:
|
||||
return await self.config.guild(ctx.guild).country_code() if ctx else "US"
|
||||
return (
|
||||
(
|
||||
await self.config.user(ctx.author).country_code()
|
||||
or await self.config.guild(ctx.guild).country_code()
|
||||
)
|
||||
if ctx
|
||||
else "US"
|
||||
)
|
||||
|
||||
async def request_access_token(self) -> MutableMapping:
|
||||
"""Make a spotify call to get the auth token"""
|
||||
"""Make a spotify call to get the auth token."""
|
||||
await self.get_token()
|
||||
payload = {"grant_type": "client_credentials"}
|
||||
headers = self.make_auth_header(self.client_id, self.client_secret)
|
||||
@@ -125,7 +134,7 @@ class SpotifyWrapper:
|
||||
return r
|
||||
|
||||
async def get_access_token(self) -> Optional[str]:
|
||||
"""Get the access_token"""
|
||||
"""Get the access_token."""
|
||||
if self.spotify_token and not await self.is_access_token_valid(self.spotify_token):
|
||||
return self.spotify_token["access_token"]
|
||||
token = await self.request_access_token()
|
||||
@@ -142,20 +151,20 @@ class SpotifyWrapper:
|
||||
async def post(
|
||||
self, url: str, payload: MutableMapping, headers: MutableMapping = None
|
||||
) -> MutableMapping:
|
||||
"""Make a POST call to spotify"""
|
||||
"""Make a POST call to spotify."""
|
||||
async with self.session.post(url, data=payload, headers=headers) as r:
|
||||
data = await r.json()
|
||||
data = await r.json(loads=json.loads)
|
||||
if r.status != 200:
|
||||
log.debug(f"Issue making POST request to {url}: [{r.status}] {data}")
|
||||
return data
|
||||
|
||||
async def make_get_call(self, url: str, params: MutableMapping) -> MutableMapping:
|
||||
"""Make a Get call to spotify"""
|
||||
"""Make a Get call to spotify."""
|
||||
token = await self.get_access_token()
|
||||
return await self.get(url, params=params, headers={"Authorization": f"Bearer {token}"})
|
||||
|
||||
async def get_categories(self, ctx: Context = None) -> List[MutableMapping]:
|
||||
"""Get the spotify categories"""
|
||||
"""Get the spotify categories."""
|
||||
country_code = await self.get_country_code(ctx=ctx)
|
||||
params: MutableMapping = {"country": country_code} if country_code else {}
|
||||
result = await self.make_get_call(CATEGORY_ENDPOINT, params=params)
|
||||
@@ -171,7 +180,7 @@ class SpotifyWrapper:
|
||||
return [{c["name"]: c["id"]} for c in categories if c]
|
||||
|
||||
async def get_playlist_from_category(self, category: str, ctx: Context = None):
|
||||
"""Get spotify playlists for the specified category"""
|
||||
"""Get spotify playlists for the specified category."""
|
||||
url = f"{CATEGORY_ENDPOINT}/{category}/playlists"
|
||||
country_code = await self.get_country_code(ctx=ctx)
|
||||
params: MutableMapping = {"country": country_code} if country_code else {}
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Mapping, Optional, TYPE_CHECKING, Union
|
||||
|
||||
from typing import TYPE_CHECKING, Mapping, Optional, Union
|
||||
|
||||
import aiohttp
|
||||
|
||||
@@ -33,15 +35,17 @@ class YouTubeWrapper:
|
||||
def update_token(self, new_token: Mapping[str, str]):
|
||||
self._token = new_token
|
||||
|
||||
async def _get_api_key(self) -> str:
|
||||
"""Get the stored youtube token"""
|
||||
async def _get_api_key(
|
||||
self,
|
||||
) -> str:
|
||||
"""Get the stored youtube token."""
|
||||
if not self._token:
|
||||
self._token = await self.bot.get_shared_api_tokens("youtube")
|
||||
self.api_key = self._token.get("api_key", "")
|
||||
return self.api_key if self.api_key is not None else ""
|
||||
|
||||
async def get_call(self, query: str) -> Optional[str]:
|
||||
"""Make a Get call to youtube data api"""
|
||||
"""Make a Get call to youtube data api."""
|
||||
params = {
|
||||
"q": query,
|
||||
"part": "id",
|
||||
@@ -57,7 +61,7 @@ class YouTubeWrapper:
|
||||
raise YouTubeApiError("Your YouTube Data API quota has been reached.")
|
||||
return None
|
||||
else:
|
||||
search_response = await r.json()
|
||||
search_response = await r.json(loads=json.loads)
|
||||
for search_result in search_response.get("items", []):
|
||||
if search_result["id"]["kind"] == "youtube#video":
|
||||
return f"https://www.youtube.com/watch?v={search_result['id']['videoId']}"
|
||||
|
||||
Reference in New Issue
Block a user