[V3 Audio] Restrict toggle for commercial sites (#2245)

* [V3 Audio] Restrict toggle for commercial sites

* Different url parsing

* Allow local tracks

* No self needed

* Change Twitch url
This commit is contained in:
aikaterna 2018-12-13 09:31:24 -08:00 committed by Kowlin
parent 30c3a4c7c1
commit 3b50ed8192

View File

@ -29,7 +29,7 @@ from .manager import shutdown_lavalink_server
_ = Translator("Audio", __file__) _ = Translator("Audio", __file__)
__version__ = "0.0.7" __version__ = "0.0.8"
__author__ = ["aikaterna", "billy/bollo/ati"] __author__ = ["aikaterna", "billy/bollo/ati"]
@ -50,6 +50,7 @@ class Audio(commands.Cog):
"status": False, "status": False,
"current_version": redbot.core.VersionInfo.from_str("3.0.0a0").to_json(), "current_version": redbot.core.VersionInfo.from_str("3.0.0a0").to_json(),
"use_external_lavalink": False, "use_external_lavalink": False,
"restrict": True,
} }
default_guild = { default_guild = {
@ -319,6 +320,19 @@ class Audio(commands.Cog):
ctx, _("Verbose mode on: {true_or_false}.").format(true_or_false=not notify) ctx, _("Verbose mode on: {true_or_false}.").format(true_or_false=not notify)
) )
@audioset.command()
@checks.is_owner()
async def restrict(self, ctx):
"""Toggle the domain restriction on Audio.
When toggled off, users will be able to play songs from non-commercial websites and links.
When toggled on, users are restricted to YouTube, SoundCloud, Mixer, Vimeo, Twitch, and Bandcamp links."""
restrict = await self.config.restrict()
await self.config.restrict.set(not restrict)
await self._embed_msg(
ctx, _("Commercial links only: {true_or_false}.").format(true_or_false=not restrict)
)
@audioset.command() @audioset.command()
async def settings(self, ctx): async def settings(self, ctx):
"""Show the current settings.""" """Show the current settings."""
@ -870,6 +884,12 @@ class Audio(commands.Cog):
dj_enabled = await self.config.guild(ctx.guild).dj_enabled() dj_enabled = await self.config.guild(ctx.guild).dj_enabled()
jukebox_price = await self.config.guild(ctx.guild).jukebox_price() jukebox_price = await self.config.guild(ctx.guild).jukebox_price()
shuffle = await self.config.guild(ctx.guild).shuffle() shuffle = await self.config.guild(ctx.guild).shuffle()
restrict = await self.config.restrict()
if restrict:
if self._match_url(query):
url_check = self._url_check(query)
if not url_check:
return await self._embed_msg(ctx, _("That URL is not allowed."))
if not self._player_check(ctx): if not self._player_check(ctx):
try: try:
if not ctx.author.voice.channel.permissions_for(ctx.me).connect or self._userlimit( if not ctx.author.voice.channel.permissions_for(ctx.me).connect or self._userlimit(
@ -1252,6 +1272,7 @@ class Audio(commands.Cog):
@playlist.command(name="start") @playlist.command(name="start")
async def _playlist_start(self, ctx, playlist_name=None): async def _playlist_start(self, ctx, playlist_name=None):
"""Load a playlist into the queue.""" """Load a playlist into the queue."""
restrict = await self.config.restrict()
if not await self._playlist_check(ctx): if not await self._playlist_check(ctx):
return return
playlists = await self.config.guild(ctx.guild).playlists.get_raw() playlists = await self.config.guild(ctx.guild).playlists.get_raw()
@ -1260,6 +1281,10 @@ class Audio(commands.Cog):
try: try:
player = lavalink.get_player(ctx.guild.id) player = lavalink.get_player(ctx.guild.id)
for track in playlists[playlist_name]["tracks"]: for track in playlists[playlist_name]["tracks"]:
if restrict:
url_check = self._url_check(track["info"]["uri"])
if not url_check:
continue
player.add(author_obj, lavalink.rest_api.Track(data=track)) player.add(author_obj, lavalink.rest_api.Track(data=track))
track_count = track_count + 1 track_count = track_count + 1
embed = discord.Embed( embed = discord.Embed(
@ -2586,6 +2611,24 @@ class Audio(commands.Cog):
track_obj[key] = value track_obj[key] = value
return track_obj return track_obj
@staticmethod
def _url_check(url):
valid_tld = [
"youtube.com",
"youtu.be",
"soundcloud.com",
"bandcamp.com",
"vimeo.com",
"mixer.com",
"twitch.tv",
"localtracks",
]
query_url = urlparse(url)
url_domain = ".".join(query_url.netloc.split(".")[-2:])
if not query_url.netloc:
url_domain = ".".join(query_url.path.split("/")[0].split(".")[-2:])
return True if url_domain in valid_tld else False
@staticmethod @staticmethod
def _userlimit(channel): def _userlimit(channel):
if channel.user_limit == 0: if channel.user_limit == 0: