Signed-off-by: Drapersniper <27962761+drapersniper@users.noreply.github.com>
This commit is contained in:
Drapersniper
2019-12-18 20:09:09 +00:00
parent 9f50d83545
commit 85a1057af6
3 changed files with 181 additions and 97 deletions

View File

@@ -202,29 +202,33 @@ class Audio(commands.Cog):
await self._embed_msg(ctx, _("No DJ role found. Disabling DJ mode.")) await self._embed_msg(ctx, _("No DJ role found. Disabling DJ mode."))
async def initialize(self): async def initialize(self):
await self.bot.wait_until_ready() try:
# Unlike most cases, we want the cache to exit before migration. await self.bot.wait_until_ready()
await self.music_cache.initialize(self.config) # Unlike most cases, we want the cache to exit before migration.
pass_config_to_dependencies(self.config, self.bot, await self.config.localpath()) await self.music_cache.initialize(self.config)
await self._migrate_config( pass_config_to_dependencies(self.config, self.bot, await self.config.localpath())
from_version=await self.config.schema_version(), to_version=_SCHEMA_VERSION await self._migrate_config(
) from_version=await self.config.schema_version(), to_version=_SCHEMA_VERSION
self._restart_connect() )
self._disconnect_task = self.bot.loop.create_task(self.disconnect_timer()) self._restart_connect()
lavalink.register_event_listener(self.event_handler) self._disconnect_task = self.bot.loop.create_task(self.disconnect_timer())
if not HAS_SQL: lavalink.register_event_listener(self.event_handler)
error_message = ( if not HAS_SQL:
"Audio version: {version}\nThis version requires some SQL dependencies to " error_message = (
"access the caching features, " "Audio version: {version}\nThis version requires some SQL dependencies to "
"your Python install is missing some of them.\n\n" "access the caching features, "
"For instructions on how to fix it Google " "your Python install is missing some of them.\n\n"
f"`{_ERROR}`.\n" "For instructions on how to fix it Google "
"You will need to install the missing SQL dependency.\n\n" f"`{_ERROR}`.\n"
).format(version=__version__) "You will need to install the missing SQL dependency.\n\n"
with contextlib.suppress(discord.HTTPException): ).format(version=__version__)
for page in pagify(error_message): with contextlib.suppress(discord.HTTPException):
await self.bot.send_to_owners(page) for page in pagify(error_message):
log.critical(error_message) await self.bot.send_to_owners(page)
log.critical(error_message)
except Exception as e:
log.exception("Error on audio init", exc_info=e)
raise e
self._ready_event.set() self._ready_event.set()
self.bot.dispatch("red_audio_initialized", self) self.bot.dispatch("red_audio_initialized", self)
@@ -3196,6 +3200,7 @@ class Audio(commands.Cog):
When multiple matches are found but none is selected. When multiple matches are found but none is selected.
""" """
correct_scope_matches: List[Playlist]
original_input = matches.get("arg") original_input = matches.get("arg")
correct_scope_matches = matches.get(scope) correct_scope_matches = matches.get(scope)
guild_to_query = guild.id guild_to_query = guild.id
@@ -3204,50 +3209,40 @@ class Audio(commands.Cog):
return None, original_input return None, original_input
if scope == PlaylistScope.USER.value: if scope == PlaylistScope.USER.value:
correct_scope_matches = [ correct_scope_matches = [
(i[2]["id"], i[2]["name"], len(i[2]["tracks"]), i[2]["author"]) p for p in correct_scope_matches if user_to_query == p.scope_id
for i in correct_scope_matches
if str(user_to_query) == i[0]
] ]
elif scope == PlaylistScope.GUILD.value: elif scope == PlaylistScope.GUILD.value:
if specified_user: if specified_user:
correct_scope_matches = [ correct_scope_matches = [
(i[2]["id"], i[2]["name"], len(i[2]["tracks"]), i[2]["author"]) p
for i in correct_scope_matches for p in correct_scope_matches
if str(guild_to_query) == i[0] and i[2]["author"] == user_to_query if guild_to_query == p.scope_id and p.author == user_to_query
] ]
else: else:
correct_scope_matches = [ correct_scope_matches = [
(i[2]["id"], i[2]["name"], len(i[2]["tracks"]), i[2]["author"]) p for p in correct_scope_matches if guild_to_query == p.scope_id
for i in correct_scope_matches
if str(guild_to_query) == i[0]
] ]
else: else:
if specified_user: if specified_user:
correct_scope_matches = [ correct_scope_matches = [
(i[2]["id"], i[2]["name"], len(i[2]["tracks"]), i[2]["author"]) p for p in correct_scope_matches if p.author == user_to_query
for i in correct_scope_matches
if i[2]["author"] == user_to_query
] ]
else: else:
correct_scope_matches = [ correct_scope_matches = [p for p in correct_scope_matches]
(i[2]["id"], i[2]["name"], len(i[2]["tracks"]), i[2]["author"])
for i in correct_scope_matches
]
match_count = len(correct_scope_matches) match_count = len(correct_scope_matches)
# We done all the trimming we can with the info available time to ask the user # We done all the trimming we can with the info available time to ask the user
print("correct_scope_matches", correct_scope_matches)
if match_count > 10: if match_count > 10:
if original_input.isnumeric(): if original_input.isnumeric():
arg = int(original_input) arg = int(original_input)
correct_scope_matches = [ correct_scope_matches = [p for p in correct_scope_matches if p.id == arg]
(i, n, t, a) for i, n, t, a in correct_scope_matches if i == arg
]
if match_count > 10: if match_count > 10:
raise TooManyMatches( raise TooManyMatches(
f"{match_count} playlists match {original_input}: " f"{match_count} playlists match {original_input}: "
f"Please try to be more specific, or use the playlist ID." f"Please try to be more specific, or use the playlist ID."
) )
elif match_count == 1: elif match_count == 1:
return correct_scope_matches[0][0], original_input return correct_scope_matches[0].id, original_input
elif match_count == 0: elif match_count == 0:
return None, original_input return None, original_input
@@ -3255,14 +3250,14 @@ class Audio(commands.Cog):
pos_len = 3 pos_len = 3
playlists = f"{'#':{pos_len}}\n" playlists = f"{'#':{pos_len}}\n"
for number, (pid, pname, ptracks, pauthor) in enumerate(correct_scope_matches, 1): for number, playlist in enumerate(correct_scope_matches, 1):
author = self.bot.get_user(pauthor) or "Unknown" author = self.bot.get_user(playlist.author) or "Unknown"
line = ( line = (
f"{number}." f"{number}."
f" <{pname}>\n" f" <{playlist.name}>\n"
f" - Scope: < {humanize_scope(scope)} >\n" f" - Scope: < {humanize_scope(scope)} >\n"
f" - ID: < {pid} >\n" f" - ID: < {playlist.id} >\n"
f" - Tracks: < {ptracks} >\n" f" - Tracks: < {len(playlist.tracks)} >\n"
f" - Author: < {author} >\n\n" f" - Author: < {author} >\n\n"
) )
playlists += line playlists += line
@@ -3295,7 +3290,7 @@ class Audio(commands.Cog):
) )
with contextlib.suppress(discord.HTTPException): with contextlib.suppress(discord.HTTPException):
await msg.delete() await msg.delete()
return correct_scope_matches[pred.result][0], original_input return correct_scope_matches[pred.result].id, original_input
@commands.group() @commands.group()
@commands.guild_only() @commands.guild_only()
@@ -3954,14 +3949,17 @@ class Audio(commands.Cog):
playlist_songs_backwards_compatible = [ playlist_songs_backwards_compatible = [
track["info"]["uri"] for track in playlist.tracks track["info"]["uri"] for track in playlist.tracks
] ]
# TODO: Keep new playlists backwards compatible, Remove me in a few releases
playlist_data[ playlist_data[
"playlist" "playlist"
] = playlist_songs_backwards_compatible # TODO: Keep new playlists backwards compatible, Remove me in a few releases ] = (
playlist_songs_backwards_compatible
)
playlist_data[ playlist_data[
"link" "link"
] = ( ] = (
playlist.url playlist.url
) # TODO: Keep new playlists backwards compatible, Remove me in a few releases )
file_name = playlist.id file_name = playlist.id
playlist_data.update({"schema": schema, "version": version}) playlist_data.update({"schema": schema, "version": version})
playlist_data = json.dumps(playlist_data) playlist_data = json.dumps(playlist_data)

View File

@@ -8,7 +8,7 @@ from redbot.core import Config, commands
from redbot.core.bot import Red from redbot.core.bot import Red
from redbot.core.i18n import Translator from redbot.core.i18n import Translator
from .playlists import PlaylistScope, standardize_scope from .playlists import PlaylistScope, standardize_scope, get_all_playlist_converter
_ = Translator("Audio", __file__) _ = Translator("Audio", __file__)
@@ -22,8 +22,8 @@ __all__ = [
"get_playlist_converter", "get_playlist_converter",
] ]
_config = None _config: Config = None
_bot = None _bot: Red = None
_SCOPE_HELP = """ _SCOPE_HELP = """
Scope must be a valid version of one of the following: Scope must be a valid version of one of the following:
@@ -54,29 +54,18 @@ def _pass_config_to_converters(config: Config, bot: Red):
class PlaylistConverter(commands.Converter): class PlaylistConverter(commands.Converter):
async def convert(self, ctx: commands.Context, arg: str) -> dict: async def convert(self, ctx: commands.Context, arg: str) -> dict:
global_scope = await _config.custom(PlaylistScope.GLOBAL.value).all() global_matches = await get_all_playlist_converter(
guild_scope = await _config.custom(PlaylistScope.GUILD.value).all() PlaylistScope.GLOBAL.value, _bot, arg, guild=ctx.guild, author=ctx.author
user_scope = await _config.custom(PlaylistScope.USER.value).all() )
user_matches = [ guild_matches = await get_all_playlist_converter(
(uid, pid, pdata) PlaylistScope.GUILD.value, _bot, arg, guild=ctx.guild, author=ctx.author
for uid, data in user_scope.items() )
for pid, pdata in data.items() user_matches = await get_all_playlist_converter(
if arg == pid or arg.lower() in pdata.get("name", "").lower() PlaylistScope.USER.value, _bot, arg, guild=ctx.guild, author=ctx.author
] )
guild_matches = [
(gid, pid, pdata)
for gid, data in guild_scope.items()
for pid, pdata in data.items()
if arg == pid or arg.lower() in pdata.get("name", "").lower()
]
global_matches = [
(None, pid, pdata)
for pid, pdata in global_scope.items()
if arg == pid or arg.lower() in pdata.get("name", "").lower()
]
if not user_matches and not guild_matches and not global_matches: if not user_matches and not guild_matches and not global_matches:
raise commands.BadArgument(_("Could not match '{}' to a playlist.").format(arg)) raise commands.BadArgument(_("Could not match '{}' to a playlist.").format(arg))
return { return {
PlaylistScope.GLOBAL.value: global_matches, PlaylistScope.GLOBAL.value: global_matches,
PlaylistScope.GUILD.value: guild_matches, PlaylistScope.GUILD.value: guild_matches,

View File

@@ -98,7 +98,9 @@ SELECT
FROM FROM
playlists playlists
WHERE WHERE
scope_type = :scope_type ; scope_type = :scope_type
AND scope_id = :scope_id
;
""" """
_FETCH_ALL_WITH_FILTER = """ _FETCH_ALL_WITH_FILTER = """
@@ -114,11 +116,35 @@ FROM
WHERE WHERE
( (
scope_type = :scope_type scope_type = :scope_type
AND scope_id = :scope_id
AND author_id = :author_id AND author_id = :author_id
) )
; ;
""" """
_FETCH_ALL_CONVERTER = """
SELECT
playlist_id,
playlist_name,
scope_id,
author_id,
playlist_url,
tracks
FROM
playlists
WHERE
(
scope_type = :scope_type
AND
(
playlist_id = :playlist_id
OR
playlist_name = :playlist_name
)
)
;
"""
_FETCH = """ _FETCH = """
SELECT SELECT
playlist_id, playlist_id,
@@ -211,22 +237,38 @@ class Database:
return SQLFetchResult(*row) if row else None return SQLFetchResult(*row) if row else None
def fetch_all(self, scope: str, author_id=None) -> List[SQLFetchResult]: def fetch_all(self, scope: str, scope_id: int, author_id=None) -> List[SQLFetchResult]:
scope_type = self.get_scope_type(scope) scope_type = self.get_scope_type(scope)
if author_id is not None: if author_id is not None:
output = ( output = self.cursor.execute(
self.cursor.execute( _FETCH_ALL_WITH_FILTER,
_FETCH_ALL, ({"scope_type": scope_type, "author_id": author_id}) ({"scope_type": scope_type, "scope_id": scope_id, "author_id": author_id}),
).fetchall() ).fetchall()
or []
)
else: else:
output = ( output = self.cursor.execute(
self.cursor.execute( _FETCH_ALL, ({"scope_type": scope_type, "scope_id": scope_id})
_FETCH_ALL_WITH_FILTER, ({"scope_type": scope_type}) ).fetchall()
).fetchall() return [SQLFetchResult(*row) for row in output] if output else []
or []
) def fetch_all_converter(self, scope: str, playlist_name, playlist_id) -> List[SQLFetchResult]:
scope_type = self.get_scope_type(scope)
try:
playlist_id = int(playlist_id)
except:
playlist_id = -1
output = (
self.cursor.execute(
_FETCH_ALL_CONVERTER,
(
{
"scope_type": scope_type,
"playlist_name": playlist_name,
"playlist_id": playlist_id,
}
),
).fetchall()
or []
)
return [SQLFetchResult(*row) for row in output] if output else [] return [SQLFetchResult(*row) for row in output] if output else []
def delete(self, scope: str, playlist_id: int, scope_id: int): def delete(self, scope: str, playlist_id: int, scope_id: int):
@@ -369,9 +411,10 @@ class PlaylistMigration23: # TODO: remove me in a future version ?
self.url = playlist_url self.url = playlist_url
self.tracks = tracks or [] self.tracks = tracks or []
@classmethod @classmethod
async def from_json(cls, scope: str, playlist_number: int, data: dict, **kwargs): async def from_json(
cls, scope: str, playlist_number: int, data: dict, **kwargs
) -> "PlaylistMigration23":
"""Get a Playlist object from the provided information. """Get a Playlist object from the provided information.
Parameters Parameters
---------- ----------
@@ -405,7 +448,7 @@ class PlaylistMigration23: # TODO: remove me in a future version ?
playlist_id = data.get("id") or playlist_number playlist_id = data.get("id") or playlist_number
name = data.get("name", "Unnamed") name = data.get("name", "Unnamed")
playlist_url = data.get("playlist_url", None) playlist_url = data.get("playlist_url", None)
tracks = json.loads(data.get("tracks", "[]")) tracks = data.get("tracks", [])
return cls( return cls(
guild=guild, guild=guild,
@@ -469,7 +512,11 @@ async def get_all_playlist_for_migration23( # TODO: remove me in a future versi
if scope == PlaylistScope.GLOBAL.value: if scope == PlaylistScope.GLOBAL.value:
return [ return [
await PlaylistMigration23.from_json( await PlaylistMigration23.from_json(
scope, playlist_number, playlist_data, guild=guild, author=int(playlist_data.get("author", 0)) scope,
playlist_number,
playlist_data,
guild=guild,
author=int(playlist_data.get("author", 0)),
) )
for playlist_number, playlist_data in playlists.items() for playlist_number, playlist_data in playlists.items()
] ]
@@ -484,7 +531,11 @@ async def get_all_playlist_for_migration23( # TODO: remove me in a future versi
else: else:
return [ return [
await PlaylistMigration23.from_json( await PlaylistMigration23.from_json(
scope, playlist_number, playlist_data, guild=int(guild_id), author=int(playlist_data.get("author", 0)) scope,
playlist_number,
playlist_data,
guild=int(guild_id),
author=int(playlist_data.get("author", 0)),
) )
for guild_id, scopedata in playlists.items() for guild_id, scopedata in playlists.items()
for playlist_number, playlist_data in scopedata.items() for playlist_number, playlist_data in scopedata.items()
@@ -509,6 +560,7 @@ class Playlist:
self.guild = guild self.guild = guild
self.scope = standardize_scope(scope) self.scope = standardize_scope(scope)
self.config_scope = _prepare_config_scope(self.scope, author, guild) self.config_scope = _prepare_config_scope(self.scope, author, guild)
self.scope_id = self.config_scope[-1]
self.author = author self.author = author
self.guild_id = ( self.guild_id = (
getattr(guild, "id", guild) if self.scope == PlaylistScope.GLOBAL.value else None getattr(guild, "id", guild) if self.scope == PlaylistScope.GLOBAL.value else None
@@ -569,7 +621,9 @@ class Playlist:
return data return data
@classmethod @classmethod
async def from_json(cls, bot: Red, scope: str, playlist_number: int, data: dict, **kwargs): async def from_json(
cls, bot: Red, scope: str, playlist_number: int, data: SQLFetchResult, **kwargs
):
"""Get a Playlist object from the provided information. """Get a Playlist object from the provided information.
Parameters Parameters
---------- ----------
@@ -701,10 +755,53 @@ async def get_all_playlist(
if specified_user: if specified_user:
user_id = getattr(author, "id", author) user_id = getattr(author, "id", author)
playlists = database.fetch_all(scope_standard, author_id=user_id) playlists = database.fetch_all(scope_standard, scope_id, author_id=user_id)
else: else:
playlists = database.fetch_all(scope_standard) playlists = database.fetch_all(scope_standard, scope_id)
return [
await Playlist.from_json(
bot, scope, playlist.playlist_id, playlist, guild=guild, author=author
)
for playlist in playlists
]
async def get_all_playlist_converter(
scope: str,
bot: Red,
arg: str,
guild: Union[discord.Guild, int] = None,
author: Union[discord.abc.User, int] = None,
) -> List[Playlist]:
"""
Gets all playlist for the specified scope.
Parameters
----------
scope: str
The custom config scope. One of 'GLOBALPLAYLIST', 'GUILDPLAYLIST' or 'USERPLAYLIST'.
guild: discord.Guild
The guild to get the playlist from if scope is GUILDPLAYLIST.
author: int
The ID of the user to get the playlist from if scope is USERPLAYLIST.
bot: Red
The bot's instance
specified_user:bool
Whether or not user ID was passed as an argparse.
Returns
-------
list
A list of all playlists for the specified scope
Raises
------
`InvalidPlaylistScope`
Passing a scope that is not supported.
`MissingGuild`
Trying to access the Guild scope without a guild.
`MissingAuthor`
Trying to access the User scope without an user id.
"""
scope_standard, scope_id = _prepare_config_scope(scope, author, guild)
playlists = database.fetch_all_converter(scope_standard, playlist_name=arg, playlist_id=arg)
return [ return [
await Playlist.from_json( await Playlist.from_json(
bot, scope, playlist.playlist_id, playlist, guild=guild, author=author bot, scope, playlist.playlist_id, playlist, guild=guild, author=author