reorder some startup to prevent heartbeat issues (#3073)

* reorder some startup to prevent heartbeat issues

* changelog

* handle startup cleanup in audio

* style

* rebased to handle conflict

* be a little smarter to prevent (some) infinite hangs

* Fix a pre-existing NoneType Error

* Migrate config before things are using it...

* another place we should ensure we're ready

* rename-toavoid-issues

* fix cache ordering and mis-use of ensure_future

* remove incorrect typehints

* style
This commit is contained in:
Michael H 2019-11-09 14:19:57 -05:00 committed by GitHub
parent 6852b7a1d1
commit b3363acf77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 192 additions and 156 deletions

View File

@ -0,0 +1 @@
Bot now handles more things prior to connecting to discord to reduce issues with initial load

View File

@ -0,0 +1 @@
``bot.wait_until_ready`` should no longer be used during extension setup

View File

@ -312,7 +312,7 @@ def main():
loop.run_until_complete(red.http.close())
sys.exit(0)
try:
loop.run_until_complete(red.start(token, bot=True))
loop.run_until_complete(red.start(token, bot=True, cli_flags=cli_flags))
except discord.LoginFailure:
log.critical("This token doesn't seem to be valid.")
db_token = loop.run_until_complete(red._config.token())

View File

@ -3,7 +3,6 @@ from redbot.core import commands
from .audio import Audio
async def setup(bot: commands.Bot):
def setup(bot: commands.Bot):
cog = Audio(bot)
await cog.initialize()
bot.add_cog(cog)

View File

@ -9,7 +9,7 @@ import random
import time
import traceback
from collections import namedtuple
from typing import Callable, Dict, List, Mapping, NoReturn, Optional, Tuple, Union
from typing import Callable, Dict, List, Mapping, Optional, Tuple, Union
try:
from sqlite3 import Error as SQLError
@ -32,7 +32,7 @@ from lavalink.rest_api import LoadResult
from redbot.core import Config, commands
from redbot.core.bot import Red
from redbot.core.i18n import Translator, cog_i18n
from . import dataclasses
from . import audio_dataclasses
from .errors import InvalidTableError, SpotifyFetchError, YouTubeApiError
from .playlists import get_playlist
from .utils import CacheLevel, Notifier, is_allowed, queue_duration, track_limit
@ -193,7 +193,7 @@ class SpotifyAPI:
)
return await r.json()
async def _get_auth(self) -> NoReturn:
async def _get_auth(self):
if self.client_id is None or self.client_secret is None:
tokens = await self.bot.get_shared_api_tokens("spotify")
self.client_id = tokens.get("client_id", "")
@ -331,7 +331,7 @@ class MusicCache:
self._lock: asyncio.Lock = asyncio.Lock()
self.config: Optional[Config] = None
async def initialize(self, config: Config) -> NoReturn:
async def initialize(self, config: Config):
if HAS_SQL:
await self.database.connect()
@ -348,12 +348,12 @@ class MusicCache:
await self.database.execute(query=_CREATE_UNIQUE_INDEX_SPOTIFY_TABLE)
self.config = config
async def close(self) -> NoReturn:
async def close(self):
if HAS_SQL:
await self.database.execute(query="PRAGMA optimize;")
await self.database.disconnect()
async def insert(self, table: str, values: List[dict]) -> NoReturn:
async def insert(self, table: str, values: List[dict]):
# if table == "spotify":
# return
if HAS_SQL:
@ -363,7 +363,7 @@ class MusicCache:
await self.database.execute_many(query=query, values=values)
async def update(self, table: str, values: Dict[str, str]) -> NoReturn:
async def update(self, table: str, values: Dict[str, str]):
# if table == "spotify":
# return
if HAS_SQL:
@ -746,7 +746,7 @@ class MusicCache:
if val:
try:
result, called_api = await self.lavalink_query(
ctx, player, dataclasses.Query.process_input(val)
ctx, player, audio_dataclasses.Query.process_input(val)
)
except (RuntimeError, aiohttp.ServerDisconnectedError):
lock(ctx, False)
@ -805,7 +805,7 @@ class MusicCache:
ctx.guild,
(
f"{single_track.title} {single_track.author} {single_track.uri} "
f"{str(dataclasses.Query.process_input(single_track))}"
f"{str(audio_dataclasses.Query.process_input(single_track))}"
),
):
has_not_allowed = True
@ -911,7 +911,7 @@ class MusicCache:
self,
ctx: commands.Context,
player: lavalink.Player,
query: dataclasses.Query,
query: audio_dataclasses.Query,
forced: bool = False,
) -> Tuple[LoadResult, bool]:
"""
@ -925,7 +925,7 @@ class MusicCache:
The context this method is being called under.
player : lavalink.Player
The player who's requesting the query.
query: dataclasses.Query
query: audio_dataclasses.Query
The Query object for the query in question.
forced:bool
Whether or not to skip cache and call API first..
@ -939,7 +939,7 @@ class MusicCache:
)
cache_enabled = CacheLevel.set_lavalink().is_subset(current_cache_level)
val = None
_raw_query = dataclasses.Query.process_input(query)
_raw_query = audio_dataclasses.Query.process_input(query)
query = str(_raw_query)
if cache_enabled and not forced and not _raw_query.is_local:
update = True
@ -1003,14 +1003,10 @@ class MusicCache:
tasks = self._tasks[ctx.message.id]
del self._tasks[ctx.message.id]
await asyncio.gather(
*[asyncio.ensure_future(self.insert(*a)) for a in tasks["insert"]],
loop=self.bot.loop,
return_exceptions=True,
*[self.insert(*a) for a in tasks["insert"]], return_exceptions=True
)
await asyncio.gather(
*[asyncio.ensure_future(self.update(*a)) for a in tasks["update"]],
loop=self.bot.loop,
return_exceptions=True,
*[self.update(*a) for a in tasks["update"]], return_exceptions=True
)
log.debug(f"Completed database writes for {lock_id} " f"({lock_author})")
@ -1025,14 +1021,10 @@ class MusicCache:
self._tasks = {}
await asyncio.gather(
*[asyncio.ensure_future(self.insert(*a)) for a in tasks["insert"]],
loop=self.bot.loop,
return_exceptions=True,
*[self.insert(*a) for a in tasks["insert"]], return_exceptions=True
)
await asyncio.gather(
*[asyncio.ensure_future(self.update(*a)) for a in tasks["update"]],
loop=self.bot.loop,
return_exceptions=True,
*[self.update(*a) for a in tasks["update"]], return_exceptions=True
)
log.debug("Completed pending writes to database have finished")
@ -1096,7 +1088,9 @@ class MusicCache:
if not tracks:
ctx = namedtuple("Context", "message")
results, called_api = await self.lavalink_query(
ctx(player.channel.guild), player, dataclasses.Query.process_input(_TOP_100_US)
ctx(player.channel.guild),
player,
audio_dataclasses.Query.process_input(_TOP_100_US),
)
tracks = list(results.tracks)
if tracks:
@ -1107,7 +1101,7 @@ class MusicCache:
while valid is False and multiple:
track = random.choice(tracks)
query = dataclasses.Query.process_input(track)
query = audio_dataclasses.Query.process_input(track)
if not query.valid:
continue
if query.is_local and not query.track.exists():
@ -1116,7 +1110,7 @@ class MusicCache:
player.channel.guild,
(
f"{track.title} {track.author} {track.uri} "
f"{str(dataclasses.Query.process_input(track))}"
f"{str(audio_dataclasses.Query.process_input(track))}"
),
):
log.debug(

View File

@ -34,7 +34,7 @@ from redbot.core.utils.menus import (
start_adding_reactions,
)
from redbot.core.utils.predicates import MessagePredicate, ReactionPredicate
from . import dataclasses
from . import audio_dataclasses
from .apis import MusicCache, HAS_SQL, _ERROR
from .checks import can_have_caching
from .converters import ComplexScopeParser, ScopeParser, get_lazy_converter, get_playlist_converter
@ -142,7 +142,11 @@ class Audio(commands.Cog):
self.play_lock = {}
self._manager: Optional[ServerManager] = None
self.bot.dispatch("red_audio_initialized", self)
# These has to be a task since this requires the bot to be ready
# If it waits for ready in startup, we cause a deadlock during initial load
# as initial load happens before the bot can ever be ready.
self._init_task = self.bot.loop.create_task(self.initialize())
self._ready_event = asyncio.Event()
@property
def owns_autoplay(self):
@ -166,9 +170,14 @@ class Audio(commands.Cog):
self._cog_id = None
async def cog_before_invoke(self, ctx: commands.Context):
await self._ready_event.wait()
# check for unsupported arch
# Check on this needs refactoring at a later date
# so that we have a better way to handle the tasks
if self.llsetup in [ctx.command, ctx.command.root_parent]:
pass
elif self._connect_task.cancelled():
elif self._connect_task and self._connect_task.cancelled():
await ctx.send(
"You have attempted to run Audio's Lavalink server on an unsupported"
" architecture. Only settings related commands will be available."
@ -176,6 +185,7 @@ class Audio(commands.Cog):
raise RuntimeError(
"Not running audio command due to invalid machine architecture for Lavalink."
)
dj_enabled = await self.config.guild(ctx.guild).dj_enabled()
if dj_enabled:
dj_role_obj = ctx.guild.get_role(await self.config.guild(ctx.guild).dj_role())
@ -185,13 +195,13 @@ class Audio(commands.Cog):
await self._embed_msg(ctx, _("No DJ role found. Disabling DJ mode."))
async def initialize(self):
pass_config_to_dependencies(self.config, self.bot, await self.config.localpath())
await self.bot.wait_until_ready()
# Unlike most cases, we want the cache to exit before migration.
await self.music_cache.initialize(self.config)
asyncio.ensure_future(
self._migrate_config(
from_version=await self.config.schema_version(), to_version=_SCHEMA_VERSION
)
await self._migrate_config(
from_version=await self.config.schema_version(), to_version=_SCHEMA_VERSION
)
pass_config_to_dependencies(self.config, self.bot, await self.config.localpath())
self._restart_connect()
self._disconnect_task = self.bot.loop.create_task(self.disconnect_timer())
lavalink.register_event_listener(self.event_handler)
@ -209,6 +219,9 @@ class Audio(commands.Cog):
await self.bot.send_to_owners(page)
log.critical(error_message)
self._ready_event.set()
self.bot.dispatch("red_audio_initialized", self)
async def _migrate_config(self, from_version: int, to_version: int):
database_entries = []
time_now = str(datetime.datetime.now(datetime.timezone.utc))
@ -253,7 +266,7 @@ class Audio(commands.Cog):
cast(discord.Guild, discord.Object(id=guild_id))
).clear_raw("playlists")
if database_entries and HAS_SQL:
asyncio.ensure_future(self.music_cache.insert("lavalink", database_entries))
await self.music_cache.insert("lavalink", database_entries)
def _restart_connect(self):
if self._connect_task:
@ -366,7 +379,9 @@ class Audio(commands.Cog):
async def _players_check():
try:
get_single_title = lavalink.active_players()[0].current.title
query = dataclasses.Query.process_input(lavalink.active_players()[0].current.uri)
query = audio_dataclasses.Query.process_input(
lavalink.active_players()[0].current.uri
)
if get_single_title == "Unknown title":
get_single_title = lavalink.active_players()[0].current.uri
if not get_single_title.startswith("http"):
@ -463,18 +478,18 @@ class Audio(commands.Cog):
)
await notify_channel.send(embed=embed)
query = dataclasses.Query.process_input(player.current.uri)
query = audio_dataclasses.Query.process_input(player.current.uri)
if query.is_local if player.current else False:
if player.current.title != "Unknown title":
description = "**{} - {}**\n{}".format(
player.current.author,
player.current.title,
dataclasses.LocalPath(player.current.uri).to_string_hidden(),
audio_dataclasses.LocalPath(player.current.uri).to_string_hidden(),
)
else:
description = "{}".format(
dataclasses.LocalPath(player.current.uri).to_string_hidden()
audio_dataclasses.LocalPath(player.current.uri).to_string_hidden()
)
else:
description = "**[{}]({})**".format(player.current.title, player.current.uri)
@ -532,9 +547,9 @@ class Audio(commands.Cog):
message_channel = player.fetch("channel")
if message_channel:
message_channel = self.bot.get_channel(message_channel)
query = dataclasses.Query.process_input(player.current.uri)
query = audio_dataclasses.Query.process_input(player.current.uri)
if player.current and query.is_local:
query = dataclasses.Query.process_input(player.current.uri)
query = audio_dataclasses.Query.process_input(player.current.uri)
if player.current.title == "Unknown title":
description = "{}".format(query.track.to_string_hidden())
else:
@ -590,7 +605,7 @@ class Audio(commands.Cog):
player.store("channel", channel.id)
player.store("guild", guild.id)
await self._data_check(guild.me)
query = dataclasses.Query.process_input(query)
query = audio_dataclasses.Query.process_input(query)
ctx = namedtuple("Context", "message")
results, called_api = await self.music_cache.lavalink_query(ctx(guild), player, query)
@ -1094,7 +1109,7 @@ class Audio(commands.Cog):
with contextlib.suppress(discord.HTTPException):
await info.delete()
return
temp = dataclasses.LocalPath(local_path, forced=True)
temp = audio_dataclasses.LocalPath(local_path, forced=True)
if not temp.exists() or not temp.is_dir():
return await self._embed_msg(
ctx,
@ -1536,7 +1551,7 @@ class Audio(commands.Cog):
int((datetime.datetime.utcnow() - connect_start).total_seconds())
)
try:
query = dataclasses.Query.process_input(p.current.uri)
query = audio_dataclasses.Query.process_input(p.current.uri)
if query.is_local:
if p.current.title == "Unknown title":
current_title = localtracks.LocalPath(p.current.uri).to_string_hidden()
@ -1606,9 +1621,9 @@ class Audio(commands.Cog):
bump_song = player.queue[bump_index]
player.queue.insert(0, bump_song)
removed = player.queue.pop(index)
query = dataclasses.Query.process_input(removed.uri)
query = audio_dataclasses.Query.process_input(removed.uri)
if query.is_local:
localtrack = dataclasses.LocalPath(removed.uri)
localtrack = audio_dataclasses.LocalPath(removed.uri)
if removed.title != "Unknown title":
description = "**{} - {}**\n{}".format(
removed.author, removed.title, localtrack.to_string_hidden()
@ -1997,12 +2012,12 @@ class Audio(commands.Cog):
await ctx.invoke(self.local_play, play_subfolders=play_subfolders)
else:
folder = folder.strip()
_dir = dataclasses.LocalPath.joinpath(folder)
_dir = audio_dataclasses.LocalPath.joinpath(folder)
if not _dir.exists():
return await self._embed_msg(
ctx, _("No localtracks folder named {name}.").format(name=folder)
)
query = dataclasses.Query.process_input(_dir, search_subfolders=play_subfolders)
query = audio_dataclasses.Query.process_input(_dir, search_subfolders=play_subfolders)
await self._local_play_all(ctx, query, from_search=False if not folder else True)
@local.command(name="play")
@ -2064,8 +2079,8 @@ class Audio(commands.Cog):
all_tracks = await self._folder_list(
ctx,
(
dataclasses.Query.process_input(
dataclasses.LocalPath(
audio_dataclasses.Query.process_input(
audio_dataclasses.LocalPath(
await self.config.localpath()
).localtrack_folder.absolute(),
search_subfolders=play_subfolders,
@ -2081,18 +2096,18 @@ class Audio(commands.Cog):
return await ctx.invoke(self.search, query=search_list)
async def _localtracks_folders(self, ctx: commands.Context, search_subfolders=False):
audio_data = dataclasses.LocalPath(
dataclasses.LocalPath(None).localtrack_folder.absolute()
audio_data = audio_dataclasses.LocalPath(
audio_dataclasses.LocalPath(None).localtrack_folder.absolute()
)
if not await self._localtracks_check(ctx):
return
return audio_data.subfolders_in_tree() if search_subfolders else audio_data.subfolders()
async def _folder_list(self, ctx: commands.Context, query: dataclasses.Query):
async def _folder_list(self, ctx: commands.Context, query: audio_dataclasses.Query):
if not await self._localtracks_check(ctx):
return
query = dataclasses.Query.process_input(query)
query = audio_dataclasses.Query.process_input(query)
if not query.track.exists():
return
return (
@ -2102,12 +2117,12 @@ class Audio(commands.Cog):
)
async def _folder_tracks(
self, ctx, player: lavalink.player_manager.Player, query: dataclasses.Query
self, ctx, player: lavalink.player_manager.Player, query: audio_dataclasses.Query
):
if not await self._localtracks_check(ctx):
return
audio_data = dataclasses.LocalPath(None)
audio_data = audio_dataclasses.LocalPath(None)
try:
query.track.path.relative_to(audio_data.to_string())
except ValueError:
@ -2120,17 +2135,17 @@ class Audio(commands.Cog):
return local_tracks
async def _local_play_all(
self, ctx: commands.Context, query: dataclasses.Query, from_search=False
self, ctx: commands.Context, query: audio_dataclasses.Query, from_search=False
):
if not await self._localtracks_check(ctx):
return
if from_search:
query = dataclasses.Query.process_input(
query = audio_dataclasses.Query.process_input(
query.track.to_string(), invoked_from="local folder"
)
await ctx.invoke(self.search, query=query)
async def _all_folder_tracks(self, ctx: commands.Context, query: dataclasses.Query):
async def _all_folder_tracks(self, ctx: commands.Context, query: audio_dataclasses.Query):
if not await self._localtracks_check(ctx):
return
@ -2141,7 +2156,7 @@ class Audio(commands.Cog):
)
async def _localtracks_check(self, ctx: commands.Context):
folder = dataclasses.LocalPath(None)
folder = audio_dataclasses.LocalPath(None)
if folder.localtrack_folder.exists():
return True
if ctx.invoked_with != "start":
@ -2177,7 +2192,7 @@ class Audio(commands.Cog):
dur = "LIVE"
else:
dur = lavalink.utils.format_time(player.current.length)
query = dataclasses.Query.process_input(player.current.uri)
query = audio_dataclasses.Query.process_input(player.current.uri)
if query.is_local:
if not player.current.title == "Unknown title":
song = "**{track.author} - {track.title}**\n{uri}\n"
@ -2189,8 +2204,8 @@ class Audio(commands.Cog):
song += "\n\n{arrow}`{pos}`/`{dur}`"
song = song.format(
track=player.current,
uri=dataclasses.LocalPath(player.current.uri).to_string_hidden()
if dataclasses.Query.process_input(player.current.uri).is_local
uri=audio_dataclasses.LocalPath(player.current.uri).to_string_hidden()
if audio_dataclasses.Query.process_input(player.current.uri).is_local
else player.current.uri,
arrow=arrow,
pos=pos,
@ -2301,9 +2316,9 @@ class Audio(commands.Cog):
if not player.current:
return await self._embed_msg(ctx, _("Nothing playing."))
query = dataclasses.Query.process_input(player.current.uri)
query = audio_dataclasses.Query.process_input(player.current.uri)
if query.is_local:
query = dataclasses.Query.process_input(player.current.uri)
query = audio_dataclasses.Query.process_input(player.current.uri)
if player.current.title == "Unknown title":
description = "{}".format(query.track.to_string_hidden())
else:
@ -2436,7 +2451,7 @@ class Audio(commands.Cog):
)
if not await self._currency_check(ctx, guild_data["jukebox_price"]):
return
query = dataclasses.Query.process_input(query)
query = audio_dataclasses.Query.process_input(query)
if not query.valid:
return await self._embed_msg(ctx, _("No tracks to play."))
if query.is_spotify:
@ -2593,7 +2608,7 @@ class Audio(commands.Cog):
)
playlists_search_page_list.append(embed)
playlists_pick = await menu(ctx, playlists_search_page_list, playlist_search_controls)
query = dataclasses.Query.process_input(playlists_pick)
query = audio_dataclasses.Query.process_input(playlists_pick)
if not query.valid:
return await self._embed_msg(ctx, _("No tracks to play."))
if not await self._currency_check(ctx, guild_data["jukebox_price"]):
@ -2728,7 +2743,7 @@ class Audio(commands.Cog):
elif player.current:
await self._embed_msg(ctx, _("Adding a track to queue."))
async def _get_spotify_tracks(self, ctx: commands.Context, query: dataclasses.Query):
async def _get_spotify_tracks(self, ctx: commands.Context, query: audio_dataclasses.Query):
if ctx.invoked_with in ["play", "genre"]:
enqueue_tracks = True
else:
@ -2771,12 +2786,12 @@ class Audio(commands.Cog):
self._play_lock(ctx, False)
try:
if enqueue_tracks:
new_query = dataclasses.Query.process_input(res[0])
new_query = audio_dataclasses.Query.process_input(res[0])
new_query.start_time = query.start_time
return await self._enqueue_tracks(ctx, new_query)
else:
result, called_api = await self.music_cache.lavalink_query(
ctx, player, dataclasses.Query.process_input(res[0])
ctx, player, audio_dataclasses.Query.process_input(res[0])
)
tracks = result.tracks
if not tracks:
@ -2808,7 +2823,9 @@ class Audio(commands.Cog):
ctx, _("This doesn't seem to be a supported Spotify URL or code.")
)
async def _enqueue_tracks(self, ctx: commands.Context, query: Union[dataclasses.Query, list]):
async def _enqueue_tracks(
self, ctx: commands.Context, query: Union[audio_dataclasses.Query, list]
):
player = lavalink.get_player(ctx.guild.id)
try:
if self.play_lock[ctx.message.guild.id]:
@ -2863,7 +2880,7 @@ class Audio(commands.Cog):
ctx.guild,
(
f"{track.title} {track.author} {track.uri} "
f"{str(dataclasses.Query.process_input(track))}"
f"{str(audio_dataclasses.Query.process_input(track))}"
),
):
log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})")
@ -2923,7 +2940,7 @@ class Audio(commands.Cog):
ctx.guild,
(
f"{single_track.title} {single_track.author} {single_track.uri} "
f"{str(dataclasses.Query.process_input(single_track))}"
f"{str(audio_dataclasses.Query.process_input(single_track))}"
),
):
log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})")
@ -2956,17 +2973,17 @@ class Audio(commands.Cog):
return await self._embed_msg(
ctx, _("Nothing found. Check your Lavalink logs for details.")
)
query = dataclasses.Query.process_input(single_track.uri)
query = audio_dataclasses.Query.process_input(single_track.uri)
if query.is_local:
if single_track.title != "Unknown title":
description = "**{} - {}**\n{}".format(
single_track.author,
single_track.title,
dataclasses.LocalPath(single_track.uri).to_string_hidden(),
audio_dataclasses.LocalPath(single_track.uri).to_string_hidden(),
)
else:
description = "{}".format(
dataclasses.LocalPath(single_track.uri).to_string_hidden()
audio_dataclasses.LocalPath(single_track.uri).to_string_hidden()
)
else:
description = "**[{}]({})**".format(single_track.title, single_track.uri)
@ -2987,7 +3004,11 @@ class Audio(commands.Cog):
self._play_lock(ctx, False)
async def _spotify_playlist(
self, ctx: commands.Context, stype: str, query: dataclasses.Query, enqueue: bool = False
self,
ctx: commands.Context,
stype: str,
query: audio_dataclasses.Query,
enqueue: bool = False,
):
player = lavalink.get_player(ctx.guild.id)
@ -3340,7 +3361,7 @@ class Audio(commands.Cog):
return
player = lavalink.get_player(ctx.guild.id)
to_append = await self._playlist_tracks(
ctx, player, dataclasses.Query.process_input(query)
ctx, player, audio_dataclasses.Query.process_input(query)
)
if not to_append:
return await self._embed_msg(ctx, _("Could not find a track matching your query."))
@ -3993,7 +4014,7 @@ class Audio(commands.Cog):
spaces = "\N{EN SPACE}" * (len(str(len(playlist.tracks))) + 2)
for track in playlist.tracks:
track_idx = track_idx + 1
query = dataclasses.Query.process_input(track["info"]["uri"])
query = audio_dataclasses.Query.process_input(track["info"]["uri"])
if query.is_local:
if track["info"]["title"] != "Unknown title":
msg += "`{}.` **{} - {}**\n{}{}\n".format(
@ -4398,7 +4419,7 @@ class Audio(commands.Cog):
return
player = lavalink.get_player(ctx.guild.id)
tracklist = await self._playlist_tracks(
ctx, player, dataclasses.Query.process_input(playlist_url)
ctx, player, audio_dataclasses.Query.process_input(playlist_url)
)
if tracklist is not None:
playlist = await create_playlist(
@ -4488,14 +4509,14 @@ class Audio(commands.Cog):
ctx.guild,
(
f"{track.title} {track.author} {track.uri} "
f"{str(dataclasses.Query.process_input(track))}"
f"{str(audio_dataclasses.Query.process_input(track))}"
),
):
log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})")
continue
query = dataclasses.Query.process_input(track.uri)
query = audio_dataclasses.Query.process_input(track.uri)
if query.is_local:
local_path = dataclasses.LocalPath(track.uri)
local_path = audio_dataclasses.LocalPath(track.uri)
if not await self._localtracks_check(ctx):
pass
if not local_path.exists() and not local_path.is_file():
@ -4781,7 +4802,7 @@ class Audio(commands.Cog):
or not match_yt_playlist(uploaded_playlist_url)
or not (
await self.music_cache.lavalink_query(
ctx, player, dataclasses.Query.process_input(uploaded_playlist_url)
ctx, player, audio_dataclasses.Query.process_input(uploaded_playlist_url)
)
)[0].tracks
):
@ -4966,7 +4987,7 @@ class Audio(commands.Cog):
}
)
if database_entries and HAS_SQL:
asyncio.ensure_future(self.music_cache.insert("lavalink", database_entries))
await self.music_cache.insert("lavalink", database_entries)
async def _load_v2_playlist(
self,
@ -4993,7 +5014,7 @@ class Audio(commands.Cog):
track_count += 1
try:
result, called_api = await self.music_cache.lavalink_query(
ctx, player, dataclasses.Query.process_input(song_url)
ctx, player, audio_dataclasses.Query.process_input(song_url)
)
track = result.tracks
except Exception:
@ -5041,7 +5062,7 @@ class Audio(commands.Cog):
return [], [], playlist
results = {}
updated_tracks = await self._playlist_tracks(
ctx, player, dataclasses.Query.process_input(playlist.url)
ctx, player, audio_dataclasses.Query.process_input(playlist.url)
)
if not updated_tracks:
# No Tracks available on url Lets set it to none to avoid repeated calls here
@ -5106,7 +5127,7 @@ class Audio(commands.Cog):
self,
ctx: commands.Context,
player: lavalink.player_manager.Player,
query: dataclasses.Query,
query: audio_dataclasses.Query,
):
search = query.is_search
tracklist = []
@ -5175,7 +5196,7 @@ class Audio(commands.Cog):
player.queue.insert(0, bump_song)
player.queue.pop(queue_len)
await player.skip()
query = dataclasses.Query.process_input(player.current.uri)
query = audio_dataclasses.Query.process_input(player.current.uri)
if query.is_local:
if player.current.title == "Unknown title":
@ -5227,7 +5248,7 @@ class Audio(commands.Cog):
else:
dur = lavalink.utils.format_time(player.current.length)
query = dataclasses.Query.process_input(player.current)
query = audio_dataclasses.Query.process_input(player.current)
if query.is_local:
if player.current.title != "Unknown title":
@ -5240,8 +5261,8 @@ class Audio(commands.Cog):
song += "\n\n{arrow}`{pos}`/`{dur}`"
song = song.format(
track=player.current,
uri=dataclasses.LocalPath(player.current.uri).to_string_hidden()
if dataclasses.Query.process_input(player.current.uri).is_local
uri=audio_dataclasses.LocalPath(player.current.uri).to_string_hidden()
if audio_dataclasses.Query.process_input(player.current.uri).is_local
else player.current.uri,
arrow=arrow,
pos=pos,
@ -5313,7 +5334,7 @@ class Audio(commands.Cog):
else:
dur = lavalink.utils.format_time(player.current.length)
query = dataclasses.Query.process_input(player.current)
query = audio_dataclasses.Query.process_input(player.current)
if query.is_stream:
queue_list += _("**Currently livestreaming:**\n")
@ -5327,7 +5348,7 @@ class Audio(commands.Cog):
(
_("Playing: ")
+ "**{current.author} - {current.title}**".format(current=player.current),
dataclasses.LocalPath(player.current.uri).to_string_hidden(),
audio_dataclasses.LocalPath(player.current.uri).to_string_hidden(),
_("Requested by: **{user}**\n").format(user=player.current.requester),
f"{arrow}`{pos}`/`{dur}`\n\n",
)
@ -5336,7 +5357,7 @@ class Audio(commands.Cog):
queue_list += "\n".join(
(
_("Playing: ")
+ dataclasses.LocalPath(player.current.uri).to_string_hidden(),
+ audio_dataclasses.LocalPath(player.current.uri).to_string_hidden(),
_("Requested by: **{user}**\n").format(user=player.current.requester),
f"{arrow}`{pos}`/`{dur}`\n\n",
)
@ -5357,13 +5378,13 @@ class Audio(commands.Cog):
track_title = track.title
req_user = track.requester
track_idx = i + 1
query = dataclasses.Query.process_input(track)
query = audio_dataclasses.Query.process_input(track)
if query.is_local:
if track.title == "Unknown title":
queue_list += f"`{track_idx}.` " + ", ".join(
(
bold(dataclasses.LocalPath(track.uri).to_string_hidden()),
bold(audio_dataclasses.LocalPath(track.uri).to_string_hidden()),
_("requested by **{user}**\n").format(user=req_user),
)
)
@ -5420,7 +5441,7 @@ class Audio(commands.Cog):
for track in queue_list:
queue_idx = queue_idx + 1
if not match_url(track.uri):
query = dataclasses.Query.process_input(track)
query = audio_dataclasses.Query.process_input(track)
if track.title == "Unknown title":
track_title = query.track.to_string_hidden()
else:
@ -5449,7 +5470,7 @@ class Audio(commands.Cog):
):
track_idx = i + 1
if type(track) is str:
track_location = dataclasses.LocalPath(track).to_string_hidden()
track_location = audio_dataclasses.LocalPath(track).to_string_hidden()
track_match += "`{}.` **{}**\n".format(track_idx, track_location)
else:
track_match += "`{}.` **{}**\n".format(track[0], track[1])
@ -5674,9 +5695,9 @@ class Audio(commands.Cog):
)
index -= 1
removed = player.queue.pop(index)
query = dataclasses.Query.process_input(removed.uri)
query = audio_dataclasses.Query.process_input(removed.uri)
if query.is_local:
local_path = dataclasses.LocalPath(removed.uri).to_string_hidden()
local_path = audio_dataclasses.LocalPath(removed.uri).to_string_hidden()
if removed.title == "Unknown title":
removed_title = local_path
else:
@ -5762,7 +5783,7 @@ class Audio(commands.Cog):
await self._data_check(ctx)
if not isinstance(query, list):
query = dataclasses.Query.process_input(query)
query = audio_dataclasses.Query.process_input(query)
if query.invoked_from == "search list" or query.invoked_from == "local folder":
if query.invoked_from == "search list":
result, called_api = await self.music_cache.lavalink_query(ctx, player, query)
@ -5791,7 +5812,7 @@ class Audio(commands.Cog):
ctx.guild,
(
f"{track.title} {track.author} {track.uri} "
f"{str(dataclasses.Query.process_input(track))}"
f"{str(audio_dataclasses.Query.process_input(track))}"
),
):
log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})")
@ -5905,10 +5926,10 @@ class Audio(commands.Cog):
except IndexError:
search_choice = tracks[-1]
try:
query = dataclasses.Query.process_input(search_choice.uri)
query = audio_dataclasses.Query.process_input(search_choice.uri)
if query.is_local:
localtrack = dataclasses.LocalPath(search_choice.uri)
localtrack = audio_dataclasses.LocalPath(search_choice.uri)
if search_choice.title != "Unknown title":
description = "**{} - {}**\n{}".format(
search_choice.author, search_choice.title, localtrack.to_string_hidden()
@ -5919,7 +5940,7 @@ class Audio(commands.Cog):
description = "**[{}]({})**".format(search_choice.title, search_choice.uri)
except AttributeError:
search_choice = dataclasses.Query.process_input(search_choice)
search_choice = audio_dataclasses.Query.process_input(search_choice)
if search_choice.track.exists() and search_choice.track.is_dir():
return await ctx.invoke(self.search, query=search_choice)
elif search_choice.track.exists() and search_choice.track.is_file():
@ -5935,7 +5956,7 @@ class Audio(commands.Cog):
ctx.guild,
(
f"{search_choice.title} {search_choice.author} {search_choice.uri} "
f"{str(dataclasses.Query.process_input(search_choice))}"
f"{str(audio_dataclasses.Query.process_input(search_choice))}"
),
):
log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})")
@ -5984,12 +6005,12 @@ class Audio(commands.Cog):
if search_track_num == 0:
search_track_num = 5
try:
query = dataclasses.Query.process_input(track.uri)
query = audio_dataclasses.Query.process_input(track.uri)
if query.is_local:
search_list += "`{0}.` **{1}**\n[{2}]\n".format(
search_track_num,
track.title,
dataclasses.LocalPath(track.uri).to_string_hidden(),
audio_dataclasses.LocalPath(track.uri).to_string_hidden(),
)
else:
search_list += "`{0}.` **[{1}]({2})**\n".format(
@ -5997,7 +6018,7 @@ class Audio(commands.Cog):
)
except AttributeError:
# query = Query.process_input(track)
track = dataclasses.Query.process_input(track)
track = audio_dataclasses.Query.process_input(track)
if track.is_local and command != "search":
search_list += "`{}.` **{}**\n".format(
search_track_num, track.to_string_user()
@ -6890,6 +6911,7 @@ class Audio(commands.Cog):
async def on_voice_state_update(
self, member: discord.Member, before: discord.VoiceState, after: discord.VoiceState
):
await self._ready_event.wait()
if after.channel != before.channel:
try:
self.skip_votes[before.channel.guild].remove(member.id)
@ -6907,6 +6929,9 @@ class Audio(commands.Cog):
if self._connect_task:
self._connect_task.cancel()
if self._init_task:
self._init_task.cancel()
lavalink.unregister_event_listener(self.event_handler)
self.bot.loop.create_task(lavalink.close())
if self._manager is not None:

View File

@ -3,7 +3,6 @@ import contextlib
import os
import re
import time
from typing import NoReturn
from urllib.parse import urlparse
import discord
@ -11,7 +10,7 @@ import lavalink
from redbot.core import Config, commands
from redbot.core.bot import Red
from . import dataclasses
from . import audio_dataclasses
from .converters import _pass_config_to_converters
@ -51,7 +50,7 @@ def pass_config_to_dependencies(config: Config, bot: Red, localtracks_folder: st
_config = config
_pass_config_to_playlist(config, bot)
_pass_config_to_converters(config, bot)
dataclasses._pass_config_to_dataclasses(config, bot, localtracks_folder)
audio_dataclasses._pass_config_to_dataclasses(config, bot, localtracks_folder)
def track_limit(track, maxlength):
@ -168,7 +167,7 @@ async def clear_react(bot: Red, message: discord.Message, emoji: dict = None):
async def get_description(track):
if any(x in track.uri for x in [f"{os.sep}localtracks", f"localtracks{os.sep}"]):
local_track = dataclasses.LocalPath(track.uri)
local_track = audio_dataclasses.LocalPath(track.uri)
if track.title != "Unknown title":
return "**{} - {}**\n{}".format(
track.author, track.title, local_track.to_string_hidden()
@ -389,7 +388,7 @@ class Notifier:
key: str = None,
seconds_key: str = None,
seconds: str = None,
) -> NoReturn:
):
"""
This updates an existing message.
Based on the message found in :variable:`Notifier.updates` as per the `key` param
@ -410,14 +409,14 @@ class Notifier:
except discord.errors.NotFound:
pass
async def update_text(self, text: str) -> NoReturn:
async def update_text(self, text: str):
embed2 = discord.Embed(colour=self.color, title=text)
try:
await self.message.edit(embed=embed2)
except discord.errors.NotFound:
pass
async def update_embed(self, embed: discord.Embed) -> NoReturn:
async def update_embed(self, embed: discord.Embed):
try:
await self.message.edit(embed=embed)
self.last_msg_time = time.time()

