Audio Cog - v2.3.0 (#4446)

* First commit - Bring everything from dev cog minus NSFW support

* Add a toggle for auto deafen

* Add a one off Send to Owners

* aaaaaaa

* Update this to ensure `get_perms` is not called if the API is disabled

* Apply suggestions from code review

Co-authored-by: Vuks <51289041+Vuks69@users.noreply.github.com>

* silence any errors here (in case API is down so it doesnt affect audio)

* update the message to tell the mto join the Official Red server.

* remove useless sutff, and change dj check order to ensure bot doesnt join VC for non DJ's

* ffs

* Update redbot/cogs/audio/core/tasks/startup.py

Co-authored-by: Twentysix <Twentysix26@users.noreply.github.com>

* Aikas Review

* Add #3995 in here

* update

* *sigh*

* lock behind owner

* to help with debugging

* Revert "to help with debugging"

This reverts commit 8cbf17be

* resolve last review

Co-authored-by: Vuks <51289041+Vuks69@users.noreply.github.com>
Co-authored-by: Twentysix <Twentysix26@users.noreply.github.com>
This commit is contained in:
Draper
2020-10-12 19:39:39 +01:00
committed by GitHub
parent 29ebf0f060
commit 2da9b502d8
41 changed files with 1553 additions and 331 deletions

View File

@@ -6,6 +6,7 @@ from dataclasses import dataclass, field
from typing import List, MutableMapping, Optional, Union
import discord
import lavalink
from redbot.core.bot import Red
from redbot.core.utils.chat_formatting import humanize_list
@@ -74,8 +75,22 @@ class PlaylistFetchResult:
self.tracks = json.loads(self.tracks)
@dataclass
class QueueFetchResult:
guild_id: int
room_id: int
track: dict = field(default_factory=lambda: {})
track_object: lavalink.Track = None
def __post_init__(self):
if isinstance(self.track, str):
self.track = json.loads(self.track)
if self.track:
self.track_object = lavalink.Track(self.track)
def standardize_scope(scope: str) -> str:
"""Convert any of the used scopes into one we are expecting"""
"""Convert any of the used scopes into one we are expecting."""
scope = scope.upper()
valid_scopes = ["GLOBAL", "GUILD", "AUTHOR", "USER", "SERVER", "MEMBER", "BOT"]
@@ -103,7 +118,7 @@ def prepare_config_scope(
author: Union[discord.abc.User, int] = None,
guild: Union[discord.Guild, int] = None,
):
"""Return the scope used by Playlists"""
"""Return the scope used by Playlists."""
scope = standardize_scope(scope)
if scope == PlaylistScope.GLOBAL.value:
config_scope = [PlaylistScope.GLOBAL.value, bot.user.id]
@@ -121,7 +136,7 @@ def prepare_config_scope(
def prepare_config_scope_for_migration23( # TODO: remove me in a future version ?
scope, author: Union[discord.abc.User, int] = None, guild: discord.Guild = None
):
"""Return the scope used by Playlists"""
"""Return the scope used by Playlists."""
scope = standardize_scope(scope)
if scope == PlaylistScope.GLOBAL.value:

View File

@@ -1,8 +1,10 @@
import asyncio
import contextlib
import json
import logging
import urllib.parse
from typing import Mapping, Optional, TYPE_CHECKING, Union
from copy import copy
from typing import TYPE_CHECKING, Mapping, Optional, Union
import aiohttp
from lavalink.rest_api import LoadResult
@@ -17,7 +19,7 @@ from ..audio_logging import IS_DEBUG, debug_exc_log
if TYPE_CHECKING:
from .. import Audio
_API_URL = "https://redbot.app/"
_API_URL = "https://api.redbot.app/"
log = logging.getLogger("red.cogs.Audio.api.GlobalDB")
@@ -32,11 +34,150 @@ class GlobalCacheWrapper:
self.session = session
self.api_key = None
self._handshake_token = ""
self.can_write = False
self._handshake_token = ""
self.has_api_key = None
self._token: Mapping[str, str] = {}
self.cog = cog
def update_token(self, new_token: Mapping[str, str]):
self._token = new_token
async def _get_api_key(
self,
) -> Optional[str]:
if not self._token:
self._token = await self.bot.get_shared_api_tokens("audiodb")
self.api_key = self._token.get("api_key", None)
self.has_api_key = self.cog.global_api_user.get("can_post")
id_list = list(self.bot.owner_ids)
self._handshake_token = "||".join(map(str, id_list))
return self.api_key
async def get_call(self, query: Optional[Query] = None) -> dict:
api_url = f"{_API_URL}api/v2/queries"
if not self.cog.global_api_user.get("can_read"):
return {}
try:
query = Query.process_input(query, self.cog.local_folder_current_path)
if any([not query or not query.valid or query.is_spotify or query.is_local]):
return {}
await self._get_api_key()
if self.api_key is None:
return {}
search_response = "error"
query = query.lavalink_query
with contextlib.suppress(aiohttp.ContentTypeError, asyncio.TimeoutError):
async with self.session.get(
api_url,
timeout=aiohttp.ClientTimeout(total=await self.config.global_db_get_timeout()),
headers={"Authorization": self.api_key, "X-Token": self._handshake_token},
params={"query": query},
) as r:
search_response = await r.json(loads=json.loads)
if IS_DEBUG and "x-process-time" in r.headers:
log.debug(
f"GET || Ping {r.headers.get('x-process-time')} || "
f"Status code {r.status} || {query}"
)
if "tracks" not in search_response:
return {}
return search_response
except Exception as err:
debug_exc_log(log, err, f"Failed to Get query: {api_url}/{query}")
return {}
async def get_spotify(self, title: str, author: Optional[str]) -> dict:
if not self.cog.global_api_user.get("can_read"):
return {}
api_url = f"{_API_URL}api/v2/queries/spotify"
try:
search_response = "error"
params = {"title": title, "author": author}
await self._get_api_key()
if self.api_key is None:
return {}
with contextlib.suppress(aiohttp.ContentTypeError, asyncio.TimeoutError):
async with self.session.get(
api_url,
timeout=aiohttp.ClientTimeout(total=await self.config.global_db_get_timeout()),
headers={"Authorization": self.api_key, "X-Token": self._handshake_token},
params=params,
) as r:
search_response = await r.json(loads=json.loads)
if IS_DEBUG and "x-process-time" in r.headers:
log.debug(
f"GET/spotify || Ping {r.headers.get('x-process-time')} || "
f"Status code {r.status} || {title} - {author}"
)
if "tracks" not in search_response:
return {}
return search_response
except Exception as err:
debug_exc_log(log, err, f"Failed to Get query: {api_url}")
return {}
async def post_call(self, llresponse: LoadResult, query: Optional[Query]) -> None:
try:
if not self.cog.global_api_user.get("can_post"):
return
query = Query.process_input(query, self.cog.local_folder_current_path)
if llresponse.has_error or llresponse.load_type.value in ["NO_MATCHES", "LOAD_FAILED"]:
return
if query and query.valid and query.is_youtube:
query = query.lavalink_query
else:
return None
await self._get_api_key()
if self.api_key is None:
return None
api_url = f"{_API_URL}api/v2/queries"
async with self.session.post(
api_url,
json=llresponse._raw,
headers={"Authorization": self.api_key, "X-Token": self._handshake_token},
params={"query": query},
) as r:
await r.read()
if IS_DEBUG and "x-process-time" in r.headers:
log.debug(
f"POST || Ping {r.headers.get('x-process-time')} ||"
f" Status code {r.status} || {query}"
)
except Exception as err:
debug_exc_log(log, err, f"Failed to post query: {query}")
await asyncio.sleep(0)
async def update_global(self, llresponse: LoadResult, query: Optional[Query] = None):
await self.post_call(llresponse=llresponse, query=query)
async def report_invalid(self, id: str) -> None:
if not self.cog.global_api_user.get("can_delete"):
return
api_url = f"{_API_URL}api/v2/queries/es/id"
with contextlib.suppress(Exception):
async with self.session.delete(
api_url,
headers={"Authorization": self.api_key, "X-Token": self._handshake_token},
params={"id": id},
) as r:
await r.read()
async def get_perms(self):
global_api_user = copy(self.cog.global_api_user)
await self._get_api_key()
is_enabled = await self.config.global_db_enabled()
await self._get_api_key()
if (not is_enabled) or self.api_key is None:
return global_api_user
with contextlib.suppress(Exception):
async with aiohttp.ClientSession(json_serialize=json.dumps) as session:
async with session.get(
f"{_API_URL}api/v2/users/me",
headers={"Authorization": self.api_key, "X-Token": self._handshake_token},
) as resp:
if resp.status == 200:
search_response = await resp.json(loads=json.loads)
global_api_user["fetched"] = True
global_api_user["can_read"] = search_response.get("can_read", False)
global_api_user["can_post"] = search_response.get("can_post", False)
global_api_user["can_delete"] = search_response.get("can_delete", False)
return global_api_user

View File

@@ -1,30 +1,34 @@
import asyncio
import contextlib
import datetime
import json
import logging
import random
import time
from collections import namedtuple
from typing import Callable, List, MutableMapping, Optional, TYPE_CHECKING, Tuple, Union, cast
from typing import TYPE_CHECKING, Callable, List, MutableMapping, Optional, Tuple, Union, cast
import aiohttp
import discord
import lavalink
from lavalink.rest_api import LoadResult
from redbot.core.utils import AsyncIter
from lavalink.rest_api import LoadResult, LoadType
from redbot.core import Config, commands
from redbot.core.bot import Red
from redbot.core.commands import Cog, Context
from redbot.core.i18n import Translator
from redbot.core.utils import AsyncIter
from redbot.core.utils.dbtools import APSWConnectionWrapper
from ..audio_dataclasses import Query
from ..audio_logging import IS_DEBUG, debug_exc_log
from ..errors import DatabaseError, SpotifyFetchError, TrackEnqueueError
from ..utils import CacheLevel, Notifier
from .api_utils import LavalinkCacheFetchForGlobalResult
from .global_db import GlobalCacheWrapper
from .local_db import LocalCacheWrapper
from .persist_queue_wrapper import QueueInterface
from .playlist_interface import get_playlist
from .playlist_wrapper import PlaylistWrapper
from .spotify import SpotifyWrapper
@@ -36,6 +40,7 @@ if TYPE_CHECKING:
_ = Translator("Audio", __file__)
log = logging.getLogger("red.cogs.Audio.api.AudioAPIInterface")
_TOP_100_US = "https://www.youtube.com/playlist?list=PL4fGSI1pDJn5rWitrRWFKdm-ulaFiIyoK"
# TODO: Get random from global Cache
class AudioAPIInterface:
@@ -60,20 +65,22 @@ class AudioAPIInterface:
self.youtube_api: YouTubeWrapper = YouTubeWrapper(self.bot, self.config, session, self.cog)
self.local_cache_api = LocalCacheWrapper(self.bot, self.config, self.conn, self.cog)
self.global_cache_api = GlobalCacheWrapper(self.bot, self.config, session, self.cog)
self.persistent_queue_api = QueueInterface(self.bot, self.config, self.conn, self.cog)
self._session: aiohttp.ClientSession = session
self._tasks: MutableMapping = {}
self._lock: asyncio.Lock = asyncio.Lock()
async def initialize(self) -> None:
"""Initialises the Local Cache connection"""
"""Initialises the Local Cache connection."""
await self.local_cache_api.lavalink.init()
await self.persistent_queue_api.init()
def close(self) -> None:
"""Closes the Local Cache connection"""
"""Closes the Local Cache connection."""
self.local_cache_api.lavalink.close()
async def get_random_track_from_db(self) -> Optional[MutableMapping]:
"""Get a random track from the local database and return it"""
async def get_random_track_from_db(self, tries=0) -> Optional[MutableMapping]:
"""Get a random track from the local database and return it."""
track: Optional[MutableMapping] = {}
try:
query_data = {}
@@ -106,7 +113,7 @@ class AudioAPIInterface:
action_type: str = None,
data: Union[List[MutableMapping], MutableMapping] = None,
) -> None:
"""Separate the tasks and run them in the appropriate functions"""
"""Separate the tasks and run them in the appropriate functions."""
if not data:
return
@@ -126,9 +133,11 @@ class AudioAPIInterface:
await self.local_cache_api.youtube.update(data)
elif table == "spotify":
await self.local_cache_api.spotify.update(data)
elif action_type == "global" and isinstance(data, list):
await asyncio.gather(*[self.global_cache_api.update_global(**d) for d in data])
async def run_tasks(self, ctx: Optional[commands.Context] = None, message_id=None) -> None:
"""Run tasks for a specific context"""
"""Run tasks for a specific context."""
if message_id is not None:
lock_id = message_id
elif ctx is not None:
@@ -143,7 +152,7 @@ class AudioAPIInterface:
try:
tasks = self._tasks[lock_id]
tasks = [self.route_tasks(a, tasks[a]) for a in tasks]
await asyncio.gather(*tasks, return_exceptions=True)
await asyncio.gather(*tasks, return_exceptions=False)
del self._tasks[lock_id]
except Exception as exc:
debug_exc_log(
@@ -154,7 +163,7 @@ class AudioAPIInterface:
log.debug(f"Completed database writes for {lock_id} ({lock_author})")
async def run_all_pending_tasks(self) -> None:
"""Run all pending tasks left in the cache, called on cog_unload"""
"""Run all pending tasks left in the cache, called on cog_unload."""
async with self._lock:
if IS_DEBUG:
log.debug("Running pending writes to database")
@@ -166,7 +175,7 @@ class AudioAPIInterface:
self._tasks = {}
coro_tasks = [self.route_tasks(a, tasks[a]) for a in tasks]
await asyncio.gather(*coro_tasks, return_exceptions=True)
await asyncio.gather(*coro_tasks, return_exceptions=False)
except Exception as exc:
debug_exc_log(log, exc, "Failed database writes")
@@ -175,7 +184,7 @@ class AudioAPIInterface:
log.debug("Completed pending writes to database have finished")
def append_task(self, ctx: commands.Context, event: str, task: Tuple, _id: int = None) -> None:
"""Add a task to the cache to be run later"""
"""Add a task to the cache to be run later."""
lock_id = _id or ctx.message.id
if lock_id not in self._tasks:
self._tasks[lock_id] = {"update": [], "insert": [], "global": []}
@@ -190,7 +199,7 @@ class AudioAPIInterface:
skip_youtube: bool = False,
current_cache_level: CacheLevel = CacheLevel.none(),
) -> List[str]:
"""Return youtube URLS for the spotify URL provided"""
"""Return youtube URLS for the spotify URL provided."""
youtube_urls = []
tracks = await self.fetch_from_spotify_api(
query_type, uri, params=None, notifier=notifier, ctx=ctx
@@ -266,7 +275,7 @@ class AudioAPIInterface:
notifier: Optional[Notifier] = None,
ctx: Context = None,
) -> Union[List[MutableMapping], List[str]]:
"""Gets track info from spotify API"""
"""Gets track info from spotify API."""
if recursive is False:
(call, params) = self.spotify_api.spotify_format_call(query_type, uri)
@@ -394,9 +403,10 @@ class AudioAPIInterface:
lock: Callable,
notifier: Optional[Notifier] = None,
forced: bool = False,
query_global: bool = False,
query_global: bool = True,
) -> List[lavalink.Track]:
"""Queries the Database then falls back to Spotify and YouTube APIs then Enqueued matched tracks.
"""Queries the Database then falls back to Spotify and YouTube APIs then Enqueued matched
tracks.
Parameters
----------
@@ -423,7 +433,9 @@ class AudioAPIInterface:
List[str]
List of Youtube URLs.
"""
# globaldb_toggle = await self.config.global_db_enabled()
await self.global_cache_api._get_api_key()
globaldb_toggle = await self.config.global_db_enabled()
global_entry = globaldb_toggle and query_global
track_list: List = []
has_not_allowed = False
try:
@@ -485,7 +497,14 @@ class AudioAPIInterface:
)
except Exception as exc:
debug_exc_log(log, exc, f"Failed to fetch {track_info} from YouTube table")
should_query_global = globaldb_toggle and query_global and val is None
if should_query_global:
llresponse = await self.global_cache_api.get_spotify(track_name, artist_name)
if llresponse:
if llresponse.get("loadType") == "V2_COMPACT":
llresponse["loadType"] = "V2_COMPAT"
llresponse = LoadResult(llresponse)
val = llresponse or None
if val is None:
val = await self.fetch_youtube_query(
ctx, track_info, current_cache_level=current_cache_level
@@ -494,34 +513,44 @@ class AudioAPIInterface:
task = ("update", ("youtube", {"track": track_info}))
self.append_task(ctx, *task)
if llresponse is not None:
if isinstance(llresponse, LoadResult):
track_object = llresponse.tracks
elif val:
try:
(result, called_api) = await self.fetch_track(
ctx,
player,
Query.process_input(val, self.cog.local_folder_current_path),
forced=forced,
)
except (RuntimeError, aiohttp.ServerDisconnectedError):
lock(ctx, False)
error_embed = discord.Embed(
colour=await ctx.embed_colour(),
title=_("The connection was reset while loading the playlist."),
)
if notifier is not None:
await notifier.update_embed(error_embed)
break
except asyncio.TimeoutError:
lock(ctx, False)
error_embed = discord.Embed(
colour=await ctx.embed_colour(),
title=_("Player timeout, skipping remaining tracks."),
)
if notifier is not None:
await notifier.update_embed(error_embed)
break
result = None
if should_query_global:
llresponse = await self.global_cache_api.get_call(val)
if llresponse:
if llresponse.get("loadType") == "V2_COMPACT":
llresponse["loadType"] = "V2_COMPAT"
llresponse = LoadResult(llresponse)
result = llresponse or None
if not result:
try:
(result, called_api) = await self.fetch_track(
ctx,
player,
Query.process_input(val, self.cog.local_folder_current_path),
forced=forced,
should_query_global=not should_query_global,
)
except (RuntimeError, aiohttp.ServerDisconnectedError):
lock(ctx, False)
error_embed = discord.Embed(
colour=await ctx.embed_colour(),
title=_("The connection was reset while loading the playlist."),
)
if notifier is not None:
await notifier.update_embed(error_embed)
break
except asyncio.TimeoutError:
lock(ctx, False)
error_embed = discord.Embed(
colour=await ctx.embed_colour(),
title=_("Player timeout, skipping remaining tracks."),
)
if notifier is not None:
await notifier.update_embed(error_embed)
break
track_object = result.tracks
else:
track_object = []
@@ -538,7 +567,7 @@ class AudioAPIInterface:
seconds=seconds,
)
if consecutive_fails >= 10:
if consecutive_fails >= (100 if global_entry else 10):
error_embed = discord.Embed(
colour=await ctx.embed_colour(),
title=_("Failing to get tracks, skipping remaining."),
@@ -551,13 +580,12 @@ class AudioAPIInterface:
continue
consecutive_fails = 0
single_track = track_object[0]
query = Query.process_input(single_track, self.cog.local_folder_current_path)
if not await self.cog.is_query_allowed(
self.config,
ctx.guild,
(
f"{single_track.title} {single_track.author} {single_track.uri} "
f"{Query.process_input(single_track, self.cog.local_folder_current_path)}"
),
ctx,
f"{single_track.title} {single_track.author} {single_track.uri} {query}",
query_obj=query,
):
has_not_allowed = True
if IS_DEBUG:
@@ -570,6 +598,13 @@ class AudioAPIInterface:
if guild_data["maxlength"] > 0:
if self.cog.is_track_length_allowed(single_track, guild_data["maxlength"]):
enqueued_tracks += 1
single_track.extras.update(
{
"enqueue_time": int(time.time()),
"vc": player.channel.id,
"requester": ctx.author.id,
}
)
player.add(ctx.author, single_track)
self.bot.dispatch(
"red_audio_track_enqueue",
@@ -579,6 +614,13 @@ class AudioAPIInterface:
)
else:
enqueued_tracks += 1
single_track.extras.update(
{
"enqueue_time": int(time.time()),
"vc": player.channel.id,
"requester": ctx.author.id,
}
)
player.add(ctx.author, single_track)
self.bot.dispatch(
"red_audio_track_enqueue",
@@ -642,9 +684,7 @@ class AudioAPIInterface:
track_info: str,
current_cache_level: CacheLevel = CacheLevel.none(),
) -> Optional[str]:
"""
Call the Youtube API and returns the youtube URL that the query matched
"""
"""Call the Youtube API and returns the youtube URL that the query matched."""
track_url = await self.youtube_api.get_call(track_info)
if CacheLevel.set_youtube().is_subset(current_cache_level) and track_url:
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
@@ -668,9 +708,7 @@ class AudioAPIInterface:
async def fetch_from_youtube_api(
self, ctx: commands.Context, track_info: str
) -> Optional[str]:
"""
Gets an YouTube URL from for the query
"""
"""Gets an YouTube URL from for the query."""
current_cache_level = CacheLevel(await self.config.cache_level())
cache_enabled = CacheLevel.set_youtube().is_subset(current_cache_level)
val = None
@@ -727,6 +765,7 @@ class AudioAPIInterface:
val = None
query = Query.process_input(query, self.cog.local_folder_current_path)
query_string = str(query)
globaldb_toggle = await self.config.global_db_enabled()
valid_global_entry = False
results = None
called_api = False
@@ -754,7 +793,31 @@ class AudioAPIInterface:
called_api = False
else:
val = None
if (
globaldb_toggle
and not val
and should_query_global
and not forced
and not query.is_local
and not query.is_spotify
):
valid_global_entry = False
with contextlib.suppress(Exception):
global_entry = await self.global_cache_api.get_call(query=query)
if global_entry.get("loadType") == "V2_COMPACT":
global_entry["loadType"] = "V2_COMPAT"
results = LoadResult(global_entry)
if results.load_type in [
LoadType.PLAYLIST_LOADED,
LoadType.TRACK_LOADED,
LoadType.SEARCH_RESULT,
LoadType.V2_COMPAT,
]:
valid_global_entry = True
if valid_global_entry:
if IS_DEBUG:
log.debug(f"Querying Global DB api for {query}")
results, called_api = results, False
if valid_global_entry:
pass
elif lazy is True:
@@ -769,6 +832,7 @@ class AudioAPIInterface:
if results.has_error:
# If cached value has an invalid entry make a new call so that it gets updated
results, called_api = await self.fetch_track(ctx, player, query, forced=True)
valid_global_entry = False
else:
if IS_DEBUG:
log.debug(f"Querying Lavalink api for {query_string}")
@@ -781,7 +845,19 @@ class AudioAPIInterface:
raise TrackEnqueueError
if results is None:
results = LoadResult({"loadType": "LOAD_FAILED", "playlistInfo": {}, "tracks": []})
valid_global_entry = False
update_global = (
globaldb_toggle and not valid_global_entry and self.global_cache_api.has_api_key
)
with contextlib.suppress(Exception):
if (
update_global
and not query.is_local
and not results.has_error
and len(results.tracks) >= 1
):
global_task = ("global", dict(llresponse=results, query=query))
self.append_task(ctx, *global_task)
if (
cache_enabled
and results.load_type
@@ -817,9 +893,7 @@ class AudioAPIInterface:
return results, called_api
async def autoplay(self, player: lavalink.Player, playlist_api: PlaylistWrapper):
"""
Enqueue a random track
"""
"""Enqueue a random track."""
autoplaylist = await self.config.guild(player.channel.guild).autoplaylist()
current_cache_level = CacheLevel(await self.config.cache_level())
cache_enabled = CacheLevel.set_lavalink().is_subset(current_cache_level)
@@ -865,19 +939,18 @@ class AudioAPIInterface:
track = random.choice(tracks)
query = Query.process_input(track, self.cog.local_folder_current_path)
await asyncio.sleep(0.001)
if not query.valid or (
if (not query.valid) or (
query.is_local
and query.local_track_path is not None
and not query.local_track_path.exists()
):
continue
notify_channel = self.bot.get_channel(player.fetch("channel"))
if not await self.cog.is_query_allowed(
self.config,
player.channel.guild,
(
f"{track.title} {track.author} {track.uri} "
f"{str(Query.process_input(track, self.cog.local_folder_current_path))}"
),
notify_channel,
f"{track.title} {track.author} {track.uri} {query}",
query_obj=query,
):
if IS_DEBUG:
log.debug(
@@ -886,11 +959,20 @@ class AudioAPIInterface:
)
continue
valid = True
track.extras["autoplay"] = True
track.extras.update(
{
"autoplay": True,
"enqueue_time": int(time.time()),
"vc": player.channel.id,
"requester": player.channel.guild.me.id,
}
)
player.add(player.channel.guild.me, track)
self.bot.dispatch(
"red_audio_track_auto_play", player.channel.guild, track, player.channel.guild.me
)
if not player.current:
await player.play()
async def fetch_all_contribute(self) -> List[LavalinkCacheFetchForGlobalResult]:
return await self.local_cache_api.lavalink.fetch_all_for_global()

View File

@@ -4,14 +4,14 @@ import datetime
import logging
import random
import time
from types import SimpleNamespace
from typing import Callable, List, MutableMapping, Optional, TYPE_CHECKING, Tuple, Union
from redbot.core.utils import AsyncIter
from types import SimpleNamespace
from typing import TYPE_CHECKING, Callable, List, MutableMapping, Optional, Tuple, Union
from redbot.core import Config
from redbot.core.bot import Red
from redbot.core.commands import Cog
from redbot.core.utils import AsyncIter
from redbot.core.utils.dbtools import APSWConnectionWrapper
from ..audio_logging import debug_exc_log
@@ -313,7 +313,7 @@ class LavalinkTableWrapper(BaseWrapper):
self.statement.get_random = LAVALINK_QUERY_LAST_FETCHED_RANDOM
self.statement.get_all_global = LAVALINK_FETCH_ALL_ENTRIES_GLOBAL
self.fetch_result = LavalinkCacheFetchResult
self.fetch_for_global: Optional[Callable] = None
self.fetch_for_global: Optional[Callable] = LavalinkCacheFetchForGlobalResult
async def fetch_one(
self, values: MutableMapping

View File

@@ -0,0 +1,133 @@
import concurrent
import json
import logging
import time
from types import SimpleNamespace
from typing import TYPE_CHECKING, List, Union
import lavalink
from redbot.core import Config
from redbot.core.bot import Red
from redbot.core.commands import Cog
from redbot.core.utils import AsyncIter
from redbot.core.utils.dbtools import APSWConnectionWrapper
from ..audio_logging import debug_exc_log
from ..sql_statements import (
PERSIST_QUEUE_BULK_PLAYED,
PERSIST_QUEUE_CREATE_INDEX,
PERSIST_QUEUE_CREATE_TABLE,
PERSIST_QUEUE_DELETE_SCHEDULED,
PERSIST_QUEUE_DROP_TABLE,
PERSIST_QUEUE_FETCH_ALL,
PERSIST_QUEUE_PLAYED,
PERSIST_QUEUE_UPSERT,
PRAGMA_FETCH_user_version,
PRAGMA_SET_journal_mode,
PRAGMA_SET_read_uncommitted,
PRAGMA_SET_temp_store,
PRAGMA_SET_user_version,
)
from .api_utils import QueueFetchResult
log = logging.getLogger("red.cogs.Audio.api.PersistQueueWrapper")
if TYPE_CHECKING:
from .. import Audio
class QueueInterface:
def __init__(
self, bot: Red, config: Config, conn: APSWConnectionWrapper, cog: Union["Audio", Cog]
):
self.bot = bot
self.database = conn
self.config = config
self.cog = cog
self.statement = SimpleNamespace()
self.statement.pragma_temp_store = PRAGMA_SET_temp_store
self.statement.pragma_journal_mode = PRAGMA_SET_journal_mode
self.statement.pragma_read_uncommitted = PRAGMA_SET_read_uncommitted
self.statement.set_user_version = PRAGMA_SET_user_version
self.statement.get_user_version = PRAGMA_FETCH_user_version
self.statement.create_table = PERSIST_QUEUE_CREATE_TABLE
self.statement.create_index = PERSIST_QUEUE_CREATE_INDEX
self.statement.upsert = PERSIST_QUEUE_UPSERT
self.statement.update_bulk_player = PERSIST_QUEUE_BULK_PLAYED
self.statement.delete_scheduled = PERSIST_QUEUE_DELETE_SCHEDULED
self.statement.drop_table = PERSIST_QUEUE_DROP_TABLE
self.statement.get_all = PERSIST_QUEUE_FETCH_ALL
self.statement.get_player = PERSIST_QUEUE_PLAYED
async def init(self) -> None:
"""Initialize the PersistQueue table"""
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(self.database.cursor().execute, self.statement.pragma_temp_store)
executor.submit(self.database.cursor().execute, self.statement.pragma_journal_mode)
executor.submit(self.database.cursor().execute, self.statement.pragma_read_uncommitted)
executor.submit(self.database.cursor().execute, self.statement.create_table)
executor.submit(self.database.cursor().execute, self.statement.create_index)
async def fetch_all(self) -> List[QueueFetchResult]:
"""Fetch all playlists"""
output = []
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
for future in concurrent.futures.as_completed(
[
executor.submit(
self.database.cursor().execute,
self.statement.get_all,
)
]
):
try:
row_result = future.result()
except Exception as exc:
debug_exc_log(log, exc, "Failed to complete playlist fetch from database")
return []
async for index, row in AsyncIter(row_result).enumerate(start=1):
output.append(QueueFetchResult(*row))
return output
async def played(self, guild_id: int, track_id: str) -> None:
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(
self.database.cursor().execute,
PERSIST_QUEUE_PLAYED,
{"guild_id": guild_id, "track_id": track_id},
)
async def delete_scheduled(self):
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(self.database.cursor().execute, PERSIST_QUEUE_DELETE_SCHEDULED)
async def drop(self, guild_id: int):
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(
self.database.cursor().execute, PERSIST_QUEUE_BULK_PLAYED, ({"guild_id": guild_id})
)
async def enqueued(self, guild_id: int, room_id: int, track: lavalink.Track):
enqueue_time = track.extras.get("enqueue_time", 0)
if enqueue_time == 0:
track.extras["enqueue_time"] = int(time.time())
track_identifier = track.track_identifier
track = self.cog.track_to_json(track)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(
self.database.cursor().execute,
PERSIST_QUEUE_UPSERT,
{
"guild_id": int(guild_id),
"room_id": int(room_id),
"played": False,
"time": enqueue_time,
"track": json.dumps(track),
"track_id": track_identifier,
},
)

View File

@@ -1,12 +1,13 @@
import logging
from typing import List, MutableMapping, Optional, Union
import discord
import lavalink
from redbot.core.utils import AsyncIter
from redbot.core import Config, commands
from redbot.core.bot import Red
from redbot.core.utils import AsyncIter
from ..errors import NotAllowed
from ..utils import PlaylistScope

View File

@@ -1,17 +1,18 @@
import concurrent
import json
import logging
from types import SimpleNamespace
from typing import List, MutableMapping, Optional
from redbot.core.utils import AsyncIter
from redbot.core import Config
from redbot.core.bot import Red
from redbot.core.utils import AsyncIter
from redbot.core.utils.dbtools import APSWConnectionWrapper
from ..audio_logging import debug_exc_log
from ..sql_statements import (
HANDLE_DISCORD_DATA_DELETION_QUERY,
PLAYLIST_CREATE_INDEX,
PLAYLIST_CREATE_TABLE,
PLAYLIST_DELETE,
@@ -27,7 +28,6 @@ from ..sql_statements import (
PRAGMA_SET_read_uncommitted,
PRAGMA_SET_temp_store,
PRAGMA_SET_user_version,
HANDLE_DISCORD_DATA_DELETION_QUERY,
)
from ..utils import PlaylistScope
from .api_utils import PlaylistFetchResult
@@ -62,7 +62,7 @@ class PlaylistWrapper:
self.statement.drop_user_playlists = HANDLE_DISCORD_DATA_DELETION_QUERY
async def init(self) -> None:
"""Initialize the Playlist table"""
"""Initialize the Playlist table."""
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(self.database.cursor().execute, self.statement.pragma_temp_store)
executor.submit(self.database.cursor().execute, self.statement.pragma_journal_mode)
@@ -72,7 +72,7 @@ class PlaylistWrapper:
@staticmethod
def get_scope_type(scope: str) -> int:
"""Convert a scope to a numerical identifier"""
"""Convert a scope to a numerical identifier."""
if scope == PlaylistScope.GLOBAL.value:
table = 1
elif scope == PlaylistScope.USER.value:
@@ -82,7 +82,7 @@ class PlaylistWrapper:
return table
async def fetch(self, scope: str, playlist_id: int, scope_id: int) -> PlaylistFetchResult:
"""Fetch a single playlist"""
"""Fetch a single playlist."""
scope_type = self.get_scope_type(scope)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
@@ -113,7 +113,7 @@ class PlaylistWrapper:
async def fetch_all(
self, scope: str, scope_id: int, author_id=None
) -> List[PlaylistFetchResult]:
"""Fetch all playlists"""
"""Fetch all playlists."""
scope_type = self.get_scope_type(scope)
output = []
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
@@ -160,7 +160,7 @@ class PlaylistWrapper:
async def fetch_all_converter(
self, scope: str, playlist_name, playlist_id
) -> List[PlaylistFetchResult]:
"""Fetch all playlists with the specified filter"""
"""Fetch all playlists with the specified filter."""
scope_type = self.get_scope_type(scope)
try:
playlist_id = int(playlist_id)
@@ -195,7 +195,7 @@ class PlaylistWrapper:
return output
async def delete(self, scope: str, playlist_id: int, scope_id: int):
"""Deletes a single playlists"""
"""Deletes a single playlists."""
scope_type = self.get_scope_type(scope)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(
@@ -205,12 +205,12 @@ class PlaylistWrapper:
)
async def delete_scheduled(self):
"""Clean up database from all deleted playlists"""
"""Clean up database from all deleted playlists."""
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(self.database.cursor().execute, self.statement.delete_scheduled)
async def drop(self, scope: str):
"""Delete all playlists in a scope"""
"""Delete all playlists in a scope."""
scope_type = self.get_scope_type(scope)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(
@@ -220,7 +220,7 @@ class PlaylistWrapper:
)
async def create_table(self):
"""Create the playlist table"""
"""Create the playlist table."""
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(self.database.cursor().execute, PLAYLIST_CREATE_TABLE)
@@ -234,7 +234,7 @@ class PlaylistWrapper:
playlist_url: Optional[str],
tracks: List[MutableMapping],
):
"""Insert or update a playlist into the database"""
"""Insert or update a playlist into the database."""
scope_type = self.get_scope_type(scope)
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
executor.submit(

View File

@@ -1,16 +1,18 @@
import base64
import contextlib
import json
import logging
import time
from typing import List, Mapping, MutableMapping, Optional, TYPE_CHECKING, Tuple, Union
from typing import TYPE_CHECKING, List, Mapping, MutableMapping, Optional, Tuple, Union
import aiohttp
from redbot.core.i18n import Translator
from redbot.core.utils import AsyncIter
from redbot.core import Config
from redbot.core.bot import Red
from redbot.core.commands import Cog, Context
from redbot.core.i18n import Translator
from redbot.core.utils import AsyncIter
from ..errors import SpotifyFetchError
@@ -46,7 +48,7 @@ class SpotifyWrapper:
@staticmethod
def spotify_format_call(query_type: str, key: str) -> Tuple[str, MutableMapping]:
"""Format the spotify endpoint"""
"""Format the spotify endpoint."""
params: MutableMapping = {}
if query_type == "album":
query = f"{ALBUMS_ENDPOINT}/{key}/tracks"
@@ -59,7 +61,7 @@ class SpotifyWrapper:
async def get_spotify_track_info(
self, track_data: MutableMapping, ctx: Context
) -> Tuple[str, ...]:
"""Extract track info from spotify response"""
"""Extract track info from spotify response."""
prefer_lyrics = await self.cog.get_lyrics_status(ctx)
track_name = track_data["name"]
if prefer_lyrics:
@@ -75,14 +77,14 @@ class SpotifyWrapper:
@staticmethod
async def is_access_token_valid(token: MutableMapping) -> bool:
"""Check if current token is not too old"""
"""Check if current token is not too old."""
return (token["expires_at"] - int(time.time())) < 60
@staticmethod
def make_auth_header(
client_id: Optional[str], client_secret: Optional[str]
) -> MutableMapping[str, Union[str, int]]:
"""Make Authorization header for spotify token"""
"""Make Authorization header for spotify token."""
if client_id is None:
client_id = ""
if client_secret is None:
@@ -93,11 +95,11 @@ class SpotifyWrapper:
async def get(
self, url: str, headers: MutableMapping = None, params: MutableMapping = None
) -> MutableMapping[str, str]:
"""Make a GET request to the spotify API"""
"""Make a GET request to the spotify API."""
if params is None:
params = {}
async with self.session.request("GET", url, params=params, headers=headers) as r:
data = await r.json()
data = await r.json(loads=json.loads)
if r.status != 200:
log.debug(f"Issue making GET request to {url}: [{r.status}] {data}")
return data
@@ -106,7 +108,7 @@ class SpotifyWrapper:
self._token = new_token
async def get_token(self) -> None:
"""Get the stored spotify tokens"""
"""Get the stored spotify tokens."""
if not self._token:
self._token = await self.bot.get_shared_api_tokens("spotify")
@@ -114,10 +116,17 @@ class SpotifyWrapper:
self.client_secret = self._token.get("client_secret", "")
async def get_country_code(self, ctx: Context = None) -> str:
return await self.config.guild(ctx.guild).country_code() if ctx else "US"
return (
(
await self.config.user(ctx.author).country_code()
or await self.config.guild(ctx.guild).country_code()
)
if ctx
else "US"
)
async def request_access_token(self) -> MutableMapping:
"""Make a spotify call to get the auth token"""
"""Make a spotify call to get the auth token."""
await self.get_token()
payload = {"grant_type": "client_credentials"}
headers = self.make_auth_header(self.client_id, self.client_secret)
@@ -125,7 +134,7 @@ class SpotifyWrapper:
return r
async def get_access_token(self) -> Optional[str]:
"""Get the access_token"""
"""Get the access_token."""
if self.spotify_token and not await self.is_access_token_valid(self.spotify_token):
return self.spotify_token["access_token"]
token = await self.request_access_token()
@@ -142,20 +151,20 @@ class SpotifyWrapper:
async def post(
self, url: str, payload: MutableMapping, headers: MutableMapping = None
) -> MutableMapping:
"""Make a POST call to spotify"""
"""Make a POST call to spotify."""
async with self.session.post(url, data=payload, headers=headers) as r:
data = await r.json()
data = await r.json(loads=json.loads)
if r.status != 200:
log.debug(f"Issue making POST request to {url}: [{r.status}] {data}")
return data
async def make_get_call(self, url: str, params: MutableMapping) -> MutableMapping:
"""Make a Get call to spotify"""
"""Make a Get call to spotify."""
token = await self.get_access_token()
return await self.get(url, params=params, headers={"Authorization": f"Bearer {token}"})
async def get_categories(self, ctx: Context = None) -> List[MutableMapping]:
"""Get the spotify categories"""
"""Get the spotify categories."""
country_code = await self.get_country_code(ctx=ctx)
params: MutableMapping = {"country": country_code} if country_code else {}
result = await self.make_get_call(CATEGORY_ENDPOINT, params=params)
@@ -171,7 +180,7 @@ class SpotifyWrapper:
return [{c["name"]: c["id"]} for c in categories if c]
async def get_playlist_from_category(self, category: str, ctx: Context = None):
"""Get spotify playlists for the specified category"""
"""Get spotify playlists for the specified category."""
url = f"{CATEGORY_ENDPOINT}/{category}/playlists"
country_code = await self.get_country_code(ctx=ctx)
params: MutableMapping = {"country": country_code} if country_code else {}

View File

@@ -1,5 +1,7 @@
import json
import logging
from typing import Mapping, Optional, TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Mapping, Optional, Union
import aiohttp
@@ -33,15 +35,17 @@ class YouTubeWrapper:
def update_token(self, new_token: Mapping[str, str]):
self._token = new_token
async def _get_api_key(self) -> str:
"""Get the stored youtube token"""
async def _get_api_key(
self,
) -> str:
"""Get the stored youtube token."""
if not self._token:
self._token = await self.bot.get_shared_api_tokens("youtube")
self.api_key = self._token.get("api_key", "")
return self.api_key if self.api_key is not None else ""
async def get_call(self, query: str) -> Optional[str]:
"""Make a Get call to youtube data api"""
"""Make a Get call to youtube data api."""
params = {
"q": query,
"part": "id",
@@ -57,7 +61,7 @@ class YouTubeWrapper:
raise YouTubeApiError("Your YouTube Data API quota has been reached.")
return None
else:
search_response = await r.json()
search_response = await r.json(loads=json.loads)
for search_result in search_response.get("items", []):
if search_result["id"]["kind"] == "youtube#video":
return f"https://www.youtube.com/watch?v={search_result['id']['videoId']}"