Audio Cog - v2.3.0 (#4446)

* First commit - Bring everything from dev cog minus NSFW support

* Add a toggle for auto deafen

* Add a one off Send to Owners

* aaaaaaa

* Update this to ensure `get_perms` is not called if the API is disabled

* Apply suggestions from code review

Co-authored-by: Vuks <51289041+Vuks69@users.noreply.github.com>

* silence any errors here (in case API is down so it doesnt affect audio)

* update the message to tell the mto join the Official Red server.

* remove useless sutff, and change dj check order to ensure bot doesnt join VC for non DJ's

* ffs

* Update redbot/cogs/audio/core/tasks/startup.py

Co-authored-by: Twentysix <Twentysix26@users.noreply.github.com>

* Aikas Review

* Add #3995 in here

* update

* *sigh*

* lock behind owner

* to help with debugging

* Revert "to help with debugging"

This reverts commit 8cbf17be

* resolve last review

Co-authored-by: Vuks <51289041+Vuks69@users.noreply.github.com>
Co-authored-by: Twentysix <Twentysix26@users.noreply.github.com>
This commit is contained in:
Draper
2020-10-12 19:39:39 +01:00
committed by GitHub
parent 29ebf0f060
commit 2da9b502d8
41 changed files with 1553 additions and 331 deletions

View File

@@ -1,30 +1,34 @@
import asyncio
import contextlib
import datetime
import json
import logging
import random
import time
from collections import namedtuple
from typing import Callable, List, MutableMapping, Optional, TYPE_CHECKING, Tuple, Union, cast
from typing import TYPE_CHECKING, Callable, List, MutableMapping, Optional, Tuple, Union, cast
import aiohttp
import discord
import lavalink
from lavalink.rest_api import LoadResult
from redbot.core.utils import AsyncIter
from lavalink.rest_api import LoadResult, LoadType
from redbot.core import Config, commands
from redbot.core.bot import Red
from redbot.core.commands import Cog, Context
from redbot.core.i18n import Translator
from redbot.core.utils import AsyncIter
from redbot.core.utils.dbtools import APSWConnectionWrapper
from ..audio_dataclasses import Query
from ..audio_logging import IS_DEBUG, debug_exc_log
from ..errors import DatabaseError, SpotifyFetchError, TrackEnqueueError
from ..utils import CacheLevel, Notifier
from .api_utils import LavalinkCacheFetchForGlobalResult
from .global_db import GlobalCacheWrapper
from .local_db import LocalCacheWrapper
from .persist_queue_wrapper import QueueInterface
from .playlist_interface import get_playlist
from .playlist_wrapper import PlaylistWrapper
from .spotify import SpotifyWrapper
@@ -36,6 +40,7 @@ if TYPE_CHECKING:
_ = Translator("Audio", __file__)
log = logging.getLogger("red.cogs.Audio.api.AudioAPIInterface")
_TOP_100_US = "https://www.youtube.com/playlist?list=PL4fGSI1pDJn5rWitrRWFKdm-ulaFiIyoK"
# TODO: Get random from global Cache
class AudioAPIInterface:
@@ -60,20 +65,22 @@ class AudioAPIInterface:
self.youtube_api: YouTubeWrapper = YouTubeWrapper(self.bot, self.config, session, self.cog)
self.local_cache_api = LocalCacheWrapper(self.bot, self.config, self.conn, self.cog)
self.global_cache_api = GlobalCacheWrapper(self.bot, self.config, session, self.cog)
self.persistent_queue_api = QueueInterface(self.bot, self.config, self.conn, self.cog)
self._session: aiohttp.ClientSession = session
self._tasks: MutableMapping = {}
self._lock: asyncio.Lock = asyncio.Lock()
async def initialize(self) -> None:
"""Initialises the Local Cache connection"""
"""Initialises the Local Cache connection."""
await self.local_cache_api.lavalink.init()
await self.persistent_queue_api.init()
def close(self) -> None:
"""Closes the Local Cache connection"""
"""Closes the Local Cache connection."""
self.local_cache_api.lavalink.close()
async def get_random_track_from_db(self) -> Optional[MutableMapping]:
"""Get a random track from the local database and return it"""
async def get_random_track_from_db(self, tries=0) -> Optional[MutableMapping]:
"""Get a random track from the local database and return it."""
track: Optional[MutableMapping] = {}
try:
query_data = {}
@@ -106,7 +113,7 @@ class AudioAPIInterface:
action_type: str = None,
data: Union[List[MutableMapping], MutableMapping] = None,
) -> None:
"""Separate the tasks and run them in the appropriate functions"""
"""Separate the tasks and run them in the appropriate functions."""
if not data:
return
@@ -126,9 +133,11 @@ class AudioAPIInterface:
await self.local_cache_api.youtube.update(data)
elif table == "spotify":
await self.local_cache_api.spotify.update(data)
elif action_type == "global" and isinstance(data, list):
await asyncio.gather(*[self.global_cache_api.update_global(**d) for d in data])
async def run_tasks(self, ctx: Optional[commands.Context] = None, message_id=None) -> None:
"""Run tasks for a specific context"""
"""Run tasks for a specific context."""
if message_id is not None:
lock_id = message_id
elif ctx is not None:
@@ -143,7 +152,7 @@ class AudioAPIInterface:
try:
tasks = self._tasks[lock_id]
tasks = [self.route_tasks(a, tasks[a]) for a in tasks]
await asyncio.gather(*tasks, return_exceptions=True)
await asyncio.gather(*tasks, return_exceptions=False)
del self._tasks[lock_id]
except Exception as exc:
debug_exc_log(
@@ -154,7 +163,7 @@ class AudioAPIInterface:
log.debug(f"Completed database writes for {lock_id} ({lock_author})")
async def run_all_pending_tasks(self) -> None:
"""Run all pending tasks left in the cache, called on cog_unload"""
"""Run all pending tasks left in the cache, called on cog_unload."""
async with self._lock:
if IS_DEBUG:
log.debug("Running pending writes to database")
@@ -166,7 +175,7 @@ class AudioAPIInterface:
self._tasks = {}
coro_tasks = [self.route_tasks(a, tasks[a]) for a in tasks]
await asyncio.gather(*coro_tasks, return_exceptions=True)
await asyncio.gather(*coro_tasks, return_exceptions=False)
except Exception as exc:
debug_exc_log(log, exc, "Failed database writes")
@@ -175,7 +184,7 @@ class AudioAPIInterface:
log.debug("Completed pending writes to database have finished")
def append_task(self, ctx: commands.Context, event: str, task: Tuple, _id: int = None) -> None:
"""Add a task to the cache to be run later"""
"""Add a task to the cache to be run later."""
lock_id = _id or ctx.message.id
if lock_id not in self._tasks:
self._tasks[lock_id] = {"update": [], "insert": [], "global": []}
@@ -190,7 +199,7 @@ class AudioAPIInterface:
skip_youtube: bool = False,
current_cache_level: CacheLevel = CacheLevel.none(),
) -> List[str]:
"""Return youtube URLS for the spotify URL provided"""
"""Return youtube URLS for the spotify URL provided."""
youtube_urls = []
tracks = await self.fetch_from_spotify_api(
query_type, uri, params=None, notifier=notifier, ctx=ctx
@@ -266,7 +275,7 @@ class AudioAPIInterface:
notifier: Optional[Notifier] = None,
ctx: Context = None,
) -> Union[List[MutableMapping], List[str]]:
"""Gets track info from spotify API"""
"""Gets track info from spotify API."""
if recursive is False:
(call, params) = self.spotify_api.spotify_format_call(query_type, uri)
@@ -394,9 +403,10 @@ class AudioAPIInterface:
lock: Callable,
notifier: Optional[Notifier] = None,
forced: bool = False,
query_global: bool = False,
query_global: bool = True,
) -> List[lavalink.Track]:
"""Queries the Database then falls back to Spotify and YouTube APIs then Enqueued matched tracks.
"""Queries the Database then falls back to Spotify and YouTube APIs then Enqueued matched
tracks.
Parameters
----------
@@ -423,7 +433,9 @@ class AudioAPIInterface:
List[str]
List of Youtube URLs.
"""
# globaldb_toggle = await self.config.global_db_enabled()
await self.global_cache_api._get_api_key()
globaldb_toggle = await self.config.global_db_enabled()
global_entry = globaldb_toggle and query_global
track_list: List = []
has_not_allowed = False
try:
@@ -485,7 +497,14 @@ class AudioAPIInterface:
)
except Exception as exc:
debug_exc_log(log, exc, f"Failed to fetch {track_info} from YouTube table")
should_query_global = globaldb_toggle and query_global and val is None
if should_query_global:
llresponse = await self.global_cache_api.get_spotify(track_name, artist_name)
if llresponse:
if llresponse.get("loadType") == "V2_COMPACT":
llresponse["loadType"] = "V2_COMPAT"
llresponse = LoadResult(llresponse)
val = llresponse or None
if val is None:
val = await self.fetch_youtube_query(
ctx, track_info, current_cache_level=current_cache_level
@@ -494,34 +513,44 @@ class AudioAPIInterface:
task = ("update", ("youtube", {"track": track_info}))
self.append_task(ctx, *task)
if llresponse is not None:
if isinstance(llresponse, LoadResult):
track_object = llresponse.tracks
elif val:
try:
(result, called_api) = await self.fetch_track(
ctx,
player,
Query.process_input(val, self.cog.local_folder_current_path),
forced=forced,
)
except (RuntimeError, aiohttp.ServerDisconnectedError):
lock(ctx, False)
error_embed = discord.Embed(
colour=await ctx.embed_colour(),
title=_("The connection was reset while loading the playlist."),
)
if notifier is not None:
await notifier.update_embed(error_embed)
break
except asyncio.TimeoutError:
lock(ctx, False)
error_embed = discord.Embed(
colour=await ctx.embed_colour(),
title=_("Player timeout, skipping remaining tracks."),
)
if notifier is not None:
await notifier.update_embed(error_embed)
break
result = None
if should_query_global:
llresponse = await self.global_cache_api.get_call(val)
if llresponse:
if llresponse.get("loadType") == "V2_COMPACT":
llresponse["loadType"] = "V2_COMPAT"
llresponse = LoadResult(llresponse)
result = llresponse or None
if not result:
try:
(result, called_api) = await self.fetch_track(
ctx,
player,
Query.process_input(val, self.cog.local_folder_current_path),
forced=forced,
should_query_global=not should_query_global,
)
except (RuntimeError, aiohttp.ServerDisconnectedError):
lock(ctx, False)
error_embed = discord.Embed(
colour=await ctx.embed_colour(),
title=_("The connection was reset while loading the playlist."),
)
if notifier is not None:
await notifier.update_embed(error_embed)
break
except asyncio.TimeoutError:
lock(ctx, False)
error_embed = discord.Embed(
colour=await ctx.embed_colour(),
title=_("Player timeout, skipping remaining tracks."),
)
if notifier is not None:
await notifier.update_embed(error_embed)
break
track_object = result.tracks
else:
track_object = []
@@ -538,7 +567,7 @@ class AudioAPIInterface:
seconds=seconds,
)
if consecutive_fails >= 10:
if consecutive_fails >= (100 if global_entry else 10):
error_embed = discord.Embed(
colour=await ctx.embed_colour(),
title=_("Failing to get tracks, skipping remaining."),
@@ -551,13 +580,12 @@ class AudioAPIInterface:
continue
consecutive_fails = 0
single_track = track_object[0]
query = Query.process_input(single_track, self.cog.local_folder_current_path)
if not await self.cog.is_query_allowed(
self.config,
ctx.guild,
(
f"{single_track.title} {single_track.author} {single_track.uri} "
f"{Query.process_input(single_track, self.cog.local_folder_current_path)}"
),
ctx,
f"{single_track.title} {single_track.author} {single_track.uri} {query}",
query_obj=query,
):
has_not_allowed = True
if IS_DEBUG:
@@ -570,6 +598,13 @@ class AudioAPIInterface:
if guild_data["maxlength"] > 0:
if self.cog.is_track_length_allowed(single_track, guild_data["maxlength"]):
enqueued_tracks += 1
single_track.extras.update(
{
"enqueue_time": int(time.time()),
"vc": player.channel.id,
"requester": ctx.author.id,
}
)
player.add(ctx.author, single_track)
self.bot.dispatch(
"red_audio_track_enqueue",
@@ -579,6 +614,13 @@ class AudioAPIInterface:
)
else:
enqueued_tracks += 1
single_track.extras.update(
{
"enqueue_time": int(time.time()),
"vc": player.channel.id,
"requester": ctx.author.id,
}
)
player.add(ctx.author, single_track)
self.bot.dispatch(
"red_audio_track_enqueue",
@@ -642,9 +684,7 @@ class AudioAPIInterface:
track_info: str,
current_cache_level: CacheLevel = CacheLevel.none(),
) -> Optional[str]:
"""
Call the Youtube API and returns the youtube URL that the query matched
"""
"""Call the Youtube API and returns the youtube URL that the query matched."""
track_url = await self.youtube_api.get_call(track_info)
if CacheLevel.set_youtube().is_subset(current_cache_level) and track_url:
time_now = int(datetime.datetime.now(datetime.timezone.utc).timestamp())
@@ -668,9 +708,7 @@ class AudioAPIInterface:
async def fetch_from_youtube_api(
self, ctx: commands.Context, track_info: str
) -> Optional[str]:
"""
Gets an YouTube URL from for the query
"""
"""Gets an YouTube URL from for the query."""
current_cache_level = CacheLevel(await self.config.cache_level())
cache_enabled = CacheLevel.set_youtube().is_subset(current_cache_level)
val = None
@@ -727,6 +765,7 @@ class AudioAPIInterface:
val = None
query = Query.process_input(query, self.cog.local_folder_current_path)
query_string = str(query)
globaldb_toggle = await self.config.global_db_enabled()
valid_global_entry = False
results = None
called_api = False
@@ -754,7 +793,31 @@ class AudioAPIInterface:
called_api = False
else:
val = None
if (
globaldb_toggle
and not val
and should_query_global
and not forced
and not query.is_local
and not query.is_spotify
):
valid_global_entry = False
with contextlib.suppress(Exception):
global_entry = await self.global_cache_api.get_call(query=query)
if global_entry.get("loadType") == "V2_COMPACT":
global_entry["loadType"] = "V2_COMPAT"
results = LoadResult(global_entry)
if results.load_type in [
LoadType.PLAYLIST_LOADED,
LoadType.TRACK_LOADED,
LoadType.SEARCH_RESULT,
LoadType.V2_COMPAT,
]:
valid_global_entry = True
if valid_global_entry:
if IS_DEBUG:
log.debug(f"Querying Global DB api for {query}")
results, called_api = results, False
if valid_global_entry:
pass
elif lazy is True:
@@ -769,6 +832,7 @@ class AudioAPIInterface:
if results.has_error:
# If cached value has an invalid entry make a new call so that it gets updated
results, called_api = await self.fetch_track(ctx, player, query, forced=True)
valid_global_entry = False
else:
if IS_DEBUG:
log.debug(f"Querying Lavalink api for {query_string}")
@@ -781,7 +845,19 @@ class AudioAPIInterface:
raise TrackEnqueueError
if results is None:
results = LoadResult({"loadType": "LOAD_FAILED", "playlistInfo": {}, "tracks": []})
valid_global_entry = False
update_global = (
globaldb_toggle and not valid_global_entry and self.global_cache_api.has_api_key
)
with contextlib.suppress(Exception):
if (
update_global
and not query.is_local
and not results.has_error
and len(results.tracks) >= 1
):
global_task = ("global", dict(llresponse=results, query=query))
self.append_task(ctx, *global_task)
if (
cache_enabled
and results.load_type
@@ -817,9 +893,7 @@ class AudioAPIInterface:
return results, called_api
async def autoplay(self, player: lavalink.Player, playlist_api: PlaylistWrapper):
"""
Enqueue a random track
"""
"""Enqueue a random track."""
autoplaylist = await self.config.guild(player.channel.guild).autoplaylist()
current_cache_level = CacheLevel(await self.config.cache_level())
cache_enabled = CacheLevel.set_lavalink().is_subset(current_cache_level)
@@ -865,19 +939,18 @@ class AudioAPIInterface:
track = random.choice(tracks)
query = Query.process_input(track, self.cog.local_folder_current_path)
await asyncio.sleep(0.001)
if not query.valid or (
if (not query.valid) or (
query.is_local
and query.local_track_path is not None
and not query.local_track_path.exists()
):
continue
notify_channel = self.bot.get_channel(player.fetch("channel"))
if not await self.cog.is_query_allowed(
self.config,
player.channel.guild,
(
f"{track.title} {track.author} {track.uri} "
f"{str(Query.process_input(track, self.cog.local_folder_current_path))}"
),
notify_channel,
f"{track.title} {track.author} {track.uri} {query}",
query_obj=query,
):
if IS_DEBUG:
log.debug(
@@ -886,11 +959,20 @@ class AudioAPIInterface:
)
continue
valid = True
track.extras["autoplay"] = True
track.extras.update(
{
"autoplay": True,
"enqueue_time": int(time.time()),
"vc": player.channel.id,
"requester": player.channel.guild.me.id,
}
)
player.add(player.channel.guild.me, track)
self.bot.dispatch(
"red_audio_track_auto_play", player.channel.guild, track, player.channel.guild.me
)
if not player.current:
await player.play()
async def fetch_all_contribute(self) -> List[LavalinkCacheFetchForGlobalResult]:
return await self.local_cache_api.lavalink.fetch_all_for_global()