mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-22 02:37:57 -05:00
Audio Cog - v2.3.0 (#4446)
* First commit - Bring everything from dev cog minus NSFW support
* Add a toggle for auto deafen
* Add a one off Send to Owners
* aaaaaaa
* Update this to ensure `get_perms` is not called if the API is disabled
* Apply suggestions from code review
Co-authored-by: Vuks <51289041+Vuks69@users.noreply.github.com>
* silence any errors here (in case API is down so it doesnt affect audio)
* update the message to tell the mto join the Official Red server.
* remove useless sutff, and change dj check order to ensure bot doesnt join VC for non DJ's
* ffs
* Update redbot/cogs/audio/core/tasks/startup.py
Co-authored-by: Twentysix <Twentysix26@users.noreply.github.com>
* Aikas Review
* Add #3995 in here
* update
* *sigh*
* lock behind owner
* to help with debugging
* Revert "to help with debugging"
This reverts commit 8cbf17be
* resolve last review
Co-authored-by: Vuks <51289041+Vuks69@users.noreply.github.com>
Co-authored-by: Twentysix <Twentysix26@users.noreply.github.com>
This commit is contained in:
@@ -1,30 +1,34 @@
|
||||
import asyncio
|
||||
import contextlib
|
||||
import datetime
|
||||
import json
|
||||
import logging
|
||||
import random
|
||||
import time
|
||||
|
||||
from collections import namedtuple
|
||||
from typing import Callable, List, MutableMapping, Optional, TYPE_CHECKING, Tuple, Union, cast
|
||||
from typing import TYPE_CHECKING, Callable, List, MutableMapping, Optional, Tuple, Union, cast
|
||||
|
||||
import aiohttp
|
||||
import discord
|
||||
import lavalink
|
||||
from lavalink.rest_api import LoadResult
|
||||
from redbot.core.utils import AsyncIter
|
||||
|
||||
from lavalink.rest_api import LoadResult, LoadType
|
||||
from redbot.core import Config, commands
|
||||
from redbot.core.bot import Red
|
||||
from redbot.core.commands import Cog, Context
|
||||
from redbot.core.i18n import Translator
|
||||
from redbot.core.utils import AsyncIter
|
||||
from redbot.core.utils.dbtools import APSWConnectionWrapper
|
||||
|
||||
from ..audio_dataclasses import Query
|
||||
from ..audio_logging import IS_DEBUG, debug_exc_log
|
||||
from ..errors import DatabaseError, SpotifyFetchError, TrackEnqueueError
|
||||
from ..utils import CacheLevel, Notifier
|
||||
from .api_utils import LavalinkCacheFetchForGlobalResult
|
||||
from .global_db import GlobalCacheWrapper
|
||||
from .local_db import LocalCacheWrapper
|
||||
from .persist_queue_wrapper import QueueInterface
|
||||
from .playlist_interface import get_playlist
|
||||
from .playlist_wrapper import PlaylistWrapper
|
||||
from .spotify import SpotifyWrapper
|
||||
@@ -36,6 +40,7 @@ if TYPE_CHECKING:
|
||||
_ = Translator("Audio", __file__)
|
||||
log = logging.getLogger("red.cogs.Audio.api.AudioAPIInterface")
|
||||
_TOP_100_US = "https://www.youtube.com/playlist?list=PL4fGSI1pDJn5rWitrRWFKdm-ulaFiIyoK"
|
||||
# TODO: Get random from global Cache
|
||||
|
||||
|
||||
class AudioAPIInterface:
|
||||
@@ -60,20 +65,22 @@ class AudioAPIInterface:
|
||||
self.youtube_api: YouTubeWrapper = YouTubeWrapper(self.bot, self.config, session, self.cog)
|
||||
self.local_cache_api = LocalCacheWrapper(self.bot, self.config, self.conn, self.cog)
|
||||
self.global_cache_api = GlobalCacheWrapper(self.bot, self.config, session, self.cog)
|
||||
self.persistent_queue_api = QueueInterface(self.bot, self.config, self.conn, self.cog)
|
||||
self._session: aiohttp.ClientSession = session
|
||||
self._tasks: MutableMapping = {}
|
||||
self._lock: asyncio.Lock = asyncio.Lock()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
"""Initialises the Local Cache connection"""
|
||||
"""Initialises the Local Cache connection."""
|
||||
await self.local_cache_api.lavalink.init()
|
||||
await self.persistent_queue_api.init()
|
||||
|
||||
def close(self) -> None:
|
||||
"""Closes the Local Cache connection"""
|
||||
"""Closes the Local Cache connection."""
|
||||
self.local_cache_api.lavalink.close()
|
||||
|
||||
async def get_random_track_from_db(self) -> Optional[MutableMapping]:
|
||||
"""Get a random track from the local database and return it"""
|
||||
async def get_random_track_from_db(self, tries=0) -> Optional[MutableMapping]:
|
||||
"""Get a random track from the local database and return it."""
|
||||
track: Optional[MutableMapping] = {}
|
||||
try:
|
||||
query_data = {}
|
||||
@@ -106,7 +113,7 @@ class AudioAPIInterface:
|
||||
action_type: str = None,
|
||||
data: Union[List[MutableMapping], MutableMapping] = None,
|
||||
) -> None:
|
||||
"""Separate the tasks and run them in the appropriate functions"""
|
||||
"""Separate the tasks and run them in the appropriate functions."""
|
||||
|
||||
if not data:
|
||||
return
|
||||
@@ -126,9 +133,11 @@ class AudioAPIInterface:
|
||||
await self.local_cache_api.youtube.update(data)
|
||||
elif table == "spotify":
|
||||
await self.local_cache_api.spotify.update(data)
|
||||
elif action_type == "global" and isinstance(data, list):
|
||||
await asyncio.gather(*[self.global_cache_api.update_global(**d) for d in data])
|
||||
|
||||
async def run_tasks(self, ctx: Optional[commands.Context] = None, message_id=None) -> None:
|
||||
"""Run tasks for a specific context"""
|
||||
"""Run tasks for a specific context."""
|
||||
if message_id is not None:
|
||||
lock_id = message_id
|
||||
elif ctx is not None:
|
||||
@@ -143,7 +152,7 @@ class AudioAPIInterface:
|
||||
try:
|
||||
tasks = self._tasks[lock_id]
|
||||
tasks = [self.route_tasks(a, tasks[a]) for a in tasks]
|
||||
await asyncio.gather(*tasks, return_exceptions=True)
|
||||
await asyncio.gather(*tasks, return_exceptions=False)
|
||||
del self._tasks[lock_id]
|
||||
except Exception as exc:
|
||||
debug_exc_log(
|
||||
@@ -154,7 +163,7 @@ class AudioAPIInterface:
|
||||
log.debug(f"Completed database writes for {lock_id} ({lock_author})")
|
||||
|
||||
async def run_all_pending_tasks(self) -> None:
|
||||
"""Run all pending tasks left in the cache, called on cog_unload"""
|
||||
"""Run all pending tasks left in the cache, called on cog_unload."""
|
||||
async with self._lock:
|
||||
if IS_DEBUG:
|
||||
log.debug("Running pending writes to database")
|
||||
@@ -166,7 +175,7 @@ class AudioAPIInterface:
|
||||
self._tasks = {}
|
||||
coro_tasks = [self.route_tasks(a, tasks[a]) for a in tasks]
|
||||
|
||||
await asyncio.gather(*coro_tasks, return_exceptions=True)
|
||||
await asyncio.gather(*coro_tasks, return_exceptions=False)
|
||||
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, "Failed database writes")
|
||||
@@ -175,7 +184,7 @@ class AudioAPIInterface:
|
||||
log.debug("Completed pending writes to database have finished")
|
||||
|
||||
def append_task(self, ctx: commands.Context, event: str, task: Tuple, _id: int = None) -> None:
|
||||
"""Add a task to the cache to be run later"""
|
||||
"""Add a task to the cache to be run later."""
|
||||
lock_id = _id or ctx.message.id
|
||||
if lock_id not in self._tasks:
|
||||
self._tasks[lock_id] = {"update": [], "insert": [], "global": []}
|
||||
@@ -190,7 +199,7 @@ class AudioAPIInterface:
|
||||
skip_youtube: bool = False,
|
||||
current_cache_level: CacheLevel = CacheLevel.none(),
|
||||
) -> List[str]:
|
||||
"""Return youtube URLS for the spotify URL provided"""
|
||||
"""Return youtube URLS for the spotify URL provided."""
|
||||
youtube_urls = []
|
||||
tracks = await self.fetch_from_spotify_api(
|
||||
query_type, uri, params=None, notifier=notifier, ctx=ctx
|
||||
@@ -266,7 +275,7 @@ class AudioAPIInterface:
|
||||
notifier: Optional[Notifier] = None,
|
||||
ctx: Context = None,
|
||||
) -> Union[List[MutableMapping], List[str]]:
|
||||
"""Gets track info from spotify API"""
|
||||
"""Gets track info from spotify API."""
|
||||
|
||||
if recursive is False:
|
||||
(call, params) = self.spotify_api.spotify_format_call(query_type, uri)
|
||||
@@ -394,9 +403,10 @@ class AudioAPIInterface:
|
||||
lock: Callable,
|
||||
notifier: Optional[Notifier] = None,
|
||||
forced: bool = False,
|
||||
query_global: bool = False,
|
||||
query_global: bool = True,
|
||||
) -> List[lavalink.Track]:
|
||||
"""Queries the Database then falls back to Spotify and YouTube APIs then Enqueued matched tracks.
|
||||
"""Queries the Database then falls back to Spotify and YouTube APIs then Enqueued matched
|
||||
tracks.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -423,7 +433,9 @@ class AudioAPIInterface:
|
||||
List[str]
|
||||
List of Youtube URLs.
|
||||
"""
|
||||
# globaldb_toggle = await self.config.global_db_enabled()
|
||||
await self.global_cache_api._get_api_key()
|
||||
globaldb_toggle = await self.config.global_db_enabled()
|
||||
global_entry = globaldb_toggle and query_global
|
||||
track_list: List = []
|
||||
has_not_allowed = False
|
||||
try:
|
||||
@@ -485,7 +497,14 @@ class AudioAPIInterface:
|
||||
)
|
||||
except Exception as exc:
|
||||
debug_exc_log(log, exc, f"Failed to fetch {track_info} from YouTube table")
|
||||
|
||||
should_query_global = globaldb_toggle and query_global and val is None
|
||||
if should_query_global:
|
||||
llresponse = await self.global_cache_api.get_spotify(track_name, artist_name)
|
||||
if llresponse:
|
||||
if llresponse.get("loadType") == "V2_COMPACT":
|
||||
llresponse["loadType"] = "V2_COMPAT"
|
||||
llresponse = LoadResult(llresponse)
|
||||
val = llresponse or None
|
||||
if val is None:
|
||||
val = await self.fetch_youtube_query(
|
||||
ctx, track_info, current_cache_level=current_cache_level
|
||||
@@ -494,34 +513,44 @@ class AudioAPIInterface:
|
||||
task = ("update", ("youtube", {"track": track_info}))
|
||||
self.append_task(ctx, *task)
|
||||
|
||||
if llresponse is not None:
|
||||
if isinstance(llresponse, LoadResult):
|
||||
track_object = llresponse.tracks
|
||||
elif val:
|
||||
try:
|
||||
(result, called_api) = await self.fetch_track(
|
||||
ctx,
|
||||
player,
|
||||
Query.process_input(val, self.cog.local_folder_current_path),
|
||||
forced=forced,
|
||||
)
|
||||
except (RuntimeError, aiohttp.ServerDisconnectedError):
|
||||
lock(ctx, False)
|
||||
error_embed = discord.Embed(
|
||||
colour=await ctx.embed_colour(),
|
||||
title=_("The connection was reset while loading the playlist."),
|
||||
)
|
||||
if notifier is not None:
|
||||
await notifier.update_embed(error_embed)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
lock(ctx, False)
|
||||
error_embed = discord.Embed(
|
||||
colour=await ctx.embed_colour(),
|
||||
title=_("Player timeout, skipping remaining tracks."),
|
||||
)
|
||||
if notifier is not None:
|
||||
await notifier.update_embed(error_embed)
|
||||
break
|
||||
result = None
|
||||
if should_query_global:
|
||||
llresponse = await self.global_cache_api.get_call(val)
|
||||
if llresponse:
|
||||
if llresponse.get("loadType") == "V2_COMPACT":
|
||||
llresponse["loadType"] = "V2_COMPAT"
|
||||
llresponse = LoadResult(llresponse)
|
||||
result = llresponse or None
|
||||
if not result:
|
||||
try:
|
||||
(result, called_api) = await self.fetch_track(
|
||||
ctx,
|
||||
player,
|
||||
Query.process_input(val, self.cog.local_folder_current_path),
|
||||
forced=forced,
|
||||
should_query_global=not should_query_global,
|
||||
)
|
||||
except (RuntimeError, aiohttp.ServerDisconnectedError):
|
||||
lock(ctx, False)
|
||||
error_embed = discord.Embed(
|
||||
colour=await ctx.embed_colour(),
|
||||
title=_("The connection was reset while loading the playlist."),
|
||||
)
|
||||
if notifier is not None:
|
||||
await notifier.update_embed(error_embed)
|
||||
break
|
||||
except asyncio.TimeoutError:
|
||||
lock(ctx, False)
|
||||
error_embed = discord.Embed(
|
||||
colour=await ctx.embed_colour(),
|
||||
title=_("Player timeout, skipping remaining tracks."),
|
||||
)
|
||||
if notifier is not None:
|
||||
await notifier.update_embed(error_embed)
|
||||
break
|
||||
track_object = result.tracks
|
||||
else:
|
||||
track_object = []
|
||||
@@ -538,7 +567,7 @@ class AudioAPIInterface:
|
||||
seconds=seconds,
|
||||
)
|
||||
|
||||
if consecutive_fails >= 10:
|
||||
if consecutive_fails >= (100 if global_entry else 10):
|
||||
error_embed = discord.Embed(
|
||||
colour=await ctx.embed_colour(),
|
||||
title=_("Failing to get tracks, skipping remaining."),
|
||||
@@ -551,13 +580,12 @@ class AudioAPIInterface:
|
||||
continue
|
||||
consecutive_fails = 0
|
||||
single_track = track_object[0]
|
||||
query = Query.process_input(single_track, self.cog.local_folder_current_path)
|
||||
if not await self.cog.is_query_allowed(
|
||||
self.config,
|
||||
ctx.guild,
|
||||
(
|
||||
f"{single_track.title} {single_track.author} {single_track.uri} "
|
||||
f"{Query.process_input(single_track, self.cog.local_folder_current_path)}"
|
||||
),
|
||||
ctx,
|
||||
f"{single_track.title} {single_track.author} {single_track.uri} {query}",
|
||||
query_obj=query,
|
||||
):
|
||||
has_not_allowed = True
|
||||
if IS_DEBUG:
|
||||
@@ -570,6 +598,13 @@ class AudioAPIInterface:
|
||||
if guild_data["maxlength"] > 0:
|
||||
if self.cog.is_track_length_allowed(single_track, guild_data["maxlength"]):
|
||||
enqueued_tracks += 1
|
||||
single_track.extras.update(
|
||||
{
|
||||
"enqueue_time": int(time.time()),
|
||||
"vc": player.channel.id,
|
||||
"requester": ctx.author.id,
|
||||
}
|
||||
)
|
||||
player.add(ctx.author, single_track)
|
||||
self.bot.dispatch(
|
||||
"red_audio_track_enqueue",
|
||||
@@ -579,6 +614,13 @@ class AudioAPIInterface:
|
||||
)
|
||||
else:
|
||||
enqueued_tracks += 1
|
||||
single_track.extras.update(
|
||||
{
|
||||
"enqueue_time": int(time.time()),
|
||||
"vc": player.channel.id,
|
||||
"requester": ctx.author.id,
|
||||
}
|
||||
)
|
||||
player.add(ctx.author, single_track)
|
||||
self.bot.dispatch(
|
||||
"red_audio_track_enqueue",
|
||||
@@ -642,9 +684,7 @@ class AudioAPIInterface:
|
||||
track_info: str,
|
||||
current_cache_level: CacheLevel = CacheLevel.none(),
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Call the Youtube API and returns the youtube URL that the query matched
|
||||
"""
|
||||
"""Call the Youtube API and returns the youtube URL that the query matched."""
|
||||
track_url = await self.youtube_api.get_call(track_info)
|
||||
if CacheLevel.set_youtube().is_subset(current_cache_level) and track_url:
|
||||
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
|
||||
@@ -668,9 +708,7 @@ class AudioAPIInterface:
|
||||
async def fetch_from_youtube_api(
|
||||
self, ctx: commands.Context, track_info: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Gets an YouTube URL from for the query
|
||||
"""
|
||||
"""Gets an YouTube URL from for the query."""
|
||||
current_cache_level = CacheLevel(await self.config.cache_level())
|
||||
cache_enabled = CacheLevel.set_youtube().is_subset(current_cache_level)
|
||||
val = None
|
||||
@@ -727,6 +765,7 @@ class AudioAPIInterface:
|
||||
val = None
|
||||
query = Query.process_input(query, self.cog.local_folder_current_path)
|
||||
query_string = str(query)
|
||||
globaldb_toggle = await self.config.global_db_enabled()
|
||||
valid_global_entry = False
|
||||
results = None
|
||||
called_api = False
|
||||
@@ -754,7 +793,31 @@ class AudioAPIInterface:
|
||||
called_api = False
|
||||
else:
|
||||
val = None
|
||||
|
||||
if (
|
||||
globaldb_toggle
|
||||
and not val
|
||||
and should_query_global
|
||||
and not forced
|
||||
and not query.is_local
|
||||
and not query.is_spotify
|
||||
):
|
||||
valid_global_entry = False
|
||||
with contextlib.suppress(Exception):
|
||||
global_entry = await self.global_cache_api.get_call(query=query)
|
||||
if global_entry.get("loadType") == "V2_COMPACT":
|
||||
global_entry["loadType"] = "V2_COMPAT"
|
||||
results = LoadResult(global_entry)
|
||||
if results.load_type in [
|
||||
LoadType.PLAYLIST_LOADED,
|
||||
LoadType.TRACK_LOADED,
|
||||
LoadType.SEARCH_RESULT,
|
||||
LoadType.V2_COMPAT,
|
||||
]:
|
||||
valid_global_entry = True
|
||||
if valid_global_entry:
|
||||
if IS_DEBUG:
|
||||
log.debug(f"Querying Global DB api for {query}")
|
||||
results, called_api = results, False
|
||||
if valid_global_entry:
|
||||
pass
|
||||
elif lazy is True:
|
||||
@@ -769,6 +832,7 @@ class AudioAPIInterface:
|
||||
if results.has_error:
|
||||
# If cached value has an invalid entry make a new call so that it gets updated
|
||||
results, called_api = await self.fetch_track(ctx, player, query, forced=True)
|
||||
valid_global_entry = False
|
||||
else:
|
||||
if IS_DEBUG:
|
||||
log.debug(f"Querying Lavalink api for {query_string}")
|
||||
@@ -781,7 +845,19 @@ class AudioAPIInterface:
|
||||
raise TrackEnqueueError
|
||||
if results is None:
|
||||
results = LoadResult({"loadType": "LOAD_FAILED", "playlistInfo": {}, "tracks": []})
|
||||
|
||||
valid_global_entry = False
|
||||
update_global = (
|
||||
globaldb_toggle and not valid_global_entry and self.global_cache_api.has_api_key
|
||||
)
|
||||
with contextlib.suppress(Exception):
|
||||
if (
|
||||
update_global
|
||||
and not query.is_local
|
||||
and not results.has_error
|
||||
and len(results.tracks) >= 1
|
||||
):
|
||||
global_task = ("global", dict(llresponse=results, query=query))
|
||||
self.append_task(ctx, *global_task)
|
||||
if (
|
||||
cache_enabled
|
||||
and results.load_type
|
||||
@@ -817,9 +893,7 @@ class AudioAPIInterface:
|
||||
return results, called_api
|
||||
|
||||
async def autoplay(self, player: lavalink.Player, playlist_api: PlaylistWrapper):
|
||||
"""
|
||||
Enqueue a random track
|
||||
"""
|
||||
"""Enqueue a random track."""
|
||||
autoplaylist = await self.config.guild(player.channel.guild).autoplaylist()
|
||||
current_cache_level = CacheLevel(await self.config.cache_level())
|
||||
cache_enabled = CacheLevel.set_lavalink().is_subset(current_cache_level)
|
||||
@@ -865,19 +939,18 @@ class AudioAPIInterface:
|
||||
track = random.choice(tracks)
|
||||
query = Query.process_input(track, self.cog.local_folder_current_path)
|
||||
await asyncio.sleep(0.001)
|
||||
if not query.valid or (
|
||||
if (not query.valid) or (
|
||||
query.is_local
|
||||
and query.local_track_path is not None
|
||||
and not query.local_track_path.exists()
|
||||
):
|
||||
continue
|
||||
notify_channel = self.bot.get_channel(player.fetch("channel"))
|
||||
if not await self.cog.is_query_allowed(
|
||||
self.config,
|
||||
player.channel.guild,
|
||||
(
|
||||
f"{track.title} {track.author} {track.uri} "
|
||||
f"{str(Query.process_input(track, self.cog.local_folder_current_path))}"
|
||||
),
|
||||
notify_channel,
|
||||
f"{track.title} {track.author} {track.uri} {query}",
|
||||
query_obj=query,
|
||||
):
|
||||
if IS_DEBUG:
|
||||
log.debug(
|
||||
@@ -886,11 +959,20 @@ class AudioAPIInterface:
|
||||
)
|
||||
continue
|
||||
valid = True
|
||||
|
||||
track.extras["autoplay"] = True
|
||||
track.extras.update(
|
||||
{
|
||||
"autoplay": True,
|
||||
"enqueue_time": int(time.time()),
|
||||
"vc": player.channel.id,
|
||||
"requester": player.channel.guild.me.id,
|
||||
}
|
||||
)
|
||||
player.add(player.channel.guild.me, track)
|
||||
self.bot.dispatch(
|
||||
"red_audio_track_auto_play", player.channel.guild, track, player.channel.guild.me
|
||||
)
|
||||
if not player.current:
|
||||
await player.play()
|
||||
|
||||
async def fetch_all_contribute(self) -> List[LavalinkCacheFetchForGlobalResult]:
|
||||
return await self.local_cache_api.lavalink.fetch_all_for_global()
|
||||
|
||||
Reference in New Issue
Block a user