mirror of
https://github.com/Cog-Creators/Red-DiscordBot.git
synced 2025-11-07 11:48:55 -05:00
Use partial messages in Streams cog to avoid potential leakage (#4742)
* Use partial messages in Streams cog to avoid leakage * Stop trying to save bot object to Config... * Put guild id as part of message data * Fix AttributeError * Pass bot object to stream classes in commands * ugh * Another place we use this class in * more...
This commit is contained in:
parent
c25095ba2d
commit
67fa735555
@ -209,6 +209,7 @@ class Streams(commands.Cog):
|
|||||||
await self.maybe_renew_twitch_bearer_token()
|
await self.maybe_renew_twitch_bearer_token()
|
||||||
token = (await self.bot.get_shared_api_tokens("twitch")).get("client_id")
|
token = (await self.bot.get_shared_api_tokens("twitch")).get("client_id")
|
||||||
stream = TwitchStream(
|
stream = TwitchStream(
|
||||||
|
_bot=self.bot,
|
||||||
name=channel_name,
|
name=channel_name,
|
||||||
token=token,
|
token=token,
|
||||||
bearer=self.ttv_bearer_cache.get("access_token", None),
|
bearer=self.ttv_bearer_cache.get("access_token", None),
|
||||||
@ -224,21 +225,25 @@ class Streams(commands.Cog):
|
|||||||
apikey = await self.bot.get_shared_api_tokens("youtube")
|
apikey = await self.bot.get_shared_api_tokens("youtube")
|
||||||
is_name = self.check_name_or_id(channel_id_or_name)
|
is_name = self.check_name_or_id(channel_id_or_name)
|
||||||
if is_name:
|
if is_name:
|
||||||
stream = YoutubeStream(name=channel_id_or_name, token=apikey, config=self.config)
|
stream = YoutubeStream(
|
||||||
|
_bot=self.bot, name=channel_id_or_name, token=apikey, config=self.config
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
stream = YoutubeStream(id=channel_id_or_name, token=apikey, config=self.config)
|
stream = YoutubeStream(
|
||||||
|
_bot=self.bot, id=channel_id_or_name, token=apikey, config=self.config
|
||||||
|
)
|
||||||
await self.check_online(ctx, stream)
|
await self.check_online(ctx, stream)
|
||||||
|
|
||||||
@commands.command()
|
@commands.command()
|
||||||
async def smashcast(self, ctx: commands.Context, channel_name: str):
|
async def smashcast(self, ctx: commands.Context, channel_name: str):
|
||||||
"""Check if a smashcast channel is live."""
|
"""Check if a smashcast channel is live."""
|
||||||
stream = HitboxStream(name=channel_name)
|
stream = HitboxStream(_bot=self.bot, name=channel_name)
|
||||||
await self.check_online(ctx, stream)
|
await self.check_online(ctx, stream)
|
||||||
|
|
||||||
@commands.command()
|
@commands.command()
|
||||||
async def picarto(self, ctx: commands.Context, channel_name: str):
|
async def picarto(self, ctx: commands.Context, channel_name: str):
|
||||||
"""Check if a Picarto channel is live."""
|
"""Check if a Picarto channel is live."""
|
||||||
stream = PicartoStream(name=channel_name)
|
stream = PicartoStream(_bot=self.bot, name=channel_name)
|
||||||
await self.check_online(ctx, stream)
|
await self.check_online(ctx, stream)
|
||||||
|
|
||||||
async def check_online(
|
async def check_online(
|
||||||
@ -396,19 +401,22 @@ class Streams(commands.Cog):
|
|||||||
is_yt = _class.__name__ == "YoutubeStream"
|
is_yt = _class.__name__ == "YoutubeStream"
|
||||||
is_twitch = _class.__name__ == "TwitchStream"
|
is_twitch = _class.__name__ == "TwitchStream"
|
||||||
if is_yt and not self.check_name_or_id(channel_name):
|
if is_yt and not self.check_name_or_id(channel_name):
|
||||||
stream = _class(id=channel_name, token=token, config=self.config)
|
stream = _class(_bot=self.bot, id=channel_name, token=token, config=self.config)
|
||||||
elif is_twitch:
|
elif is_twitch:
|
||||||
await self.maybe_renew_twitch_bearer_token()
|
await self.maybe_renew_twitch_bearer_token()
|
||||||
stream = _class(
|
stream = _class(
|
||||||
|
_bot=self.bot,
|
||||||
name=channel_name,
|
name=channel_name,
|
||||||
token=token.get("client_id"),
|
token=token.get("client_id"),
|
||||||
bearer=self.ttv_bearer_cache.get("access_token", None),
|
bearer=self.ttv_bearer_cache.get("access_token", None),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
if is_yt:
|
if is_yt:
|
||||||
stream = _class(name=channel_name, token=token, config=self.config)
|
stream = _class(
|
||||||
|
_bot=self.bot, name=channel_name, token=token, config=self.config
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
stream = _class(name=channel_name, token=token)
|
stream = _class(_bot=self.bot, name=channel_name, token=token)
|
||||||
try:
|
try:
|
||||||
exists = await self.check_exists(stream)
|
exists = await self.check_exists(stream)
|
||||||
except InvalidTwitchCredentials:
|
except InvalidTwitchCredentials:
|
||||||
@ -714,14 +722,23 @@ class Streams(commands.Cog):
|
|||||||
await asyncio.sleep(await self.config.refresh_timer())
|
await asyncio.sleep(await self.config.refresh_timer())
|
||||||
|
|
||||||
async def _send_stream_alert(
|
async def _send_stream_alert(
|
||||||
self, stream, channel: discord.TextChannel, embed: discord.Embed, content: str = None
|
self,
|
||||||
|
stream,
|
||||||
|
channel: discord.TextChannel,
|
||||||
|
embed: discord.Embed,
|
||||||
|
content: str = None,
|
||||||
|
*,
|
||||||
|
is_schedule: bool = False,
|
||||||
):
|
):
|
||||||
m = await channel.send(
|
m = await channel.send(
|
||||||
content,
|
content,
|
||||||
embed=embed,
|
embed=embed,
|
||||||
allowed_mentions=discord.AllowedMentions(roles=True, everyone=True),
|
allowed_mentions=discord.AllowedMentions(roles=True, everyone=True),
|
||||||
)
|
)
|
||||||
stream._messages_cache.append(m)
|
message_data = {"guild": m.guild.id, "channel": m.channel.id, "message": m.id}
|
||||||
|
if is_schedule:
|
||||||
|
message_data["is_schedule"] = True
|
||||||
|
stream.messages.append(message_data)
|
||||||
|
|
||||||
async def check_streams(self):
|
async def check_streams(self):
|
||||||
to_remove = []
|
to_remove = []
|
||||||
@ -744,19 +761,25 @@ class Streams(commands.Cog):
|
|||||||
to_remove.append(stream)
|
to_remove.append(stream)
|
||||||
continue
|
continue
|
||||||
except OfflineStream:
|
except OfflineStream:
|
||||||
if not stream._messages_cache:
|
if not stream.messages:
|
||||||
continue
|
continue
|
||||||
for message in stream._messages_cache:
|
|
||||||
if await self.bot.cog_disabled_in_guild(self, message.guild):
|
for msg_data in stream.iter_messages():
|
||||||
|
partial_msg = msg_data["partial_message"]
|
||||||
|
if partial_msg is None:
|
||||||
continue
|
continue
|
||||||
autodelete = await self.config.guild(message.guild).autodelete()
|
if await self.bot.cog_disabled_in_guild(self, partial_msg.guild):
|
||||||
if autodelete:
|
continue
|
||||||
with contextlib.suppress(discord.NotFound):
|
if not await self.config.guild(partial_msg.guild).autodelete():
|
||||||
await message.delete()
|
continue
|
||||||
stream._messages_cache.clear()
|
|
||||||
|
with contextlib.suppress(discord.NotFound):
|
||||||
|
await partial_msg.delete()
|
||||||
|
|
||||||
|
stream.messages.clear()
|
||||||
await self.save_streams()
|
await self.save_streams()
|
||||||
else:
|
else:
|
||||||
if stream._messages_cache:
|
if stream.messages:
|
||||||
continue
|
continue
|
||||||
for channel_id in stream.channels:
|
for channel_id in stream.channels:
|
||||||
channel = self.bot.get_channel(channel_id)
|
channel = self.bot.get_channel(channel_id)
|
||||||
@ -772,7 +795,7 @@ class Streams(commands.Cog):
|
|||||||
continue
|
continue
|
||||||
if is_schedule:
|
if is_schedule:
|
||||||
# skip messages and mentions
|
# skip messages and mentions
|
||||||
await self._send_stream_alert(stream, channel, embed)
|
await self._send_stream_alert(stream, channel, embed, is_schedule=True)
|
||||||
await self.save_streams()
|
await self.save_streams()
|
||||||
continue
|
continue
|
||||||
await set_contextual_locales_from_guild(self.bot, channel.guild)
|
await set_contextual_locales_from_guild(self.bot, channel.guild)
|
||||||
@ -874,17 +897,6 @@ class Streams(commands.Cog):
|
|||||||
_class = getattr(_streamtypes, raw_stream["type"], None)
|
_class = getattr(_streamtypes, raw_stream["type"], None)
|
||||||
if not _class:
|
if not _class:
|
||||||
continue
|
continue
|
||||||
raw_msg_cache = raw_stream["messages"]
|
|
||||||
raw_stream["_messages_cache"] = []
|
|
||||||
for raw_msg in raw_msg_cache:
|
|
||||||
chn = self.bot.get_channel(raw_msg["channel"])
|
|
||||||
if chn is not None:
|
|
||||||
try:
|
|
||||||
msg = await chn.fetch_message(raw_msg["message"])
|
|
||||||
except discord.HTTPException:
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
raw_stream["_messages_cache"].append(msg)
|
|
||||||
token = await self.bot.get_shared_api_tokens(_class.token_name)
|
token = await self.bot.get_shared_api_tokens(_class.token_name)
|
||||||
if token:
|
if token:
|
||||||
if _class.__name__ == "TwitchStream":
|
if _class.__name__ == "TwitchStream":
|
||||||
@ -894,6 +906,7 @@ class Streams(commands.Cog):
|
|||||||
if _class.__name__ == "YoutubeStream":
|
if _class.__name__ == "YoutubeStream":
|
||||||
raw_stream["config"] = self.config
|
raw_stream["config"] = self.config
|
||||||
raw_stream["token"] = token
|
raw_stream["token"] = token
|
||||||
|
raw_stream["_bot"] = self.bot
|
||||||
streams.append(_class(**raw_stream))
|
streams.append(_class(**raw_stream))
|
||||||
|
|
||||||
return streams
|
return streams
|
||||||
|
|||||||
@ -58,10 +58,11 @@ class Stream:
|
|||||||
token_name: ClassVar[Optional[str]] = None
|
token_name: ClassVar[Optional[str]] = None
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
|
self._bot = kwargs.pop("_bot")
|
||||||
self.name = kwargs.pop("name", None)
|
self.name = kwargs.pop("name", None)
|
||||||
self.channels = kwargs.pop("channels", [])
|
self.channels = kwargs.pop("channels", [])
|
||||||
# self.already_online = kwargs.pop("already_online", False)
|
# self.already_online = kwargs.pop("already_online", False)
|
||||||
self._messages_cache = kwargs.pop("_messages_cache", [])
|
self.messages = kwargs.pop("messages", [])
|
||||||
self.type = self.__class__.__name__
|
self.type = self.__class__.__name__
|
||||||
|
|
||||||
async def is_online(self):
|
async def is_online(self):
|
||||||
@ -70,14 +71,24 @@ class Stream:
|
|||||||
def make_embed(self):
|
def make_embed(self):
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
def iter_messages(self):
|
||||||
|
for msg_data in self.messages:
|
||||||
|
data = msg_data.copy()
|
||||||
|
# "guild" key might not exist for old config data (available since GH-4742)
|
||||||
|
if guild_id := msg_data.get("guild"):
|
||||||
|
guild = self._bot.get_guild(guild_id)
|
||||||
|
channel = guild and guild.get_channel(msg_data["channel"])
|
||||||
|
else:
|
||||||
|
channel = self._bot.get_channel(msg_data["channel"])
|
||||||
|
if channel is not None:
|
||||||
|
data["partial_message"] = channel.get_partial_message(data["message"])
|
||||||
|
yield data
|
||||||
|
|
||||||
def export(self):
|
def export(self):
|
||||||
data = {}
|
data = {}
|
||||||
for k, v in self.__dict__.items():
|
for k, v in self.__dict__.items():
|
||||||
if not k.startswith("_"):
|
if not k.startswith("_"):
|
||||||
data[k] = v
|
data[k] = v
|
||||||
data["messages"] = []
|
|
||||||
for m in self._messages_cache:
|
|
||||||
data["messages"].append({"channel": m.channel.id, "message": m.id})
|
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
@ -211,17 +222,21 @@ class YoutubeStream(Stream):
|
|||||||
embed.timestamp = start_time
|
embed.timestamp = start_time
|
||||||
is_schedule = True
|
is_schedule = True
|
||||||
else:
|
else:
|
||||||
# repost message
|
# delete the message(s) about the stream schedule
|
||||||
to_remove = []
|
to_remove = []
|
||||||
for message in self._messages_cache:
|
for msg_data in self.iter_messages():
|
||||||
if message.embeds[0].description is discord.Embed.Empty:
|
if not msg_data.get("is_schedule", False):
|
||||||
continue
|
continue
|
||||||
with contextlib.suppress(Exception):
|
partial_msg = msg_data["partial_message"]
|
||||||
autodelete = await self._config.guild(message.guild).autodelete()
|
if partial_msg is not None:
|
||||||
|
autodelete = await self._config.guild(partial_msg.guild).autodelete()
|
||||||
if autodelete:
|
if autodelete:
|
||||||
await message.delete()
|
with contextlib.suppress(discord.NotFound):
|
||||||
to_remove.append(message.id)
|
await partial_msg.delete()
|
||||||
self._messages_cache = [x for x in self._messages_cache if x.id not in to_remove]
|
to_remove.append(msg_data["message"])
|
||||||
|
self.messages = [
|
||||||
|
data for data in self.messages if data["message"] not in to_remove
|
||||||
|
]
|
||||||
embed.set_author(name=channel_title)
|
embed.set_author(name=channel_title)
|
||||||
embed.set_image(url=rnd(thumbnail))
|
embed.set_image(url=rnd(thumbnail))
|
||||||
embed.colour = 0x9255A5
|
embed.colour = 0x9255A5
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user