diff --git a/src/anyio/pytest_plugin.py b/src/anyio/pytest_plugin.py index 2d57c09d..23528445 100644 --- a/src/anyio/pytest_plugin.py +++ b/src/anyio/pytest_plugin.py @@ -2,7 +2,7 @@ from collections.abc import Iterator from contextlib import contextmanager -from contextvars import Context, copy_context +from contextvars import Context, ContextVar, copy_context from inspect import isasyncgenfunction, iscoroutinefunction from typing import Any, Dict, Tuple, cast @@ -15,17 +15,35 @@ from .abc import TestRunner _current_runner: TestRunner | None = None +_current_reentrancy_token = ContextVar[object]("anyio.pytest_plugin.reentrancy_token") contextvars_context_key: StashKey[Context] = StashKey() _test_context_like_key: StashKey[ContextLike] = StashKey() class _TestContext(ContextLike): - """This class manages transmission of sniffio.current_async_library_cvar""" + """Manages reentrancy and transmission of sniffio.current_async_library_cvar""" def __init__(self, context: Context): self._context = context + self._reentrancy_token = object() + + def _is_already_in_context(self) -> bool: + # if context var is not set to the token, we are in another context + if _current_reentrancy_token.get(None) is not self._reentrancy_token: + return False + + # Token value is the same, but we may be in a copy of self._context + test_value = object() + reset_reentrancy = _current_reentrancy_token.set(test_value) + try: + return self._context[_current_reentrancy_token] is test_value + finally: + _current_reentrancy_token.reset(reset_reentrancy) def run(self, func: Any, /, *args: Any, **kwargs: Any) -> Any: + if self._is_already_in_context(): + return func(*args, **kwargs) + return self._context.run( self._set_context_and_run, sniffio.current_async_library_cvar.get(None), @@ -37,6 +55,7 @@ def run(self, func: Any, /, *args: Any, **kwargs: Any) -> Any: def _set_context_and_run( self, current_async_library: str | None, func: Any, /, *args: Any, **kwargs: Any ) -> Any: + reset_reentrancy = _current_reentrancy_token.set(self._reentrancy_token) reset_sniffio = None if current_async_library is not None: reset_sniffio = sniffio.current_async_library_cvar.set( @@ -46,6 +65,7 @@ def _set_context_and_run( try: return func(*args, **kwargs) finally: + _current_reentrancy_token.reset(reset_reentrancy) if reset_sniffio is not None: sniffio.current_async_library_cvar.reset(reset_sniffio) diff --git a/tests/test_pytest_plugin.py b/tests/test_pytest_plugin.py index 006c1e56..e986d79b 100644 --- a/tests/test_pytest_plugin.py +++ b/tests/test_pytest_plugin.py @@ -439,3 +439,61 @@ def test_sync_func_sync_then_async_fixture( result = testdir.runpytest(*pytest_args) result.assert_outcomes(passed=4 * len(get_all_backends())) + + +def test_sync_getfixturevalue(testdir: Pytester) -> None: + testdir.makepyfile( + """ + from __future__ import annotations + + from contextvars import ContextVar + + import pytest + + + var = ContextVar("var") + + + @pytest.fixture + def function_fixture(): + return "function" + + + @pytest.fixture + def generator_fixture(): + yield "generator" + + + @pytest.fixture + def set_var(): + value = object() + reset = var.set(value) + yield value + var.reset(reset) + + + @pytest.mark.parametrize("prefix", ["function", "generator"]) + def test_getfixturevalue_from_sync(request, prefix): + assert request.getfixturevalue(f"{prefix}_fixture") == prefix + + + @pytest.mark.anyio + @pytest.mark.parametrize("prefix", ["function", "generator"]) + async def test_getfixturevalue_from_async(request, prefix): + assert request.getfixturevalue(f"{prefix}_fixture") == prefix + + + def test_getfixturevalue_with_context_from_sync(request): + value = request.getfixturevalue("set_var") + assert var.get(None) is value + + + @pytest.mark.anyio + async def test_getfixturevalue_with_context_from_async(request): + value = request.getfixturevalue("set_var") + assert var.get(None) is value + """ + ) + + result = testdir.runpytest(*pytest_args) + result.assert_outcomes(passed=3 * len(get_all_backends()) + 3)