Skip to content

Commit

Permalink
Rely on web service to create endpoint UUID
Browse files Browse the repository at this point in the history
  • Loading branch information
rjmello committed Nov 14, 2023
1 parent e783850 commit 766eddb
Show file tree
Hide file tree
Showing 8 changed files with 184 additions and 78 deletions.
22 changes: 5 additions & 17 deletions compute_endpoint/globus_compute_endpoint/endpoint/endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import subprocess
import sys
import typing as t
import uuid
from http import HTTPStatus

import daemon
Expand Down Expand Up @@ -302,15 +301,6 @@ def get_endpoint_dir_by_uuid(
return ep_path
return None

@staticmethod
def get_or_create_endpoint_uuid(
endpoint_dir: pathlib.Path, endpoint_uuid: str | None
) -> str:
ep_id = Endpoint.get_endpoint_id(endpoint_dir)
if not ep_id:
ep_id = endpoint_uuid or str(uuid.uuid4())
return ep_id

@staticmethod
def get_funcx_client(config: Config | None) -> Client:
if config:
Expand Down Expand Up @@ -399,15 +389,13 @@ def start_endpoint(
# place registration after everything else so that the endpoint will
# only be registered if everything else has been set up successfully
if not reg_info:
endpoint_uuid = Endpoint.get_or_create_endpoint_uuid(
endpoint_dir, endpoint_uuid
)
log.debug("Attempting registration; trying with eid: %s", endpoint_uuid)
endpoint_uuid = Endpoint.get_endpoint_id(endpoint_dir) or endpoint_uuid
log.debug("Attempting endpoint registration")
try:
fx_client = Endpoint.get_funcx_client(endpoint_config)
reg_info = fx_client.register_endpoint(
endpoint_dir.name,
endpoint_uuid,
name=endpoint_dir.name,
endpoint_id=endpoint_uuid,
metadata=Endpoint.get_metadata(endpoint_config),
multi_user=False,
display_name=endpoint_config.display_name,
Expand Down Expand Up @@ -449,7 +437,7 @@ def start_endpoint(
exit(os.EX_TEMPFAIL)

ret_ep_uuid = reg_info.get("endpoint_id")
if ret_ep_uuid != endpoint_uuid:
if endpoint_uuid and ret_ep_uuid != endpoint_uuid:
log.error(
"Unexpected response from server: mismatched endpoint id."
f"\n Expected: {endpoint_uuid}, received: {ret_ep_uuid}"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ def __init__(
maxsize=32768, ttl=config.mu_child_ep_grace_period_s
)

endpoint_uuid = Endpoint.get_or_create_endpoint_uuid(conf_dir, endpoint_uuid)
endpoint_uuid = Endpoint.get_endpoint_id(conf_dir) or endpoint_uuid

if not config.identity_mapping_config_path:
msg = (
Expand All @@ -133,8 +133,8 @@ def __init__(

gcc = GC.Client(**client_options)
reg_info = gcc.register_endpoint(
conf_dir.name,
endpoint_uuid,
name=conf_dir.name,
endpoint_id=endpoint_uuid,
metadata=EndpointManager.get_metadata(config, conf_dir),
multi_user=True,
)
Expand Down Expand Up @@ -165,14 +165,14 @@ def __init__(
exit(os.EX_TEMPFAIL)

upstream_ep_uuid = reg_info.get("endpoint_id")
if upstream_ep_uuid != endpoint_uuid:
if endpoint_uuid and upstream_ep_uuid != endpoint_uuid:
log.error(
"Unexpected response from server: mismatched endpoint id."
f"\n Expected: {endpoint_uuid}, received: {upstream_ep_uuid}"
)
exit(os.EX_SOFTWARE)

self._endpoint_uuid_str = upstream_ep_uuid
self._endpoint_uuid_str = str(upstream_ep_uuid)

try:
cq_info = reg_info["command_queue_info"]
Expand Down
2 changes: 1 addition & 1 deletion compute_endpoint/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def get_auth_client(self) -> globus_sdk.AuthClient:

def get_web_client(self, *, base_url: str | None = None) -> WebClient:
return WebClient(
base_url="https://compute.api.globus.org/v2/",
base_url="https://compute.api.globus.org",
authorizer=globus_sdk.NullAuthorizer(),
)

Expand Down
13 changes: 12 additions & 1 deletion compute_endpoint/tests/integration/endpoint/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,18 @@ def resp(endpoint_uuid: str):

responses.add(
method=responses.POST,
url="https://compute.api.globus.org/v2/endpoints",
url="https://compute.api.globus.org/v3/endpoints",
headers={"Content-Type": "application/json"},
json={
"endpoint_id": endpoint_uuid,
"task_queue_info": task_queue_info,
"result_queue_info": rq_info,
},
)

responses.add(
method=responses.PUT,
url=f"https://compute.api.globus.org/v3/endpoints/{endpoint_uuid}",
headers={"Content-Type": "application/json"},
json={
"endpoint_id": endpoint_uuid,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def test_non_configured_endpoint(mocker):
)
def test_start_endpoint_display_name(mocker, fs, display_name):
responses.add( # 404 == we are verifying the POST, not the response
responses.POST, _SVC_ADDY + "/v2/endpoints", json={}, status=404
responses.POST, _SVC_ADDY + "/v3/endpoints", json={}, status=404
)

ep = endpoint.Endpoint()
Expand All @@ -70,7 +70,7 @@ def test_start_endpoint_display_name(mocker, fs, display_name):
ep_conf.display_name = display_name

with pytest.raises(SystemExit) as pyt_exc:
ep.start_endpoint(ep_dir, str(uuid.uuid4()), ep_conf, False, True, reg_info={})
ep.start_endpoint(ep_dir, None, ep_conf, False, True, reg_info={})
assert int(str(pyt_exc.value)) == os.EX_UNAVAILABLE, "Verify exit due to test 404"

req = pyt_exc.value.__cause__._underlying_response.request
Expand All @@ -83,7 +83,7 @@ def test_start_endpoint_display_name(mocker, fs, display_name):

def test_start_endpoint_allowlist_passthrough(mocker, fs):
responses.add( # 404 == we are verifying the POST, not the response
responses.POST, _SVC_ADDY + "/v2/endpoints", json={}, status=404
responses.POST, _SVC_ADDY + "/v3/endpoints", json={}, status=404
)

ep = endpoint.Endpoint()
Expand All @@ -93,7 +93,7 @@ def test_start_endpoint_allowlist_passthrough(mocker, fs):
ep_conf.allowed_functions = [str(uuid.uuid4()), str(uuid.uuid4())]

with pytest.raises(SystemExit) as pyt_exc:
ep.start_endpoint(ep_dir, str(uuid.uuid4()), ep_conf, False, True, reg_info={})
ep.start_endpoint(ep_dir, None, ep_conf, False, True, reg_info={})
assert int(str(pyt_exc.value)) == os.EX_UNAVAILABLE, "Verify exit due to test 404"

req = pyt_exc.value.__cause__._underlying_response.request
Expand All @@ -104,7 +104,7 @@ def test_start_endpoint_allowlist_passthrough(mocker, fs):

def test_start_endpoint_auth_policy_passthrough(mocker, fs):
responses.add( # 404 == we are verifying the POST, not the response
responses.POST, _SVC_ADDY + "/v2/endpoints", json={}, status=404
responses.POST, _SVC_ADDY + "/v3/endpoints", json={}, status=404
)

ep_dir = pathlib.Path("/some/path/some_endpoint_name")
Expand All @@ -115,7 +115,7 @@ def test_start_endpoint_auth_policy_passthrough(mocker, fs):
ep_conf.authentication_policy = str(uuid.uuid4())

with pytest.raises(SystemExit) as pyt_exc:
ep.start_endpoint(ep_dir, str(uuid.uuid4()), ep_conf, False, True, reg_info={})
ep.start_endpoint(ep_dir, None, ep_conf, False, True, reg_info={})
assert int(str(pyt_exc.value)) == os.EX_UNAVAILABLE, "Verify exit due to test 404"

req = pyt_exc.value.__cause__._underlying_response.request
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import json
import logging
import os
import pathlib
Expand Down Expand Up @@ -444,32 +443,6 @@ def test_with_funcx_config(self, mocker):
**mock_optionals,
)

def test_get_or_create_endpoint_uuid_no_json_no_uuid(self, mocker):
mock_uuid = mocker.patch(f"{_MOCK_BASE}uuid.uuid4")
mock_uuid.return_value = 123456

config_dir = pathlib.Path("/some/path/mock_endpoint")
manager = Endpoint()
manager.configure_endpoint(config_dir, None)
assert "123456" == manager.get_or_create_endpoint_uuid(config_dir, None)

def test_get_or_create_endpoint_uuid_no_json_given_uuid(self):
config_dir = pathlib.Path("/some/path/mock_endpoint")
manager = Endpoint()
manager.configure_endpoint(config_dir, None)
assert "234567" == manager.get_or_create_endpoint_uuid(config_dir, "234567")

def test_get_or_create_endpoint_uuid_given_json(self):
config_dir = pathlib.Path("/some/path/mock_endpoint")
manager = Endpoint()
manager.configure_endpoint(config_dir, None)

mock_dict = {"endpoint_id": "abcde12345"}
with open(os.path.join(config_dir, "endpoint.json"), "w") as fd:
json.dump(mock_dict, fd)

assert "abcde12345" == manager.get_or_create_endpoint_uuid(config_dir, "234567")

@pytest.mark.parametrize("dir_exists", (True, False))
@pytest.mark.parametrize("web_svc_ok", (True, False))
@pytest.mark.parametrize("force", (True, False))
Expand Down
127 changes: 110 additions & 17 deletions compute_endpoint/tests/unit/test_endpoint_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,24 +53,33 @@ def create_response(
creds = ""
if username and password:
creds = f"{username}:{password}@"

res_body = {
"endpoint_id": endpoint_id,
"task_queue_info": {
"exchange_name": "tasks",
"connection_url": f"amqp://{creds}{rmq_fqdn}",
"args": queue_kwargs,
},
"result_queue_info": {
"exchange_name": "results",
"connection_url": f"amqp://{creds}{rmq_fqdn}",
"args": queue_kwargs,
"routing_key": f"{endpoint_uuid}.results",
},
}

responses.add(
method=responses.POST,
url="https://compute.api.globus.org/v2/endpoints",
url="https://compute.api.globus.org/v3/endpoints",
headers={"Content-Type": "application/json"},
json={
"endpoint_id": endpoint_id,
"task_queue_info": {
"exchange_name": "tasks",
"connection_url": f"amqp://{creds}{rmq_fqdn}",
"args": queue_kwargs,
},
"result_queue_info": {
"exchange_name": "results",
"connection_url": f"amqp://{creds}{rmq_fqdn}",
"args": queue_kwargs,
"routing_key": f"{endpoint_uuid}.results",
},
},
json=res_body,
)
responses.add(
method=responses.PUT,
url=f"https://compute.api.globus.org/v3/endpoints/{endpoint_id}",
headers={"Content-Type": "application/json"},
json=res_body,
)

return create_response
Expand All @@ -81,7 +90,14 @@ def register_endpoint_failure_response(endpoint_uuid):
def create_response(endpoint_id=endpoint_uuid, status_code=200, msg="Error Msg"):
responses.add(
method=responses.POST,
url="https://compute.api.globus.org/v2/endpoints",
url="https://compute.api.globus.org/v3/endpoints",
headers={"Content-Type": "application/json"},
json={"error": msg},
status=status_code,
)
responses.add(
method=responses.PUT,
url=f"https://compute.api.globus.org/v3/endpoints/{endpoint_id}",
headers={"Content-Type": "application/json"},
json={"error": msg},
status=status_code,
Expand Down Expand Up @@ -261,7 +277,13 @@ def test_register_endpoint_invalid_response(

ep, ep_dir, log_to_console, no_color, ep_conf = mock_ep_data

register_endpoint_response(endpoint_id=other_endpoint_id)
responses.add(
method=responses.PUT,
url=f"https://compute.api.globus.org/v3/endpoints/{endpoint_uuid}",
headers={"Content-Type": "application/json"},
json={"endpoint_id": other_endpoint_id},
)

with pytest.raises(SystemExit) as pytest_exc:
ep.start_endpoint(
ep_dir, endpoint_uuid, ep_conf, log_to_console, no_color, reg_info={}
Expand Down Expand Up @@ -768,3 +790,74 @@ def test_get_endpoint_dir_by_uuid(tmp_path, name, uuid, exists):
assert result is not None
else:
assert result is None


@pytest.mark.parametrize("json_exists", [True, False])
def test_get_endpoint_id(tmp_path: pathlib.Path, json_exists: bool):
ep_uuid_str = str(uuid.uuid4())
if json_exists:
ep_json = tmp_path / "endpoint.json"
ep_json.write_text(json.dumps({"endpoint_id": ep_uuid_str}))

ret = Endpoint.get_endpoint_id(endpoint_dir=tmp_path)

if json_exists:
assert ret == ep_uuid_str
else:
assert ret is None


def test_handles_provided_endpoint_id_no_json(
mocker: MockFixture,
mock_ep_data: tuple[Endpoint, pathlib.Path, bool, bool, Config],
mock_reg_info: dict,
):
ep, ep_dir, log_to_console, no_color, ep_conf = mock_ep_data
ep_uuid_str = str(uuid.uuid4())

mocker.patch(f"{_mock_base}daemon")
mocker.patch(f"{_mock_base}EndpointInterchange")

mock_gcc = mocker.Mock()
mock_gcc.register_endpoint.return_value = {
**mock_reg_info,
"endpoint_id": ep_uuid_str,
}
mocker.patch(f"{_mock_base}Endpoint.get_funcx_client").return_value = mock_gcc

ep.start_endpoint(
ep_dir, ep_uuid_str, ep_conf, log_to_console, no_color, reg_info={}
)

_a, k = mock_gcc.register_endpoint.call_args
assert k["endpoint_id"] == ep_uuid_str


def test_handles_provided_endpoint_id_with_json(
mocker: MockFixture,
mock_ep_data: tuple[Endpoint, pathlib.Path, bool, bool, Config],
mock_reg_info: dict,
):
ep, ep_dir, log_to_console, no_color, ep_conf = mock_ep_data
ep_uuid_str = str(uuid.uuid4())
provided_ep_uuid_str = str(uuid.uuid4())

ep_json = ep_dir / "endpoint.json"
ep_json.write_text(json.dumps({"endpoint_id": ep_uuid_str}))

mocker.patch(f"{_mock_base}daemon")
mocker.patch(f"{_mock_base}EndpointInterchange")

mock_gcc = mocker.Mock()
mock_gcc.register_endpoint.return_value = {
**mock_reg_info,
"endpoint_id": ep_uuid_str,
}
mocker.patch(f"{_mock_base}Endpoint.get_funcx_client").return_value = mock_gcc

ep.start_endpoint(
ep_dir, provided_ep_uuid_str, ep_conf, log_to_console, no_color, reg_info={}
)

_a, k = mock_gcc.register_endpoint.call_args
assert k["endpoint_id"] == ep_uuid_str
Loading

0 comments on commit 766eddb

Please sign in to comment.