Skip to content

Commit

Permalink
Add serializer allow lists to engine configs
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-janidlo committed Jan 17, 2025
1 parent 389c33b commit 9113665
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 8 deletions.
18 changes: 18 additions & 0 deletions changelog.d/20250113_150119_chris_restrict_serializers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,21 @@ New Functionality
# The only allowed data serializer is JSONData.
#
# (Hint: reserialize the arguments with JSONData and try again.)
- Compute Endpoints can be configured to only deserialize and execute submissions that
were serialized with specific serialization strategies. For example, with the
following config:

.. code-block:: yaml
engine:
allowed_code_serializers:
- globus_compute_sdk.serialize.DillCodeSource
- globus_compute_sdk.serialize.DillCodeTextInspect
allowed_data_serializers:
- globus_compute_sdk.serialize.JSONData
type: ThreadPoolEngine
any submissions that used the default serialization strategies (``DillCode``,
``DillDataBase64``) would be rejected, and users would be informed to use one of the
allowed strategies.
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,8 @@ class EngineModel(BaseConfigModel):
worker_port_range: t.Optional[t.Tuple[int, int]]
interchange_port_range: t.Optional[t.Tuple[int, int]]
max_retries_on_system_failure: t.Optional[int]
allowed_code_serializers: t.Optional[t.List[str]]
allowed_data_serializers: t.Optional[t.List[str]]

_validate_type = _validate_import("type", engines)
_validate_provider = _validate_params("provider")
Expand Down
8 changes: 8 additions & 0 deletions compute_endpoint/globus_compute_endpoint/engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_error_string,
get_result_error_details,
)
from globus_compute_sdk.serialize.facade import ComputeSerializer, DeserializerAllowlist
from parsl.utils import RepresentationMixin

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -80,6 +81,8 @@ def __init__(
endpoint_id: uuid.UUID | None = None,
max_retries_on_system_failure: int = 0,
working_dir: str | os.PathLike = "tasks_working_dir",
allowed_code_serializers: DeserializerAllowlist | None = None,
allowed_data_serializers: DeserializerAllowlist | None = None,
**kwargs: object,
):
"""
Expand All @@ -106,6 +109,10 @@ def __init__(
"""
self._shutdown_event = threading.Event()
self.endpoint_id = endpoint_id
self.serde = ComputeSerializer(
allowed_code_deserializer_types=allowed_code_serializers,
allowed_data_deserializer_types=allowed_data_serializers,
)
self.max_retries_on_system_failure = max_retries_on_system_failure
self._retry_table: dict[str, dict] = {}
# remove these unused vars that we are adding to just keep
Expand Down Expand Up @@ -273,6 +280,7 @@ def submit(
self.endpoint_id,
run_dir=self.working_dir,
run_in_sandbox=self.run_in_sandbox,
serde=self.serde,
)
except Exception as e:
future = Future()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
GlobusComputeEngineBase,
ReportingThread,
)
from globus_compute_sdk.serialize.facade import DeserializerAllowlist
from parsl.executors.high_throughput.executor import HighThroughputExecutor
from parsl.jobs.job_status_poller import JobStatusPoller

