diff --git a/changelog.d/20230828_165626_LeiGlobus_unknown_task_id_sc_24535.rst b/changelog.d/20230828_165626_LeiGlobus_unknown_task_id_sc_24535.rst new file mode 100644 index 000000000..d688288cb --- /dev/null +++ b/changelog.d/20230828_165626_LeiGlobus_unknown_task_id_sc_24535.rst @@ -0,0 +1,4 @@ +Bug Fixes +^^^^^^^^^ + +- Expired or unknown tasks queried using Client.get_batch_result() method will display the appropriate unknown response instead of producing a stack trace diff --git a/compute_sdk/globus_compute_sdk/sdk/client.py b/compute_sdk/globus_compute_sdk/sdk/client.py index 387e1339f..8a37b7e13 100644 --- a/compute_sdk/globus_compute_sdk/sdk/client.py +++ b/compute_sdk/globus_compute_sdk/sdk/client.py @@ -187,10 +187,16 @@ def _update_task_table(self, return_msg: str | t.Dict, task_id: str): status = {"pending": pending, "status": r_status} if not pending: + # We are tolerant on the other fields but task_id should be there + if task_id != r_dict.get("task_id"): + err_msg = f"Task {task_id} returned invalid response: ({r_dict})" + logger.error(err_msg) + raise ValueError(err_msg) + + completion_t = r_dict.get("completion_t", "unknown") if "result" not in r_dict and "exception" not in r_dict: - raise ValueError("non-pending result is missing result data") - completion_t = r_dict["completion_t"] - if "result" in r_dict: + status["reason"] = r_dict.get("reason", "unknown") + elif "result" in r_dict: try: r_obj = self.fx_serializer.deserialize(r_dict["result"]) except Exception: diff --git a/compute_sdk/tests/unit/test_client.py b/compute_sdk/tests/unit/test_client.py index 87e5fae48..bae779172 100644 --- a/compute_sdk/tests/unit/test_client.py +++ b/compute_sdk/tests/unit/test_client.py @@ -101,7 +101,12 @@ def test_update_task_table_on_invalid_data(api_data): def test_update_task_table_on_exception(): - api_data = {"status": "success", "exception": "foo-bar-baz", "completion_t": "1.1"} + api_data = { + "status": "success", + "exception": "foo-bar-baz", + "completion_t": "1.1", + "task_id": "task-id-foo", + } gcc = gc.Client(do_version_check=False, login_manager=mock.Mock()) with pytest.raises(TaskExecutionFailed) as excinfo: @@ -115,7 +120,7 @@ def test_update_task_table_simple_object(randomstring): task_id = "some_task_id" payload = randomstring() - data = {"status": "success", "completion_t": "1.1"} + data = {"task_id": task_id, "status": "success", "completion_t": "1.1"} data["result"] = serde.serialize(payload) st = gcc._update_task_table(data, task_id) @@ -132,7 +137,10 @@ def test_pending_tasks_always_fetched(): gcc = gc.Client(do_version_check=False, login_manager=mock.Mock()) gcc.web_client = mock.MagicMock() gcc._task_status_table.update( - {should_fetch_01: {"pending": True}, no_fetch: {"pending": False}} + { + should_fetch_01: {"pending": True, "task_id": should_fetch_01}, + no_fetch: {"pending": False, "task_id": no_fetch}, + } ) task_id_list = [no_fetch, should_fetch_01, should_fetch_02] @@ -279,3 +287,35 @@ def test_delete_function(): gcc.delete_function(func_uuid_str) assert gcc.web_client.delete_function.called_with(func_uuid_str) + + +def test_missing_task_info(mocker, login_manager): + tid1 = str(uuid.uuid4()) + tid1_reason = "XYZ tid1" + tid2 = str(uuid.uuid4()) + + mock_resp = { + "response": "batch", + "results": { + tid1: { + "task_id": tid1, + "status": "failed", + "reason": tid1_reason, + }, + tid2: { + "task_id": tid2, + "status": "failed", + }, + }, + } + login_manager.get_web_client.get_batch_status = mocker.Mock(return_value=mock_resp) + gcc = gc.Client(do_version_check=False, login_manager=login_manager) + + gcc.web_client.base_url = "https://a.g.org" + res = gcc.get_batch_result([tid1, tid2]) + + assert tid1 in res + assert res[tid1]["pending"] is False + assert res[tid1]["reason"] == tid1_reason + assert tid2 in res + assert res[tid2]["reason"] == "unknown"