diff --git a/requirements.txt b/requirements.txt index 38d4414..3202e7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,5 @@ lit pybind11 pytest transformers +fsspec +einops diff --git a/tests/jax/models/mlpmixer/__init__.py b/tests/jax/models/mlpmixer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/jax/models/mlpmixer/model_implementation.py b/tests/jax/models/mlpmixer/model_implementation.py new file mode 100644 index 0000000..03679e4 --- /dev/null +++ b/tests/jax/models/mlpmixer/model_implementation.py @@ -0,0 +1,82 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +# This file incorporates work covered by the following copyright and permission +# notice: +# SPDX-FileCopyrightText: Copyright 2024 Google LLC. +# SPDX-License-Identifier: Apache-2.0 + +# This code is based on google-research/vision_transformer + +from typing import Any, Optional + +import einops +import flax.linen as nn +import jax.numpy as jnp +import jax + + +class MlpBlock(nn.Module): + mlp_dim: int + + @nn.compact + def __call__(self, x: jax.Array) -> jax.Array: + y = nn.Dense(self.mlp_dim)(x) + y = nn.gelu(y) + return nn.Dense(x.shape[-1])(y) + + +class MixerBlock(nn.Module): + """Mixer block layer.""" + + tokens_mlp_dim: int + channels_mlp_dim: int + + @nn.compact + def __call__(self, x: jax.Array) -> jax.Array: + y = nn.LayerNorm()(x) + y = jnp.swapaxes(y, 1, 2) + y = MlpBlock(self.tokens_mlp_dim, name="token_mixing")(y) + y = jnp.swapaxes(y, 1, 2) + x = x + y + + y = nn.LayerNorm()(x) + y = MlpBlock(self.channels_mlp_dim, name="channel_mixing")(y) + y = x + y + + return y + + +class MlpMixer(nn.Module): + """Mixer architecture.""" + + patches: Any + num_classes: int + num_blocks: int + hidden_dim: int + tokens_mlp_dim: int + channels_mlp_dim: int + model_name: Optional[str] = None + + @nn.compact + def __call__(self, inputs: jax.Array) -> jax.Array: + x = nn.Conv( + self.hidden_dim, self.patches.size, strides=self.patches.size, name="stem" + )( + inputs + ) # Patch embedding + x = einops.rearrange(x, "n h w c -> n (h w) c") + + for _ in range(self.num_blocks): + x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x) + + x = nn.LayerNorm(name="pre_head_layer_norm")(x) + x = jnp.mean(x, axis=1) + + if self.num_classes: + x = nn.Dense( + self.num_classes, kernel_init=nn.initializers.zeros, name="head" + )(x) + + return x diff --git a/tests/jax/models/mlpmixer/test_mlpmixer.py b/tests/jax/models/mlpmixer/test_mlpmixer.py new file mode 100644 index 0000000..fd8c1fa --- /dev/null +++ b/tests/jax/models/mlpmixer/test_mlpmixer.py @@ -0,0 +1,104 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +from typing import Any, Dict, Sequence + +import flax.traverse_util +import fsspec +import jax +import jax.numpy as jnp +import numpy +import pytest +from flax import linen as nn +from infra import ModelTester, RunMode + +from .model_implementation import MlpMixer + +# Hyperparameters for Mixer-B/16 +patch_size = 16 +num_classes = 21843 +num_blocks = 12 +hidden_dim = 768 +token_mlp_dim = 384 +channel_mlp_dim = 3072 + + +class MlpMixerTester(ModelTester): + """Tester for MlpMixer model.""" + + # @override + def _get_model(self) -> nn.Module: + patch = jnp.ones((patch_size, patch_size)) + return MlpMixer( + patches=patch, + num_classes=num_classes, + num_blocks=num_blocks, + hidden_dim=hidden_dim, + tokens_mlp_dim=token_mlp_dim, + channels_mlp_dim=channel_mlp_dim, + ) + + @staticmethod + def _retrieve_pretrained_weights() -> Dict: + # TODO(stefan): Discuss how weights should be handled org wide + link = "https://storage.googleapis.com/mixer_models/imagenet21k/Mixer-B_16.npz" + with fsspec.open("filecache::" + link, cache_storage="/tmp/files/") as f: + weights = numpy.load(f, encoding="bytes") + state_dict = {k: v for k, v in weights.items()} + pytree = flax.traverse_util.unflatten_dict(state_dict, sep="/") + return {"params": pytree} + + # @override + def _get_forward_method_name(self) -> str: + return "apply" + + # @override + def _get_input_activations(self) -> jax.Array: + key = jax.random.PRNGKey(42) + random_image = jax.random.normal(key, (1, 196, 196, 3)) + return random_image + + # @override + def _get_forward_method_args(self) -> Sequence[Any]: + ins = self._get_input_activations() + weights = self._retrieve_pretrained_weights() + + # Required to bypass "Initializer expected to generate shape (16, 16, 3, 768) but got shape (256, 3, 768)" + kernel = weights["params"]["stem"]["kernel"] + kernel = kernel.reshape(-1, 3, hidden_dim) + weights["params"]["stem"]["kernel"] = kernel + + # Alternatively, weights could be randomly initialized like this: + # weights = self._model.init(jax.random.PRNGKey(42), ins) + + # JAX frameworks have a convention of passing weights as the first argument + return [weights, ins] + + +# ----- Fixtures ----- + + +@pytest.fixture +def inference_tester() -> MlpMixerTester: + return MlpMixerTester() + + +@pytest.fixture +def training_tester() -> MlpMixerTester: + return MlpMixerTester(RunMode.TRAINING) + + +# ----- Tests ----- + + +@pytest.mark.skip( + reason="error: failed to legalize operation 'ttir.convolution' that was explicitly marked illegal" +) +def test_mlpmixer(inference_tester: MlpMixerTester): + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_mlpmixer_training(training_tester: MlpMixerTester): + training_tester.test()