[Dev] Allow top-level await in code statements (#3508)

* [dev] allow top-level await in code statements

* style

* use staticmethod, cls is unneeded

* add asyncio and aiohttp to env

* fix repl

* add __builtins__ to repl env

* style...

* fix debug with no coro

* add `optimize=0` to eval
This commit is contained in:
zephyrkul 2020-02-13 10:29:10 -07:00 committed by GitHub
parent cc30726ab6
commit 42a23277cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,8 +1,11 @@
import ast
import asyncio
import aiohttp
import inspect
import io
import textwrap
import traceback
import types
import re
from contextlib import redirect_stdout
from copy import copy
@ -35,6 +38,19 @@ class Dev(commands.Cog):
self._last_result = None
self.sessions = set()
@staticmethod
def async_compile(source, filename, mode):
return compile(source, filename, mode, flags=ast.PyCF_ALLOW_TOP_LEVEL_AWAIT, optimize=0)
@staticmethod
async def maybe_await(coro):
for i in range(2):
if inspect.isawaitable(coro):
coro = await coro
else:
return coro
return coro
@staticmethod
def cleanup_code(content):
"""Automatically removes code blocks from the code."""
@ -53,7 +69,9 @@ class Dev(commands.Cog):
"""
if e.text is None:
return box("{0.__class__.__name__}: {0}".format(e), lang="py")
return box("{0.text}{1:>{0.offset}}\n{2}: {0}".format(e, "^", type(e).__name__), lang="py")
return box(
"{0.text}\n{1:>{0.offset}}\n{2}: {0}".format(e, "^", type(e).__name__), lang="py"
)
@staticmethod
def get_pages(msg: str):
@ -75,8 +93,8 @@ class Dev(commands.Cog):
If the return value of the code is a coroutine, it will be awaited,
and the result of that will be the bot's response.
Note: Only one statement may be evaluated. Using await, yield or
similar restricted keywords will result in a syntax error. For multiple
Note: Only one statement may be evaluated. Using certain restricted
keywords, e.g. yield, will result in a syntax error. For multiple
lines or asynchronous code, see [p]repl or [p]eval.
Environment Variables:
@ -96,6 +114,8 @@ class Dev(commands.Cog):
"author": ctx.author,
"guild": ctx.guild,
"message": ctx.message,
"asyncio": asyncio,
"aiohttp": aiohttp,
"discord": discord,
"commands": commands,
"_": self._last_result,
@ -104,7 +124,8 @@ class Dev(commands.Cog):
code = self.cleanup_code(code)
try:
result = eval(code, env)
compiled = self.async_compile(code, "<string>", "eval")
result = await self.maybe_await(eval(compiled, env))
except SyntaxError as e:
await ctx.send(self.get_syntax_error(e))
return
@ -112,9 +133,6 @@ class Dev(commands.Cog):
await ctx.send(box("{}: {!s}".format(type(e).__name__, e), lang="py"))
return
if inspect.isawaitable(result):
result = await result
self._last_result = result
result = self.sanitize_output(ctx, str(result))
@ -149,6 +167,8 @@ class Dev(commands.Cog):
"author": ctx.author,
"guild": ctx.guild,
"message": ctx.message,
"asyncio": asyncio,
"aiohttp": aiohttp,
"discord": discord,
"commands": commands,
"_": self._last_result,
@ -160,7 +180,8 @@ class Dev(commands.Cog):
to_compile = "async def func():\n%s" % textwrap.indent(body, " ")
try:
exec(to_compile, env)
compiled = self.async_compile(to_compile, "<string>", "exec")
exec(compiled, env)
except SyntaxError as e:
return await ctx.send(self.get_syntax_error(e))
@ -192,9 +213,6 @@ class Dev(commands.Cog):
The REPL will only recognise code as messages which start with a
backtick. This includes codeblocks, and as such multiple lines can be
evaluated.
You may not await any code in this REPL unless you define it inside an
async function.
"""
variables = {
"ctx": ctx,
@ -203,7 +221,9 @@ class Dev(commands.Cog):
"guild": ctx.guild,
"channel": ctx.channel,
"author": ctx.author,
"asyncio": asyncio,
"_": None,
"__builtins__": __builtins__,
}
if ctx.channel.id in self.sessions:
@ -225,19 +245,19 @@ class Dev(commands.Cog):
self.sessions.remove(ctx.channel.id)
return
executor = exec
executor = None
if cleaned.count("\n") == 0:
# single statement, potentially 'eval'
try:
code = compile(cleaned, "<repl session>", "eval")
code = self.async_compile(cleaned, "<repl session>", "eval")
except SyntaxError:
pass
else:
executor = eval
if executor is exec:
if executor is None:
try:
code = compile(cleaned, "<repl session>", "exec")
code = self.async_compile(cleaned, "<repl session>", "exec")
except SyntaxError as e:
await ctx.send(self.get_syntax_error(e))
continue
@ -250,9 +270,11 @@ class Dev(commands.Cog):
try:
with redirect_stdout(stdout):
result = executor(code, variables)
if inspect.isawaitable(result):
result = await result
if executor is None:
result = types.FunctionType(code, variables)()
else:
result = executor(code, variables)
result = await self.maybe_await(result)
except:
value = stdout.getvalue()
msg = "{}{}".format(value, traceback.format_exc())