Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convnext architecture dev #356

Open
wants to merge 9 commits into
base: development
Choose a base branch
from
Open
5 changes: 5 additions & 0 deletions tensorflow_similarity/architectures/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,8 @@
from .efficientnet import EfficientNetSim # noqa
from .resnet18 import ResNet18Sim # noqa
from .resnet50 import ResNet50Sim # noqa

try:
from .convnext import ConvNeXtSim # noqa
except ImportError:
print("Warning: ConvNeXtSim not imported. This requires TensorFlow 2.10 or higher.")
143 changes: 143 additions & 0 deletions tensorflow_similarity/architectures/convnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# Copyright 2021 The TensorFlow Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"ConvNeXt backbone for similarity learning"
from __future__ import annotations

import re

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras.applications import convnext

from tensorflow_similarity.layers import GeneralizedMeanPooling2D, MetricEmbedding
from tensorflow_similarity.models import SimilarityModel

CONVNEXT_ARCHITECTURE = {
"TINY": convnext.ConvNeXtTiny,
"SMALL": convnext.ConvNeXtSmall,
"BASE": convnext.ConvNeXtBase,
"LARGE": convnext.ConvNeXtLarge,
"XLARGE": convnext.ConvNeXtXLarge,
}


def ConvNeXtSim(
input_shape: tuple[int, int, int],
embedding_size: int = 128,
variant: str = "BASE",
weights: str = "imagenet",
trainable: str = "frozen",
l2_norm: bool = True,
include_top: bool = True,
pooling: str = "gem",
gem_p: float = 3.0,
) -> SimilarityModel:
""" "Build an ConvNeXt Model backbone for similarity learning
[A ConvNet for the 2020s](https://arxiv.org/pdf/2201.03545.pdf)
Args:
input_shape: Size of the input image. Must match size of ConvNeXt version you use.
See below for version input size.
embedding_size: Size of the output embedding. Usually between 64
and 512. Defaults to 128.
variant: Which Variant of the ConvNeXt to use. Defaults to "BASE".
weights: Use pre-trained weights - the only available currently being
imagenet. Defaults to "imagenet".
trainable: Make the ConvNeXt backbone fully trainable or partially
trainable.
- "full" to make the entire backbone trainable,
- "partial" to only make the last 3 block trainable
- "frozen" to make it not trainable.
l2_norm: If True and include_top is also True, then
tfsim.layers.MetricEmbedding is used as the last layer, otherwise
keras.layers.Dense is used. This should be true when using cosine
distance. Defaults to True.
include_top: Whether to include the fully-connected layer at the top
of the network. Defaults to True.
pooling: Optional pooling mode for feature extraction when
include_top is False. Defaults to gem.
- None means that the output of the model will be the 4D tensor
output of the last convolutional layer.
- avg means that global average pooling will be applied to the
output of the last convolutional layer, and thus the output of the
model will be a 2D tensor.
- max means that global max pooling will be applied.
- gem means that global GeneralizedMeanPooling2D will be applied.
The gem_p param sets the contrast amount on the pooling.
gem_p: Sets the power in the GeneralizedMeanPooling2D layer. A value
of 1.0 is equivalent to GlobalMeanPooling2D, while larger values
will increase the contrast between activations within each feature
map, and a value of math.inf will be equivalent to MaxPool2d.
"""
inputs = layers.Input(shape=input_shape)
x = inputs

if variant not in CONVNEXT_ARCHITECTURE:
raise ValueError("Unknown ConvNeXt variant. Valid TINY BASE LARGE SMALL XLARGE")

x = build_convnext(variant, weights, trainable)(x)

if pooling == "gem":
x = GeneralizedMeanPooling2D(p=gem_p, name="gem_pool")(x)
elif pooling == "avg":
x = layers.GlobalAveragePooling2D(name="avg_pool")(x)
elif pooling == "max":
x = layers.GlobalMaxPooling2D(name="max_pool")(x)

if include_top and pooling is not None:
if l2_norm:
outputs = MetricEmbedding(embedding_size)(x)
else:
outputs = layers.Dense(embedding_size)(x)
else:
outputs = x

return SimilarityModel(inputs, outputs)


def build_convnext(variant: str, weights: str | None = None, trainable: str = "full") -> tf.keras.Model:
"""Build the requested ConvNeXt

Args:
variant: Which Variant of the ConvNeXt to use.
weights: Use pre-trained weights - the only available currently being
imagenet.
trainable: Make the ConvNeXt backbone fully trainable or partially
trainable.
- "full" to make the entire backbone trainable,
- "partial" to only make the last 3 block trainable
- "frozen" to make it not trainable.
Returns:
The output layer of the convnext model
"""
convnext_fn = CONVNEXT_ARCHITECTURE[variant.upper()]
convnext = convnext_fn(weights=weights, include_top=False)

