Skip to content

Commit

Permalink
Add serializer allow lists to engine configs
Browse files Browse the repository at this point in the history
  • Loading branch information
chris-janidlo committed Jan 22, 2025
1 parent 2d8625f commit e39d2ab
Show file tree
Hide file tree
Showing 8 changed files with 76 additions and 8 deletions.
17 changes: 17 additions & 0 deletions changelog.d/20250113_150119_chris_restrict_serializers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -30,3 +30,20 @@ New Functionality
# The only allowed data serializer is JSONData.
#
# (Hint: reserialize the arguments with JSONData and try again.)
- Compute Endpoints can be configured to only deserialize and execute submissions that
were serialized with specific serialization strategies. For example, with the
following config:

.. code-block:: yaml
engine:
allowed_serializers:
- globus_compute_sdk.serialize.DillCodeSource
- globus_compute_sdk.serialize.DillCodeTextInspect
- globus_compute_sdk.serialize.JSONData
type: ThreadPoolEngine
any submissions that used the default serialization strategies (``DillCode``,
``DillDataBase64``) would be rejected, and users would be informed to use one of the
allowed strategies.
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ class EngineModel(BaseConfigModel):
worker_port_range: t.Optional[t.Tuple[int, int]]
interchange_port_range: t.Optional[t.Tuple[int, int]]
max_retries_on_system_failure: t.Optional[int]
allowed_serializers: t.Optional[t.List[str]]

_validate_type = _validate_import("type", engines)
_validate_provider = _validate_params("provider")
Expand Down
10 changes: 10 additions & 0 deletions compute_endpoint/globus_compute_endpoint/engines/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
get_error_string,
get_result_error_details,
)
from globus_compute_sdk.serialize.facade import ComputeSerializer, DeserializerAllowlist
from parsl.utils import RepresentationMixin

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(
endpoint_id: uuid.UUID | None = None,
max_retries_on_system_failure: int = 0,
working_dir: str | os.PathLike = "tasks_working_dir",
allowed_serializers: DeserializerAllowlist | None = None,
**kwargs: object,
):
"""
Expand All @@ -102,10 +104,17 @@ def __init__(
to the endpoint.run_dir. If an absolute path is supplied, it is
used as is. default="tasks_working_dir"
allowed_serializers: DeserializerAllowlist | None
A list of serialization strategy types or import paths to such
types, which the engine's serializer will check against whenever
deserializing user submissions. If falsy, every serializer is
allowed. See ComputeSerializer for more details. default=None
kwargs
"""
self._shutdown_event = threading.Event()
self.endpoint_id = endpoint_id
self.serde = ComputeSerializer(allowed_deserializer_types=allowed_serializers)
self.max_retries_on_system_failure = max_retries_on_system_failure
self._retry_table: dict[str, dict] = {}
# remove these unused vars that we are adding to just keep
Expand Down Expand Up @@ -273,6 +282,7 @@ def submit(
self.endpoint_id,
run_dir=self.working_dir,
run_in_sandbox=self.run_in_sandbox,
serde=self.serde,
)
except Exception as e:
future = Future()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
GlobusComputeEngineBase,
ReportingThread,
)
from globus_compute_sdk.serialize.facade import DeserializerAllowlist
from parsl.executors.high_throughput.executor import HighThroughputExecutor
from parsl.jobs.job_status_poller import JobStatusPoller

