From 7a4a12acd2e791e3b1980832c3d2a4e9daa95f8d Mon Sep 17 00:00:00 2001 From: Mayank Lunayach Date: Tue, 13 Sep 2022 21:46:41 -0400 Subject: [PATCH 1/3] ignore wandb --- .gitignore | 1 + 1 file changed, 1 insertion(+) create mode 100644 .gitignore diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..4e5f96e --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +wandb/ From 0377b96144faf980ec1c8a73dba78c567791522d Mon Sep 17 00:00:00 2001 From: Mayank Lunayach Date: Tue, 13 Sep 2022 21:48:49 -0400 Subject: [PATCH 2/3] ignore more --- .gitignore | 3 +++ 1 file changed, 3 insertions(+) diff --git a/.gitignore b/.gitignore index 4e5f96e..8130b04 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,4 @@ wandb/ +*.sh +*.txt + From 718ea89c7ac0d5063645b9a6d4065e172a559048 Mon Sep 17 00:00:00 2001 From: Mayank Lunayach Date: Tue, 13 Sep 2022 21:51:12 -0400 Subject: [PATCH 3/3] push changes --- net_train.py | 5 +++++ simnet/lib/net/dataset.py | 2 +- 2 files changed, 6 insertions(+), 1 deletion(-) diff --git a/net_train.py b/net_train.py index 598f484..1619da5 100755 --- a/net_train.py +++ b/net_train.py @@ -20,6 +20,8 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning import loggers +from pytorch_lightning.profiler import SimpleProfiler + from simnet.lib.net import common from simnet.lib import datapoint @@ -69,6 +71,8 @@ def draw_detections( model = PanopticModel(hparams, epochs, train_ds, EvalMethod()) model_checkpoint = ModelCheckpoint(filepath=hparams.output, save_top_k=-1, period=1, mode='max') wandb_logger = loggers.WandbLogger(name=hparams.wandb_name, project='CenterSnap') + + profiler = SimpleProfiler() if hparams.finetune_real: trainer = pl.Trainer( @@ -94,6 +98,7 @@ def draw_detections( default_save_path=hparams.output, use_amp=False, print_nan_grads=True, + profiler=profiler ) trainer.fit(model) diff --git a/simnet/lib/net/dataset.py b/simnet/lib/net/dataset.py index 1984bfd..7a5ddc3 100755 --- a/simnet/lib/net/dataset.py +++ b/simnet/lib/net/dataset.py @@ -49,7 +49,7 @@ def __init__(self, dataset_uri, hparams, preprocess_image_func=None, datapoint_d super().__init__() if datapoint_dataset is None: datapoint_dataset = datapoint.make_dataset(dataset_uri) - self.datapoint_handles = datapoint_dataset.list() + self.datapoint_handles = datapoint_dataset.list()[:200] # No need to shuffle, already shufled based on random uids self.hparams = hparams if preprocess_image_func is None: