diff --git a/compute_sdk/globus_compute_sdk/sdk/client.py b/compute_sdk/globus_compute_sdk/sdk/client.py index 56e938e3c..389f64cfe 100644 --- a/compute_sdk/globus_compute_sdk/sdk/client.py +++ b/compute_sdk/globus_compute_sdk/sdk/client.py @@ -416,7 +416,7 @@ def batch_run( def register_endpoint( self, name, - endpoint_id: UUID_LIKE_T, + endpoint_id: UUID_LIKE_T | None, metadata: dict | None = None, multi_user: bool | None = None, display_name: str | None = None, @@ -429,7 +429,7 @@ def register_endpoint( ---------- name : str Name of the endpoint - endpoint_id : str | UUID + endpoint_id : str | UUID | None The uuid of the endpoint metadata : dict | None Endpoint metadata diff --git a/compute_sdk/globus_compute_sdk/sdk/web_client.py b/compute_sdk/globus_compute_sdk/sdk/web_client.py index e0ba34d75..409aa1c31 100644 --- a/compute_sdk/globus_compute_sdk/sdk/web_client.py +++ b/compute_sdk/globus_compute_sdk/sdk/web_client.py @@ -167,7 +167,7 @@ def submit( def register_endpoint( self, endpoint_name: str, - endpoint_id: UUID_LIKE_T, + endpoint_id: t.Optional[UUID_LIKE_T] = None, *, metadata: t.Optional[dict] = None, multi_user: t.Optional[bool] = None, @@ -176,10 +176,7 @@ def register_endpoint( auth_policy: t.Optional[UUID_LIKE_T] = None, additional_fields: t.Optional[t.Dict[str, t.Any]] = None, ) -> globus_sdk.GlobusHTTPResponse: - data: t.Dict[str, t.Any] = { - "endpoint_name": endpoint_name, - "endpoint_uuid": str(endpoint_id), - } + data: t.Dict[str, t.Any] = {"endpoint_name": endpoint_name} # Only populate if not None. "" is valid and will be included # No value or a 'None' on an existing endpoint will leave @@ -200,7 +197,11 @@ def register_endpoint( data["authentication_policy"] = auth_policy if additional_fields is not None: data.update(additional_fields) - return self.post("/v2/endpoints", data=data) + + if endpoint_id: + return self.put(f"/v3/endpoints/{endpoint_id}", data=data) + else: + return self.post("/v3/endpoints", data=data) def get_result_amqp_url(self) -> globus_sdk.GlobusHTTPResponse: return self.get("/v2/get_amqp_result_connection_url") diff --git a/compute_sdk/tests/unit/test_web_client.py b/compute_sdk/tests/unit/test_web_client.py index fb738885a..bd2995bf3 100644 --- a/compute_sdk/tests/unit/test_web_client.py +++ b/compute_sdk/tests/unit/test_web_client.py @@ -4,6 +4,7 @@ import pytest import responses +from globus_compute_sdk.sdk.client import Client from globus_compute_sdk.sdk.web_client import WebClient from globus_compute_sdk.version import __version__ @@ -103,9 +104,21 @@ def test_get_amqp_url(client, randomstring): @pytest.mark.parametrize("multi_user", [None, True, False]) -def test_multi_user_post(client, multi_user): - responses.post(url="https://api.funcx/v2/endpoints") - resp = client.register_endpoint("ep_name", "ep_id", multi_user=multi_user) +def test_multi_user_post(client: Client, multi_user): + responses.post(url="https://api.funcx/v3/endpoints") + resp = client.register_endpoint("ep_name", None, multi_user=multi_user) + req_body = json.loads(resp._response.request.body) + if multi_user: + assert req_body["multi_user"] == multi_user + else: + assert "multi_user" not in req_body + + +@pytest.mark.parametrize("multi_user", [None, True, False]) +def test_multi_user_put(client: Client, multi_user): + ep_uuid = uuid.uuid4() + responses.put(url=f"https://api.funcx/v3/endpoints/{ep_uuid}") + resp = client.register_endpoint("ep_name", ep_uuid, multi_user=multi_user) req_body = json.loads(resp._response.request.body) if multi_user: assert req_body["multi_user"] == multi_user @@ -139,3 +152,30 @@ def test_delete_function(client: WebClient, randomstring: t.Callable): res = client.delete_function(func_uuid_str) assert res["some_key"] == expected_response + + +@pytest.mark.parametrize("ep_uuid", [uuid.uuid4(), None]) +def test_register_endpoint_post_put( + client: WebClient, randomstring: t.Callable, ep_uuid: t.Optional[uuid.UUID] +): + post_response = {"foo": randomstring()} + put_response = {"foo": randomstring()} + if ep_uuid: + responses.add( + responses.PUT, + f"https://api.funcx/v3/endpoints/{ep_uuid}", + json=put_response, + ) + else: + responses.add( + responses.POST, + "https://api.funcx/v3/endpoints", + json=post_response, + ) + + res = client.register_endpoint("MyEP", ep_uuid) + + if ep_uuid: + assert res.data == put_response + else: + assert res.data == post_response