Skip to content

Commit

Permalink
Rename heartbeat_period_s to heartbeat_period
Browse files Browse the repository at this point in the history
Renaming to maintain parity with the `HighThroughputEngine` and Parsl's
``HighThroughputExecutor``.
  • Loading branch information
rjmello committed Oct 23, 2023
1 parent 3e75b90 commit 9be18ae
Show file tree
Hide file tree
Showing 8 changed files with 35 additions and 24 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
Bug Fixes
^^^^^^^

- The ``GlobusComputeEngine``, ``ProcessPoolEngine``, and ``ThreadPoolEngine``
now respect the ``heartbeat_period`` variable, as defined in ``config.yaml``.

Changed
^^^^^^^

- Renamed the ``heartbeat_period_s`` attribute to ``heartbeat_period`` for
``GlobusComputeEngine``, ``ProcessPoolEngine``, and ``ThreadPoolEngine``
to maintain parity with the ``HighThroughputEngine`` and Parsl's
``HighThroughputExecutor``.
10 changes: 4 additions & 6 deletions compute_endpoint/globus_compute_endpoint/engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,12 @@ class GlobusComputeEngineBase(ABC):
def __init__(
self,
*args: object,
heartbeat_period_s: float = 30.0,
heartbeat_period: float = 30.0,
endpoint_id: t.Optional[uuid.UUID] = None,
**kwargs: object,
):
self._shutdown_event = threading.Event()
self._heartbeat_period_s = heartbeat_period_s
self._heartbeat_period = heartbeat_period
self.endpoint_id = endpoint_id

# remove these unused vars that we are adding to just keep
Expand Down Expand Up @@ -111,10 +111,8 @@ def report_status(self) -> None:
packed: bytes = messagepack.pack(status_report)
self.results_passthrough.put({"message": packed})

def _status_report(
self, shutdown_event: threading.Event, heartbeat_period_s: float
):
while not shutdown_event.wait(timeout=heartbeat_period_s):
def _status_report(self, shutdown_event: threading.Event, heartbeat_period: float):
while not shutdown_event.wait(timeout=heartbeat_period):
status_report = self.get_status_report()
packed = messagepack.pack(status_report)
self.results_passthrough.put({"message": packed})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ def __init__(
*args,
label: str = "GlobusComputeEngine",
address: t.Optional[str] = None,
heartbeat_period_s: float = 30.0,
heartbeat_period: float = 30.0,
strategy: t.Optional[SimpleStrategy] = SimpleStrategy(),
executor: t.Optional[HighThroughputExecutor] = None,
**kwargs,
Expand All @@ -34,9 +34,9 @@ def __init__(
self.run_dir = os.getcwd()
self.label = label
self._status_report_thread = ReportingThread(
target=self.report_status, args=[], reporting_period=heartbeat_period_s
target=self.report_status, args=[], reporting_period=heartbeat_period
)
super().__init__(*args, heartbeat_period_s=heartbeat_period_s, **kwargs)
super().__init__(*args, heartbeat_period=heartbeat_period, **kwargs)
self.strategy = strategy
self.max_workers_per_node = 1
if executor is None:
Expand Down Expand Up @@ -167,7 +167,7 @@ def get_status_report(self) -> EPStatusReport:
"min_blocks": 1,
"max_workers_per_node": 0,
"nodes_per_block": 1,
"heartbeat_period": self._heartbeat_period_s,
"heartbeat_period": self._heartbeat_period,
},
}
task_status_deltas: t.Dict[str, t.List[TaskTransition]] = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ def __init__(
self,
*args,
label: str = "ProcessPoolEngine",
heartbeat_period_s: float = 30.0,
heartbeat_period: float = 30.0,
**kwargs,
):
self.label = label
self.executor: t.Optional[NativeExecutor] = None
self._executor_args = args
self._executor_kwargs = kwargs
self._status_report_thread = ReportingThread(
target=self.report_status, args=[], reporting_period=heartbeat_period_s
target=self.report_status, args=[], reporting_period=heartbeat_period
)
super().__init__(*args, heartbeat_period_s=heartbeat_period_s, **kwargs)
super().__init__(*args, heartbeat_period=heartbeat_period, **kwargs)

