diff --git a/aioconsole/__init__.py b/aioconsole/__init__.py index 1c8f441..75fb335 100644 --- a/aioconsole/__init__.py +++ b/aioconsole/__init__.py @@ -3,7 +3,7 @@ It also includes an interactive event loop, and a command line interface. """ -from .execute import aexec +from .execute import aexec, aeval from .console import AsynchronousConsole, interact from .stream import ainput, aprint, get_standard_streams from .events import InteractiveEventLoop, InteractiveEventLoopPolicy @@ -17,6 +17,7 @@ __all__ = [ "aexec", + "aeval", "ainput", "aprint", "AsynchronousConsole", diff --git a/aioconsole/execute.py b/aioconsole/execute.py index c9ca50c..8e2e2c3 100644 --- a/aioconsole/execute.py +++ b/aioconsole/execute.py @@ -148,3 +148,32 @@ async def aexec(source, local=None, stream=None, filename=""): if isinstance(tree, ast.Interactive): exec_single_result(result, new_local, stream) full_update(local, new_local) + + +async def aeval(source, local=None): + """Asynchronous equivalent to *eval*.""" + if local is None: + local = {} + + if not isinstance(local, dict): + raise TypeError("globals must be a dict") + + # Ensure that the result key is unique within the local namespace + key = "__aeval_result__" + while key in local: + key += "_" + + # Perform syntax check to ensure the input is a valid eval expression + try: + ast.parse(source, mode="eval") + except SyntaxError: + raise + + # Assign the result of the expression to a known variable + wrapped_code = f"{key} = {source}" + + # Use aexec to evaluate the wrapped code within the given local namespace + await aexec(wrapped_code, local=local) + + # Return the result from the local namespace + return local.pop(key) diff --git a/tests/test_execute.py b/tests/test_execute.py index 269ba87..e7d574d 100644 --- a/tests/test_execute.py +++ b/tests/test_execute.py @@ -2,7 +2,7 @@ import asyncio import pytest -from aioconsole import aexec +from aioconsole import aexec, aeval from aioconsole.execute import compile_for_aexec @@ -124,3 +124,130 @@ async def test_correct(): await aexec("async def x(): return") await aexec("def x(): yield") await aexec("async def x(): yield") + + +def echo(x): + """Sync function for aeval parameterized test.""" + return x + + +async def aecho(x): + """Async function for aeval parameterized test.""" + return echo(x) + + +# Parametrized test with a variety of expressions +@pytest.mark.asyncio +@pytest.mark.parametrize( + "expression, local", + [ + # Valid Simple Expressions + ("1 + 2", None), + ("sum([i * i for i in range(10)])", None), + # Invalid Expressions + ("def foo(): return 42", None), + ("x = 1", None), + ("x = 1\nx + 1", None), + ("for i in range(10): pass", None), + ("if True: pass", None), + ("while True: break", None), + ("try: pass\nexcept: pass", None), + # Expressions Involving Undefined Variables + ("undefined_variable", None), + ("undefined_function()", None), + # Expressions with Deliberate Errors + ("1/0", None), + ("open('nonexistent_file.txt')", None), + # Lambda and Anonymous Functions + ("(lambda x: x * 2)(5)", None), + # Expressions with Built-in Functions + ("len('test')", None), + ("min([3, 1, 4, 1, 5, 9])", None), + ("max([x * x for x in range(10)])", None), + # Boolean and Conditional Expressions + ("True and False", None), + ("not True", None), # Boolean negation + ("5 if True else 10", None), + # String Manipulation + ("'hello' + ' ' + 'world'", None), + ("f'hello {42}'", None), + # Complex List Comprehensions + ("[x for x in range(5)]", None), + ("[x * x for x in range(10) if x % 2 == 0]", None), + # Expressions with Syntax Errors + ("return 42", None), + ("yield 5", None), + # Test with await + ("await aecho(5)", {"aecho": aecho, "echo": echo}), + # Test invalid local + ("...", []), + ("...", "string_instead_of_dict"), + ("...", 42), + ("...", set()), + ("...", ...), + ("...", 1.5), + ("...", object()), + ("...", asyncio), + ("...", lambda: ...), + ("...", {"__result__": 99}), + # Invalid expressions + ("", None), + (None, None), + (0, None), + ({}, None), + (object(), None), + (asyncio, None), + (..., None), + (lambda: ..., None), + # Conflicting name in local + ("x", {"x": 1, "__aeval_result__": 99}), + ], +) +async def test_aeval(expression, local): + + async def capture(func, *args, **kwargs): + try: + if asyncio.iscoroutinefunction(func): + result = await func(*args, **kwargs) + else: + result = func(*args, **kwargs) + except Exception as exc: + return (type(exc), None) + else: + return (None, result) + + # Remove the await keyword from the expression for the synchronous evaluation + sync_expression = ( + expression.replace("await a", "") if isinstance(expression, str) else expression + ) + + # Capture and compare the results of the synchronous and asynchronous evaluations + sync_capture = await capture(eval, sync_expression, local) + async_capture = await capture(aeval, expression, local) + assert sync_capture == async_capture + + +# Test calling an async function without awaiting it +@pytest.mark.asyncio +async def test_aeval_async_func_without_await(): + expression = "asyncio.sleep(0)" + local = {"asyncio": asyncio} + result = await aeval(expression, local) + assert asyncio.iscoroutine(result) + await result + + +@pytest.mark.asyncio +async def test_aeval_valid_await_syntax(): + expression = "await aecho(10)" + local = {"aecho": aecho} + result = await aeval(expression, local) + assert result == 10 + + +@pytest.mark.asyncio +async def test_aeval_coro_in_local(): + expression = "await coro" + local = {"coro": aecho(10)} + result = await aeval(expression, local) + assert result == 10