if trainable == "full":
convnext.trainable = True
elif trainable == "partial":
convnext.trainable = True
for layer in convnext.layers:
# freeze all layeres befor the last 3 blocks
if not re.search("^block[5,6,7]|^top", layer.name):
Copy link

@erikreed erikreed Sep 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm also trying out this architecture. But does this EfficientNetV2 layer naming apply to convnext?

model = tf.keras.applications.ConvNeXtBase()
[l.name for l in model.layers if re.search("^block[5,6,7]|^top", l.name)]
# this outputs []

The test also suggests partial is not being applied as expected since the number of trainable layers is 0 with partial.


edit: another candidate might be "convnext_base_stage_3_block_2", also unfreezing the last layer norm since it comes after the final block.

model.trainable = True
for layer in model.layers:
    # freeze all layers before the last block
    if not re.search("^convnext_base_stage_3_block_2", layer.name):
        layer.trainable = False
model.layers[-1].trainable = True

This results in about 10% of weights being unfrozen and only the final block [1].

Total params: 87566464 (334.04 MB)
Trainable params: 8450048 (32.23 MB)
Non-trainable params: 79116416 (301.81 MB)

[1]
image

layer.trainable = False
elif trainable == "frozen":
convnext.trainable = False
else:
raise ValueError(f"{trainable} is not a supported option for 'trainable'.")

if weights:
for layer in convnext.layers:
if isinstance(layer, layers.experimental.SyncBatchNormalization):
layer.trainable = False
return convnext
111 changes: 111 additions & 0 deletions tests/architectures/test_convnext.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,111 @@
import re

import pytest
import tensorflow as tf


MIN_TF_MAJOR_VERSION = 2
MIN_TF_MINOR_VERSION = 10

major_version = tf.__version__.split(".")[0]
minor_version = tf.__version__.split(".")[1]

convneXt = pytest.importorskip("tensorflow_similarity.architectures.convnext")

TF_MAJOR_VERSION = int(tf.__version__.split(".")[0])
TF_MINOR_VERSION = int(tf.__version__.split(".")[1])


def tf_version_check(major_version, minor_version):
if TF_MAJOR_VERSION <= major_version and TF_MINOR_VERSION < minor_version:
return True

return False


def test_build_convnext_tiny_full():
input_layer = tf.keras.layers.Input((224, 224, 3))
output = convneXt.build_convnext("tiny", "imagenet", "full")(input_layer)

convnext = output._keras_history.layer
assert convnext.name == "convnext_tiny"
assert convnext.trainable

total_layer_count = 0
trainable_layer_count = 0
for layer in convnext._self_tracked_trackables:
total_layer_count += 1
if layer.trainable:
trainable_layer_count += 1

expected_total_layer_count = 151
expected_trainable_layer_count = 151

assert total_layer_count == expected_total_layer_count
assert trainable_layer_count == expected_trainable_layer_count


def test_build_convnext_small_partial():
input_layer = tf.keras.layers.Input((224, 224, 3))
output = convneXt.build_convnext("small", "imagenet", "partial")(input_layer)

convnext = output._keras_history.layer
assert convnext.name == "convnext_small"
assert convnext.trainable

total_layer_count = 0
trainable_layer_count = 0
for layer in convnext._self_tracked_trackables:
total_layer_count += 1
if layer.trainable:
trainable_layer_count += 1

expected_total_layer_count = 295
expected_trainable_layer_count = 0

assert total_layer_count == expected_total_layer_count
assert trainable_layer_count == expected_trainable_layer_count


def test_build_convnext_base_frozen():
input_layer = tf.keras.layers.Input((224, 224, 3))
output = convneXt.build_convnext("base", "imagenet", "frozen")(input_layer)

convnext = output._keras_history.layer
assert convnext.name == "convnext_base"
assert not convnext.trainable

total_layer_count = 0
trainable_layer_count = 0
for layer in convnext._self_tracked_trackables:
total_layer_count += 1
if layer.trainable:
trainable_layer_count += 1

expected_total_layer_count = 295
expected_trainable_layer_count = 0

assert total_layer_count == expected_total_layer_count
assert trainable_layer_count == expected_trainable_layer_count


def test_build_convnext_large_full():
input_layer = tf.keras.layers.Input((224, 224, 3))
output = convneXt.build_convnext("large", "imagenet", "full")(input_layer)

convnext = output._keras_history.layer
assert convnext.name == "convnext_large"
assert convnext.trainable

total_layer_count = 0
trainable_layer_count = 0
for layer in convnext._self_tracked_trackables:
total_layer_count += 1
if layer.trainable:
trainable_layer_count += 1

expected_total_layer_count = 295
expected_trainable_layer_count = 295

assert total_layer_count == expected_total_layer_count
assert trainable_layer_count == expected_trainable_layer_count