From 70ca8ff1f4c81852647aad924074991898398595 Mon Sep 17 00:00:00 2001 From: Jakub Kuczys Date: Fri, 12 May 2023 00:27:19 +0200 Subject: [PATCH] Add snippet numbers to filenames in the Dev cog to fix exception formatting (#6135) --- redbot/core/dev_commands.py | 172 +++++++++++++++++++++++--------- tests/core/test_dev_commands.py | 75 +++++++++++--- 2 files changed, 184 insertions(+), 63 deletions(-) diff --git a/redbot/core/dev_commands.py b/redbot/core/dev_commands.py index 4e32feede..5033dcd63 100644 --- a/redbot/core/dev_commands.py +++ b/redbot/core/dev_commands.py @@ -23,7 +23,7 @@ import types import re import sys from copy import copy -from typing import Any, Awaitable, Dict, Iterator, Literal, Type, TypeVar, Union +from typing import Any, Awaitable, Dict, Iterator, List, Literal, Tuple, Type, TypeVar, Union from types import CodeType, TracebackType import discord @@ -83,13 +83,55 @@ def cleanup_code(content: str) -> str: return content.strip("` \n") +class SourceCache: + MAX_SIZE = 1000 + + def __init__(self) -> None: + # estimated to take less than 100 kB + self._data: Dict[str, Tuple[str, int]] = {} + # this just keeps going up until the bot is restarted, shouldn't really be an issue + self._next_index = 0 + + def take_next_index(self) -> int: + next_index = self._next_index + self._next_index += 1 + return next_index + + def __getitem__(self, key: str) -> Tuple[List[str], int]: + value = self._data.pop(key) # pop to put it at the end as most recent + self._data[key] = value + # To mimic linecache module's behavior, + # all lines (including the last one) should end with \n. + source_lines = [f"{line}\n" for line in value[0].splitlines()] + # Note: while it might seem like a waste of time to always calculate the list of source lines, + # this is a necessary memory optimization. If all of the data in `self._data` were list, + # it could theoretically take up to 1000x as much memory. + return source_lines, value[1] + + def __setitem__(self, key: str, value: Tuple[str, int]) -> None: + self._data.pop(key, None) + self._data[key] = value + if len(self._data) > self.MAX_SIZE: + del self._data[next(iter(self._data))] + + class DevOutput: def __init__( - self, ctx: commands.Context, *, source: str, filename: str, env: Dict[str, Any] + self, + ctx: commands.Context, + *, + source_cache: SourceCache, + filename: str, + source: str, + env: Dict[str, Any], ) -> None: self.ctx = ctx - self.source = source + self.source_cache = source_cache self.filename = filename + self.source_line_offset = 0 + #: raw source - as received from the command after stripping the code block + self.raw_source = source + self.set_compilable_source(source) self.env = env self.always_include_result = False self._stream = io.StringIO() @@ -98,12 +140,14 @@ class DevOutput: self._old_streams = [] @property - def source(self) -> str: - return self._original_source + def compilable_source(self) -> str: + """Source string that we pass to async_compile().""" + return self._compilable_source - @source.setter - def source(self, value: str) -> None: - self._source = self._original_source = value + def set_compilable_source(self, compilable_source: str, *, line_offset: int = 0) -> None: + self._compilable_source = compilable_source + self.source_line_offset = line_offset + self.source_cache[self.filename] = (compilable_source, line_offset) def __str__(self) -> str: output = [] @@ -124,10 +168,8 @@ class DevOutput: if tick and not self.formatted_exc: await self.ctx.tick() - def set_exception(self, exc: Exception, *, line_offset: int = 0, skip_frames: int = 1) -> None: - self.formatted_exc = self.format_exception( - exc, line_offset=line_offset, skip_frames=skip_frames - ) + def set_exception(self, exc: Exception, *, skip_frames: int = 1) -> None: + self.formatted_exc = self.format_exception(exc, skip_frames=skip_frames) def __enter__(self) -> None: self._old_streams.append(sys.stdout) @@ -144,31 +186,49 @@ class DevOutput: @classmethod async def from_debug( - cls, ctx: commands.Context, *, source: str, env: Dict[str, Any] + cls, ctx: commands.Context, *, source: str, source_cache: SourceCache, env: Dict[str, Any] ) -> DevOutput: - output = cls(ctx, source=source, filename="", env=env) + output = cls( + ctx, + source=source, + source_cache=source_cache, + filename=f"", + env=env, + ) await output.run_debug() return output @classmethod async def from_eval( - cls, ctx: commands.Context, *, source: str, env: Dict[str, Any] + cls, ctx: commands.Context, *, source: str, source_cache: SourceCache, env: Dict[str, Any] ) -> DevOutput: - output = cls(ctx, source=source, filename="", env=env) + output = cls( + ctx, + source=source, + source_cache=source_cache, + filename=f"", + env=env, + ) await output.run_eval() return output @classmethod async def from_repl( - cls, ctx: commands.Context, *, source: str, env: Dict[str, Any] + cls, ctx: commands.Context, *, source: str, source_cache: SourceCache, env: Dict[str, Any] ) -> DevOutput: - output = cls(ctx, source=source, filename="", env=env) + output = cls( + ctx, + source=source, + source_cache=source_cache, + filename=f"", + env=env, + ) await output.run_repl() return output async def run_debug(self) -> None: self.always_include_result = True - self._source = self.source + self.set_compilable_source(self.raw_source) try: compiled = self.async_compile_with_eval() except SyntaxError as exc: @@ -182,12 +242,14 @@ class DevOutput: async def run_eval(self) -> None: self.always_include_result = False - self._source = "async def func():\n%s" % textwrap.indent(self.source, " ") + self.set_compilable_source( + "async def func():\n%s" % textwrap.indent(self.raw_source, " "), line_offset=1 + ) try: compiled = self.async_compile_with_exec() exec(compiled, self.env) except SyntaxError as exc: - self.set_exception(exc, line_offset=1, skip_frames=3) + self.set_exception(exc, skip_frames=3) return func = self.env["func"] @@ -195,13 +257,13 @@ class DevOutput: with self: self.result = await func() except Exception as exc: - self.set_exception(exc, line_offset=1) + self.set_exception(exc) async def run_repl(self) -> None: self.always_include_result = False - self._source = self.source + self.set_compilable_source(self.raw_source) executor = None - if self.source.count("\n") == 0: + if self.raw_source.count("\n") == 0: # single statement, potentially 'eval' try: code = self.async_compile_with_eval() @@ -231,14 +293,12 @@ class DevOutput: self.env["_"] = self.result def async_compile_with_exec(self) -> CodeType: - return async_compile(self._source, self.filename, "exec") + return async_compile(self.compilable_source, self.filename, "exec") def async_compile_with_eval(self) -> CodeType: - return async_compile(self._source, self.filename, "eval") + return async_compile(self.compilable_source, self.filename, "eval") - def format_exception( - self, exc: Exception, *, line_offset: int = 0, skip_frames: int = 1 - ) -> str: + def format_exception(self, exc: Exception, *, skip_frames: int = 1) -> str: """ Format an exception to send to the user. @@ -260,33 +320,44 @@ class DevOutput: break tb = tb.tb_next - # To mimic linecache module's behavior, - # all lines (including the last one) should end with \n. - source_lines = [f"{line}\n" for line in self._source.splitlines()] filename = self.filename # sometimes SyntaxError.text is None, sometimes it isn't - if ( - issubclass(exc_type, SyntaxError) - and exc.filename == filename - and exc.lineno is not None - ): - if exc.text is None: - # line numbers are 1-based, the list indexes are 0-based - exc.text = source_lines[exc.lineno - 1] - exc.lineno -= line_offset + if issubclass(exc_type, SyntaxError) and exc.lineno is not None: + try: + source_lines, line_offset = self.source_cache[exc.filename] + except KeyError: + pass + else: + if exc.text is None: + try: + # line numbers are 1-based, the list indexes are 0-based + exc.text = source_lines[exc.lineno - 1] + except IndexError: + # the frame might be pointing at a different source code, ignore... + pass + else: + exc.lineno -= line_offset + else: + exc.lineno -= line_offset traceback_exc = traceback.TracebackException(exc_type, exc, tb) py311_or_above = sys.version_info >= (3, 11) stack_summary = traceback_exc.stack for idx, frame_summary in enumerate(stack_summary): - if frame_summary.filename != filename: + try: + source_lines, line_offset = self.source_cache[frame_summary.filename] + except KeyError: continue lineno = frame_summary.lineno if lineno is None: continue - # line numbers are 1-based, the list indexes are 0-based - line = source_lines[lineno - 1] + try: + # line numbers are 1-based, the list indexes are 0-based + line = source_lines[lineno - 1] + except IndexError: + # the frame might be pointing at a different source code, ignore... + continue lineno -= line_offset # support for enhanced error locations in tracebacks if py311_or_above: @@ -327,6 +398,7 @@ class Dev(commands.Cog): self._last_result = None self.sessions = {} self.env_extensions = {} + self.source_cache = SourceCache() def get_environment(self, ctx: commands.Context) -> dict: env = { @@ -382,7 +454,9 @@ class Dev(commands.Cog): env = self.get_environment(ctx) source = cleanup_code(code) - output = await DevOutput.from_debug(ctx, source=source, env=env) + output = await DevOutput.from_debug( + ctx, source=source, source_cache=self.source_cache, env=env + ) self._last_result = output.result await output.send() @@ -415,7 +489,9 @@ class Dev(commands.Cog): env = self.get_environment(ctx) source = cleanup_code(body) - output = await DevOutput.from_eval(ctx, source=source, env=env) + output = await DevOutput.from_eval( + ctx, source=source, source_cache=self.source_cache, env=env + ) if output.result is not None: self._last_result = output.result await output.send() @@ -483,7 +559,9 @@ class Dev(commands.Cog): del self.sessions[ctx.channel.id] return - output = await DevOutput.from_repl(ctx, source=source, env=env) + output = await DevOutput.from_repl( + ctx, source=source, source_cache=self.source_cache, env=env + ) try: await output.send(tick=False) except discord.Forbidden: diff --git a/tests/core/test_dev_commands.py b/tests/core/test_dev_commands.py index 508c7e569..1cf82825c 100644 --- a/tests/core/test_dev_commands.py +++ b/tests/core/test_dev_commands.py @@ -1,11 +1,12 @@ import sys import textwrap +from typing import Any, Dict, Optional from unittest.mock import MagicMock import pytest from redbot.core import commands -from redbot.core.dev_commands import DevOutput, cleanup_code +from redbot.core.dev_commands import DevOutput, SourceCache, cleanup_code # the examples are based on how the markdown ends up being rendered by Discord @@ -134,12 +135,20 @@ def test_cleanup_code(content: str, source: str) -> None: assert cleanup_code(content) == source -def _get_dev_output(source: str) -> DevOutput: +def _get_dev_output( + source: str, + *, + source_cache: Optional[SourceCache] = None, + env: Optional[Dict[str, Any]] = None, +) -> DevOutput: + if source_cache is None: + source_cache = SourceCache() return DevOutput( MagicMock(spec=commands.Context), + source_cache=source_cache, + filename=f"", source=source, - filename="", - env={"__builtins__": __builtins__, "__name__": "__main__", "_": None}, + env={"__builtins__": __builtins__, "__name__": "__main__", "_": None, **(env or {})}, ) @@ -184,7 +193,7 @@ EXPRESSION_TESTS = { ( lambda v: v < (3, 10), """\ - File "", line 1 + File "", line 1 12x ^ SyntaxError: invalid syntax @@ -193,7 +202,7 @@ EXPRESSION_TESTS = { ( lambda v: v >= (3, 10), """\ - File "", line 1 + File "", line 1 12x ^ SyntaxError: invalid decimal literal @@ -204,7 +213,7 @@ EXPRESSION_TESTS = { ( lambda v: v < (3, 10), """\ - File "", line 1 + File "", line 1 foo(x, z for z in range(10), t, w) ^ SyntaxError: Generator expression must be parenthesized @@ -213,7 +222,7 @@ EXPRESSION_TESTS = { ( lambda v: v >= (3, 10), """\ - File "", line 1 + File "", line 1 foo(x, z for z in range(10), t, w) ^^^^^^^^^^^^^^^^^^^^ SyntaxError: Generator expression must be parenthesized @@ -226,7 +235,7 @@ EXPRESSION_TESTS = { lambda v: v < (3, 11), """\ Traceback (most recent call last): - File "", line 1, in + File "", line 1, in abs(1 / 0) ZeroDivisionError: division by zero """, @@ -235,7 +244,7 @@ EXPRESSION_TESTS = { lambda v: v >= (3, 11), """\ Traceback (most recent call last): - File "", line 1, in + File "", line 1, in abs(1 / 0) ~~^~~ ZeroDivisionError: division by zero @@ -252,7 +261,7 @@ STATEMENT_TESTS = { ( lambda v: v < (3, 10), """\ - File "", line 2 + File "", line 2 12x ^ SyntaxError: invalid syntax @@ -261,7 +270,7 @@ STATEMENT_TESTS = { ( lambda v: v >= (3, 10), """\ - File "", line 2 + File "", line 2 12x ^ SyntaxError: invalid decimal literal @@ -275,7 +284,7 @@ STATEMENT_TESTS = { ( lambda v: v < (3, 10), """\ - File "", line 2 + File "", line 2 foo(x, z for z in range(10), t, w) ^ SyntaxError: Generator expression must be parenthesized @@ -284,7 +293,7 @@ STATEMENT_TESTS = { ( lambda v: v >= (3, 10), """\ - File "", line 2 + File "", line 2 foo(x, z for z in range(10), t, w) ^^^^^^^^^^^^^^^^^^^^ SyntaxError: Generator expression must be parenthesized @@ -304,7 +313,7 @@ STATEMENT_TESTS = { """\ 123 Traceback (most recent call last): - File "", line 3, in + File "", line 3, in abs(1 / 0) ZeroDivisionError: division by zero """, @@ -314,7 +323,7 @@ STATEMENT_TESTS = { """\ 123 Traceback (most recent call last): - File "", line 3, in + File "", line 3, in abs(1 / 0) ~~^~~ ZeroDivisionError: division by zero @@ -389,3 +398,37 @@ async def test_successful_run_repl_exec(monkeypatch: pytest.MonkeyPatch) -> None world """ await _run_dev_output(monkeypatch, source, result, repl=True) + + +async def test_regression_format_exception_from_previous_snippet( + monkeypatch: pytest.MonkeyPatch, +) -> None: + snippet_0 = textwrap.dedent( + """\ + def repro(): + raise Exception("this is an error!") + + return repro + """ + ) + snippet_1 = "_()" + result = textwrap.dedent( + """\ + Traceback (most recent call last): + File "", line 1, in func + _() + File "", line 2, in repro + raise Exception("this is an error!") + Exception: this is an error! + """ + ) + monkeypatch.setattr("redbot.core.dev_commands.sanitize_output", lambda ctx, s: s) + + source_cache = SourceCache() + output = _get_dev_output(snippet_0, source_cache=source_cache) + await output.run_eval() + output = _get_dev_output(snippet_1, source_cache=source_cache, env={"_": output.result}) + await output.run_eval() + assert str(output) == result + # ensure that our Context mock is never actually used by anything + assert not output.ctx.mock_calls