Use partial messages in Streams cog to avoid potential leakage (#4742)

* Use partial messages in Streams cog to avoid leakage

* Stop trying to save bot object to Config...

* Put guild id as part of message data

* Fix AttributeError

* Pass bot object to stream classes in commands

* ugh

* Another place we use this class in

* more...
This commit is contained in:
jack1142 2021-04-05 21:39:33 +02:00 committed by GitHub
parent c25095ba2d
commit 67fa735555
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 70 additions and 42 deletions

View File

@ -209,6 +209,7 @@ class Streams(commands.Cog):
await self.maybe_renew_twitch_bearer_token() await self.maybe_renew_twitch_bearer_token()
token = (await self.bot.get_shared_api_tokens("twitch")).get("client_id") token = (await self.bot.get_shared_api_tokens("twitch")).get("client_id")
stream = TwitchStream( stream = TwitchStream(
_bot=self.bot,
name=channel_name, name=channel_name,
token=token, token=token,
bearer=self.ttv_bearer_cache.get("access_token", None), bearer=self.ttv_bearer_cache.get("access_token", None),
@ -224,21 +225,25 @@ class Streams(commands.Cog):
apikey = await self.bot.get_shared_api_tokens("youtube") apikey = await self.bot.get_shared_api_tokens("youtube")
is_name = self.check_name_or_id(channel_id_or_name) is_name = self.check_name_or_id(channel_id_or_name)
if is_name: if is_name:
stream = YoutubeStream(name=channel_id_or_name, token=apikey, config=self.config) stream = YoutubeStream(
_bot=self.bot, name=channel_id_or_name, token=apikey, config=self.config
)
else: else:
stream = YoutubeStream(id=channel_id_or_name, token=apikey, config=self.config) stream = YoutubeStream(
_bot=self.bot, id=channel_id_or_name, token=apikey, config=self.config
)
await self.check_online(ctx, stream) await self.check_online(ctx, stream)
@commands.command() @commands.command()
async def smashcast(self, ctx: commands.Context, channel_name: str): async def smashcast(self, ctx: commands.Context, channel_name: str):
"""Check if a smashcast channel is live.""" """Check if a smashcast channel is live."""
stream = HitboxStream(name=channel_name) stream = HitboxStream(_bot=self.bot, name=channel_name)
await self.check_online(ctx, stream) await self.check_online(ctx, stream)
@commands.command() @commands.command()
async def picarto(self, ctx: commands.Context, channel_name: str): async def picarto(self, ctx: commands.Context, channel_name: str):
"""Check if a Picarto channel is live.""" """Check if a Picarto channel is live."""
stream = PicartoStream(name=channel_name) stream = PicartoStream(_bot=self.bot, name=channel_name)
await self.check_online(ctx, stream) await self.check_online(ctx, stream)
async def check_online( async def check_online(
@ -396,19 +401,22 @@ class Streams(commands.Cog):
is_yt = _class.__name__ == "YoutubeStream" is_yt = _class.__name__ == "YoutubeStream"
is_twitch = _class.__name__ == "TwitchStream" is_twitch = _class.__name__ == "TwitchStream"
if is_yt and not self.check_name_or_id(channel_name): if is_yt and not self.check_name_or_id(channel_name):
stream = _class(id=channel_name, token=token, config=self.config) stream = _class(_bot=self.bot, id=channel_name, token=token, config=self.config)
elif is_twitch: elif is_twitch:
await self.maybe_renew_twitch_bearer_token() await self.maybe_renew_twitch_bearer_token()
stream = _class( stream = _class(
_bot=self.bot,
name=channel_name, name=channel_name,
token=token.get("client_id"), token=token.get("client_id"),
bearer=self.ttv_bearer_cache.get("access_token", None), bearer=self.ttv_bearer_cache.get("access_token", None),
) )
else: else:
if is_yt: if is_yt:
stream = _class(name=channel_name, token=token, config=self.config) stream = _class(
_bot=self.bot, name=channel_name, token=token, config=self.config
)
else: else:
stream = _class(name=channel_name, token=token) stream = _class(_bot=self.bot, name=channel_name, token=token)
try: try:
exists = await self.check_exists(stream) exists = await self.check_exists(stream)
except InvalidTwitchCredentials: except InvalidTwitchCredentials:
@ -714,14 +722,23 @@ class Streams(commands.Cog):
await asyncio.sleep(await self.config.refresh_timer()) await asyncio.sleep(await self.config.refresh_timer())
async def _send_stream_alert( async def _send_stream_alert(
self, stream, channel: discord.TextChannel, embed: discord.Embed, content: str = None self,
stream,
channel: discord.TextChannel,
embed: discord.Embed,
content: str = None,
*,
is_schedule: bool = False,
): ):
m = await channel.send( m = await channel.send(
content, content,
embed=embed, embed=embed,
allowed_mentions=discord.AllowedMentions(roles=True, everyone=True), allowed_mentions=discord.AllowedMentions(roles=True, everyone=True),
) )
stream._messages_cache.append(m) message_data = {"guild": m.guild.id, "channel": m.channel.id, "message": m.id}
if is_schedule:
message_data["is_schedule"] = True
stream.messages.append(message_data)
async def check_streams(self): async def check_streams(self):
to_remove = [] to_remove = []
@ -744,19 +761,25 @@ class Streams(commands.Cog):
to_remove.append(stream) to_remove.append(stream)
continue continue
except OfflineStream: except OfflineStream:
if not stream._messages_cache: if not stream.messages:
continue continue
for message in stream._messages_cache:
if await self.bot.cog_disabled_in_guild(self, message.guild): for msg_data in stream.iter_messages():
partial_msg = msg_data["partial_message"]
if partial_msg is None:
continue continue
autodelete = await self.config.guild(message.guild).autodelete() if await self.bot.cog_disabled_in_guild(self, partial_msg.guild):
if autodelete: continue
with contextlib.suppress(discord.NotFound): if not await self.config.guild(partial_msg.guild).autodelete():
await message.delete() continue
stream._messages_cache.clear()
with contextlib.suppress(discord.NotFound):
await partial_msg.delete()
stream.messages.clear()
await self.save_streams() await self.save_streams()
else: else:
if stream._messages_cache: if stream.messages:
continue continue
for channel_id in stream.channels: for channel_id in stream.channels:
channel = self.bot.get_channel(channel_id) channel = self.bot.get_channel(channel_id)
@ -772,7 +795,7 @@ class Streams(commands.Cog):
continue continue
if is_schedule: if is_schedule:
# skip messages and mentions # skip messages and mentions
await self._send_stream_alert(stream, channel, embed) await self._send_stream_alert(stream, channel, embed, is_schedule=True)
await self.save_streams() await self.save_streams()
continue continue
await set_contextual_locales_from_guild(self.bot, channel.guild) await set_contextual_locales_from_guild(self.bot, channel.guild)
@ -874,17 +897,6 @@ class Streams(commands.Cog):
_class = getattr(_streamtypes, raw_stream["type"], None) _class = getattr(_streamtypes, raw_stream["type"], None)
if not _class: if not _class:
continue continue
raw_msg_cache = raw_stream["messages"]
raw_stream["_messages_cache"] = []
for raw_msg in raw_msg_cache:
chn = self.bot.get_channel(raw_msg["channel"])
if chn is not None:
try:
msg = await chn.fetch_message(raw_msg["message"])
except discord.HTTPException:
pass
else:
raw_stream["_messages_cache"].append(msg)
token = await self.bot.get_shared_api_tokens(_class.token_name) token = await self.bot.get_shared_api_tokens(_class.token_name)
if token: if token:
if _class.__name__ == "TwitchStream": if _class.__name__ == "TwitchStream":
@ -894,6 +906,7 @@ class Streams(commands.Cog):
if _class.__name__ == "YoutubeStream": if _class.__name__ == "YoutubeStream":
raw_stream["config"] = self.config raw_stream["config"] = self.config
raw_stream["token"] = token raw_stream["token"] = token
raw_stream["_bot"] = self.bot
streams.append(_class(**raw_stream)) streams.append(_class(**raw_stream))
return streams return streams

View File

@ -58,10 +58,11 @@ class Stream:
token_name: ClassVar[Optional[str]] = None token_name: ClassVar[Optional[str]] = None
def __init__(self, **kwargs): def __init__(self, **kwargs):
self._bot = kwargs.pop("_bot")
self.name = kwargs.pop("name", None) self.name = kwargs.pop("name", None)
self.channels = kwargs.pop("channels", []) self.channels = kwargs.pop("channels", [])
# self.already_online = kwargs.pop("already_online", False) # self.already_online = kwargs.pop("already_online", False)
self._messages_cache = kwargs.pop("_messages_cache", []) self.messages = kwargs.pop("messages", [])
self.type = self.__class__.__name__ self.type = self.__class__.__name__
async def is_online(self): async def is_online(self):
@ -70,14 +71,24 @@ class Stream:
def make_embed(self): def make_embed(self):
raise NotImplementedError() raise NotImplementedError()
def iter_messages(self):
for msg_data in self.messages:
data = msg_data.copy()
# "guild" key might not exist for old config data (available since GH-4742)
if guild_id := msg_data.get("guild"):
guild = self._bot.get_guild(guild_id)
channel = guild and guild.get_channel(msg_data["channel"])
else:
channel = self._bot.get_channel(msg_data["channel"])
if channel is not None:
data["partial_message"] = channel.get_partial_message(data["message"])
yield data
def export(self): def export(self):
data = {} data = {}
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if not k.startswith("_"): if not k.startswith("_"):
data[k] = v data[k] = v
data["messages"] = []
for m in self._messages_cache:
data["messages"].append({"channel": m.channel.id, "message": m.id})
return data return data
def __repr__(self): def __repr__(self):
@ -211,17 +222,21 @@ class YoutubeStream(Stream):
embed.timestamp = start_time embed.timestamp = start_time
is_schedule = True is_schedule = True
else: else:
# repost message # delete the message(s) about the stream schedule
to_remove = [] to_remove = []
for message in self._messages_cache: for msg_data in self.iter_messages():
if message.embeds[0].description is discord.Embed.Empty: if not msg_data.get("is_schedule", False):
continue continue
with contextlib.suppress(Exception): partial_msg = msg_data["partial_message"]
autodelete = await self._config.guild(message.guild).autodelete() if partial_msg is not None:
autodelete = await self._config.guild(partial_msg.guild).autodelete()
if autodelete: if autodelete:
await message.delete() with contextlib.suppress(discord.NotFound):
to_remove.append(message.id) await partial_msg.delete()
self._messages_cache = [x for x in self._messages_cache if x.id not in to_remove] to_remove.append(msg_data["message"])
self.messages = [
data for data in self.messages if data["message"] not in to_remove
]
embed.set_author(name=channel_title) embed.set_author(name=channel_title)
embed.set_image(url=rnd(thumbnail)) embed.set_image(url=rnd(thumbnail))
embed.colour = 0x9255A5 embed.colour = 0x9255A5