Skip to content

Commit

Permalink
feat(orch): add adapters for dask and local
Browse files Browse the repository at this point in the history
  • Loading branch information
kreczko committed Oct 15, 2024
1 parent 5d35ca2 commit b8a9931
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 25 deletions.
2 changes: 1 addition & 1 deletion mypy.ini
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ implicit_reexport = true
ignore_missing_imports = true
implicit_reexport = true

[mypy-prefect_dask.*]
[mypy-dask.*]
ignore_missing_imports = true
implicit_reexport = true

Expand Down
1 change: 1 addition & 0 deletions src/fasthep_flow/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def execute(
final_vars=workflow.task_names,
output_file_path=Path(workflow.save_path) / "dag.png",
)
# TODO: if specified, run a specific task/node with execute_node
results = dag.execute(workflow.task_names, inputs={})
typer.echo(f"Results: {results}")

Expand Down
67 changes: 43 additions & 24 deletions src/fasthep_flow/orchestration.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,49 +6,68 @@
from __future__ import annotations

import logging
from functools import partial
from pathlib import Path
from typing import Any

from hamilton import driver, telemetry
from dask.distributed import Client, LocalCluster
from hamilton import base, driver, telemetry

from .workflow import Workflow, load_tasks_module

telemetry.disable_telemetry() # we don't want to send telemetry data for this example

logger = logging.getLogger(__name__)

# def get_runner(runner: str) -> Any:
# """Get the task runner for the given name."""
# from prefect.task_runners import ConcurrentTaskRunner, SequentialTaskRunner
# from prefect_dask import DaskTaskRunner

# runners: dict[str, Any] = {
# "Dask": DaskTaskRunner,
# "Sequential": SequentialTaskRunner,
# "Concurrent": ConcurrentTaskRunner,
# }
def create_dask_cluster() -> Any:
cluster = LocalCluster()
client = Client(cluster)
logger.info(client.cluster)

# return runners[runner]
return client


# def create_dask_cluster() -> Any:
# cluster = LocalCluster()
# client = Client(cluster)
# logger.info(client.cluster)
# logging.basicConfig(stream=sys.stdout, level=logging.INFO)
DASK_CLIENTS = {
"local": create_dask_cluster,
}

# return client

def create_dask_adapter(client_type: str) -> Any:
from hamilton.plugins import h_dask

def workflow_to_hamilton_dag(workflow: Workflow, output_path: str) -> Any:
client = DASK_CLIENTS[client_type]()

return h_dask.DaskGraphAdapter(
client,
base.DictResult(),
visualize_kwargs={"filename": "run_with_delayed", "format": "png"},
use_delayed=True,
compute_at_end=True,
)


def create_local_adapter() -> Any:
return base.SimplePythonGraphAdapter(base.DictResult())


PRECONFIGURED_ADAPTERS = {
"dask:local": partial(create_dask_adapter, client_type="local"),
"local": create_local_adapter,
}


def workflow_to_hamilton_dag(
workflow: Workflow,
output_path: str,
# method: str = "local"
) -> Any:
"""Convert a workflow into a Hamilton flow."""
task_functions = load_tasks_module(workflow)
return (
driver.Builder()
.with_modules(task_functions)
.with_cache(Path(output_path) / ".hamilton_cache")
.build()
)
# adapter = PRECONFIGURED_ADAPTERS[method]()
cache_dir = Path(output_path) / ".hamilton_cache"

return driver.Builder().with_modules(task_functions).with_cache(cache_dir).build()


# def gitlab_ci_workflow(workflow: Workflow):
Expand Down

0 comments on commit b8a9931

Please sign in to comment.