Skip to content

Commit

Permalink
Add options to restrict deserializers in ComputeSerializer
Browse files Browse the repository at this point in the history
Also move some errors to SerdeErrors instead of SerializationError /
DeserializationError, if those errors happen before anything is actually
serialized or deserialized; and make the ComputeSerializer enforce that
selectable serialization strategies are properly for_code or not.
  • Loading branch information
chris-janidlo committed Jan 17, 2025
1 parent a3505a0 commit 389c33b
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 19 deletions.
31 changes: 31 additions & 0 deletions changelog.d/20250113_150119_chris_restrict_serializers.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
New Functionality
^^^^^^^^^^^^^^^^^

- The ``ComputeSerializer`` can now be told to only deserialize payloads that were
serialized with specific serialization strategies. For example:

.. code-block:: python
import os
from globus_compute_sdk.serialize import ComputeSerializer, JSONData
class MaliciousPayload():
def __reduce__(self):
# this method returns a 2-tuple (callable, arguments) that dill calls to reconstruct the object
return os.system, ("<your favorite arbitrary code execution script>",)
evil_serializer = ComputeSerializer() # uses DillDataBase64 by default
payload = evil_serializer.serialize(MaliciousPayload())
safe_deserializer = ComputeSerializer(
allowed_data_deserializer_types=[JSONData]
)
safe_deserializer.deserialize(payload)
# globus_compute_sdk.errors.error_types.DeserializationError: Deserialization failed:
#
# Data serializer DillDataBase64 is not allowed in this ComputeSerializer.
# The only allowed data serializer is JSONData.
#
# (Hint: reserialize the arguments with JSONData and try again.)
6 changes: 6 additions & 0 deletions compute_sdk/globus_compute_sdk/errors/error_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,12 @@ def __repr__(self):
class SerdeError(ComputeError):
"""Base class for SerializationError and DeserializationError"""

def __init__(self, reason: str):
self.reason = reason

def __repr__(self):
return self.reason


class SerializationError(SerdeError):
"""Something failed during serialization."""
Expand Down
120 changes: 104 additions & 16 deletions compute_sdk/globus_compute_sdk/serialize/facade.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
from __future__ import annotations

import importlib
import logging
import textwrap
import typing as t

