diff --git a/src/fasthep_flow/cli.py b/src/fasthep_flow/cli.py index 7d9f639..f9cb909 100644 --- a/src/fasthep_flow/cli.py +++ b/src/fasthep_flow/cli.py @@ -61,7 +61,7 @@ def execute( cfg = init_config(config, overrides) workflow = Workflow(config=cfg) save_path = workflow.save(Path(save_path)) - dag = workflow_to_hamilton_dag(workflow) + dag = workflow_to_hamilton_dag(workflow, save_path) dag.visualize_execution( final_vars=workflow.task_names, output_file_path=Path(workflow.save_path) / "dag.png", diff --git a/src/fasthep_flow/orchestration.py b/src/fasthep_flow/orchestration.py index 024632c..b6d44a8 100644 --- a/src/fasthep_flow/orchestration.py +++ b/src/fasthep_flow/orchestration.py @@ -6,6 +6,7 @@ from __future__ import annotations import logging +from pathlib import Path from typing import Any from hamilton import driver, telemetry @@ -39,10 +40,15 @@ # return client -def workflow_to_hamilton_dag(workflow: Workflow) -> Any: +def workflow_to_hamilton_dag(workflow: Workflow, output_path: str) -> Any: """Convert a workflow into a Hamilton flow.""" task_functions = load_tasks_module(workflow) - return driver.Builder().with_modules(task_functions).with_cache().build() + return ( + driver.Builder() + .with_modules(task_functions) + .with_cache(Path(output_path) / ".hamilton_cache") + .build() + ) # def gitlab_ci_workflow(workflow: Workflow):