Add snippet numbers to filenames in the Dev cog to fix exception formatting (#6135)

This commit is contained in:
Jakub Kuczys 2023-05-12 00:27:19 +02:00 committed by GitHub
parent e7d7eba68f
commit 70ca8ff1f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 184 additions and 63 deletions

View File

@ -23,7 +23,7 @@ import types
import re
import sys
from copy import copy
from typing import Any, Awaitable, Dict, Iterator, Literal, Type, TypeVar, Union
from typing import Any, Awaitable, Dict, Iterator, List, Literal, Tuple, Type, TypeVar, Union
from types import CodeType, TracebackType
import discord
@ -83,13 +83,55 @@ def cleanup_code(content: str) -> str:
return content.strip("` \n")
class SourceCache:
MAX_SIZE = 1000
def __init__(self) -> None:
# estimated to take less than 100 kB
self._data: Dict[str, Tuple[str, int]] = {}
# this just keeps going up until the bot is restarted, shouldn't really be an issue
self._next_index = 0
def take_next_index(self) -> int:
next_index = self._next_index
self._next_index += 1
return next_index
def __getitem__(self, key: str) -> Tuple[List[str], int]:
value = self._data.pop(key) # pop to put it at the end as most recent
self._data[key] = value
# To mimic linecache module's behavior,
# all lines (including the last one) should end with \n.
source_lines = [f"{line}\n" for line in value[0].splitlines()]
# Note: while it might seem like a waste of time to always calculate the list of source lines,
# this is a necessary memory optimization. If all of the data in `self._data` were list,
# it could theoretically take up to 1000x as much memory.
return source_lines, value[1]
def __setitem__(self, key: str, value: Tuple[str, int]) -> None:
self._data.pop(key, None)
self._data[key] = value
if len(self._data) > self.MAX_SIZE:
del self._data[next(iter(self._data))]
class DevOutput:
def __init__(
self, ctx: commands.Context, *, source: str, filename: str, env: Dict[str, Any]
self,
ctx: commands.Context,
*,
source_cache: SourceCache,
filename: str,
source: str,
env: Dict[str, Any],
) -> None:
self.ctx = ctx
self.source = source
self.source_cache = source_cache
self.filename = filename
self.source_line_offset = 0
#: raw source - as received from the command after stripping the code block
self.raw_source = source
self.set_compilable_source(source)
self.env = env
self.always_include_result = False
self._stream = io.StringIO()
@ -98,12 +140,14 @@ class DevOutput:
self._old_streams = []
@property
def source(self) -> str:
return self._original_source
def compilable_source(self) -> str:
"""Source string that we pass to async_compile()."""
return self._compilable_source
@source.setter
def source(self, value: str) -> None:
self._source = self._original_source = value
def set_compilable_source(self, compilable_source: str, *, line_offset: int = 0) -> None:
self._compilable_source = compilable_source
self.source_line_offset = line_offset
self.source_cache[self.filename] = (compilable_source, line_offset)
def __str__(self) -> str:
output = []
@ -124,10 +168,8 @@ class DevOutput:
if tick and not self.formatted_exc:
await self.ctx.tick()
def set_exception(self, exc: Exception, *, line_offset: int = 0, skip_frames: int = 1) -> None:
self.formatted_exc = self.format_exception(
exc, line_offset=line_offset, skip_frames=skip_frames
)
def set_exception(self, exc: Exception, *, skip_frames: int = 1) -> None:
self.formatted_exc = self.format_exception(exc, skip_frames=skip_frames)
def __enter__(self) -> None:
self._old_streams.append(sys.stdout)
@ -144,31 +186,49 @@ class DevOutput:
@classmethod
async def from_debug(
cls, ctx: commands.Context, *, source: str, env: Dict[str, Any]
cls, ctx: commands.Context, *, source: str, source_cache: SourceCache, env: Dict[str, Any]
) -> DevOutput:
output = cls(ctx, source=source, filename="<debug command>", env=env)
output = cls(
ctx,
source=source,
source_cache=source_cache,
filename=f"<debug command - snippet #{source_cache.take_next_index()}>",
env=env,
)
await output.run_debug()
return output
@classmethod
async def from_eval(
cls, ctx: commands.Context, *, source: str, env: Dict[str, Any]
cls, ctx: commands.Context, *, source: str, source_cache: SourceCache, env: Dict[str, Any]
) -> DevOutput:
output = cls(ctx, source=source, filename="<eval command>", env=env)
output = cls(
ctx,
source=source,
source_cache=source_cache,
filename=f"<eval command - snippet #{source_cache.take_next_index()}>",
env=env,
)
await output.run_eval()
return output
@classmethod
async def from_repl(
cls, ctx: commands.Context, *, source: str, env: Dict[str, Any]
cls, ctx: commands.Context, *, source: str, source_cache: SourceCache, env: Dict[str, Any]
) -> DevOutput:
output = cls(ctx, source=source, filename="<repl session>", env=env)
output = cls(
ctx,
source=source,
source_cache=source_cache,
filename=f"<repl session - snippet #{source_cache.take_next_index()}>",
env=env,
)
await output.run_repl()
return output
async def run_debug(self) -> None:
self.always_include_result = True
self._source = self.source
self.set_compilable_source(self.raw_source)
try:
compiled = self.async_compile_with_eval()
except SyntaxError as exc:
@ -182,12 +242,14 @@ class DevOutput:
async def run_eval(self) -> None:
self.always_include_result = False
self._source = "async def func():\n%s" % textwrap.indent(self.source, " ")
self.set_compilable_source(
"async def func():\n%s" % textwrap.indent(self.raw_source, " "), line_offset=1
)
try:
compiled = self.async_compile_with_exec()
exec(compiled, self.env)
except SyntaxError as exc:
self.set_exception(exc, line_offset=1, skip_frames=3)
self.set_exception(exc, skip_frames=3)
return
func = self.env["func"]
@ -195,13 +257,13 @@ class DevOutput:
with self:
self.result = await func()
except Exception as exc:
self.set_exception(exc, line_offset=1)
self.set_exception(exc)
async def run_repl(self) -> None:
self.always_include_result = False
self._source = self.source
self.set_compilable_source(self.raw_source)
executor = None
if self.source.count("\n") == 0:
if self.raw_source.count("\n") == 0:
# single statement, potentially 'eval'
try:
code = self.async_compile_with_eval()
@ -231,14 +293,12 @@ class DevOutput:
self.env["_"] = self.result
def async_compile_with_exec(self) -> CodeType:
return async_compile(self._source, self.filename, "exec")
return async_compile(self.compilable_source, self.filename, "exec")
def async_compile_with_eval(self) -> CodeType:
return async_compile(self._source, self.filename, "eval")
return async_compile(self.compilable_source, self.filename, "eval")
def format_exception(
self, exc: Exception, *, line_offset: int = 0, skip_frames: int = 1
) -> str:
def format_exception(self, exc: Exception, *, skip_frames: int = 1) -> str:
"""
Format an exception to send to the user.
@ -260,33 +320,44 @@ class DevOutput:
break
tb = tb.tb_next
# To mimic linecache module's behavior,
# all lines (including the last one) should end with \n.
source_lines = [f"{line}\n" for line in self._source.splitlines()]
filename = self.filename
# sometimes SyntaxError.text is None, sometimes it isn't
if (
issubclass(exc_type, SyntaxError)
and exc.filename == filename
and exc.lineno is not None
):
if exc.text is None:
# line numbers are 1-based, the list indexes are 0-based
exc.text = source_lines[exc.lineno - 1]
exc.lineno -= line_offset
if issubclass(exc_type, SyntaxError) and exc.lineno is not None:
try:
source_lines, line_offset = self.source_cache[exc.filename]
except KeyError:
pass
else:
if exc.text is None:
try:
# line numbers are 1-based, the list indexes are 0-based
exc.text = source_lines[exc.lineno - 1]
except IndexError:
# the frame might be pointing at a different source code, ignore...
pass
else:
exc.lineno -= line_offset
else:
exc.lineno -= line_offset
traceback_exc = traceback.TracebackException(exc_type, exc, tb)
py311_or_above = sys.version_info >= (3, 11)
stack_summary = traceback_exc.stack
for idx, frame_summary in enumerate(stack_summary):
if frame_summary.filename != filename:
try:
source_lines, line_offset = self.source_cache[frame_summary.filename]
except KeyError:
continue
lineno = frame_summary.lineno
if lineno is None:
continue
# line numbers are 1-based, the list indexes are 0-based
line = source_lines[lineno - 1]
try:
# line numbers are 1-based, the list indexes are 0-based
line = source_lines[lineno - 1]
except IndexError:
# the frame might be pointing at a different source code, ignore...
continue
lineno -= line_offset
# support for enhanced error locations in tracebacks
if py311_or_above:
@ -327,6 +398,7 @@ class Dev(commands.Cog):
self._last_result = None
self.sessions = {}
self.env_extensions = {}
self.source_cache = SourceCache()
def get_environment(self, ctx: commands.Context) -> dict:
env = {
@ -382,7 +454,9 @@ class Dev(commands.Cog):
env = self.get_environment(ctx)
source = cleanup_code(code)
output = await DevOutput.from_debug(ctx, source=source, env=env)
output = await DevOutput.from_debug(
ctx, source=source, source_cache=self.source_cache, env=env
)
self._last_result = output.result
await output.send()
@ -415,7 +489,9 @@ class Dev(commands.Cog):
env = self.get_environment(ctx)
source = cleanup_code(body)
output = await DevOutput.from_eval(ctx, source=source, env=env)
output = await DevOutput.from_eval(
ctx, source=source, source_cache=self.source_cache, env=env
)
if output.result is not None:
self._last_result = output.result
await output.send()
@ -483,7 +559,9 @@ class Dev(commands.Cog):
del self.sessions[ctx.channel.id]
return
output = await DevOutput.from_repl(ctx, source=source, env=env)
output = await DevOutput.from_repl(
ctx, source=source, source_cache=self.source_cache, env=env
)
try:
await output.send(tick=False)
except discord.Forbidden:

View File

@ -1,11 +1,12 @@
import sys
import textwrap
from typing import Any, Dict, Optional
from unittest.mock import MagicMock
import pytest
from redbot.core import commands
from redbot.core.dev_commands import DevOutput, cleanup_code
from redbot.core.dev_commands import DevOutput, SourceCache, cleanup_code
# the examples are based on how the markdown ends up being rendered by Discord
@ -134,12 +135,20 @@ def test_cleanup_code(content: str, source: str) -> None:
assert cleanup_code(content) == source
def _get_dev_output(source: str) -> DevOutput:
def _get_dev_output(
source: str,
*,
source_cache: Optional[SourceCache] = None,
env: Optional[Dict[str, Any]] = None,
) -> DevOutput:
if source_cache is None:
source_cache = SourceCache()
return DevOutput(
MagicMock(spec=commands.Context),
source_cache=source_cache,
filename=f"<test run - snippet #{source_cache.take_next_index()}>",
source=source,
filename="<test run>",
env={"__builtins__": __builtins__, "__name__": "__main__", "_": None},
env={"__builtins__": __builtins__, "__name__": "__main__", "_": None, **(env or {})},
)
@ -184,7 +193,7 @@ EXPRESSION_TESTS = {
(
lambda v: v < (3, 10),
"""\
File "<test run>", line 1
File "<test run - snippet #0>", line 1
12x
^
SyntaxError: invalid syntax
@ -193,7 +202,7 @@ EXPRESSION_TESTS = {
(
lambda v: v >= (3, 10),
"""\
File "<test run>", line 1
File "<test run - snippet #0>", line 1
12x
^
SyntaxError: invalid decimal literal
@ -204,7 +213,7 @@ EXPRESSION_TESTS = {
(
lambda v: v < (3, 10),
"""\
File "<test run>", line 1
File "<test run - snippet #0>", line 1
foo(x, z for z in range(10), t, w)
^
SyntaxError: Generator expression must be parenthesized
@ -213,7 +222,7 @@ EXPRESSION_TESTS = {
(
lambda v: v >= (3, 10),
"""\
File "<test run>", line 1
File "<test run - snippet #0>", line 1
foo(x, z for z in range(10), t, w)
^^^^^^^^^^^^^^^^^^^^
SyntaxError: Generator expression must be parenthesized
@ -226,7 +235,7 @@ EXPRESSION_TESTS = {
lambda v: v < (3, 11),
"""\
Traceback (most recent call last):
File "<test run>", line 1, in <module>
File "<test run - snippet #0>", line 1, in <module>
abs(1 / 0)
ZeroDivisionError: division by zero
""",
@ -235,7 +244,7 @@ EXPRESSION_TESTS = {
lambda v: v >= (3, 11),
"""\
Traceback (most recent call last):
File "<test run>", line 1, in <module>
File "<test run - snippet #0>", line 1, in <module>
abs(1 / 0)
~~^~~
ZeroDivisionError: division by zero
@ -252,7 +261,7 @@ STATEMENT_TESTS = {
(
lambda v: v < (3, 10),
"""\
File "<test run>", line 2
File "<test run - snippet #0>", line 2
12x
^
SyntaxError: invalid syntax
@ -261,7 +270,7 @@ STATEMENT_TESTS = {
(
lambda v: v >= (3, 10),
"""\
File "<test run>", line 2
File "<test run - snippet #0>", line 2
12x
^
SyntaxError: invalid decimal literal
@ -275,7 +284,7 @@ STATEMENT_TESTS = {
(
lambda v: v < (3, 10),
"""\
File "<test run>", line 2
File "<test run - snippet #0>", line 2
foo(x, z for z in range(10), t, w)
^
SyntaxError: Generator expression must be parenthesized
@ -284,7 +293,7 @@ STATEMENT_TESTS = {
(
lambda v: v >= (3, 10),
"""\
File "<test run>", line 2
File "<test run - snippet #0>", line 2
foo(x, z for z in range(10), t, w)
^^^^^^^^^^^^^^^^^^^^
SyntaxError: Generator expression must be parenthesized
@ -304,7 +313,7 @@ STATEMENT_TESTS = {
"""\
123
Traceback (most recent call last):
File "<test run>", line 3, in <module>
File "<test run - snippet #0>", line 3, in <module>
abs(1 / 0)
ZeroDivisionError: division by zero
""",
@ -314,7 +323,7 @@ STATEMENT_TESTS = {
"""\
123
Traceback (most recent call last):
File "<test run>", line 3, in <module>
File "<test run - snippet #0>", line 3, in <module>
abs(1 / 0)
~~^~~
ZeroDivisionError: division by zero
@ -389,3 +398,37 @@ async def test_successful_run_repl_exec(monkeypatch: pytest.MonkeyPatch) -> None
world
"""
await _run_dev_output(monkeypatch, source, result, repl=True)
async def test_regression_format_exception_from_previous_snippet(
monkeypatch: pytest.MonkeyPatch,
) -> None:
snippet_0 = textwrap.dedent(
"""\
def repro():
raise Exception("this is an error!")
return repro
"""
)
snippet_1 = "_()"
result = textwrap.dedent(
"""\
Traceback (most recent call last):
File "<test run - snippet #1>", line 1, in func
_()
File "<test run - snippet #0>", line 2, in repro
raise Exception("this is an error!")
Exception: this is an error!
"""
)
monkeypatch.setattr("redbot.core.dev_commands.sanitize_output", lambda ctx, s: s)
source_cache = SourceCache()
output = _get_dev_output(snippet_0, source_cache=source_cache)
await output.run_eval()
output = _get_dev_output(snippet_1, source_cache=source_cache, env={"_": output.result})
await output.run_eval()
assert str(output) == result
# ensure that our Context mock is never actually used by anything
assert not output.ctx.mock_calls