Skip to content

Commit

Permalink
Vectorize make_vector
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 8, 2024
1 parent 5fd729d commit 7f623fe
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 0 deletions.
21 changes: 21 additions & 0 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1890,6 +1890,23 @@ def _get_vector_length_MakeVector(op, var):
return len(var.owner.inputs)


@_vectorize_node.register
def vectorize_make_vector(op: MakeVector, node, *batch_inputs):
# We vectorize make_vector as a join along the last axis of the broadcasted inputs
from pytensor.tensor.extra_ops import broadcast_arrays

# Check if we need to broadcast at all
bcast_pattern = batch_inputs[0].type.broadcastable
if not all(
batch_input.type.broadcastable == bcast_pattern for batch_input in batch_inputs
):
batch_inputs = broadcast_arrays(*batch_inputs)

# Join along the last axis
new_out = stack(batch_inputs, axis=-1)
return new_out.owner


def transfer(var, target):
"""
Return a version of `var` transferred to `target`.
Expand Down Expand Up @@ -2690,6 +2707,10 @@ def vectorize_join(op: Join, node, batch_axis, *batch_inputs):
# We can vectorize join as a shifted axis on the batch inputs if:
# 1. The batch axis is a constant and has not changed
# 2. All inputs are batched with the same broadcastable pattern

# TODO: We can relax the second condition by broadcasting the batch dimensions
# This can be done with `broadcast_arrays` if the tensors shape match at the axis or reduction
# Or otherwise by calling `broadcast_to` for each tensor that needs it
if (
original_axis.type.ndim == 0
and isinstance(original_axis, Constant)
Expand Down
40 changes: 40 additions & 0 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4577,6 +4577,46 @@ def core_np(x):
)


@pytest.mark.parametrize(
"batch_shapes",
[
((3,),), # edge case of make_vector with a single input
((), (), ()), # Useless
((3,), (3,), (3,)), # No broadcasting needed
((3,), (5, 3), ()), # Broadcasting needed
],
)
def test_vectorize_make_vector(batch_shapes):
n_inputs = len(batch_shapes)
input_sig = ",".join(["()"] * n_inputs)
signature = f"{input_sig}->({n_inputs})" # Something like "(),(),()->(3)"

def core_pt(*scalars):
out = stack(scalars)
out.dprint()
return out

def core_np(*scalars):
return np.stack(scalars)

tensors = [tensor(shape=shape) for shape in batch_shapes]

vectorize_pt = function(tensors, vectorize(core_pt, signature=signature)(*tensors))
assert not any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)

test_values = [
np.random.normal(size=tensor.type.shape).astype(tensor.type.dtype)
for tensor in tensors
]

np.testing.assert_allclose(
vectorize_pt(*test_values),
np.vectorize(core_np, signature=signature)(*test_values),
)


@pytest.mark.parametrize("axis", [constant(1), constant(-2), shared(1)])
@pytest.mark.parametrize("broadcasting_y", ["none", "implicit", "explicit"])
@config.change_flags(cxx="") # C code not needed
Expand Down

0 comments on commit 7f623fe

Please sign in to comment.