From 42a23277cd4f9249f1a552208e921a27893a976f Mon Sep 17 00:00:00 2001 From: zephyrkul Date: Thu, 13 Feb 2020 10:29:10 -0700 Subject: [PATCH] [Dev] Allow top-level await in code statements (#3508) * [dev] allow top-level await in code statements * style * use staticmethod, cls is unneeded * add asyncio and aiohttp to env * fix repl * add __builtins__ to repl env * style... * fix debug with no coro * add `optimize=0` to eval --- redbot/core/dev_commands.py | 58 +++++++++++++++++++++++++------------ 1 file changed, 40 insertions(+), 18 deletions(-) diff --git a/redbot/core/dev_commands.py b/redbot/core/dev_commands.py index 729685d37..b4474e2b8 100644 --- a/redbot/core/dev_commands.py +++ b/redbot/core/dev_commands.py @@ -1,8 +1,11 @@ +import ast import asyncio +import aiohttp import inspect import io import textwrap import traceback +import types import re from contextlib import redirect_stdout from copy import copy @@ -35,6 +38,19 @@ class Dev(commands.Cog): self._last_result = None self.sessions = set() + @staticmethod + def async_compile(source, filename, mode): + return compile(source, filename, mode, flags=ast.PyCF_ALLOW_TOP_LEVEL_AWAIT, optimize=0) + + @staticmethod + async def maybe_await(coro): + for i in range(2): + if inspect.isawaitable(coro): + coro = await coro + else: + return coro + return coro + @staticmethod def cleanup_code(content): """Automatically removes code blocks from the code.""" @@ -53,7 +69,9 @@ class Dev(commands.Cog): """ if e.text is None: return box("{0.__class__.__name__}: {0}".format(e), lang="py") - return box("{0.text}{1:>{0.offset}}\n{2}: {0}".format(e, "^", type(e).__name__), lang="py") + return box( + "{0.text}\n{1:>{0.offset}}\n{2}: {0}".format(e, "^", type(e).__name__), lang="py" + ) @staticmethod def get_pages(msg: str): @@ -75,8 +93,8 @@ class Dev(commands.Cog): If the return value of the code is a coroutine, it will be awaited, and the result of that will be the bot's response. - Note: Only one statement may be evaluated. Using await, yield or - similar restricted keywords will result in a syntax error. For multiple + Note: Only one statement may be evaluated. Using certain restricted + keywords, e.g. yield, will result in a syntax error. For multiple lines or asynchronous code, see [p]repl or [p]eval. Environment Variables: @@ -96,6 +114,8 @@ class Dev(commands.Cog): "author": ctx.author, "guild": ctx.guild, "message": ctx.message, + "asyncio": asyncio, + "aiohttp": aiohttp, "discord": discord, "commands": commands, "_": self._last_result, @@ -104,7 +124,8 @@ class Dev(commands.Cog): code = self.cleanup_code(code) try: - result = eval(code, env) + compiled = self.async_compile(code, "", "eval") + result = await self.maybe_await(eval(compiled, env)) except SyntaxError as e: await ctx.send(self.get_syntax_error(e)) return @@ -112,9 +133,6 @@ class Dev(commands.Cog): await ctx.send(box("{}: {!s}".format(type(e).__name__, e), lang="py")) return - if inspect.isawaitable(result): - result = await result - self._last_result = result result = self.sanitize_output(ctx, str(result)) @@ -149,6 +167,8 @@ class Dev(commands.Cog): "author": ctx.author, "guild": ctx.guild, "message": ctx.message, + "asyncio": asyncio, + "aiohttp": aiohttp, "discord": discord, "commands": commands, "_": self._last_result, @@ -160,7 +180,8 @@ class Dev(commands.Cog): to_compile = "async def func():\n%s" % textwrap.indent(body, " ") try: - exec(to_compile, env) + compiled = self.async_compile(to_compile, "", "exec") + exec(compiled, env) except SyntaxError as e: return await ctx.send(self.get_syntax_error(e)) @@ -192,9 +213,6 @@ class Dev(commands.Cog): The REPL will only recognise code as messages which start with a backtick. This includes codeblocks, and as such multiple lines can be evaluated. - - You may not await any code in this REPL unless you define it inside an - async function. """ variables = { "ctx": ctx, @@ -203,7 +221,9 @@ class Dev(commands.Cog): "guild": ctx.guild, "channel": ctx.channel, "author": ctx.author, + "asyncio": asyncio, "_": None, + "__builtins__": __builtins__, } if ctx.channel.id in self.sessions: @@ -225,19 +245,19 @@ class Dev(commands.Cog): self.sessions.remove(ctx.channel.id) return - executor = exec + executor = None if cleaned.count("\n") == 0: # single statement, potentially 'eval' try: - code = compile(cleaned, "", "eval") + code = self.async_compile(cleaned, "", "eval") except SyntaxError: pass else: executor = eval - if executor is exec: + if executor is None: try: - code = compile(cleaned, "", "exec") + code = self.async_compile(cleaned, "", "exec") except SyntaxError as e: await ctx.send(self.get_syntax_error(e)) continue @@ -250,9 +270,11 @@ class Dev(commands.Cog): try: with redirect_stdout(stdout): - result = executor(code, variables) - if inspect.isawaitable(result): - result = await result + if executor is None: + result = types.FunctionType(code, variables)() + else: + result = executor(code, variables) + result = await self.maybe_await(result) except: value = stdout.getvalue() msg = "{}{}".format(value, traceback.format_exc())