Skip to content

Commit

Permalink
FIX: Accept Iterable of Mappings from async step functions in functio…
Browse files Browse the repository at this point in the history
…nal API (#25)

* FIX: accept iterable of mappings from async step function

* API: improve error message for invalid step return values

* API: improve log messages

* DOC: update release notes

* BUILD: manage min matrix test requirements for numpy
  • Loading branch information
j-ittner authored Jul 9, 2024
1 parent 8d596ce commit b388ace
Show file tree
Hide file tree
Showing 7 changed files with 68 additions and 26 deletions.
6 changes: 6 additions & 0 deletions RELEASE_NOTES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,12 @@ Release Notes
*fluxus* 1.0
------------

*fluxus* 1.0.2
~~~~~~~~~~~~~~

- FIX: Allow asynchronous step functions to return iterators and asynchronous iterators.


*fluxus* 1.0.1
~~~~~~~~~~~~~~

Expand Down
2 changes: 2 additions & 0 deletions condabuild/meta.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ test:
- fluxus
requires:
- pytest ~= 8.2
# additional requirements of gamma-pytools
- numpy {{ environ.get('FLUXUS_V_NUMPY', '[False]') }}
commands:
- conda list
- python -c "import fluxus; import os; assert fluxus.__version__ == os.environ['FLUXUS_BUILD_FLUXUS_VERSION']"
Expand Down
4 changes: 4 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,13 +63,17 @@ pandas = "~=2.1"
python = ">=3.10.14,<3.11a"
gamma-pytools = "~=3.0.0"
typing_inspect = "~=0.7.1"
# additional minimum requirements of gamma-pytools
numpy = ">=1.23.5,<1.24a" # cannot use ~= due to conda bug

[build.matrix.max]
matplotlib = "~=3.8"
pandas = "~=2.2"
python = ">=3.12,<4a"
gamma-pytools = "~=3.0"
typing_inspect = "~=0.9"
# additional minimum requirements of gamma-pytools
numpy = ">=2.0,<3a" # cannot use ~= due to conda bug

[tool.black]
required-version = '24.4.2'
Expand Down
12 changes: 9 additions & 3 deletions src/fluxus/functional/_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def step(
Mapping[str, Any]
| Iterable[dict[str, Any]]
| AsyncIterable[dict[str, Any]]
| Awaitable[Mapping[str, Any]],
| Awaitable[Mapping[str, Any]]
| Awaitable[Iterable[Mapping[str, Any]]]
| Awaitable[AsyncIterable[Mapping[str, Any]]],
],
/,
**kwargs: Any,
Expand All @@ -103,7 +105,9 @@ def step( # type: ignore[misc]
Mapping[str, Any]
| Iterable[dict[str, Any]]
| AsyncIterable[dict[str, Any]]
| Awaitable[Mapping[str, Any]],
| Awaitable[Mapping[str, Any]]
| Awaitable[Iterable[Mapping[str, Any]]]
| Awaitable[AsyncIterable[Mapping[str, Any]]],
],
/,
**kwargs: Any,
Expand All @@ -130,7 +134,9 @@ def step(
Mapping[str, Any]
| Iterable[dict[str, Any]]
| AsyncIterable[dict[str, Any]]
| Awaitable[Mapping[str, Any]],
| Awaitable[Mapping[str, Any]]
| Awaitable[Iterable[Mapping[str, Any]]]
| Awaitable[AsyncIterable[Mapping[str, Any]]],
]
| Iterable[Mapping[str, Any]]
| AsyncIterable[dict[str, Any]]
Expand Down
49 changes: 28 additions & 21 deletions src/fluxus/functional/conduit/_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,9 @@ def add_one(x):
Mapping[str, Any]
| Iterable[Mapping[str, Any]]
| AsyncIterable[Mapping[str, Any]]
| Awaitable[Mapping[str, Any]],
| Awaitable[Mapping[str, Any]]
| Awaitable[Iterable[Mapping[str, Any]]]
| Awaitable[AsyncIterable[Mapping[str, Any]]],
]