Expand Down Expand Up @@ -51,6 +52,7 @@ def __init__(
strategy: str | None = None,
job_status_kwargs: t.Optional[JobStatusPollerKwargs] = None,
run_in_sandbox: bool = False,
allowed_serializers: DeserializerAllowlist | None = None,
**kwargs,
):
"""``GlobusComputeEngine`` is a shim over `Parsl's HighThroughputExecutor
Expand Down Expand Up @@ -108,7 +110,10 @@ def __init__(
self.label = label or type(self).__name__
self._status_report_thread = ReportingThread(target=self.report_status, args=[])
super().__init__(
*args, max_retries_on_system_failure=max_retries_on_system_failure, **kwargs
*args,
max_retries_on_system_failure=max_retries_on_system_failure,
allowed_serializers=allowed_serializers,
**kwargs,
)
self.strategy = strategy

Expand Down
7 changes: 4 additions & 3 deletions compute_endpoint/globus_compute_endpoint/engines/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

log = logging.getLogger(__name__)

_serde = ComputeSerializer()
_RESULT_SIZE_LIMIT = 10 * 1024 * 1024 # 10 MiB


Expand All @@ -33,6 +32,7 @@ def execute_task(
run_dir: str | os.PathLike,
result_size_limit: int = _RESULT_SIZE_LIMIT,
run_in_sandbox: bool = False,
serde: ComputeSerializer = ComputeSerializer(),
) -> bytes:
"""Execute task is designed to enable any executor to execute a Task payload
and return a Result payload, where the payload follows the globus-compute protocols
Expand All @@ -45,6 +45,7 @@ def execute_task(
result_size_limit: result size in bytes
run_dir: directory to run function in
run_in_sandbox: if enabled run task under run_dir/<task_uuid>
serde: serializer for deserializing user submissions and serializing results
Returns
-------
Expand Down Expand Up @@ -84,7 +85,7 @@ def execute_task(
try:
_task, task_buffer = _unpack_messagebody(task_body)
log.debug("executing task task_id='%s'", task_id)
result = _call_user_function(task_buffer)
result = _call_user_function(task_buffer, serde)

res_len = len(result)
if res_len > result_size_limit:
Expand Down Expand Up @@ -142,7 +143,7 @@ def _unpack_messagebody(message: bytes) -> tuple[Task, str]:
return task, task.task_buffer


def _call_user_function(task_buffer: str, serde: ComputeSerializer = _serde) -> str:
def _call_user_function(task_buffer: str, serde: ComputeSerializer) -> str:
"""Deserialize the buffer and execute the task.
Parameters
----------
Expand Down
15 changes: 13 additions & 2 deletions compute_endpoint/globus_compute_endpoint/engines/process_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,29 @@
GlobusComputeEngineBase,
ReportingThread,
)
from globus_compute_sdk.serialize.facade import DeserializerAllowlist

logger = logging.getLogger(__name__)


class ProcessPoolEngine(GlobusComputeEngineBase):
def __init__(self, *args, label: str = "ProcessPoolEngine", **kwargs):
def __init__(
self,
*args,
label: str = "ProcessPoolEngine",
allowed_serializers: DeserializerAllowlist | None = None,
**kwargs,
):
self.label = label
self.executor: t.Optional[NativeExecutor] = None
self._executor_args = args
self._executor_kwargs = kwargs
self._status_report_thread = ReportingThread(target=self.report_status, args=[])
super().__init__(*args, **kwargs)
super().__init__(
*args,
**kwargs,
allowed_serializers=allowed_serializers,
)

def start(
self,
Expand Down
15 changes: 13 additions & 2 deletions compute_endpoint/globus_compute_endpoint/engines/thread_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,27 @@
GlobusComputeEngineBase,
ReportingThread,
)
from globus_compute_sdk.serialize.facade import DeserializerAllowlist

logger = logging.getLogger(__name__)


class ThreadPoolEngine(GlobusComputeEngineBase):
def __init__(self, *args, label: str = "ThreadPoolEngine", **kwargs):
def __init__(
self,
*args,
label: str = "ThreadPoolEngine",
allowed_serializers: DeserializerAllowlist | None = None,
**kwargs,
):
self.label = label
self.executor = NativeExecutor(*args, **kwargs)
self._status_report_thread = ReportingThread(target=self.report_status, args=[])
super().__init__(*args, **kwargs)
super().__init__(
*args,
**kwargs,
allowed_serializers=allowed_serializers,
)

def start(
self,
Expand Down
12 changes: 12 additions & 0 deletions compute_endpoint/tests/unit/test_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
ThreadPoolEngine,
)
from globus_compute_endpoint.engines.base import GlobusComputeEngineBase
from globus_compute_sdk.serialize.concretes import SELECTABLE_STRATEGIES
from parsl import HighThroughputExecutor
from parsl.executors.high_throughput.interchange import ManagerLost
from parsl.providers import KubernetesProvider
Expand Down Expand Up @@ -169,6 +170,17 @@ def test_engine_submit_internal(
break


@pytest.mark.parametrize(
"engine_type",
(ProcessPoolEngine, ThreadPoolEngine, GlobusComputeEngine),
)
def test_allowed_serializers_passthrough_to_serde(engine_type, engine_runner):
engine = engine_runner(engine_type, allowed_serializers=SELECTABLE_STRATEGIES)

assert engine.serde is not None
assert engine.serde.allowed_deserializer_types == set(SELECTABLE_STRATEGIES)


def test_gc_engine_system_failure(ez_pack_task, task_uuid, engine_runner):
"""Test behavior of engine failure killing task"""
engine = engine_runner(GlobusComputeEngine, max_retries_on_system_failure=0)
Expand Down

0 comments on commit e39d2ab

Please sign in to comment.