Skip to content

Commit

Permalink
Replace big concatenation in op_broadcast_in_dim with MIL tiling oper…
Browse files Browse the repository at this point in the history
…ation (#18)
  • Loading branch information
kasper0406 authored Oct 30, 2024
1 parent d2135e0 commit 223cfec
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions stablehlo_coreml/converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,22 +342,24 @@ def op_broadcast_in_dim(self, context: TranslationContext, op: BroadcastInDimOp)
x = context[op.operand.get_name()]

result_shape = op.result.type.shape
if result_shape == []:
if len(result_shape) == 0:
# Cast a scalar shape to a (1,) shape
result_shape = [1]
result_shape_rank = len(result_shape)

reshaped_operand_shape = [1] * len(result_shape)
reshaped_operand_shape = [1] * result_shape_rank
for i, op_shape in enumerate(op.operand.type.shape):
result_idx = op.broadcast_dimensions[i]
reshaped_operand_shape[result_idx] = op_shape

x = mb.reshape(x=x, shape=reshaped_operand_shape)

result_tiling = [1] * result_shape_rank
for result_dim, current_shape in enumerate(reshaped_operand_shape):
if current_shape != result_shape[result_dim]:
assert current_shape == 1
# Replicate data along dimension `dim` until the result dimension is filled up
values = [x] * result_shape[result_dim]
x = mb.concat(values=values, axis=result_dim)
# Replicate data along dimension `dim` until the result dimension matches
assert result_shape[result_dim] % current_shape == 0
result_tiling[result_dim] = result_shape[result_dim] // current_shape
x = mb.tile(x=x, reps=result_tiling)

context.add_result(op.result, x)

Expand Down

0 comments on commit 223cfec

Please sign in to comment.