From 236e2ed7b3f6bf760c98326719dc5825ac20c02a Mon Sep 17 00:00:00 2001 From: sgligorijevicTT <189116645+sgligorijevicTT@users.noreply.github.com> Date: Wed, 8 Jan 2025 11:47:46 +0000 Subject: [PATCH 1/4] Add MLPMixer test --- requirements.txt | 2 + tests/jax/models/mlpmixer/__init__.py | 0 .../models/mlpmixer/model_implementation.py | 81 ++++++++++++++ tests/jax/models/mlpmixer/test_mlpmixer.py | 101 ++++++++++++++++++ tests/jax/models/mlpmixer/util.py | 32 ++++++ 5 files changed, 216 insertions(+) create mode 100644 tests/jax/models/mlpmixer/__init__.py create mode 100644 tests/jax/models/mlpmixer/model_implementation.py create mode 100644 tests/jax/models/mlpmixer/test_mlpmixer.py create mode 100644 tests/jax/models/mlpmixer/util.py diff --git a/requirements.txt b/requirements.txt index 38d4414..3202e7c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,3 +13,5 @@ lit pybind11 pytest transformers +fsspec +einops diff --git a/tests/jax/models/mlpmixer/__init__.py b/tests/jax/models/mlpmixer/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/jax/models/mlpmixer/model_implementation.py b/tests/jax/models/mlpmixer/model_implementation.py new file mode 100644 index 0000000..f41fc62 --- /dev/null +++ b/tests/jax/models/mlpmixer/model_implementation.py @@ -0,0 +1,81 @@ +# SPDX-FileCopyrightText: (c) 2024 Google LLC +# +# SPDX-License-Identifier: Apache-2.0 + +# Taken from https://github.com/google-research/vision_transformer/blob/c6de1e5378c9831a8477feb30994971bdc409e46/vit_jax/models_mixer.py + +# Copyright 2024 Google LLC. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import Any, Optional + +import einops +import flax.linen as nn +import jax.numpy as jnp + + +class MlpBlock(nn.Module): + mlp_dim: int + + @nn.compact + def __call__(self, x): + y = nn.Dense(self.mlp_dim)(x) + y = nn.gelu(y) + return nn.Dense(x.shape[-1])(y) + + +class MixerBlock(nn.Module): + """Mixer block layer.""" + + tokens_mlp_dim: int + channels_mlp_dim: int + + @nn.compact + def __call__(self, x): + y = nn.LayerNorm()(x) + y = jnp.swapaxes(y, 1, 2) + y = MlpBlock(self.tokens_mlp_dim, name="token_mixing")(y) + y = jnp.swapaxes(y, 1, 2) + x = x + y + y = nn.LayerNorm()(x) + return x + MlpBlock(self.channels_mlp_dim, name="channel_mixing")(y) + + +class MlpMixer(nn.Module): + """Mixer architecture.""" + + patches: Any + num_classes: int + num_blocks: int + hidden_dim: int + tokens_mlp_dim: int + channels_mlp_dim: int + model_name: Optional[str] = None + + @nn.compact + def __call__(self, inputs, train): + del train + x = nn.Conv( + self.hidden_dim, self.patches.size, strides=self.patches.size, name="stem" + )(inputs) + x = einops.rearrange(x, "n h w c -> n (h w) c") + for _ in range(self.num_blocks): + x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x) + x = nn.LayerNorm(name="pre_head_layer_norm")(x) + x = jnp.mean(x, axis=1) + if self.num_classes: + x = nn.Dense( + self.num_classes, kernel_init=nn.initializers.zeros, name="head" + )(x) + return x diff --git a/tests/jax/models/mlpmixer/test_mlpmixer.py b/tests/jax/models/mlpmixer/test_mlpmixer.py new file mode 100644 index 0000000..79406b7 --- /dev/null +++ b/tests/jax/models/mlpmixer/test_mlpmixer.py @@ -0,0 +1,101 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +from typing import Dict, Sequence + + +import jax +import jax.numpy as jnp +import numpy +import pytest +import fsspec +from flax import linen as nn +from infra import ModelTester, RunMode +from .model_implementation import MlpMixer +from .util import build_pytee_from_npy + + +# hypers +patch_size = 16 +num_classes = 21843 +num_blocks = 12 +hidden_dim = 768 +token_mlp_dim = 384 +channel_mlp_dim = 3072 + + +def Mixer_B_16_pretrained(): + # TODO(stefan): Discuss how weights should be handled org wide + link = "https://storage.googleapis.com/mixer_models/imagenet21k/Mixer-B_16.npz" + with fsspec.open("filecache::" + link, cache_storage="/tmp/files/") as f: + weights = numpy.load(f, encoding="bytes") + pytree = build_pytee_from_npy(weights) + return pytree + + +class MlpMixerTester(ModelTester): + """Tester for MlpMixer model.""" + + # @override + def _get_model(self) -> nn.Module: + patch = jnp.ones((patch_size, patch_size)) + return MlpMixer( + patches=patch, + num_classes=num_classes, + num_blocks=num_blocks, + hidden_dim=hidden_dim, + tokens_mlp_dim=token_mlp_dim, + channels_mlp_dim=channel_mlp_dim, + ) + + # @override + def _get_forward_method_name(self) -> str: + return "apply" + + # @override + def _get_input_activations(self) -> Sequence[jax.Array]: + key = jax.random.PRNGKey(42) + random_image = jax.random.normal(key, (1, 196, 196, 3)) + return random_image + + # @override + def _get_forward_method_args(self): + ins = self._get_input_activations() + weights = Mixer_B_16_pretrained() + # Required to bypass "Initializer expected to generate shape (16, 16, 3, 768) but got shape (256, 3, 768)" + kernel = weights["params"]["stem"]["kernel"] + kernel = kernel.reshape(-1, 3, hidden_dim) + weights["params"]["stem"]["kernel"] = kernel + + # Alternatively, weights could be randomly initialized like this: + # weights = self._model.init(jax.random.PRNGKey(42), ins, train=False) + + return [weights, ins] + + # @override + def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: + return {"train": False} + + +# ----- Fixtures ----- +@pytest.fixture +def inference_tester() -> MlpMixerTester: + return MlpMixerTester() + + +@pytest.fixture +def training_tester() -> MlpMixerTester: + return MlpMixerTester(RunMode.TRAINING) + + +# ----- Tests ----- +@pytest.mark.skip( + reason="error: failed to legalize operation 'ttir.convolution' that was explicitly marked illegal" +) +def test_mlpmixer(inference_tester: MlpMixerTester): + inference_tester.test() + + +@pytest.mark.skip(reason="Support for training not implemented") +def test_mlpmixer_training(training_tester: MlpMixerTester): + training_tester.test() diff --git a/tests/jax/models/mlpmixer/util.py b/tests/jax/models/mlpmixer/util.py new file mode 100644 index 0000000..62f79a6 --- /dev/null +++ b/tests/jax/models/mlpmixer/util.py @@ -0,0 +1,32 @@ +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +from collections import defaultdict +from jax import numpy as jnp + + +def ddict(): + return defaultdict(ddict) + + +def defaultdict_to_dict(d): + """Recursively convert defaultdicts to dicts.""" + if isinstance(d, defaultdict): + d = {k: defaultdict_to_dict(v) for k, v in d.items()} + elif isinstance(d, dict): + d = {k: defaultdict_to_dict(v) for k, v in d.items()} + return d + + +# TODO(stefan): Similar logic might be needed in other places later +# generalize and move to infra once it's needed +def build_pytee_from_npy(npfile): + """Convert a file from numpy.load with keys of form a/b/c... into a pytree""" + weights = ddict() + for name, w in npfile.items(): + keys = list(name.split("/")) + subdict = weights + for key in keys[:-1]: + subdict = subdict[key] + subdict[keys[-1]] = jnp.array(w) + return {"params": defaultdict_to_dict(weights)} From 64ce76b542622c46724be1b1bb95e5a60a5fbf1c Mon Sep 17 00:00:00 2001 From: sgligorijevicTT <189116645+sgligorijevicTT@users.noreply.github.com> Date: Thu, 9 Jan 2025 12:42:26 +0000 Subject: [PATCH 2/4] Don't reinvent the wheel --- tests/jax/models/mlpmixer/test_mlpmixer.py | 7 +++-- tests/jax/models/mlpmixer/util.py | 32 ---------------------- 2 files changed, 4 insertions(+), 35 deletions(-) delete mode 100644 tests/jax/models/mlpmixer/util.py diff --git a/tests/jax/models/mlpmixer/test_mlpmixer.py b/tests/jax/models/mlpmixer/test_mlpmixer.py index 79406b7..0c625a1 100644 --- a/tests/jax/models/mlpmixer/test_mlpmixer.py +++ b/tests/jax/models/mlpmixer/test_mlpmixer.py @@ -10,9 +10,9 @@ import pytest import fsspec from flax import linen as nn +import flax.traverse_util from infra import ModelTester, RunMode from .model_implementation import MlpMixer -from .util import build_pytee_from_npy # hypers @@ -29,8 +29,9 @@ def Mixer_B_16_pretrained(): link = "https://storage.googleapis.com/mixer_models/imagenet21k/Mixer-B_16.npz" with fsspec.open("filecache::" + link, cache_storage="/tmp/files/") as f: weights = numpy.load(f, encoding="bytes") - pytree = build_pytee_from_npy(weights) - return pytree + state_dict = {k: v for k, v in weights.items()} + pytree = flax.traverse_util.unflatten_dict(state_dict, sep="/") + return {"params": pytree} class MlpMixerTester(ModelTester): diff --git a/tests/jax/models/mlpmixer/util.py b/tests/jax/models/mlpmixer/util.py deleted file mode 100644 index 62f79a6..0000000 --- a/tests/jax/models/mlpmixer/util.py +++ /dev/null @@ -1,32 +0,0 @@ -# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC -# -# SPDX-License-Identifier: Apache-2.0 -from collections import defaultdict -from jax import numpy as jnp - - -def ddict(): - return defaultdict(ddict) - - -def defaultdict_to_dict(d): - """Recursively convert defaultdicts to dicts.""" - if isinstance(d, defaultdict): - d = {k: defaultdict_to_dict(v) for k, v in d.items()} - elif isinstance(d, dict): - d = {k: defaultdict_to_dict(v) for k, v in d.items()} - return d - - -# TODO(stefan): Similar logic might be needed in other places later -# generalize and move to infra once it's needed -def build_pytee_from_npy(npfile): - """Convert a file from numpy.load with keys of form a/b/c... into a pytree""" - weights = ddict() - for name, w in npfile.items(): - keys = list(name.split("/")) - subdict = weights - for key in keys[:-1]: - subdict = subdict[key] - subdict[keys[-1]] = jnp.array(w) - return {"params": defaultdict_to_dict(weights)} From 0cc94aaa19c246f21cfe86929ee25d5ee018f44a Mon Sep 17 00:00:00 2001 From: sgligorijevicTT <189116645+sgligorijevicTT@users.noreply.github.com> Date: Fri, 10 Jan 2025 15:05:07 +0000 Subject: [PATCH 3/4] Address PR comments --- .../models/mlpmixer/model_implementation.py | 43 ++++++++++--------- tests/jax/models/mlpmixer/test_mlpmixer.py | 42 +++++++++--------- 2 files changed, 44 insertions(+), 41 deletions(-) diff --git a/tests/jax/models/mlpmixer/model_implementation.py b/tests/jax/models/mlpmixer/model_implementation.py index f41fc62..03679e4 100644 --- a/tests/jax/models/mlpmixer/model_implementation.py +++ b/tests/jax/models/mlpmixer/model_implementation.py @@ -1,35 +1,27 @@ -# SPDX-FileCopyrightText: (c) 2024 Google LLC +# SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC # # SPDX-License-Identifier: Apache-2.0 -# Taken from https://github.com/google-research/vision_transformer/blob/c6de1e5378c9831a8477feb30994971bdc409e46/vit_jax/models_mixer.py +# This file incorporates work covered by the following copyright and permission +# notice: +# SPDX-FileCopyrightText: Copyright 2024 Google LLC. +# SPDX-License-Identifier: Apache-2.0 -# Copyright 2024 Google LLC. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. +# This code is based on google-research/vision_transformer from typing import Any, Optional import einops import flax.linen as nn import jax.numpy as jnp +import jax class MlpBlock(nn.Module): mlp_dim: int @nn.compact - def __call__(self, x): + def __call__(self, x: jax.Array) -> jax.Array: y = nn.Dense(self.mlp_dim)(x) y = nn.gelu(y) return nn.Dense(x.shape[-1])(y) @@ -42,14 +34,18 @@ class MixerBlock(nn.Module): channels_mlp_dim: int @nn.compact - def __call__(self, x): + def __call__(self, x: jax.Array) -> jax.Array: y = nn.LayerNorm()(x) y = jnp.swapaxes(y, 1, 2) y = MlpBlock(self.tokens_mlp_dim, name="token_mixing")(y) y = jnp.swapaxes(y, 1, 2) x = x + y + y = nn.LayerNorm()(x) - return x + MlpBlock(self.channels_mlp_dim, name="channel_mixing")(y) + y = MlpBlock(self.channels_mlp_dim, name="channel_mixing")(y) + y = x + y + + return y class MlpMixer(nn.Module): @@ -64,18 +60,23 @@ class MlpMixer(nn.Module): model_name: Optional[str] = None @nn.compact - def __call__(self, inputs, train): - del train + def __call__(self, inputs: jax.Array) -> jax.Array: x = nn.Conv( self.hidden_dim, self.patches.size, strides=self.patches.size, name="stem" - )(inputs) + )( + inputs + ) # Patch embedding x = einops.rearrange(x, "n h w c -> n (h w) c") + for _ in range(self.num_blocks): x = MixerBlock(self.tokens_mlp_dim, self.channels_mlp_dim)(x) + x = nn.LayerNorm(name="pre_head_layer_norm")(x) x = jnp.mean(x, axis=1) + if self.num_classes: x = nn.Dense( self.num_classes, kernel_init=nn.initializers.zeros, name="head" )(x) + return x diff --git a/tests/jax/models/mlpmixer/test_mlpmixer.py b/tests/jax/models/mlpmixer/test_mlpmixer.py index 0c625a1..9f7cd6d 100644 --- a/tests/jax/models/mlpmixer/test_mlpmixer.py +++ b/tests/jax/models/mlpmixer/test_mlpmixer.py @@ -1,8 +1,8 @@ # SPDX-FileCopyrightText: (c) 2025 Tenstorrent AI ULC # # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Sequence +from typing import Dict, Sequence, Any import jax import jax.numpy as jnp @@ -11,11 +11,12 @@ import fsspec from flax import linen as nn import flax.traverse_util + from infra import ModelTester, RunMode from .model_implementation import MlpMixer -# hypers +# Hyperparameters for Mixer-B/16 patch_size = 16 num_classes = 21843 num_blocks = 12 @@ -24,16 +25,6 @@ channel_mlp_dim = 3072 -def Mixer_B_16_pretrained(): - # TODO(stefan): Discuss how weights should be handled org wide - link = "https://storage.googleapis.com/mixer_models/imagenet21k/Mixer-B_16.npz" - with fsspec.open("filecache::" + link, cache_storage="/tmp/files/") as f: - weights = numpy.load(f, encoding="bytes") - state_dict = {k: v for k, v in weights.items()} - pytree = flax.traverse_util.unflatten_dict(state_dict, sep="/") - return {"params": pytree} - - class MlpMixerTester(ModelTester): """Tester for MlpMixer model.""" @@ -49,36 +40,45 @@ def _get_model(self) -> nn.Module: channels_mlp_dim=channel_mlp_dim, ) + @staticmethod + def _retrieve_pretrained_weights() -> Dict: + # TODO(stefan): Discuss how weights should be handled org wide + link = "https://storage.googleapis.com/mixer_models/imagenet21k/Mixer-B_16.npz" + with fsspec.open("filecache::" + link, cache_storage="/tmp/files/") as f: + weights = numpy.load(f, encoding="bytes") + state_dict = {k: v for k, v in weights.items()} + pytree = flax.traverse_util.unflatten_dict(state_dict, sep="/") + return {"params": pytree} + # @override def _get_forward_method_name(self) -> str: return "apply" # @override - def _get_input_activations(self) -> Sequence[jax.Array]: + def _get_input_activations(self) -> jax.Array: key = jax.random.PRNGKey(42) random_image = jax.random.normal(key, (1, 196, 196, 3)) return random_image # @override - def _get_forward_method_args(self): + def _get_forward_method_args(self) -> Sequence[Any]: ins = self._get_input_activations() - weights = Mixer_B_16_pretrained() + weights = self._retrieve_pretrained_weights() + # Required to bypass "Initializer expected to generate shape (16, 16, 3, 768) but got shape (256, 3, 768)" kernel = weights["params"]["stem"]["kernel"] kernel = kernel.reshape(-1, 3, hidden_dim) weights["params"]["stem"]["kernel"] = kernel # Alternatively, weights could be randomly initialized like this: - # weights = self._model.init(jax.random.PRNGKey(42), ins, train=False) + # weights = self._model.init(jax.random.PRNGKey(42), ins) return [weights, ins] - # @override - def _get_forward_method_kwargs(self) -> Dict[str, jax.Array]: - return {"train": False} - # ----- Fixtures ----- + + @pytest.fixture def inference_tester() -> MlpMixerTester: return MlpMixerTester() @@ -90,6 +90,8 @@ def training_tester() -> MlpMixerTester: # ----- Tests ----- + + @pytest.mark.skip( reason="error: failed to legalize operation 'ttir.convolution' that was explicitly marked illegal" ) From 297c2bdf762a8ccad1dc7bb99f559d713083aa19 Mon Sep 17 00:00:00 2001 From: sgligorijevicTT <189116645+sgligorijevicTT@users.noreply.github.com> Date: Mon, 13 Jan 2025 13:01:56 +0000 Subject: [PATCH 4/4] Add comment --- tests/jax/models/mlpmixer/test_mlpmixer.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/tests/jax/models/mlpmixer/test_mlpmixer.py b/tests/jax/models/mlpmixer/test_mlpmixer.py index 9f7cd6d..fd8c1fa 100644 --- a/tests/jax/models/mlpmixer/test_mlpmixer.py +++ b/tests/jax/models/mlpmixer/test_mlpmixer.py @@ -2,19 +2,18 @@ # # SPDX-License-Identifier: Apache-2.0 -from typing import Dict, Sequence, Any +from typing import Any, Dict, Sequence +import flax.traverse_util +import fsspec import jax import jax.numpy as jnp import numpy import pytest -import fsspec from flax import linen as nn -import flax.traverse_util - from infra import ModelTester, RunMode -from .model_implementation import MlpMixer +from .model_implementation import MlpMixer # Hyperparameters for Mixer-B/16 patch_size = 16 @@ -73,6 +72,7 @@ def _get_forward_method_args(self) -> Sequence[Any]: # Alternatively, weights could be randomly initialized like this: # weights = self._model.init(jax.random.PRNGKey(42), ins) + # JAX frameworks have a convention of passing weights as the first argument return [weights, ins]