diff --git a/src/train.py b/src/train.py index 9e1da97..fb55096 100644 --- a/src/train.py +++ b/src/train.py @@ -189,7 +189,7 @@ def train(train_file, validation_file, batch_size, epoch_limit, file_name, gpu_m + ' AT EPOCH: ' + str(start_epoch) + "\n" + TextColor.END) if gpu_mode: - model = torch.nn.DataParallel(model).cuda() + model = torch.nn.DistributedDataParallel(model).cuda() # Train the Model sys.stderr.write(TextColor.PURPLE + 'Training starting\n' + TextColor.END)