#: Additional keyword arguments to pass to the function.
Expand All @@ -149,7 +151,9 @@ def __init__(
Mapping[str, Any]
| Iterable[Mapping[str, Any]]
| AsyncIterable[Mapping[str, Any]]
| Awaitable[Mapping[str, Any]],
| Awaitable[Mapping[str, Any]]
| Awaitable[Iterable[Mapping[str, Any]]]
| Awaitable[AsyncIterable[Mapping[str, Any]]],
],
/,
**kwargs: Any,
Expand Down Expand Up @@ -193,7 +197,9 @@ def function(self) -> Callable[
Mapping[str, Any]
| Iterable[Mapping[str, Any]]
| AsyncIterable[Mapping[str, Any]]
| Awaitable[Mapping[str, Any]],
| Awaitable[Mapping[str, Any]]
| Awaitable[Iterable[Mapping[str, Any]]]
| Awaitable[AsyncIterable[Mapping[str, Any]]],
]:
"""
The function that this step applies to the source product.
Expand Down Expand Up @@ -221,14 +227,15 @@ async def atransform(
# source.
shadowed_attributes = source_product_attributes.keys() & kwargs.keys()
if shadowed_attributes:
logging.warning(
f"Fixed keyword arguments of step {self.name!r} shadow attributes of "
f"the source product: "
+ ", ".join(
log.warning(
"Fixed keyword arguments of step %r shadow attributes of the source "
"product: %s",
self.name,
", ".join(
f"{attr}={kwargs[attr]} shadows {attr}="
f"{source_product_attributes[attr]}"
for attr in sorted(shadowed_attributes)
)
),
)

# Input arguments are the union of the source product attributes and the fixed
Expand All @@ -247,10 +254,11 @@ async def atransform(
# actual result, an iterable of results, or an async iterable of results.
attribute_iterable = self._function(**input_args)

if isinstance(attribute_iterable, Awaitable):
attribute_iterable = await attribute_iterable

if isinstance(attribute_iterable, Mapping):
attribute_iterable = iter_sync_to_async([attribute_iterable])
elif isinstance(attribute_iterable, Awaitable):
attribute_iterable = _awaitable_to_async_iter(attribute_iterable)
elif isinstance(attribute_iterable, Iterable):
attribute_iterable = iter_sync_to_async(
cast(Iterable[Mapping[str, Any]], attribute_iterable)
Expand All @@ -269,20 +277,19 @@ async def atransform(
if not isinstance(attributes, Mapping):
raise TypeError(
f"Expected function {self._function.__name__}() of step "
f"{self.name!r} to return a Mapping or dict, but got: "
f"{attributes!r}"
f"{self.name!r} to return one or more instances of Mapping or "
f"dict, but got: {attributes!r}"
)

log.debug(
f"Completed step {self.name!r} in {end_time - start_time:g} "
f"seconds:\n"
+ str(
BinaryOperation(
BinaryOperator.ASSIGN,
Id(self._function)(**input_args),
DictLiteral(**attributes),
)
)
"Completed step %r in %g seconds:\n%s",
self.name,
end_time - start_time,
BinaryOperation(
BinaryOperator.ASSIGN,
Id(self._function)(**input_args),
DictLiteral(**attributes),
),
)

yield DictProduct(
Expand Down
17 changes: 16 additions & 1 deletion test/fluxus_test/test_functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def test_flow() -> None:
)

# The steps that process the input data
# noinspection PyTypeChecker
steps = chain(
parallel(
step(
Expand Down Expand Up @@ -179,9 +178,11 @@ def test_flow() -> None:
]

# run once with input as part of the chain
# noinspection PyTypeChecker
assert _sort_nested(run(chain(input_step, steps))) == result_expected

# run again with input as a separate argument
# noinspection PyTypeChecker
assert _sort_nested(run(steps, input=input_data)) == result_expected

# construct the same flow, with parallel steps as an iterable
Expand Down Expand Up @@ -212,6 +213,7 @@ def test_flow() -> None:
)

# run again with input as a separate argument
# noinspection PyTypeChecker
assert _sort_nested(run(steps, input=input_data)) == result_expected


Expand Down Expand Up @@ -302,6 +304,7 @@ def test_parallel_inputs() -> None:
parallel(passthrough()) # type: ignore[call-overload]

inc = step("increment", lambda a: dict(a=a + 1))
# noinspection PyTypeChecker
assert run(parallel([inc], inc), input=[dict(a=2)]) == RunResult(
[{"input": {"a": 2}, "increment": {"a": 3}}],
[{"input": {"a": 2}, "increment": {"a": 3}}],
Expand All @@ -313,6 +316,7 @@ def test_parallel_inputs() -> None:

def test_passthrough() -> None:

# noinspection PyTypeChecker
flow = chain(
# Create a producer step that produces a single dictionary
step(
Expand Down Expand Up @@ -598,6 +602,17 @@ def test_implicit_input() -> None:
)


def test_sync_from_async() -> None:
async def toy(a: int) -> list[dict[str, Any]]:
return [dict(a=i, a_previous=a) for i in range(3)]

pipeline = chain(
step("toy", toy),
step("toy", toy),
)
run(pipeline, input=dict(a=3))


#
# Auxiliary functions
#
Expand Down
4 changes: 3 additions & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@ deps =
pandas{env:FLUXUS_V_PANDAS}
gamma-pytools{env:FLUXUS_V_GAMMA_PYTOOLS}
typing_inspect{env:FLUXUS_V_TYPING_INSPECT}
# additional requirements of gamma-pytools
numpy{env:FLUXUS_V_NUMPY}

[flake8]

Expand Down Expand Up @@ -96,7 +98,7 @@ profile=black
src_paths=src,test
known_local_folder=test
known_first_party=pytools
known_third_party=numpy,pandas,joblib,matplot
known_third_party=numpy,pandas,matplot
case_sensitive = True

[pytest]
Expand Down

0 comments on commit b388ace

Please sign in to comment.