mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-07 11:48:55 -05:00
Use partial messages in Streams cog to avoid potential leakage (#4742)
* Use partial messages in Streams cog to avoid leakage * Stop trying to save bot object to Config... * Put guild id as part of message data * Fix AttributeError * Pass bot object to stream classes in commands * ugh * Another place we use this class in * more...
This commit is contained in:
parent
c25095ba2d
commit
67fa735555
@ -209,6 +209,7 @@ class Streams(commands.Cog):
|
||||
await self.maybe_renew_twitch_bearer_token()
|
||||
token = (await self.bot.get_shared_api_tokens("twitch")).get("client_id")
|
||||
stream = TwitchStream(
|
||||
_bot=self.bot,
|
||||
name=channel_name,
|
||||
token=token,
|
||||
bearer=self.ttv_bearer_cache.get("access_token", None),
|
||||
@ -224,21 +225,25 @@ class Streams(commands.Cog):
|
||||
apikey = await self.bot.get_shared_api_tokens("youtube")
|
||||
is_name = self.check_name_or_id(channel_id_or_name)
|
||||
if is_name:
|
||||
stream = YoutubeStream(name=channel_id_or_name, token=apikey, config=self.config)
|
||||
stream = YoutubeStream(
|
||||
_bot=self.bot, name=channel_id_or_name, token=apikey, config=self.config
|
||||
)
|
||||
else:
|
||||
stream = YoutubeStream(id=channel_id_or_name, token=apikey, config=self.config)
|
||||
stream = YoutubeStream(
|
||||
_bot=self.bot, id=channel_id_or_name, token=apikey, config=self.config
|
||||
)
|
||||
await self.check_online(ctx, stream)
|
||||
|
||||
@commands.command()
|
||||
async def smashcast(self, ctx: commands.Context, channel_name: str):
|
||||
"""Check if a smashcast channel is live."""
|
||||
stream = HitboxStream(name=channel_name)
|
||||
stream = HitboxStream(_bot=self.bot, name=channel_name)
|
||||
await self.check_online(ctx, stream)
|
||||
|
||||
@commands.command()
|
||||
async def picarto(self, ctx: commands.Context, channel_name: str):
|
||||
"""Check if a Picarto channel is live."""
|
||||
stream = PicartoStream(name=channel_name)
|
||||
stream = PicartoStream(_bot=self.bot, name=channel_name)
|
||||
await self.check_online(ctx, stream)
|
||||
|
||||
async def check_online(
|
||||
@ -396,19 +401,22 @@ class Streams(commands.Cog):
|
||||
is_yt = _class.__name__ == "YoutubeStream"
|
||||
is_twitch = _class.__name__ == "TwitchStream"
|
||||
if is_yt and not self.check_name_or_id(channel_name):
|
||||
stream = _class(id=channel_name, token=token, config=self.config)
|
||||
stream = _class(_bot=self.bot, id=channel_name, token=token, config=self.config)
|
||||
elif is_twitch:
|
||||
await self.maybe_renew_twitch_bearer_token()
|
||||
stream = _class(
|
||||
_bot=self.bot,
|
||||
name=channel_name,
|
||||
token=token.get("client_id"),
|
||||
bearer=self.ttv_bearer_cache.get("access_token", None),
|
||||
)
|
||||
else:
|
||||
if is_yt:
|
||||
stream = _class(name=channel_name, token=token, config=self.config)
|
||||
stream = _class(
|
||||
_bot=self.bot, name=channel_name, token=token, config=self.config
|
||||
)
|
||||
else:
|
||||
stream = _class(name=channel_name, token=token)
|
||||
stream = _class(_bot=self.bot, name=channel_name, token=token)
|
||||
try:
|
||||
exists = await self.check_exists(stream)
|
||||
except InvalidTwitchCredentials:
|
||||
@ -714,14 +722,23 @@ class Streams(commands.Cog):
|
||||
await asyncio.sleep(await self.config.refresh_timer())
|
||||
|
||||
async def _send_stream_alert(
|
||||
self, stream, channel: discord.TextChannel, embed: discord.Embed, content: str = None
|
||||
self,
|
||||
stream,
|
||||
channel: discord.TextChannel,
|
||||
embed: discord.Embed,
|
||||
content: str = None,
|
||||
*,
|
||||
is_schedule: bool = False,
|
||||
):
|
||||
m = await channel.send(
|
||||
content,
|
||||
embed=embed,
|
||||
allowed_mentions=discord.AllowedMentions(roles=True, everyone=True),
|
||||
)
|
||||
stream._messages_cache.append(m)
|
||||
message_data = {"guild": m.guild.id, "channel": m.channel.id, "message": m.id}
|
||||
if is_schedule:
|
||||
message_data["is_schedule"] = True
|
||||
stream.messages.append(message_data)
|
||||
|
||||
async def check_streams(self):
|
||||
to_remove = []
|
||||
@ -744,19 +761,25 @@ class Streams(commands.Cog):
|
||||
to_remove.append(stream)
|
||||
continue
|
||||
except OfflineStream:
|
||||
if not stream._messages_cache:
|
||||
if not stream.messages:
|
||||
continue
|
||||
for message in stream._messages_cache:
|
||||
if await self.bot.cog_disabled_in_guild(self, message.guild):
|
||||
|
||||
for msg_data in stream.iter_messages():
|
||||
partial_msg = msg_data["partial_message"]
|
||||
if partial_msg is None:
|
||||
continue
|
||||
autodelete = await self.config.guild(message.guild).autodelete()
|
||||
if autodelete:
|
||||
if await self.bot.cog_disabled_in_guild(self, partial_msg.guild):
|
||||
continue
|
||||
if not await self.config.guild(partial_msg.guild).autodelete():
|
||||
continue
|
||||
|
||||
with contextlib.suppress(discord.NotFound):
|
||||
await message.delete()
|
||||
stream._messages_cache.clear()
|
||||
await partial_msg.delete()
|
||||
|
||||
stream.messages.clear()
|
||||
await self.save_streams()
|
||||
else:
|
||||
if stream._messages_cache:
|
||||
if stream.messages:
|
||||
continue
|
||||
for channel_id in stream.channels:
|
||||
channel = self.bot.get_channel(channel_id)
|
||||
@ -772,7 +795,7 @@ class Streams(commands.Cog):
|
||||
continue
|
||||
if is_schedule:
|
||||
# skip messages and mentions
|
||||
await self._send_stream_alert(stream, channel, embed)
|
||||
await self._send_stream_alert(stream, channel, embed, is_schedule=True)
|
||||
await self.save_streams()
|
||||
continue
|
||||
await set_contextual_locales_from_guild(self.bot, channel.guild)
|
||||
@ -874,17 +897,6 @@ class Streams(commands.Cog):
|
||||
_class = getattr(_streamtypes, raw_stream["type"], None)
|
||||
if not _class:
|
||||
continue
|
||||
raw_msg_cache = raw_stream["messages"]
|
||||
raw_stream["_messages_cache"] = []
|
||||
for raw_msg in raw_msg_cache:
|
||||
chn = self.bot.get_channel(raw_msg["channel"])
|
||||
if chn is not None:
|
||||
try:
|
||||
msg = await chn.fetch_message(raw_msg["message"])
|
||||
except discord.HTTPException:
|
||||
pass
|
||||
else:
|
||||
raw_stream["_messages_cache"].append(msg)
|
||||
token = await self.bot.get_shared_api_tokens(_class.token_name)
|
||||
if token:
|
||||
if _class.__name__ == "TwitchStream":
|
||||
@ -894,6 +906,7 @@ class Streams(commands.Cog):
|
||||
if _class.__name__ == "YoutubeStream":
|
||||
raw_stream["config"] = self.config
|
||||
raw_stream["token"] = token
|
||||
raw_stream["_bot"] = self.bot
|
||||
streams.append(_class(**raw_stream))
|
||||
|
||||
return streams
|
||||
|
||||
@ -58,10 +58,11 @@ class Stream:
|
||||
token_name: ClassVar[Optional[str]] = None
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
self._bot = kwargs.pop("_bot")
|
||||
self.name = kwargs.pop("name", None)
|
||||
self.channels = kwargs.pop("channels", [])
|
||||
# self.already_online = kwargs.pop("already_online", False)
|
||||
self._messages_cache = kwargs.pop("_messages_cache", [])
|
||||
self.messages = kwargs.pop("messages", [])
|
||||
self.type = self.__class__.__name__
|
||||
|
||||
async def is_online(self):
|
||||
@ -70,14 +71,24 @@ class Stream:
|
||||
def make_embed(self):
|
||||
raise NotImplementedError()
|
||||
|
||||
def iter_messages(self):
|
||||
for msg_data in self.messages:
|
||||
data = msg_data.copy()
|
||||
# "guild" key might not exist for old config data (available since GH-4742)
|
||||
if guild_id := msg_data.get("guild"):
|
||||
guild = self._bot.get_guild(guild_id)
|
||||
channel = guild and guild.get_channel(msg_data["channel"])
|
||||
else:
|
||||
channel = self._bot.get_channel(msg_data["channel"])
|
||||
if channel is not None:
|
||||
data["partial_message"] = channel.get_partial_message(data["message"])
|
||||
yield data
|
||||
|
||||
def export(self):
|
||||
data = {}
|
||||
for k, v in self.__dict__.items():
|
||||
if not k.startswith("_"):
|
||||
data[k] = v
|
||||
data["messages"] = []
|
||||
for m in self._messages_cache:
|
||||
data["messages"].append({"channel": m.channel.id, "message": m.id})
|
||||
return data
|
||||
|
||||
def __repr__(self):
|
||||
@ -211,17 +222,21 @@ class YoutubeStream(Stream):
|
||||
embed.timestamp = start_time
|
||||
is_schedule = True
|
||||
else:
|
||||
# repost message
|
||||
# delete the message(s) about the stream schedule
|
||||
to_remove = []
|
||||
for message in self._messages_cache:
|
||||
if message.embeds[0].description is discord.Embed.Empty:
|
||||
for msg_data in self.iter_messages():
|
||||
if not msg_data.get("is_schedule", False):
|
||||
continue
|
||||
with contextlib.suppress(Exception):
|
||||
autodelete = await self._config.guild(message.guild).autodelete()
|
||||
partial_msg = msg_data["partial_message"]
|
||||
if partial_msg is not None:
|
||||
autodelete = await self._config.guild(partial_msg.guild).autodelete()
|
||||
if autodelete:
|
||||
await message.delete()
|
||||
to_remove.append(message.id)
|
||||
self._messages_cache = [x for x in self._messages_cache if x.id not in to_remove]
|
||||
with contextlib.suppress(discord.NotFound):
|
||||
await partial_msg.delete()
|
||||
to_remove.append(msg_data["message"])
|
||||
self.messages = [
|
||||
data for data in self.messages if data["message"] not in to_remove
|
||||
]
|
||||
embed.set_author(name=channel_title)
|
||||
embed.set_image(url=rnd(thumbnail))
|
||||
embed.colour = 0x9255A5
|
||||
|
||||
Loading…
x
Reference in New Issue
Block a user