View File

@ -132,7 +132,6 @@ class RedBase(commands.GroupMixin, commands.bot.BotBase, RPCMixin): # pylint: d
self._main_dir = bot_dir
self._cog_mgr = CogManager()
super().__init__(*args, help_command=None, **kwargs)
# Do not manually use the help formatter attribute here, see `send_help_for`,
# for a documented API. The internals of this object are still subject to change.
@ -325,6 +324,7 @@ class RedBase(commands.GroupMixin, commands.bot.BotBase, RPCMixin): # pylint: d
get_embed_colour = get_embed_color
# start config migrations
async def _maybe_update_config(self):
"""
This should be run prior to loading cogs or connecting to discord.
@ -375,6 +375,57 @@ class RedBase(commands.GroupMixin, commands.bot.BotBase, RPCMixin): # pylint: d
await self._config.guild(guild_obj).admin_role.set(admin_roles)
log.info("Done updating guild configs to support multiple mod/admin roles")
# end Config migrations
async def pre_flight(self, cli_flags):
"""
This should only be run once, prior to connecting to discord.
"""
await self._maybe_update_config()
packages = []
if cli_flags.no_cogs is False:
packages.extend(await self._config.packages())
if cli_flags.load_cogs:
packages.extend(cli_flags.load_cogs)
if packages:
# Load permissions first, for security reasons
try:
packages.remove("permissions")
except ValueError:
pass
else:
packages.insert(0, "permissions")
to_remove = []
print("Loading packages...")
for package in packages:
try:
spec = await self._cog_mgr.find_cog(package)
await asyncio.wait_for(self.load_extension(spec), 30)
except asyncio.TimeoutError:
log.exception("Failed to load package %s (timeout)", package)
to_remove.append(package)
except Exception as e:
log.exception("Failed to load package {}".format(package), exc_info=e)
await self.remove_loaded_package(package)
to_remove.append(package)
for package in to_remove:
packages.remove(package)
if packages:
print("Loaded packages: " + ", ".join(packages))
if self.rpc_enabled:
await self.rpc.initialize(self.rpc_port)
async def start(self, *args, **kwargs):
cli_flags = kwargs.pop("cli_flags")
await self.pre_flight(cli_flags=cli_flags)
return await super().start(*args, **kwargs)
async def send_help_for(
self, ctx: commands.Context, help_for: Union[commands.Command, commands.GroupMixin, str]
):

View File

@ -46,40 +46,6 @@ def init_events(bot, cli_flags):
return
bot._uptime = datetime.datetime.utcnow()
packages = []
if cli_flags.no_cogs is False:
packages.extend(await bot._config.packages())
if cli_flags.load_cogs:
packages.extend(cli_flags.load_cogs)
if packages:
# Load permissions first, for security reasons
try:
packages.remove("permissions")
except ValueError:
pass
else:
packages.insert(0, "permissions")
to_remove = []
print("Loading packages...")
for package in packages:
try:
spec = await bot._cog_mgr.find_cog(package)
await bot.load_extension(spec)
except Exception as e:
log.exception("Failed to load package {}".format(package), exc_info=e)
await bot.remove_loaded_package(package)
to_remove.append(package)
for package in to_remove:
packages.remove(package)
if packages:
print("Loaded packages: " + ", ".join(packages))
if bot.rpc_enabled:
await bot.rpc.initialize(bot.rpc_port)
guilds = len(bot.guilds)
users = len(set([m for m in bot.get_all_members()]))