Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Process EP heartbeats separately from results #1762

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -274,9 +274,9 @@ def __init__(
cq_info = reg_info["command_queue_info"]
_ = cq_info["connection_url"], cq_info["queue"]

rq_info = reg_info["result_queue_info"]
_ = rq_info["connection_url"], rq_info["queue"]
_ = rq_info["queue_publish_kwargs"]
hbq_info = reg_info["heartbeat_queue_info"]
_ = hbq_info["connection_url"], hbq_info["queue"]
_ = hbq_info["queue_publish_kwargs"]
except Exception as e:
log_reg_info = _redact_url_creds(str(reg_info))
log.debug("%s", log_reg_info)
Expand Down Expand Up @@ -316,7 +316,7 @@ def __init__(
stop_event=self._command_stop_event,
thread_name="CQS",
)
self._heartbeat_publisher = ResultPublisher(queue_info=rq_info)
self._heartbeat_publisher = ResultPublisher(queue_info=hbq_info)

@staticmethod
def get_metadata(config: ManagerEndpointConfig, conf_dir: pathlib.Path) -> dict:
Expand Down
92 changes: 75 additions & 17 deletions compute_endpoint/globus_compute_endpoint/endpoint/interchange.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,11 @@

class _ResultPassthroughType(t.TypedDict):
message: bytes
task_id: str | None
task_id: str


class _HeartbeatPassthroughType(t.TypedDict):
message: bytes


class EndpointInterchange:
Expand Down Expand Up @@ -74,8 +78,8 @@ def __init__(
Globus Compute config object describing how compute should be provisioned

reg_info : dict[str, dict]
Dictionary containing connection information for both the task and
result queues. The required data structure is returned from the
Dictionary containing connection information for the task, result and
heartbeat queues. The required data structure is returned from the
Endpoint registration API call, encapsulated in the SDK by
`Client.register_endpoint()`.

Expand All @@ -100,6 +104,7 @@ def __init__(

self.task_q_info = reg_info["task_queue_info"]
self.result_q_info = reg_info["result_queue_info"]
self.heartbeat_q_info = reg_info["heartbeat_queue_info"]

self.time_to_quit = False
self.heartbeat_period = self.config.heartbeat_period
Expand Down Expand Up @@ -132,6 +137,9 @@ def __init__(
log.info(f"Platform info: {self.current_platform}")

self.results_passthrough: queue.Queue[_ResultPassthroughType] = queue.Queue()
self.heartbeats_passthrough: queue.Queue[_HeartbeatPassthroughType] = (
queue.Queue()
)
# Rename self.executor -> self.engine in second round
self.executor: GlobusComputeEngineBase = self.config.executors[0]
self._test_start = False
Expand All @@ -140,6 +148,7 @@ def start_engine(self):
log.info("Starting Engine")
self.executor.start(
results_passthrough=self.results_passthrough,
heartbeats_passthrough=self.heartbeats_passthrough,
endpoint_id=self.endpoint_id,
run_dir=self.logdir,
)
Expand Down Expand Up @@ -294,6 +303,9 @@ def _main_loop(self):
results_publisher = ResultPublisher(queue_info=self.result_q_info)
results_publisher.start()

heartbeat_publisher = ResultPublisher(queue_info=self.heartbeat_q_info)
heartbeat_publisher.start()

executor = self.executor

num_tasks_forwarded = 0
Expand Down Expand Up @@ -338,8 +350,8 @@ def process_pending_tasks() -> None:
d_tag, prop_headers, body = self.pending_task_queue.get(timeout=1)
task_q_subscriber.ack(d_tag)

fid: str = prop_headers.get("function_uuid")
tid: str = prop_headers.get("task_uuid")
fid: str | None = prop_headers.get("function_uuid")
tid: str | None = prop_headers.get("task_uuid")

if not fid or not tid:
raise InvalidMessageError(
Expand Down Expand Up @@ -407,15 +419,14 @@ def process_pending_results() -> None:
# iterating the loop regardless.
nonlocal num_results_forwarded

def _create_done_cb(mq_msg: bytes, tid: str | None):
def _create_done_cb(mq_msg: bytes, tid: str):
def _done_cb(pub_fut: Future):
_exc = pub_fut.exception()
if _exc:
# Publishing didn't work -- quiesce and see if a simple
# restart fixes the issue.
if tid:
log.info(f"Storing result for later: {tid}")
self.result_store[tid] = mq_msg
log.info(f"Storing result for later: {tid}")
self.result_store[tid] = mq_msg

self._quiesce_event.set()
log.error("Failed to publish results", exc_info=_exc)
Expand All @@ -426,7 +437,7 @@ def _done_cb(pub_fut: Future):
try:
msg = self.results_passthrough.get(timeout=1)
packed_message: bytes = msg["message"]
task_id: str | None = msg.get("task_id")
task_id: str = msg["task_id"]

except queue.Empty:
continue
Expand All @@ -438,9 +449,8 @@ def _done_cb(pub_fut: Future):
)
continue

if task_id:
num_results_forwarded += 1
log.debug("Forwarding result for task: %s", task_id)
num_results_forwarded += 1
log.debug("Forwarding result for task: %s", task_id)

try:
f = results_publisher.publish(packed_message)
Expand All @@ -455,13 +465,53 @@ def _done_cb(pub_fut: Future):
"Something broke while forwarding results; setting quiesce"
" event"
)
if task_id:
log.info("Storing result for later: %s", task_id)
self.result_store[task_id] = packed_message
log.info("Storing result for later: %s", task_id)
self.result_store[task_id] = packed_message
continue # just be explicit

log.debug("Exit process-pending-results thread.")

def process_pending_heartbeats() -> None:
def _done_cb(pub_fut: Future):
_exc = pub_fut.exception()
if _exc:
# Publishing didn't work -- quiesce and see if a simple
# restart fixes the issue.
self._quiesce_event.set()
log.error("Failed to publish heartbeat", exc_info=_exc)

while not self._quiesce_event.is_set():
try:
msg = self.heartbeats_passthrough.get(timeout=1)
packed_message: bytes = msg["message"]

except queue.Empty:
continue

except Exception as exc:
log.warning(
"Invalid message received. Ignoring."
f" ([{type(exc).__name__}] {exc})"
)
continue

try:
f = heartbeat_publisher.publish(packed_message)
f.add_done_callback(_done_cb)

except Exception:
# Publishing didn't work -- quiesce and see if a simple restart
# fixes the issue.
self._quiesce_event.set()

log.exception(
"Something broke while forwarding heartbeats; setting quiesce"
" event"
)
continue # just be explicit

log.debug("Exit process-pending-heartbeats thread.")

stored_processor_thread = threading.Thread(
target=process_stored_results, daemon=True, name="Stored Result Handler"
)
Expand All @@ -471,9 +521,15 @@ def _done_cb(pub_fut: Future):
result_processor_thread = threading.Thread(
target=process_pending_results, daemon=True, name="Pending Result Handler"
)
heartbeat_processor_thread = threading.Thread(
target=process_pending_heartbeats,
daemon=True,
name="Pending Heartbeat Handler",
)
stored_processor_thread.start()
task_processor_thread.start()
result_processor_thread.start()
heartbeat_processor_thread.start()

connection_stable_hearbeats = 0
last_t, last_r = 0, 0
Expand Down Expand Up @@ -574,6 +630,7 @@ def _done_cb(pub_fut: Future):
stored_processor_thread.join(timeout=5)
task_processor_thread.join(timeout=5)
result_processor_thread.join(timeout=5)
heartbeat_processor_thread.join(timeout=5)

# let higher-level error handling take over if the following excepts
message = EPStatusReport(
Expand All @@ -589,7 +646,7 @@ def _done_cb(pub_fut: Future):
task_statuses={},
)
try:
f = results_publisher.publish(pack(message))
f = heartbeat_publisher.publish(pack(message))
f.result(timeout=5)
except concurrent.futures.TimeoutError:
log.warning(
Expand All @@ -598,5 +655,6 @@ def _done_cb(pub_fut: Future):

task_q_subscriber.stop()
results_publisher.stop()
heartbeat_publisher.stop()

log.debug("_main_loop exits")
3 changes: 2 additions & 1 deletion compute_endpoint/globus_compute_endpoint/engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ def __init__(
self.results_passthrough: queue.Queue[dict[str, bytes | str | None]] = (
queue.Queue()
)
self.heartbeats_passthrough: queue.Queue[dict[str, bytes]] = queue.Queue()
self._engine_ready: bool = False

@abstractmethod
Expand All @@ -141,7 +142,7 @@ def set_working_dir(self, run_dir: str | None = None):
def report_status(self) -> None:
status_report = self.get_status_report()
packed: bytes = messagepack.pack(status_report)
self.results_passthrough.put({"message": packed})
self.heartbeats_passthrough.put({"message": packed})

def _handle_task_exception(
self,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,7 @@ def start(
endpoint_id: t.Optional[uuid.UUID] = None,
run_dir: t.Optional[str] = None,
results_passthrough: t.Optional[queue.Queue] = None,
heartbeats_passthrough: t.Optional[queue.Queue] = None,
**kwargs,
):
assert endpoint_id, "GCExecutor requires kwarg:endpoint_id at start"
Expand All @@ -258,6 +259,8 @@ def start(
# Only update the default queue in GCExecutorBase if
# a queue is passed in
self.results_passthrough = results_passthrough
if heartbeats_passthrough:
self.heartbeats_passthrough = heartbeats_passthrough
self.executor.start()
self._status_report_thread.start()
# Add executor to poller *after* executor has started
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def start(
endpoint_id: t.Optional[uuid.UUID] = None,
run_dir: t.Optional[str] = None,
results_passthrough: t.Optional[queue.Queue] = None,
heartbeats_passthrough: t.Optional[queue.Queue] = None,
**kwargs,
) -> None:
"""
Expand All @@ -44,6 +45,7 @@ def start(
endpoint_id: Endpoint UUID
run_dir: endpoint run directory
results_passthrough: Queue to which packed results will be posted
heartbeats_passthrough: Queue to which packed status reports are posted
Returns
-------
"""
Expand All @@ -60,6 +62,8 @@ def start(
self.endpoint_id = endpoint_id
if results_passthrough:
self.results_passthrough = results_passthrough
if heartbeats_passthrough:
self.heartbeats_passthrough = heartbeats_passthrough
assert self.results_passthrough
self.set_working_dir(run_dir=run_dir)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def start(
endpoint_id: t.Optional[uuid.UUID] = None,
run_dir: t.Optional[str] = None,
results_passthrough: t.Optional[queue.Queue] = None,
heartbeats_passthrough: t.Optional[queue.Queue] = None,
**kwargs,
) -> None:
"""
Expand All @@ -42,13 +43,16 @@ def start(
endpoint_id: Endpoint UUID
run_dir: endpoint run directory
results_passthrough: Queue to which packed results will be posted
heartbeats_passthrough: Queue to which packed status reports are posted
Returns
-------
"""
assert endpoint_id, "ThreadPoolEngine requires kwarg:endpoint_id at start"
self.endpoint_id = endpoint_id
if results_passthrough:
self.results_passthrough = results_passthrough
if heartbeats_passthrough:
self.heartbeats_passthrough = heartbeats_passthrough
assert self.results_passthrough

self.set_working_dir(run_dir=run_dir)
Expand Down
Loading
Loading