Skip to content

Commit

Permalink
Add asynchronous equivalent to aeval (PR #120)
Browse files Browse the repository at this point in the history
  • Loading branch information
vxgmichel authored Aug 31, 2024
2 parents 7837517 + 43306a0 commit b279f7c
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 2 deletions.
3 changes: 2 additions & 1 deletion aioconsole/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
It also includes an interactive event loop, and a command line interface.
"""

from .execute import aexec
from .execute import aexec, aeval
from .console import AsynchronousConsole, interact
from .stream import ainput, aprint, get_standard_streams
from .events import InteractiveEventLoop, InteractiveEventLoopPolicy
Expand All @@ -17,6 +17,7 @@

__all__ = [
"aexec",
"aeval",
"ainput",
"aprint",
"AsynchronousConsole",
Expand Down
29 changes: 29 additions & 0 deletions aioconsole/execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,3 +148,32 @@ async def aexec(source, local=None, stream=None, filename="<aexec>"):
if isinstance(tree, ast.Interactive):
exec_single_result(result, new_local, stream)
full_update(local, new_local)


async def aeval(source, local=None):
"""Asynchronous equivalent to *eval*."""
if local is None:
local = {}

if not isinstance(local, dict):
raise TypeError("globals must be a dict")

# Ensure that the result key is unique within the local namespace
key = "__aeval_result__"
while key in local:
key += "_"

# Perform syntax check to ensure the input is a valid eval expression
try:
ast.parse(source, mode="eval")
except SyntaxError:
raise

# Assign the result of the expression to a known variable
wrapped_code = f"{key} = {source}"

# Use aexec to evaluate the wrapped code within the given local namespace
await aexec(wrapped_code, local=local)

# Return the result from the local namespace
return local.pop(key)
129 changes: 128 additions & 1 deletion tests/test_execute.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import asyncio

import pytest
from aioconsole import aexec
from aioconsole import aexec, aeval
from aioconsole.execute import compile_for_aexec


Expand Down Expand Up @@ -124,3 +124,130 @@ async def test_correct():
await aexec("async def x(): return")
await aexec("def x(): yield")
await aexec("async def x(): yield")


def echo(x):
"""Sync function for aeval parameterized test."""
return x


async def aecho(x):
"""Async function for aeval parameterized test."""
return echo(x)


# Parametrized test with a variety of expressions
@pytest.mark.asyncio
@pytest.mark.parametrize(
"expression, local",
[
# Valid Simple Expressions
("1 + 2", None),
("sum([i * i for i in range(10)])", None),
# Invalid Expressions
("def foo(): return 42", None),
("x = 1", None),
("x = 1\nx + 1", None),
("for i in range(10): pass", None),
("if True: pass", None),
("while True: break", None),
("try: pass\nexcept: pass", None),
# Expressions Involving Undefined Variables
("undefined_variable", None),
("undefined_function()", None),
# Expressions with Deliberate Errors
("1/0", None),
("open('nonexistent_file.txt')", None),
# Lambda and Anonymous Functions
("(lambda x: x * 2)(5)", None),
# Expressions with Built-in Functions
("len('test')", None),
("min([3, 1, 4, 1, 5, 9])", None),
("max([x * x for x in range(10)])", None),
# Boolean and Conditional Expressions
("True and False", None),
("not True", None), # Boolean negation
("5 if True else 10", None),
# String Manipulation
("'hello' + ' ' + 'world'", None),
("f'hello {42}'", None),
# Complex List Comprehensions
("[x for x in range(5)]", None),
("[x * x for x in range(10) if x % 2 == 0]", None),
# Expressions with Syntax Errors
("return 42", None),
("yield 5", None),
# Test with await
("await aecho(5)", {"aecho": aecho, "echo": echo}),
# Test invalid local
("...", []),
("...", "string_instead_of_dict"),
("...", 42),
("...", set()),
("...", ...),
("...", 1.5),
("...", object()),
("...", asyncio),
("...", lambda: ...),
("...", {"__result__": 99}),
# Invalid expressions
("", None),
(None, None),
(0, None),
({}, None),
(object(), None),
(asyncio, None),
(..., None),
(lambda: ..., None),
# Conflicting name in local
("x", {"x": 1, "__aeval_result__": 99}),
],
)
async def test_aeval(expression, local):

async def capture(func, *args, **kwargs):
try:
if asyncio.iscoroutinefunction(func):
result = await func(*args, **kwargs)
else:
result = func(*args, **kwargs)
except Exception as exc:
return (type(exc), None)
else:
return (None, result)

# Remove the await keyword from the expression for the synchronous evaluation
sync_expression = (
expression.replace("await a", "") if isinstance(expression, str) else expression
)

# Capture and compare the results of the synchronous and asynchronous evaluations
sync_capture = await capture(eval, sync_expression, local)
async_capture = await capture(aeval, expression, local)
assert sync_capture == async_capture


# Test calling an async function without awaiting it
@pytest.mark.asyncio
async def test_aeval_async_func_without_await():
expression = "asyncio.sleep(0)"
local = {"asyncio": asyncio}
result = await aeval(expression, local)
assert asyncio.iscoroutine(result)
await result


@pytest.mark.asyncio
async def test_aeval_valid_await_syntax():
expression = "await aecho(10)"
local = {"aecho": aecho}
result = await aeval(expression, local)
assert result == 10


@pytest.mark.asyncio
async def test_aeval_coro_in_local():
expression = "await coro"
local = {"coro": aecho(10)}
result = await aeval(expression, local)
assert result == 10

0 comments on commit b279f7c

Please sign in to comment.