from globus_compute_sdk.errors import DeserializationError, SerializationError
from globus_compute_sdk.errors import (
DeserializationError,
SerdeError,
SerializationError,
)
from globus_compute_sdk.serialize.base import IDENTIFIER_LENGTH, SerializationStrategy
from globus_compute_sdk.serialize.concretes import (
DEFAULT_STRATEGY_CODE,
Expand All @@ -14,6 +20,61 @@

logger = logging.getLogger(__name__)

DeserializerAllowlist = t.Iterable[t.Union[type[SerializationStrategy], str]]


def assert_strategy_type_valid(
strategy_type: type[SerializationStrategy], for_code: bool
) -> None:
if strategy_type not in SELECTABLE_STRATEGIES:
raise SerdeError(
f"{strategy_type.__name__} is not a known serialization strategy"
f" (must be one of {SELECTABLE_STRATEGIES})"
)

if strategy_type.for_code != for_code:
etype = "code" if for_code else "data"
gtype = "code" if strategy_type.for_code else "data"
raise SerdeError(
f"{strategy_type.__name__} is a {gtype} serialization strategy,"
f" expected a {etype} strategy"
)


def validate_strategy(
strategy: SerializationStrategy, for_code: bool
) -> SerializationStrategy:
assert_strategy_type_valid(type(strategy), for_code)
return strategy


def validate_allowlist(
unvalidated: DeserializerAllowlist, for_code: bool
) -> set[type[SerializationStrategy]]:
validated = set()
for value in unvalidated:
resolved_strategy_class = None
if isinstance(value, str):
try:
mod_name, class_name = value.rsplit(".", 1)
mod = importlib.import_module(mod_name)
resolved_strategy_class = getattr(mod, class_name)
except Exception as e:
raise SerdeError(f"`{value}` is not a valid path to a strategy") from e
else:
resolved_strategy_class = value

if not issubclass(resolved_strategy_class, SerializationStrategy):
raise SerdeError(
"Allowed deserializers must either be SerializationStrategies"
f" or valid paths to them (got {value})"
)

assert_strategy_type_valid(resolved_strategy_class, for_code)
validated.add(resolved_strategy_class)

return validated


class ComputeSerializer:
"""Provides uniform interface to underlying serialization strategies"""
Expand All @@ -22,23 +83,23 @@ def __init__(
self,
strategy_code: SerializationStrategy | None = None,
strategy_data: SerializationStrategy | None = None,
*,
allowed_code_deserializer_types: DeserializerAllowlist | None = None,
allowed_data_deserializer_types: DeserializerAllowlist | None = None,
):
"""Instantiate the appropriate classes"""

def validate(strategy: SerializationStrategy) -> SerializationStrategy:
if type(strategy) not in SELECTABLE_STRATEGIES:
raise SerializationError(
f"{strategy} is not a known serialization strategy "
f"(must be one of {SELECTABLE_STRATEGIES})"
)

return strategy

self.strategy_code = (
validate(strategy_code) if strategy_code else DEFAULT_STRATEGY_CODE
self.code_serializer = validate_strategy(
strategy_code or DEFAULT_STRATEGY_CODE, True
)
self.strategy_data = (
validate(strategy_data) if strategy_data else DEFAULT_STRATEGY_DATA
self.data_serializer = validate_strategy(
strategy_data or DEFAULT_STRATEGY_DATA, False
)
self.allowed_code_deserializer_types = validate_allowlist(
allowed_code_deserializer_types or [], True
)
self.allowed_data_deserializer_types = validate_allowlist(
allowed_data_deserializer_types or [], False
)

self.strategies = {
Expand All @@ -48,9 +109,9 @@ def validate(strategy: SerializationStrategy) -> SerializationStrategy:

def serialize(self, data):
if callable(data):
stype, strategy = "Code", self.strategy_code
stype, strategy = "Code", self.code_serializer
else:
stype, strategy = "Data", self.strategy_data
stype, strategy = "Data", self.data_serializer

try:
return strategy.serialize(data)
Expand All @@ -72,6 +133,8 @@ def deserialize(self, payload):
if not strategy:
raise DeserializationError(f"Invalid header: {header} in data payload")

self.assert_deserializer_allowed(strategy)

return strategy.deserialize(payload)

@staticmethod
Expand Down Expand Up @@ -147,3 +210,28 @@ def check_strategies(self, function: t.Callable, *args, **kwargs):
return self.unpack_and_deserialize(packed)
except Exception as e:
raise DeserializationError("check_strategies failed to deserialize") from e

def assert_deserializer_allowed(self, strategy: SerializationStrategy) -> None:
allowset = (
self.allowed_code_deserializer_types
if strategy.for_code
else self.allowed_data_deserializer_types
)

if not allowset or type(strategy) in allowset:
return

stype = "Code" if strategy.for_code else "Data"
payload_type = "function" if strategy.for_code else "arguments"
allowed_names = ", ".join(sorted(_t.__name__ for _t in allowset))
msg = (
f"{stype} serializer {type(strategy).__name__} disabled by current"
f" configuration. The current configuration requires the *{payload_type}*"
f" to be serialized with one of the allowed classes:\n\n"
f" Allowed serializers: {allowed_names}"
# note that there is (intentionally) no link to the documentation in this
# error message - that's because the SDK appends its own hint to any
# apparent serialization errors coming back from the endpoint. see https://github.com/globus/globus-compute/blob/112dc3ae9d9986f36618976f8806f0bd48702460/compute_sdk/globus_compute_sdk/errors/error_types.py#L74 # noqa: E501
)

raise DeserializationError(textwrap.indent(msg, " "))
149 changes: 146 additions & 3 deletions compute_sdk/tests/integration/test_serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,11 @@

import globus_compute_sdk.serialize.concretes as concretes
import pytest
from globus_compute_sdk.errors import SerializationError
from globus_compute_sdk.errors import (
DeserializationError,
SerdeError,
SerializationError,
)
from globus_compute_sdk.serialize.base import SerializationStrategy
from globus_compute_sdk.serialize.facade import ComputeSerializer

Expand Down Expand Up @@ -337,6 +341,18 @@ def test_selectable_serialization(strategy):
assert ser_data[:ID_LEN] == strategy.identifier


@pytest.mark.parametrize("strategy", concretes.SELECTABLE_STRATEGIES)
def test_selectable_serialization_enforces_for_code(strategy):
with pytest.raises(SerdeError) as pyt_exc:
ComputeSerializer(strategy_code=strategy(), strategy_data=strategy())

if strategy.for_code:
e = "is a code serialization strategy, expected a data strategy"
else:
e = "is a data serialization strategy, expected a code strategy"
assert e in str(pyt_exc)


def test_serializer_errors_on_unknown_strategy():
class NewStrategy(SerializationStrategy):
identifier = "aa\n"
Expand All @@ -350,12 +366,12 @@ def deserialize(self, payload):

strategy = NewStrategy()

with pytest.raises(SerializationError):
with pytest.raises(SerdeError):
ComputeSerializer(strategy_code=strategy)

NewStrategy.for_code = False

with pytest.raises(SerializationError):
with pytest.raises(SerdeError):
ComputeSerializer(strategy_data=strategy)


Expand All @@ -382,3 +398,130 @@ def test_check_strategies(strategy_code, strategy_data, function, args, kwargs):
new_result = new_fn(*new_args, **new_kwargs)

assert original_result == new_result


@pytest.mark.parametrize(
"disallowed_strategy", (s for s in concretes.SELECTABLE_STRATEGIES if s.for_code)
)
def test_allowed_deserializers_code(disallowed_strategy):
allowlist = [
strategy
for strategy in concretes.SELECTABLE_STRATEGIES
if strategy.for_code and strategy != disallowed_strategy
]

assert allowlist, "expect to have at least one allowed deserializer"
assert disallowed_strategy not in allowlist, "sanity check"

serializer = ComputeSerializer(allowed_code_deserializer_types=allowlist)
payload = disallowed_strategy().serialize(foo)

with pytest.raises(DeserializationError) as pyt_exc:
serializer.deserialize(payload)
assert f"serializer {disallowed_strategy.__name__} disabled" in str(pyt_exc)


@pytest.mark.parametrize(
"disallowed_strategy",
(s for s in concretes.SELECTABLE_STRATEGIES if not s.for_code),
)
def test_allowed_deserializers_data(disallowed_strategy):
allowlist = [
strategy
for strategy in concretes.SELECTABLE_STRATEGIES
if not strategy.for_code and strategy != disallowed_strategy
]

assert allowlist, "expect to have at least one allowed deserializer"
assert disallowed_strategy not in allowlist, "sanity check"

serializer = ComputeSerializer(allowed_data_deserializer_types=allowlist)
payload = disallowed_strategy().serialize("foo")

with pytest.raises(DeserializationError) as pyt_exc:
serializer.deserialize(payload)
assert f"serializer {disallowed_strategy.__name__} disabled" in str(pyt_exc)


@pytest.mark.parametrize(
"allowlist",
[
["globus_compute_sdk.serialize.concretes.DillCode"],
[
"globus_compute_sdk.serialize.concretes.DillCode",
"globus_compute_sdk.serialize.concretes.DillCodeTextInspect",
"globus_compute_sdk.serialize.concretes.DillCodeSource",
"globus_compute_sdk.serialize.concretes.CombinedCode",
],
],
)
def test_allowed_serializers_imports_from_path(allowlist):
serializer = ComputeSerializer(allowed_code_deserializer_types=allowlist)
assert len(serializer.allowed_code_deserializer_types) == len(allowlist)


@pytest.mark.parametrize(
"allowlist",
[
["my_malicious_package.my_malicious_serializer"],
["invalid_path_1"],
["invalid path 2"],
[""],
[
"globus_compute_sdk.serialize.concretes.DillCode",
"my_malicious_package.my_malicious_serializer",
],
[
"globus_compute_sdk.serialize.concretes.DillCode",
"invalid path",
],
[
"globus_compute_sdk.serialize.concretes.DillCode",
"",
],
],
)
def test_allowed_serializers_errors_on_invalid_import_path(allowlist):
with pytest.raises(SerdeError) as pyt_exc:
ComputeSerializer(allowed_code_deserializer_types=allowlist)
assert "is not a valid path to a strategy" in str(pyt_exc)


@pytest.mark.parametrize(
"types",
(
[s for s in concretes.SELECTABLE_STRATEGIES if s.for_code],
[s for s in concretes.SELECTABLE_STRATEGIES if not s.for_code],
),
)
def test_allowed_deserializers_enforces_for_code(types):
code_data = (("data", "code"), ("code", "data"))[types[0].for_code]
exp_exc = "%s serialization strategy, expected a %s" % code_data
import_str = [f"{s.__module__}.{s.__qualname__}" for s in types]
for allowed in (types, import_str):
with pytest.raises(SerdeError) as pyt_exc:
ComputeSerializer(
allowed_code_deserializer_types=allowed,
allowed_data_deserializer_types=allowed,
)
assert exp_exc in str(pyt_exc.value)


def test_allowed_deserializers_errors_on_unknown_strategy():
class NewStrategy(SerializationStrategy):
identifier = "aa\n"
for_code = True

def serialize(self, data):
pass

def deserialize(self, payload):
pass

with pytest.raises(SerdeError):
ComputeSerializer(allowed_code_deserializer_types=[NewStrategy])

NewStrategy.for_code = False

with pytest.raises(SerdeError):
ComputeSerializer(allowed_data_deserializer_types=[NewStrategy])

0 comments on commit 389c33b

Please sign in to comment.