Add snippet numbers to filenames in the Dev cog to fix exception formatting (#6135)

This commit is contained in:
Jakub Kuczys 2023-05-12 00:27:19 +02:00 committed by GitHub
parent e7d7eba68f
commit 70ca8ff1f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 184 additions and 63 deletions

View File

@ -23,7 +23,7 @@ import types
import re import re
import sys import sys
from copy import copy from copy import copy
from typing import Any, Awaitable, Dict, Iterator, Literal, Type, TypeVar, Union from typing import Any, Awaitable, Dict, Iterator, List, Literal, Tuple, Type, TypeVar, Union
from types import CodeType, TracebackType from types import CodeType, TracebackType
import discord import discord
@ -83,13 +83,55 @@ def cleanup_code(content: str) -> str:
return content.strip("` \n") return content.strip("` \n")
class SourceCache:
MAX_SIZE = 1000
def __init__(self) -> None:
# estimated to take less than 100 kB
self._data: Dict[str, Tuple[str, int]] = {}
# this just keeps going up until the bot is restarted, shouldn't really be an issue
self._next_index = 0
def take_next_index(self) -> int:
next_index = self._next_index
self._next_index += 1
return next_index
def __getitem__(self, key: str) -> Tuple[List[str], int]:
value = self._data.pop(key) # pop to put it at the end as most recent
self._data[key] = value
# To mimic linecache module's behavior,
# all lines (including the last one) should end with \n.
source_lines = [f"{line}\n" for line in value[0].splitlines()]
# Note: while it might seem like a waste of time to always calculate the list of source lines,
# this is a necessary memory optimization. If all of the data in `self._data` were list,
# it could theoretically take up to 1000x as much memory.
return source_lines, value[1]
def __setitem__(self, key: str, value: Tuple[str, int]) -> None:
self._data.pop(key, None)
self._data[key] = value
if len(self._data) > self.MAX_SIZE:
del self._data[next(iter(self._data))]
class DevOutput: class DevOutput:
def __init__( def __init__(
self, ctx: commands.Context, *, source: str, filename: str, env: Dict[str, Any] self,
ctx: commands.Context,
*,
source_cache: SourceCache,
filename: str,
source: str,
env: Dict[str, Any],
) -> None: ) -> None:
self.ctx = ctx self.ctx = ctx
self.source = source self.source_cache = source_cache
self.filename = filename self.filename = filename
self.source_line_offset = 0
#: raw source - as received from the command after stripping the code block
self.raw_source = source
self.set_compilable_source(source)
self.env = env self.env = env
self.always_include_result = False self.always_include_result = False
self._stream = io.StringIO() self._stream = io.StringIO()
@ -98,12 +140,14 @@ class DevOutput:
self._old_streams = [] self._old_streams = []
@property @property
def source(self) -> str: def compilable_source(self) -> str:
return self._original_source """Source string that we pass to async_compile()."""
return self._compilable_source
@source.setter def set_compilable_source(self, compilable_source: str, *, line_offset: int = 0) -> None:
def source(self, value: str) -> None: self._compilable_source = compilable_source
self._source = self._original_source = value self.source_line_offset = line_offset
self.source_cache[self.filename] = (compilable_source, line_offset)
def __str__(self) -> str: def __str__(self) -> str:
output = [] output = []
@ -124,10 +168,8 @@ class DevOutput:
if tick and not self.formatted_exc: if tick and not self.formatted_exc:
await self.ctx.tick() await self.ctx.tick()
def set_exception(self, exc: Exception, *, line_offset: int = 0, skip_frames: int = 1) -> None: def set_exception(self, exc: Exception, *, skip_frames: int = 1) -> None:
self.formatted_exc = self.format_exception( self.formatted_exc = self.format_exception(exc, skip_frames=skip_frames)
exc, line_offset=line_offset, skip_frames=skip_frames
)
def __enter__(self) -> None: def __enter__(self) -> None:
self._old_streams.append(sys.stdout) self._old_streams.append(sys.stdout)
@ -144,31 +186,49 @@ class DevOutput:
@classmethod @classmethod
async def from_debug( async def from_debug(
cls, ctx: commands.Context, *, source: str, env: Dict[str, Any] cls, ctx: commands.Context, *, source: str, source_cache: SourceCache, env: Dict[str, Any]
) -> DevOutput: ) -> DevOutput:
output = cls(ctx, source=source, filename="<debug command>", env=env) output = cls(
ctx,
source=source,
source_cache=source_cache,
filename=f"<debug command - snippet #{source_cache.take_next_index()}>",
env=env,
)
await output.run_debug() await output.run_debug()
return output return output
@classmethod @classmethod
async def from_eval( async def from_eval(
cls, ctx: commands.Context, *, source: str, env: Dict[str, Any] cls, ctx: commands.Context, *, source: str, source_cache: SourceCache, env: Dict[str, Any]
) -> DevOutput: ) -> DevOutput:
output = cls(ctx, source=source, filename="<eval command>", env=env) output = cls(
ctx,
source=source,
source_cache=source_cache,
filename=f"<eval command - snippet #{source_cache.take_next_index()}>",
env=env,
)
await output.run_eval() await output.run_eval()
return output return output
@classmethod @classmethod
async def from_repl( async def from_repl(
cls, ctx: commands.Context, *, source: str, env: Dict[str, Any] cls, ctx: commands.Context, *, source: str, source_cache: SourceCache, env: Dict[str, Any]
) -> DevOutput: ) -> DevOutput:
output = cls(ctx, source=source, filename="<repl session>", env=env) output = cls(
ctx,
source=source,
source_cache=source_cache,
filename=f"<repl session - snippet #{source_cache.take_next_index()}>",
env=env,
)
await output.run_repl() await output.run_repl()
return output return output
async def run_debug(self) -> None: async def run_debug(self) -> None:
self.always_include_result = True self.always_include_result = True
self._source = self.source self.set_compilable_source(self.raw_source)
try: try:
compiled = self.async_compile_with_eval() compiled = self.async_compile_with_eval()
except SyntaxError as exc: except SyntaxError as exc:
@ -182,12 +242,14 @@ class DevOutput:
async def run_eval(self) -> None: async def run_eval(self) -> None:
self.always_include_result = False self.always_include_result = False
self._source = "async def func():\n%s" % textwrap.indent(self.source, " ") self.set_compilable_source(
"async def func():\n%s" % textwrap.indent(self.raw_source, " "), line_offset=1
)
try: try:
compiled = self.async_compile_with_exec() compiled = self.async_compile_with_exec()
exec(compiled, self.env) exec(compiled, self.env)
except SyntaxError as exc: except SyntaxError as exc:
self.set_exception(exc, line_offset=1, skip_frames=3) self.set_exception(exc, skip_frames=3)
return return
func = self.env["func"] func = self.env["func"]
@ -195,13 +257,13 @@ class DevOutput:
with self: with self:
self.result = await func() self.result = await func()
except Exception as exc: except Exception as exc:
self.set_exception(exc, line_offset=1) self.set_exception(exc)
async def run_repl(self) -> None: async def run_repl(self) -> None:
self.always_include_result = False self.always_include_result = False
self._source = self.source self.set_compilable_source(self.raw_source)
executor = None executor = None
if self.source.count("\n") == 0: if self.raw_source.count("\n") == 0:
# single statement, potentially 'eval' # single statement, potentially 'eval'
try: try:
code = self.async_compile_with_eval() code = self.async_compile_with_eval()
@ -231,14 +293,12 @@ class DevOutput:
self.env["_"] = self.result self.env["_"] = self.result
def async_compile_with_exec(self) -> CodeType: def async_compile_with_exec(self) -> CodeType:
return async_compile(self._source, self.filename, "exec") return async_compile(self.compilable_source, self.filename, "exec")
def async_compile_with_eval(self) -> CodeType: def async_compile_with_eval(self) -> CodeType:
return async_compile(self._source, self.filename, "eval") return async_compile(self.compilable_source, self.filename, "eval")
def format_exception( def format_exception(self, exc: Exception, *, skip_frames: int = 1) -> str:
self, exc: Exception, *, line_offset: int = 0, skip_frames: int = 1
) -> str:
""" """
Format an exception to send to the user. Format an exception to send to the user.
@ -260,33 +320,44 @@ class DevOutput:
break break
tb = tb.tb_next tb = tb.tb_next
# To mimic linecache module's behavior,
# all lines (including the last one) should end with \n.
source_lines = [f"{line}\n" for line in self._source.splitlines()]
filename = self.filename filename = self.filename
# sometimes SyntaxError.text is None, sometimes it isn't # sometimes SyntaxError.text is None, sometimes it isn't
if ( if issubclass(exc_type, SyntaxError) and exc.lineno is not None:
issubclass(exc_type, SyntaxError) try:
and exc.filename == filename source_lines, line_offset = self.source_cache[exc.filename]
and exc.lineno is not None except KeyError:
): pass
else:
if exc.text is None: if exc.text is None:
try:
# line numbers are 1-based, the list indexes are 0-based # line numbers are 1-based, the list indexes are 0-based
exc.text = source_lines[exc.lineno - 1] exc.text = source_lines[exc.lineno - 1]
except IndexError:
# the frame might be pointing at a different source code, ignore...
pass
else:
exc.lineno -= line_offset
else:
exc.lineno -= line_offset exc.lineno -= line_offset
traceback_exc = traceback.TracebackException(exc_type, exc, tb) traceback_exc = traceback.TracebackException(exc_type, exc, tb)
py311_or_above = sys.version_info >= (3, 11) py311_or_above = sys.version_info >= (3, 11)
stack_summary = traceback_exc.stack stack_summary = traceback_exc.stack
for idx, frame_summary in enumerate(stack_summary): for idx, frame_summary in enumerate(stack_summary):
if frame_summary.filename != filename: try:
source_lines, line_offset = self.source_cache[frame_summary.filename]
except KeyError:
continue continue
lineno = frame_summary.lineno lineno = frame_summary.lineno
if lineno is None: if lineno is None:
continue continue
try:
# line numbers are 1-based, the list indexes are 0-based # line numbers are 1-based, the list indexes are 0-based
line = source_lines[lineno - 1] line = source_lines[lineno - 1]
except IndexError:
# the frame might be pointing at a different source code, ignore...
continue
lineno -= line_offset lineno -= line_offset
# support for enhanced error locations in tracebacks # support for enhanced error locations in tracebacks
if py311_or_above: if py311_or_above:
@ -327,6 +398,7 @@ class Dev(commands.Cog):
self._last_result = None self._last_result = None
self.sessions = {} self.sessions = {}
self.env_extensions = {} self.env_extensions = {}
self.source_cache = SourceCache()
def get_environment(self, ctx: commands.Context) -> dict: def get_environment(self, ctx: commands.Context) -> dict:
env = { env = {
@ -382,7 +454,9 @@ class Dev(commands.Cog):
env = self.get_environment(ctx) env = self.get_environment(ctx)
source = cleanup_code(code) source = cleanup_code(code)
output = await DevOutput.from_debug(ctx, source=source, env=env) output = await DevOutput.from_debug(
ctx, source=source, source_cache=self.source_cache, env=env
)
self._last_result = output.result self._last_result = output.result
await output.send() await output.send()
@ -415,7 +489,9 @@ class Dev(commands.Cog):
env = self.get_environment(ctx) env = self.get_environment(ctx)
source = cleanup_code(body) source = cleanup_code(body)
output = await DevOutput.from_eval(ctx, source=source, env=env) output = await DevOutput.from_eval(
ctx, source=source, source_cache=self.source_cache, env=env
)
if output.result is not None: if output.result is not None:
self._last_result = output.result self._last_result = output.result
await output.send() await output.send()
@ -483,7 +559,9 @@ class Dev(commands.Cog):
del self.sessions[ctx.channel.id] del self.sessions[ctx.channel.id]
return return
output = await DevOutput.from_repl(ctx, source=source, env=env) output = await DevOutput.from_repl(
ctx, source=source, source_cache=self.source_cache, env=env
)
try: try:
await output.send(tick=False) await output.send(tick=False)
except discord.Forbidden: except discord.Forbidden:

View File

@ -1,11 +1,12 @@
import sys import sys
import textwrap import textwrap
from typing import Any, Dict, Optional
from unittest.mock import MagicMock from unittest.mock import MagicMock
import pytest import pytest
from redbot.core import commands from redbot.core import commands
from redbot.core.dev_commands import DevOutput, cleanup_code from redbot.core.dev_commands import DevOutput, SourceCache, cleanup_code
# the examples are based on how the markdown ends up being rendered by Discord # the examples are based on how the markdown ends up being rendered by Discord
@ -134,12 +135,20 @@ def test_cleanup_code(content: str, source: str) -> None:
assert cleanup_code(content) == source assert cleanup_code(content) == source
def _get_dev_output(source: str) -> DevOutput: def _get_dev_output(
source: str,
*,
source_cache: Optional[SourceCache] = None,
env: Optional[Dict[str, Any]] = None,
) -> DevOutput:
if source_cache is None:
source_cache = SourceCache()
return DevOutput( return DevOutput(
MagicMock(spec=commands.Context), MagicMock(spec=commands.Context),
source_cache=source_cache,
filename=f"<test run - snippet #{source_cache.take_next_index()}>",
source=source, source=source,
filename="<test run>", env={"__builtins__": __builtins__, "__name__": "__main__", "_": None, **(env or {})},
env={"__builtins__": __builtins__, "__name__": "__main__", "_": None},
) )
@ -184,7 +193,7 @@ EXPRESSION_TESTS = {
( (
lambda v: v < (3, 10), lambda v: v < (3, 10),
"""\ """\
File "<test run>", line 1 File "<test run - snippet #0>", line 1
12x 12x
^ ^
SyntaxError: invalid syntax SyntaxError: invalid syntax
@ -193,7 +202,7 @@ EXPRESSION_TESTS = {
( (
lambda v: v >= (3, 10), lambda v: v >= (3, 10),
"""\ """\
File "<test run>", line 1 File "<test run - snippet #0>", line 1
12x 12x
^ ^
SyntaxError: invalid decimal literal SyntaxError: invalid decimal literal
@ -204,7 +213,7 @@ EXPRESSION_TESTS = {
( (
lambda v: v < (3, 10), lambda v: v < (3, 10),
"""\ """\
File "<test run>", line 1 File "<test run - snippet #0>", line 1
foo(x, z for z in range(10), t, w) foo(x, z for z in range(10), t, w)
^ ^
SyntaxError: Generator expression must be parenthesized SyntaxError: Generator expression must be parenthesized
@ -213,7 +222,7 @@ EXPRESSION_TESTS = {
( (
lambda v: v >= (3, 10), lambda v: v >= (3, 10),
"""\ """\
File "<test run>", line 1 File "<test run - snippet #0>", line 1
foo(x, z for z in range(10), t, w) foo(x, z for z in range(10), t, w)
^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^
SyntaxError: Generator expression must be parenthesized SyntaxError: Generator expression must be parenthesized
@ -226,7 +235,7 @@ EXPRESSION_TESTS = {
lambda v: v < (3, 11), lambda v: v < (3, 11),
"""\ """\
Traceback (most recent call last): Traceback (most recent call last):
File "<test run>", line 1, in <module> File "<test run - snippet #0>", line 1, in <module>
abs(1 / 0) abs(1 / 0)
ZeroDivisionError: division by zero ZeroDivisionError: division by zero
""", """,
@ -235,7 +244,7 @@ EXPRESSION_TESTS = {
lambda v: v >= (3, 11), lambda v: v >= (3, 11),
"""\ """\
Traceback (most recent call last): Traceback (most recent call last):
File "<test run>", line 1, in <module> File "<test run - snippet #0>", line 1, in <module>
abs(1 / 0) abs(1 / 0)
~~^~~ ~~^~~
ZeroDivisionError: division by zero ZeroDivisionError: division by zero
@ -252,7 +261,7 @@ STATEMENT_TESTS = {
( (
lambda v: v < (3, 10), lambda v: v < (3, 10),
"""\ """\
File "<test run>", line 2 File "<test run - snippet #0>", line 2
12x 12x
^ ^
SyntaxError: invalid syntax SyntaxError: invalid syntax
@ -261,7 +270,7 @@ STATEMENT_TESTS = {
( (
lambda v: v >= (3, 10), lambda v: v >= (3, 10),
"""\ """\
File "<test run>", line 2 File "<test run - snippet #0>", line 2
12x 12x
^ ^
SyntaxError: invalid decimal literal SyntaxError: invalid decimal literal
@ -275,7 +284,7 @@ STATEMENT_TESTS = {
( (
lambda v: v < (3, 10), lambda v: v < (3, 10),
"""\ """\
File "<test run>", line 2 File "<test run - snippet #0>", line 2
foo(x, z for z in range(10), t, w) foo(x, z for z in range(10), t, w)
^ ^
SyntaxError: Generator expression must be parenthesized SyntaxError: Generator expression must be parenthesized
@ -284,7 +293,7 @@ STATEMENT_TESTS = {
( (
lambda v: v >= (3, 10), lambda v: v >= (3, 10),
"""\ """\
File "<test run>", line 2 File "<test run - snippet #0>", line 2
foo(x, z for z in range(10), t, w) foo(x, z for z in range(10), t, w)
^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^
SyntaxError: Generator expression must be parenthesized SyntaxError: Generator expression must be parenthesized
@ -304,7 +313,7 @@ STATEMENT_TESTS = {
"""\ """\
123 123
Traceback (most recent call last): Traceback (most recent call last):
File "<test run>", line 3, in <module> File "<test run - snippet #0>", line 3, in <module>
abs(1 / 0) abs(1 / 0)
ZeroDivisionError: division by zero ZeroDivisionError: division by zero
""", """,
@ -314,7 +323,7 @@ STATEMENT_TESTS = {
"""\ """\
123 123
Traceback (most recent call last): Traceback (most recent call last):
File "<test run>", line 3, in <module> File "<test run - snippet #0>", line 3, in <module>
abs(1 / 0) abs(1 / 0)
~~^~~ ~~^~~
ZeroDivisionError: division by zero ZeroDivisionError: division by zero
@ -389,3 +398,37 @@ async def test_successful_run_repl_exec(monkeypatch: pytest.MonkeyPatch) -> None
world world
""" """
await _run_dev_output(monkeypatch, source, result, repl=True) await _run_dev_output(monkeypatch, source, result, repl=True)
async def test_regression_format_exception_from_previous_snippet(
monkeypatch: pytest.MonkeyPatch,
) -> None:
snippet_0 = textwrap.dedent(
"""\
def repro():
raise Exception("this is an error!")
return repro
"""
)
snippet_1 = "_()"
result = textwrap.dedent(
"""\
Traceback (most recent call last):
File "<test run - snippet #1>", line 1, in func
_()
File "<test run - snippet #0>", line 2, in repro
raise Exception("this is an error!")
Exception: this is an error!
"""
)
monkeypatch.setattr("redbot.core.dev_commands.sanitize_output", lambda ctx, s: s)
source_cache = SourceCache()
output = _get_dev_output(snippet_0, source_cache=source_cache)
await output.run_eval()
output = _get_dev_output(snippet_1, source_cache=source_cache, env={"_": output.result})
await output.run_eval()
assert str(output) == result
# ensure that our Context mock is never actually used by anything
assert not output.ctx.mock_calls