From 34b250f0ff49f0f61ef821f927199073d91f1f6c Mon Sep 17 00:00:00 2001 From: Reid Mello <30907815+rjmello@users.noreply.github.com> Date: Tue, 24 Oct 2023 18:32:16 -0400 Subject: [PATCH] Pass `heartbeat_period` to Parsl HTEX --- ...0907815+rjmello_gcengine_executor_heartbeat.rst | 5 +++++ .../engines/globus_compute.py | 2 +- compute_endpoint/tests/conftest.py | 4 ++-- compute_endpoint/tests/unit/test_engines.py | 14 ++++++++++++++ .../tests/unit/test_status_reporting.py | 4 ++-- 5 files changed, 24 insertions(+), 5 deletions(-) create mode 100644 changelog.d/20231024_183134_30907815+rjmello_gcengine_executor_heartbeat.rst diff --git a/changelog.d/20231024_183134_30907815+rjmello_gcengine_executor_heartbeat.rst b/changelog.d/20231024_183134_30907815+rjmello_gcengine_executor_heartbeat.rst new file mode 100644 index 000000000..bcfcd292f --- /dev/null +++ b/changelog.d/20231024_183134_30907815+rjmello_gcengine_executor_heartbeat.rst @@ -0,0 +1,5 @@ +Bug Fixes +^^^^^^^^^ + +- The ``GlobusComputeEngine`` has been updated to fully support the + ``heartbeat_period`` parameter. \ No newline at end of file diff --git a/compute_endpoint/globus_compute_endpoint/engines/globus_compute.py b/compute_endpoint/globus_compute_endpoint/engines/globus_compute.py index 0d6c2ff31..79910482e 100644 --- a/compute_endpoint/globus_compute_endpoint/engines/globus_compute.py +++ b/compute_endpoint/globus_compute_endpoint/engines/globus_compute.py @@ -41,7 +41,7 @@ def __init__( self.max_workers_per_node = 1 if executor is None: executor = HighThroughputExecutor( # type: ignore - *args, address=address, **kwargs + *args, address=address, heartbeat_period=heartbeat_period, **kwargs ) self.executor = executor diff --git a/compute_endpoint/tests/conftest.py b/compute_endpoint/tests/conftest.py index a244e3954..483342cb1 100644 --- a/compute_endpoint/tests/conftest.py +++ b/compute_endpoint/tests/conftest.py @@ -106,8 +106,8 @@ def func(): @pytest.fixture -def engine_heartbeat() -> float: - return 0.1 +def engine_heartbeat() -> int: + return 1 @pytest.fixture diff --git a/compute_endpoint/tests/unit/test_engines.py b/compute_endpoint/tests/unit/test_engines.py index 42868ef5b..b0b4523b0 100644 --- a/compute_endpoint/tests/unit/test_engines.py +++ b/compute_endpoint/tests/unit/test_engines.py @@ -4,6 +4,7 @@ import time import uuid +import parsl import pytest from globus_compute_common import messagepack from globus_compute_common.messagepack.message_types import TaskTransition @@ -19,6 +20,7 @@ from globus_compute_endpoint.engines.base import GlobusComputeEngineBase from globus_compute_sdk.serialize import ComputeSerializer from parsl.executors.high_throughput.interchange import ManagerLost +from pytest_mock import MockFixture from tests.utils import double, ez_pack_function, slow_double logger = logging.getLogger(__name__) @@ -185,3 +187,15 @@ def test_serialized_engine_config_has_provider(engine_type: GlobusComputeEngineB executor = res["executors"][0].get("executor") or res["executors"][0] assert executor.get("provider") + + +def test_gcengine_pass_through_to_executor(mocker: MockFixture): + mock_executor = mocker.patch.object(parsl.HighThroughputExecutor, "__new__") + + args = (1, "blah") + kwargs = {"address": "127.0.0.1", "heartbeat_period": 10, "foo": "bar"} + GlobusComputeEngine(*args, **kwargs) + + a, k = mock_executor.call_args + assert a[1:] == args + assert kwargs == k diff --git a/compute_endpoint/tests/unit/test_status_reporting.py b/compute_endpoint/tests/unit/test_status_reporting.py index 493332ff1..19b90f8a3 100644 --- a/compute_endpoint/tests/unit/test_status_reporting.py +++ b/compute_endpoint/tests/unit/test_status_reporting.py @@ -12,7 +12,7 @@ "engine_type", (engines.ProcessPoolEngine, engines.ThreadPoolEngine, engines.GlobusComputeEngine), ) -def test_status_reporting(engine_type, engine_runner, engine_heartbeat: float): +def test_status_reporting(engine_type, engine_runner, engine_heartbeat: int): engine = engine_runner(engine_type) report = engine.get_status_report() @@ -28,7 +28,7 @@ def test_status_reporting(engine_type, engine_runner, engine_heartbeat: float): # Confirm heartbeats in regular intervals for _i in range(3): - q_msg = results_q.get(timeout=1) + q_msg = results_q.get(timeout=1.1) assert isinstance(q_msg, dict) message = q_msg["message"]