diff --git a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/engine.py b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/engine.py index f78999475..2a5412e92 100644 --- a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/engine.py +++ b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/engine.py @@ -10,6 +10,7 @@ import multiprocessing import os import queue +import socket import threading import time import typing as t @@ -243,7 +244,7 @@ def __init__( # Tuning info prefetch_capacity=10, provider=LocalProvider(), - address="localhost", + address="127.0.0.1", worker_ports=None, worker_port_range=(54000, 55000), interchange_port_range=(55000, 56000), @@ -299,17 +300,14 @@ def __init__( self.endpoint_id = endpoint_id self._task_counter = 0 - try: - # if address != "localhost": - # 'localhost' works for both v4 and v6 in some circumstances - # where an actual numeric IP isn't required - ipaddress.ip_address(address=address) - except Exception: - log.critical( - f"Invalid address supplied: {address}. " - "Please use a valid IPv4 or IPv6 address" + if not HighThroughputEngine.is_hostname_or_ip(address): + err_msg = ( + f"Invalid address supplied: ({address}) " + "Please use a valid hostname or IPv4/IPv6 address" ) - raise + log.critical(err_msg) + raise ValueError(err_msg) + self.address = address self.worker_ports = worker_ports self.worker_port_range = worker_port_range @@ -422,6 +420,28 @@ def start( return self.outgoing_q.port, self.incoming_q.port, self.command_client.port + @staticmethod + def is_hostname_or_ip(hostname_or_ip: str) -> bool: + """ + Utility method to verify that the input is a valid hostname or + IP address. This is potentially useful elsewhere, if used, + should we move it to another module? + """ + if not hostname_or_ip: + return False + else: + try: + socket.gethostbyname(hostname_or_ip) + return True + except socket.gaierror: + # Not a hostname, now check IP + pass + try: + ipaddress.ip_address(address=hostname_or_ip) + except ValueError: + return False + return True + def _start_local_interchange_process(self): """Starts the interchange process locally diff --git a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/zmq_pipes.py b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/zmq_pipes.py index 6557c7126..31b09c7ba 100644 --- a/compute_endpoint/globus_compute_endpoint/engines/high_throughput/zmq_pipes.py +++ b/compute_endpoint/globus_compute_endpoint/engines/high_throughput/zmq_pipes.py @@ -1,5 +1,8 @@ #!/usr/bin/env python3 +from __future__ import annotations + +import ipaddress import logging import time @@ -10,11 +13,22 @@ log = logging.getLogger(__name__) -def remap_ipv6_loopback(ip_address): - # Special case for being compatible with ipv4 and v6 - if ip_address == "::1": - ip_address = "localhost" - return f"tcp://{ip_address}" +def _zmq_canonicalize_address(addr: str | int) -> str: + try: + ip = ipaddress.ip_address(addr) + except ValueError: + # Not a valid IPv4 or IPv6 address + if isinstance(addr, int): + # If it was an integer, then it's just plain invalid + raise + + # Otherwise, it was likely a hostname; let another layer deal with it + return addr + + if ip.version == 4: + return str(ip) # like "12.34.56.78" + elif ip.version == 6: + return f"[{ip}]" # like "[::1]" class CommandClient: @@ -36,7 +50,7 @@ def __init__(self, ip_address, port_range): self.zmq_socket = self.context.socket(zmq.DEALER) self.zmq_socket.set_hwm(0) self.port = self.zmq_socket.bind_to_random_port( - remap_ipv6_loopback(ip_address), + f"tcp://{_zmq_canonicalize_address(ip_address)}", min_port=port_range[0], max_port=port_range[1], ) @@ -78,7 +92,7 @@ def __init__(self, ip_address, port_range): self.zmq_socket.set_hwm(0) self.port = self.zmq_socket.bind_to_random_port( - remap_ipv6_loopback(ip_address), + f"tcp://{_zmq_canonicalize_address(ip_address)}", min_port=port_range[0], max_port=port_range[1], ) @@ -153,7 +167,7 @@ def __init__(self, ip_address, port_range): self.results_receiver = self.context.socket(zmq.DEALER) self.results_receiver.set_hwm(0) self.port = self.results_receiver.bind_to_random_port( - remap_ipv6_loopback(ip_address), + f"tcp://{_zmq_canonicalize_address(ip_address)}", min_port=port_range[0], max_port=port_range[1], ) diff --git a/compute_endpoint/tests/unit/test_engines.py b/compute_endpoint/tests/unit/test_engines.py index 6f42c7b72..b6e443eb9 100644 --- a/compute_endpoint/tests/unit/test_engines.py +++ b/compute_endpoint/tests/unit/test_engines.py @@ -395,3 +395,24 @@ def test_gcmpiengine_accepts_resource_specification(task_uuid, randomstring): a, _k = engine.executor.submit.call_args assert spec in a + + +@pytest.mark.parametrize( + ("input", "is_valid"), + ( + [None, False], + ["", False], + ["localhost.localhost", False], + ["localhost", True], + ["1.2.3.4.5", False], + ["127.0.0.1", True], + ["example.com", True], + ["0:0:0:0:0:0:0:1", True], + ["11111:0:0:0:0:0:0:1", False], + ["::1", True], + ["abc", False], + ), +) +def test_hostname_or_ip_validation(input, is_valid): + result = HighThroughputEngine.is_hostname_or_ip(input) + assert is_valid == result