def start(
self,
Expand Down Expand Up @@ -86,7 +86,7 @@ def get_status_report(self) -> EPStatusReport:
"min_blocks": 1,
"max_workers_per_node": self.executor._max_workers, # type: ignore
"nodes_per_block": 1,
"heartbeat_period": self._heartbeat_period_s,
"heartbeat_period": self._heartbeat_period,
},
}
task_status_deltas: t.Dict[str, t.List[TaskTransition]] = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@ def __init__(
self,
*args,
label: str = "ThreadPoolEngine",
heartbeat_period_s: float = 30.0,
heartbeat_period: float = 30.0,
**kwargs,
):
self.label = label
self.executor = NativeExecutor(*args, **kwargs)
self._status_report_thread = ReportingThread(
target=self.report_status, args=[], reporting_period=heartbeat_period_s
target=self.report_status, args=[], reporting_period=heartbeat_period
)
super().__init__(*args, heartbeat_period_s=heartbeat_period_s, **kwargs)
super().__init__(*args, heartbeat_period=heartbeat_period, **kwargs)

def start(
self,
Expand Down Expand Up @@ -83,7 +83,7 @@ def get_status_report(self) -> EPStatusReport:
"min_blocks": 1,
"max_workers_per_node": self.executor._max_workers, # type: ignore
"nodes_per_block": 1,
"heartbeat_period": self._heartbeat_period_s,
"heartbeat_period": self._heartbeat_period,
},
}
task_status_deltas: t.Dict[str, t.List[TaskTransition]] = {}
Expand Down
6 changes: 3 additions & 3 deletions compute_endpoint/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,13 @@ def _runner(engine_type: t.Type[GlobusComputeEngineBase]):
ep_id = uuid.uuid4()
queue = Queue()
if engine_type is engines.ProcessPoolEngine:
k = dict(heartbeat_period_s=engine_heartbeat, max_workers=2)
k = dict(heartbeat_period=engine_heartbeat, max_workers=2)
elif engine_type is engines.ThreadPoolEngine:
k = dict(heartbeat_period_s=engine_heartbeat, max_workers=2)
k = dict(heartbeat_period=engine_heartbeat, max_workers=2)
elif engine_type is engines.GlobusComputeEngine:
k = dict(
address="127.0.0.1",
heartbeat_period_s=engine_heartbeat,
heartbeat_period=engine_heartbeat,
heartbeat_threshold=1,
)
else:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def gc_engine_scaling(tmp_path):
ep_id = uuid.uuid4()
engine = GlobusComputeEngine(
address="127.0.0.1",
heartbeat_period_s=1,
heartbeat_period=1,
heartbeat_threshold=1,
provider=LocalProvider(
init_blocks=0,
Expand All @@ -40,7 +40,7 @@ def gc_engine_non_scaling(tmp_path):
ep_id = uuid.uuid4()
engine = GlobusComputeEngine(
address="127.0.0.1",
heartbeat_period_s=1,
heartbeat_period=1,
heartbeat_threshold=1,
provider=LocalProvider(
init_blocks=1,
Expand Down
2 changes: 1 addition & 1 deletion compute_endpoint/tests/unit/test_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_engine_submit_internal(engine_type: GlobusComputeEngineBase, engine_run


def test_proc_pool_engine_not_started():
engine = ProcessPoolEngine(heartbeat_period_s=1, max_workers=2)
engine = ProcessPoolEngine(heartbeat_period=1, max_workers=2)

with pytest.raises(AssertionError) as pyt_exc:
engine.submit(double, 10)
Expand Down

0 comments on commit 9be18ae

Please sign in to comment.