diff --git a/README.md b/README.md index ef91d93..fdfb8e7 100644 --- a/README.md +++ b/README.md @@ -8,7 +8,8 @@ sparsity pruning depends on special algorithms and hardware to achieve accelerat Adlik pruning focuses on channel pruning and filter pruning, which can really reduce the number of parameters and flops. In terms of quantization, Adlik focuses on 8-bit quantization that is easier to accelerate on specific hardware. After testing, it is found that running a small batch of datasets can obtain a quantitative model with little loss of -accuracy, so Adlik focuses on this method. +accuracy, so Adlik focuses on this method. Knowledge distillation is another way to improve the performance of deep +learning algorithm. It is possible to compress the knowledge in the big model into a smaller model. The proposed framework mainly consists of two categories of algorithm components, i.e. pruner and quantizer. The pruner is mainly composed of five modules:core, scheduler, models, dataset and learner. The core module defines @@ -35,6 +36,14 @@ The following table is the size of the above model files: | LeNet-5 | 1176KB | 499KB(59% pruned) | 120KB | 1154KB (pb) | | ResNet-50 | 99MB | 67MB(31.9% pruned) | 18MB | 138MB(pb) | +Knowledge distillation is an effective way to imporve the performance of model. + +The following table shows the distillation result of ResNet-50 as the student network where ResNet-101 as the teacher network. + +| student model | ResNet-101 distilled | accuracy change | +| ------------- | -------------------- | --------------- | +| ResNet-50 | 77.14% | +0.97% | + ## 1. Pruning and quantization principle ### 1.1 Filter pruning @@ -63,6 +72,16 @@ quantization, only need to have inference model and very little calibration data of quantization is very small, and even some models will rise. Adlik only needs 100 sample images to complete the quantification of ResNet-50 in less than one minute. +### 1.3 Knowledge Distillation + +Knowledge distillation is a compression technique by which the knowledge of a larger model(teacher) is transfered into +a smaller one(student). During distillation, a student model learns from a teacher model to generalize well by raise +the temperature of the final softmax of the teacher model as the soft set of targets. + +![Distillation](imgs/distillation.png) + +Refer to the paper [Distilling the Knowledge in a Neural Network](https://arxiv.org/pdf/1503.02531.pdf) + ## 2. Installation These instructions will help get Adlik optimizer up and running on your local machine. @@ -102,7 +121,7 @@ rm -rf /tmp/openmpi #### 2.2.2 Install python package ```shell -pip install tensorflow-gpu==2.1.0 +pip install tensorflow-gpu==2.3.0 pip install horovod==0.19.1 pip install mpi4py pip install networkx diff --git a/doc/ResNet-50-Knowledge-Distillation.md b/doc/ResNet-50-Knowledge-Distillation.md new file mode 100644 index 0000000..a5f0f7e --- /dev/null +++ b/doc/ResNet-50-Knowledge-Distillation.md @@ -0,0 +1,46 @@ +# ResNet-50 Knowledge Distillation + +The following uses ResNet-101 on the ImageNet data set as teacher model to illustrate how to use the model optimizer to +improve the preformance of ResNet-50 by knowledge distillation. + +## 1 Prepare data + +### 1.1 Generate training and test data sets + +You may follow the data preparation guide [here](https://github.com/tensorflow/models/tree/v1.13.0/research/inception) +to download the full data set and convert it into TFRecord files. By default, when the script finishes, you will find +1024 training files and 128 validation files in the DATA_DIR. The files will match the patterns train-?????-of-01024 +and validation-?????-of-00128, respectively. + +### 2 Train the teacher model + +Enter the examples directory and execute + +```shell +cd examples +horovodrun -np 8 -H localhost:8 python resnet_101_imagenet_train.py +``` + +After execution, the default checkpoint file will be generated in ./models_ckpt/resnet_101_imagenet, and the inference +checkpoint file will be generated in ./models_eval_ckpt/resnet_101_imagenet. You can also modify the checkpoint_path +and checkpoint_eval_path of the resnet_101_imagenet_train.py file to change the generated file path. + +### 3 Distill + +Enter the examples directory and execute + +```shell +horovodrun -np 8 -H localhost:8 python resnet_50_imagenet_distill.py +``` + +After execution, the default checkpoint file will be generated in ./models_ckpt/resnet_50_imagenet_distill, and +the inference checkpoint file will be generated in ./models_eval_ckpt/resnet_50_imagenet_distill. You can also +modify the checkpoint_path and checkpoint_eval_path of the resnet_50_imagenet_distill.py file to change the generated +file path. + +> Note +> +> > i. The model in the checkpoint_path is not the pure ResNet-50 model. It's the hybird of ResNet-50(student) and +> > ResNet-101(teacher) +> > +> > ii. The model in the checkpoint_eval_path is the distilled model, i.e. pure ResNet-50 model diff --git a/docker/Dockerfile.cpu b/docker/Dockerfile.cpu index 8a24000..1698428 100644 --- a/docker/Dockerfile.cpu +++ b/docker/Dockerfile.cpu @@ -3,7 +3,8 @@ FROM ubuntu:18.04 RUN apt-get update && \ apt-get install -y software-properties-common && \ apt-get update -y && \ - apt-get install -y --no-install-recommends build-essential python3.6 python3.6-dev python3-distutils \ + apt-get install -y --no-install-recommends --allow-downgrades --allow-change-held-packages \ + build-essential python3.6 python3.6-dev python3-distutils \ curl git openssh-client openssh-server && \ mkdir -p /var/run/sshd && \ mkdir -p /root/work && \ @@ -26,7 +27,7 @@ RUN mkdir /tmp/openmpi && \ rm -rf /tmp/openmpi # Install Tensorflow and Horovod -RUN pip install --no-cache-dir tensorflow==2.1.0 +RUN pip install --no-cache-dir tensorflow==2.3.0 RUN HOROVOD_WITH_TENSORFLOW=1 pip install --no-cache-dir horovod==0.19.1 diff --git a/docker/Dockerfile.gpu b/docker/Dockerfile.gpu index dbc1659..52d2dae 100644 --- a/docker/Dockerfile.gpu +++ b/docker/Dockerfile.gpu @@ -3,7 +3,8 @@ FROM nvidia/cuda:10.1-devel-ubuntu18.04 RUN apt-get update && \ apt-get install -y software-properties-common && \ apt-get update -y && \ - apt-get install -y --no-install-recommends build-essential python3.6 python3.6-dev python3-distutils \ + apt-get install -y --no-install-recommends --allow-downgrades --allow-change-held-packages \ + build-essential python3.6 python3.6-dev python3-distutils \ curl vim git openssh-client openssh-server \ libcudnn7=7.6.5.32-1+cuda10.1 \ libcudnn7-dev=7.6.5.32-1+cuda10.1 \ @@ -38,7 +39,7 @@ RUN mkdir /tmp/openmpi && \ rm -rf /tmp/openmpi # Install Tensorflow and Horovod -RUN pip install --no-cache-dir tensorflow-gpu==2.1.0 +RUN pip install --no-cache-dir tensorflow-gpu==2.3.0 RUN HOROVOD_WITH_TENSORFLOW=1 pip install --no-cache-dir horovod==0.19.1 diff --git a/examples/resnet_101_imagenet_train.py b/examples/resnet_101_imagenet_train.py new file mode 100644 index 0000000..a74832e --- /dev/null +++ b/examples/resnet_101_imagenet_train.py @@ -0,0 +1,37 @@ +# Copyright 2019 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Train a ResNet_101 model on the ImageNet dataset +""" +import os +# If you did not execute the setup.py, uncomment the following four lines +# import sys +# from os.path import abspath, join, dirname +# sys.path.insert(0, join(abspath(dirname(__file__)), '../src')) +# print(sys.path) + +from model_optimizer import prune_model # noqa: E402 + + +def _main(): + base_dir = os.path.dirname(__file__) + request = { + "dataset": "imagenet", + "model_name": "resnet_101", + "data_dir": os.path.join(base_dir, "./data/imagenet"), + "batch_size": 128, + "batch_size_val": 100, + "learning_rate": 0.1, + "epochs": 120, + "checkpoint_path": os.path.join(base_dir, "./models_ckpt/resnet_101_imagenet"), + "checkpoint_save_period": 5, # save a checkpoint every 5 epoch + "checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/resnet_101_imagenet"), + "scheduler": "train", + "classifier_activation": None # None or "softmax", default is softmax + } + prune_model(request) + + +if __name__ == "__main__": + _main() diff --git a/examples/resnet_50_imagenet_distill.py b/examples/resnet_50_imagenet_distill.py new file mode 100644 index 0000000..82b5269 --- /dev/null +++ b/examples/resnet_50_imagenet_distill.py @@ -0,0 +1,38 @@ +# Copyright 2020 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Distill a ResNet_50 model from a trained ResNet_101 model on the ImageNet dataset +""" +import os +# If you did not execute the setup.py, uncomment the following four lines +# import sys +# from os.path import abspath, join, dirname +# sys.path.insert(0, join(abspath(dirname(__file__)), '../src')) +# print(sys.path) + +from model_optimizer import prune_model # noqa: E402 + + +def _main(): + base_dir = os.path.dirname(__file__) + request = { + "dataset": "imagenet", + "model_name": "resnet_50", + "data_dir": os.path.join(base_dir, "./data/imagenet"), + "batch_size": 256, + "batch_size_val": 100, + "learning_rate": 0.1, + "epochs": 90, + "checkpoint_path": os.path.join(base_dir, "./models_ckpt/resnet_50_imagenet_distill"), + "checkpoint_save_period": 5, # save a checkpoint every 5 epoch + "checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/resnet_50_imagenet_distill"), + "scheduler": "distill", + "scheduler_file_name": "resnet_50_imagenet_0.3.yaml", + "classifier_activation": None # None or "softmax", default is softmax + } + prune_model(request) + + +if __name__ == "__main__": + _main() diff --git a/examples/resnet_50_imagenet_train.py b/examples/resnet_50_imagenet_train.py index 0361d60..ec122aa 100644 --- a/examples/resnet_50_imagenet_train.py +++ b/examples/resnet_50_imagenet_train.py @@ -5,12 +5,13 @@ Train a ResNet_50 model on the ImageNet dataset """ import os -# If you did not execute the setup.py, uncomment the following four lines + # import sys # from os.path import abspath, join, dirname # sys.path.insert(0, join(abspath(dirname(__file__)), '../src')) # print(sys.path) + from model_optimizer import prune_model # noqa: E402 @@ -27,7 +28,8 @@ def _main(): "checkpoint_path": os.path.join(base_dir, "./models_ckpt/resnet_50_imagenet"), "checkpoint_save_period": 5, # save a checkpoint every 5 epoch "checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/resnet_50_imagenet"), - "scheduler": "train" + "scheduler": "train", + "classifier_activation": None # None or "softmax", default is softmax } prune_model(request) diff --git a/imgs/distillation.png b/imgs/distillation.png new file mode 100644 index 0000000..7b57ed4 Binary files /dev/null and b/imgs/distillation.png differ diff --git a/setup.py b/setup.py index 9d2cc8a..6d11194 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,7 @@ _REQUIRED_PACKAGES = [ 'requests', - 'tensorflow==2.1.0', + 'tensorflow==2.3.0', 'jsonschema==3.1.1', 'networkx==2.4', 'mpi4py==3.0.3', @@ -43,7 +43,7 @@ def get_dist(pkgname): if get_dist('tensorflow') is None and get_dist('tensorflow-gpu') is not None: - _REQUIRED_PACKAGES.remove('tensorflow==2.1.0') + _REQUIRED_PACKAGES.remove('tensorflow==2.3.0') setup( name="model_optimizer", @@ -60,7 +60,8 @@ def get_dist(pkgname): package_data={ 'model_optimizer': ['**/*.json', 'pruner/scheduler/uniform_auto/*.yaml', - 'pruner/scheduler/uniform_specified_layer/*.yaml'] + 'pruner/scheduler/uniform_specified_layer/*.yaml', + 'pruner/scheduler/distill/*.yaml'] }, ) diff --git a/src/model_optimizer/pruner/dataset/cifar10.py b/src/model_optimizer/pruner/dataset/cifar10.py index 64f0a73..556c48a 100644 --- a/src/model_optimizer/pruner/dataset/cifar10.py +++ b/src/model_optimizer/pruner/dataset/cifar10.py @@ -20,7 +20,7 @@ def __init__(self, config, is_training): :param is_training: whether to construct the training subset :return: """ - super(Cifar10Dataset, self).__init__(config, is_training) + super().__init__(config, is_training) if is_training: self.file_pattern = os.path.join(self.data_dir, 'train.tfrecords') self.batch_size = self.batch_size @@ -32,6 +32,7 @@ def __init__(self, config, is_training): self.num_samples_of_train = 50000 self.num_samples_of_val = 10000 + # pylint: disable=no-value-for-parameter,unexpected-keyword-arg def parse_fn(self, example_serialized): """ Parse features from the serialized data diff --git a/src/model_optimizer/pruner/dataset/dataset_base.py b/src/model_optimizer/pruner/dataset/dataset_base.py index 8cd1c42..a05dff4 100644 --- a/src/model_optimizer/pruner/dataset/dataset_base.py +++ b/src/model_optimizer/pruner/dataset/dataset_base.py @@ -62,9 +62,10 @@ def num_samples(self): else: return self.num_samples_of_val - def build(self): + def build(self, is_distill=False): """ Build dataset + :param is_distill: is distilling or not :return: batch of a dataset """ dataset = tf.data.Dataset.list_files(self.file_pattern, shuffle=True) @@ -73,7 +74,10 @@ def build(self): dataset = dataset.interleave(self.dataset_fn, cycle_length=10, num_parallel_calls=tf.data.experimental.AUTOTUNE) if self.is_training: dataset = dataset.shuffle(buffer_size=self.buffer_size).repeat() - dataset = dataset.map(self.parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) + if is_distill: + dataset = dataset.map(self.parse_fn_distill, num_parallel_calls=tf.data.experimental.AUTOTUNE) + else: + dataset = dataset.map(self.parse_fn, num_parallel_calls=tf.data.experimental.AUTOTUNE) return self.__build_batch(dataset) def __build_batch(self, dataset): diff --git a/src/model_optimizer/pruner/dataset/imagenet.py b/src/model_optimizer/pruner/dataset/imagenet.py index fa3717e..295b596 100644 --- a/src/model_optimizer/pruner/dataset/imagenet.py +++ b/src/model_optimizer/pruner/dataset/imagenet.py @@ -21,7 +21,7 @@ def __init__(self, config, is_training, num_shards=1, shard_index=0): :param is_training: whether to construct the training subset :return: """ - super(ImagenetDataset, self).__init__(config, is_training, num_shards, shard_index) + super().__init__(config, is_training, num_shards, shard_index) if is_training: self.file_pattern = os.path.join(self.data_dir, 'train-*-of-*') self.batch_size = self.batch_size @@ -33,6 +33,7 @@ def __init__(self, config, is_training, num_shards=1, shard_index=0): self.num_samples_of_train = 1281167 self.num_samples_of_val = 50000 + # pylint: disable=no-value-for-parameter,unexpected-keyword-arg def parse_fn(self, example_serialized): """ Parse features from the serialized data @@ -77,3 +78,14 @@ def parse_fn(self, example_serialized): num_channels=3, is_training=self.is_training) return image, label + + def parse_fn_distill(self, example_serialized): + """ + Parse features from the serialized data for distillation + :param example_serialized: serialized data + :return: {image, label},{} + """ + image, label = self.parse_fn(example_serialized) + inputs = {"image": image, "label": label} + targets = {} + return inputs, targets diff --git a/src/model_optimizer/pruner/dataset/mnist.py b/src/model_optimizer/pruner/dataset/mnist.py index 0ddc63d..d251c1f 100644 --- a/src/model_optimizer/pruner/dataset/mnist.py +++ b/src/model_optimizer/pruner/dataset/mnist.py @@ -20,7 +20,7 @@ def __init__(self, config, is_training): :param is_training: whether to construct the training subset :return: """ - super(MnistDataset, self).__init__(config, is_training) + super().__init__(config, is_training) if is_training: self.file_pattern = os.path.join(self.data_dir, 'train.tfrecords') self.batch_size = self.batch_size @@ -33,6 +33,7 @@ def __init__(self, config, is_training): self.num_samples_of_val = 10000 # pylint: disable=R0201 + # pylint: disable=no-value-for-parameter,unexpected-keyword-arg def parse_fn(self, example_serialized): """ Parse features from the serialized data diff --git a/src/model_optimizer/pruner/distill/__init__.py b/src/model_optimizer/pruner/distill/__init__.py new file mode 100644 index 0000000..e18d67c --- /dev/null +++ b/src/model_optimizer/pruner/distill/__init__.py @@ -0,0 +1,2 @@ +# Copyright 2021 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 diff --git a/src/model_optimizer/pruner/distill/distill_loss.py b/src/model_optimizer/pruner/distill/distill_loss.py new file mode 100644 index 0000000..a51112d --- /dev/null +++ b/src/model_optimizer/pruner/distill/distill_loss.py @@ -0,0 +1,75 @@ +# Copyright 2021 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Distilling Loss Layer +""" +import tensorflow as tf + + +class DistillLossLayer(tf.keras.layers.Layer): + """ + Layer to compute the loss for distillation. + the total loss = the student loss + the distillation loss + + Arguments: + alpha: a float between [0.0, 1.0]. It corresponds to the importance between the student loss and the + distillation loss. + temperature: the temperature of distillation. Defaults to 10. + teacher_path: the model path of teacher. The format of the model is h5. + name: String, name to use for this layer. Defaults to 'DistillLoss'. + + Call arguments: + inputs: inputs of the layer. It corresponds to [input, y_true, y_prediction] + """ + def __init__(self, teacher_path, alpha=1.0, temperature=10, name="DistillLoss", **kwargs): + """ + :param teacher_path: the model path of teacher. The format of the model is h5. + :param alpha: a float between [0.0, 1.0]. It corresponds to the importance between the student loss and the + distillation loss. + :param temperature: the temperature of distillation. Defaults to 10. + :param name: String, name to use for this layer. Defaults to 'DistillLoss'. + """ + super().__init__(name=name, **kwargs) + self.alpha = alpha + self.temperature = temperature + self.teacher_path = teacher_path + self.accuracy_fn = tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy") + self.teacher = tf.keras.models.load_model(self.teacher_path) + + # pylint: disable=unused-argument + def call(self, inputs, **kwargs): + """ + :param inputs: inputs of the layer. It corresponds to [input, y_true, y_prediction] + :return: the total loss of the distiller model + """ + x, y_true, y_pred = inputs + rtn_loss = None + if y_true is not None: + student_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)(y_true, y_pred) + self.teacher.trainable = False + teacher_predictions = self.teacher(x) + distillation_loss = tf.keras.losses.KLDivergence()( + tf.nn.softmax(y_pred / self.temperature, axis=1), + tf.nn.softmax(teacher_predictions / self.temperature, axis=1) + ) + stu_loss = self.alpha * student_loss + dis_loss = (1 - self.alpha) * self.temperature * self.temperature * distillation_loss + rtn_loss = stu_loss + dis_loss + + self.add_loss(rtn_loss) + self.add_metric(student_loss, aggregation="mean", name="stu_loss") + self.add_metric(dis_loss, aggregation="mean", name="dis_loss") + + self.add_metric(self.accuracy_fn(y_true, y_pred)) + return rtn_loss + + def get_config(self): + """ + Implement get_config to enable serialization. + """ + config = super().get_config() + config.update({"teacher_path": self.teacher_path}) + config.update({"alpha": self.alpha}) + config.update({"temperature": self.temperature}) + return config diff --git a/src/model_optimizer/pruner/distill/distiller.py b/src/model_optimizer/pruner/distill/distiller.py new file mode 100644 index 0000000..6eb3a27 --- /dev/null +++ b/src/model_optimizer/pruner/distill/distiller.py @@ -0,0 +1,27 @@ +# Copyright 2021 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +get distiller model +""" +import tensorflow as tf + +from .distill_loss import DistillLossLayer + + +def get_distiller(student_model, scheduler_config): + """ + Get distiller model + :param student_model: student model function + :param scheduler_config: scheduler config object + :return: keras model of distiller + """ + input_img = tf.keras.layers.Input(shape=(224, 224, 3), name='image') + input_lbl = tf.keras.layers.Input((), name="label", dtype='int32') + student = student_model + _, logits = student(input_img) + total_loss = DistillLossLayer(scheduler_config['teacher_path'], scheduler_config['alpha'], + scheduler_config['temperature'], )([input_img, input_lbl, logits]) + distill_model = tf.keras.Model(inputs=[input_img, input_lbl], outputs=[logits, total_loss]) + + return distill_model diff --git a/src/model_optimizer/pruner/learner/__init__.py b/src/model_optimizer/pruner/learner/__init__.py index b689e27..d0f542b 100644 --- a/src/model_optimizer/pruner/learner/__init__.py +++ b/src/model_optimizer/pruner/learner/__init__.py @@ -20,6 +20,9 @@ def get_learner(config): elif model_name == 'resnet_50' and dataset_name == 'imagenet': from .resnet_50_imagenet import Learner return Learner(config) + elif model_name == 'resnet_101' and dataset_name == 'imagenet': + from .resnet_101_imagenet import Learner + return Learner(config) elif model_name == 'mobilenet_v1' and dataset_name == 'imagenet': from .mobilenet_v1_imagenet import Learner return Learner(config) diff --git a/src/model_optimizer/pruner/learner/learner_base.py b/src/model_optimizer/pruner/learner/learner_base.py index 98e5159..5dcafb5 100644 --- a/src/model_optimizer/pruner/learner/learner_base.py +++ b/src/model_optimizer/pruner/learner/learner_base.py @@ -12,6 +12,7 @@ from ..models import get_model from .utils import get_call_backs from ...stat import print_keras_model_summary, print_keras_model_params_flops +from ..distill.distill_loss import DistillLossLayer class LearnerBase(metaclass=abc.ABCMeta): @@ -48,7 +49,8 @@ def __init__(self, config): eval_model = tf.keras.models.clone_model(origin_eval_model) self.models_train.append(train_model) self.models_eval.append(eval_model) - self.train_dataset, self.eval_dataset = self.build_dataset() + self.train_dataset, self.eval_dataset, self.train_dataset_distill, self.eval_dataset_distill = \ + self.build_dataset() self.build_train() self.build_eval() self.load_model() @@ -64,7 +66,7 @@ def resume_epoch(self): return self.resume_from_epoch @abc.abstractmethod - def get_losses(self): + def get_losses(self, is_training=True): """ Model compile losses :return: Return model compile losses @@ -80,7 +82,7 @@ def get_optimizer(self): pass @abc.abstractmethod - def get_metrics(self): + def get_metrics(self, is_training=True): """ Model compile metrics :return: Return model compile metrics @@ -99,7 +101,15 @@ def build_dataset(self): ds_eval = get_dataset(self.config, is_training=False) self.eval_steps_per_epoch = ds_eval.steps_per_epoch eval_dataset = ds_eval.build() - return train_dataset, eval_dataset + train_dataset_distill = None + eval_dataset_distill = None + if self.config.get_attribute("scheduler") == "distill": + ds_train_distill = get_dataset(self.config, is_training=True, num_shards=hvd.size(), shard_index=hvd.rank()) + train_dataset_distill = ds_train_distill.build(True) + ds_eval_distill = get_dataset(self.config, is_training=False) + eval_dataset_distill = ds_eval_distill.build(True) + + return train_dataset, eval_dataset, train_dataset_distill, eval_dataset_distill def build_train(self): """ @@ -120,9 +130,9 @@ def build_eval(self): Model compile for eval model :return: """ - loss = self.get_losses() + loss = self.get_losses(False) optimizer = self.get_optimizer() - metrics = self.get_metrics() + metrics = self.get_metrics(False) eval_model = self.models_eval[-1] eval_model.compile(loss=loss, optimizer=optimizer, @@ -142,21 +152,34 @@ def train(self, initial_epoch=0, epochs=1, lr_schedulers=None): self.callbacks.append(tf.keras.callbacks.ModelCheckpoint(os.path.join(self.checkpoint_path, './checkpoint-{epoch}.h5'), period=self.checkpoint_save_period)) - train_model.fit(self.train_dataset, initial_epoch=initial_epoch, steps_per_epoch=self.train_steps_per_epoch, + if self.config.get_attribute('scheduler') == 'distill': + train_dataset = self.train_dataset_distill + else: + train_dataset = self.train_dataset + train_model.fit(train_dataset, initial_epoch=initial_epoch, steps_per_epoch=self.train_steps_per_epoch, epochs=epochs, verbose=self.verbose, callbacks=self.callbacks) self.cur_epoch += epochs-initial_epoch def eval(self): """ Model eval process, only evaluate on rank 0 + the format of score is like as follows: + {loss: 7.6969 dense1_loss: 5.4490 softmax_1_sparse_categorical_accuracy: 0.0665 + dense1_sparse_categorical_accuracy: 0.0665} :return: """ if hvd.rank() != 0: return eval_model = self.models_eval[-1] score = eval_model.evaluate(self.eval_dataset, steps=self.eval_steps_per_epoch) - print('Test loss:', score[0]) - print('Test accuracy:', score[1]) + loss = score[0] + if self.config.get_attribute("classifier_activation", "softmax") == "softmax": + accuracy = score[2] + else: + accuracy = score[3] + + print('Test loss:', loss) + print('Test accuracy:', accuracy) def get_latest_train_model(self): """ @@ -209,6 +232,14 @@ def load_model(self): Load checkpoint and update cur_epoch resume_from_epoch train_model :return: """ + _custom_objects = { + 'DistillLossLayer': DistillLossLayer + } + if self.config.get_attribute('scheduler') == 'distill': + custom_objects = _custom_objects + else: + custom_objects = None + self.resume_from_epoch = 0 for try_epoch in range(self.epochs, 0, -1): if os.path.exists(os.path.join(self.checkpoint_path, self.checkpoint_format.format(epoch=try_epoch))): @@ -218,7 +249,8 @@ def load_model(self): self.cur_epoch = self.resume_from_epoch model = tf.keras.models.load_model( os.path.join(self.checkpoint_path, - self.checkpoint_format.format(epoch=self.resume_from_epoch))) + self.checkpoint_format.format(epoch=self.resume_from_epoch)), + custom_objects=custom_objects) self.train_models_update(model) def save_eval_model(self): @@ -230,17 +262,28 @@ def save_eval_model(self): return train_model = self.models_train[-1] eval_model = self.models_eval[-1] - clone_model = tf.keras.models.clone_model(eval_model) - for i, layer in enumerate(clone_model.layers): - if 'Conv2D' in str(type(layer)): - clone_model.layers[i].filters = train_model.get_layer(layer.name).filters - elif 'Dense' in str(type(layer)): - clone_model.layers[i].units = train_model.get_layer(layer.name).units - pruned_eval_model = tf.keras.models.model_from_json(clone_model.to_json()) - pruned_eval_model.set_weights(train_model.get_weights()) - save_model_path = os.path.join(self.save_model_path, 'checkpoint-') + str(self.cur_epoch)+'.h5' - pruned_eval_model.save(save_model_path) - self.eval_models_update(pruned_eval_model) + save_model_path = os.path.join(self.save_model_path, 'checkpoint-') + str(self.cur_epoch) + '.h5' + if self.config.get_attribute('scheduler') == 'distill': + model_name = self.config.get_attribute('model_name') + for layer_eval in eval_model.layers: + for layer in train_model.layers: + if layer.name == model_name and layer_eval.name == model_name: + layer_eval.set_weights(layer.get_weights()) + student_eval = layer_eval + break + student_eval.save(save_model_path) + self.eval_models_update(student_eval) + else: + clone_model = tf.keras.models.clone_model(eval_model) + for i, layer in enumerate(clone_model.layers): + if 'Conv2D' in str(type(layer)): + clone_model.layers[i].filters = train_model.get_layer(layer.name).filters + elif 'Dense' in str(type(layer)): + clone_model.layers[i].units = train_model.get_layer(layer.name).units + pruned_eval_model = tf.keras.models.model_from_json(clone_model.to_json()) + pruned_eval_model.set_weights(train_model.get_weights()) + pruned_eval_model.save(save_model_path) + self.eval_models_update(pruned_eval_model) def print_model_summary(self): """ diff --git a/src/model_optimizer/pruner/learner/lenet_mnist.py b/src/model_optimizer/pruner/learner/lenet_mnist.py index 1e973f1..0ee7cd2 100644 --- a/src/model_optimizer/pruner/learner/lenet_mnist.py +++ b/src/model_optimizer/pruner/learner/lenet_mnist.py @@ -15,7 +15,7 @@ class Learner(LearnerBase): Lenet on mnist Learner """ def __init__(self, config): - super(Learner, self).__init__(config) + super().__init__(config) self.callbacks = [ # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when @@ -42,16 +42,20 @@ def get_optimizer(self): opt = hvd.DistributedOptimizer(opt) return opt - def get_losses(self): + def get_losses(self, is_training=True): """ Model compile losses + :param is_training: is training or not :return: Return model compile losses """ return 'sparse_categorical_crossentropy' - def get_metrics(self): + def get_metrics(self, is_training=True): """ Model compile metrics + :param is_training: is training or not :return: Return model compile metrics """ + if self.config.get_attribute('scheduler') == 'distill' and is_training: + return None return ['sparse_categorical_accuracy'] diff --git a/src/model_optimizer/pruner/learner/mobilenet_v1_imagenet.py b/src/model_optimizer/pruner/learner/mobilenet_v1_imagenet.py index 12120a2..464a9b0 100644 --- a/src/model_optimizer/pruner/learner/mobilenet_v1_imagenet.py +++ b/src/model_optimizer/pruner/learner/mobilenet_v1_imagenet.py @@ -15,7 +15,7 @@ class Learner(LearnerBase): Resnet-50 on imagenet Learner """ def __init__(self, config): - super(Learner, self).__init__(config) + super().__init__(config) self.callbacks = [ # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when @@ -52,16 +52,20 @@ def get_optimizer(self): opt = hvd.DistributedOptimizer(opt) return opt - def get_losses(self): + def get_losses(self, is_training=True): """ Model compile losses + :param is_training: is training or not :return: Return model compile losses """ return 'sparse_categorical_crossentropy' - def get_metrics(self): + def get_metrics(self, is_training=True): """ Model compile metrics + :param is_training: is training or not :return: Return model compile metrics """ + if self.config.get_attribute('scheduler') == 'distill' and is_training: + return None return ['sparse_categorical_accuracy'] diff --git a/src/model_optimizer/pruner/learner/mobilenet_v2_imagenet.py b/src/model_optimizer/pruner/learner/mobilenet_v2_imagenet.py index bcc3407..b29cbfc 100644 --- a/src/model_optimizer/pruner/learner/mobilenet_v2_imagenet.py +++ b/src/model_optimizer/pruner/learner/mobilenet_v2_imagenet.py @@ -16,7 +16,7 @@ class Learner(LearnerBase): Resnet-50 on imagenet Learner """ def __init__(self, config): - super(Learner, self).__init__(config) + super().__init__(config) self.callbacks = [ # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when @@ -51,16 +51,20 @@ def get_optimizer(self): opt = hvd.DistributedOptimizer(opt) return opt - def get_losses(self): + def get_losses(self, is_training=True): """ Model compile losses + :param is_training: is training or not :return: Return model compile losses """ return 'sparse_categorical_crossentropy' - def get_metrics(self): + def get_metrics(self, is_training=True): """ Model compile metrics + :param is_training: is training or not :return: Return model compile metrics """ + if self.config.get_attribute('scheduler') == 'distill' and is_training: + return None return ['sparse_categorical_accuracy'] diff --git a/src/model_optimizer/pruner/learner/resnet_101_imagenet.py b/src/model_optimizer/pruner/learner/resnet_101_imagenet.py new file mode 100644 index 0000000..62a3e80 --- /dev/null +++ b/src/model_optimizer/pruner/learner/resnet_101_imagenet.py @@ -0,0 +1,78 @@ +# Copyright 2019 ZTE corporation. All Rights Reserved. +# SPDX-License-Identifier: Apache-2.0 + +""" +Resnet-101 on imagenet Learner definition +""" +import os +import tensorflow as tf +import horovod.tensorflow.keras as hvd +from .learner_base import LearnerBase + + +class Learner(LearnerBase): + """ + Resnet-101 on imagenet Learner + """ + def __init__(self, config): + super().__init__(config) + self.callbacks = [ + # Horovod: broadcast initial variable states from rank 0 to all other processes. + # This is necessary to ensure consistent initialization of all workers when + # training is started with random weights or restored from a checkpoint. + hvd.callbacks.BroadcastGlobalVariablesCallback(0), + # Horovod: average metrics among workers at the end of every epoch. + # + # Note: This callback must be in the list before the ReduceLROnPlateau, + # TensorBoard or other metrics-based callbacks. + hvd.callbacks.MetricAverageCallback(), + # Horovod: using `lr = 1.0 * hvd.size()` from the very beginning leads to worse final + # accuracy. Scale the learning rate `lr = 1.0` ---> `lr = 1.0 * hvd.size()` during + # the first five epochs. See https://arxiv.org/abs/1706.02677 for details. + hvd.callbacks.LearningRateWarmupCallback(warmup_epochs=5, verbose=0), + # Horovod: after the warmup reduce learning rate by 10 on the 30th, 60th and 90th epochs. + hvd.callbacks.LearningRateScheduleCallback(start_epoch=5, end_epoch=30, multiplier=1.), + hvd.callbacks.LearningRateScheduleCallback(start_epoch=30, end_epoch=60, multiplier=1e-1), + hvd.callbacks.LearningRateScheduleCallback(start_epoch=60, end_epoch=90, multiplier=1e-2), + hvd.callbacks.LearningRateScheduleCallback(start_epoch=90, multiplier=1e-3), + ] + # Horovod: save checkpoints only on worker 0 to prevent other workers from corrupting them. + if hvd.rank() == 0: + self.callbacks.append(tf.keras.callbacks.ModelCheckpoint(os.path.join(self.checkpoint_path, + './checkpoint-{epoch}.h5'), + period=self.checkpoint_save_period)) + + def get_optimizer(self): + """ + Model compile optimizer + :return: Return model compile optimizer + """ + opt = tf.keras.optimizers.SGD(learning_rate=self.learning_rate*hvd.size(), momentum=0.9) + opt = hvd.DistributedOptimizer(opt) + return opt + + def get_losses(self, is_training=True): + """ + Model compile losses + :param: is_training: is training of not + :return: Return model compile losses + """ + softmax_loss = tf.keras.losses.SparseCategoricalCrossentropy() + logits_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + if self.config.get_attribute('scheduler') == 'distill' and is_training: + return None + else: + if self.config.get_attribute("classifier_activation", "softmax") == "softmax": + return [softmax_loss, None] + else: + return [None, logits_loss] + + def get_metrics(self, is_training=True): + """ + Model compile metrics + :param: is_training: is training of not + :return: Return model compile metrics + """ + if self.config.get_attribute('scheduler') == 'distill' and is_training: + return None + return ['sparse_categorical_accuracy'] diff --git a/src/model_optimizer/pruner/learner/resnet_50_imagenet.py b/src/model_optimizer/pruner/learner/resnet_50_imagenet.py index 0c9651c..b92905c 100644 --- a/src/model_optimizer/pruner/learner/resnet_50_imagenet.py +++ b/src/model_optimizer/pruner/learner/resnet_50_imagenet.py @@ -15,7 +15,7 @@ class Learner(LearnerBase): Resnet-50 on imagenet Learner """ def __init__(self, config): - super(Learner, self).__init__(config) + super().__init__(config) self.callbacks = [ # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when @@ -51,16 +51,28 @@ def get_optimizer(self): opt = hvd.DistributedOptimizer(opt) return opt - def get_losses(self): + def get_losses(self, is_training=True): """ Model compile losses + :param is_training: is training or not :return: Return model compile losses """ - return 'sparse_categorical_crossentropy' + softmax_loss = tf.keras.losses.SparseCategoricalCrossentropy() + logits_loss = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True) + if self.config.get_attribute('scheduler') == 'distill' and is_training: + return None + else: + if self.config.get_attribute("classifier_activation", "softmax") == "softmax": + return [softmax_loss, None] + else: + return [None, logits_loss] - def get_metrics(self): + def get_metrics(self, is_training=True): """ Model compile metrics + :param is_training: is training or not :return: Return model compile metrics """ + if self.config.get_attribute('scheduler') == 'distill' and is_training: + return None return ['sparse_categorical_accuracy'] diff --git a/src/model_optimizer/pruner/learner/vgg_m_16_cifar10.py b/src/model_optimizer/pruner/learner/vgg_m_16_cifar10.py index 94e4be1..326e9ee 100644 --- a/src/model_optimizer/pruner/learner/vgg_m_16_cifar10.py +++ b/src/model_optimizer/pruner/learner/vgg_m_16_cifar10.py @@ -15,7 +15,7 @@ class Learner(LearnerBase): VGG_m_16 on cifar10 Learner """ def __init__(self, config): - super(Learner, self).__init__(config) + super().__init__(config) self.callbacks = [ # Horovod: broadcast initial variable states from rank 0 to all other processes. # This is necessary to ensure consistent initialization of all workers when @@ -50,16 +50,20 @@ def get_optimizer(self): opt = hvd.DistributedOptimizer(opt) return opt - def get_losses(self): + def get_losses(self, is_training=True): """ Model compile losses + :param is_training: is training or not :return: Return model compile losses """ return 'sparse_categorical_crossentropy' - def get_metrics(self): + def get_metrics(self, is_training=True): """ Model compile metrics + :param is_training: is training or not :return: Return model compile metrics """ + if self.config.get_attribute('scheduler') == 'distill' and is_training: + return None return ['sparse_categorical_accuracy'] diff --git a/src/model_optimizer/pruner/models/__init__.py b/src/model_optimizer/pruner/models/__init__.py index ae0f003..16d806c 100644 --- a/src/model_optimizer/pruner/models/__init__.py +++ b/src/model_optimizer/pruner/models/__init__.py @@ -4,8 +4,11 @@ """ Get model """ +from ..scheduler.common import get_scheduler +from ..distill.distiller import get_distiller +# pylint: disable=too-many-return-statements def get_model(config, is_training=True): """ Get model @@ -14,26 +17,35 @@ def get_model(config, is_training=True): :return: class of keras Model """ model_name = config.get_attribute('model_name') - if model_name not in ['lenet', 'resnet_18', 'vgg_m_16', 'resnet_50', + scheduler_config = get_scheduler(config) + if model_name not in ['lenet', 'resnet_18', 'vgg_m_16', 'resnet_50', 'resnet_101', 'mobilenet_v1', 'mobilenet_v2']: raise Exception('Not support model %s' % model_name) if model_name == 'lenet': from .lenet import lenet - return lenet(is_training) + return lenet(model_name, is_training) elif model_name == 'vgg_m_16': from .vgg import vgg_m_16 - return vgg_m_16(is_training) + return vgg_m_16(is_training, model_name) elif model_name == 'resnet_18': from .resnet import resnet_18 - return resnet_18(is_training) + return resnet_18(is_training, model_name) elif model_name == 'resnet_50': from .resnet import resnet_50 - return resnet_50(is_training) + student_model = resnet_50(is_training, model_name) + if config.get_attribute('scheduler') == 'distill': + distill_model = get_distiller(student_model, scheduler_config) + return distill_model + else: + return student_model + elif model_name == 'resnet_101': + from .resnet import resnet_101 + return resnet_101(is_training, model_name) elif model_name == 'mobilenet_v1': from .mobilenet_v1 import mobilenet_v1_1 - return mobilenet_v1_1(is_training=is_training) + return mobilenet_v1_1(is_training=is_training, name=model_name) elif model_name == 'mobilenet_v2': from .mobilenet_v2 import mobilenet_v2_1 - return mobilenet_v2_1(is_training=is_training) + return mobilenet_v2_1(is_training=is_training, name=model_name) else: raise Exception('Not support model {}'.format(model_name)) diff --git a/src/model_optimizer/pruner/models/lenet.py b/src/model_optimizer/pruner/models/lenet.py index 9eba187..44767e3 100644 --- a/src/model_optimizer/pruner/models/lenet.py +++ b/src/model_optimizer/pruner/models/lenet.py @@ -7,9 +7,12 @@ import tensorflow as tf -def lenet(is_training=True): +def lenet(name, is_training=True): """ This implements a slightly modified LeNet-5 [LeCun et al., 1998a] + :param name: the model name + :param is_training: if training or not + :return: LeNet model """ input_ = tf.keras.layers.Input(shape=(28, 28, 1), name='input') x = tf.keras.layers.Conv2D(filters=6, @@ -31,5 +34,5 @@ def lenet(is_training=True): x = tf.keras.layers.Dense(120, activation='relu', name='dense_1')(x) x = tf.keras.layers.Dense(84, activation='relu', name='dense_2')(x) output_ = tf.keras.layers.Dense(10, activation='softmax', name='dense_3')(x) - model = tf.keras.Model(input_, output_) + model = tf.keras.Model(input_, output_, name=name) return model diff --git a/src/model_optimizer/pruner/models/mobilenet_v1.py b/src/model_optimizer/pruner/models/mobilenet_v1.py index 8024309..6ace82d 100644 --- a/src/model_optimizer/pruner/models/mobilenet_v1.py +++ b/src/model_optimizer/pruner/models/mobilenet_v1.py @@ -68,28 +68,30 @@ def mobilenet_v1_0_75(num_classes=1001, return _mobilenet_v1(num_classes, dropout_prob, is_training, scale=0.75, depth_multiplier=depth_multiplier) -def mobilenet_v1_1(num_classes=1001, +def mobilenet_v1_1(name, num_classes=1001, dropout_prob=1e-3, is_training=True, depth_multiplier=1): """ Build mobilenet_v1_1.0 model + :param name: the model name :param num_classes: :param dropout_prob: :param is_training: :param depth_multiplier: :return: """ - return _mobilenet_v1(num_classes, dropout_prob, is_training, scale=1.0, depth_multiplier=depth_multiplier) + return _mobilenet_v1(name, num_classes, dropout_prob, is_training, scale=1.0, depth_multiplier=depth_multiplier) -def _mobilenet_v1(num_classes=1000, +def _mobilenet_v1(name, num_classes=1000, dropout_prob=1e-3, is_training=True, scale=1.0, depth_multiplier=1): """ Build mobilenet_v1 model + :param name: the model name :param num_classes: :param dropout_prob: :param is_training: @@ -131,7 +133,7 @@ def _mobilenet_v1(num_classes=1000, name='conv_preds')(x) x = tf.keras.layers.Reshape((num_classes,), name='reshape_2')(x) outputs = tf.keras.layers.Activation('softmax', name='act_softmax')(x) - model = tf.keras.Model(inputs, outputs) + model = tf.keras.Model(inputs, outputs, name=name) return model diff --git a/src/model_optimizer/pruner/models/mobilenet_v2.py b/src/model_optimizer/pruner/models/mobilenet_v2.py index 23eeef0..dab3a9c 100644 --- a/src/model_optimizer/pruner/models/mobilenet_v2.py +++ b/src/model_optimizer/pruner/models/mobilenet_v2.py @@ -61,30 +61,32 @@ def mobilenet_v2_0_75(num_classes=1001, return _mobilenet_v2(num_classes, dropout_prob, is_training, scale=0.75) -def mobilenet_v2_1(num_classes=1001, +def mobilenet_v2_1(name, num_classes=1001, dropout_prob=1e-3, is_training=True): """ Build mobilenet_v2_1.0 model + :param name: the model name :param num_classes: :param dropout_prob: :param is_training: :return: """ - return _mobilenet_v2(num_classes, dropout_prob, is_training, scale=1.0) + return _mobilenet_v2(name, num_classes, dropout_prob, is_training, scale=1.0) -def mobilenet_v2_1_3(num_classes=1001, +def mobilenet_v2_1_3(name, num_classes=1001, dropout_prob=1e-3, is_training=True): """ Build mobilenet_v2_1.3 model + :param name: the model name :param num_classes: :param dropout_prob: :param is_training: :return: """ - return _mobilenet_v2(num_classes, dropout_prob, is_training, scale=1.3) + return _mobilenet_v2(name, num_classes, dropout_prob, is_training, scale=1.3) def mobilenet_v2_1_4(num_classes=1001, @@ -100,12 +102,13 @@ def mobilenet_v2_1_4(num_classes=1001, return _mobilenet_v2(num_classes, dropout_prob, is_training, scale=1.4) -def _mobilenet_v2(num_classes=1001, +def _mobilenet_v2(name, num_classes=1001, dropout_prob=1e-3, is_training=True, scale=1.0): """ Build mobilenet_v2 model + :param name: the model name :param num_classes: :param dropout_prob: :param is_training: @@ -193,7 +196,7 @@ def _mobilenet_v2(num_classes=1001, name='conv_preds')(x) x = tf.keras.layers.Reshape((num_classes,), name='reshape_2')(x) outputs = tf.keras.layers.Activation('softmax', name='act_softmax')(x) - model = tf.keras.Model(inputs, outputs) + model = tf.keras.Model(inputs, outputs, name=name) return model diff --git a/src/model_optimizer/pruner/models/resnet.py b/src/model_optimizer/pruner/models/resnet.py index 6fad6df..93e1590 100644 --- a/src/model_optimizer/pruner/models/resnet.py +++ b/src/model_optimizer/pruner/models/resnet.py @@ -20,10 +20,11 @@ def _gen_l2_regularizer(use_l2_regularizer=True): return tf.keras.regularizers.l2(L2_WEIGHT_DECAY) if use_l2_regularizer else None -def resnet(layer_num, num_classes=1001, use_l2_regularizer=True, is_training=True): +def resnet(layer_num, name, num_classes=1001, use_l2_regularizer=True, is_training=True): """ Build resnet-18 resnet-34 resnet-50 resnet-101 resnet-152 model :param layer_num: 18, 34, 50, 101, 152 + :param name: the model name :param num_classes: classification class :param use_l2_regularizer: if use l2_regularizer :param is_training: if training or not @@ -58,40 +59,52 @@ def resnet(layer_num, num_classes=1001, use_l2_regularizer=True, is_training=Tru num_filters *= 2 x = tf.keras.layers.AveragePooling2D(pool_size=7, name='avg1')(x) x = tf.keras.layers.Flatten(name='flat1')(x) - outputs = tf.keras.layers.Dense(num_classes, activation='softmax', - kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), - kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer), - bias_regularizer=_gen_l2_regularizer(use_l2_regularizer), - name='dense1')(x) - model = tf.keras.Model(inputs, outputs) + logits = tf.keras.layers.Dense(num_classes, kernel_initializer=tf.keras.initializers.RandomNormal(stddev=0.01), + kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer), + bias_regularizer=_gen_l2_regularizer(use_l2_regularizer), name='dense1')(x) + outputs = tf.keras.layers.Softmax()(logits) + model = tf.keras.Model(inputs, [outputs, logits], name=name) return model -def resnet_18(is_training): +def resnet_18(is_training, name): """ Build resnet-18 model :param is_training: if training or not + :param name: the model name :return: resnet-18 model """ - return resnet(18, is_training=is_training) + return resnet(18, is_training=is_training, name=name) -def resnet_34(is_training): +def resnet_34(is_training, name): """ Build resnet-34 model :param is_training: if training or not + :param name: the model name :return: resnet-34 model """ - return resnet(34, is_training=is_training) + return resnet(34, is_training=is_training, name=name) -def resnet_50(is_training): +def resnet_50(is_training, name): """ Build resnet-50 model :param is_training: if training or not + :param name: the model name :return: resnet-50 model """ - return resnet(50, is_training=is_training) + return resnet(50, is_training=is_training, name=name) + + +def resnet_101(is_training, name): + """ + Build resnet-101 model + :param is_training: if training or not + :param name: the model name + :return: resnet-101 model + """ + return resnet(101, is_training=is_training, name=name) def residual_block(stage, block_num, input_data, filters, kernel_size, is_training): diff --git a/src/model_optimizer/pruner/models/vgg.py b/src/model_optimizer/pruner/models/vgg.py index dafd55e..3430f2b 100644 --- a/src/model_optimizer/pruner/models/vgg.py +++ b/src/model_optimizer/pruner/models/vgg.py @@ -11,48 +11,56 @@ BATCH_NORM_EPSILON = 1e-5 -def vgg_16(is_training, num_classes=1001, use_l2_regularizer=True): +def vgg_16(is_training, name, num_classes=1001, use_l2_regularizer=True): """ VGG-16 model :param is_training: if training or not + :param name: the model name :param num_classes: classification class :param use_l2_regularizer: if use l2 regularizer or not :return: """ - return vgg(ver='D', is_training=is_training, num_classes=num_classes, use_l2_regularizer=use_l2_regularizer) + return vgg(ver='D', is_training=is_training, name=name, num_classes=num_classes, + use_l2_regularizer=use_l2_regularizer) -def vgg_19(is_training, num_classes=1001, use_l2_regularizer=True): +def vgg_19(is_training, name, num_classes=1001, use_l2_regularizer=True): """ VGG-19 model :param is_training: if training or not + :param name: the model name :param num_classes: classification class :param use_l2_regularizer: if use l2 regularizer or not :return: """ - return vgg(ver='E', is_training=is_training, num_classes=num_classes, use_l2_regularizer=use_l2_regularizer) + return vgg(ver='E', is_training=is_training, name=name, num_classes=num_classes, + use_l2_regularizer=use_l2_regularizer) -def vgg_m_16(is_training, num_classes=10, use_l2_regularizer=True): +def vgg_m_16(is_training, name, num_classes=10, use_l2_regularizer=True): """ VGG-M-16 model :param is_training: if training or not + :param name: the model name :param num_classes: classification class :param use_l2_regularizer: if use l2 regularizer or not :return: """ - return vgg_m(ver='D', is_training=is_training, num_classes=num_classes, use_l2_regularizer=use_l2_regularizer) + return vgg_m(ver='D', is_training=is_training, name=name, num_classes=num_classes, + use_l2_regularizer=use_l2_regularizer) -def vgg_m_19(is_training, num_classes=10, use_l2_regularizer=True): +def vgg_m_19(is_training, name, num_classes=10, use_l2_regularizer=True): """ VGG-M-19 model :param is_training: if training or not + :param name: the model name :param num_classes: classification class :param use_l2_regularizer: if use l2 regularizer or not :return: """ - return vgg_m(ver='E', is_training=is_training, num_classes=num_classes, use_l2_regularizer=use_l2_regularizer) + return vgg_m(ver='E', is_training=is_training, name=name, num_classes=num_classes, + use_l2_regularizer=use_l2_regularizer) def _gen_l2_regularizer(use_l2_regularizer=True): @@ -73,11 +81,12 @@ def _vgg_blocks(block, conv_num, filters, x, is_training, use_l2_regularizer=Tru return x -def vgg(ver, is_training, num_classes=1001, use_l2_regularizer=True): +def vgg(ver, is_training, name, num_classes=1001, use_l2_regularizer=True): """ VGG models :param ver: 'D' or 'E' :param is_training: if training or not + :param name: the model name :param num_classes: classification class :param use_l2_regularizer: if use l2 regularizer or not :return: @@ -113,15 +122,16 @@ def vgg(ver, is_training, num_classes=1001, use_l2_regularizer=True): kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer), bias_regularizer=_gen_l2_regularizer(use_l2_regularizer), name='fc3')(x) - model = tf.keras.Model(inputs, outputs) + model = tf.keras.Model(inputs, outputs, name=name) return model -def vgg_m(ver, is_training, num_classes=10, use_l2_regularizer=True): +def vgg_m(ver, is_training, name, num_classes=10, use_l2_regularizer=True): """ VGG-M models :param ver: 'D' or 'E' :param is_training: if training or not + :param name: the model name :param num_classes: classification class :param use_l2_regularizer: if use l2 regularizer or not :return: @@ -148,5 +158,5 @@ def vgg_m(ver, is_training, num_classes=10, use_l2_regularizer=True): kernel_regularizer=_gen_l2_regularizer(use_l2_regularizer), bias_regularizer=_gen_l2_regularizer(use_l2_regularizer), name='fc2')(x) - model = tf.keras.Model(inputs, outputs) + model = tf.keras.Model(inputs, outputs, name=name) return model diff --git a/src/model_optimizer/pruner/scheduler/common.py b/src/model_optimizer/pruner/scheduler/common.py index 2614cc1..0dafbad 100644 --- a/src/model_optimizer/pruner/scheduler/common.py +++ b/src/model_optimizer/pruner/scheduler/common.py @@ -26,7 +26,7 @@ def config_get_epochs_to_train(config): :return: """ scheduler_config = get_scheduler(config) - if scheduler_config is None: + if scheduler_config is None or config.get_attribute('scheduler') == 'distill': return 0, [], None return get_epochs_lr_to_train(scheduler_config) diff --git a/src/model_optimizer/pruner/scheduler/distill/resnet_50_imagenet_0.3.yaml b/src/model_optimizer/pruner/scheduler/distill/resnet_50_imagenet_0.3.yaml new file mode 100644 index 0000000..442efed --- /dev/null +++ b/src/model_optimizer/pruner/scheduler/distill/resnet_50_imagenet_0.3.yaml @@ -0,0 +1,5 @@ +version: 1 +alpha: 0.3 +temperature: 10 +student_name: "resnet_50" +teacher_path: "/root/work/examples/models_ckpt/resnet_101_imagenet_120e_logits/checkpoint-120.h5" \ No newline at end of file diff --git a/src/model_optimizer/quantizer/calib_dataset/cifar10.py b/src/model_optimizer/quantizer/calib_dataset/cifar10.py index 98b3900..05bb6ba 100644 --- a/src/model_optimizer/quantizer/calib_dataset/cifar10.py +++ b/src/model_optimizer/quantizer/calib_dataset/cifar10.py @@ -19,10 +19,11 @@ def __init__(self, data_path): :param data_path: tfrecord data path :return: """ - super(Cifar10Dataset, self).__init__(data_path) + super().__init__(data_path) self.dataset_fn = tf.data.TFRecordDataset # pylint: disable=R0201 + # pylint: disable=no-value-for-parameter,unexpected-keyword-arg def parse_fn(self, example_serialized): """ Parse features from the serialized data diff --git a/src/model_optimizer/quantizer/calib_dataset/imagenet.py b/src/model_optimizer/quantizer/calib_dataset/imagenet.py index ced525c..9c91f04 100644 --- a/src/model_optimizer/quantizer/calib_dataset/imagenet.py +++ b/src/model_optimizer/quantizer/calib_dataset/imagenet.py @@ -20,10 +20,11 @@ def __init__(self, data_path): :param data_path: tfrecord data path :return: """ - super(ImagenetDataset, self).__init__(data_path) + super().__init__(data_path) self.dataset_fn = tf.data.TFRecordDataset # pylint: disable=R0201 + # pylint: disable=no-value-for-parameter,unexpected-keyword-arg def parse_fn(self, example_serialized): """ Parse features from the serialized data diff --git a/src/model_optimizer/quantizer/calib_dataset/mnist.py b/src/model_optimizer/quantizer/calib_dataset/mnist.py index 59f6ee0..6fe1f38 100644 --- a/src/model_optimizer/quantizer/calib_dataset/mnist.py +++ b/src/model_optimizer/quantizer/calib_dataset/mnist.py @@ -20,10 +20,11 @@ def __init__(self, data_path): :param data_path: tfrecord data path :return: """ - super(MnistDataset, self).__init__(data_path) + super().__init__(data_path) self.dataset_fn = tf.data.TFRecordDataset # pylint: disable=R0201 + # pylint: disable=no-value-for-parameter,unexpected-keyword-arg def parse_fn(self, example_serialized): """ Parse features from the serialized data diff --git a/src/model_optimizer/quantizer/tflite/optimizer.py b/src/model_optimizer/quantizer/tflite/optimizer.py index 18a769f..06dec1e 100644 --- a/src/model_optimizer/quantizer/tflite/optimizer.py +++ b/src/model_optimizer/quantizer/tflite/optimizer.py @@ -18,7 +18,7 @@ class Quantizer(BaseQuantizer): """ def __init__(self, config, calibration_input_fn): - super(Quantizer, self).__init__(config) + super().__init__(config) self.calibration_input_fn = calibration_input_fn def _do_quantize(self): diff --git a/src/model_optimizer/quantizer/tftrt/optimizer.py b/src/model_optimizer/quantizer/tftrt/optimizer.py index 05211ae..0107094 100644 --- a/src/model_optimizer/quantizer/tftrt/optimizer.py +++ b/src/model_optimizer/quantizer/tftrt/optimizer.py @@ -19,7 +19,7 @@ class Quantizer(BaseQuantizer): """ def __init__(self, config, calibration_input_fn): - super(Quantizer, self).__init__(config) + super().__init__(config) self.calibration_input_fn = calibration_input_fn def _do_quantize(self): diff --git a/src/model_optimizer/stat.py b/src/model_optimizer/stat.py index 67af166..8b5afa9 100644 --- a/src/model_optimizer/stat.py +++ b/src/model_optimizer/stat.py @@ -8,6 +8,7 @@ import numpy as np +# pylint: disable=not-context-manager def get_keras_model_flops(model_h5_path): """ Get keras model FLOPs diff --git a/tests/test_model.py b/tests/test_model.py new file mode 100644 index 0000000..62f7dc8 --- /dev/null +++ b/tests/test_model.py @@ -0,0 +1,36 @@ +""" +Tests for the model_optimizer.models.get_model method. +""" +import os +# If you did not execute the setup.py, uncomment the following four lines +from model_optimizer.pruner.config import create_config_from_obj as prune_conf_from_obj +from model_optimizer.pruner.models import get_model + + +def test_get_model_distill(): + """ + test get_model function for distillation + """ + base_dir = os.path.dirname(__file__) + request = { + "dataset": "imagenet", + "model_name": "resnet_50", + "data_dir": "", + "batch_size": 256, + "batch_size_val": 100, + "learning_rate": 0.1, + "epochs": 90, + "checkpoint_path": os.path.join(base_dir, "./models_ckpt/resnet_50_imagenet_distill"), + "checkpoint_save_period": 5, # save a checkpoint every 5 epoch + "checkpoint_eval_path": os.path.join(base_dir, "./models_eval_ckpt/resnet_50_imagenet_distill"), + "scheduler": "train", + "scheduler_file_name": "resnet_50_imagenet_0.3.yaml", + "classifier_activation": None # None or "softmax", default is softmax + } + + config = prune_conf_from_obj(request) + train_model = get_model(config, is_training=True) + for layer in train_model.layers: + if layer.name == "DistillLoss": + assert False + break diff --git a/tests/test_pruner.py b/tests/test_pruner.py index 7337582..3d2b359 100644 --- a/tests/test_pruner.py +++ b/tests/test_pruner.py @@ -2,11 +2,12 @@ Tests for the model_optimizer package. """ import tensorflow as tf +import numpy as np from model_optimizer.pruner.core import AutoPruner from model_optimizer.pruner.core import SpecifiedLayersPruner -import numpy as np +# noqa: ignore=C901 def test_uniform_auto_prune(): """ Test the AutoPruner prune function. diff --git a/tools/common/model_predict.py b/tools/common/model_predict.py index 11adab0..fca160b 100644 --- a/tools/common/model_predict.py +++ b/tools/common/model_predict.py @@ -38,11 +38,12 @@ def _get_from_saved_model(graph_func, input_data, print_result=False): return output_data -def keras_model_predict(request, file_path): +def keras_model_predict(request, file_path, is_multi_output=False): """ Keras model predict :param request: dict, must match pruner config_schema.json :param file_path: file path + :param is_multi_output: the flag of multiple output of the model :return: """ ds_val = get_dataset(prune_conf_from_obj(request), is_training=False) @@ -53,7 +54,10 @@ def keras_model_predict(request, file_path): cur_steps = 0 start = time.time() for x_test, y_test in val_dataset: - result = keras_model.predict(x_test) + if is_multi_output: + result, _ = keras_model.predict(x_test) + else: + result = keras_model.predict(x_test) output_data = tf.keras.backend.argmax(result) for j in range(y_test.shape[0]): if int(output_data[j]) == int(y_test[j]): diff --git a/tools/keras_model_predict_resnet_50_imagenet.py b/tools/keras_model_predict_resnet_50_imagenet.py index 7892060..ccc9fe8 100644 --- a/tools/keras_model_predict_resnet_50_imagenet.py +++ b/tools/keras_model_predict_resnet_50_imagenet.py @@ -18,4 +18,4 @@ "batch_size_val": 64 } model_path = os.path.join(base_dir, '../examples/models_eval_ckpt/resnet_50_imagenet_pruned/checkpoint-120.h5') - keras_model_predict(request, model_path) + keras_model_predict(request, model_path, True)