-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtrain_net.py
663 lines (580 loc) · 29.1 KB
/
train_net.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
import os
import re
import weakref
import logging
import shutil
import time
import signal
import subprocess
from pathlib import Path
from argparse import Namespace
from typing import Callable, Dict, Optional, Set, List
import torch
from torch.utils.data import DataLoader
from fvcore.nn.precise_bn import get_bn_modules
import detectron2.utils.comm as comm
from detectron2.engine.defaults import create_ddp_model, DefaultTrainer, TrainerBase
from detectron2.engine.defaults import hooks as d2_hooks
from detectron2.checkpoint import DetectionCheckpointer
from detectron2.evaluation import DatasetEvaluators, DatasetEvaluator, verify_results
from detectron2.evaluation import COCOPanopticEvaluator
from detectron2.utils.logger import setup_logger
from detectron2.utils.events import EventStorage
from detectron2.config import CfgNode
from mask2former import (
COCOInstanceNewBaselineDatasetMapper,
COCOPanopticNewBaselineDatasetMapper,
MaskFormerInstanceDatasetMapper,
MaskFormerPanopticDatasetMapper,
MaskFormerSemanticDatasetMapper,
)
from .events import CustomJSONWriter, CustomCommonMetricPrinter
from .config import (
update_config_epochs,
get_config_value,
)
from .train_loop import (
CustomAMPTrainer,
CustomSimpleTrainer,
parse_dataloader,
)
from .build import (
build_detection_train_loader,
build_detection_test_loader,
)
from .panoptic_evaluation import PartialCOCOPanopticEvaluator
from .hooks import CustomLRScheduler, CustomBestCheckpointer, ETAHook, CustomEvalHook, EpochIterHook
from .wandb import CustomWandbWriter
logger = logging.getLogger(__name__)
class CustomTrainerMixin:
_dataloaders: Dict[str, DataLoader] = {}
_test_eval_fp16: bool = False
def __init__(self, cfg):
"""Same as train_net.py but uses CustomAMPTrainer instead of AMPTrainer and
CustomSimpleTrainer instead of SimpleTrainer. Intended to be used as a mixin like:
```
class CustomTrainer(CustomTrainerMixin, Trainer):
pass
```
"""
# This init overrides DefaultTrainer.__init__(), so call TrainerBase.__init__() directly
TrainerBase.__init__(self)
cfg = DefaultTrainer.auto_scale_workers(cfg, comm.get_world_size())
# Update: Build data_loader before model/optimizer so we can set SOLVER.MAX_ITER
# The data_loader should not depend on the model or optimizer
data_loader = self.build_train_loader(cfg)
# Update: Convert epoch-based to iter-based metrics
pytorch_data_loader, steps_per_epoch, batch_size = parse_dataloader(data_loader)
update_config_epochs(cfg=cfg, steps_per_epoch=steps_per_epoch)
# Assume these objects must be constructed in this order.
model = self.build_model(cfg)
optimizer = self.build_optimizer(cfg, model)
model = create_ddp_model(model, broadcast_buffers=False)
# Update: Use CustomAMPTrainer and CustomSimpleTrainer to log epoch / epoch_float
trainer_cls = CustomAMPTrainer if cfg.SOLVER.AMP.ENABLED else CustomSimpleTrainer
self._trainer = trainer_cls(
steps_per_epoch=steps_per_epoch,
model=model,
data_loader=data_loader,
optimizer=optimizer,
)
# Update: Define test precision if set
test_eval_fp16: Optional[bool] = get_config_value(cfg=cfg, key="TEST.EVAL_FP16")
if test_eval_fp16:
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required if TEST.EVAL_FP16=True")
CustomTrainerMixin._test_eval_fp16 = True
# Update: Add early_exit_iter and other metrics we print out in self._train()
early_exit_iter: Optional[int] = get_config_value(cfg=cfg, key="SOLVER.EARLY_EXIT_ITER")
early_exit_epochs: Optional[int] = get_config_value(cfg=cfg, key="SOLVER.EARLY_EXIT_EPOCHS")
slurm_requeue_num_epochs: Optional[int] = get_config_value(
cfg=cfg, key="SOLVER.SLURM_REQUEUE_NUM_EPOCHS"
)
slurm_job_id: Optional[str] = None
if early_exit_iter is not None and early_exit_epochs is not None:
raise RuntimeError(
f"Found both SOLVER.EARLY_EXIT_ITER={early_exit_iter} and"
f" SOLVER.EARLY_EXIT_EPOCHS={early_exit_epochs}. Expected only one to be set."
)
elif early_exit_epochs is not None:
early_exit_iter = early_exit_epochs * steps_per_epoch
slurm_requeue_num_iter: Optional[int] = None
if slurm_requeue_num_epochs is not None:
slurm_requeue_num_iter = slurm_requeue_num_epochs * steps_per_epoch
slurm_job_id = os.environ.get("SLURM_JOB_ID")
if slurm_job_id is None:
raise RuntimeError(
f"Expected SLURM_JOB_ID in environment if using SOLVER.SLURM_REQUEUE_NUM_EPOCHS"
)
self._early_exit_iter = early_exit_iter
self._slurm_requeue_num_iter = slurm_requeue_num_iter
self._slurm_job_id = slurm_job_id
self._per_gpu_batch_size = batch_size
self._total_batch_size = batch_size * comm.get_world_size()
self.scheduler = self.build_lr_scheduler(cfg, optimizer)
self.checkpointer = DetectionCheckpointer(
# Assume you want to save checkpoints together with logs/statistics
model,
cfg.OUTPUT_DIR,
trainer=weakref.proxy(self),
)
self.start_iter = 0
self.max_iter = cfg.SOLVER.MAX_ITER
self.cfg = cfg
self.register_hooks(self.build_hooks())
self._optimizer_named_params: Optional[Dict[str, Dict]] = None
CustomTrainerMixin._dataloaders["train"] = pytorch_data_loader
@property
def optimizer_named_params(self) -> Dict[str, dict]:
if self._optimizer_named_params is None:
# Feature still doesn't exist, see https://github.com/pytorch/pytorch/issues/1489
# Iterate over named params just like mask2former.Trainer.build_optimizer()
param_idx = 0
optimizer_named_params = {}
for module_name, module in self._trainer.model.named_modules():
for module_param_name, value in module.named_parameters(recurse=False):
full_name = f"{module_name}.{module_param_name}"
params_dict = self._trainer.optimizer.param_groups[param_idx]
assert len(params_dict["params"]) == 1 and id(params_dict["params"][0]) == id(
value
), "Params and module tensor mismatch"
optimizer_named_params[full_name] = params_dict # Store orig params dict
param_idx += 1
assert len(optimizer_named_params) == len(
self._trainer.optimizer.param_groups
), "Missing named params"
self._optimizer_named_params = optimizer_named_params
return self._optimizer_named_params
def compute_unique_lr_groups(self) -> Dict[str, float]:
"""
Iterate over optimizer params and return a dictionary of top-most keys and their current LRs
This is re-computed each iteration so the LR values are "current"
"""
# Construct map for lr to named params
groups_to_lrs: Dict[str, Set[str]] = {}
for name, params in self.optimizer_named_params.items():
lr_str = str(params["lr"]) # Use string as key
split_name = name.split(".")
for idx in range(len(split_name)):
group_name = ".".join(split_name[: idx + 1])
if group_name not in groups_to_lrs:
groups_to_lrs[group_name] = set()
groups_to_lrs[group_name].add(lr_str)
# Mark names to be removed if they are redundant (the "parent" prefix has only one lr value)
remove_names = []
for name in list(groups_to_lrs.keys()):
parent_name = ".".join(name.split(".")[:-1])
if parent_name in groups_to_lrs and len(groups_to_lrs[parent_name]) == 1:
remove_names.append(name) # This param is "covered" by parent
unique_lr_groups: Dict[str, float] = {
name: float(next(iter(lr_vals))) # Use next(iter(s)) to get single element in set s
for name, lr_vals in groups_to_lrs.items()
if name not in remove_names and len(lr_vals) == 1
}
return unique_lr_groups # e.g. {'backbone': 1e-05, 'sem_seg_head': 0.0001}
def resume_or_load(self, resume=True):
"""
Same as detectron2 but we update cfg.MODEL.WEIGHTS if it's missing, but found in cache dir
"""
weights_path = self.cfg.MODEL.WEIGHTS
if "://" not in self.cfg.MODEL.WEIGHTS: # Local filepath, try to find in torch cache dir
weights_path = Path(weights_path)
if not weights_path.exists():
cache_dir = torch.hub.get_dir()
matching_files = [
filepath
for filepath in Path(cache_dir).rglob(f"*{weights_path.suffix}")
if filepath.name == weights_path.name
]
if len(matching_files) == 1:
logger.info(
f"Weights file {weights_path} not found. Found one matching filename in"
f" torchhub cache dir: {matching_files[0]}. Using matching filepath."
)
weights_path = str(matching_files[0])
elif not resume:
raise RuntimeError(
f"Failed to find weights file {weights_path}. Tried to use torchhub cache dir,"
f" found {len(matching_files)} in cache dir (must find exactly one to use"
f" cache dir path)"
)
# Same as detectron2.DefaultTrainer now
self.checkpointer.resume_or_load(weights_path, resume=resume)
if resume and self.checkpointer.has_checkpoint():
# The checkpoint stores the training iteration that just finished, thus we start
# at the next iteration
self.start_iter = self.iter + 1
def build_writers(self):
"""Same as OneFormer.train_net.py but uses cfg.LOGGER.INTERVAL and cfg.WANDB.ENABLED"""
log_interval = self.cfg.get("LOGGER", {}).get("INTERVAL", 20)
json_file = os.path.join(self.cfg.OUTPUT_DIR, "metrics.json")
return [
CustomCommonMetricPrinter(
early_exit_iter=self._early_exit_iter,
max_iter=self.max_iter,
window_size=log_interval,
),
CustomJSONWriter(
steps_per_epoch=self._trainer.steps_per_epoch, # For EventWriterMixin
json_file=json_file,
window_size=log_interval,
),
# Initialize wandb via `setup_wandb` from self._train (only want to init for training)
CustomWandbWriter(
steps_per_epoch=self._trainer.steps_per_epoch # For EventWriterMixin
),
]
def build_hooks(self):
"""Same as detectron2 DefaultTrainer.build_hooks with a few additions:
- Update LRScheduler -> CustomLRScheduler
- Update PeriodicWriter -> PeriodicWriter with log interval
- Add BestCheckpointer after EvalHook
"""
log_interval = get_config_value(cfg=self.cfg, key="LOGGER.INTERVAL", default=20)
checkpoint_max_keep = get_config_value(cfg=self.cfg, key="SOLVER.CHECKPOINT_MAX_KEEP")
if (
checkpoint_max_keep is not None
and not float(checkpoint_max_keep).is_integer()
and int(checkpoint_max_keep) >= 1
):
raise RuntimeError(
f"Expected SOLVER.CHECKPOINT_MAX_KEEP to be an integer >= 1,"
f" found SOLVER.CHECKPOINT_MAX_KEEP={checkpoint_max_keep}"
)
cfg = self.cfg.clone()
cfg.defrost()
cfg.DATALOADER.NUM_WORKERS = 0 # save some memory and time for PreciseBN
ret = [
d2_hooks.IterationTimer(),
CustomLRScheduler(), # Updated: CustomLRScheduler
(
d2_hooks.PreciseBN(
# Run at the same freq as (but before) evaluation.
cfg.TEST.EVAL_PERIOD,
self.model,
# Build a new data loader to not affect training
self.build_train_loader(cfg),
cfg.TEST.PRECISE_BN.NUM_ITER,
)
if cfg.TEST.PRECISE_BN.ENABLED and get_bn_modules(self.model)
else None
),
]
# Do PreciseBN before checkpointer, because it updates the model and need to
# be saved by checkpointer.
# This is not always the best: if checkpointing has a different frequency,
# some checkpoints may have more precise statistics than others.
if comm.is_main_process(): # Updated: Add checkpoint_max_keep
ret.append(
d2_hooks.PeriodicCheckpointer(
checkpointer=self.checkpointer,
period=cfg.SOLVER.CHECKPOINT_PERIOD,
max_to_keep=checkpoint_max_keep,
)
)
def test_and_save_results():
self._last_eval_results = self.test(self.cfg, self.model)
return self._last_eval_results
# Do evaluation after checkpointer, because then if it fails,
# we can use the saved checkpoint to debug.
ret.append(
CustomEvalHook(
eval_iters=cfg.TEST.EVAL_EXPLICIT_ITERS,
eval_period=cfg.TEST.EVAL_PERIOD,
eval_function=test_and_save_results,
eval_after_train=True,
)
)
# Updated: Add BestCheckpointer
best_metrics = cfg.SOLVER.get("CHECKPOINT_BEST_METRICS", [])
if isinstance(best_metrics, str):
best_metrics = [best_metrics]
wandb_save = cfg.SOLVER.get("CHECKPOINT_BEST_METRICS_WANDB_SAVE", [])
if isinstance(wandb_save, str):
wandb_save = [wandb_save]
if any([metric not in best_metrics for metric in wandb_save]):
raise RuntimeError(
f"Expected all metrics in SOLVER.CHECKPOINT_BEST_METRICS_WANDB_SAVE to be in"
f" SOLVER.CHECKPOINT_BEST_METRICS"
)
for val_metric in best_metrics:
ret.append(
CustomBestCheckpointer(
eval_period=cfg.TEST.EVAL_PERIOD,
checkpointer=self.checkpointer,
val_metric=val_metric,
save_wandb=val_metric in wandb_save,
)
)
if comm.is_main_process():
# Here the default print/log frequency of each writer is used.
# run writers in the end, so that evaluation metrics are written
# Update: Add separate ETA hook, pass in log_interval
ret.append(ETAHook(max_iter=self.max_iter, early_exit_iter=self._early_exit_iter))
ret.append(EpochIterHook())
ret.append(d2_hooks.PeriodicWriter(self.build_writers(), period=log_interval))
return ret
@classmethod
def build_train_loader(cls, cfg):
"""Same as Trainer.build_train_loader but uses our build_detection_train_loader which
supports cfg.DATALOADER.SAMPLER_TRAIN == "EpochTrainingSampler" or
"RandomSubsetEpochTrainingSampler".
"""
# Semantic segmentation dataset mapper
if cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_semantic":
mapper = MaskFormerSemanticDatasetMapper(cfg, True)
return build_detection_train_loader(cfg, mapper=mapper)
# Panoptic segmentation dataset mapper
elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_panoptic":
mapper = MaskFormerPanopticDatasetMapper(cfg, True)
return build_detection_train_loader(cfg, mapper=mapper)
# Instance segmentation dataset mapper
elif cfg.INPUT.DATASET_MAPPER_NAME == "mask_former_instance":
mapper = MaskFormerInstanceDatasetMapper(cfg, True)
return build_detection_train_loader(cfg, mapper=mapper)
# coco instance segmentation lsj new baseline
elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_instance_lsj":
mapper = COCOInstanceNewBaselineDatasetMapper(cfg, True)
return build_detection_train_loader(cfg, mapper=mapper)
# coco panoptic segmentation lsj new baseline
elif cfg.INPUT.DATASET_MAPPER_NAME == "coco_panoptic_lsj":
mapper = COCOPanopticNewBaselineDatasetMapper(cfg, True)
return build_detection_train_loader(cfg, mapper=mapper)
else:
mapper = None
return build_detection_train_loader(cfg, mapper=mapper)
@classmethod
def build_test_loader(cls, cfg, dataset_name):
"""Use our own build_detection_test_loader to support RandomSubsetInferenceSampler"""
data_loader = build_detection_test_loader(cfg, dataset_name)
pytorch_data_loader, _steps_per_epoch, _batch_size = parse_dataloader(data_loader)
cls._dataloaders["test"] = pytorch_data_loader
return data_loader # Original dataloader, not inner pytorch dataloader
@classmethod
def test(cls, cfg, model, evaluators=None):
# if torch.cuda.is_available():
# torch.cuda.empty_cache() # Empty from last train iter, to avoid OOM
if cls._test_eval_fp16:
from torch.cuda.amp import autocast
with autocast(dtype=torch.float16):
results = super().test(cfg=cfg, model=model, evaluators=evaluators)
else:
results = super().test(cfg=cfg, model=model, evaluators=evaluators)
cls._dataloaders.pop("test")
# if torch.cuda.is_available():
# torch.cuda.empty_cache() # Empty from test, again to avoid OOM
return results
@classmethod
def build_evaluator(cls, cfg, dataset_name, output_folder=None):
evaluators: DatasetEvaluators = super().build_evaluator(
cfg=cfg, dataset_name=dataset_name, output_folder=output_folder
)
# Replace COCOPanopticEvaluator with ParitalCOCOPanopticEvaluator if using test subset
# This is required to avoid a runtime error that will prevent PQ from being calculated
# And we also update how the multiprocessing pool is closed for multi-core PQ results
test_subset_ratio = get_config_value(cfg=cfg, key="DATALOADER.TEST_RANDOM_SUBSET_RATIO")
test_subset_size = get_config_value(cfg=cfg, key="DATALOADER.TEST_RANDOM_SUBSET_SIZE")
if test_subset_ratio is not None or test_subset_size is not None:
final_evaluators: List[DatasetEvaluator] = []
if isinstance(evaluators, DatasetEvaluators): # Multiple evaluators
for evaluator in evaluators._evaluators:
if type(evaluator) == COCOPanopticEvaluator:
final_evaluators.append(
PartialCOCOPanopticEvaluator(
dataset_name=dataset_name, output_dir=evaluator._output_dir
)
)
else:
logger.info(
f"Dropping unsupported evaluator {type(evaluator).__name__} when"
f" testing on a subset with DATALOADER.TEST_RANDOM_SUBSET_SIZE or"
f" DATALOADER.TEST_RANDOM_SUBSET_RATIO."
)
elif type(evaluators) == COCOPanopticEvaluator: # Single evaluator
final_evaluators.append(
PartialCOCOPanopticEvaluator(
dataset_name=dataset_name, output_dir=evaluator._output_dir
)
)
else:
raise RuntimeError(
f"Only COCOPanopticEvaluator is supported with"
f" DATALOADER.TEST_RANDOM_SUBSET_RATIO or DATALOADER.TEST_RANDOM_SUBSET_SIZE"
)
evaluators = DatasetEvaluators(final_evaluators)
return evaluators
def train(self):
"""Same as DefaultTrainer.train() but calls our _train() instead of super().train().
Also handles SIGTERM to close Wandb and mark run as preempting (b/c we cancelled it)"""
self._train()
if len(self.cfg.TEST.EXPECTED_RESULTS) and comm.is_main_process():
assert hasattr(
self, "_last_eval_results"
), "No evaluation results obtained during training!"
verify_results(self.cfg, self._last_eval_results)
return self._last_eval_results
def _train(self):
"""Same as TrainerBase.train() but uses a sigterm handler to close wandb correctly, and
supports self._early_exit_iter"""
orig_signal_handlers: dict[int, Callable] = {}
def _sigterm_handler(signal_num, _frame):
signal_descr: Optional[str] = signal.strsignal(signal_num)
# Need exit_code=1 when using preempting=True to show 'preempted'
exit_code = 1
preempting = True
# If print statement was interrupted by SIGTERM, can't print to console during handling
# Use "signal_safe" logger, from CustomRunner.build_logger(), only prints to file
# See https://stackoverflow.com/questions/45680378/how-to-explain-the-reentrant-runtimeerror-caused-by-printing-in-signal-handlers
# and https://stackoverflow.com/questions/64147017/logging-signals-in-python
logger = logging.getLogger("signal_safe") # Hard-coded in CustomRunner.build_logger()
logger.info(
f"Caught signal {signal_descr} ({signal_num}) during training."
f" Closing wandb backend with exit_code={exit_code}, preempting={preempting}."
)
# Use quiet=True for same reasons as signal_save loggger
# Wandb still not completely silent but the minimal console output doesn't cause issues
CustomWandbWriter.close_wandb(
exit_code=exit_code,
preempting=preempting,
quiet=True,
dataloaders=list(CustomTrainerMixin._dataloaders.values()),
)
logger.info("Signal handling finished. Re-raising signal with default signal handler")
signal.signal(signal_num, orig_signal_handlers[signal_num])
signal.raise_signal(signal_num)
if comm.is_main_process():
orig_signal_handlers[signal.SIGCONT] = signal.signal(signal.SIGCONT, _sigterm_handler)
orig_signal_handlers[signal.SIGTERM] = signal.signal(signal.SIGTERM, _sigterm_handler)
def _register_prev_handlers():
signal.signal(signal.SIGCONT, orig_signal_handlers[signal.SIGCONT])
signal.signal(signal.SIGTERM, orig_signal_handlers[signal.SIGTERM])
# Child processes such as those in pq_compute_multi_core will all call this handler
# but only want this current process to do the cleanup; need to re-register handlers
# From https://stackoverflow.com/a/74688726/12422298
os.register_at_fork(after_in_child=lambda: _register_prev_handlers())
start_iter = self.start_iter
max_iter = self.max_iter
self.iter = start_iter
self.max_iter = max_iter
# Use early_exit_iter only for range in loop below, not for self.max_iter
# We only want to break out early, don't want to impact any other mechanisms
slurm_requeue_on_finish = False
early_exit = False
if self._early_exit_iter is not None and self._early_exit_iter < max_iter:
max_iter = self._early_exit_iter
early_exit = True
if self._slurm_requeue_num_iter is not None:
slurm_max_iter = start_iter + self._slurm_requeue_num_iter
if slurm_max_iter < max_iter:
max_iter = slurm_max_iter
slurm_requeue_on_finish = True
early_exit = True
logger.info(
f"Starting training with start_iter={start_iter}, max_iter={max_iter},"
f" steps_per_epoch={self._trainer.steps_per_epoch},"
f" per_gpu_batch_size={self._per_gpu_batch_size},"
f" total_batch_size={self._total_batch_size}"
)
success = False
with EventStorage(start_iter) as self.storage:
try:
self.before_train()
for self.iter in range(start_iter, max_iter):
self.before_step()
self.run_step()
self.after_step()
# self.iter == max_iter can be used by `after_train` to
# tell whether the training successfully finished or failed
# due to exceptions.
self.iter += 1
success = True
except Exception as e:
# Sleep for a few seconds and see if we get SIGTERM, if we do the exception came
# from a timeout, which triggered a dataloader crash or something
time.sleep(5)
# If we reach this point, it was a real exception, close with exit_code=1
msg = f": {str(e)}" if len(str(e)) > 0 else ""
logger.info(
f"Caught {type(e).__name__}{msg}. Closing wandb (exit_code=1) and re-raising."
)
CustomWandbWriter.close_wandb(exit_code=1)
raise
finally:
if success:
logger.info(
f"Succesfully finished training for {max_iter} iter. Calling after_train()."
)
exit_code = 0
preempting = early_exit
else:
exit_code = 1
preempting = False
CustomWandbWriter.close_wandb(exit_code=exit_code, preempting=preempting)
self.after_train()
# comm.synchronize() # If non-main process leaves early, torchrun may terminate main?
if slurm_requeue_on_finish and comm.is_main_process():
assert self._slurm_job_id is not None, "Expected SLURM_JOB_ID if requeueing"
logger.info(
f"Requeuing slurm_job_id={self._slurm_job_id} after"
f" {self._slurm_requeue_num_iter} iterations"
)
subprocess.run(f"scontrol requeue {self._slurm_job_id}", shell=True)
def setup_loggers(cfg: CfgNode) -> None:
# Update: Setup additional logger for detectron2_plugin and this script, and a 'signal_safe'
# version which can be safely called during SIGTERM handling (can't print to stdout)
for name, abbrev in [
("mask2former", "mask2former"), # Originally only this
("detectron2_plugin", "d2_plugin"),
("__main__", "train_net_custom"),
("signal_safe", "signal_safe"),
]:
plugin_logger = setup_logger(
output=cfg.OUTPUT_DIR,
distributed_rank=comm.get_rank(),
name=name,
abbrev_name=abbrev,
configure_stdout=True if name != "signal_safe" else False,
)
plugin_logger.setLevel(logging.INFO)
for handler in plugin_logger.handlers:
handler.setLevel(logging.INFO)
def maybe_restart_run(args: Namespace, cfg: CfgNode):
if cfg.get("RESTART_RUN", False):
args.resume = False # Don't resume
if comm.is_main_process():
# Don't backup `wandb` dir, wandb already initialized with resume=False
backup_dir_names_regex = ["inference"]
backup_file_names_regex = [
r"log\..+",
r".+\.json",
r".+\.pth",
"last_checkpoint",
] # Top-level dir only
logger.info(
f"Found cfg.RESTART_RUN=True, backing up directories matching"
f" {backup_dir_names_regex} and files matching {backup_file_names_regex}"
)
backup_dest_dir = Path(cfg.OUTPUT_DIR, "prev_run")
if backup_dest_dir.exists():
logger.info(
f"Found previous backup dir {backup_dest_dir}. Deleting previous backup."
)
shutil.rmtree(backup_dest_dir, ignore_errors=True) # ignore_errors req if not empty
backup_dest_dir.mkdir(parents=True, exist_ok=True)
for filepath in Path(cfg.OUTPUT_DIR).glob("*"): # Not recursive
match_dir = any(
[
re.search(pattern=regex, string=filepath.name) is not None
for regex in backup_dir_names_regex
]
)
match_file = any(
[
re.search(pattern=regex, string=filepath.name) is not None
for regex in backup_file_names_regex
]
)
if (filepath.is_dir() and match_dir) or (filepath.is_file() and match_file):
dest_filepath = backup_dest_dir.joinpath(filepath.name)
shutil.move(src=str(filepath), dst=str(dest_filepath))
comm.synchronize()