From 17051e3cb4afd3519da9521ccb55736b7a612e9d Mon Sep 17 00:00:00 2001 From: balin Date: Tue, 25 Apr 2023 17:55:06 +0000 Subject: [PATCH] Added num_groups option to config used by Horovod distributed optimizer --- src/config/framework.py | 2 ++ src/utils/tensorflow2/distributed_trainer.py | 2 +- src/utils/torch/distributed_trainer.py | 3 ++- 3 files changed, 5 insertions(+), 2 deletions(-) diff --git a/src/config/framework.py b/src/config/framework.py index daddf6f7..e17015a0 100644 --- a/src/config/framework.py +++ b/src/config/framework.py @@ -18,12 +18,14 @@ class Tensorflow(Framework): name: str = "tensorflow" inter_op_parallelism_threads: int = 2 intra_op_parallelism_threads: int = 24 + num_groups: int = 0 @dataclass class Torch(Framework): name: str = "torch" sparse: bool = False distributed_mode: DistributedMode = DistributedMode.DDP + num_groups: int = 0 cs = ConfigStore.instance() cs.store(group="framework", name="tensorflow", node=Tensorflow) diff --git a/src/utils/tensorflow2/distributed_trainer.py b/src/utils/tensorflow2/distributed_trainer.py index 38a77133..f80d6b82 100644 --- a/src/utils/tensorflow2/distributed_trainer.py +++ b/src/utils/tensorflow2/distributed_trainer.py @@ -66,7 +66,7 @@ def init_optimizer(self): # Wrap the optimizer it in horovod: # self._opt = hvd.DistributedOptimizer(self._opt) - self.tape = hvd.DistributedGradientTape(self.tape, num_groups=1) + self.tape = hvd.DistributedGradientTape(self.tape, num_groups=self.args.framework.num_groups) def init_saver(self): if hvd.rank() == 0: diff --git a/src/utils/torch/distributed_trainer.py b/src/utils/torch/distributed_trainer.py index e8bfe10a..29423d89 100644 --- a/src/utils/torch/distributed_trainer.py +++ b/src/utils/torch/distributed_trainer.py @@ -195,7 +195,8 @@ def init_optimizer(self): torch_trainer.init_optimizer(self) if self.args.framework.distributed_mode == DistributedMode.horovod: - self._opt = hvd.DistributedOptimizer(self._opt, named_parameters=self._net.named_parameters()) + self._opt = hvd.DistributedOptimizer(self._opt, named_parameters=self._net.named_parameters(), + num_groups=self.args.framework.num_groups) # self._opt.param_groups[0]['capturable'] = True self.lr_scheduler = torch.optim.lr_scheduler.LambdaLR(self._opt, self.lr_calculator, last_epoch=-1)