reorder some startup to prevent heartbeat issues (#3073)

* reorder some startup to prevent heartbeat issues

* changelog

* handle startup cleanup in audio

* style

* rebased to handle conflict

* be a little smarter to prevent (some) infinite hangs

* Fix a pre-existing NoneType Error

* Migrate config before things are using it...

* another place we should ensure we're ready

* rename-toavoid-issues

* fix cache ordering and mis-use of ensure_future

* remove incorrect typehints

* style
This commit is contained in:
Michael H 2019-11-09 14:19:57 -05:00 committed by GitHub
parent 6852b7a1d1
commit b3363acf77
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
10 changed files with 192 additions and 156 deletions

View File

@ -0,0 +1 @@
Bot now handles more things prior to connecting to discord to reduce issues with initial load

View File

@ -0,0 +1 @@
``bot.wait_until_ready`` should no longer be used during extension setup

View File

@ -312,7 +312,7 @@ def main():
loop.run_until_complete(red.http.close()) loop.run_until_complete(red.http.close())
sys.exit(0) sys.exit(0)
try: try:
loop.run_until_complete(red.start(token, bot=True)) loop.run_until_complete(red.start(token, bot=True, cli_flags=cli_flags))
except discord.LoginFailure: except discord.LoginFailure:
log.critical("This token doesn't seem to be valid.") log.critical("This token doesn't seem to be valid.")
db_token = loop.run_until_complete(red._config.token()) db_token = loop.run_until_complete(red._config.token())

View File

@ -3,7 +3,6 @@ from redbot.core import commands
from .audio import Audio from .audio import Audio
async def setup(bot: commands.Bot): def setup(bot: commands.Bot):
cog = Audio(bot) cog = Audio(bot)
await cog.initialize()
bot.add_cog(cog) bot.add_cog(cog)

View File

@ -9,7 +9,7 @@ import random
import time import time
import traceback import traceback
from collections import namedtuple from collections import namedtuple
from typing import Callable, Dict, List, Mapping, NoReturn, Optional, Tuple, Union from typing import Callable, Dict, List, Mapping, Optional, Tuple, Union
try: try:
from sqlite3 import Error as SQLError from sqlite3 import Error as SQLError
@ -32,7 +32,7 @@ from lavalink.rest_api import LoadResult
from redbot.core import Config, commands from redbot.core import Config, commands
from redbot.core.bot import Red from redbot.core.bot import Red
from redbot.core.i18n import Translator, cog_i18n from redbot.core.i18n import Translator, cog_i18n
from . import dataclasses from . import audio_dataclasses
from .errors import InvalidTableError, SpotifyFetchError, YouTubeApiError from .errors import InvalidTableError, SpotifyFetchError, YouTubeApiError
from .playlists import get_playlist from .playlists import get_playlist
from .utils import CacheLevel, Notifier, is_allowed, queue_duration, track_limit from .utils import CacheLevel, Notifier, is_allowed, queue_duration, track_limit
@ -193,7 +193,7 @@ class SpotifyAPI:
) )
return await r.json() return await r.json()
async def _get_auth(self) -> NoReturn: async def _get_auth(self):
if self.client_id is None or self.client_secret is None: if self.client_id is None or self.client_secret is None:
tokens = await self.bot.get_shared_api_tokens("spotify") tokens = await self.bot.get_shared_api_tokens("spotify")
self.client_id = tokens.get("client_id", "") self.client_id = tokens.get("client_id", "")
@ -331,7 +331,7 @@ class MusicCache:
self._lock: asyncio.Lock = asyncio.Lock() self._lock: asyncio.Lock = asyncio.Lock()
self.config: Optional[Config] = None self.config: Optional[Config] = None
async def initialize(self, config: Config) -> NoReturn: async def initialize(self, config: Config):
if HAS_SQL: if HAS_SQL:
await self.database.connect() await self.database.connect()
@ -348,12 +348,12 @@ class MusicCache:
await self.database.execute(query=_CREATE_UNIQUE_INDEX_SPOTIFY_TABLE) await self.database.execute(query=_CREATE_UNIQUE_INDEX_SPOTIFY_TABLE)
self.config = config self.config = config
async def close(self) -> NoReturn: async def close(self):
if HAS_SQL: if HAS_SQL:
await self.database.execute(query="PRAGMA optimize;") await self.database.execute(query="PRAGMA optimize;")
await self.database.disconnect() await self.database.disconnect()
async def insert(self, table: str, values: List[dict]) -> NoReturn: async def insert(self, table: str, values: List[dict]):
# if table == "spotify": # if table == "spotify":
# return # return
if HAS_SQL: if HAS_SQL:
@ -363,7 +363,7 @@ class MusicCache:
await self.database.execute_many(query=query, values=values) await self.database.execute_many(query=query, values=values)
async def update(self, table: str, values: Dict[str, str]) -> NoReturn: async def update(self, table: str, values: Dict[str, str]):
# if table == "spotify": # if table == "spotify":
# return # return
if HAS_SQL: if HAS_SQL:
@ -746,7 +746,7 @@ class MusicCache:
if val: if val:
try: try:
result, called_api = await self.lavalink_query( result, called_api = await self.lavalink_query(
ctx, player, dataclasses.Query.process_input(val) ctx, player, audio_dataclasses.Query.process_input(val)
) )
except (RuntimeError, aiohttp.ServerDisconnectedError): except (RuntimeError, aiohttp.ServerDisconnectedError):
lock(ctx, False) lock(ctx, False)
@ -805,7 +805,7 @@ class MusicCache:
ctx.guild, ctx.guild,
( (
f"{single_track.title} {single_track.author} {single_track.uri} " f"{single_track.title} {single_track.author} {single_track.uri} "
f"{str(dataclasses.Query.process_input(single_track))}" f"{str(audio_dataclasses.Query.process_input(single_track))}"
), ),
): ):
has_not_allowed = True has_not_allowed = True
@ -911,7 +911,7 @@ class MusicCache:
self, self,
ctx: commands.Context, ctx: commands.Context,
player: lavalink.Player, player: lavalink.Player,
query: dataclasses.Query, query: audio_dataclasses.Query,
forced: bool = False, forced: bool = False,
) -> Tuple[LoadResult, bool]: ) -> Tuple[LoadResult, bool]:
""" """
@ -925,7 +925,7 @@ class MusicCache:
The context this method is being called under. The context this method is being called under.
player : lavalink.Player player : lavalink.Player
The player who's requesting the query. The player who's requesting the query.
query: dataclasses.Query query: audio_dataclasses.Query
The Query object for the query in question. The Query object for the query in question.
forced:bool forced:bool
Whether or not to skip cache and call API first.. Whether or not to skip cache and call API first..
@ -939,7 +939,7 @@ class MusicCache:
) )
cache_enabled = CacheLevel.set_lavalink().is_subset(current_cache_level) cache_enabled = CacheLevel.set_lavalink().is_subset(current_cache_level)
val = None val = None
_raw_query = dataclasses.Query.process_input(query) _raw_query = audio_dataclasses.Query.process_input(query)
query = str(_raw_query) query = str(_raw_query)
if cache_enabled and not forced and not _raw_query.is_local: if cache_enabled and not forced and not _raw_query.is_local:
update = True update = True
@ -1003,14 +1003,10 @@ class MusicCache:
tasks = self._tasks[ctx.message.id] tasks = self._tasks[ctx.message.id]
del self._tasks[ctx.message.id] del self._tasks[ctx.message.id]
await asyncio.gather( await asyncio.gather(
*[asyncio.ensure_future(self.insert(*a)) for a in tasks["insert"]], *[self.insert(*a) for a in tasks["insert"]], return_exceptions=True
loop=self.bot.loop,
return_exceptions=True,
) )
await asyncio.gather( await asyncio.gather(
*[asyncio.ensure_future(self.update(*a)) for a in tasks["update"]], *[self.update(*a) for a in tasks["update"]], return_exceptions=True
loop=self.bot.loop,
return_exceptions=True,
) )
log.debug(f"Completed database writes for {lock_id} " f"({lock_author})") log.debug(f"Completed database writes for {lock_id} " f"({lock_author})")
@ -1025,14 +1021,10 @@ class MusicCache:
self._tasks = {} self._tasks = {}
await asyncio.gather( await asyncio.gather(
*[asyncio.ensure_future(self.insert(*a)) for a in tasks["insert"]], *[self.insert(*a) for a in tasks["insert"]], return_exceptions=True
loop=self.bot.loop,
return_exceptions=True,
) )
await asyncio.gather( await asyncio.gather(
*[asyncio.ensure_future(self.update(*a)) for a in tasks["update"]], *[self.update(*a) for a in tasks["update"]], return_exceptions=True
loop=self.bot.loop,
return_exceptions=True,
) )
log.debug("Completed pending writes to database have finished") log.debug("Completed pending writes to database have finished")
@ -1096,7 +1088,9 @@ class MusicCache:
if not tracks: if not tracks:
ctx = namedtuple("Context", "message") ctx = namedtuple("Context", "message")
results, called_api = await self.lavalink_query( results, called_api = await self.lavalink_query(
ctx(player.channel.guild), player, dataclasses.Query.process_input(_TOP_100_US) ctx(player.channel.guild),
player,
audio_dataclasses.Query.process_input(_TOP_100_US),
) )
tracks = list(results.tracks) tracks = list(results.tracks)
if tracks: if tracks:
@ -1107,7 +1101,7 @@ class MusicCache:
while valid is False and multiple: while valid is False and multiple:
track = random.choice(tracks) track = random.choice(tracks)
query = dataclasses.Query.process_input(track) query = audio_dataclasses.Query.process_input(track)
if not query.valid: if not query.valid:
continue continue
if query.is_local and not query.track.exists(): if query.is_local and not query.track.exists():
@ -1116,7 +1110,7 @@ class MusicCache:
player.channel.guild, player.channel.guild,
( (
f"{track.title} {track.author} {track.uri} " f"{track.title} {track.author} {track.uri} "
f"{str(dataclasses.Query.process_input(track))}" f"{str(audio_dataclasses.Query.process_input(track))}"
), ),
): ):
log.debug( log.debug(

View File

@ -34,7 +34,7 @@ from redbot.core.utils.menus import (
start_adding_reactions, start_adding_reactions,
) )
from redbot.core.utils.predicates import MessagePredicate, ReactionPredicate from redbot.core.utils.predicates import MessagePredicate, ReactionPredicate
from . import dataclasses from . import audio_dataclasses
from .apis import MusicCache, HAS_SQL, _ERROR from .apis import MusicCache, HAS_SQL, _ERROR
from .checks import can_have_caching from .checks import can_have_caching
from .converters import ComplexScopeParser, ScopeParser, get_lazy_converter, get_playlist_converter from .converters import ComplexScopeParser, ScopeParser, get_lazy_converter, get_playlist_converter
@ -142,7 +142,11 @@ class Audio(commands.Cog):
self.play_lock = {} self.play_lock = {}
self._manager: Optional[ServerManager] = None self._manager: Optional[ServerManager] = None
self.bot.dispatch("red_audio_initialized", self) # These has to be a task since this requires the bot to be ready
# If it waits for ready in startup, we cause a deadlock during initial load
# as initial load happens before the bot can ever be ready.
self._init_task = self.bot.loop.create_task(self.initialize())
self._ready_event = asyncio.Event()
@property @property
def owns_autoplay(self): def owns_autoplay(self):
@ -166,9 +170,14 @@ class Audio(commands.Cog):
self._cog_id = None self._cog_id = None
async def cog_before_invoke(self, ctx: commands.Context): async def cog_before_invoke(self, ctx: commands.Context):
await self._ready_event.wait()
# check for unsupported arch
# Check on this needs refactoring at a later date
# so that we have a better way to handle the tasks
if self.llsetup in [ctx.command, ctx.command.root_parent]: if self.llsetup in [ctx.command, ctx.command.root_parent]:
pass pass
elif self._connect_task.cancelled():
elif self._connect_task and self._connect_task.cancelled():
await ctx.send( await ctx.send(
"You have attempted to run Audio's Lavalink server on an unsupported" "You have attempted to run Audio's Lavalink server on an unsupported"
" architecture. Only settings related commands will be available." " architecture. Only settings related commands will be available."
@ -176,6 +185,7 @@ class Audio(commands.Cog):
raise RuntimeError( raise RuntimeError(
"Not running audio command due to invalid machine architecture for Lavalink." "Not running audio command due to invalid machine architecture for Lavalink."
) )
dj_enabled = await self.config.guild(ctx.guild).dj_enabled() dj_enabled = await self.config.guild(ctx.guild).dj_enabled()
if dj_enabled: if dj_enabled:
dj_role_obj = ctx.guild.get_role(await self.config.guild(ctx.guild).dj_role()) dj_role_obj = ctx.guild.get_role(await self.config.guild(ctx.guild).dj_role())
@ -185,13 +195,13 @@ class Audio(commands.Cog):
await self._embed_msg(ctx, _("No DJ role found. Disabling DJ mode.")) await self._embed_msg(ctx, _("No DJ role found. Disabling DJ mode."))
async def initialize(self): async def initialize(self):
pass_config_to_dependencies(self.config, self.bot, await self.config.localpath()) await self.bot.wait_until_ready()
# Unlike most cases, we want the cache to exit before migration.
await self.music_cache.initialize(self.config) await self.music_cache.initialize(self.config)
asyncio.ensure_future( await self._migrate_config(
self._migrate_config(
from_version=await self.config.schema_version(), to_version=_SCHEMA_VERSION from_version=await self.config.schema_version(), to_version=_SCHEMA_VERSION
) )
) pass_config_to_dependencies(self.config, self.bot, await self.config.localpath())
self._restart_connect() self._restart_connect()
self._disconnect_task = self.bot.loop.create_task(self.disconnect_timer()) self._disconnect_task = self.bot.loop.create_task(self.disconnect_timer())
lavalink.register_event_listener(self.event_handler) lavalink.register_event_listener(self.event_handler)
@ -209,6 +219,9 @@ class Audio(commands.Cog):
await self.bot.send_to_owners(page) await self.bot.send_to_owners(page)
log.critical(error_message) log.critical(error_message)
self._ready_event.set()
self.bot.dispatch("red_audio_initialized", self)
async def _migrate_config(self, from_version: int, to_version: int): async def _migrate_config(self, from_version: int, to_version: int):
database_entries = [] database_entries = []
time_now = str(datetime.datetime.now(datetime.timezone.utc)) time_now = str(datetime.datetime.now(datetime.timezone.utc))
@ -253,7 +266,7 @@ class Audio(commands.Cog):
cast(discord.Guild, discord.Object(id=guild_id)) cast(discord.Guild, discord.Object(id=guild_id))
).clear_raw("playlists") ).clear_raw("playlists")
if database_entries and HAS_SQL: if database_entries and HAS_SQL:
asyncio.ensure_future(self.music_cache.insert("lavalink", database_entries)) await self.music_cache.insert("lavalink", database_entries)
def _restart_connect(self): def _restart_connect(self):
if self._connect_task: if self._connect_task:
@ -366,7 +379,9 @@ class Audio(commands.Cog):
async def _players_check(): async def _players_check():
try: try:
get_single_title = lavalink.active_players()[0].current.title get_single_title = lavalink.active_players()[0].current.title
query = dataclasses.Query.process_input(lavalink.active_players()[0].current.uri) query = audio_dataclasses.Query.process_input(
lavalink.active_players()[0].current.uri
)
if get_single_title == "Unknown title": if get_single_title == "Unknown title":
get_single_title = lavalink.active_players()[0].current.uri get_single_title = lavalink.active_players()[0].current.uri
if not get_single_title.startswith("http"): if not get_single_title.startswith("http"):
@ -463,18 +478,18 @@ class Audio(commands.Cog):
) )
await notify_channel.send(embed=embed) await notify_channel.send(embed=embed)
query = dataclasses.Query.process_input(player.current.uri) query = audio_dataclasses.Query.process_input(player.current.uri)
if query.is_local if player.current else False: if query.is_local if player.current else False:
if player.current.title != "Unknown title": if player.current.title != "Unknown title":
description = "**{} - {}**\n{}".format( description = "**{} - {}**\n{}".format(
player.current.author, player.current.author,
player.current.title, player.current.title,
dataclasses.LocalPath(player.current.uri).to_string_hidden(), audio_dataclasses.LocalPath(player.current.uri).to_string_hidden(),
) )
else: else:
description = "{}".format( description = "{}".format(
dataclasses.LocalPath(player.current.uri).to_string_hidden() audio_dataclasses.LocalPath(player.current.uri).to_string_hidden()
) )
else: else:
description = "**[{}]({})**".format(player.current.title, player.current.uri) description = "**[{}]({})**".format(player.current.title, player.current.uri)
@ -532,9 +547,9 @@ class Audio(commands.Cog):
message_channel = player.fetch("channel") message_channel = player.fetch("channel")
if message_channel: if message_channel:
message_channel = self.bot.get_channel(message_channel) message_channel = self.bot.get_channel(message_channel)
query = dataclasses.Query.process_input(player.current.uri) query = audio_dataclasses.Query.process_input(player.current.uri)
if player.current and query.is_local: if player.current and query.is_local:
query = dataclasses.Query.process_input(player.current.uri) query = audio_dataclasses.Query.process_input(player.current.uri)
if player.current.title == "Unknown title": if player.current.title == "Unknown title":
description = "{}".format(query.track.to_string_hidden()) description = "{}".format(query.track.to_string_hidden())
else: else:
@ -590,7 +605,7 @@ class Audio(commands.Cog):
player.store("channel", channel.id) player.store("channel", channel.id)
player.store("guild", guild.id) player.store("guild", guild.id)
await self._data_check(guild.me) await self._data_check(guild.me)
query = dataclasses.Query.process_input(query) query = audio_dataclasses.Query.process_input(query)
ctx = namedtuple("Context", "message") ctx = namedtuple("Context", "message")
results, called_api = await self.music_cache.lavalink_query(ctx(guild), player, query) results, called_api = await self.music_cache.lavalink_query(ctx(guild), player, query)
@ -1094,7 +1109,7 @@ class Audio(commands.Cog):
with contextlib.suppress(discord.HTTPException): with contextlib.suppress(discord.HTTPException):
await info.delete() await info.delete()
return return
temp = dataclasses.LocalPath(local_path, forced=True) temp = audio_dataclasses.LocalPath(local_path, forced=True)
if not temp.exists() or not temp.is_dir(): if not temp.exists() or not temp.is_dir():
return await self._embed_msg( return await self._embed_msg(
ctx, ctx,
@ -1536,7 +1551,7 @@ class Audio(commands.Cog):
int((datetime.datetime.utcnow() - connect_start).total_seconds()) int((datetime.datetime.utcnow() - connect_start).total_seconds())
) )
try: try:
query = dataclasses.Query.process_input(p.current.uri) query = audio_dataclasses.Query.process_input(p.current.uri)
if query.is_local: if query.is_local:
if p.current.title == "Unknown title": if p.current.title == "Unknown title":
current_title = localtracks.LocalPath(p.current.uri).to_string_hidden() current_title = localtracks.LocalPath(p.current.uri).to_string_hidden()
@ -1606,9 +1621,9 @@ class Audio(commands.Cog):
bump_song = player.queue[bump_index] bump_song = player.queue[bump_index]
player.queue.insert(0, bump_song) player.queue.insert(0, bump_song)
removed = player.queue.pop(index) removed = player.queue.pop(index)
query = dataclasses.Query.process_input(removed.uri) query = audio_dataclasses.Query.process_input(removed.uri)
if query.is_local: if query.is_local:
localtrack = dataclasses.LocalPath(removed.uri) localtrack = audio_dataclasses.LocalPath(removed.uri)
if removed.title != "Unknown title": if removed.title != "Unknown title":
description = "**{} - {}**\n{}".format( description = "**{} - {}**\n{}".format(
removed.author, removed.title, localtrack.to_string_hidden() removed.author, removed.title, localtrack.to_string_hidden()
@ -1997,12 +2012,12 @@ class Audio(commands.Cog):
await ctx.invoke(self.local_play, play_subfolders=play_subfolders) await ctx.invoke(self.local_play, play_subfolders=play_subfolders)
else: else:
folder = folder.strip() folder = folder.strip()
_dir = dataclasses.LocalPath.joinpath(folder) _dir = audio_dataclasses.LocalPath.joinpath(folder)
if not _dir.exists(): if not _dir.exists():
return await self._embed_msg( return await self._embed_msg(
ctx, _("No localtracks folder named {name}.").format(name=folder) ctx, _("No localtracks folder named {name}.").format(name=folder)
) )
query = dataclasses.Query.process_input(_dir, search_subfolders=play_subfolders) query = audio_dataclasses.Query.process_input(_dir, search_subfolders=play_subfolders)
await self._local_play_all(ctx, query, from_search=False if not folder else True) await self._local_play_all(ctx, query, from_search=False if not folder else True)
@local.command(name="play") @local.command(name="play")
@ -2064,8 +2079,8 @@ class Audio(commands.Cog):
all_tracks = await self._folder_list( all_tracks = await self._folder_list(
ctx, ctx,
( (
dataclasses.Query.process_input( audio_dataclasses.Query.process_input(
dataclasses.LocalPath( audio_dataclasses.LocalPath(
await self.config.localpath() await self.config.localpath()
).localtrack_folder.absolute(), ).localtrack_folder.absolute(),
search_subfolders=play_subfolders, search_subfolders=play_subfolders,
@ -2081,18 +2096,18 @@ class Audio(commands.Cog):
return await ctx.invoke(self.search, query=search_list) return await ctx.invoke(self.search, query=search_list)
async def _localtracks_folders(self, ctx: commands.Context, search_subfolders=False): async def _localtracks_folders(self, ctx: commands.Context, search_subfolders=False):
audio_data = dataclasses.LocalPath( audio_data = audio_dataclasses.LocalPath(
dataclasses.LocalPath(None).localtrack_folder.absolute() audio_dataclasses.LocalPath(None).localtrack_folder.absolute()
) )
if not await self._localtracks_check(ctx): if not await self._localtracks_check(ctx):
return return
return audio_data.subfolders_in_tree() if search_subfolders else audio_data.subfolders() return audio_data.subfolders_in_tree() if search_subfolders else audio_data.subfolders()
async def _folder_list(self, ctx: commands.Context, query: dataclasses.Query): async def _folder_list(self, ctx: commands.Context, query: audio_dataclasses.Query):
if not await self._localtracks_check(ctx): if not await self._localtracks_check(ctx):
return return
query = dataclasses.Query.process_input(query) query = audio_dataclasses.Query.process_input(query)
if not query.track.exists(): if not query.track.exists():
return return
return ( return (
@ -2102,12 +2117,12 @@ class Audio(commands.Cog):
) )
async def _folder_tracks( async def _folder_tracks(
self, ctx, player: lavalink.player_manager.Player, query: dataclasses.Query self, ctx, player: lavalink.player_manager.Player, query: audio_dataclasses.Query
): ):
if not await self._localtracks_check(ctx): if not await self._localtracks_check(ctx):
return return
audio_data = dataclasses.LocalPath(None) audio_data = audio_dataclasses.LocalPath(None)
try: try:
query.track.path.relative_to(audio_data.to_string()) query.track.path.relative_to(audio_data.to_string())
except ValueError: except ValueError:
@ -2120,17 +2135,17 @@ class Audio(commands.Cog):
return local_tracks return local_tracks
async def _local_play_all( async def _local_play_all(
self, ctx: commands.Context, query: dataclasses.Query, from_search=False self, ctx: commands.Context, query: audio_dataclasses.Query, from_search=False
): ):
if not await self._localtracks_check(ctx): if not await self._localtracks_check(ctx):
return return
if from_search: if from_search:
query = dataclasses.Query.process_input( query = audio_dataclasses.Query.process_input(
query.track.to_string(), invoked_from="local folder" query.track.to_string(), invoked_from="local folder"
) )
await ctx.invoke(self.search, query=query) await ctx.invoke(self.search, query=query)
async def _all_folder_tracks(self, ctx: commands.Context, query: dataclasses.Query): async def _all_folder_tracks(self, ctx: commands.Context, query: audio_dataclasses.Query):
if not await self._localtracks_check(ctx): if not await self._localtracks_check(ctx):
return return
@ -2141,7 +2156,7 @@ class Audio(commands.Cog):
) )
async def _localtracks_check(self, ctx: commands.Context): async def _localtracks_check(self, ctx: commands.Context):
folder = dataclasses.LocalPath(None) folder = audio_dataclasses.LocalPath(None)
if folder.localtrack_folder.exists(): if folder.localtrack_folder.exists():
return True return True
if ctx.invoked_with != "start": if ctx.invoked_with != "start":
@ -2177,7 +2192,7 @@ class Audio(commands.Cog):
dur = "LIVE" dur = "LIVE"
else: else:
dur = lavalink.utils.format_time(player.current.length) dur = lavalink.utils.format_time(player.current.length)
query = dataclasses.Query.process_input(player.current.uri) query = audio_dataclasses.Query.process_input(player.current.uri)
if query.is_local: if query.is_local:
if not player.current.title == "Unknown title": if not player.current.title == "Unknown title":
song = "**{track.author} - {track.title}**\n{uri}\n" song = "**{track.author} - {track.title}**\n{uri}\n"
@ -2189,8 +2204,8 @@ class Audio(commands.Cog):
song += "\n\n{arrow}`{pos}`/`{dur}`" song += "\n\n{arrow}`{pos}`/`{dur}`"
song = song.format( song = song.format(
track=player.current, track=player.current,
uri=dataclasses.LocalPath(player.current.uri).to_string_hidden() uri=audio_dataclasses.LocalPath(player.current.uri).to_string_hidden()
if dataclasses.Query.process_input(player.current.uri).is_local if audio_dataclasses.Query.process_input(player.current.uri).is_local
else player.current.uri, else player.current.uri,
arrow=arrow, arrow=arrow,
pos=pos, pos=pos,
@ -2301,9 +2316,9 @@ class Audio(commands.Cog):
if not player.current: if not player.current:
return await self._embed_msg(ctx, _("Nothing playing.")) return await self._embed_msg(ctx, _("Nothing playing."))
query = dataclasses.Query.process_input(player.current.uri) query = audio_dataclasses.Query.process_input(player.current.uri)
if query.is_local: if query.is_local:
query = dataclasses.Query.process_input(player.current.uri) query = audio_dataclasses.Query.process_input(player.current.uri)
if player.current.title == "Unknown title": if player.current.title == "Unknown title":
description = "{}".format(query.track.to_string_hidden()) description = "{}".format(query.track.to_string_hidden())
else: else:
@ -2436,7 +2451,7 @@ class Audio(commands.Cog):
) )
if not await self._currency_check(ctx, guild_data["jukebox_price"]): if not await self._currency_check(ctx, guild_data["jukebox_price"]):
return return
query = dataclasses.Query.process_input(query) query = audio_dataclasses.Query.process_input(query)
if not query.valid: if not query.valid:
return await self._embed_msg(ctx, _("No tracks to play.")) return await self._embed_msg(ctx, _("No tracks to play."))
if query.is_spotify: if query.is_spotify:
@ -2593,7 +2608,7 @@ class Audio(commands.Cog):
) )
playlists_search_page_list.append(embed) playlists_search_page_list.append(embed)
playlists_pick = await menu(ctx, playlists_search_page_list, playlist_search_controls) playlists_pick = await menu(ctx, playlists_search_page_list, playlist_search_controls)
query = dataclasses.Query.process_input(playlists_pick) query = audio_dataclasses.Query.process_input(playlists_pick)
if not query.valid: if not query.valid:
return await self._embed_msg(ctx, _("No tracks to play.")) return await self._embed_msg(ctx, _("No tracks to play."))
if not await self._currency_check(ctx, guild_data["jukebox_price"]): if not await self._currency_check(ctx, guild_data["jukebox_price"]):
@ -2728,7 +2743,7 @@ class Audio(commands.Cog):
elif player.current: elif player.current:
await self._embed_msg(ctx, _("Adding a track to queue.")) await self._embed_msg(ctx, _("Adding a track to queue."))
async def _get_spotify_tracks(self, ctx: commands.Context, query: dataclasses.Query): async def _get_spotify_tracks(self, ctx: commands.Context, query: audio_dataclasses.Query):
if ctx.invoked_with in ["play", "genre"]: if ctx.invoked_with in ["play", "genre"]:
enqueue_tracks = True enqueue_tracks = True
else: else:
@ -2771,12 +2786,12 @@ class Audio(commands.Cog):
self._play_lock(ctx, False) self._play_lock(ctx, False)
try: try:
if enqueue_tracks: if enqueue_tracks:
new_query = dataclasses.Query.process_input(res[0]) new_query = audio_dataclasses.Query.process_input(res[0])
new_query.start_time = query.start_time new_query.start_time = query.start_time
return await self._enqueue_tracks(ctx, new_query) return await self._enqueue_tracks(ctx, new_query)
else: else:
result, called_api = await self.music_cache.lavalink_query( result, called_api = await self.music_cache.lavalink_query(
ctx, player, dataclasses.Query.process_input(res[0]) ctx, player, audio_dataclasses.Query.process_input(res[0])
) )
tracks = result.tracks tracks = result.tracks
if not tracks: if not tracks:
@ -2808,7 +2823,9 @@ class Audio(commands.Cog):
ctx, _("This doesn't seem to be a supported Spotify URL or code.") ctx, _("This doesn't seem to be a supported Spotify URL or code.")
) )
async def _enqueue_tracks(self, ctx: commands.Context, query: Union[dataclasses.Query, list]): async def _enqueue_tracks(
self, ctx: commands.Context, query: Union[audio_dataclasses.Query, list]
):
player = lavalink.get_player(ctx.guild.id) player = lavalink.get_player(ctx.guild.id)
try: try:
if self.play_lock[ctx.message.guild.id]: if self.play_lock[ctx.message.guild.id]:
@ -2863,7 +2880,7 @@ class Audio(commands.Cog):
ctx.guild, ctx.guild,
( (
f"{track.title} {track.author} {track.uri} " f"{track.title} {track.author} {track.uri} "
f"{str(dataclasses.Query.process_input(track))}" f"{str(audio_dataclasses.Query.process_input(track))}"
), ),
): ):
log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})") log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})")
@ -2923,7 +2940,7 @@ class Audio(commands.Cog):
ctx.guild, ctx.guild,
( (
f"{single_track.title} {single_track.author} {single_track.uri} " f"{single_track.title} {single_track.author} {single_track.uri} "
f"{str(dataclasses.Query.process_input(single_track))}" f"{str(audio_dataclasses.Query.process_input(single_track))}"
), ),
): ):
log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})") log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})")
@ -2956,17 +2973,17 @@ class Audio(commands.Cog):
return await self._embed_msg( return await self._embed_msg(
ctx, _("Nothing found. Check your Lavalink logs for details.") ctx, _("Nothing found. Check your Lavalink logs for details.")
) )
query = dataclasses.Query.process_input(single_track.uri) query = audio_dataclasses.Query.process_input(single_track.uri)
if query.is_local: if query.is_local:
if single_track.title != "Unknown title": if single_track.title != "Unknown title":
description = "**{} - {}**\n{}".format( description = "**{} - {}**\n{}".format(
single_track.author, single_track.author,
single_track.title, single_track.title,
dataclasses.LocalPath(single_track.uri).to_string_hidden(), audio_dataclasses.LocalPath(single_track.uri).to_string_hidden(),
) )
else: else:
description = "{}".format( description = "{}".format(
dataclasses.LocalPath(single_track.uri).to_string_hidden() audio_dataclasses.LocalPath(single_track.uri).to_string_hidden()
) )
else: else:
description = "**[{}]({})**".format(single_track.title, single_track.uri) description = "**[{}]({})**".format(single_track.title, single_track.uri)
@ -2987,7 +3004,11 @@ class Audio(commands.Cog):
self._play_lock(ctx, False) self._play_lock(ctx, False)
async def _spotify_playlist( async def _spotify_playlist(
self, ctx: commands.Context, stype: str, query: dataclasses.Query, enqueue: bool = False self,
ctx: commands.Context,
stype: str,
query: audio_dataclasses.Query,
enqueue: bool = False,
): ):
player = lavalink.get_player(ctx.guild.id) player = lavalink.get_player(ctx.guild.id)
@ -3340,7 +3361,7 @@ class Audio(commands.Cog):
return return
player = lavalink.get_player(ctx.guild.id) player = lavalink.get_player(ctx.guild.id)
to_append = await self._playlist_tracks( to_append = await self._playlist_tracks(
ctx, player, dataclasses.Query.process_input(query) ctx, player, audio_dataclasses.Query.process_input(query)
) )
if not to_append: if not to_append:
return await self._embed_msg(ctx, _("Could not find a track matching your query.")) return await self._embed_msg(ctx, _("Could not find a track matching your query."))
@ -3993,7 +4014,7 @@ class Audio(commands.Cog):
spaces = "\N{EN SPACE}" * (len(str(len(playlist.tracks))) + 2) spaces = "\N{EN SPACE}" * (len(str(len(playlist.tracks))) + 2)
for track in playlist.tracks: for track in playlist.tracks:
track_idx = track_idx + 1 track_idx = track_idx + 1
query = dataclasses.Query.process_input(track["info"]["uri"]) query = audio_dataclasses.Query.process_input(track["info"]["uri"])
if query.is_local: if query.is_local:
if track["info"]["title"] != "Unknown title": if track["info"]["title"] != "Unknown title":
msg += "`{}.` **{} - {}**\n{}{}\n".format( msg += "`{}.` **{} - {}**\n{}{}\n".format(
@ -4398,7 +4419,7 @@ class Audio(commands.Cog):
return return
player = lavalink.get_player(ctx.guild.id) player = lavalink.get_player(ctx.guild.id)
tracklist = await self._playlist_tracks( tracklist = await self._playlist_tracks(
ctx, player, dataclasses.Query.process_input(playlist_url) ctx, player, audio_dataclasses.Query.process_input(playlist_url)
) )
if tracklist is not None: if tracklist is not None:
playlist = await create_playlist( playlist = await create_playlist(
@ -4488,14 +4509,14 @@ class Audio(commands.Cog):
ctx.guild, ctx.guild,
( (
f"{track.title} {track.author} {track.uri} " f"{track.title} {track.author} {track.uri} "
f"{str(dataclasses.Query.process_input(track))}" f"{str(audio_dataclasses.Query.process_input(track))}"
), ),
): ):
log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})") log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})")
continue continue
query = dataclasses.Query.process_input(track.uri) query = audio_dataclasses.Query.process_input(track.uri)
if query.is_local: if query.is_local:
local_path = dataclasses.LocalPath(track.uri) local_path = audio_dataclasses.LocalPath(track.uri)
if not await self._localtracks_check(ctx): if not await self._localtracks_check(ctx):
pass pass
if not local_path.exists() and not local_path.is_file(): if not local_path.exists() and not local_path.is_file():
@ -4781,7 +4802,7 @@ class Audio(commands.Cog):
or not match_yt_playlist(uploaded_playlist_url) or not match_yt_playlist(uploaded_playlist_url)
or not ( or not (
await self.music_cache.lavalink_query( await self.music_cache.lavalink_query(
ctx, player, dataclasses.Query.process_input(uploaded_playlist_url) ctx, player, audio_dataclasses.Query.process_input(uploaded_playlist_url)
) )
)[0].tracks )[0].tracks
): ):
@ -4966,7 +4987,7 @@ class Audio(commands.Cog):
} }
) )
if database_entries and HAS_SQL: if database_entries and HAS_SQL:
asyncio.ensure_future(self.music_cache.insert("lavalink", database_entries)) await self.music_cache.insert("lavalink", database_entries)
async def _load_v2_playlist( async def _load_v2_playlist(
self, self,
@ -4993,7 +5014,7 @@ class Audio(commands.Cog):
track_count += 1 track_count += 1
try: try:
result, called_api = await self.music_cache.lavalink_query( result, called_api = await self.music_cache.lavalink_query(
ctx, player, dataclasses.Query.process_input(song_url) ctx, player, audio_dataclasses.Query.process_input(song_url)
) )
track = result.tracks track = result.tracks
except Exception: except Exception:
@ -5041,7 +5062,7 @@ class Audio(commands.Cog):
return [], [], playlist return [], [], playlist
results = {} results = {}
updated_tracks = await self._playlist_tracks( updated_tracks = await self._playlist_tracks(
ctx, player, dataclasses.Query.process_input(playlist.url) ctx, player, audio_dataclasses.Query.process_input(playlist.url)
) )
if not updated_tracks: if not updated_tracks:
# No Tracks available on url Lets set it to none to avoid repeated calls here # No Tracks available on url Lets set it to none to avoid repeated calls here
@ -5106,7 +5127,7 @@ class Audio(commands.Cog):
self, self,
ctx: commands.Context, ctx: commands.Context,
player: lavalink.player_manager.Player, player: lavalink.player_manager.Player,
query: dataclasses.Query, query: audio_dataclasses.Query,
): ):
search = query.is_search search = query.is_search
tracklist = [] tracklist = []
@ -5175,7 +5196,7 @@ class Audio(commands.Cog):
player.queue.insert(0, bump_song) player.queue.insert(0, bump_song)
player.queue.pop(queue_len) player.queue.pop(queue_len)
await player.skip() await player.skip()
query = dataclasses.Query.process_input(player.current.uri) query = audio_dataclasses.Query.process_input(player.current.uri)
if query.is_local: if query.is_local:
if player.current.title == "Unknown title": if player.current.title == "Unknown title":
@ -5227,7 +5248,7 @@ class Audio(commands.Cog):
else: else:
dur = lavalink.utils.format_time(player.current.length) dur = lavalink.utils.format_time(player.current.length)
query = dataclasses.Query.process_input(player.current) query = audio_dataclasses.Query.process_input(player.current)
if query.is_local: if query.is_local:
if player.current.title != "Unknown title": if player.current.title != "Unknown title":
@ -5240,8 +5261,8 @@ class Audio(commands.Cog):
song += "\n\n{arrow}`{pos}`/`{dur}`" song += "\n\n{arrow}`{pos}`/`{dur}`"
song = song.format( song = song.format(
track=player.current, track=player.current,
uri=dataclasses.LocalPath(player.current.uri).to_string_hidden() uri=audio_dataclasses.LocalPath(player.current.uri).to_string_hidden()
if dataclasses.Query.process_input(player.current.uri).is_local if audio_dataclasses.Query.process_input(player.current.uri).is_local
else player.current.uri, else player.current.uri,
arrow=arrow, arrow=arrow,
pos=pos, pos=pos,
@ -5313,7 +5334,7 @@ class Audio(commands.Cog):
else: else:
dur = lavalink.utils.format_time(player.current.length) dur = lavalink.utils.format_time(player.current.length)
query = dataclasses.Query.process_input(player.current) query = audio_dataclasses.Query.process_input(player.current)
if query.is_stream: if query.is_stream:
queue_list += _("**Currently livestreaming:**\n") queue_list += _("**Currently livestreaming:**\n")
@ -5327,7 +5348,7 @@ class Audio(commands.Cog):
( (
_("Playing: ") _("Playing: ")
+ "**{current.author} - {current.title}**".format(current=player.current), + "**{current.author} - {current.title}**".format(current=player.current),
dataclasses.LocalPath(player.current.uri).to_string_hidden(), audio_dataclasses.LocalPath(player.current.uri).to_string_hidden(),
_("Requested by: **{user}**\n").format(user=player.current.requester), _("Requested by: **{user}**\n").format(user=player.current.requester),
f"{arrow}`{pos}`/`{dur}`\n\n", f"{arrow}`{pos}`/`{dur}`\n\n",
) )
@ -5336,7 +5357,7 @@ class Audio(commands.Cog):
queue_list += "\n".join( queue_list += "\n".join(
( (
_("Playing: ") _("Playing: ")
+ dataclasses.LocalPath(player.current.uri).to_string_hidden(), + audio_dataclasses.LocalPath(player.current.uri).to_string_hidden(),
_("Requested by: **{user}**\n").format(user=player.current.requester), _("Requested by: **{user}**\n").format(user=player.current.requester),
f"{arrow}`{pos}`/`{dur}`\n\n", f"{arrow}`{pos}`/`{dur}`\n\n",
) )
@ -5357,13 +5378,13 @@ class Audio(commands.Cog):
track_title = track.title track_title = track.title
req_user = track.requester req_user = track.requester
track_idx = i + 1 track_idx = i + 1
query = dataclasses.Query.process_input(track) query = audio_dataclasses.Query.process_input(track)
if query.is_local: if query.is_local:
if track.title == "Unknown title": if track.title == "Unknown title":
queue_list += f"`{track_idx}.` " + ", ".join( queue_list += f"`{track_idx}.` " + ", ".join(
( (
bold(dataclasses.LocalPath(track.uri).to_string_hidden()), bold(audio_dataclasses.LocalPath(track.uri).to_string_hidden()),
_("requested by **{user}**\n").format(user=req_user), _("requested by **{user}**\n").format(user=req_user),
) )
) )
@ -5420,7 +5441,7 @@ class Audio(commands.Cog):
for track in queue_list: for track in queue_list:
queue_idx = queue_idx + 1 queue_idx = queue_idx + 1
if not match_url(track.uri): if not match_url(track.uri):
query = dataclasses.Query.process_input(track) query = audio_dataclasses.Query.process_input(track)
if track.title == "Unknown title": if track.title == "Unknown title":
track_title = query.track.to_string_hidden() track_title = query.track.to_string_hidden()
else: else:
@ -5449,7 +5470,7 @@ class Audio(commands.Cog):
): ):
track_idx = i + 1 track_idx = i + 1
if type(track) is str: if type(track) is str:
track_location = dataclasses.LocalPath(track).to_string_hidden() track_location = audio_dataclasses.LocalPath(track).to_string_hidden()
track_match += "`{}.` **{}**\n".format(track_idx, track_location) track_match += "`{}.` **{}**\n".format(track_idx, track_location)
else: else:
track_match += "`{}.` **{}**\n".format(track[0], track[1]) track_match += "`{}.` **{}**\n".format(track[0], track[1])
@ -5674,9 +5695,9 @@ class Audio(commands.Cog):
) )
index -= 1 index -= 1
removed = player.queue.pop(index) removed = player.queue.pop(index)
query = dataclasses.Query.process_input(removed.uri) query = audio_dataclasses.Query.process_input(removed.uri)
if query.is_local: if query.is_local:
local_path = dataclasses.LocalPath(removed.uri).to_string_hidden() local_path = audio_dataclasses.LocalPath(removed.uri).to_string_hidden()
if removed.title == "Unknown title": if removed.title == "Unknown title":
removed_title = local_path removed_title = local_path
else: else:
@ -5762,7 +5783,7 @@ class Audio(commands.Cog):
await self._data_check(ctx) await self._data_check(ctx)
if not isinstance(query, list): if not isinstance(query, list):
query = dataclasses.Query.process_input(query) query = audio_dataclasses.Query.process_input(query)
if query.invoked_from == "search list" or query.invoked_from == "local folder": if query.invoked_from == "search list" or query.invoked_from == "local folder":
if query.invoked_from == "search list": if query.invoked_from == "search list":
result, called_api = await self.music_cache.lavalink_query(ctx, player, query) result, called_api = await self.music_cache.lavalink_query(ctx, player, query)
@ -5791,7 +5812,7 @@ class Audio(commands.Cog):
ctx.guild, ctx.guild,
( (
f"{track.title} {track.author} {track.uri} " f"{track.title} {track.author} {track.uri} "
f"{str(dataclasses.Query.process_input(track))}" f"{str(audio_dataclasses.Query.process_input(track))}"
), ),
): ):
log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})") log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})")
@ -5905,10 +5926,10 @@ class Audio(commands.Cog):
except IndexError: except IndexError:
search_choice = tracks[-1] search_choice = tracks[-1]
try: try:
query = dataclasses.Query.process_input(search_choice.uri) query = audio_dataclasses.Query.process_input(search_choice.uri)
if query.is_local: if query.is_local:
localtrack = dataclasses.LocalPath(search_choice.uri) localtrack = audio_dataclasses.LocalPath(search_choice.uri)
if search_choice.title != "Unknown title": if search_choice.title != "Unknown title":
description = "**{} - {}**\n{}".format( description = "**{} - {}**\n{}".format(
search_choice.author, search_choice.title, localtrack.to_string_hidden() search_choice.author, search_choice.title, localtrack.to_string_hidden()
@ -5919,7 +5940,7 @@ class Audio(commands.Cog):
description = "**[{}]({})**".format(search_choice.title, search_choice.uri) description = "**[{}]({})**".format(search_choice.title, search_choice.uri)
except AttributeError: except AttributeError:
search_choice = dataclasses.Query.process_input(search_choice) search_choice = audio_dataclasses.Query.process_input(search_choice)
if search_choice.track.exists() and search_choice.track.is_dir(): if search_choice.track.exists() and search_choice.track.is_dir():
return await ctx.invoke(self.search, query=search_choice) return await ctx.invoke(self.search, query=search_choice)
elif search_choice.track.exists() and search_choice.track.is_file(): elif search_choice.track.exists() and search_choice.track.is_file():
@ -5935,7 +5956,7 @@ class Audio(commands.Cog):
ctx.guild, ctx.guild,
( (
f"{search_choice.title} {search_choice.author} {search_choice.uri} " f"{search_choice.title} {search_choice.author} {search_choice.uri} "
f"{str(dataclasses.Query.process_input(search_choice))}" f"{str(audio_dataclasses.Query.process_input(search_choice))}"
), ),
): ):
log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})") log.debug(f"Query is not allowed in {ctx.guild} ({ctx.guild.id})")
@ -5984,12 +6005,12 @@ class Audio(commands.Cog):
if search_track_num == 0: if search_track_num == 0:
search_track_num = 5 search_track_num = 5
try: try:
query = dataclasses.Query.process_input(track.uri) query = audio_dataclasses.Query.process_input(track.uri)
if query.is_local: if query.is_local:
search_list += "`{0}.` **{1}**\n[{2}]\n".format( search_list += "`{0}.` **{1}**\n[{2}]\n".format(
search_track_num, search_track_num,
track.title, track.title,
dataclasses.LocalPath(track.uri).to_string_hidden(), audio_dataclasses.LocalPath(track.uri).to_string_hidden(),
) )
else: else:
search_list += "`{0}.` **[{1}]({2})**\n".format( search_list += "`{0}.` **[{1}]({2})**\n".format(
@ -5997,7 +6018,7 @@ class Audio(commands.Cog):
) )
except AttributeError: except AttributeError:
# query = Query.process_input(track) # query = Query.process_input(track)
track = dataclasses.Query.process_input(track) track = audio_dataclasses.Query.process_input(track)
if track.is_local and command != "search": if track.is_local and command != "search":
search_list += "`{}.` **{}**\n".format( search_list += "`{}.` **{}**\n".format(
search_track_num, track.to_string_user() search_track_num, track.to_string_user()
@ -6890,6 +6911,7 @@ class Audio(commands.Cog):
async def on_voice_state_update( async def on_voice_state_update(
self, member: discord.Member, before: discord.VoiceState, after: discord.VoiceState self, member: discord.Member, before: discord.VoiceState, after: discord.VoiceState
): ):
await self._ready_event.wait()
if after.channel != before.channel: if after.channel != before.channel:
try: try:
self.skip_votes[before.channel.guild].remove(member.id) self.skip_votes[before.channel.guild].remove(member.id)
@ -6907,6 +6929,9 @@ class Audio(commands.Cog):
if self._connect_task: if self._connect_task:
self._connect_task.cancel() self._connect_task.cancel()
if self._init_task:
self._init_task.cancel()
lavalink.unregister_event_listener(self.event_handler) lavalink.unregister_event_listener(self.event_handler)
self.bot.loop.create_task(lavalink.close()) self.bot.loop.create_task(lavalink.close())
if self._manager is not None: if self._manager is not None:

View File

@ -3,7 +3,6 @@ import contextlib
import os import os
import re import re
import time import time
from typing import NoReturn
from urllib.parse import urlparse from urllib.parse import urlparse
import discord import discord
@ -11,7 +10,7 @@ import lavalink
from redbot.core import Config, commands from redbot.core import Config, commands
from redbot.core.bot import Red from redbot.core.bot import Red
from . import dataclasses from . import audio_dataclasses
from .converters import _pass_config_to_converters from .converters import _pass_config_to_converters
@ -51,7 +50,7 @@ def pass_config_to_dependencies(config: Config, bot: Red, localtracks_folder: st
_config = config _config = config
_pass_config_to_playlist(config, bot) _pass_config_to_playlist(config, bot)
_pass_config_to_converters(config, bot) _pass_config_to_converters(config, bot)
dataclasses._pass_config_to_dataclasses(config, bot, localtracks_folder) audio_dataclasses._pass_config_to_dataclasses(config, bot, localtracks_folder)
def track_limit(track, maxlength): def track_limit(track, maxlength):
@ -168,7 +167,7 @@ async def clear_react(bot: Red, message: discord.Message, emoji: dict = None):
async def get_description(track): async def get_description(track):
if any(x in track.uri for x in [f"{os.sep}localtracks", f"localtracks{os.sep}"]): if any(x in track.uri for x in [f"{os.sep}localtracks", f"localtracks{os.sep}"]):
local_track = dataclasses.LocalPath(track.uri) local_track = audio_dataclasses.LocalPath(track.uri)
if track.title != "Unknown title": if track.title != "Unknown title":
return "**{} - {}**\n{}".format( return "**{} - {}**\n{}".format(
track.author, track.title, local_track.to_string_hidden() track.author, track.title, local_track.to_string_hidden()
@ -389,7 +388,7 @@ class Notifier:
key: str = None, key: str = None,
seconds_key: str = None, seconds_key: str = None,
seconds: str = None, seconds: str = None,
) -> NoReturn: ):
""" """
This updates an existing message. This updates an existing message.
Based on the message found in :variable:`Notifier.updates` as per the `key` param Based on the message found in :variable:`Notifier.updates` as per the `key` param
@ -410,14 +409,14 @@ class Notifier:
except discord.errors.NotFound: except discord.errors.NotFound:
pass pass
async def update_text(self, text: str) -> NoReturn: async def update_text(self, text: str):
embed2 = discord.Embed(colour=self.color, title=text) embed2 = discord.Embed(colour=self.color, title=text)
try: try:
await self.message.edit(embed=embed2) await self.message.edit(embed=embed2)
except discord.errors.NotFound: except discord.errors.NotFound:
pass pass
async def update_embed(self, embed: discord.Embed) -> NoReturn: async def update_embed(self, embed: discord.Embed):
try: try:
await self.message.edit(embed=embed) await self.message.edit(embed=embed)
self.last_msg_time = time.time() self.last_msg_time = time.time()

View File

@ -132,7 +132,6 @@ class RedBase(commands.GroupMixin, commands.bot.BotBase, RPCMixin): # pylint: d
self._main_dir = bot_dir self._main_dir = bot_dir
self._cog_mgr = CogManager() self._cog_mgr = CogManager()
super().__init__(*args, help_command=None, **kwargs) super().__init__(*args, help_command=None, **kwargs)
# Do not manually use the help formatter attribute here, see `send_help_for`, # Do not manually use the help formatter attribute here, see `send_help_for`,
# for a documented API. The internals of this object are still subject to change. # for a documented API. The internals of this object are still subject to change.
@ -325,6 +324,7 @@ class RedBase(commands.GroupMixin, commands.bot.BotBase, RPCMixin): # pylint: d
get_embed_colour = get_embed_color get_embed_colour = get_embed_color
# start config migrations
async def _maybe_update_config(self): async def _maybe_update_config(self):
""" """
This should be run prior to loading cogs or connecting to discord. This should be run prior to loading cogs or connecting to discord.
@ -375,6 +375,57 @@ class RedBase(commands.GroupMixin, commands.bot.BotBase, RPCMixin): # pylint: d
await self._config.guild(guild_obj).admin_role.set(admin_roles) await self._config.guild(guild_obj).admin_role.set(admin_roles)
log.info("Done updating guild configs to support multiple mod/admin roles") log.info("Done updating guild configs to support multiple mod/admin roles")
# end Config migrations
async def pre_flight(self, cli_flags):
"""
This should only be run once, prior to connecting to discord.
"""
await self._maybe_update_config()
packages = []
if cli_flags.no_cogs is False:
packages.extend(await self._config.packages())
if cli_flags.load_cogs:
packages.extend(cli_flags.load_cogs)
if packages:
# Load permissions first, for security reasons
try:
packages.remove("permissions")
except ValueError:
pass
else:
packages.insert(0, "permissions")
to_remove = []
print("Loading packages...")
for package in packages:
try:
spec = await self._cog_mgr.find_cog(package)
await asyncio.wait_for(self.load_extension(spec), 30)
except asyncio.TimeoutError:
log.exception("Failed to load package %s (timeout)", package)
to_remove.append(package)
except Exception as e:
log.exception("Failed to load package {}".format(package), exc_info=e)
await self.remove_loaded_package(package)
to_remove.append(package)
for package in to_remove:
packages.remove(package)
if packages:
print("Loaded packages: " + ", ".join(packages))
if self.rpc_enabled:
await self.rpc.initialize(self.rpc_port)
async def start(self, *args, **kwargs):
cli_flags = kwargs.pop("cli_flags")
await self.pre_flight(cli_flags=cli_flags)
return await super().start(*args, **kwargs)
async def send_help_for( async def send_help_for(
self, ctx: commands.Context, help_for: Union[commands.Command, commands.GroupMixin, str] self, ctx: commands.Context, help_for: Union[commands.Command, commands.GroupMixin, str]
): ):

View File

@ -46,40 +46,6 @@ def init_events(bot, cli_flags):
return return
bot._uptime = datetime.datetime.utcnow() bot._uptime = datetime.datetime.utcnow()
packages = []
if cli_flags.no_cogs is False:
packages.extend(await bot._config.packages())
if cli_flags.load_cogs:
packages.extend(cli_flags.load_cogs)
if packages:
# Load permissions first, for security reasons
try:
packages.remove("permissions")
except ValueError:
pass
else:
packages.insert(0, "permissions")
to_remove = []
print("Loading packages...")
for package in packages:
try:
spec = await bot._cog_mgr.find_cog(package)
await bot.load_extension(spec)
except Exception as e:
log.exception("Failed to load package {}".format(package), exc_info=e)
await bot.remove_loaded_package(package)
to_remove.append(package)
for package in to_remove:
packages.remove(package)
if packages:
print("Loaded packages: " + ", ".join(packages))
if bot.rpc_enabled:
await bot.rpc.initialize(bot.rpc_port)
guilds = len(bot.guilds) guilds = len(bot.guilds)
users = len(set([m for m in bot.get_all_members()])) users = len(set([m for m in bot.get_all_members()]))