From dfe639377434b8b658c2686ac33db5c66b7c974d Mon Sep 17 00:00:00 2001 From: Zeyu Qin Date: Mon, 3 Apr 2023 22:13:54 +0800 Subject: [PATCH] add new code about FedRep and Simple_tuning (#564) * add new code * updating * updating * re-formmat code * conflict resolve * updating * updating * Update test_femnist_simple_tuning.py * Update test_femnist_simple_tuning.py format change - (rerun-unitest) * Update __init__.py fix the import problem * Update __init__.py * re-formatted code --------- Co-authored-by: yuexiang.xyx Co-authored-by: Osier-Yi Co-authored-by: Daoyuan Chen <67475544+yxdyc@users.noreply.github.com> --- .../core/auxiliaries/trainer_builder.py | 9 ++ federatedscope/core/configs/cfg_fl_algo.py | 7 ++ federatedscope/core/configs/cfg_training.py | 7 ++ federatedscope/core/trainers/__init__.py | 6 +- .../core/trainers/trainer_FedRep.py | 99 +++++++++++++++++++ .../core/trainers/trainer_simple_tuning.py | 75 ++++++++++++++ tests/test_femnist_fedrep.py | 91 +++++++++++++++++ tests/test_femnist_simple_tuning.py | 89 +++++++++++++++++ 8 files changed, 382 insertions(+), 1 deletion(-) create mode 100644 federatedscope/core/trainers/trainer_FedRep.py create mode 100644 federatedscope/core/trainers/trainer_simple_tuning.py create mode 100644 tests/test_femnist_fedrep.py create mode 100644 tests/test_femnist_simple_tuning.py diff --git a/federatedscope/core/auxiliaries/trainer_builder.py b/federatedscope/core/auxiliaries/trainer_builder.py index 98feef7f6..b32baf74e 100644 --- a/federatedscope/core/auxiliaries/trainer_builder.py +++ b/federatedscope/core/auxiliaries/trainer_builder.py @@ -222,6 +222,10 @@ def get_trainer(model=None, # copy construct style: instance a (class A) -> instance b (class B) trainer = FedEMTrainer(model_nums=config.model.model_num_per_trainer, base_trainer=trainer) + elif config.federate.method.lower() == "fedrep": + from federatedscope.core.trainers import wrap_FedRepTrainer + # wrap style: instance a (class A) -> instance a (class A) + trainer = wrap_FedRepTrainer(trainer) # attacker plug-in if 'backdoor' in config.attack.attack_method: @@ -246,4 +250,9 @@ def get_trainer(model=None, from federatedscope.core.trainers import wrap_fedprox_trainer trainer = wrap_fedprox_trainer(trainer) + # different fine-tuning + if config.finetune.before_eval and config.finetune.simple_tuning: + from federatedscope.core.trainers import wrap_Simple_tuning_Trainer + trainer = wrap_Simple_tuning_Trainer(trainer) + return trainer diff --git a/federatedscope/core/configs/cfg_fl_algo.py b/federatedscope/core/configs/cfg_fl_algo.py index 7a4f9e7bf..244c9df73 100644 --- a/federatedscope/core/configs/cfg_fl_algo.py +++ b/federatedscope/core/configs/cfg_fl_algo.py @@ -63,6 +63,13 @@ def extend_fl_algo_cfg(cfg): cfg.personalization.K = 5 # the local approximation steps for pFedMe cfg.personalization.beta = 1.0 # the average moving parameter for pFedMe + # parameters for FedRep: + cfg.personalization.lr_feature = 0.1 # learning rate: feature extractors + cfg.personalization.lr_linear = 0.1 # learning rate: linear head + cfg.personalization.epoch_feature = 1 # training epoch number + cfg.personalization.epoch_linear = 2 # training epoch number + cfg.personalization.weight_decay = 0.0 + # ---------------------------------------------------------------------- # # FedSage+ related options, gfl # ---------------------------------------------------------------------- # diff --git a/federatedscope/core/configs/cfg_training.py b/federatedscope/core/configs/cfg_training.py index c896d29d1..6e98c3623 100644 --- a/federatedscope/core/configs/cfg_training.py +++ b/federatedscope/core/configs/cfg_training.py @@ -60,6 +60,13 @@ def extend_training_cfg(cfg): cfg.finetune.scheduler.type = '' cfg.finetune.scheduler.warmup_ratio = 0.0 + # simple-tuning + cfg.finetune.simple_tuning = False # use simple tuning, default: False + cfg.finetune.epoch_linear = 10 # training epoch number, default: 10 + cfg.finetune.lr_linear = 0.005 # learning rate for training linear head + cfg.finetune.weight_decay = 0.0 + cfg.finetune.local_param = [] # tuning parameters list + # ---------------------------------------------------------------------- # # Gradient related options # ---------------------------------------------------------------------- # diff --git a/federatedscope/core/trainers/__init__.py b/federatedscope/core/trainers/__init__.py index 072d6369d..11fede0d3 100644 --- a/federatedscope/core/trainers/__init__.py +++ b/federatedscope/core/trainers/__init__.py @@ -7,6 +7,9 @@ from federatedscope.core.trainers.trainer_pFedMe import wrap_pFedMeTrainer from federatedscope.core.trainers.trainer_Ditto import wrap_DittoTrainer from federatedscope.core.trainers.trainer_FedEM import FedEMTrainer +from federatedscope.core.trainers.trainer_FedRep import wrap_FedRepTrainer +from federatedscope.core.trainers.trainer_simple_tuning import \ + wrap_Simple_tuning_Trainer from federatedscope.core.trainers.context import Context from federatedscope.core.trainers.trainer_fedprox import wrap_fedprox_trainer from federatedscope.core.trainers.trainer_nbafl import wrap_nbafl_trainer, \ @@ -16,5 +19,6 @@ 'Trainer', 'Context', 'GeneralTorchTrainer', 'GeneralMultiModelTrainer', 'wrap_pFedMeTrainer', 'wrap_DittoTrainer', 'FedEMTrainer', 'wrap_fedprox_trainer', 'wrap_nbafl_trainer', 'wrap_nbafl_server', - 'BaseTrainer', 'GeneralTFTrainer' + 'wrap_Simple_tuning_Trainer', 'wrap_FedRepTrainer', 'BaseTrainer', + 'GeneralTFTrainer' ] diff --git a/federatedscope/core/trainers/trainer_FedRep.py b/federatedscope/core/trainers/trainer_FedRep.py new file mode 100644 index 000000000..0b577491d --- /dev/null +++ b/federatedscope/core/trainers/trainer_FedRep.py @@ -0,0 +1,99 @@ +import copy +import torch +import logging + +from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer + +from typing import Type + +logger = logging.getLogger(__name__) + + +def wrap_FedRepTrainer( + base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: + # ---------------------------------------------------------------------- # + # FedRep method: + # https://arxiv.org/abs/2102.07078 + # First training linear classifier and then feature extractor + # Linear classifier: local_param; feature extractor: global_param + # ---------------------------------------------------------------------- # + init_FedRep_ctx(base_trainer) + + base_trainer.register_hook_in_train(new_hook=hook_on_fit_start_fedrep, + trigger="on_fit_start", + insert_pos=-1) + + base_trainer.register_hook_in_train(new_hook=hook_on_epoch_start_fedrep, + trigger="on_epoch_start", + insert_pos=-1) + + return base_trainer + + +def init_FedRep_ctx(base_trainer): + + ctx = base_trainer.ctx + cfg = base_trainer.cfg + + ctx.epoch_feature = cfg.personalization.epoch_feature + ctx.epoch_linear = cfg.personalization.epoch_linear + + ctx.num_train_epoch = ctx.epoch_feature + ctx.epoch_linear + + ctx.epoch_number = 0 + + ctx.lr_feature = cfg.personalization.lr_feature + ctx.lr_linear = cfg.personalization.lr_linear + ctx.weight_decay = cfg.personalization.weight_decay + + ctx.local_param = cfg.personalization.local_param + + ctx.local_update_param = [] + ctx.global_update_param = [] + + for name, param in ctx.model.named_parameters(): + if name.split(".")[0] in ctx.local_param: + ctx.local_update_param.append(param) + else: + ctx.global_update_param.append(param) + + +def hook_on_fit_start_fedrep(ctx): + + ctx.num_train_epoch = ctx.epoch_feature + ctx.epoch_linear + ctx.epoch_number = 0 + + ctx.optimizer_for_feature = torch.optim.SGD(ctx.global_update_param, + lr=ctx.lr_feature, + momentum=0, + weight_decay=ctx.weight_decay) + ctx.optimizer_for_linear = torch.optim.SGD(ctx.local_update_param, + lr=ctx.lr_linear, + momentum=0, + weight_decay=ctx.weight_decay) + + for name, param in ctx.model.named_parameters(): + + if name.split(".")[0] in ctx.local_param: + param.requires_grad = True + else: + param.requires_grad = False + + ctx.optimizer = ctx.optimizer_for_linear + + +def hook_on_epoch_start_fedrep(ctx): + + ctx.epoch_number += 1 + + if ctx.epoch_number == ctx.epoch_linear + 1: + + for name, param in ctx.model.named_parameters(): + + if name.split(".")[0] in ctx.local_param: + param.requires_grad = False + else: + param.requires_grad = True + + ctx.optimizer = ctx.optimizer_for_feature diff --git a/federatedscope/core/trainers/trainer_simple_tuning.py b/federatedscope/core/trainers/trainer_simple_tuning.py new file mode 100644 index 000000000..b0eb70765 --- /dev/null +++ b/federatedscope/core/trainers/trainer_simple_tuning.py @@ -0,0 +1,75 @@ +import copy +import torch +import logging +import math + +from federatedscope.core.auxiliaries.optimizer_builder import get_optimizer +from federatedscope.core.trainers.torch_trainer import GeneralTorchTrainer + +from typing import Type + +logger = logging.getLogger(__name__) + + +def wrap_Simple_tuning_Trainer( + base_trainer: Type[GeneralTorchTrainer]) -> Type[GeneralTorchTrainer]: + # ---------------------------------------------------------------------- # + # Simple_tuning method: + # https://arxiv.org/abs/2302.01677 + # Only tuning the linear classifier and freeze the feature extractor + # the key is to reinitialize the linear classifier + # ---------------------------------------------------------------------- # + init_Simple_tuning_ctx(base_trainer) + + base_trainer.register_hook_in_ft(new_hook=hook_on_fit_start_simple_tuning, + trigger="on_fit_start", + insert_pos=-1) + + return base_trainer + + +def init_Simple_tuning_ctx(base_trainer): + + ctx = base_trainer.ctx + cfg = base_trainer.cfg + + ctx.epoch_linear = cfg.finetune.epoch_linear + + ctx.num_train_epoch = ctx.epoch_linear + + ctx.epoch_number = 0 + + ctx.lr_linear = cfg.finetune.lr_linear + ctx.weight_decay = cfg.finetune.weight_decay + + ctx.local_param = cfg.finetune.local_param + + ctx.local_update_param = [] + + for name, param in ctx.model.named_parameters(): + if name.split(".")[0] in ctx.local_param: + ctx.local_update_param.append(param) + + +def hook_on_fit_start_simple_tuning(ctx): + + ctx.num_train_epoch = ctx.epoch_linear + ctx.epoch_number = 0 + + ctx.optimizer_for_linear = torch.optim.SGD(ctx.local_update_param, + lr=ctx.lr_linear, + momentum=0, + weight_decay=ctx.weight_decay) + + for name, param in ctx.model.named_parameters(): + if name.split(".")[0] in ctx.local_param: + if name.split(".")[1] == 'weight': + stdv = 1. / math.sqrt(param.size(-1)) + param.data.uniform_(-stdv, stdv) + else: + param.data.uniform_(-stdv, stdv) + param.requires_grad = True + else: + param.requires_grad = False + + ctx.optimizer = ctx.optimizer_for_linear diff --git a/tests/test_femnist_fedrep.py b/tests/test_femnist_fedrep.py new file mode 100644 index 000000000..bcb880d5d --- /dev/null +++ b/tests/test_femnist_fedrep.py @@ -0,0 +1,91 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from federatedscope.core.auxiliaries.data_builder import get_data +from federatedscope.core.auxiliaries.utils import setup_seed +from federatedscope.core.auxiliaries.logging import update_logger +from federatedscope.core.configs.config import global_cfg +from federatedscope.core.auxiliaries.runner_builder import get_runner +from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls + +SAMPLE_CLIENT_NUM = 5 + + +class FedRep_Testing(unittest.TestCase): + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def set_config_Fedrep_femnist(self, cfg): + backup_cfg = cfg.clone() + + import torch + cfg.use_gpu = torch.cuda.is_available() + cfg.eval.freq = 10 + cfg.eval.metrics = ['acc', 'loss_regular'] + + cfg.federate.mode = 'standalone' + cfg.train.local_update_steps = 5 + cfg.federate.total_round_num = 20 + cfg.federate.sample_client_num = SAMPLE_CLIENT_NUM + + cfg.data.root = 'test_data/' + cfg.data.type = 'femnist' + cfg.data.splits = [0.6, 0.2, 0.2] + cfg.data.batch_size = 10 + cfg.data.subsample = 0.05 + cfg.data.transform = [['ToTensor'], + [ + 'Normalize', { + 'mean': [0.9637], + 'std': [0.1592] + } + ]] + + cfg.model.type = 'convnet2' + cfg.model.hidden = 2048 + cfg.model.out_channels = 62 + + cfg.train.optimizer.lr = 0.001 + cfg.train.optimizer.weight_decay = 0.0 + cfg.grad.grad_clip = 5.0 + + cfg.criterion.type = 'CrossEntropyLoss' + cfg.trainer.type = 'cvtrainer' + cfg.seed = 123 + cfg.personalization.local_param = ['fc2'] + cfg.personalization.local_update_steps = 2 + cfg.personalization.regular_weight = 0.1 + cfg.personalization.epoch_feature = 2 + cfg.personalization.epoch_linear = 1 + cfg.personalization.lr_feature = 0.1 + cfg.personalization.lr_linear = 0.1 + + return backup_cfg + + def test_femnist_standalone(self): + init_cfg = global_cfg.clone() + backup_cfg = self.set_config_Fedrep_femnist(init_cfg) + setup_seed(init_cfg.seed) + update_logger(init_cfg, True) + + data, modified_cfg = get_data(init_cfg.clone()) + init_cfg.merge_from_other_cfg(modified_cfg) + self.assertIsNotNone(data) + self.assertEqual(init_cfg.federate.sample_client_num, + SAMPLE_CLIENT_NUM) + + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) + self.assertIsNotNone(Fed_runner) + test_best_results = Fed_runner.run() + print(test_best_results) + init_cfg.merge_from_other_cfg(backup_cfg) + self.assertLess( + test_best_results["client_summarized_weighted_avg"]['test_loss'], + 1200) + + +if __name__ == '__main__': + unittest.main() diff --git a/tests/test_femnist_simple_tuning.py b/tests/test_femnist_simple_tuning.py new file mode 100644 index 000000000..ccc915d15 --- /dev/null +++ b/tests/test_femnist_simple_tuning.py @@ -0,0 +1,89 @@ +# Copyright (c) Alibaba, Inc. and its affiliates. +import unittest + +from federatedscope.core.auxiliaries.data_builder import get_data +from federatedscope.core.auxiliaries.utils import setup_seed +from federatedscope.core.auxiliaries.logging import update_logger +from federatedscope.core.configs.config import global_cfg +from federatedscope.core.auxiliaries.runner_builder import get_runner +from federatedscope.core.auxiliaries.worker_builder import get_server_cls, get_client_cls + +SAMPLE_CLIENT_NUM = 5 + + +class Simple_tuning_Testing(unittest.TestCase): + def setUp(self): + print(('Testing %s.%s' % (type(self).__name__, self._testMethodName))) + + def set_config_femnist_simple_tuning(self, cfg): + backup_cfg = cfg.clone() + + import torch + cfg.use_gpu = torch.cuda.is_available() + cfg.eval.freq = 10 + cfg.eval.metrics = ['acc', 'loss_regular'] + + cfg.federate.mode = 'standalone' + cfg.train.local_update_steps = 5 + cfg.federate.total_round_num = 20 + cfg.federate.sample_client_num = SAMPLE_CLIENT_NUM + + cfg.data.root = 'test_data/' + cfg.data.type = 'femnist' + cfg.data.splits = [0.6, 0.2, 0.2] + cfg.data.batch_size = 10 + cfg.data.subsample = 0.05 + cfg.data.transform = [['ToTensor'], + [ + 'Normalize', { + 'mean': [0.9637], + 'std': [0.1592] + } + ]] + + cfg.model.type = 'convnet2' + cfg.model.hidden = 2048 + cfg.model.out_channels = 62 + + cfg.train.optimizer.lr = 0.001 + cfg.train.optimizer.weight_decay = 0.0 + cfg.grad.grad_clip = 5.0 + + cfg.criterion.type = 'CrossEntropyLoss' + cfg.trainer.type = 'cvtrainer' + cfg.seed = 123 + cfg.finetune.before_eval = True + cfg.finetune.simple_tuning = True + cfg.finetune.local_update_steps = 5 + cfg.finetune.local_param = ['fc2'] + cfg.finetune.lr_linear = 0.005 + + return backup_cfg + + def test_femnist_standalone(self): + init_cfg = global_cfg.clone() + backup_cfg = self.set_config_femnist_simple_tuning(init_cfg) + setup_seed(init_cfg.seed) + update_logger(init_cfg, True) + + data, modified_cfg = get_data(init_cfg.clone()) + init_cfg.merge_from_other_cfg(modified_cfg) + self.assertIsNotNone(data) + self.assertEqual(init_cfg.federate.sample_client_num, + SAMPLE_CLIENT_NUM) + + Fed_runner = get_runner(data=data, + server_class=get_server_cls(init_cfg), + client_class=get_client_cls(init_cfg), + config=init_cfg.clone()) + self.assertIsNotNone(Fed_runner) + test_best_results = Fed_runner.run() + print(test_best_results) + init_cfg.merge_from_other_cfg(backup_cfg) + self.assertLess( + test_best_results["client_summarized_weighted_avg"]['test_loss'], + 1200) + + +if __name__ == '__main__': + unittest.main()