Expand Down Expand Up @@ -51,6 +52,8 @@ def __init__(
strategy: str | None = None,
job_status_kwargs: t.Optional[JobStatusPollerKwargs] = None,
run_in_sandbox: bool = False,
allowed_code_serializers: DeserializerAllowlist | None = None,
allowed_data_serializers: DeserializerAllowlist | None = None,
**kwargs,
):
"""``GlobusComputeEngine`` is a shim over `Parsl's HighThroughputExecutor
Expand Down Expand Up @@ -108,7 +111,11 @@ def __init__(
self.label = label or type(self).__name__
self._status_report_thread = ReportingThread(target=self.report_status, args=[])
super().__init__(
*args, max_retries_on_system_failure=max_retries_on_system_failure, **kwargs
*args,
max_retries_on_system_failure=max_retries_on_system_failure,
allowed_code_serializers=allowed_code_serializers,
allowed_data_serializers=allowed_data_serializers,
**kwargs,
)
self.strategy = strategy

Expand Down
6 changes: 3 additions & 3 deletions compute_endpoint/globus_compute_endpoint/engines/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

log = logging.getLogger(__name__)

_serde = ComputeSerializer()
_RESULT_SIZE_LIMIT = 10 * 1024 * 1024 # 10 MiB


Expand All @@ -33,6 +32,7 @@ def execute_task(
run_dir: str | os.PathLike,
result_size_limit: int = _RESULT_SIZE_LIMIT,
run_in_sandbox: bool = False,
serde: ComputeSerializer = ComputeSerializer(),
) -> bytes:
"""Execute task is designed to enable any executor to execute a Task payload
and return a Result payload, where the payload follows the globus-compute protocols
Expand Down Expand Up @@ -84,7 +84,7 @@ def execute_task(
try:
_task, task_buffer = _unpack_messagebody(task_body)
log.debug("executing task task_id='%s'", task_id)
result = _call_user_function(task_buffer)
result = _call_user_function(task_buffer, serde)

res_len = len(result)
if res_len > result_size_limit:
Expand Down Expand Up @@ -142,7 +142,7 @@ def _unpack_messagebody(message: bytes) -> tuple[Task, str]:
return task, task.task_buffer


def _call_user_function(task_buffer: str, serde: ComputeSerializer = _serde) -> str:
def _call_user_function(task_buffer: str, serde: ComputeSerializer) -> str:
"""Deserialize the buffer and execute the task.
Parameters
----------
Expand Down
17 changes: 15 additions & 2 deletions compute_endpoint/globus_compute_endpoint/engines/process_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,31 @@
GlobusComputeEngineBase,
ReportingThread,
)
from globus_compute_sdk.serialize.facade import DeserializerAllowlist

logger = logging.getLogger(__name__)


class ProcessPoolEngine(GlobusComputeEngineBase):
def __init__(self, *args, label: str = "ProcessPoolEngine", **kwargs):
def __init__(
self,
*args,
label: str = "ProcessPoolEngine",
allowed_code_serializers: DeserializerAllowlist | None = None,
allowed_data_serializers: DeserializerAllowlist | None = None,
**kwargs,
):
self.label = label
self.executor: t.Optional[NativeExecutor] = None
self._executor_args = args
self._executor_kwargs = kwargs
self._status_report_thread = ReportingThread(target=self.report_status, args=[])
super().__init__(*args, **kwargs)
super().__init__(
*args,
**kwargs,
allowed_code_serializers=allowed_code_serializers,
allowed_data_serializers=allowed_data_serializers,
)

def start(
self,
Expand Down
17 changes: 15 additions & 2 deletions compute_endpoint/globus_compute_endpoint/engines/thread_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,29 @@
GlobusComputeEngineBase,
ReportingThread,
)
from globus_compute_sdk.serialize.facade import DeserializerAllowlist

logger = logging.getLogger(__name__)


class ThreadPoolEngine(GlobusComputeEngineBase):
def __init__(self, *args, label: str = "ThreadPoolEngine", **kwargs):
def __init__(
self,
*args,
label: str = "ThreadPoolEngine",
allowed_code_serializers: DeserializerAllowlist | None = None,
allowed_data_serializers: DeserializerAllowlist | None = None,
**kwargs,
):
self.label = label
self.executor = NativeExecutor(*args, **kwargs)
self._status_report_thread = ReportingThread(target=self.report_status, args=[])
super().__init__(*args, **kwargs)
super().__init__(
*args,
**kwargs,
allowed_code_serializers=allowed_code_serializers,
allowed_data_serializers=allowed_data_serializers,
)

def start(
self,
Expand Down
21 changes: 21 additions & 0 deletions compute_endpoint/tests/unit/test_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ThreadPoolEngine,
)
from globus_compute_endpoint.engines.base import GlobusComputeEngineBase
from globus_compute_sdk.serialize.concretes import SELECTABLE_STRATEGIES
from parsl import HighThroughputExecutor
from parsl.executors.high_throughput.interchange import ManagerLost
from parsl.providers import KubernetesProvider
Expand Down Expand Up @@ -169,6 +170,26 @@ def test_engine_submit_internal(
break


@pytest.mark.parametrize(
"engine_type",
(ProcessPoolEngine, ThreadPoolEngine, GlobusComputeEngine),
)
def test_allowed_serializers_passthrough_to_serde(engine_type, engine_runner):
allowed_code_serializers = [s for s in SELECTABLE_STRATEGIES if s.for_code]
allowed_data_serializers = [s for s in SELECTABLE_STRATEGIES if not s.for_code]

engine = engine_runner(
engine_type,
allowed_code_serializers=allowed_code_serializers,
allowed_data_serializers=allowed_data_serializers,
)

assert engine.serde is not None

assert engine.serde.allowed_code_deserializer_types == set(allowed_code_serializers)
assert engine.serde.allowed_data_deserializer_types == set(allowed_data_serializers)


def test_gc_engine_system_failure(ez_pack_task, task_uuid, engine_runner):
"""Test behavior of engine failure killing task"""
engine = engine_runner(GlobusComputeEngine, max_retries_on_system_failure=0)
Expand Down

0 comments on commit 9113665

Please sign in to comment.