[Dev] Allow top-level await in code statements (#3508)

* [dev] allow top-level await in code statements

* style

* use staticmethod, cls is unneeded

* add asyncio and aiohttp to env

* fix repl

* add __builtins__ to repl env

* style...

* fix debug with no coro

* add `optimize=0` to eval
This commit is contained in:
zephyrkul 2020-02-13 10:29:10 -07:00 committed by GitHub
parent cc30726ab6
commit 42a23277cd
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -1,8 +1,11 @@
import ast
import asyncio import asyncio
import aiohttp
import inspect import inspect
import io import io
import textwrap import textwrap
import traceback import traceback
import types
import re import re
from contextlib import redirect_stdout from contextlib import redirect_stdout
from copy import copy from copy import copy
@ -35,6 +38,19 @@ class Dev(commands.Cog):
self._last_result = None self._last_result = None
self.sessions = set() self.sessions = set()
@staticmethod
def async_compile(source, filename, mode):
return compile(source, filename, mode, flags=ast.PyCF_ALLOW_TOP_LEVEL_AWAIT, optimize=0)
@staticmethod
async def maybe_await(coro):
for i in range(2):
if inspect.isawaitable(coro):
coro = await coro
else:
return coro
return coro
@staticmethod @staticmethod
def cleanup_code(content): def cleanup_code(content):
"""Automatically removes code blocks from the code.""" """Automatically removes code blocks from the code."""
@ -53,7 +69,9 @@ class Dev(commands.Cog):
""" """
if e.text is None: if e.text is None:
return box("{0.__class__.__name__}: {0}".format(e), lang="py") return box("{0.__class__.__name__}: {0}".format(e), lang="py")
return box("{0.text}{1:>{0.offset}}\n{2}: {0}".format(e, "^", type(e).__name__), lang="py") return box(
"{0.text}\n{1:>{0.offset}}\n{2}: {0}".format(e, "^", type(e).__name__), lang="py"
)
@staticmethod @staticmethod
def get_pages(msg: str): def get_pages(msg: str):
@ -75,8 +93,8 @@ class Dev(commands.Cog):
If the return value of the code is a coroutine, it will be awaited, If the return value of the code is a coroutine, it will be awaited,
and the result of that will be the bot's response. and the result of that will be the bot's response.
Note: Only one statement may be evaluated. Using await, yield or Note: Only one statement may be evaluated. Using certain restricted
similar restricted keywords will result in a syntax error. For multiple keywords, e.g. yield, will result in a syntax error. For multiple
lines or asynchronous code, see [p]repl or [p]eval. lines or asynchronous code, see [p]repl or [p]eval.
Environment Variables: Environment Variables:
@ -96,6 +114,8 @@ class Dev(commands.Cog):
"author": ctx.author, "author": ctx.author,
"guild": ctx.guild, "guild": ctx.guild,
"message": ctx.message, "message": ctx.message,
"asyncio": asyncio,
"aiohttp": aiohttp,
"discord": discord, "discord": discord,
"commands": commands, "commands": commands,
"_": self._last_result, "_": self._last_result,
@ -104,7 +124,8 @@ class Dev(commands.Cog):
code = self.cleanup_code(code) code = self.cleanup_code(code)
try: try:
result = eval(code, env) compiled = self.async_compile(code, "<string>", "eval")
result = await self.maybe_await(eval(compiled, env))
except SyntaxError as e: except SyntaxError as e:
await ctx.send(self.get_syntax_error(e)) await ctx.send(self.get_syntax_error(e))
return return
@ -112,9 +133,6 @@ class Dev(commands.Cog):
await ctx.send(box("{}: {!s}".format(type(e).__name__, e), lang="py")) await ctx.send(box("{}: {!s}".format(type(e).__name__, e), lang="py"))
return return
if inspect.isawaitable(result):
result = await result
self._last_result = result self._last_result = result
result = self.sanitize_output(ctx, str(result)) result = self.sanitize_output(ctx, str(result))
@ -149,6 +167,8 @@ class Dev(commands.Cog):
"author": ctx.author, "author": ctx.author,
"guild": ctx.guild, "guild": ctx.guild,
"message": ctx.message, "message": ctx.message,
"asyncio": asyncio,
"aiohttp": aiohttp,
"discord": discord, "discord": discord,
"commands": commands, "commands": commands,
"_": self._last_result, "_": self._last_result,
@ -160,7 +180,8 @@ class Dev(commands.Cog):
to_compile = "async def func():\n%s" % textwrap.indent(body, " ") to_compile = "async def func():\n%s" % textwrap.indent(body, " ")
try: try:
exec(to_compile, env) compiled = self.async_compile(to_compile, "<string>", "exec")
exec(compiled, env)
except SyntaxError as e: except SyntaxError as e:
return await ctx.send(self.get_syntax_error(e)) return await ctx.send(self.get_syntax_error(e))
@ -192,9 +213,6 @@ class Dev(commands.Cog):
The REPL will only recognise code as messages which start with a The REPL will only recognise code as messages which start with a
backtick. This includes codeblocks, and as such multiple lines can be backtick. This includes codeblocks, and as such multiple lines can be
evaluated. evaluated.
You may not await any code in this REPL unless you define it inside an
async function.
""" """
variables = { variables = {
"ctx": ctx, "ctx": ctx,
@ -203,7 +221,9 @@ class Dev(commands.Cog):
"guild": ctx.guild, "guild": ctx.guild,
"channel": ctx.channel, "channel": ctx.channel,
"author": ctx.author, "author": ctx.author,
"asyncio": asyncio,
"_": None, "_": None,
"__builtins__": __builtins__,
} }
if ctx.channel.id in self.sessions: if ctx.channel.id in self.sessions:
@ -225,19 +245,19 @@ class Dev(commands.Cog):
self.sessions.remove(ctx.channel.id) self.sessions.remove(ctx.channel.id)
return return
executor = exec executor = None
if cleaned.count("\n") == 0: if cleaned.count("\n") == 0:
# single statement, potentially 'eval' # single statement, potentially 'eval'
try: try:
code = compile(cleaned, "<repl session>", "eval") code = self.async_compile(cleaned, "<repl session>", "eval")
except SyntaxError: except SyntaxError:
pass pass
else: else:
executor = eval executor = eval
if executor is exec: if executor is None:
try: try:
code = compile(cleaned, "<repl session>", "exec") code = self.async_compile(cleaned, "<repl session>", "exec")
except SyntaxError as e: except SyntaxError as e:
await ctx.send(self.get_syntax_error(e)) await ctx.send(self.get_syntax_error(e))
continue continue
@ -250,9 +270,11 @@ class Dev(commands.Cog):
try: try:
with redirect_stdout(stdout): with redirect_stdout(stdout):
if executor is None:
result = types.FunctionType(code, variables)()
else:
result = executor(code, variables) result = executor(code, variables)
if inspect.isawaitable(result): result = await self.maybe_await(result)
result = await result
except: except:
value = stdout.getvalue() value = stdout.getvalue()
msg = "{}{}".format(value, traceback.format_exc()) msg = "{}{}".format(value, traceback.format_exc())