diff --git a/stablehlo_coreml/converter.py b/stablehlo_coreml/converter.py index 54e02d0..8673983 100644 --- a/stablehlo_coreml/converter.py +++ b/stablehlo_coreml/converter.py @@ -64,7 +64,7 @@ def pop_function(self): self.variables.pop(self.path()) self._path.pop() - def add_variable(self, name: str, mil_var): + def add_variable(self, name: str, mil_var: mil.Var): path = self.path() if path not in self.variables: self.variables[path] = {} @@ -73,6 +73,22 @@ def add_variable(self, name: str, mil_var): raise ValueError(f"Variable {name} is already defined in path {path}") self.variables[path][name] = mil_var + def add_result(self, hlo_result, result: mil.Var): + result_name = hlo_result.get_name() + self.add_variable(result_name, result) + + def validate_shapes(hlo_shape: tuple, mil_shape: tuple): + if hlo_shape == tuple() and (mil_shape == tuple() or mil_shape == (1, )): + return True + if hlo_shape == mil_shape: + return True + + raise ValueError(f"The HLO result shape `{hlo_shape}` is different from the actual MIL result shape `{mil_shape}`") + + hlo_shape = tuple(hlo_result.type.shape) + mil_shape = tuple(result.shape) + validate_shapes(hlo_shape=hlo_shape, mil_shape=mil_shape) + def __getitem__(self, name: str): # Walk up along the path list to find the first correctly named variable in scope path = self._path.copy() @@ -204,7 +220,7 @@ def op_call(self, context: TranscriptionContext, op: CallOp): # Configure return value for result, output in zip(op.results, outputs): - context.add_variable(result.get_name(), output) + context.add_result(result, output) @register_stablehlo_op def op_return(self, context: TranscriptionContext, op: ReturnOp): @@ -222,35 +238,35 @@ def op_add(self, context: TranscriptionContext, op: AddOp): lhs = context[op.lhs.get_name()] rhs = context[op.rhs.get_name()] cml_op = mb.add(x=lhs, y=rhs) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_or(self, context: TranscriptionContext, op: OrOp): lhs = context[op.lhs.get_name()] rhs = context[op.rhs.get_name()] cml_op = mb.logical_or(x=lhs, y=rhs) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_and(self, context: TranscriptionContext, op: AndOp): lhs = context[op.lhs.get_name()] rhs = context[op.rhs.get_name()] cml_op = mb.logical_and(x=lhs, y=rhs) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_subtract(self, context: TranscriptionContext, op: SubtractOp): lhs = context[op.lhs.get_name()] rhs = context[op.rhs.get_name()] cml_op = mb.sub(x=lhs, y=rhs) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_mul(self, context: TranscriptionContext, op: MulOp): lhs = context[op.lhs.get_name()] rhs = context[op.rhs.get_name()] cml_op = mb.mul(x=lhs, y=rhs) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_div(self, context: TranscriptionContext, op: DivOp): @@ -272,7 +288,7 @@ def op_div(self, context: TranscriptionContext, op: DivOp): else: raise ValueError(f"Unknown dtype {lhs_type}") - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_neg(self, context: TranscriptionContext, op: NegOp): @@ -280,25 +296,25 @@ def op_neg(self, context: TranscriptionContext, op: NegOp): operand = context[op.operand.get_name()] minus_one = np.array([-1], dtype=types.nptype_from_builtin(operand.dtype)) cml_op = mb.mul(x=minus_one, y=operand) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_sign(self, context: TranscriptionContext, op: SignOp): operand = context[op.operand.get_name()] cml_op = mb.sign(x=operand) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_abs(self, context: TranscriptionContext, op: AbsOp): operand = context[op.operand.get_name()] cml_op = mb.abs(x=operand) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_log(self, context: TranscriptionContext, op: LogOp): operand = context[op.operand.get_name()] cml_op = mb.log(x=operand) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_log1p(self, context: TranscriptionContext, op: Log1pOp): @@ -306,38 +322,38 @@ def op_log1p(self, context: TranscriptionContext, op: Log1pOp): one = np.array([1], dtype=types.nptype_from_builtin(self.__resolve_type(operand))) x_plus_one = mb.add(x=one, y=operand) cml_op = mb.log(x=x_plus_one) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_exp(self, context: TranscriptionContext, op: ExpOp): operand = context[op.operand.get_name()] cml_op = mb.exp(x=operand) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_expm1(self, context: TranscriptionContext, op: Expm1Op): operand = context[op.operand.get_name()] cml_op = mb.add(x=mb.exp(x=operand), y=-1.0) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_transpose(self, context: TranscriptionContext, op: TransposeOp): operand = context[op.operand.get_name()] perm = np.array(op.permutation, dtype=np.int32) cml_op = mb.transpose(x=operand, perm=perm) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_sqrt(self, context: TranscriptionContext, op: SqrtOp): operand = context[op.operand.get_name()] cml_op = mb.sqrt(x=operand) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_constant(self, context: TranscriptionContext, op: ConstantOp): constant = np.array(op.value) constant = np.reshape(constant, op.result.type.shape) - context.add_variable(op.result.get_name(), constant) + context.add_result(op.result, constant) @register_stablehlo_op def op_dot_general(self, context: TranscriptionContext, op: DotGeneralOp): @@ -413,7 +429,13 @@ def calculate_result_index(lhs_idx, rhs_idx, acc): # If we added a fake dimension, we will make sure to squeeze it away if len(lhs_result_dim) == 0 and len(rhs_result_dim) == 0: - idx_result = mb.squeeze(x=idx_result, axes=(-1, -2)) + if len(idx_result.shape) == 2: + assert idx_result.shape == (1, 1) + # This is a special case, where the result is a scalar of shape (1, 1) + # In order to not end up with a 0-rank tensor, we only contract one dimension + idx_result = mb.squeeze(x=idx_result, axes=(-1, )) + else: + idx_result = mb.squeeze(x=idx_result, axes=(-1, -2)) elif len(lhs_result_dim) == 0: idx_result = mb.squeeze(x=idx_result, axes=(-2,)) elif len(rhs_result_dim) == 0: @@ -442,14 +464,14 @@ def calculate_result_index(lhs_idx, rhs_idx, acc): # we will just loop through them sequentially. result, = iterate_indexes_in_shapes(calculate_result_index, [lhs_shape, rhs_shape], [result]) - context.add_variable(op.result.get_name(), result) + context.add_result(op.result, result) @register_stablehlo_op def op_reshape(self, context: TranscriptionContext, op: ReshapeOp): x = context[op.operand.get_name()] new_shape = op.result.type.shape reshape_res = mb.reshape(x=x, shape=new_shape) - context.add_variable(op.result.get_name(), reshape_res) + context.add_result(op.result, reshape_res) @register_stablehlo_op def op_broadcast_in_dim(self, context: TranscriptionContext, op: BroadcastInDimOp): @@ -473,7 +495,7 @@ def op_broadcast_in_dim(self, context: TranscriptionContext, op: BroadcastInDimO values = [x] * result_shape[result_dim] x = mb.concat(values=values, axis=result_dim) - context.add_variable(op.result.get_name(), x) + context.add_result(op.result, x) @register_stablehlo_op def op_while(self, context: TranscriptionContext, op: WhileOp): @@ -494,7 +516,7 @@ def body(*body_args): while_results = mb.while_loop(_cond=cond, _body=body, loop_vars=loop_vars) for result_var, while_result in zip(op.results, while_results): - context.add_variable(result_var.get_name(), while_result) + context.add_result(result_var, while_result) @register_stablehlo_op def op_compare(self, context: TranscriptionContext, op: CompareOp): @@ -511,14 +533,14 @@ def op_compare(self, context: TranscriptionContext, op: CompareOp): lhs = context[op.lhs.get_name()] rhs = context[op.rhs.get_name()] cml_op = cml_op_builder(x=lhs, y=rhs) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_convert(self, context: TranscriptionContext, op: ConvertOp): x = context[op.operand.get_name()] new_dtype = self.__get_dtype(op.result.type.element_type) cml_op = mb.cast(x=x, dtype=self.__dtype_str(new_dtype)) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_select(self, context: TranscriptionContext, op: SelectOp): @@ -526,7 +548,7 @@ def op_select(self, context: TranscriptionContext, op: SelectOp): a = context[op.on_true.get_name()] b = context[op.on_false.get_name()] cml_op = mb.select(cond=cond, a=a, b=b) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_dynamic_slice(self, context: TranscriptionContext, op: DynamicSliceOp): @@ -543,7 +565,7 @@ def op_dynamic_slice(self, context: TranscriptionContext, op: DynamicSliceOp): sizes = np.array(op.slice_sizes, dtype=np.int32) cml_op = mb.slice_by_size(x=x, begin=begin, size=sizes) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_slice(self, context: TranscriptionContext, op: SliceOp): @@ -559,7 +581,7 @@ def op_slice(self, context: TranscriptionContext, op: SliceOp): end=end, stride=stride, ) - context.add_variable(op.result.get_name(), cml_op) + context.add_result(op.result, cml_op) @register_stablehlo_op def op_dynamic_update_slice(self, context: TranscriptionContext, op: DynamicUpdateSliceOp): @@ -576,7 +598,7 @@ def op_dynamic_update_slice(self, context: TranscriptionContext, op: DynamicUpda begin=start_indices, end=end_indices, ) - context.add_variable(op.result.get_name(), update_res) + context.add_result(op.result, update_res) @register_stablehlo_op def op_convolution(self, context: TranscriptionContext, op: ConvolutionOp): @@ -693,51 +715,51 @@ def op_convolution(self, context: TranscriptionContext, op: ConvolutionOp): ]) cml_conv = mb.transpose(x=cml_conv, perm=output_permutation) - context.add_variable(op.result.get_name(), cml_conv) + context.add_result(op.result, cml_conv) @register_stablehlo_op def op_max(self, context: TranscriptionContext, op: MaxOp): lhs = context[op.lhs.get_name()] rhs = context[op.rhs.get_name()] cml_res = mb.maximum(x=lhs, y=rhs) - context.add_variable(op.result.get_name(), cml_res) + context.add_result(op.result, cml_res) @register_stablehlo_op def op_min(self, context: TranscriptionContext, op: MinOp): lhs = context[op.lhs.get_name()] rhs = context[op.rhs.get_name()] cml_res = mb.minimum(x=lhs, y=rhs) - context.add_variable(op.result.get_name(), cml_res) + context.add_result(op.result, cml_res) @register_stablehlo_op def op_rsqrt(self, context: TranscriptionContext, op: RsqrtOp): x = context[op.operand.get_name()] mil_res = mb.rsqrt(x=x) - context.add_variable(op.result.get_name(), mil_res) + context.add_result(op.result, mil_res) @register_stablehlo_op def op_tanh(self, context: TranscriptionContext, op: TanhOp): x = context[op.operand.get_name()] mil_res = mb.tanh(x=x) - context.add_variable(op.result.get_name(), mil_res) + context.add_result(op.result, mil_res) @register_stablehlo_op def op_sine(self, context: TranscriptionContext, op: SineOp): x = context[op.operand.get_name()] mil_res = mb.sin(x=x) - context.add_variable(op.result.get_name(), mil_res) + context.add_result(op.result, mil_res) @register_stablehlo_op def op_cosine(self, context: TranscriptionContext, op: CosineOp): x = context[op.operand.get_name()] mil_res = mb.cos(x=x) - context.add_variable(op.result.get_name(), mil_res) + context.add_result(op.result, mil_res) @register_stablehlo_op def op_tan(self, context: TranscriptionContext, op: TanOp): x = context[op.operand.get_name()] mil_res = mb.tan(x=x) - context.add_variable(op.result.get_name(), mil_res) + context.add_result(op.result, mil_res) @register_stablehlo_op def op_atan2(self, context: TranscriptionContext, op: Atan2Op): @@ -753,20 +775,20 @@ def op_atan2(self, context: TranscriptionContext, op: Atan2Op): a=atan2_res_adjusted, b=atan2_res, ) - context.add_variable(op.result.get_name(), atan2_res) + context.add_result(op.result, atan2_res) @register_stablehlo_op def op_concatenate(self, context: TranscriptionContext, op: ConcatenateOp): values = [context[input.get_name()] for input in op.inputs] values = promote_input_dtypes(values) mil_res = mb.concat(values=values, axis=op.dimension.value) - context.add_variable(op.result.get_name(), mil_res) + context.add_result(op.result, mil_res) @register_stablehlo_op def op_reverse(self, context: TranscriptionContext, op: ReverseOp): x = context[op.operand.get_name()] mil_res = mb.reverse(x=x, axes=np.array(op.dimensions, dtype=np.int32)) - context.add_variable(op.result.get_name(), mil_res) + context.add_result(op.result, mil_res) @register_stablehlo_op def op_isfinite(self, context: TranscriptionContext, op: IsFiniteOp): @@ -774,7 +796,7 @@ def op_isfinite(self, context: TranscriptionContext, op: IsFiniteOp): # All finite numbers will have abs(x) < inf infinity = np.array(np.inf, dtype=types.nptype_from_builtin(self.__resolve_type(x))) mil_res = mb.less(x=mb.abs(x=x), y=infinity) - context.add_variable(op.result.get_name(), mil_res) + context.add_result(op.result, mil_res) @register_stablehlo_op def op_reduce(self, context: TranscriptionContext, op: ReduceOp): @@ -789,7 +811,7 @@ def op_reduce(self, context: TranscriptionContext, op: ReduceOp): mil_results = self.__compute_reduction(context, inputs, op.dimensions, op.body, init_values, result_types) for (res, mil_res) in zip(op.results, mil_results): - context.add_variable(res.get_name(), mil_res) + context.add_result(res, mil_res) @register_stablehlo_op def op_reduce_window(self, context: TranscriptionContext, op: ReduceWindowOp): @@ -893,7 +915,7 @@ def compute_reduction(result_idx, *partial_results): ] for (res, mil_res) in zip(op.results, reduction_results): - context.add_variable(res.get_name(), mil_res) + context.add_result(res, mil_res) @register_stablehlo_op def op_iota(self, context: TranscriptionContext, op: IotaOp): @@ -902,7 +924,7 @@ def op_iota(self, context: TranscriptionContext, op: IotaOp): vec_shape = [tensor_shape[dim] if dim == iota_dim else 1 for dim in range(len(tensor_shape))] dtype = types.nptype_from_builtin(self.__get_dtype(op.result.type.element_type)) res = np.reshape(np.arange(tensor_shape[iota_dim], dtype=dtype), vec_shape) * np.ones(tensor_shape, dtype=dtype) - context.add_variable(op.result.get_name(), res) + context.add_result(op.result, res) @register_stablehlo_op def op_custom_call(self, context: TranscriptionContext, op: CustomCallOp): @@ -926,7 +948,7 @@ def op_custom_call(self, context: TranscriptionContext, op: CustomCallOp): # the `delegate_op` results according to the custom call results mil_results = op_impl(context, delegate_op) for (custom_call_result, mil_result) in zip(op.results, mil_results): - context.add_variable(custom_call_result.get_name(), mil_result) + context.add_result(custom_call_result, mil_result) return @@ -948,7 +970,7 @@ def __invoke_hlo_function(self, context: TranscriptionContext, func_name: str, h # Setup arguments for the function for hlo_func_param, actual_arg in zip(hlo_params, cml_args): - context.add_variable(hlo_func_param.get_name(), actual_arg) + context.add_result(hlo_func_param, actual_arg) # Process the function if len(hlo_func_body.blocks) != 1: diff --git a/tests/test_equinox.py b/tests/test_equinox.py index bb3adc5..84c244f 100644 --- a/tests/test_equinox.py +++ b/tests/test_equinox.py @@ -5,9 +5,11 @@ from tests.test_jax import run_and_compare +from functools import partial + def run_and_compare_eqx(model, input_spec): - return run_and_compare(eqxi.finalise_fn(model), input_spec) + return run_and_compare(eqxi.finalise_fn(eqx.nn.inference_mode(model)), input_spec) def test_conv_1d(): @@ -243,3 +245,127 @@ def test_3d_polling(): # Due to the CoreML rank <= 5 condition, the result can unfortunately not fit in a tensor # run_and_compare_eqx(jax.vmap(eqx.nn.AvgPool3d(kernel_size=(5, 4, 3))), (jnp.zeros((10, channels, 41, 21, 10)), )) + + +def test_layernorm(): + batch_size = 3 + input_shape = (10, 3) + run_and_compare_eqx(jax.vmap(eqx.nn.LayerNorm(shape=input_shape)), (jnp.zeros((batch_size, *input_shape)), )) + + +def test_rmsnorm(): + batch_size = 3 + input_shape = (10, 3) + run_and_compare_eqx(jax.vmap(eqx.nn.RMSNorm(shape=input_shape)), (jnp.zeros((batch_size, *input_shape)), )) + + +def test_groupnorm(): + batch_size = 3 + input_shape = (4, 12) + run_and_compare_eqx( + jax.vmap(eqx.nn.GroupNorm(groups=4, channelwise_affine=False)), + (jnp.zeros((batch_size, *input_shape)), ) + ) + run_and_compare_eqx( + jax.vmap(eqx.nn.GroupNorm(groups=2, channels=4)), + (jnp.zeros((batch_size, *input_shape)), ) + ) + + +# Unfortunately this test currently fails due to https://github.com/llvm/llvm-project/pull/113064 +# def test_batchnorm(): +# batch_size = 3 +# input_shape = (4, 12) + +# class Model(eqx.Module): +# batch_norm: eqx.nn.BatchNorm + +# def __init__(self, wrapping_layer: eqx.Module, key: jax.random.PRNGKey): +# self.v = eqx.nn.BatchNorm(input_size=4, axis_name="batch") + +# def __call__(self, x, state): +# out, _state = self.batch_norm(x, state) +# return out + +# model, state = eqx.nn.make_with_state(eqx.nn.BatchNorm)(input_size=4, axis_name="batch") +# batched_model = jax.vmap(partial(model, state=state), axis_name="batch") +# run_and_compare_eqx(batched_model, (jnp.zeros((batch_size, *input_shape)), )) + + +def test_spectralnorm(): + batch_size = 5 + + class Model(eqx.Module): + spectral_norm: eqx.nn.SpectralNorm[eqx.Module] + + def __init__(self, wrapping_layer: eqx.Module, key: jax.random.PRNGKey): + self.spectral_norm = eqx.nn.SpectralNorm( + layer=wrapping_layer, + weight_name="weight", + key=key, + ) + + def __call__(self, x, state): + out, _state = self.spectral_norm(x, state) + return out + + wrapping_key, model_key = jax.random.split(jax.random.PRNGKey(0), 2) + + # Linear wrapping layer + model, state = eqx.nn.make_with_state(Model)( + wrapping_layer=eqx.nn.Linear(in_features=12, out_features=24, key=wrapping_key), + key=model_key, + ) + batched_model = jax.vmap(partial(model, state=state)) + run_and_compare_eqx(batched_model, (jnp.zeros((batch_size, 12)), )) + + # Convolutional 1d wrapping layer + model, state = eqx.nn.make_with_state(Model)( + wrapping_layer=eqx.nn.Conv1d(in_channels=12, out_channels=24, kernel_size=3, key=wrapping_key), + key=model_key, + ) + batched_model = jax.vmap(partial(model, state=state)) + run_and_compare_eqx(batched_model, (jnp.zeros((batch_size, 12, 31)), )) + + # Convolutional 2d wrapping layer + model, state = eqx.nn.make_with_state(Model)( + wrapping_layer=eqx.nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, key=wrapping_key), + key=model_key, + ) + batched_model = jax.vmap(partial(model, state=state)) + run_and_compare_eqx(batched_model, (jnp.zeros((batch_size, 12, 31, 15)), )) + + +def test_weightnorm(): + batch_size = 5 + key = jax.random.PRNGKey(0) + + class Model(eqx.Module): + weight_norm: eqx.nn.WeightNorm[eqx.Module] + + def __init__(self, wrapping_layer: eqx.Module): + self.weight_norm = eqx.nn.WeightNorm( + layer=wrapping_layer, + weight_name="weight", + ) + + def __call__(self, x): + return self.weight_norm(x) + + # Linear wrapping layer + model = jax.vmap(Model( + wrapping_layer=eqx.nn.Linear(in_features=12, out_features=24, key=key), + )) + run_and_compare_eqx(model, (jnp.zeros((batch_size, 12)), )) + + # Convolutional 1d wrapping layer + model = jax.vmap(Model( + wrapping_layer=eqx.nn.Conv1d(in_channels=12, out_channels=24, kernel_size=3, key=key), + )) + run_and_compare_eqx(model, (jnp.zeros((batch_size, 12, 31)), )) + + # Convolutional 2d wrapping layer + model = jax.vmap(Model( + wrapping_layer=eqx.nn.Conv2d(in_channels=12, out_channels=24, kernel_size=3, key=key), + )) + run_and_compare_eqx(model, (jnp.zeros((batch_size, 12, 31, 15)), ))