Skip to content

Commit

Permalink
changes
Browse files Browse the repository at this point in the history
  • Loading branch information
LeiGlobus committed Nov 22, 2024
1 parent 3200583 commit c2927de
Show file tree
Hide file tree
Showing 3 changed files with 74 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import multiprocessing
import os
import queue
import socket
import threading
import time
import typing as t
Expand Down Expand Up @@ -243,7 +244,7 @@ def __init__(
# Tuning info
prefetch_capacity=10,
provider=LocalProvider(),
address="localhost",
address="127.0.0.1",
worker_ports=None,
worker_port_range=(54000, 55000),
interchange_port_range=(55000, 56000),
Expand Down Expand Up @@ -299,17 +300,14 @@ def __init__(
self.endpoint_id = endpoint_id
self._task_counter = 0

try:
# if address != "localhost":
# 'localhost' works for both v4 and v6 in some circumstances
# where an actual numeric IP isn't required
ipaddress.ip_address(address=address)
except Exception:
log.critical(
f"Invalid address supplied: {address}. "
"Please use a valid IPv4 or IPv6 address"
if not HighThroughputEngine.is_hostname_or_ip(address):
err_msg = (
f"Invalid address supplied: ({address}) "
"Please use a valid hostname or IPv4/IPv6 address"
)
raise
log.critical(err_msg)
raise ValueError(err_msg)

self.address = address
self.worker_ports = worker_ports
self.worker_port_range = worker_port_range
Expand Down Expand Up @@ -422,6 +420,28 @@ def start(

return self.outgoing_q.port, self.incoming_q.port, self.command_client.port

@staticmethod
def is_hostname_or_ip(hostname_or_ip: str) -> bool:
"""
Utility method to verify that the input is a valid hostname or
IP address. This is potentially useful elsewhere, if used,
should we move it to another module?
"""
if not hostname_or_ip:
return False
else:
try:
socket.gethostbyname(hostname_or_ip)
return True
except socket.gaierror:
# Not a hostname, now check IP
pass
try:
ipaddress.ip_address(address=hostname_or_ip)
except ValueError:
return False
return True

def _start_local_interchange_process(self):
"""Starts the interchange process locally
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
#!/usr/bin/env python3

from __future__ import annotations

import ipaddress
import logging
import time

Expand All @@ -10,11 +13,22 @@
log = logging.getLogger(__name__)


def remap_ipv6_loopback(ip_address):
# Special case for being compatible with ipv4 and v6
if ip_address == "::1":
ip_address = "localhost"
return f"tcp://{ip_address}"
def _zmq_canonicalize_address(addr: str | int) -> str:
try:
ip = ipaddress.ip_address(addr)
except ValueError:
# Not a valid IPv4 or IPv6 address
if isinstance(addr, int):
# If it was an integer, then it's just plain invalid
raise

# Otherwise, it was likely a hostname; let another layer deal with it
return addr

if ip.version == 4:
return str(ip) # like "12.34.56.78"
elif ip.version == 6:
return f"[{ip}]" # like "[::1]"


class CommandClient:
Expand All @@ -36,7 +50,7 @@ def __init__(self, ip_address, port_range):
self.zmq_socket = self.context.socket(zmq.DEALER)
self.zmq_socket.set_hwm(0)
self.port = self.zmq_socket.bind_to_random_port(
remap_ipv6_loopback(ip_address),
f"tcp://{_zmq_canonicalize_address(ip_address)}",
min_port=port_range[0],
max_port=port_range[1],
)
Expand Down Expand Up @@ -78,7 +92,7 @@ def __init__(self, ip_address, port_range):
self.zmq_socket.set_hwm(0)

self.port = self.zmq_socket.bind_to_random_port(
remap_ipv6_loopback(ip_address),
f"tcp://{_zmq_canonicalize_address(ip_address)}",
min_port=port_range[0],
max_port=port_range[1],
)
Expand Down Expand Up @@ -153,7 +167,7 @@ def __init__(self, ip_address, port_range):
self.results_receiver = self.context.socket(zmq.DEALER)
self.results_receiver.set_hwm(0)
self.port = self.results_receiver.bind_to_random_port(
remap_ipv6_loopback(ip_address),
f"tcp://{_zmq_canonicalize_address(ip_address)}",
min_port=port_range[0],
max_port=port_range[1],
)
Expand Down
21 changes: 21 additions & 0 deletions compute_endpoint/tests/unit/test_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,24 @@ def test_gcmpiengine_accepts_resource_specification(task_uuid, randomstring):

a, _k = engine.executor.submit.call_args
assert spec in a


@pytest.mark.parametrize(
("input", "is_valid"),
(
[None, False],
["", False],
["localhost.localhost", False],
["localhost", True],
["1.2.3.4.5", False],
["127.0.0.1", True],
["example.com", True],
["0:0:0:0:0:0:0:1", True],
["11111:0:0:0:0:0:0:1", False],
["::1", True],
["abc", False],
),
)
def test_hostname_or_ip_validation(input, is_valid):
result = HighThroughputEngine.is_hostname_or_ip(input)
assert is_valid == result

0 comments on commit c2927de

Please sign in to comment.