Skip to content

Commit

Permalink
Vectorize make_vector
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 5, 2024
1 parent 31bf682 commit 4a9077b
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 0 deletions.
23 changes: 23 additions & 0 deletions pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1888,6 +1888,25 @@ def _get_vector_length_MakeVector(op, var):
return len(var.owner.inputs)


@_vectorize_node.register
def vectorize_make_vector(op: MakeVector, node, *batch_inputs):
# We vectorize make_vector as a join along the last axis of the broadcasted inputs
from pytensor.tensor.extra_ops import broadcast_arrays

# Check if we need to broadcast at all
bcast_pattern = batch_inputs[0].type.broadcastable
if not all(
batch_input.type.broadcastable == bcast_pattern for batch_input in batch_inputs
):
batch_inputs = broadcast_arrays(*batch_inputs)

# Join along the last axis
new_out = join(
-1, *[expand_dims(batch_inputs, axis=-1) for batch_inputs in batch_inputs]
)
return new_out.owner


def transfer(var, target):
"""
Return a version of `var` transferred to `target`.
Expand Down Expand Up @@ -2687,6 +2706,10 @@ def vectorize_join(op: Join, node, batch_axis, *batch_inputs):
# We can vectorize join as a shifted axis on the batch inputs if:
# 1. The batch axis is a constant and has not changed
# 2. All inputs are batched with the same broadcastable pattern

# TODO: We can relax the second condition by broadcasting the batch dimensions
# This can be done with `broadcast_arrays` if the tensors shape match at the axis or reduction
# Or otherwise by calling `broadcast_to` for each tensor that needs it
if (
original_axis.type.ndim == 0
and isinstance(original_axis, Constant)
Expand Down
30 changes: 30 additions & 0 deletions tests/tensor/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4577,6 +4577,36 @@ def core_np(x):
)


@pytest.mark.parametrize("requires_broadcasting", [False, True])
def test_vectorize_make_vector(requires_broadcasting):
signature = "(),(),()->(3)"

def core_pt(a, b, c):
return ptb.stack([a, b, c])

def core_np(a, b, c):
return np.stack([a, b, c])

a, b, c = (vector(shape=(3,)) for _ in range(3))
if requires_broadcasting:
b = matrix(shape=(5, 3))

vectorize_pt = function([a, b, c], vectorize(core_pt, signature=signature)(a, b, c))
assert not any(
isinstance(node.op, Blockwise) for node in vectorize_pt.maker.fgraph.apply_nodes
)

a_test = np.random.normal(size=a.type.shape).astype(a.type.dtype)
b_test = np.random.normal(size=b.type.shape).astype(b.type.dtype)
c_test = np.random.normal(size=c.type.shape).astype(c.type.dtype)

vectorize_np = np.vectorize(core_np, signature=signature)
np.testing.assert_allclose(
vectorize_pt(a_test, b_test, c_test),
vectorize_np(a_test, b_test, c_test),
)


@pytest.mark.parametrize("axis", [constant(1), constant(-2), shared(1)])
@pytest.mark.parametrize("broadcasting_y", ["none", "implicit", "explicit"])
@config.change_flags(cxx="") # C code not needed
Expand Down

0 comments on commit 4a9077b

Please sign in to comment.