diff --git a/changelog.d/3045.enhance.rst b/changelog.d/3045.enhance.rst new file mode 100644 index 000000000..65bc0182f --- /dev/null +++ b/changelog.d/3045.enhance.rst @@ -0,0 +1 @@ +Bot now handles more things prior to connecting to discord to reduce issues with initial load \ No newline at end of file diff --git a/changelog.d/3073.breaking.rst b/changelog.d/3073.breaking.rst new file mode 100644 index 000000000..d12ae4859 --- /dev/null +++ b/changelog.d/3073.breaking.rst @@ -0,0 +1 @@ +``bot.wait_until_ready`` should no longer be used during extension setup \ No newline at end of file diff --git a/redbot/__main__.py b/redbot/__main__.py index b9a87cecc..e947758d0 100644 --- a/redbot/__main__.py +++ b/redbot/__main__.py @@ -312,7 +312,7 @@ def main(): loop.run_until_complete(red.http.close()) sys.exit(0) try: - loop.run_until_complete(red.start(token, bot=True)) + loop.run_until_complete(red.start(token, bot=True, cli_flags=cli_flags)) except discord.LoginFailure: log.critical("This token doesn't seem to be valid.") db_token = loop.run_until_complete(red._config.token()) diff --git a/redbot/cogs/audio/__init__.py b/redbot/cogs/audio/__init__.py index d36ffc7e3..e69258734 100644 --- a/redbot/cogs/audio/__init__.py +++ b/redbot/cogs/audio/__init__.py @@ -3,7 +3,6 @@ from redbot.core import commands from .audio import Audio -async def setup(bot: commands.Bot): +def setup(bot: commands.Bot): cog = Audio(bot) - await cog.initialize() bot.add_cog(cog) diff --git a/redbot/cogs/audio/apis.py b/redbot/cogs/audio/apis.py index 31f208da2..fa254ffc2 100644 --- a/redbot/cogs/audio/apis.py +++ b/redbot/cogs/audio/apis.py @@ -9,7 +9,7 @@ import random import time import traceback from collections import namedtuple -from typing import Callable, Dict, List, Mapping, NoReturn, Optional, Tuple, Union +from typing import Callable, Dict, List, Mapping, Optional, Tuple, Union try: from sqlite3 import Error as SQLError @@ -32,7 +32,7 @@ from lavalink.rest_api import LoadResult from redbot.core import Config, commands from redbot.core.bot import Red from redbot.core.i18n import Translator, cog_i18n -from . import dataclasses +from . import audio_dataclasses from .errors import InvalidTableError, SpotifyFetchError, YouTubeApiError from .playlists import get_playlist from .utils import CacheLevel, Notifier, is_allowed, queue_duration, track_limit @@ -193,7 +193,7 @@ class SpotifyAPI: ) return await r.json() - async def _get_auth(self) -> NoReturn: + async def _get_auth(self): if self.client_id is None or self.client_secret is None: tokens = await self.bot.get_shared_api_tokens("spotify") self.client_id = tokens.get("client_id", "") @@ -331,7 +331,7 @@ class MusicCache: self._lock: asyncio.Lock = asyncio.Lock() self.config: Optional[Config] = None - async def initialize(self, config: Config) -> NoReturn: + async def initialize(self, config: Config): if HAS_SQL: await self.database.connect() @@ -348,12 +348,12 @@ class MusicCache: await self.database.execute(query=_CREATE_UNIQUE_INDEX_SPOTIFY_TABLE) self.config = config - async def close(self) -> NoReturn: + async def close(self): if HAS_SQL: await self.database.execute(query="PRAGMA optimize;") await self.database.disconnect() - async def insert(self, table: str, values: List[dict]) -> NoReturn: + async def insert(self, table: str, values: List[dict]): # if table == "spotify": # return if HAS_SQL: @@ -363,7 +363,7 @@ class MusicCache: await self.database.execute_many(query=query, values=values) - async def update(self, table: str, values: Dict[str, str]) -> NoReturn: + async def update(self, table: str, values: Dict[str, str]): # if table == "spotify": # return if HAS_SQL: @@ -746,7 +746,7 @@ class MusicCache: if val: try: result, called_api = await self.lavalink_query( - ctx, player, dataclasses.Query.process_input(val) + ctx, player, audio_dataclasses.Query.process_input(val) ) except (RuntimeError, aiohttp.ServerDisconnectedError): lock(ctx, False) @@ -805,7 +805,7 @@ class MusicCache: ctx.guild, ( f"{single_track.title} {single_track.author} {single_track.uri} " - f"{str(dataclasses.Query.process_input(single_track))}" + f"{str(audio_dataclasses.Query.process_input(single_track))}" ), ): has_not_allowed = True @@ -911,7 +911,7 @@ class MusicCache: self, ctx: commands.Context, player: lavalink.Player, - query: dataclasses.Query, + query: audio_dataclasses.Query, forced: bool = False, ) -> Tuple[LoadResult, bool]: """ @@ -925,7 +925,7 @@ class MusicCache: The context this method is being called under. player : lavalink.Player The player who's requesting the query. - query: dataclasses.Query + query: audio_dataclasses.Query The Query object for the query in question. forced:bool Whether or not to skip cache and call API first.. @@ -939,7 +939,7 @@ class MusicCache: ) cache_enabled = CacheLevel.set_lavalink().is_subset(current_cache_level) val = None - _raw_query = dataclasses.Query.process_input(query) + _raw_query = audio_dataclasses.Query.process_input(query) query = str(_raw_query) if cache_enabled and not forced and not _raw_query.is_local: update = True @@ -1003,14 +1003,10 @@ class MusicCache: tasks = self._tasks[ctx.message.id] del self._tasks[ctx.message.id] await asyncio.gather( - *[asyncio.ensure_future(self.insert(*a)) for a in tasks["insert"]], - loop=self.bot.loop, - return_exceptions=True, + *[self.insert(*a) for a in tasks["insert"]], return_exceptions=True ) await asyncio.gather( - *[asyncio.ensure_future(self.update(*a)) for a in tasks["update"]], - loop=self.bot.loop, - return_exceptions=True, + *[self.update(*a) for a in tasks["update"]], return_exceptions=True ) log.debug(f"Completed database writes for {lock_id} " f"({lock_author})") @@ -1025,14 +1021,10 @@ class MusicCache: self._tasks = {} await asyncio.gather( - *[asyncio.ensure_future(self.insert(*a)) for a in tasks["insert"]], - loop=self.bot.loop, - return_exceptions=True, + *[self.insert(*a) for a in tasks["insert"]], return_exceptions=True ) await asyncio.gather( - *[asyncio.ensure_future(self.update(*a)) for a in tasks["update"]], - loop=self.bot.loop, - return_exceptions=True, + *[self.update(*a) for a in tasks["update"]], return_exceptions=True ) log.debug("Completed pending writes to database have finished") @@ -1096,7 +1088,9 @@ class MusicCache: if not tracks: ctx = namedtuple("Context", "message") results, called_api = await self.lavalink_query( - ctx(player.channel.guild), player, dataclasses.Query.process_input(_TOP_100_US) + ctx(player.channel.guild), + player, + audio_dataclasses.Query.process_input(_TOP_100_US), ) tracks = list(results.tracks) if tracks: @@ -1107,7 +1101,7 @@ class MusicCache: while valid is False and multiple: track = random.choice(tracks) - query = dataclasses.Query.process_input(track) + query = audio_dataclasses.Query.process_input(track) if not query.valid: continue if query.is_local and not query.track.exists(): @@ -1116,7 +1110,7 @@ class MusicCache: player.channel.guild, ( f"{track.title} {track.author} {track.uri} " - f"{str(dataclasses.Query.process_input(track))}" + f"{str(audio_dataclasses.Query.process_input(track))}" ), ): log.debug( diff --git a/redbot/cogs/audio/audio.py b/redbot/cogs/audio/audio.py index e442b80af..970de7653 100644 --- a/redbot/cogs/audio/audio.py +++ b/redbot/cogs/audio/audio.py @@ -34,7 +34,7 @@ from redbot.core.utils.menus import ( start_adding_reactions, ) from redbot.core.utils.predicates import MessagePredicate, ReactionPredicate -from . import dataclasses +from . import audio_dataclasses from .apis import MusicCache, HAS_SQL, _ERROR from .checks import can_have_caching from .converters import ComplexScopeParser, ScopeParser, get_lazy_converter, get_playlist_converter @@ -142,7 +142,11 @@ class Audio(commands.Cog): self.play_lock = {} self._manager: Optional[ServerManager] = None - self.bot.dispatch("red_audio_initialized", self) + # These has to be a task since this requires the bot to be ready + # If it waits for ready in startup, we cause a deadlock during initial load + # as initial load happens before the bot can ever be ready. + self._init_task = self.bot.loop.create_task(self.initialize()) + self._ready_event = asyncio.Event() @property def owns_autoplay(self): @@ -166,9 +170,14 @@ class Audio(commands.Cog): self._cog_id = None async def cog_before_invoke(self, ctx: commands.Context): + await self._ready_event.wait() + # check for unsupported arch + # Check on this needs refactoring at a later date + # so that we have a better way to handle the tasks if self.llsetup in [ctx.command, ctx.command.root_parent]: pass - elif self._connect_task.cancelled(): + + elif self._connect_task and self._connect_task.cancelled(): await ctx.send( "You have attempted to run Audio's Lavalink server on an unsupported" " architecture. Only settings related commands will be available." @@ -176,6 +185,7 @@ class Audio(commands.Cog): raise RuntimeError( "Not running audio command due to invalid machine architecture for Lavalink." ) + dj_enabled = await self.config.guild(ctx.guild).dj_enabled() if dj_enabled: dj_role_obj = ctx.guild.get_role(await self.config.guild(ctx.guild).dj_role()) @@ -185,13 +195,13 @@ class Audio(commands.Cog): await self._embed_msg(ctx, _("No DJ role found. Disabling DJ mode.")) async def initialize(self): - pass_config_to_dependencies(self.config, self.bot, await self.config.localpath()) + await self.bot.wait_until_ready() + # Unlike most cases, we want the cache to exit before migration. await self.music_cache.initialize(self.config) - asyncio.ensure_future( - self._migrate_config( - from_version=await self.config.schema_version(), to_version=_SCHEMA_VERSION - ) + await self._migrate_config( + from_version=await self.config.schema_version(), to_version=_SCHEMA_VERSION ) + pass_config_to_dependencies(self.config, self.bot, await self.config.localpath()) self._restart_connect() self._disconnect_task = self.bot.loop.create_task(self.disconnect_timer()) lavalink.register_event_listener(self.event_handler) @@ -209,6 +219,9 @@ class Audio(commands.Cog): await self.bot.send_to_owners(page) log.critical(error_message) + self._ready_event.set() + self.bot.dispatch("red_audio_initialized", self) + async def _migrate_config(self, from_version: int, to_version: int): database_entries = [] time_now = str(datetime.datetime.now(datetime.timezone.utc)) @@ -253,7 +266,7 @@ class Audio(commands.Cog): cast(discord.Guild, discord.Object(id=guild_id)) ).clear_raw("playlists") if database_entries and HAS_SQL: - asyncio.ensure_future(self.music_cache.insert("lavalink", database_entries)) + await self.music_cache.insert("lavalink", database_entries) def _restart_connect(self): if self._connect_task: @@ -366,7 +379,9 @@ class Audio(commands.Cog): async def _players_check(): try: get_single_title = lavalink.active_players()[0].current.title - query = dataclasses.Query.process_input(lavalink.active_players()[0].current.uri) + query = audio_dataclasses.Query.process_input( + lavalink.active_players()[0].current.uri + ) if get_single_title == "Unknown title": get_single_title = lavalink.active_players()[0].current.uri if not get_single_title.startswith("http"): @@ -463,18 +478,18 @@ class Audio(commands.Cog): ) await notify_channel.send(embed=embed) - query = dataclasses.Query.process_input(player.current.uri) + query = audio_dataclasses.Query.process_input(player.current.uri) if query.is_local if player.current else False: if player.current.title != "Unknown title": description = "**{} - {}**\n{}".format( player.current.author, player.current.title, - dataclasses.LocalPath(player.current.uri).to_string_hidden(), + audio_dataclasses.LocalPath(player.current.uri).to_string_hidden(), ) else: description = "{}".format( - dataclasses.LocalPath(player.current.uri).to_string_hidden() + audio_dataclasses.LocalPath(player.current.uri).to_string_hidden() ) else: description = "**[{}]({})**".format(player.current.title, player.current.uri) @@ -532,9 +547,9 @@ class Audio(commands.Cog): message_channel = player.fetch("channel") if message_channel: message_channel = self.bot.get_channel(message_channel) - query = dataclasses.Query.process_input(player.current.uri) + query = audio_dataclasses.Query.process_input(player.current.uri) if player.current and query.is_local: - query = dataclasses.Query.process_input(player.current.uri) + query = audio_dataclasses.Query.process_input(player.current.uri) if player.current.title == "Unknown title": description = "{}".format(query.track.to_string_hidden()) else: @@ -590,7 +605,7 @@ class Audio(commands.Cog): player.store("channel", channel.id) player.store("guild", guild.id) await self._data_check(guild.me) - query = dataclasses.Query.process_input(query) + query = audio_dataclasses.Query.process_input(query) ctx = namedtuple("Context", "message") results, called_api = await self.music_cache.lavalink_query(ctx(guild), player, query) @@ -1094,7 +1109,7 @@ class Audio(commands.Cog): with contextlib.suppress(discord.HTTPException): await info.delete() return - temp = dataclasses.LocalPath(local_path, forced=True) + temp = audio_dataclasses.LocalPath(local_path, forced=True) if not temp.exists() or not temp.is_dir(): return await self._embed_msg( ctx, @@ -1536,7 +1551,7 @@ class Audio(commands.Cog): int((datetime.datetime.utcnow() - connect_start).total_seconds()) ) try: - query = dataclasses.Query.process_input(p.current.uri) + query = audio_dataclasses.Query.process_input(p.current.uri) if query.is_local: if p.current.title == "Unknown title": current_title = localtracks.LocalPath(p.current.uri).to_string_hidden() @@ -1606,9 +1621,9 @@ class Audio(commands.Cog): bump_song = player.queue[bump_index] player.queue.insert(0, bump_song) removed = player.queue.pop(index) - query = dataclasses.Query.process_input(removed.uri) + query = audio_dataclasses.Query.process_input(removed.uri) if query.is_local: - localtrack = dataclasses.LocalPath(removed.uri) + localtrack = audio_dataclasses.LocalPath(removed.uri) if removed.title != "Unknown title": description = "**{} - {}**\n{}".format( removed.author, removed.title, localtrack.to_string_hidden() @@ -1997,12 +2012,12 @@ class Audio(commands.Cog): await ctx.invoke(self.local_play, play_subfolders=play_subfolders) else: folder = folder.strip() - _dir = dataclasses.LocalPath.joinpath(folder) + _dir = audio_dataclasses.LocalPath.joinpath(folder) if not _dir.exists(): return await self._embed_msg( ctx, _("No localtracks folder named {name}.").format(name=folder) ) - query = dataclasses.Query.process_input(_dir, search_subfolders=play_subfolders) + query = audio_dataclasses.Query.process_input(_dir, search_subfolders=play_subfolders) await self._local_play_all(ctx, query, from_search=False if not folder else True) @local.command(name="play") @@ -2064,8 +2079,8 @@ class Audio(commands.Cog): all_tracks = await self._folder_list( ctx, ( - dataclasses.Query.process_input( - dataclasses.LocalPath( + audio_dataclasses.Query.process_input( + audio_dataclasses.LocalPath( await self.config.localpath() ).localtrack_folder.absolute(), search_subfolders=play_subfolders, @@ -2081,18 +2096,18 @@ class Audio(commands.Cog): return await ctx.invoke(self.search, query=search_list) async def _localtracks_folders(self, ctx: commands.Context, search_subfolders=False): - audio_data = dataclasses.LocalPath( - dataclasses.LocalPath(None).localtrack_folder.absolute() + audio_data = audio_dataclasses.LocalPath( + audio_dataclasses.LocalPath(None).localtrack_folder.absolute() ) if not await self._localtracks_check(ctx): return return audio_data.subfolders_in_tree() if search_subfolders else audio_data.subfolders() - async def _folder_list(self, ctx: commands.Context, query: dataclasses.Query): + async def _folder_list(self, ctx: commands.Context, query: audio_dataclasses.Query): if not await self._localtracks_check(ctx): return - query = dataclasses.Query.process_input(query) + query = audio_dataclasses.Query.process_input(query) if not query.track.exists(): return return ( @@ -2102,12 +2117,12 @@ class Audio(commands.Cog): ) async def _folder_tracks( - self, ctx, player: lavalink.player_manager.Player, query: dataclasses.Query + self, ctx, player: lavalink.player_manager.Player, query: audio_dataclasses.Query ): if not await self._localtracks_check(ctx): return - audio_data = dataclasses.LocalPath(None) + audio_data = audio_dataclasses.LocalPath(None) try: query.track.path.relative_to(audio_data.to_string()) except ValueError: @@ -2120,17 +2135,17 @@ class Audio(commands.Cog): return local_tracks async def _local_play_all( - self, ctx: commands.Context, query: dataclasses.Query, from_search=False + self, ctx: commands.Context, query: audio_dataclasses.Query, from_search=False ): if not await self._localtracks_check(ctx): return if from_search: - query = dataclasses.Query.process_input( + query = audio_dataclasses.Query.process_input( query.track.to_string(), invoked_from="local folder" ) await ctx.invoke(self.search, query=query) - async def _all_folder_tracks(self, ctx: commands.Context, query: dataclasses.Query): + async def _all_folder_tracks(self, ctx: commands.Context, query: audio_dataclasses.Query): if not await self._localtracks_check(ctx): return @@ -2141,7 +2156,7 @@ class Audio(commands.Cog): ) async def _localtracks_check(self, ctx: commands.Context): - folder = dataclasses.LocalPath(None) + folder = audio_dataclasses.LocalPath(None) if folder.localtrack_folder.exists(): return True if ctx.invoked_with != "start": @@ -2177,7 +2192,7 @@ class Audio(commands.Cog): dur = "LIVE" else: dur = lavalink.utils.format_time(player.current.length) - query = dataclasses.Query.process_input(player.current.uri) + query = audio_dataclasses.Query.process_input(player.current.uri) if query.is_local: if not player.current.title == "Unknown title": song = "**{track.author} - {track.title}**\n{uri}\n" @@ -2189,8 +2204,8 @@ class Audio(commands.Cog): song += "\n\n{arrow}`{pos}`/`{dur}`" song = song.format( track=player.current, - uri=dataclasses.LocalPath(player.current.uri).to_string_hidden() - if dataclasses.Query.process_input(player.current.uri).is_local + uri=audio_dataclasses.LocalPath(player.current.uri).to_string_hidden() + if audio_dataclasses.Query.process_input(player.current.uri).is_local else player.current.uri, arrow=arrow, pos=pos, @@ -2301,9 +2316,9 @@ class Audio(commands.Cog): if not player.current: return await self._embed_msg(ctx, _("Nothing playing.")) - query = dataclasses.Query.process_input(player.current.uri) + query = audio_dataclasses.Query.process_input(player.current.uri) if query.is_local: - query = dataclasses.Query.process_input(player.current.uri) + query = audio_dataclasses.Query.process_input(player.current.uri) if player.current.title == "Unknown title": description = "{}".format(query.track.to_string_hidden()) else: @@ -2436,7 +2451,7 @@ class Audio(commands.Cog): ) if not await self._currency_check(ctx, guild_data["jukebox_price"]): return - query = dataclasses.Query.process_input(query) + query = audio_dataclasses.Query.process_input(query) if not query.valid: return await self._embed_msg(ctx, _("No tracks to play.")) if query.is_spotify: @@ -2593,7 +2608,7 @@ class Audio(commands.Cog): ) playlists_search_page_list.append(embed) playlists_pick = await menu(ctx, playlists_search_page_list, playlist_search_controls) - query = dataclasses.Query.process_input(playlists_pick) + query = audio_dataclasses.Query.process_input(playlists_pick) if not query.valid: return await self._embed_msg(ctx, _("No tracks to play.")) if not await self._currency_check(ctx, guild_data["jukebox_price"]): @@ -2728,7 +2743,7 @@ class Audio(commands.Cog): elif player.current: await self._embed_msg(ctx, _("Adding a track to queue.")) - async def _get_spotify_tracks(self, ctx: commands.Context, query: dataclasses.Query): + async def _get_spotify_tracks(self, ctx: commands.Context, query: audio_dataclasses.Query): if ctx.invoked_with in ["play", "genre"]: enqueue_tracks = True else: @@ -2771,12 +2786,12 @@ class Audio(commands.Cog): self._play_lock(ctx, False) try: if enqueue_tracks: - new_query = dataclasses.Query.process_input(res[0]) + new_query = audio_dataclasses.Query.process_input(res[0]) new_query.start_time = query.start_time return await self._enqueue_tracks(ctx, new_query) else: result, called_api = await self.music_cache.lavalink_query( - ctx, player, dataclasses.Query.process_input(res[0]) + ctx, player, audio_dataclasses.Query.process_input(res[0]) ) tracks = result.tracks if not tracks: @@ -2808,7 +2823,9 @@ class Audio(commands.Cog): ctx, _("This doesn't seem to be a supported Spotify URL or code.") ) - async def _enqueue_tracks(self, ctx: commands.Context, query: Union[dataclasses.Query, list]): + async def _enqueue_tracks( + self, ctx: commands.Context, query: Union[audio_dataclasses.Query, list] + ): player = lavalink.get_player(ctx.guild.id) try: if self.play_lock[ctx.message.guild.id]: @@ -2863,7 +2880,7 @@ class Audio(commands.Cog): ctx.guild, ( f"{track.title} {track.author} {track.uri} " - f"{str(dataclasses.Query.process_input(track))}" + f"{str(audio_dataclasses.Query.process_input(track))}" ), ): log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})") @@ -2923,7 +2940,7 @@ class Audio(commands.Cog): ctx.guild, ( f"{single_track.title} {single_track.author} {single_track.uri} " - f"{str(dataclasses.Query.process_input(single_track))}" + f"{str(audio_dataclasses.Query.process_input(single_track))}" ), ): log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})") @@ -2956,17 +2973,17 @@ class Audio(commands.Cog): return await self._embed_msg( ctx, _("Nothing found. Check your Lavalink logs for details.") ) - query = dataclasses.Query.process_input(single_track.uri) + query = audio_dataclasses.Query.process_input(single_track.uri) if query.is_local: if single_track.title != "Unknown title": description = "**{} - {}**\n{}".format( single_track.author, single_track.title, - dataclasses.LocalPath(single_track.uri).to_string_hidden(), + audio_dataclasses.LocalPath(single_track.uri).to_string_hidden(), ) else: description = "{}".format( - dataclasses.LocalPath(single_track.uri).to_string_hidden() + audio_dataclasses.LocalPath(single_track.uri).to_string_hidden() ) else: description = "**[{}]({})**".format(single_track.title, single_track.uri) @@ -2987,7 +3004,11 @@ class Audio(commands.Cog): self._play_lock(ctx, False) async def _spotify_playlist( - self, ctx: commands.Context, stype: str, query: dataclasses.Query, enqueue: bool = False + self, + ctx: commands.Context, + stype: str, + query: audio_dataclasses.Query, + enqueue: bool = False, ): player = lavalink.get_player(ctx.guild.id) @@ -3340,7 +3361,7 @@ class Audio(commands.Cog): return player = lavalink.get_player(ctx.guild.id) to_append = await self._playlist_tracks( - ctx, player, dataclasses.Query.process_input(query) + ctx, player, audio_dataclasses.Query.process_input(query) ) if not to_append: return await self._embed_msg(ctx, _("Could not find a track matching your query.")) @@ -3993,7 +4014,7 @@ class Audio(commands.Cog): spaces = "\N{EN SPACE}" * (len(str(len(playlist.tracks))) + 2) for track in playlist.tracks: track_idx = track_idx + 1 - query = dataclasses.Query.process_input(track["info"]["uri"]) + query = audio_dataclasses.Query.process_input(track["info"]["uri"]) if query.is_local: if track["info"]["title"] != "Unknown title": msg += "`{}.` **{} - {}**\n{}{}\n".format( @@ -4398,7 +4419,7 @@ class Audio(commands.Cog): return player = lavalink.get_player(ctx.guild.id) tracklist = await self._playlist_tracks( - ctx, player, dataclasses.Query.process_input(playlist_url) + ctx, player, audio_dataclasses.Query.process_input(playlist_url) ) if tracklist is not None: playlist = await create_playlist( @@ -4488,14 +4509,14 @@ class Audio(commands.Cog): ctx.guild, ( f"{track.title} {track.author} {track.uri} " - f"{str(dataclasses.Query.process_input(track))}" + f"{str(audio_dataclasses.Query.process_input(track))}" ), ): log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})") continue - query = dataclasses.Query.process_input(track.uri) + query = audio_dataclasses.Query.process_input(track.uri) if query.is_local: - local_path = dataclasses.LocalPath(track.uri) + local_path = audio_dataclasses.LocalPath(track.uri) if not await self._localtracks_check(ctx): pass if not local_path.exists() and not local_path.is_file(): @@ -4781,7 +4802,7 @@ class Audio(commands.Cog): or not match_yt_playlist(uploaded_playlist_url) or not ( await self.music_cache.lavalink_query( - ctx, player, dataclasses.Query.process_input(uploaded_playlist_url) + ctx, player, audio_dataclasses.Query.process_input(uploaded_playlist_url) ) )[0].tracks ): @@ -4966,7 +4987,7 @@ class Audio(commands.Cog): } ) if database_entries and HAS_SQL: - asyncio.ensure_future(self.music_cache.insert("lavalink", database_entries)) + await self.music_cache.insert("lavalink", database_entries) async def _load_v2_playlist( self, @@ -4993,7 +5014,7 @@ class Audio(commands.Cog): track_count += 1 try: result, called_api = await self.music_cache.lavalink_query( - ctx, player, dataclasses.Query.process_input(song_url) + ctx, player, audio_dataclasses.Query.process_input(song_url) ) track = result.tracks except Exception: @@ -5041,7 +5062,7 @@ class Audio(commands.Cog): return [], [], playlist results = {} updated_tracks = await self._playlist_tracks( - ctx, player, dataclasses.Query.process_input(playlist.url) + ctx, player, audio_dataclasses.Query.process_input(playlist.url) ) if not updated_tracks: # No Tracks available on url Lets set it to none to avoid repeated calls here @@ -5106,7 +5127,7 @@ class Audio(commands.Cog): self, ctx: commands.Context, player: lavalink.player_manager.Player, - query: dataclasses.Query, + query: audio_dataclasses.Query, ): search = query.is_search tracklist = [] @@ -5175,7 +5196,7 @@ class Audio(commands.Cog): player.queue.insert(0, bump_song) player.queue.pop(queue_len) await player.skip() - query = dataclasses.Query.process_input(player.current.uri) + query = audio_dataclasses.Query.process_input(player.current.uri) if query.is_local: if player.current.title == "Unknown title": @@ -5227,7 +5248,7 @@ class Audio(commands.Cog): else: dur = lavalink.utils.format_time(player.current.length) - query = dataclasses.Query.process_input(player.current) + query = audio_dataclasses.Query.process_input(player.current) if query.is_local: if player.current.title != "Unknown title": @@ -5240,8 +5261,8 @@ class Audio(commands.Cog): song += "\n\n{arrow}`{pos}`/`{dur}`" song = song.format( track=player.current, - uri=dataclasses.LocalPath(player.current.uri).to_string_hidden() - if dataclasses.Query.process_input(player.current.uri).is_local + uri=audio_dataclasses.LocalPath(player.current.uri).to_string_hidden() + if audio_dataclasses.Query.process_input(player.current.uri).is_local else player.current.uri, arrow=arrow, pos=pos, @@ -5313,7 +5334,7 @@ class Audio(commands.Cog): else: dur = lavalink.utils.format_time(player.current.length) - query = dataclasses.Query.process_input(player.current) + query = audio_dataclasses.Query.process_input(player.current) if query.is_stream: queue_list += _("**Currently livestreaming:**\n") @@ -5327,7 +5348,7 @@ class Audio(commands.Cog): ( _("Playing: ") + "**{current.author} - {current.title}**".format(current=player.current), - dataclasses.LocalPath(player.current.uri).to_string_hidden(), + audio_dataclasses.LocalPath(player.current.uri).to_string_hidden(), _("Requested by: **{user}**\n").format(user=player.current.requester), f"{arrow}`{pos}`/`{dur}`\n\n", ) @@ -5336,7 +5357,7 @@ class Audio(commands.Cog): queue_list += "\n".join( ( _("Playing: ") - + dataclasses.LocalPath(player.current.uri).to_string_hidden(), + + audio_dataclasses.LocalPath(player.current.uri).to_string_hidden(), _("Requested by: **{user}**\n").format(user=player.current.requester), f"{arrow}`{pos}`/`{dur}`\n\n", ) @@ -5357,13 +5378,13 @@ class Audio(commands.Cog): track_title = track.title req_user = track.requester track_idx = i + 1 - query = dataclasses.Query.process_input(track) + query = audio_dataclasses.Query.process_input(track) if query.is_local: if track.title == "Unknown title": queue_list += f"`{track_idx}.` " + ", ".join( ( - bold(dataclasses.LocalPath(track.uri).to_string_hidden()), + bold(audio_dataclasses.LocalPath(track.uri).to_string_hidden()), _("requested by **{user}**\n").format(user=req_user), ) ) @@ -5420,7 +5441,7 @@ class Audio(commands.Cog): for track in queue_list: queue_idx = queue_idx + 1 if not match_url(track.uri): - query = dataclasses.Query.process_input(track) + query = audio_dataclasses.Query.process_input(track) if track.title == "Unknown title": track_title = query.track.to_string_hidden() else: @@ -5449,7 +5470,7 @@ class Audio(commands.Cog): ): track_idx = i + 1 if type(track) is str: - track_location = dataclasses.LocalPath(track).to_string_hidden() + track_location = audio_dataclasses.LocalPath(track).to_string_hidden() track_match += "`{}.` **{}**\n".format(track_idx, track_location) else: track_match += "`{}.` **{}**\n".format(track[0], track[1]) @@ -5674,9 +5695,9 @@ class Audio(commands.Cog): ) index -= 1 removed = player.queue.pop(index) - query = dataclasses.Query.process_input(removed.uri) + query = audio_dataclasses.Query.process_input(removed.uri) if query.is_local: - local_path = dataclasses.LocalPath(removed.uri).to_string_hidden() + local_path = audio_dataclasses.LocalPath(removed.uri).to_string_hidden() if removed.title == "Unknown title": removed_title = local_path else: @@ -5762,7 +5783,7 @@ class Audio(commands.Cog): await self._data_check(ctx) if not isinstance(query, list): - query = dataclasses.Query.process_input(query) + query = audio_dataclasses.Query.process_input(query) if query.invoked_from == "search list" or query.invoked_from == "local folder": if query.invoked_from == "search list": result, called_api = await self.music_cache.lavalink_query(ctx, player, query) @@ -5791,7 +5812,7 @@ class Audio(commands.Cog): ctx.guild, ( f"{track.title} {track.author} {track.uri} " - f"{str(dataclasses.Query.process_input(track))}" + f"{str(audio_dataclasses.Query.process_input(track))}" ), ): log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})") @@ -5905,10 +5926,10 @@ class Audio(commands.Cog): except IndexError: search_choice = tracks[-1] try: - query = dataclasses.Query.process_input(search_choice.uri) + query = audio_dataclasses.Query.process_input(search_choice.uri) if query.is_local: - localtrack = dataclasses.LocalPath(search_choice.uri) + localtrack = audio_dataclasses.LocalPath(search_choice.uri) if search_choice.title != "Unknown title": description = "**{} - {}**\n{}".format( search_choice.author, search_choice.title, localtrack.to_string_hidden() @@ -5919,7 +5940,7 @@ class Audio(commands.Cog): description = "**[{}]({})**".format(search_choice.title, search_choice.uri) except AttributeError: - search_choice = dataclasses.Query.process_input(search_choice) + search_choice = audio_dataclasses.Query.process_input(search_choice) if search_choice.track.exists() and search_choice.track.is_dir(): return await ctx.invoke(self.search, query=search_choice) elif search_choice.track.exists() and search_choice.track.is_file(): @@ -5935,7 +5956,7 @@ class Audio(commands.Cog): ctx.guild, ( f"{search_choice.title} {search_choice.author} {search_choice.uri} " - f"{str(dataclasses.Query.process_input(search_choice))}" + f"{str(audio_dataclasses.Query.process_input(search_choice))}" ), ): log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})") @@ -5984,12 +6005,12 @@ class Audio(commands.Cog): if search_track_num == 0: search_track_num = 5 try: - query = dataclasses.Query.process_input(track.uri) + query = audio_dataclasses.Query.process_input(track.uri) if query.is_local: search_list += "`{0}.` **{1}**\n[{2}]\n".format( search_track_num, track.title, - dataclasses.LocalPath(track.uri).to_string_hidden(), + audio_dataclasses.LocalPath(track.uri).to_string_hidden(), ) else: search_list += "`{0}.` **[{1}]({2})**\n".format( @@ -5997,7 +6018,7 @@ class Audio(commands.Cog): ) except AttributeError: # query = Query.process_input(track) - track = dataclasses.Query.process_input(track) + track = audio_dataclasses.Query.process_input(track) if track.is_local and command != "search": search_list += "`{}.` **{}**\n".format( search_track_num, track.to_string_user() @@ -6890,6 +6911,7 @@ class Audio(commands.Cog): async def on_voice_state_update( self, member: discord.Member, before: discord.VoiceState, after: discord.VoiceState ): + await self._ready_event.wait() if after.channel != before.channel: try: self.skip_votes[before.channel.guild].remove(member.id) @@ -6907,6 +6929,9 @@ class Audio(commands.Cog): if self._connect_task: self._connect_task.cancel() + if self._init_task: + self._init_task.cancel() + lavalink.unregister_event_listener(self.event_handler) self.bot.loop.create_task(lavalink.close()) if self._manager is not None: diff --git a/redbot/cogs/audio/dataclasses.py b/redbot/cogs/audio/audio_dataclasses.py similarity index 100% rename from redbot/cogs/audio/dataclasses.py rename to redbot/cogs/audio/audio_dataclasses.py diff --git a/redbot/cogs/audio/utils.py b/redbot/cogs/audio/utils.py index 3f0d9972a..682d36d50 100644 --- a/redbot/cogs/audio/utils.py +++ b/redbot/cogs/audio/utils.py @@ -3,7 +3,6 @@ import contextlib import os import re import time -from typing import NoReturn from urllib.parse import urlparse import discord @@ -11,7 +10,7 @@ import lavalink from redbot.core import Config, commands from redbot.core.bot import Red -from . import dataclasses +from . import audio_dataclasses from .converters import _pass_config_to_converters @@ -51,7 +50,7 @@ def pass_config_to_dependencies(config: Config, bot: Red, localtracks_folder: st _config = config _pass_config_to_playlist(config, bot) _pass_config_to_converters(config, bot) - dataclasses._pass_config_to_dataclasses(config, bot, localtracks_folder) + audio_dataclasses._pass_config_to_dataclasses(config, bot, localtracks_folder) def track_limit(track, maxlength): @@ -168,7 +167,7 @@ async def clear_react(bot: Red, message: discord.Message, emoji: dict = None): async def get_description(track): if any(x in track.uri for x in [f"{os.sep}localtracks", f"localtracks{os.sep}"]): - local_track = dataclasses.LocalPath(track.uri) + local_track = audio_dataclasses.LocalPath(track.uri) if track.title != "Unknown title": return "**{} - {}**\n{}".format( track.author, track.title, local_track.to_string_hidden() @@ -389,7 +388,7 @@ class Notifier: key: str = None, seconds_key: str = None, seconds: str = None, - ) -> NoReturn: + ): """ This updates an existing message. Based on the message found in :variable:`Notifier.updates` as per the `key` param @@ -410,14 +409,14 @@ class Notifier: except discord.errors.NotFound: pass - async def update_text(self, text: str) -> NoReturn: + async def update_text(self, text: str): embed2 = discord.Embed(colour=self.color, title=text) try: await self.message.edit(embed=embed2) except discord.errors.NotFound: pass - async def update_embed(self, embed: discord.Embed) -> NoReturn: + async def update_embed(self, embed: discord.Embed): try: await self.message.edit(embed=embed) self.last_msg_time = time.time() diff --git a/redbot/core/bot.py b/redbot/core/bot.py index a7becdf73..d2d26cd70 100644 --- a/redbot/core/bot.py +++ b/redbot/core/bot.py @@ -132,7 +132,6 @@ class RedBase(commands.GroupMixin, commands.bot.BotBase, RPCMixin): # pylint: d self._main_dir = bot_dir self._cog_mgr = CogManager() - super().__init__(*args, help_command=None, **kwargs) # Do not manually use the help formatter attribute here, see `send_help_for`, # for a documented API. The internals of this object are still subject to change. @@ -325,6 +324,7 @@ class RedBase(commands.GroupMixin, commands.bot.BotBase, RPCMixin): # pylint: d get_embed_colour = get_embed_color + # start config migrations async def _maybe_update_config(self): """ This should be run prior to loading cogs or connecting to discord. @@ -375,6 +375,57 @@ class RedBase(commands.GroupMixin, commands.bot.BotBase, RPCMixin): # pylint: d await self._config.guild(guild_obj).admin_role.set(admin_roles) log.info("Done updating guild configs to support multiple mod/admin roles") + # end Config migrations + + async def pre_flight(self, cli_flags): + """ + This should only be run once, prior to connecting to discord. + """ + await self._maybe_update_config() + + packages = [] + + if cli_flags.no_cogs is False: + packages.extend(await self._config.packages()) + + if cli_flags.load_cogs: + packages.extend(cli_flags.load_cogs) + + if packages: + # Load permissions first, for security reasons + try: + packages.remove("permissions") + except ValueError: + pass + else: + packages.insert(0, "permissions") + + to_remove = [] + print("Loading packages...") + for package in packages: + try: + spec = await self._cog_mgr.find_cog(package) + await asyncio.wait_for(self.load_extension(spec), 30) + except asyncio.TimeoutError: + log.exception("Failed to load package %s (timeout)", package) + to_remove.append(package) + except Exception as e: + log.exception("Failed to load package {}".format(package), exc_info=e) + await self.remove_loaded_package(package) + to_remove.append(package) + for package in to_remove: + packages.remove(package) + if packages: + print("Loaded packages: " + ", ".join(packages)) + + if self.rpc_enabled: + await self.rpc.initialize(self.rpc_port) + + async def start(self, *args, **kwargs): + cli_flags = kwargs.pop("cli_flags") + await self.pre_flight(cli_flags=cli_flags) + return await super().start(*args, **kwargs) + async def send_help_for( self, ctx: commands.Context, help_for: Union[commands.Command, commands.GroupMixin, str] ): diff --git a/redbot/core/events.py b/redbot/core/events.py index 20970e670..5fd418d36 100644 --- a/redbot/core/events.py +++ b/redbot/core/events.py @@ -46,40 +46,6 @@ def init_events(bot, cli_flags): return bot._uptime = datetime.datetime.utcnow() - packages = [] - - if cli_flags.no_cogs is False: - packages.extend(await bot._config.packages()) - - if cli_flags.load_cogs: - packages.extend(cli_flags.load_cogs) - - if packages: - # Load permissions first, for security reasons - try: - packages.remove("permissions") - except ValueError: - pass - else: - packages.insert(0, "permissions") - - to_remove = [] - print("Loading packages...") - for package in packages: - try: - spec = await bot._cog_mgr.find_cog(package) - await bot.load_extension(spec) - except Exception as e: - log.exception("Failed to load package {}".format(package), exc_info=e) - await bot.remove_loaded_package(package) - to_remove.append(package) - for package in to_remove: - packages.remove(package) - if packages: - print("Loaded packages: " + ", ".join(packages)) - - if bot.rpc_enabled: - await bot.rpc.initialize(bot.rpc_port) guilds = len(bot.guilds) users = len(set([m for m in bot.get_all_members()]))