From 2f0e72b73e0dece2fe2c3bafd264cc8f1e55ad15 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Thu, 12 Aug 2021 18:25:23 +0530 Subject: [PATCH 01/50] add dataparallel --- ml3d/torch/dataloaders/concat_batcher.py | 21 +++++++++++++++++++ ml3d/torch/pipelines/base_pipeline.py | 12 ++++++++--- ml3d/torch/pipelines/semantic_segmentation.py | 10 +++++++-- ml3d/utils/builder.py | 11 +++++++--- scripts/run_pipeline.py | 13 +++++++++--- 5 files changed, 56 insertions(+), 11 deletions(-) diff --git a/ml3d/torch/dataloaders/concat_batcher.py b/ml3d/torch/dataloaders/concat_batcher.py index 9efa70064..ce531509e 100644 --- a/ml3d/torch/dataloaders/concat_batcher.py +++ b/ml3d/torch/dataloaders/concat_batcher.py @@ -4,6 +4,7 @@ import pickle import torch import yaml +import math from os import listdir from os.path import exists, join, isdir @@ -434,6 +435,26 @@ def to(self, device): self.feat = [feat.to(device) for feat in self.feat] self.label = [label.to(device) for label in self.label] + @staticmethod + def scatter(batch, num_gpu): + batch_size = len(batch.batch_lengths) + assert num_gpu <= batch_size, "batch size must be greater than number of cuda devices" + + new_batch_size = math.ceil(batch_size / num_gpu) + batches = [SparseConvUnetBatch([]) for _ in range(num_gpu)] + splits = [0] + for length in batch.batch_lengths: + splits.append(splits[-1] + length) + for i in range(num_gpu): + start = splits[new_batch_size * i] + end = splits[min(new_batch_size * (i + 1), len(splits) - 1)] + batches[i].point = batch.point[start:end] + batches[i].feat = batch.feat[start:end] + batches[i].label = batch.label[start:end] + batches[i].batch_lengths = batch.batch_lengths[start:end] + + return batches + class ObjectDetectBatch: diff --git a/ml3d/torch/pipelines/base_pipeline.py b/ml3d/torch/pipelines/base_pipeline.py index 97c6f746b..4e56968c4 100644 --- a/ml3d/torch/pipelines/base_pipeline.py +++ b/ml3d/torch/pipelines/base_pipeline.py @@ -12,7 +12,12 @@ class BasePipeline(ABC): """Base pipeline class.""" - def __init__(self, model, dataset=None, device='gpu', **kwargs): + def __init__(self, + model, + dataset=None, + device='cuda', + device_ids=[0], + **kwargs): """Initialize. Args: @@ -42,9 +47,10 @@ def __init__(self, model, dataset=None, device='gpu', **kwargs): if device == 'cpu' or not torch.cuda.is_available(): self.device = torch.device('cpu') + self.device_ids = [-1] else: - self.device = torch.device('cuda' if len(device.split(':')) == - 1 else 'cuda:' + device.split(':')[1]) + self.device = torch.device('cuda') + self.device_ids = device_ids @abstractmethod def run_inference(self, data): diff --git a/ml3d/torch/pipelines/semantic_segmentation.py b/ml3d/torch/pipelines/semantic_segmentation.py index 9d857f9f4..78cc5c7b0 100644 --- a/ml3d/torch/pipelines/semantic_segmentation.py +++ b/ml3d/torch/pipelines/semantic_segmentation.py @@ -14,6 +14,7 @@ from os.path import exists, join, isfile, dirname, abspath from .base_pipeline import BasePipeline +from .dataparallel import CustomDataParallel from ..dataloaders import get_sampler, TorchDataloader, DefaultBatcher, ConcatBatcher from ..utils import latest_torch_ckpt from ..modules.losses import SemSegLoss @@ -102,7 +103,8 @@ def __init__( scheduler_gamma=0.95, momentum=0.98, main_log_dir='./logs/', - device='gpu', + device='cuda', + device_ids=[0], split='train', train_sum_dir='train_log', **kwargs): @@ -122,6 +124,7 @@ def __init__( momentum=momentum, main_log_dir=main_log_dir, device=device, + device_ids=device_ids, split=split, train_sum_dir=train_sum_dir, **kwargs) @@ -308,7 +311,6 @@ def run_train(self): dataset = self.dataset cfg = self.cfg - model.to(device) log.info("DEVICE : {}".format(device)) timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') @@ -379,6 +381,10 @@ def run_train(self): writer = SummaryWriter(self.tensorboard_dir) self.save_config(writer) + + model = CustomDataParallel(model, device_ids=self.device_ids) + # model.to(device) + log.info("Writing summary in {}.".format(self.tensorboard_dir)) log.info("Started training") diff --git a/ml3d/utils/builder.py b/ml3d/utils/builder.py index f2c02afc8..f56b12097 100644 --- a/ml3d/utils/builder.py +++ b/ml3d/utils/builder.py @@ -14,17 +14,22 @@ def build_network(cfg): return build(cfg, NETWORK) -def convert_device_name(framework): +def convert_device_name(framework, device_ids): """Convert device to either cpu or cuda.""" gpu_names = ["gpu", "cuda"] cpu_names = ["cpu"] if framework not in cpu_names + gpu_names: raise KeyError("the device shoule either " "be cuda or cpu but got {}".format(framework)) + assert type(device_ids) is list + device_ids_new = [] + for device in device_ids: + device_ids_new.append(int(device)) + if framework in gpu_names: - return "cuda" + return "cuda", device_ids_new else: - return "cpu" + return "cpu", device_ids_new def convert_framework_name(framework): diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 753352cd4..c56296b81 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -28,8 +28,12 @@ def parse_args(): parser.add_argument('--dataset_path', help='path to the dataset') parser.add_argument('--ckpt_path', help='path to the checkpoint') parser.add_argument('--device', - help='device to run the pipeline', - default='gpu') + help='devices to run the pipeline', + default='cuda') + parser.add_argument('--device_ids', + nargs='+', + help='cuda device list', + default=['0']) parser.add_argument('--split', help='train or test', default='train') parser.add_argument('--mode', help='additional mode', default=None) parser.add_argument('--max_epochs', help='number of epochs', default=None) @@ -62,7 +66,8 @@ def main(): args, extra_dict = parse_args() framework = _ml3d.utils.convert_framework_name(args.framework) - args.device = _ml3d.utils.convert_device_name(args.device) + args.device, args.device_ids = _ml3d.utils.convert_device_name( + args.device, args.device_ids) if framework == 'torch': import open3d.ml.torch as ml3d else: @@ -107,6 +112,8 @@ def main(): cfg_dict_pipeline["max_epochs"] = args.max_epochs if args.batch_size is not None: cfg_dict_pipeline["batch_size"] = args.batch_size + cfg_dict_pipeline["device"] = args.device + cfg_dict_pipeline["device_ids"] = args.device_ids pipeline = Pipeline(model, dataset, **cfg_dict_pipeline) else: if (args.pipeline and args.model and args.dataset) is None: From 306de5bda59be3b0bd6c834b883faf734adf054e Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Thu, 12 Aug 2021 18:52:15 +0530 Subject: [PATCH 02/50] add dataparallel class --- ml3d/torch/pipelines/dataparallel.py | 36 ++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 ml3d/torch/pipelines/dataparallel.py diff --git a/ml3d/torch/pipelines/dataparallel.py b/ml3d/torch/pipelines/dataparallel.py new file mode 100644 index 000000000..086857db2 --- /dev/null +++ b/ml3d/torch/pipelines/dataparallel.py @@ -0,0 +1,36 @@ +import torch +import numpy as np +from torch.nn.parallel import DataParallel + + +class CustomDataParallel(DataParallel): + """Custom DataParallel method for performing scatter operation + outside of torch's DataParallel. + """ + + def __init__(self, module, **kwargs): + super(CustomDataParallel, self).__init__(module, **kwargs) + self.get_loss = self.module.get_loss + self.cfg = self.module.cfg + + def forward(self, *inputs, **kwargs): + if not self.device_ids: + return self.module(*inputs, **kwargs) + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + # self._sync_params() + if len(self.device_ids) == 1: + return self.module(*inputs[0], **kwargs[0]) + + replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) + outputs = self.parallel_apply(replicas, inputs, kwargs) + + return self.gather(outputs, self.output_device) + + def scatter(self, inputs, kwargs, device_ids): + if not hasattr(inputs[0], 'scatter'): + raise NotImplementedError( + f"Please implement scatter for {inputs[0]} for multi gpu execution." + ) + inputs = inputs[0].scatter(inputs[0], len(self.device_ids)) + + return inputs, [kwargs for _ in range(len(inputs))] From b953254ef12023b33864639cb9532a624a84fc09 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Fri, 13 Aug 2021 19:20:23 +0530 Subject: [PATCH 03/50] fix bugs --- ml3d/torch/dataloaders/concat_batcher.py | 10 +++---- ml3d/torch/pipelines/dataparallel.py | 27 ++++++++++++++++--- ml3d/torch/pipelines/semantic_segmentation.py | 4 --- 3 files changed, 26 insertions(+), 15 deletions(-) diff --git a/ml3d/torch/dataloaders/concat_batcher.py b/ml3d/torch/dataloaders/concat_batcher.py index ce531509e..5497ced96 100644 --- a/ml3d/torch/dataloaders/concat_batcher.py +++ b/ml3d/torch/dataloaders/concat_batcher.py @@ -438,22 +438,18 @@ def to(self, device): @staticmethod def scatter(batch, num_gpu): batch_size = len(batch.batch_lengths) - assert num_gpu <= batch_size, "batch size must be greater than number of cuda devices" new_batch_size = math.ceil(batch_size / num_gpu) batches = [SparseConvUnetBatch([]) for _ in range(num_gpu)] - splits = [0] - for length in batch.batch_lengths: - splits.append(splits[-1] + length) for i in range(num_gpu): - start = splits[new_batch_size * i] - end = splits[min(new_batch_size * (i + 1), len(splits) - 1)] + start = new_batch_size * i + end = min(new_batch_size * (i + 1), batch_size) batches[i].point = batch.point[start:end] batches[i].feat = batch.feat[start:end] batches[i].label = batch.label[start:end] batches[i].batch_lengths = batch.batch_lengths[start:end] - return batches + return [b for b in batches if len(b.point)] # filter empty batch class ObjectDetectBatch: diff --git a/ml3d/torch/pipelines/dataparallel.py b/ml3d/torch/pipelines/dataparallel.py index 086857db2..11738b5a0 100644 --- a/ml3d/torch/pipelines/dataparallel.py +++ b/ml3d/torch/pipelines/dataparallel.py @@ -16,21 +16,40 @@ def __init__(self, module, **kwargs): def forward(self, *inputs, **kwargs): if not self.device_ids: return self.module(*inputs, **kwargs) - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) - # self._sync_params() + if len(self.device_ids) == 1: - return self.module(*inputs[0], **kwargs[0]) + if hasattr(inputs[0], 'to'): + inputs[0].to(self.device_ids[0]) + return self.module(inputs[0], **kwargs) + + inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + self.module.cuda() replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) outputs = self.parallel_apply(replicas, inputs, kwargs) return self.gather(outputs, self.output_device) def scatter(self, inputs, kwargs, device_ids): + """Custom scatter method to override default method. + Scatter batch dimension based on custom scatter implemented + in custom batcher. + + Agrs: + inputs: Object of type custom batcher. + kwargs: Optional keyword arguments. + device_ids: List of device ids. + + Returns: + Returns a list of inputs of length num_devices. + Each input is transfered to different device id. + """ if not hasattr(inputs[0], 'scatter'): raise NotImplementedError( f"Please implement scatter for {inputs[0]} for multi gpu execution." ) - inputs = inputs[0].scatter(inputs[0], len(self.device_ids)) + inputs = inputs[0].scatter(inputs[0], len(device_ids)) + for i in range(len(inputs)): + inputs[i].to(torch.device(device_ids[i])) return inputs, [kwargs for _ in range(len(inputs))] diff --git a/ml3d/torch/pipelines/semantic_segmentation.py b/ml3d/torch/pipelines/semantic_segmentation.py index 78cc5c7b0..f5089c405 100644 --- a/ml3d/torch/pipelines/semantic_segmentation.py +++ b/ml3d/torch/pipelines/semantic_segmentation.py @@ -399,8 +399,6 @@ def run_train(self): model.trans_point_sampler = train_sampler.get_point_sampler() for step, inputs in enumerate(tqdm(train_loader, desc='training')): - if hasattr(inputs['data'], 'to'): - inputs['data'].to(device) self.optimizer.zero_grad() results = model(inputs['data']) loss, gt_labels, predict_scores = model.get_loss( @@ -429,8 +427,6 @@ def run_train(self): with torch.no_grad(): for step, inputs in enumerate( tqdm(valid_loader, desc='validation')): - if hasattr(inputs['data'], 'to'): - inputs['data'].to(device) results = model(inputs['data']) loss, gt_labels, predict_scores = model.get_loss( From 4dccfc03dc3d2a68b2ae3f945ca972b3c11cb30f Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 24 Aug 2021 15:04:18 +0530 Subject: [PATCH 04/50] objdet multigpu --- ml3d/torch/dataloaders/concat_batcher.py | 46 +++++++++++++++++++++++- ml3d/torch/pipelines/object_detection.py | 9 +++-- 2 files changed, 52 insertions(+), 3 deletions(-) diff --git a/ml3d/torch/dataloaders/concat_batcher.py b/ml3d/torch/dataloaders/concat_batcher.py index 5497ced96..9b53a2a18 100644 --- a/ml3d/torch/dataloaders/concat_batcher.py +++ b/ml3d/torch/dataloaders/concat_batcher.py @@ -340,6 +340,33 @@ def to(self, device): return self + @staticmethod + def scatter(batch, num_gpu): + batch_size = len(batch.points) + + new_batch_size = math.ceil(batch_size / num_gpu) + batches = [KPConvBatch([]) for _ in range(num_gpu)] + for i in range(num_gpu): + start = new_batch_size * i + end = min(new_batch_size * (i + 1), batch_size) + batches[i].points = batch.points[start:end] + batches[i].neighbors = batch.neighbors[start:end] + batches[i].pools = batch.pools[start:end] + batches[i].upsamples = batch.upsamples[start:end] + batches[i].lengths = batch.lengths[start: + end] # TODO : verify lengths + + return [b for b in batches if len(b.points)] # filter empty batch + + def print(self): + print(self.points) + print(self.neighbors) + print(self.pools) + print(self.upsamples) + print(self.lengths) + print(self.features) + exit(0) + def unstack_points(self, layer=None): """Unstack the points.""" return self.unstack_elements('points', layer) @@ -471,13 +498,13 @@ def __init__(self, batches): self.attr = [] for batch in batches: - self.attr.append(batch['attr']) data = batch['data'] attr = batch['attr'] if 'test' not in attr['split'] and len( data['bboxes'] ) == 0: # Skip training batch with no bounding box. continue + self.attr.append(attr) self.point.append(torch.tensor(data['point'], dtype=torch.float32)) self.labels.append( torch.tensor(data['labels'], dtype=torch.int64) if 'labels' in @@ -506,6 +533,23 @@ def to(self, device): if self.bboxes[i] is not None: self.bboxes[i] = self.bboxes[i].to(device) + @staticmethod + def scatter(batch, num_gpu): + batch_size = len(batch.point) + + new_batch_size = math.ceil(batch_size / num_gpu) + batches = [ObjectDetectBatch([]) for _ in range(num_gpu)] + for i in range(num_gpu): + start = new_batch_size * i + end = min(new_batch_size * (i + 1), batch_size) + batches[i].point = batch.point[start:end] + batches[i].labels = batch.labels[start:end] + batches[i].bboxes = batch.bboxes[start:end] + batches[i].bbox_objs = batch.bbox_objs[start:end] + batches[i].attr = batch.attr[start:end] + + return [b for b in batches if len(b.point)] # filter empty batch + class ConcatBatcher(object): """ConcatBatcher for KPConv.""" diff --git a/ml3d/torch/pipelines/object_detection.py b/ml3d/torch/pipelines/object_detection.py index 43dd2a445..fe2a6ed05 100644 --- a/ml3d/torch/pipelines/object_detection.py +++ b/ml3d/torch/pipelines/object_detection.py @@ -11,6 +11,7 @@ from pathlib import Path from .base_pipeline import BasePipeline +from .dataparallel import CustomDataParallel from ..dataloaders import TorchDataloader, ConcatBatcher from torch.utils.tensorboard import SummaryWriter from ..utils import latest_torch_ckpt @@ -171,7 +172,7 @@ def run_valid(self): gt = [] with torch.no_grad(): for data in tqdm(valid_loader, desc='validation'): - data.to(device) + # data.to(device) results = model(data) loss = model.loss(results, data) for l, v in loss.items(): @@ -280,6 +281,10 @@ def run_train(self): writer = SummaryWriter(self.tensorboard_dir) self.save_config(writer) + + # wrap model for multiple GPU + model = CustomDataParallel(model, device_ids=self.device_ids) + log.info("Writing summary in {}.".format(self.tensorboard_dir)) log.info("Started training") @@ -291,7 +296,7 @@ def run_train(self): process_bar = tqdm(train_loader, desc='training') for data in process_bar: - data.to(device) + # data.to(device) results = model(data) loss = model.loss(results, data) loss_sum = sum(loss.values()) From 27793394385d2deab2e618fde904a63828080f7b Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Fri, 27 Aug 2021 13:33:42 +0530 Subject: [PATCH 05/50] rename scatter --- ml3d/torch/pipelines/dataparallel.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/ml3d/torch/pipelines/dataparallel.py b/ml3d/torch/pipelines/dataparallel.py index 11738b5a0..ae7a77859 100644 --- a/ml3d/torch/pipelines/dataparallel.py +++ b/ml3d/torch/pipelines/dataparallel.py @@ -22,7 +22,7 @@ def forward(self, *inputs, **kwargs): inputs[0].to(self.device_ids[0]) return self.module(inputs[0], **kwargs) - inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids) + inputs, kwargs = self.customscatter(inputs, kwargs, self.device_ids) self.module.cuda() replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) @@ -30,7 +30,7 @@ def forward(self, *inputs, **kwargs): return self.gather(outputs, self.output_device) - def scatter(self, inputs, kwargs, device_ids): + def customscatter(self, inputs, kwargs, device_ids): """Custom scatter method to override default method. Scatter batch dimension based on custom scatter implemented in custom batcher. @@ -45,9 +45,12 @@ def scatter(self, inputs, kwargs, device_ids): Each input is transfered to different device id. """ if not hasattr(inputs[0], 'scatter'): - raise NotImplementedError( - f"Please implement scatter for {inputs[0]} for multi gpu execution." - ) + try: + self.scatter(inputs, kwargs, device_ids) + except: + raise NotImplementedError( + f"Please implement scatter for {inputs[0]} for multi gpu execution." + ) inputs = inputs[0].scatter(inputs[0], len(device_ids)) for i in range(len(inputs)): inputs[i].to(torch.device(device_ids[i])) From 48a101f35cb0ec388b7959a1b14b174515ab5fc2 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Fri, 24 Sep 2021 16:33:46 +0530 Subject: [PATCH 06/50] fix objdet --- ml3d/datasets/utils/operations.py | 2 +- ml3d/torch/dataloaders/concat_batcher.py | 27 ------------------------ ml3d/torch/models/base_model_objdet.py | 2 +- ml3d/torch/models/point_pillars.py | 2 +- ml3d/torch/models/point_rcnn.py | 2 +- ml3d/torch/pipelines/dataparallel.py | 2 +- ml3d/torch/pipelines/object_detection.py | 6 +++--- 7 files changed, 8 insertions(+), 35 deletions(-) diff --git a/ml3d/datasets/utils/operations.py b/ml3d/datasets/utils/operations.py index 45147f2c3..3d252d481 100644 --- a/ml3d/datasets/utils/operations.py +++ b/ml3d/datasets/utils/operations.py @@ -4,7 +4,7 @@ import math from scipy.spatial import ConvexHull -from ...metrics import iou_bev +from open3d.ml.contrib import iou_bev_cpu as iou_bev def create_3D_rotations(axis, angle): diff --git a/ml3d/torch/dataloaders/concat_batcher.py b/ml3d/torch/dataloaders/concat_batcher.py index 9b53a2a18..6a7dda437 100644 --- a/ml3d/torch/dataloaders/concat_batcher.py +++ b/ml3d/torch/dataloaders/concat_batcher.py @@ -340,33 +340,6 @@ def to(self, device): return self - @staticmethod - def scatter(batch, num_gpu): - batch_size = len(batch.points) - - new_batch_size = math.ceil(batch_size / num_gpu) - batches = [KPConvBatch([]) for _ in range(num_gpu)] - for i in range(num_gpu): - start = new_batch_size * i - end = min(new_batch_size * (i + 1), batch_size) - batches[i].points = batch.points[start:end] - batches[i].neighbors = batch.neighbors[start:end] - batches[i].pools = batch.pools[start:end] - batches[i].upsamples = batch.upsamples[start:end] - batches[i].lengths = batch.lengths[start: - end] # TODO : verify lengths - - return [b for b in batches if len(b.points)] # filter empty batch - - def print(self): - print(self.points) - print(self.neighbors) - print(self.pools) - print(self.upsamples) - print(self.lengths) - print(self.features) - exit(0) - def unstack_points(self, layer=None): """Unstack the points.""" return self.unstack_elements('points', layer) diff --git a/ml3d/torch/models/base_model_objdet.py b/ml3d/torch/models/base_model_objdet.py index e63176c0b..7e665955a 100644 --- a/ml3d/torch/models/base_model_objdet.py +++ b/ml3d/torch/models/base_model_objdet.py @@ -24,7 +24,7 @@ def __init__(self, **kwargs): self.cfg = Config(kwargs) @abstractmethod - def loss(self, results, inputs): + def get_loss(self, results, inputs): """Computes the loss given the network input and outputs. Args: diff --git a/ml3d/torch/models/point_pillars.py b/ml3d/torch/models/point_pillars.py index 47344ecd8..da1d49b27 100644 --- a/ml3d/torch/models/point_pillars.py +++ b/ml3d/torch/models/point_pillars.py @@ -138,7 +138,7 @@ def get_optimizer(self, cfg): optimizer = torch.optim.AdamW(self.parameters(), **cfg) return optimizer, None - def loss(self, results, inputs): + def get_loss(self, results, inputs): scores, bboxes, dirs = results gt_labels = inputs.labels gt_bboxes = inputs.bboxes diff --git a/ml3d/torch/models/point_rcnn.py b/ml3d/torch/models/point_rcnn.py index f06c3c217..ea542e045 100644 --- a/ml3d/torch/models/point_rcnn.py +++ b/ml3d/torch/models/point_rcnn.py @@ -183,7 +183,7 @@ def step(self): return optimizer, scheduler - def loss(self, results, inputs): + def get_loss(self, results, inputs): if self.mode == "RPN": return self.rpn.loss(results, inputs) else: diff --git a/ml3d/torch/pipelines/dataparallel.py b/ml3d/torch/pipelines/dataparallel.py index ae7a77859..1b36c431a 100644 --- a/ml3d/torch/pipelines/dataparallel.py +++ b/ml3d/torch/pipelines/dataparallel.py @@ -46,7 +46,7 @@ def customscatter(self, inputs, kwargs, device_ids): """ if not hasattr(inputs[0], 'scatter'): try: - self.scatter(inputs, kwargs, device_ids) + return self.scatter(inputs, kwargs, device_ids) except: raise NotImplementedError( f"Please implement scatter for {inputs[0]} for multi gpu execution." diff --git a/ml3d/torch/pipelines/object_detection.py b/ml3d/torch/pipelines/object_detection.py index fe2a6ed05..b31a210da 100644 --- a/ml3d/torch/pipelines/object_detection.py +++ b/ml3d/torch/pipelines/object_detection.py @@ -174,7 +174,7 @@ def run_valid(self): for data in tqdm(valid_loader, desc='validation'): # data.to(device) results = model(data) - loss = model.loss(results, data) + loss = model.get_loss(results, data) for l, v in loss.items(): if not l in self.valid_losses: self.valid_losses[l] = [] @@ -296,9 +296,9 @@ def run_train(self): process_bar = tqdm(train_loader, desc='training') for data in process_bar: - # data.to(device) + data.to(device) results = model(data) - loss = model.loss(results, data) + loss = model.get_loss(results, data) loss_sum = sum(loss.values()) self.optimizer.zero_grad() From 53479c2838abce31cf832a3ba441fd9332ca6570 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Fri, 24 Sep 2021 16:36:16 +0530 Subject: [PATCH 07/50] remove comments --- ml3d/torch/pipelines/object_detection.py | 2 +- ml3d/torch/pipelines/semantic_segmentation.py | 1 - 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/ml3d/torch/pipelines/object_detection.py b/ml3d/torch/pipelines/object_detection.py index b31a210da..9d4332338 100644 --- a/ml3d/torch/pipelines/object_detection.py +++ b/ml3d/torch/pipelines/object_detection.py @@ -172,7 +172,7 @@ def run_valid(self): gt = [] with torch.no_grad(): for data in tqdm(valid_loader, desc='validation'): - # data.to(device) + data.to(device) results = model(data) loss = model.get_loss(results, data) for l, v in loss.items(): diff --git a/ml3d/torch/pipelines/semantic_segmentation.py b/ml3d/torch/pipelines/semantic_segmentation.py index f5089c405..26d2c966d 100644 --- a/ml3d/torch/pipelines/semantic_segmentation.py +++ b/ml3d/torch/pipelines/semantic_segmentation.py @@ -383,7 +383,6 @@ def run_train(self): self.save_config(writer) model = CustomDataParallel(model, device_ids=self.device_ids) - # model.to(device) log.info("Writing summary in {}.".format(self.tensorboard_dir)) From 6139e66449f559f0f7ffe33bfb6d37a580c2cb18 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Mon, 10 Jan 2022 17:47:41 +0530 Subject: [PATCH 08/50] fix cam_img matrix --- ml3d/datasets/waymo.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ml3d/datasets/waymo.py b/ml3d/datasets/waymo.py index 2e1545b47..8d1690b91 100644 --- a/ml3d/datasets/waymo.py +++ b/ml3d/datasets/waymo.py @@ -166,7 +166,7 @@ def read_calib(path): Tr_velo_to_cam = Waymo._extend_matrix(Tr_velo_to_cam) world_cam = np.transpose(rect_4x4 @ Tr_velo_to_cam) - cam_img = np.transpose(P2) + cam_img = np.transpose(np.vstack((P2.reshape(3, 4), [0, 0, 0, 1]))) return {'world_cam': world_cam, 'cam_img': cam_img} From d9e156488d53f55d7894f404185d478e4b0ddf54 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Mon, 17 Jan 2022 21:40:59 +0530 Subject: [PATCH 09/50] add distributed training --- ml3d/configs/pointpillars_waymo.yml | 28 +++---- ml3d/torch/pipelines/base_pipeline.py | 18 ++++- ml3d/torch/pipelines/object_detection.py | 80 ++++++++++++-------- ml3d/utils/config.py | 6 ++ scripts/run_pipeline.py | 94 ++++++++++++++++++++---- 5 files changed, 162 insertions(+), 64 deletions(-) diff --git a/ml3d/configs/pointpillars_waymo.yml b/ml3d/configs/pointpillars_waymo.yml index 43534e834..4e18a5b4c 100644 --- a/ml3d/configs/pointpillars_waymo.yml +++ b/ml3d/configs/pointpillars_waymo.yml @@ -10,7 +10,7 @@ model: batcher: "ignore" - point_cloud_range: [-74.88, -74.88, -2, 74.88, 74.88, 4] + point_cloud_range: [-80, -80, -6, 85, 85, 7] classes: ['VEHICLE', 'PEDESTRIAN', 'CYCLIST'] loss: @@ -31,7 +31,7 @@ model: max_voxels: [32000, 32000] voxel_encoder: - in_channels: 5 + in_channels: 4 feat_channels: [64] voxel_size: *vsize @@ -70,18 +70,18 @@ model: rotations: [0, 1.57] iou_thr: [[0.4, 0.55], [0.3, 0.5], [0.3, 0.5]] - augment: - PointShuffle: True - ObjectRangeFilter: True - ObjectSample: - min_points_dict: - VEHICLE: 5 - PEDESTRIAN: 10 - CYCLIST: 10 - sample_dict: - VEHICLE: 15 - PEDESTRIAN: 10 - CYCLIST: 10 + augment: {} + # PointShuffle: True + # ObjectRangeFilter: True + # ObjectSample: + # min_points_dict: + # VEHICLE: 5 + # PEDESTRIAN: 10 + # CYCLIST: 10 + # sample_dict: + # VEHICLE: 15 + # PEDESTRIAN: 10 + # CYCLIST: 10 pipeline: diff --git a/ml3d/torch/pipelines/base_pipeline.py b/ml3d/torch/pipelines/base_pipeline.py index 13ff7be86..e76efca8e 100644 --- a/ml3d/torch/pipelines/base_pipeline.py +++ b/ml3d/torch/pipelines/base_pipeline.py @@ -16,7 +16,7 @@ def __init__(self, model, dataset=None, device='cuda', - device_ids=[0], + distributed=False, **kwargs): """Initialize. @@ -46,12 +46,22 @@ def __init__(self, model.__class__.__name__ + '_' + dataset_name + '_torch') make_dir(self.cfg.logs_dir) + self.distributed = distributed + + self.rank = kwargs.get('rank', 0) + if device == 'cpu' or not torch.cuda.is_available(): + if distributed: + raise ValueError( + "Distributed training is ON, but CUDA not available.") self.device = torch.device('cpu') - self.device_ids = [-1] else: - self.device = torch.device('cuda') - self.device_ids = device_ids + if distributed: + self.device = torch.device(device) + print("Using device", self.device) + torch.cuda.set_device(self.device) + else: + self.device = torch.device('cuda') self.summary = {} self.cfg.setdefault('summary', {}) diff --git a/ml3d/torch/pipelines/object_detection.py b/ml3d/torch/pipelines/object_detection.py index 2927c19ef..9228e3a4c 100644 --- a/ml3d/torch/pipelines/object_detection.py +++ b/ml3d/torch/pipelines/object_detection.py @@ -1,12 +1,13 @@ import logging import re +import numpy as np +import torch +import torch.distributed as dist + from datetime import datetime from os.path import exists, join from pathlib import Path - from tqdm import tqdm -import numpy as np -import torch from torch.utils.data import DataLoader from .base_pipeline import BasePipeline @@ -255,18 +256,21 @@ def run_train(self): """Run training with train data split.""" torch.manual_seed(self.rng.integers(np.iinfo( np.int32).max)) # Random reproducible seed for torch + rank = self.rank # Rank for distributed training model = self.model device = self.device dataset = self.dataset cfg = self.cfg - log.info("DEVICE : {}".format(device)) - timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') + if rank == 0: + log.info("DEVICE : {}".format(device)) + timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') - log_file_path = join(cfg.logs_dir, 'log_train_' + timestamp + '.txt') - log.info("Logging in file : {}".format(log_file_path)) - log.addHandler(logging.FileHandler(log_file_path)) + log_file_path = join(cfg.logs_dir, + 'log_train_' + timestamp + '.txt') + log.info("Logging in file : {}".format(log_file_path)) + log.addHandler(logging.FileHandler(log_file_path)) batcher = ConcatBatcher(device, model.cfg.name) @@ -277,15 +281,22 @@ def run_train(self): use_cache=dataset.cfg.use_cache, steps_per_epoch=dataset.cfg.get( 'steps_per_epoch_train', None)) - train_loader = DataLoader( - train_split, - batch_size=cfg.batch_size, - num_workers=cfg.get('num_workers', 4), - pin_memory=cfg.get('pin_memory', False), - collate_fn=batcher.collate_fn, - worker_init_fn=lambda x: np.random.seed(x + np.uint32( - torch.utils.data.get_worker_info().seed)) - ) # numpy expects np.uint32, whereas torch returns np.uint64. + + if self.distributed: + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_split) + else: + train_sampler = None + + train_loader = DataLoader(train_split, + batch_size=cfg.batch_size, + num_workers=cfg.get('num_workers', 0), + pin_memory=cfg.get('pin_memory', False), + collate_fn=batcher.collate_fn, + sampler=train_sampler) + # worker_init_fn=lambda x: np.random.seed(x + np.uint32( + # torch.utils.data.get_worker_info().seed)) + # ) # numpy expects np.uint32, whereas torch returns np.uint64. self.optimizer, self.scheduler = model.get_optimizer(cfg.optimizer) @@ -301,23 +312,31 @@ def run_train(self): runid + '_' + Path(tensorboard_dir).name) writer = SummaryWriter(self.tensorboard_dir) - self.save_config(writer) + if rank == 0: + self.save_config(writer) + log.info("Writing summary in {}.".format(self.tensorboard_dir)) # wrap model for multiple GPU - model = CustomDataParallel(model, device_ids=self.device_ids) + if self.distributed: + model.cuda(self.device) + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[self.device]) - log.info("Writing summary in {}.".format(self.tensorboard_dir)) record_summary = 'train' in cfg.get('summary').get('record_for', []) - log.info("Started training") + if rank == 0: + log.info("Started training") + for epoch in range(start_ep, cfg.max_epoch + 1): log.info(f'=== EPOCH {epoch:d}/{cfg.max_epoch:d} ===') - model.train() + if self.distributed: + train_sampler.set_epoch(epoch) + model.train() self.losses = {} process_bar = tqdm(train_loader, desc='training') - for data in process_bar: + for data in train_loader: data.to(device) results = model(data) loss = model.get_loss(results, data) @@ -331,7 +350,7 @@ def run_train(self): self.optimizer.step() # Record visualization for the last iteration - if record_summary and process_bar.n == process_bar.total - 1: + if rank == 0 and record_summary and process_bar.n == process_bar.total - 1: boxes = model.inference_end(results, data) self.summary['train'] = self.get_3d_summary(boxes, data, @@ -351,13 +370,13 @@ def run_train(self): self.scheduler.step() # --------------------- validation - if (epoch % cfg.get("validation_freq", 1)) == 0: + if rank == 0 and (epoch % cfg.get("validation_freq", 1)) == 0: self.run_valid() - self.save_logs(writer, epoch) - - if epoch % cfg.save_ckpt_freq == 0: - self.save_ckpt(epoch) + if rank == 0: + self.save_logs(writer, epoch) + if epoch % cfg.save_ckpt_freq == 0: + self.save_ckpt(epoch) def get_3d_summary(self, infer_bboxes_batch, @@ -480,7 +499,8 @@ def save_logs(self, writer, epoch): def load_ckpt(self, ckpt_path=None, is_resume=True): train_ckpt_dir = join(self.cfg.logs_dir, 'checkpoint') - make_dir(train_ckpt_dir) + if self.rank == 0: + make_dir(train_ckpt_dir) epoch = 0 if ckpt_path is None: diff --git a/ml3d/utils/config.py b/ml3d/utils/config.py index 51b10829e..2686c958b 100644 --- a/ml3d/utils/config.py +++ b/ml3d/utils/config.py @@ -272,3 +272,9 @@ def __getattr__(self, name): def __getitem__(self, name): return self._cfg_dict.__getitem__(name) + + def __getstate__(self): + return self.cfg_dict + + def __setstate__(self, state): + self.cfg_dict = state diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 6b18e02a3..8a9e34564 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -3,6 +3,8 @@ import sys import yaml import pprint +import os +import torch.distributed as dist from pathlib import Path @@ -70,6 +72,8 @@ def main(): rng = np.random.default_rng(args.seed) if framework == 'torch': import open3d.ml.torch as ml3d + import torch.multiprocessing as mp + import torch.distributed as dist else: import tensorflow as tf import open3d.ml.tf as ml3d @@ -101,24 +105,20 @@ def main(): cfg_dict_dataset, cfg_dict_pipeline, cfg_dict_model = \ _ml3d.utils.Config.merge_cfg_file(cfg, args, extra_dict) - cfg_dict_dataset['seed'] = rng - cfg_dict_model['seed'] = rng - cfg_dict_pipeline['seed'] = rng - - dataset = Dataset(cfg_dict_dataset.pop('dataset_path', None), - **cfg_dict_dataset) - if args.mode is not None: cfg_dict_model["mode"] = args.mode - model = Model(**cfg_dict_model) - if args.max_epochs is not None: cfg_dict_pipeline["max_epochs"] = args.max_epochs if args.batch_size is not None: cfg_dict_pipeline["batch_size"] = args.batch_size + + cfg_dict_dataset['seed'] = rng + cfg_dict_model['seed'] = rng + cfg_dict_pipeline['seed'] = rng + cfg_dict_pipeline["device"] = args.device cfg_dict_pipeline["device_ids"] = args.device_ids - pipeline = Pipeline(model, dataset, **cfg_dict_pipeline) + else: if (args.pipeline and args.model and args.dataset) is None: raise ValueError("Please specify pipeline, model, and dataset " + @@ -136,25 +136,87 @@ def main(): cfg_dict_model['seed'] = rng cfg_dict_pipeline['seed'] = rng - dataset = Dataset(**cfg_dict_dataset) - model = Model(**cfg_dict_model, mode=args.mode) - pipeline = Pipeline(model, dataset, **cfg_dict_pipeline) - with open(Path(__file__).parent / 'README.md', 'r') as f: readme = f.read() - pipeline.cfg_tb = { + + cfg_tb = { 'readme': readme, 'cmd_line': cmd_line, 'dataset': pprint.pformat(cfg_dict_dataset, indent=2), 'model': pprint.pformat(cfg_dict_model, indent=2), 'pipeline': pprint.pformat(cfg_dict_pipeline, indent=2) } + args.cfg_tb = cfg_tb + args.distributed = framework == 'torch' and args.device != 'cpu' and len( + args.device_ids) > 1 + + if not args.distributed: + # print("not distr : ") + # exit(0) + dataset = Dataset(**cfg_dict_dataset) + model = Model(**cfg_dict_model, mode=args.mode) + pipeline = Pipeline(model, dataset, **cfg_dict_pipeline) + + pipeline.cfg_tb = cfg_tb + + if args.split == 'test': + pipeline.run_test() + else: + pipeline.run_train() + + else: + mp.spawn(main_worker, + args=(Dataset, Model, Pipeline, cfg_dict_dataset, + cfg_dict_model, cfg_dict_pipeline, args), + nprocs=len(args.device_ids)) + + +def setup(rank, world_size): + os.environ['MASTER_ADDR'] = 'localhost' + os.environ['MASTER_PORT'] = '12355' + + # initialize the process group + dist.init_process_group("gloo", rank=rank, world_size=world_size) + + +def cleanup(): + dist.destroy_process_group() + + +def main_worker(rank, Dataset, Model, Pipeline, cfg_dict_dataset, + cfg_dict_model, cfg_dict_pipeline, args): + world_size = len(args.device_ids) + setup(rank, world_size) + + cfg_dict_dataset['rank'] = rank + cfg_dict_model['rank'] = rank + cfg_dict_pipeline['rank'] = rank + + device = f"cuda:{args.device_ids[rank]}" + print(f"rank = {rank}, world_size = {world_size}, gpu = {device}") + + cfg_dict_model['device'] = device + cfg_dict_pipeline['device'] = device + + dataset = Dataset(**cfg_dict_dataset) + model = Model(**cfg_dict_model, mode=args.mode) + pipeline = Pipeline(model, + dataset, + distributed=args.distributed, + **cfg_dict_pipeline) + + with open(Path(__file__).parent / 'README.md', 'r') as f: + readme = f.read() + pipeline.cfg_tb = args.cfg_tb if args.split == 'test': - pipeline.run_test() + if rank == 0: + pipeline.run_test() else: pipeline.run_train() + cleanup() + if __name__ == '__main__': main() From ef0e44053adfc7ddaddd25b7f78513d653dbca30 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 18 Jan 2022 18:51:56 +0530 Subject: [PATCH 10/50] parllel validation --- ml3d/torch/pipelines/object_detection.py | 65 +++++++++++++++++------- scripts/run_pipeline.py | 2 + 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/ml3d/torch/pipelines/object_detection.py b/ml3d/torch/pipelines/object_detection.py index 9228e3a4c..ebbfe09b7 100644 --- a/ml3d/torch/pipelines/object_detection.py +++ b/ml3d/torch/pipelines/object_detection.py @@ -156,7 +156,8 @@ def run_valid(self, epoch=0): log.info("DEVICE : {}".format(device)) log_file_path = join(cfg.logs_dir, 'log_valid_' + timestamp + '.txt') log.info("Logging in file : {}".format(log_file_path)) - log.addHandler(logging.FileHandler(log_file_path)) + if self.rank == 0: + log.addHandler(logging.FileHandler(log_file_path)) batcher = ConcatBatcher(device, model.cfg.name) @@ -168,16 +169,24 @@ def run_valid(self, epoch=0): shuffle=True, steps_per_epoch=dataset.cfg.get( 'steps_per_epoch_valid', None)) - valid_loader = DataLoader( - valid_split, - batch_size=cfg.val_batch_size, - num_workers=cfg.get('num_workers', 4), - pin_memory=cfg.get('pin_memory', False), - collate_fn=batcher.collate_fn, - worker_init_fn=lambda x: np.random.seed(x + np.uint32( - torch.utils.data.get_worker_info().seed))) - record_summary = 'valid' in cfg.get('summary').get('record_for', []) + if self.distributed: + valid_sampler = torch.utils.data.distributed.DistributedSampler( + valid_split) + else: + valid_sampler = None + + valid_loader = DataLoader(valid_split, + batch_size=cfg.val_batch_size, + num_workers=cfg.get('num_workers', 0), + pin_memory=cfg.get('pin_memory', False), + collate_fn=batcher.collate_fn, + sampler=valid_sampler) + # worker_init_fn=lambda x: np.random.seed(x + np.uint32( + # torch.utils.data.get_worker_info().seed))) + + record_summary = self.rank == 0 and 'valid' in cfg.get('summary').get( + 'record_for', []) log.info("Started validation") self.valid_losses = {} @@ -322,7 +331,8 @@ def run_train(self): model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.device]) - record_summary = 'train' in cfg.get('summary').get('record_for', []) + record_summary = self.rank == 0 and 'train' in cfg.get('summary').get( + 'record_for', []) if rank == 0: log.info("Started training") @@ -336,22 +346,35 @@ def run_train(self): self.losses = {} process_bar = tqdm(train_loader, desc='training') - for data in train_loader: + for data in process_bar: data.to(device) results = model(data) - loss = model.get_loss(results, data) + if self.distributed: + loss = model.module.get_loss(results, data) + else: + loss = model.get_loss(results, data) loss_sum = sum(loss.values()) self.optimizer.zero_grad() loss_sum.backward() - if model.cfg.get('grad_clip_norm', -1) > 0: - torch.nn.utils.clip_grad_value_(model.parameters(), - model.cfg.grad_clip_norm) + if self.distributed: + if model.module.cfg.get('grad_clip_norm', -1) > 0: + torch.nn.utils.clip_grad_value_( + model.module.parameters(), + model.module.cfg.grad_clip_norm) + else: + if model.cfg.get('grad_clip_norm', -1) > 0: + torch.nn.utils.clip_grad_value_( + model.parameters(), model.cfg.grad_clip_norm) + self.optimizer.step() # Record visualization for the last iteration - if rank == 0 and record_summary and process_bar.n == process_bar.total - 1: - boxes = model.inference_end(results, data) + if record_summary and process_bar.n == process_bar.total - 1: + if self.distributed: + boxes = model.module.inference_end(results, data) + else: + boxes = model.inference_end(results, data) self.summary['train'] = self.get_3d_summary(boxes, data, epoch, @@ -366,11 +389,15 @@ def run_train(self): process_bar.set_description(desc) process_bar.refresh() + if self.distributed: + dist.barrier() + if self.scheduler is not None: self.scheduler.step() # --------------------- validation - if rank == 0 and (epoch % cfg.get("validation_freq", 1)) == 0: + # if rank == 0 and (epoch % cfg.get("validation_freq", 1)) == 0: + if epoch % cfg.get("validation_freq", 1) == 0: self.run_valid() if rank == 0: diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 8a9e34564..84ffa2919 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -5,6 +5,7 @@ import pprint import os import torch.distributed as dist +from torch import multiprocessing from pathlib import Path @@ -219,4 +220,5 @@ def main_worker(rank, Dataset, Model, Pipeline, cfg_dict_dataset, if __name__ == '__main__': + multiprocessing.set_start_method('spawn') main() From 29b972906a099996541e0c897737cbc688b65010 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Mon, 7 Feb 2022 17:18:34 +0530 Subject: [PATCH 11/50] update config --- ml3d/configs/pointpillars_waymo.yml | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/ml3d/configs/pointpillars_waymo.yml b/ml3d/configs/pointpillars_waymo.yml index 4e18a5b4c..cd35e21f6 100644 --- a/ml3d/configs/pointpillars_waymo.yml +++ b/ml3d/configs/pointpillars_waymo.yml @@ -37,7 +37,7 @@ model: scatter: in_channels: 64 - output_shape: [468, 468] + output_shape: [520, 520] backbone: in_channels: 64 @@ -57,14 +57,14 @@ model: nms_pre: 4096 score_thr: 0.1 ranges: [ - [-74.88, -74.88, -0.0345, 74.88, 74.88, -0.0345], - [-74.88, -74.88, -0.1188, 74.88, 74.88, -0.1188], - [-74.88, -74.88, 0, 74.88, 74.88, 0], + [-80, -80, 1.142, 85, 85, 1.142], + [-80, -80, 1.139, 85, 85, 1.139], + [-80, -80, 1.149, 85, 85, 1.149], ] sizes: [ - [2.08, 4.73, 1.77], # car - [0.84, 1.81, 1.77], # cyclist - [0.84, 0.91, 1.74] # pedestrian + [1.98, 4.50, 1.96], # VEHICLE + [0.91, 1.94, 1.78], # CYCLIST + [0.84, 0.91, 1.70] # PEDESTRIAN ] dir_offset: 0.7854 rotations: [0, 1.57] From 672248d56638f62aa50770c569e1a5258ed5f8ad Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 8 Feb 2022 00:41:42 -0800 Subject: [PATCH 12/50] gather in run_valid --- ml3d/torch/pipelines/object_detection.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/ml3d/torch/pipelines/object_detection.py b/ml3d/torch/pipelines/object_detection.py index ebbfe09b7..cff670f92 100644 --- a/ml3d/torch/pipelines/object_detection.py +++ b/ml3d/torch/pipelines/object_detection.py @@ -227,6 +227,22 @@ def run_valid(self, epoch=0): similar_classes = cfg.get("similar_classes", {}) difficulties = cfg.get("difficulties", [0]) + if self.distributed: + gt_gather = [None for _ in range(dist.get_world_size())] + pred_gather = [None for _ in range(dist.get_world_size())] + + dist.gather_object(gt, gt_gather if self.rank == 0 else None, dst=0) + dist.gather_object(pred, + pred_gather if self.rank == 0 else None, + dst=0) + + if self.rank == 0: + gt = sum(gt_gather, []) + pred = sum(pred_gather, []) + + if self.rank != 0: + return + ap = mAP(pred, gt, model.classes, @@ -399,6 +415,8 @@ def run_train(self): # if rank == 0 and (epoch % cfg.get("validation_freq", 1)) == 0: if epoch % cfg.get("validation_freq", 1) == 0: self.run_valid() + if self.distributed: + dist.barrier() if rank == 0: self.save_logs(writer, epoch) @@ -528,6 +546,8 @@ def load_ckpt(self, ckpt_path=None, is_resume=True): train_ckpt_dir = join(self.cfg.logs_dir, 'checkpoint') if self.rank == 0: make_dir(train_ckpt_dir) + if self.distributed: + dist.barrier() epoch = 0 if ckpt_path is None: From 620b35c9fdb600bddaa0662e29ad24b070169d28 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Wed, 9 Feb 2022 11:39:58 -0800 Subject: [PATCH 13/50] fix preprocessing --- ml3d/datasets/augment/augmentation.py | 2 +- ml3d/datasets/waymo.py | 17 +++++++---------- ml3d/torch/pipelines/base_pipeline.py | 11 ++++++----- scripts/collect_bboxes.py | 22 +++++++++++++++++++--- scripts/preprocess_waymo.py | 22 +++++++++++----------- 5 files changed, 44 insertions(+), 30 deletions(-) diff --git a/ml3d/datasets/augment/augmentation.py b/ml3d/datasets/augment/augmentation.py index c11f13452..c7ca1244f 100644 --- a/ml3d/datasets/augment/augmentation.py +++ b/ml3d/datasets/augment/augmentation.py @@ -484,7 +484,7 @@ def ObjectSample(self, data, db_boxes_dict, sample_dict): sampled_points = np.concatenate( [box.points_inside_box for box in sampled], axis=0) points = remove_points_in_boxes(points, sampled) - points = np.concatenate([sampled_points, points], axis=0) + points = np.concatenate([sampled_points[:, :4], points], axis=0) return { 'point': points, diff --git a/ml3d/datasets/waymo.py b/ml3d/datasets/waymo.py index 8d1690b91..13cc098b2 100644 --- a/ml3d/datasets/waymo.py +++ b/ml3d/datasets/waymo.py @@ -29,7 +29,6 @@ def __init__(self, name='Waymo', cache_dir='./logs/cache', use_cache=False, - val_split=3, **kwargs): """Initialize the function by passing the dataset and other details. @@ -38,7 +37,6 @@ def __init__(self, name: The name of the dataset (Waymo in this case). cache_dir: The directory where the cache is stored. use_cache: Indicates if the dataset should be cached. - val_split: The split value to get a set of images for training, validation, for testing. Returns: class: The corresponding class. @@ -47,7 +45,6 @@ def __init__(self, name=name, cache_dir=cache_dir, use_cache=use_cache, - val_split=val_split, **kwargs) cfg = self.cfg @@ -63,15 +60,15 @@ def __init__(self, self.val_files = [] for f in self.all_files: - idx = Path(f).name.replace('.bin', '')[:3] - idx = int(idx) - if idx < cfg.val_split: + if 'train' in f: self.train_files.append(f) - else: + elif 'val' in f: self.val_files.append(f) - - self.test_files = glob( - join(cfg.dataset_path, 'testing', 'velodyne', '*.bin')) + elif 'test' in f: + self.test_files.append(f) + else: + log.warning( + f"Skipping {f}, prefix must be one of train, test or val.") @staticmethod def get_label_to_names(): diff --git a/ml3d/torch/pipelines/base_pipeline.py b/ml3d/torch/pipelines/base_pipeline.py index e76efca8e..2c633c761 100644 --- a/ml3d/torch/pipelines/base_pipeline.py +++ b/ml3d/torch/pipelines/base_pipeline.py @@ -39,16 +39,17 @@ def __init__(self, self.dataset = dataset self.rng = np.random.default_rng(kwargs.get('seed', None)) - make_dir(self.cfg.main_log_dir) + self.distributed = distributed + self.rank = kwargs.get('rank', 0) + dataset_name = dataset.name if dataset is not None else '' self.cfg.logs_dir = join( self.cfg.main_log_dir, model.__class__.__name__ + '_' + dataset_name + '_torch') - make_dir(self.cfg.logs_dir) - self.distributed = distributed - - self.rank = kwargs.get('rank', 0) + if self.rank == 0: + make_dir(self.cfg.main_log_dir) + make_dir(self.cfg.logs_dir) if device == 'cpu' or not torch.cuda.is_available(): if distributed: diff --git a/scripts/collect_bboxes.py b/scripts/collect_bboxes.py index 6865cbcd8..fd73f6b7a 100644 --- a/scripts/collect_bboxes.py +++ b/scripts/collect_bboxes.py @@ -1,6 +1,8 @@ from os.path import join import argparse import pickle +import random +from tqdm import tqdm from open3d.ml.datasets import utils from open3d.ml import datasets import multiprocessing @@ -26,6 +28,13 @@ def parse_args(): type=int, default=multiprocessing.cpu_count(), required=False) + parser.add_argument( + '--max_pc', + help= + 'Boxes from random N pointclouds will be saved. Default None(save from whole dataset).', + type=int, + default=None, + required=False) args = parser.parse_args() @@ -77,11 +86,18 @@ def process_boxes(i): classname = getattr(datasets, args.dataset_type) dataset = classname(args.dataset_path) train = dataset.get_split('train') + max_pc = len(train) if args.max_pc is None else args.max_pc + + query_pc = range(len(train)) if max_pc >= len(train) else random.sample( + range(len(train)), max_pc) - print("Found", len(train), "traning samples") - print("This may take a few minutes...") + print(f"Found {len(train)} traning samples, Using {max_pc}") + print( + f"Using {args.num_cpus} number of cpus, This may take a few minutes...") with multiprocessing.Pool(args.num_cpus) as p: - bboxes = p.map(process_boxes, range(len(train))) + bboxes = list(tqdm(p.imap(process_boxes, query_pc), + total=len(query_pc))) bboxes = [e for l in bboxes for e in l] file = open(join(out_path, 'bboxes.pkl'), 'wb') pickle.dump(bboxes, file) + print(f"Saved {len(bboxes)} boxes.") diff --git a/scripts/preprocess_waymo.py b/scripts/preprocess_waymo.py index 290c1768f..d502d8e3d 100644 --- a/scripts/preprocess_waymo.py +++ b/scripts/preprocess_waymo.py @@ -37,10 +37,10 @@ def parse_args(): default=16, type=int) - parser.add_argument('--is_test', - help='True for processing test data (default False)', - default=False, - type=bool) + parser.add_argument('--split', + help='One of {train, val, test} (default train)', + default='train', + type=str) args = parser.parse_args() @@ -65,9 +65,9 @@ class Waymo2KITTI(): is_test (bool): Whether in the test_mode. Default: False. """ - def __init__(self, dataset_path, save_dir='', workers=8, is_test=False): + def __init__(self, dataset_path, save_dir='', workers=8, split='train'): - self.write_image = True + self.write_image = False self.filter_empty_3dboxes = True self.filter_no_label_zone_points = True @@ -85,8 +85,8 @@ def __init__(self, dataset_path, save_dir='', workers=8, is_test=False): self.dataset_path = dataset_path self.save_dir = save_dir self.workers = int(workers) - self.is_test = is_test - self.prefix = '' + self.is_test = split == 'test' + self.prefix = split + '_' self.save_track_id = False self.tfrecord_files = sorted( @@ -152,8 +152,6 @@ def __len__(self): return len(self.tfrecord_files) def save_image(self, frame, file_idx, frame_idx): - self.prefix = '' - for img in frame.images: img_path = Path(self.image_save_dir + str(img.name - 1)) / ( self.prefix + str(file_idx).zfill(3) + str(frame_idx).zfill(3) + @@ -453,6 +451,8 @@ def cart_to_homo(mat): out_path = args.out_path if out_path is None: args.out_path = args.dataset_path + if args.split not in ['train', 'val', 'test']: + raise ValueError("split must be one of {train, val, test}") converter = Waymo2KITTI(args.dataset_path, args.out_path, args.workers, - args.is_test) + args.split) converter.convert() From 1c64d652a38990ecd280a223f1a5bc71a2eca5a7 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Fri, 18 Feb 2022 02:48:44 -0800 Subject: [PATCH 14/50] add shuffle --- ml3d/datasets/waymo.py | 4 ++++ ml3d/torch/pipelines/object_detection.py | 6 +++--- scripts/run_pipeline.py | 2 -- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/ml3d/datasets/waymo.py b/ml3d/datasets/waymo.py index 13cc098b2..d64331211 100644 --- a/ml3d/datasets/waymo.py +++ b/ml3d/datasets/waymo.py @@ -53,6 +53,7 @@ def __init__(self, self.dataset_path = cfg.dataset_path self.num_classes = 4 self.label_to_names = self.get_label_to_names() + self.shuffle = kwargs.get('shuffle', False) self.all_files = sorted( glob(join(cfg.dataset_path, 'velodyne', '*.bin'))) @@ -69,6 +70,9 @@ def __init__(self, else: log.warning( f"Skipping {f}, prefix must be one of train, test or val.") + if self.shuffle: + log.info("Shuffling training files...") + self.rng.shuffle(self.train_files) @staticmethod def get_label_to_names(): diff --git a/ml3d/torch/pipelines/object_detection.py b/ml3d/torch/pipelines/object_detection.py index cff670f92..60e6d7ec1 100644 --- a/ml3d/torch/pipelines/object_detection.py +++ b/ml3d/torch/pipelines/object_detection.py @@ -155,8 +155,8 @@ def run_valid(self, epoch=0): log.info("DEVICE : {}".format(device)) log_file_path = join(cfg.logs_dir, 'log_valid_' + timestamp + '.txt') - log.info("Logging in file : {}".format(log_file_path)) if self.rank == 0: + log.info("Logging in file : {}".format(log_file_path)) log.addHandler(logging.FileHandler(log_file_path)) batcher = ConcatBatcher(device, model.cfg.name) @@ -207,12 +207,12 @@ def run_valid(self, epoch=0): boxes = model.inference_end(results, data) pred.extend([BEVBox3D.to_dicts(b) for b in boxes]) gt.extend([BEVBox3D.to_dicts(b) for b in data.bbox_objs]) - # Save only for the first batch - if record_summary and 'valid' not in self.summary: + if record_summary: self.summary['valid'] = self.get_3d_summary(boxes, data, epoch, results=results) + record_summary = False # Save only for the first batch sum_loss = 0 desc = "validation - " diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 84ffa2919..0495fedd9 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -152,8 +152,6 @@ def main(): args.device_ids) > 1 if not args.distributed: - # print("not distr : ") - # exit(0) dataset = Dataset(**cfg_dict_dataset) model = Model(**cfg_dict_model, mode=args.mode) pipeline = Pipeline(model, dataset, **cfg_dict_pipeline) From a22a8dce59175103977da4971481c748a03f0440 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Fri, 18 Feb 2022 17:29:16 +0530 Subject: [PATCH 15/50] fix rng --- ml3d/torch/pipelines/base_pipeline.py | 3 ++- scripts/run_pipeline.py | 7 ++++++- 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/ml3d/torch/pipelines/base_pipeline.py b/ml3d/torch/pipelines/base_pipeline.py index 2c633c761..a032714a9 100644 --- a/ml3d/torch/pipelines/base_pipeline.py +++ b/ml3d/torch/pipelines/base_pipeline.py @@ -24,6 +24,7 @@ def __init__(self, model: A network model. dataset: A dataset, or None for inference model. device: 'gpu' or 'cpu'. + distributed: Whether to use multiple gpus. kwargs: Returns: @@ -59,7 +60,7 @@ def __init__(self, else: if distributed: self.device = torch.device(device) - print("Using device", self.device) + print(f"Rank : {self.rank} using device : {self.device}") torch.cuda.set_device(self.device) else: self.device = torch.device('cuda') diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 0495fedd9..0bcf31e30 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -191,6 +191,11 @@ def main_worker(rank, Dataset, Model, Pipeline, cfg_dict_dataset, cfg_dict_model['rank'] = rank cfg_dict_pipeline['rank'] = rank + rng = np.random.default_rng(args.seed + rank) + cfg_dict_dataset['seed'] = rng + cfg_dict_model['seed'] = rng + cfg_dict_pipeline['seed'] = rng + device = f"cuda:{args.device_ids[rank]}" print(f"rank = {rank}, world_size = {world_size}, gpu = {device}") @@ -219,4 +224,4 @@ def main_worker(rank, Dataset, Model, Pipeline, cfg_dict_dataset, if __name__ == '__main__': multiprocessing.set_start_method('spawn') - main() + sys.exit(main()) From 685dd3aa3993a9a65f019153adb4f2b3933c4f01 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Fri, 18 Feb 2022 17:31:18 +0530 Subject: [PATCH 16/50] remove customparallel --- ml3d/torch/pipelines/object_detection.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ml3d/torch/pipelines/object_detection.py b/ml3d/torch/pipelines/object_detection.py index 60e6d7ec1..d8df902bc 100644 --- a/ml3d/torch/pipelines/object_detection.py +++ b/ml3d/torch/pipelines/object_detection.py @@ -11,7 +11,6 @@ from torch.utils.data import DataLoader from .base_pipeline import BasePipeline -from .dataparallel import CustomDataParallel from ..dataloaders import TorchDataloader, ConcatBatcher from torch.utils.tensorboard import SummaryWriter # pylint: disable-next=unused-import From bfaa4a2dec9968347e20f0e71e783b7db6147c97 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Fri, 18 Feb 2022 17:44:42 +0530 Subject: [PATCH 17/50] reset semseg distributed training --- ml3d/torch/pipelines/base_pipeline.py | 5 +++++ ml3d/torch/pipelines/object_detection.py | 19 ++++++++++--------- ml3d/torch/pipelines/semantic_segmentation.py | 11 +++++------ 3 files changed, 20 insertions(+), 15 deletions(-) diff --git a/ml3d/torch/pipelines/base_pipeline.py b/ml3d/torch/pipelines/base_pipeline.py index a032714a9..aed40e36f 100644 --- a/ml3d/torch/pipelines/base_pipeline.py +++ b/ml3d/torch/pipelines/base_pipeline.py @@ -41,6 +41,11 @@ def __init__(self, self.rng = np.random.default_rng(kwargs.get('seed', None)) self.distributed = distributed + if self.distributed and self.name == "SemanticSegmentation": + raise NotImplementedError( + "Distributed training not implemented for SemanticSegmentation!" + ) + self.rank = kwargs.get('rank', 0) dataset_name = dataset.name if dataset is not None else '' diff --git a/ml3d/torch/pipelines/object_detection.py b/ml3d/torch/pipelines/object_detection.py index d8df902bc..91a570676 100644 --- a/ml3d/torch/pipelines/object_detection.py +++ b/ml3d/torch/pipelines/object_detection.py @@ -312,15 +312,16 @@ def run_train(self): else: train_sampler = None - train_loader = DataLoader(train_split, - batch_size=cfg.batch_size, - num_workers=cfg.get('num_workers', 0), - pin_memory=cfg.get('pin_memory', False), - collate_fn=batcher.collate_fn, - sampler=train_sampler) - # worker_init_fn=lambda x: np.random.seed(x + np.uint32( - # torch.utils.data.get_worker_info().seed)) - # ) # numpy expects np.uint32, whereas torch returns np.uint64. + train_loader = DataLoader( + train_split, + batch_size=cfg.batch_size, + num_workers=cfg.get('num_workers', 0), + pin_memory=cfg.get('pin_memory', False), + collate_fn=batcher.collate_fn, + sampler=train_sampler, + worker_init_fn=lambda x: np.random.seed(x + np.uint32( + torch.utils.data.get_worker_info().seed)) + ) # numpy expects np.uint32, whereas torch returns np.uint64. self.optimizer, self.scheduler = model.get_optimizer(cfg.optimizer) diff --git a/ml3d/torch/pipelines/semantic_segmentation.py b/ml3d/torch/pipelines/semantic_segmentation.py index d56ddc48b..227b77722 100644 --- a/ml3d/torch/pipelines/semantic_segmentation.py +++ b/ml3d/torch/pipelines/semantic_segmentation.py @@ -12,7 +12,6 @@ # pylint: disable-next=unused-import from open3d.visualization.tensorboard_plugin import summary from .base_pipeline import BasePipeline -from .dataparallel import CustomDataParallel from ..dataloaders import get_sampler, TorchDataloader, DefaultBatcher, ConcatBatcher from ..utils import latest_torch_ckpt from ..modules.losses import SemSegLoss, filter_valid_label @@ -102,7 +101,6 @@ def __init__( momentum=0.98, main_log_dir='./logs/', device='cuda', - device_ids=[0], split='train', train_sum_dir='train_log', **kwargs): @@ -122,7 +120,6 @@ def __init__( momentum=momentum, main_log_dir=main_log_dir, device=device, - device_ids=device_ids, split=split, train_sum_dir=train_sum_dir, **kwargs) @@ -309,6 +306,7 @@ def run_train(self): dataset = self.dataset cfg = self.cfg + model.to(device) log.info("DEVICE : {}".format(device)) timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') @@ -379,9 +377,6 @@ def run_train(self): writer = SummaryWriter(self.tensorboard_dir) self.save_config(writer) - - model = CustomDataParallel(model, device_ids=self.device_ids) - log.info("Writing summary in {}.".format(self.tensorboard_dir)) record_summary = cfg.get('summary').get('record_for', []) @@ -397,6 +392,8 @@ def run_train(self): model.trans_point_sampler = train_sampler.get_point_sampler() for step, inputs in enumerate(tqdm(train_loader, desc='training')): + if hasattr(inputs['data'], 'to'): + inputs['data'].to(device) self.optimizer.zero_grad() results = model(inputs['data']) loss, gt_labels, predict_scores = model.get_loss( @@ -429,6 +426,8 @@ def run_train(self): with torch.no_grad(): for step, inputs in enumerate( tqdm(valid_loader, desc='validation')): + if hasattr(inputs['data'], 'to'): + inputs['data'].to(device) results = model(inputs['data']) loss, gt_labels, predict_scores = model.get_loss( From 6d58cd12faee716feb0385dd020e3e85880d1ae1 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 22 Feb 2022 01:47:35 -0800 Subject: [PATCH 18/50] change config --- ml3d/configs/pointpillars_waymo.yml | 51 +++++++++++++++-------------- 1 file changed, 26 insertions(+), 25 deletions(-) diff --git a/ml3d/configs/pointpillars_waymo.yml b/ml3d/configs/pointpillars_waymo.yml index cd35e21f6..13b35fdd9 100644 --- a/ml3d/configs/pointpillars_waymo.yml +++ b/ml3d/configs/pointpillars_waymo.yml @@ -2,7 +2,7 @@ dataset: name: Waymo dataset_path: # path/to/your/dataset cache_dir: ./logs/cache - steps_per_epoch_train: 5000 + steps_per_epoch_train: 4000 model: name: PointPillars @@ -10,7 +10,7 @@ model: batcher: "ignore" - point_cloud_range: [-80, -80, -6, 85, 85, 7] + point_cloud_range: [-74.88, -74.88, -2, 74.88, 74.88, 4] classes: ['VEHICLE', 'PEDESTRIAN', 'CYCLIST'] loss: @@ -37,13 +37,13 @@ model: scatter: in_channels: 64 - output_shape: [520, 520] + output_shape: [468, 468] backbone: in_channels: 64 out_channels: [64, 128, 256] layer_nums: [3, 5, 5] - layer_strides: [2, 2, 2] + layer_strides: [1, 2, 2] neck: in_channels: [64, 128, 256] @@ -57,38 +57,39 @@ model: nms_pre: 4096 score_thr: 0.1 ranges: [ - [-80, -80, 1.142, 85, 85, 1.142], - [-80, -80, 1.139, 85, 85, 1.139], - [-80, -80, 1.149, 85, 85, 1.149], + [-74.88, -74.88, -0.0345, 74.88, 74.88, -0.0345], + [-74.88, -74.88, -0.1188, 74.88, 74.88, -0.1188], + [-74.88, -74.88, 0, 74.88, 74.88, 0], ] sizes: [ - [1.98, 4.50, 1.96], # VEHICLE - [0.91, 1.94, 1.78], # CYCLIST - [0.84, 0.91, 1.70] # PEDESTRIAN + [2.08, 4.73, 1.77], # VEHICLE + [0.84, 1.81, 1.77], # CYCLIST + [0.84, 0.91, 1.74] # PEDESTRIAN ] dir_offset: 0.7854 rotations: [0, 1.57] iou_thr: [[0.4, 0.55], [0.3, 0.5], [0.3, 0.5]] - augment: {} - # PointShuffle: True - # ObjectRangeFilter: True - # ObjectSample: - # min_points_dict: - # VEHICLE: 5 - # PEDESTRIAN: 10 - # CYCLIST: 10 - # sample_dict: - # VEHICLE: 15 - # PEDESTRIAN: 10 - # CYCLIST: 10 + augment: + PointShuffle: True + ObjectRangeFilter: + point_cloud_range: [-74.88, -74.88, -2, 74.88, 74.88, 4] + ObjectSample: + min_points_dict: + VEHICLE: 5 + PEDESTRIAN: 10 + CYCLIST: 10 + sample_dict: + VEHICLE: 15 + PEDESTRIAN: 10 + CYCLIST: 10 pipeline: name: ObjectDetection test_compute_metric: true batch_size: 6 - val_batch_size: 1 + val_batch_size: 6 test_batch_size: 1 save_ckpt_freq: 5 max_epoch: 200 @@ -102,10 +103,10 @@ pipeline: weight_decay: 0.01 # evaluation properties - overlaps: [0.5, 0.5, 0.7] + overlaps: [0.5, 0.5, 0.5] difficulties: [0, 1, 2] summary: - record_for: [] + record_for: [train, valid] max_pts: use_reference: false max_outputs: 1 From 51a16c376259ab7d4410e4f4590a2a1090b55bb7 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 22 Feb 2022 19:32:46 +0530 Subject: [PATCH 19/50] fix lgtm --- ml3d/torch/pipelines/dataparallel.py | 1 - scripts/run_pipeline.py | 2 -- 2 files changed, 3 deletions(-) diff --git a/ml3d/torch/pipelines/dataparallel.py b/ml3d/torch/pipelines/dataparallel.py index 1b36c431a..eee68138d 100644 --- a/ml3d/torch/pipelines/dataparallel.py +++ b/ml3d/torch/pipelines/dataparallel.py @@ -1,5 +1,4 @@ import torch -import numpy as np from torch.nn.parallel import DataParallel diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 504d21e97..a1b7311f4 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -212,8 +212,6 @@ def main_worker(rank, Dataset, Model, Pipeline, cfg_dict_dataset, distributed=args.distributed, **cfg_dict_pipeline) - with open(Path(__file__).parent / 'README.md', 'r') as f: - readme = f.read() pipeline.cfg_tb = args.cfg_tb if args.split == 'test': From 50133bb8db163f48997c8b3461e73dc037d8681f Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Fri, 25 Mar 2022 20:23:09 +0530 Subject: [PATCH 20/50] address reviews (1) --- ml3d/configs/pointpillars_waymo.yml | 2 +- ml3d/datasets/waymo.py | 28 ++++++++--------- ml3d/torch/pipelines/base_pipeline.py | 6 ++-- ml3d/torch/pipelines/object_detection.py | 3 +- scripts/collect_bboxes.py | 14 +++++---- scripts/preprocess_waymo.py | 38 ++++++++++++++---------- scripts/run_pipeline.py | 24 +++++++++++---- 7 files changed, 67 insertions(+), 48 deletions(-) diff --git a/ml3d/configs/pointpillars_waymo.yml b/ml3d/configs/pointpillars_waymo.yml index 13b35fdd9..23900ed15 100644 --- a/ml3d/configs/pointpillars_waymo.yml +++ b/ml3d/configs/pointpillars_waymo.yml @@ -106,7 +106,7 @@ pipeline: overlaps: [0.5, 0.5, 0.5] difficulties: [0, 1, 2] summary: - record_for: [train, valid] + record_for: [] max_pts: use_reference: false max_outputs: 1 diff --git a/ml3d/datasets/waymo.py b/ml3d/datasets/waymo.py index 6a01cf418..bceacd23f 100644 --- a/ml3d/datasets/waymo.py +++ b/ml3d/datasets/waymo.py @@ -91,18 +91,17 @@ def read_lidar(path): """Reads lidar data from the path provided. Returns: - A data object with lidar information. + pc: pointcloud data with shape [N, 6], where + the format is xyzRGB. """ - assert Path(path).exists() - return np.fromfile(path, dtype=np.float32).reshape(-1, 6) @staticmethod def read_label(path, calib): - """Reads labels of bound boxes. + """Reads labels of bounding boxes. Returns: - The data objects with bound boxes information. + The data objects with bounding boxes information. """ if not Path(path).exists(): return None @@ -132,24 +131,22 @@ def read_calib(path): Returns: The camera and the camera image used in calibration. """ - assert Path(path).exists() - with open(path, 'r') as f: lines = f.readlines() obj = lines[0].strip().split(' ')[1:] - P0 = np.array(obj, dtype=np.float32) + unused_P0 = np.array(obj, dtype=np.float32) obj = lines[1].strip().split(' ')[1:] - P1 = np.array(obj, dtype=np.float32) + unused_P1 = np.array(obj, dtype=np.float32) obj = lines[2].strip().split(' ')[1:] P2 = np.array(obj, dtype=np.float32) obj = lines[3].strip().split(' ')[1:] - P3 = np.array(obj, dtype=np.float32) + unused_P3 = np.array(obj, dtype=np.float32) obj = lines[4].strip().split(' ')[1:] - P4 = np.array(obj, dtype=np.float32) + unused_P4 = np.array(obj, dtype=np.float32) obj = lines[5].strip().split(' ')[1:] R0 = np.array(obj, dtype=np.float32).reshape(3, 3) @@ -210,7 +207,7 @@ def get_split_list(self, split): else: raise ValueError("Invalid split {}".format(split)) - def is_tested(): + def is_tested(attr): """Checks if a datum in the dataset has been tested. Args: @@ -220,16 +217,16 @@ def is_tested(): If the datum attribute is tested, then return the path where the attribute is stored; else, returns false. """ - pass + raise NotImplementedError() - def save_test_result(): + def save_test_result(results, attr): """Saves the output of a model. Args: results: The output of a model for the datum associated with the attribute passed. attr: The attributes that correspond to the outputs passed in results. """ - pass + raise NotImplementedError() class WaymoSplit(): @@ -279,6 +276,7 @@ class Object3d(BEVBox3D): """ def __init__(self, center, size, label, calib): + # ground truth files doesn't have confidence value. confidence = float(label[15]) if label.__len__() == 16 else -1.0 world_cam = calib['world_cam'] diff --git a/ml3d/torch/pipelines/base_pipeline.py b/ml3d/torch/pipelines/base_pipeline.py index aed40e36f..f466868b9 100644 --- a/ml3d/torch/pipelines/base_pipeline.py +++ b/ml3d/torch/pipelines/base_pipeline.py @@ -23,7 +23,7 @@ def __init__(self, Args: model: A network model. dataset: A dataset, or None for inference model. - device: 'gpu' or 'cpu'. + device: 'cuda' or 'cpu'. distributed: Whether to use multiple gpus. kwargs: @@ -59,8 +59,8 @@ def __init__(self, if device == 'cpu' or not torch.cuda.is_available(): if distributed: - raise ValueError( - "Distributed training is ON, but CUDA not available.") + raise NotImplementedError( + "Distributed training for CPU is not supported yet.") self.device = torch.device('cpu') else: if distributed: diff --git a/ml3d/torch/pipelines/object_detection.py b/ml3d/torch/pipelines/object_detection.py index 11bb87c99..6646ac225 100644 --- a/ml3d/torch/pipelines/object_detection.py +++ b/ml3d/torch/pipelines/object_detection.py @@ -176,8 +176,6 @@ def run_valid(self, epoch=0): pin_memory=cfg.get('pin_memory', False), collate_fn=batcher.collate_fn, sampler=valid_sampler) - # worker_init_fn=lambda x: np.random.seed(x + np.uint32( - # torch.utils.data.get_worker_info().seed))) record_summary = self.rank == 0 and 'valid' in cfg.get('summary').get( 'record_for', []) @@ -341,6 +339,7 @@ def run_train(self): model.cuda(self.device) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.device]) + # model.get_loss = model.module.get_loss record_summary = self.rank == 0 and 'train' in cfg.get('summary').get( 'record_for', []) diff --git a/scripts/collect_bboxes.py b/scripts/collect_bboxes.py index 609020851..f8c6a5f33 100644 --- a/scripts/collect_bboxes.py +++ b/scripts/collect_bboxes.py @@ -1,12 +1,13 @@ import logging -from os.path import join import argparse import pickle -import random +import numpy as np +import multiprocessing + from tqdm import tqdm +from os.path import join from open3d.ml.datasets import utils from open3d.ml import datasets -import multiprocessing def parse_args(): @@ -25,7 +26,7 @@ def parse_args(): default="KITTI", required=False) parser.add_argument('--num_cpus', - help='Name of dataset class', + help='Number of threads to use.', type=int, default=multiprocessing.cpu_count(), required=False) @@ -95,8 +96,9 @@ def process_boxes(i): train = dataset.get_split('train') max_pc = len(train) if args.max_pc is None else args.max_pc - query_pc = range(len(train)) if max_pc >= len(train) else random.sample( - range(len(train)), max_pc) + rng = np.random.default_rng() + query_pc = range(len(train)) if max_pc >= len(train) else rng.choice( + range(len(train)), max_pc, replace=False) print(f"Found {len(train)} traning samples, Using {max_pc}") print( diff --git a/scripts/preprocess_waymo.py b/scripts/preprocess_waymo.py index bf5afb24f..6b44a61fd 100644 --- a/scripts/preprocess_waymo.py +++ b/scripts/preprocess_waymo.py @@ -8,13 +8,13 @@ import logging import numpy as np import os, sys, glob, pickle -from pathlib import Path -from os.path import join, exists, dirname, abspath -from os import makedirs -import random import argparse import tensorflow as tf import matplotlib.image as mpimg + +from pathlib import Path +from os.path import join, exists, dirname, abspath +from os import makedirs from multiprocessing import Pool from tqdm import tqdm from waymo_open_dataset.utils import range_image_utils, transform_utils @@ -58,6 +58,25 @@ class Waymo2KITTI(): """Waymo to KITTI converter. This class converts tfrecord files from Waymo dataset to KITTI format. + KITTI format : (type, truncated, occluded, alpha, bbox, dimensions(3), location(3), + rotation_y(1), score(1, optional)) + type (string): Describes the type of object. + truncated (float): Ranges from 0(non-truncated) to 1(truncated). + occluded (int): Integer(0, 1, 2, 3) signifies state fully visible, partly + occluded, largely occluded, unknown. + alpha (float): Observation angle of object, ranging [-pi..pi]. + bbox (float): 2d bounding box of object in the image. + dimensions (float): 3D object dimensions: h, w, l in meters. + location (float): 3D object location: x,y,z in camera coordinates (in meters). + rotation_y (float): rotation around Y-axis in camera coordinates [-pi..pi]. + score (float): Only for predictions, indicating confidence in detection. + + Conversion writes following files: + pointcloud(np.float32) : pointcloud data with shape [N, 6]. Consists of + (x, y, z, intensity, elongation, timestamp). + images(np.uint8): camera images are saved if `write_image` is True. + calibrations(np.float32): Intinsic and Extrinsic matrix for all cameras. + label(np.float32): Bounding box information in KITTI format. Args: dataset_path (str): Directory to load waymo raw data. @@ -137,7 +156,6 @@ def process_one(self, file_idx): if (self.selected_waymo_locations is not None and frame.context.stats.location not in self.selected_waymo_locations): - print("continue") continue if self.write_image: @@ -208,7 +226,6 @@ def save_calib(self, frame, file_idx, frame_idx): f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', 'w+') as fp_calib: fp_calib.write(calib_context) - fp_calib.close() def save_pose(self, frame, file_idx, frame_idx): pose = np.array(frame.pose.transform).reshape(4, 4) @@ -226,7 +243,6 @@ def save_label(self, frame, file_idx, frame_idx): for labels in frame.projected_lidar_labels: name = labels.name for label in labels.labels: - # TODO: need a workaround as bbox may not belong to front cam bbox = [ label.box.center_x - label.box.length / 2, label.box.center_y - label.box.width / 2, @@ -255,9 +271,6 @@ def save_label(self, frame, file_idx, frame_idx): if my_type not in self.selected_waymo_classes: continue - # if self.filter_empty_3dboxes and obj.num_lidar_points_in_box < 1: - # continue - height = obj.box.height width = obj.box.width length = obj.box.length @@ -266,11 +279,6 @@ def save_label(self, frame, file_idx, frame_idx): y = obj.box.center_y z = obj.box.center_z - # # project bounding box to the virtual reference frame - # pt_ref = self.T_velo_to_front_cam @ \ - # np.array([x, y, z, 1]).reshape((4, 1)) - # x, y, z, _ = pt_ref.flatten().tolist() - rotation_y = -obj.box.heading - np.pi / 2 track_id = obj.id diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index a1b7311f4..68564e1b9 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -43,6 +43,18 @@ def parse_args(): parser.add_argument('--main_log_dir', help='the dir to save logs and models') parser.add_argument('--seed', help='random seed', default=0) + parser.add_argument( + '--host', + help='Host for distributed training, default: localhost', + default='localhost') + parser.add_argument('--port', + help='port for distributed training, default: 12355', + default='12355') + parser.add_argument( + '--backend', + help= + 'backend for distributed training. One of (nccl, gloo)}, default: gloo', + default='gloo') args, unknown = parser.parse_known_args() @@ -173,12 +185,12 @@ def main(): nprocs=len(args.device_ids)) -def setup(rank, world_size): - os.environ['MASTER_ADDR'] = 'localhost' - os.environ['MASTER_PORT'] = '12355' +def setup(rank, world_size, args): + os.environ['MASTER_ADDR'] = args.host + os.environ['MASTER_PORT'] = args.port # initialize the process group - dist.init_process_group("gloo", rank=rank, world_size=world_size) + dist.init_process_group(args.backend, rank=rank, world_size=world_size) def cleanup(): @@ -188,7 +200,7 @@ def cleanup(): def main_worker(rank, Dataset, Model, Pipeline, cfg_dict_dataset, cfg_dict_model, cfg_dict_pipeline, args): world_size = len(args.device_ids) - setup(rank, world_size) + setup(rank, world_size, args) cfg_dict_dataset['rank'] = rank cfg_dict_model['rank'] = rank @@ -229,5 +241,5 @@ def main_worker(rank, Dataset, Model, Pipeline, cfg_dict_dataset, format='%(levelname)s - %(asctime)s - %(module)s - %(message)s', ) - multiprocessing.set_start_method('spawn') + multiprocessing.set_start_method('forkserver') sys.exit(main()) From 1ed11c8841bbb4423dc82f553a4029c03fcec66e Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Fri, 25 Mar 2022 20:32:02 +0530 Subject: [PATCH 21/50] fix model.module.... --- ml3d/torch/pipelines/object_detection.py | 14 +++++--------- 1 file changed, 5 insertions(+), 9 deletions(-) diff --git a/ml3d/torch/pipelines/object_detection.py b/ml3d/torch/pipelines/object_detection.py index 6646ac225..0ba26f116 100644 --- a/ml3d/torch/pipelines/object_detection.py +++ b/ml3d/torch/pipelines/object_detection.py @@ -339,7 +339,9 @@ def run_train(self): model.cuda(self.device) model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.device]) - # model.get_loss = model.module.get_loss + model.get_loss = model.module.get_loss + model.cfg = model.module.cfg + model.inference_end = model.module.inference_end record_summary = self.rank == 0 and 'train' in cfg.get('summary').get( 'record_for', []) @@ -359,10 +361,7 @@ def run_train(self): for data in process_bar: data.to(device) results = model(data) - if self.distributed: - loss = model.module.get_loss(results, data) - else: - loss = model.get_loss(results, data) + loss = model.get_loss(results, data) loss_sum = sum(loss.values()) self.optimizer.zero_grad() @@ -381,10 +380,7 @@ def run_train(self): # Record visualization for the last iteration if record_summary and process_bar.n == process_bar.total - 1: - if self.distributed: - boxes = model.module.inference_end(results, data) - else: - boxes = model.inference_end(results, data) + boxes = model.inference_end(results, data) self.summary['train'] = self.get_3d_summary(boxes, data, epoch, From 39b31171b2b54de426ce264ae8ccc412730ae053 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Mon, 28 Mar 2022 21:12:13 +0530 Subject: [PATCH 22/50] add semseg labels --- scripts/preprocess_waymo.py | 83 ++++++++++++++++++++++--------------- 1 file changed, 50 insertions(+), 33 deletions(-) diff --git a/scripts/preprocess_waymo.py b/scripts/preprocess_waymo.py index 6a32dde7d..bc0b604a6 100644 --- a/scripts/preprocess_waymo.py +++ b/scripts/preprocess_waymo.py @@ -1,25 +1,26 @@ try: from waymo_open_dataset import dataset_pb2 + from waymo_open_dataset.utils import range_image_utils, transform_utils + from waymo_open_dataset.utils.frame_utils import \ + parse_range_image_and_camera_projection except ImportError: raise ImportError( - 'Please run "pip install waymo-open-dataset-tf-2-1-0==1.2.0" ' + 'Please clone "https://github.com/waymo-research/waymo-open-dataset.git" ' + 'checkout branch "r1.3", and append its path to PYTHONPATH ' 'to install the official devkit first.') import logging import numpy as np import os, sys, glob, pickle -from pathlib import Path -from os.path import join, exists, dirname, abspath -from os import makedirs -import random import argparse import tensorflow as tf import matplotlib.image as mpimg + +from pathlib import Path +from os.path import join, exists, dirname, abspath +from os import makedirs from multiprocessing import Pool from tqdm import tqdm -from waymo_open_dataset.utils import range_image_utils, transform_utils -from waymo_open_dataset.utils.frame_utils import \ - parse_range_image_and_camera_projection def parse_args(): @@ -38,10 +39,10 @@ def parse_args(): default=16, type=int) - parser.add_argument('--is_test', - help='True for processing test data (default False)', - default=False, - type=bool) + parser.add_argument('--split', + help='One of {train, val, test} (default train)', + default='train', + type=str) args = parser.parse_args() @@ -58,6 +59,25 @@ class Waymo2KITTI(): """Waymo to KITTI converter. This class converts tfrecord files from Waymo dataset to KITTI format. + KITTI format : (type, truncated, occluded, alpha, bbox, dimensions(3), location(3), + rotation_y(1), score(1, optional)) + type (string): Describes the type of object. + truncated (float): Ranges from 0(non-truncated) to 1(truncated). + occluded (int): Integer(0, 1, 2, 3) signifies state fully visible, partly + occluded, largely occluded, unknown. + alpha (float): Observation angle of object, ranging [-pi..pi]. + bbox (float): 2d bounding box of object in the image. + dimensions (float): 3D object dimensions: h, w, l in meters. + location (float): 3D object location: x,y,z in camera coordinates (in meters). + rotation_y (float): rotation around Y-axis in camera coordinates [-pi..pi]. + score (float): Only for predictions, indicating confidence in detection. + + Conversion writes following files: + pointcloud(np.float32) : pointcloud data with shape [N, 6]. Consists of + (x, y, z, intensity, elongation, timestamp). + images(np.uint8): camera images are saved if `write_image` is True. + calibrations(np.float32): Intinsic and Extrinsic matrix for all cameras. + label(np.float32): Bounding box information in KITTI format. Args: dataset_path (str): Directory to load waymo raw data. @@ -66,11 +86,11 @@ class Waymo2KITTI(): is_test (bool): Whether in the test_mode. Default: False. """ - def __init__(self, dataset_path, save_dir='', workers=8, is_test=False): + def __init__(self, dataset_path, save_dir='', workers=8, split='train'): - self.write_image = True + self.write_image = False self.filter_empty_3dboxes = True - self.filter_no_label_zone_points = True + self.filter_no_label_zone_points = False self.classes = ['VEHICLE', 'PEDESTRIAN', 'SIGN', 'CYCLIST'] @@ -86,8 +106,8 @@ def __init__(self, dataset_path, save_dir='', workers=8, is_test=False): self.dataset_path = dataset_path self.save_dir = save_dir self.workers = int(workers) - self.is_test = is_test - self.prefix = '' + self.is_test = split == 'test' + self.prefix = split + '_' self.save_track_id = False self.tfrecord_files = sorted( @@ -137,7 +157,6 @@ def process_one(self, file_idx): if (self.selected_waymo_locations is not None and frame.context.stats.location not in self.selected_waymo_locations): - print("continue") continue if self.write_image: @@ -153,8 +172,6 @@ def __len__(self): return len(self.tfrecord_files) def save_image(self, frame, file_idx, frame_idx): - self.prefix = '' - for img in frame.images: img_path = Path(self.image_save_dir + str(img.name - 1)) / ( self.prefix + str(file_idx).zfill(3) + str(frame_idx).zfill(3) + @@ -210,7 +227,6 @@ def save_calib(self, frame, file_idx, frame_idx): f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', 'w+') as fp_calib: fp_calib.write(calib_context) - fp_calib.close() def save_pose(self, frame, file_idx, frame_idx): pose = np.array(frame.pose.transform).reshape(4, 4) @@ -228,7 +244,6 @@ def save_label(self, frame, file_idx, frame_idx): for labels in frame.projected_lidar_labels: name = labels.name for label in labels.labels: - # TODO: need a workaround as bbox may not belong to front cam bbox = [ label.box.center_x - label.box.length / 2, label.box.center_y - label.box.width / 2, @@ -257,9 +272,6 @@ def save_label(self, frame, file_idx, frame_idx): if my_type not in self.selected_waymo_classes: continue - # if self.filter_empty_3dboxes and obj.num_lidar_points_in_box < 1: - # continue - height = obj.box.height width = obj.box.width length = obj.box.length @@ -268,11 +280,6 @@ def save_label(self, frame, file_idx, frame_idx): y = obj.box.center_y z = obj.box.center_z - # # project bounding box to the virtual reference frame - # pt_ref = self.T_velo_to_front_cam @ \ - # np.array([x, y, z, 1]).reshape((4, 1)) - # x, y, z, _ = pt_ref.flatten().tolist() - rotation_y = -obj.box.heading - np.pi / 2 track_id = obj.id @@ -306,7 +313,7 @@ def save_label(self, frame, file_idx, frame_idx): fp_label_all.close() def save_lidar(self, frame, file_idx, frame_idx): - range_images, camera_projections, range_image_top_pose = parse_range_image_and_camera_projection( + range_images, camera_projections, seg_labels, range_image_top_pose = parse_range_image_and_camera_projection( frame) # First return @@ -314,6 +321,7 @@ def save_lidar(self, frame, file_idx, frame_idx): self.convert_range_image_to_point_cloud( frame, range_images, + seg_labels, camera_projections, range_image_top_pose, ri_index=0 @@ -327,6 +335,7 @@ def save_lidar(self, frame, file_idx, frame_idx): self.convert_range_image_to_point_cloud( frame, range_images, + seg_labels, camera_projections, range_image_top_pose, ri_index=1 @@ -351,6 +360,7 @@ def save_lidar(self, frame, file_idx, frame_idx): def convert_range_image_to_point_cloud(self, frame, range_images, + seg_labels, camera_projections, range_image_top_pose, ri_index=0): @@ -360,6 +370,7 @@ def convert_range_image_to_point_cloud(self, cp_points = [] intensity = [] elongation = [] + semseg_labels = [] frame_pose = tf.convert_to_tensor( value=np.reshape(np.array(frame.pose.transform), [4, 4])) @@ -380,6 +391,7 @@ def convert_range_image_to_point_cloud(self, range_image_top_pose_tensor_translation) for c in calibrations: range_image = range_images[c.name][ri_index] + seg_label = seg_labels[c.name][ri_index] if len(c.beam_inclinations) == 0: beam_inclinations = range_image_utils.compute_inclination( tf.constant( @@ -416,9 +428,12 @@ def convert_range_image_to_point_cloud(self, frame_pose=frame_pose_local) range_image_cartesian = tf.squeeze(range_image_cartesian, axis=0) + print(range_image_cartesian.shape) points_tensor = tf.gather_nd(range_image_cartesian, tf.compat.v1.where(range_image_mask)) - + print(points_tensor.shape) + print(seg_label.shape) + exit(0) cp = camera_projections[c.name][ri_index] cp_tensor = tf.reshape(tf.convert_to_tensor(value=cp.data), cp.shape.dims) @@ -460,6 +475,8 @@ def cart_to_homo(mat): out_path = args.out_path if out_path is None: args.out_path = args.dataset_path + if args.split not in ['train', 'val', 'test']: + raise ValueError("split must be one of {train, val, test}") converter = Waymo2KITTI(args.dataset_path, args.out_path, args.workers, - args.is_test) + args.split) converter.convert() From 142df86cee0ba29b6abee1fedebddfabc1b99426 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 29 Mar 2022 02:42:33 +0530 Subject: [PATCH 23/50] add waymo semseg preprocessing --- scripts/preprocess_waymo_semseg.py | 425 +++++++++++++++++++++++++++++ 1 file changed, 425 insertions(+) create mode 100644 scripts/preprocess_waymo_semseg.py diff --git a/scripts/preprocess_waymo_semseg.py b/scripts/preprocess_waymo_semseg.py new file mode 100644 index 000000000..464ec91f3 --- /dev/null +++ b/scripts/preprocess_waymo_semseg.py @@ -0,0 +1,425 @@ +try: + from waymo_open_dataset import dataset_pb2 + from waymo_open_dataset.utils import range_image_utils, transform_utils + from waymo_open_dataset.utils.frame_utils import \ + parse_range_image_and_camera_projection +except ImportError: + raise ImportError( + 'Please clone "https://github.com/waymo-research/waymo-open-dataset.git" ' + 'checkout branch "r1.3", and install the official devkit first') + +import logging +import numpy as np +import os, sys, glob, pickle +import argparse +import tensorflow as tf +import matplotlib.image as mpimg + +from pathlib import Path +from os.path import join, exists, dirname, abspath +from os import makedirs +from multiprocessing import Pool +from tqdm import tqdm + +gpus = tf.config.experimental.list_physical_devices('GPU') +if gpus: + try: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + except RuntimeError as e: + print(e) + + +def parse_args(): + parser = argparse.ArgumentParser(description='Preprocess Waymo Dataset.') + parser.add_argument('--dataset_path', + help='path to Waymo tfrecord files', + required=True) + parser.add_argument( + '--out_path', + help='Output path to store pickle (default to dataet_path)', + default=None, + required=False) + + parser.add_argument('--workers', + help='Number of workers.', + default=16, + type=int) + + parser.add_argument('--split', + help='One of {train, val, test} (default train)', + default='train', + type=str) + + args = parser.parse_args() + + dict_args = vars(args) + for k in dict_args: + v = dict_args[k] + print("{}: {}".format(k, v) if v is not None else "{} not given". + format(k)) + + return args + + +class Waymo2KITTI(): + """Waymo to KITTI converter. + + This class converts tfrecord files from Waymo dataset to KITTI format. + KITTI format : (type, truncated, occluded, alpha, bbox, dimensions(3), location(3), + rotation_y(1), score(1, optional)) + type (string): Describes the type of object. + truncated (float): Ranges from 0(non-truncated) to 1(truncated). + occluded (int): Integer(0, 1, 2, 3) signifies state fully visible, partly + occluded, largely occluded, unknown. + alpha (float): Observation angle of object, ranging [-pi..pi]. + bbox (float): 2d bounding box of object in the image. + dimensions (float): 3D object dimensions: h, w, l in meters. + location (float): 3D object location: x,y,z in camera coordinates (in meters). + rotation_y (float): rotation around Y-axis in camera coordinates [-pi..pi]. + score (float): Only for predictions, indicating confidence in detection. + + Conversion writes following files: + pointcloud(np.float32) : pointcloud data with shape [N, 6]. Consists of + (x, y, z, intensity, elongation, timestamp). + images(np.uint8): camera images are saved if `write_image` is True. + calibrations(np.float32): Intinsic and Extrinsic matrix for all cameras. + label(np.float32): Bounding box information in KITTI format. + + Args: + dataset_path (str): Directory to load waymo raw data. + save_dir (str): Directory to save data in KITTI format. + workers (str): Number of workers for the parallel process. + is_test (bool): Whether in the test_mode. Default: False. + """ + + def __init__(self, dataset_path, save_dir='', workers=8, split='train'): + + self.write_image = False + self.filter_empty_3dboxes = True + self.filter_no_label_zone_points = False + + self.classes = ['VEHICLE', 'PEDESTRIAN', 'SIGN', 'CYCLIST'] + + self.lidar_list = [ + '_FRONT', '_FRONT_RIGHT', '_FRONT_LEFT', '_SIDE_RIGHT', '_SIDE_LEFT' + ] + self.type_list = ['UNKNOWN', 'VEHICLE', 'PEDESTRIAN', 'SIGN', 'CYCLIST'] + + self.selected_waymo_classes = self.classes + + self.selected_waymo_locations = None + + self.dataset_path = dataset_path + self.save_dir = save_dir + self.workers = int(workers) + self.is_test = split == 'test' + self.prefix = split + '_' + self.save_track_id = False + + self.tfrecord_files = sorted( + glob.glob(join(self.dataset_path, "*.tfrecord"))) + + self.label_save_dir = f'{self.save_dir}/label_' + self.label_all_save_dir = f'{self.save_dir}/label_all' + self.image_save_dir = f'{self.save_dir}/image_' + self.calib_save_dir = f'{self.save_dir}/calib' + self.point_cloud_save_dir = f'{self.save_dir}/velodyne' + self.pose_save_dir = f'{self.save_dir}/pose' + + self.create_folder() + + def create_folder(self): + if not self.is_test: + dir_list1 = [ + self.label_all_save_dir, self.calib_save_dir, + self.point_cloud_save_dir, self.pose_save_dir + ] + dir_list2 = [self.label_save_dir, self.image_save_dir] + else: + dir_list1 = [ + self.calib_save_dir, self.point_cloud_save_dir, + self.pose_save_dir + ] + dir_list2 = [self.image_save_dir] + for d in dir_list1: + makedirs(d, exist_ok=True) + for d in dir_list2: + for i in range(5): + makedirs(f'{d}{str(i)}', exist_ok=True) + + def convert(self): + print(f"Start converting {len(self)} files ...") + # for i in tqdm(range(len(self))): + # self.process_one(i) + with Pool(self.workers) as p: + p.map(self.process_one, [i for i in range(len(self))]) + + def process_one(self, file_idx): + print(f"Converting : {file_idx}") + path = self.tfrecord_files[file_idx] + dataset = tf.data.TFRecordDataset(path, compression_type='') + + for frame_idx, data in enumerate(dataset): + frame = dataset_pb2.Frame() + frame.ParseFromString(bytearray(data.numpy())) + if (not frame.lasers[0].ri_return1.segmentation_label_compressed + ) and (not self.is_test): + continue + + if (self.selected_waymo_locations is not None and + frame.context.stats.location + not in self.selected_waymo_locations): + continue + + if self.write_image: + self.save_image(frame, file_idx, frame_idx) + self.save_calib(frame, file_idx, frame_idx) + self.save_lidar(frame, file_idx, frame_idx) + self.save_pose(frame, file_idx, frame_idx) + + def __len__(self): + return len(self.tfrecord_files) + + def save_image(self, frame, file_idx, frame_idx): + for img in frame.images: + img_path = Path(self.image_save_dir + str(img.name - 1)) / ( + self.prefix + str(file_idx).zfill(4) + str(frame_idx).zfill(4) + + '.npy') + image = tf.io.decode_jpeg(img.image).numpy() + + np.save(img_path, image) + + def save_calib(self, frame, file_idx, frame_idx): + # waymo front camera to kitti reference camera + T_front_cam_to_ref = np.array([[0.0, -1.0, 0.0], [0.0, 0.0, -1.0], + [1.0, 0.0, 0.0]]) + camera_calibs = [] + R0_rect = [f'{i:e}' for i in np.eye(3).flatten()] + Tr_velo_to_cams = [] + calib_context = '' + + for camera in frame.context.camera_calibrations: + # extrinsic parameters + T_cam_to_vehicle = np.array(camera.extrinsic.transform).reshape( + 4, 4) + T_vehicle_to_cam = np.linalg.inv(T_cam_to_vehicle) + Tr_velo_to_cam = \ + self.cart_to_homo(T_front_cam_to_ref) @ T_vehicle_to_cam + if camera.name == 1: # FRONT = 1, see dataset.proto for details + self.T_velo_to_front_cam = Tr_velo_to_cam.copy() + Tr_velo_to_cam = Tr_velo_to_cam[:3, :].reshape((12,)) + Tr_velo_to_cams.append([f'{i:e}' for i in Tr_velo_to_cam]) + + # intrinsic parameters + camera_calib = np.zeros((3, 4)) + camera_calib[0, 0] = camera.intrinsic[0] + camera_calib[1, 1] = camera.intrinsic[1] + camera_calib[0, 2] = camera.intrinsic[2] + camera_calib[1, 2] = camera.intrinsic[3] + camera_calib[2, 2] = 1 + camera_calib = list(camera_calib.reshape(12)) + camera_calib = [f'{i:e}' for i in camera_calib] + camera_calibs.append(camera_calib) + + # all camera ids are saved as id-1 in the result because + # camera 0 is unknown in the proto + for i in range(5): + calib_context += 'P' + str(i) + ': ' + \ + ' '.join(camera_calibs[i]) + '\n' + calib_context += 'R0_rect' + ': ' + ' '.join(R0_rect) + '\n' + for i in range(5): + calib_context += 'Tr_velo_to_cam_' + str(i) + ': ' + \ + ' '.join(Tr_velo_to_cams[i]) + '\n' + + with open( + f'{self.calib_save_dir}/{self.prefix}' + + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.txt', + 'w+') as fp_calib: + fp_calib.write(calib_context) + + def save_pose(self, frame, file_idx, frame_idx): + pose = np.array(frame.pose.transform).reshape(4, 4) + np.savetxt( + join(f'{self.pose_save_dir}/{self.prefix}' + + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.txt'), + pose) + + def save_lidar(self, frame, file_idx, frame_idx): + range_images, camera_projections, seg_labels, range_image_top_pose = parse_range_image_and_camera_projection( + frame) + + # First return + points_0, cp_points_0, intensity_0, elongation_0, seg_label_0 = \ + self.convert_range_image_to_point_cloud( + frame, + range_images, + seg_labels, + camera_projections, + range_image_top_pose, + ri_index=0 + ) + points_0 = np.concatenate(points_0, axis=0) + intensity_0 = np.concatenate(intensity_0, axis=0) + elongation_0 = np.concatenate(elongation_0, axis=0) + seg_label_0 = np.concatenate(seg_label_0, axis=0) + + # Second return + points_1, cp_points_1, intensity_1, elongation_1, seg_label_1 = \ + self.convert_range_image_to_point_cloud( + frame, + range_images, + seg_labels, + camera_projections, + range_image_top_pose, + ri_index=1 + ) + points_1 = np.concatenate(points_1, axis=0) + intensity_1 = np.concatenate(intensity_1, axis=0) + elongation_1 = np.concatenate(elongation_1, axis=0) + seg_label_1 = np.concatenate(seg_label_1, axis=0) + + points = np.concatenate([points_0, points_1], axis=0) + intensity = np.concatenate([intensity_0, intensity_1], axis=0) + elongation = np.concatenate([elongation_0, elongation_1], axis=0) + semseg_labels = np.concatenate([seg_label_0, seg_label_1], axis=0) + timestamp = frame.timestamp_micros * np.ones_like(intensity) + + # concatenate x,y,z, intensity, elongation, timestamp (6-dim) + point_cloud = np.column_stack( + (points, intensity, elongation, timestamp, semseg_labels)) + + pc_path = f'{self.point_cloud_save_dir}/{self.prefix}' + \ + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.bin' + point_cloud.astype(np.float32).tofile(pc_path) + + def convert_range_image_to_point_cloud(self, + frame, + range_images, + segmentation_labels, + camera_projections, + range_image_top_pose, + ri_index=0): + calibrations = sorted(frame.context.laser_calibrations, + key=lambda c: c.name) + points = [] + cp_points = [] + intensity = [] + elongation = [] + semseg_labels = [] + + frame_pose = tf.convert_to_tensor( + value=np.reshape(np.array(frame.pose.transform), [4, 4])) + # [H, W, 6] + range_image_top_pose_tensor = tf.reshape( + tf.convert_to_tensor(value=range_image_top_pose.data), + range_image_top_pose.shape.dims) + # [H, W, 3, 3] + range_image_top_pose_tensor_rotation = \ + transform_utils.get_rotation_matrix( + range_image_top_pose_tensor[..., 0], + range_image_top_pose_tensor[..., 1], + range_image_top_pose_tensor[..., 2]) + range_image_top_pose_tensor_translation = \ + range_image_top_pose_tensor[..., 3:] + range_image_top_pose_tensor = transform_utils.get_transform( + range_image_top_pose_tensor_rotation, + range_image_top_pose_tensor_translation) + for c in calibrations: + range_image = range_images[c.name][ri_index] + if len(c.beam_inclinations) == 0: + beam_inclinations = range_image_utils.compute_inclination( + tf.constant( + [c.beam_inclination_min, c.beam_inclination_max]), + height=range_image.shape.dims[0]) + else: + beam_inclinations = tf.constant(c.beam_inclinations) + + beam_inclinations = tf.reverse(beam_inclinations, axis=[-1]) + extrinsic = np.reshape(np.array(c.extrinsic.transform), [4, 4]) + + range_image_tensor = tf.reshape( + tf.convert_to_tensor(value=range_image.data), + range_image.shape.dims) + pixel_pose_local = None + frame_pose_local = None + if c.name == dataset_pb2.LaserName.TOP: + pixel_pose_local = range_image_top_pose_tensor + pixel_pose_local = tf.expand_dims(pixel_pose_local, axis=0) + frame_pose_local = tf.expand_dims(frame_pose, axis=0) + range_image_mask = range_image_tensor[..., 0] > 0 + + if self.filter_no_label_zone_points: + nlz_mask = range_image_tensor[..., 3] != 1.0 # 1.0: in NLZ + range_image_mask = range_image_mask & nlz_mask + + range_image_cartesian = \ + range_image_utils.extract_point_cloud_from_range_image( + tf.expand_dims(range_image_tensor[..., 0], axis=0), + tf.expand_dims(extrinsic, axis=0), + tf.expand_dims(tf.convert_to_tensor( + value=beam_inclinations), axis=0), + pixel_pose=pixel_pose_local, + frame_pose=frame_pose_local) + + range_image_cartesian = tf.squeeze(range_image_cartesian, axis=0) + points_tensor = tf.gather_nd(range_image_cartesian, + tf.compat.v1.where(range_image_mask)) + cp = camera_projections[c.name][ri_index] + cp_tensor = tf.reshape(tf.convert_to_tensor(value=cp.data), + cp.shape.dims) + cp_points_tensor = tf.gather_nd( + cp_tensor, tf.compat.v1.where(range_image_mask)) + points.append(points_tensor.numpy()) + cp_points.append(cp_points_tensor.numpy()) + + intensity_tensor = tf.gather_nd(range_image_tensor[..., 1], + tf.where(range_image_mask)) + intensity.append(intensity_tensor.numpy()) + + elongation_tensor = tf.gather_nd(range_image_tensor[..., 2], + tf.where(range_image_mask)) + elongation.append(elongation_tensor.numpy()) + + if c.name in segmentation_labels: + sl = segmentation_labels[c.name][ri_index] + sl_tensor = tf.reshape(tf.convert_to_tensor(sl.data), + sl.shape.dims) + sl_points_tensor = tf.gather_nd(sl_tensor, + tf.where(range_image_mask)) + else: + sl_points_tensor = tf.zeros([points_tensor.shape[0], 2], + dtype=tf.int32) + + semseg_labels.append(sl_points_tensor.numpy()) + + return points, cp_points, intensity, elongation, semseg_labels + + @staticmethod + def cart_to_homo(mat): + ret = np.eye(4) + if mat.shape == (3, 3): + ret[:3, :3] = mat + elif mat.shape == (3, 4): + ret[:3, :] = mat + else: + raise ValueError(mat.shape) + return ret + + +if __name__ == '__main__': + + logging.basicConfig( + level=logging.INFO, + format='%(levelname)s - %(asctime)s - %(module)s - %(message)s', + ) + + args = parse_args() + out_path = args.out_path + if out_path is None: + args.out_path = args.dataset_path + if args.split not in ['train', 'val', 'test']: + raise ValueError("split must be one of {train, val, test}") + converter = Waymo2KITTI(args.dataset_path, args.out_path, args.workers, + args.split) + converter.convert() From 5391b1c33dd83b1ad27f74423dc07286709730d9 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Thu, 31 Mar 2022 09:34:35 +0530 Subject: [PATCH 24/50] add waymo semseg dataset --- ml3d/datasets/__init__.py | 3 +- ml3d/datasets/waymo_semseg.py | 202 ++++++++++++++++++++++++++++++++++ 2 files changed, 204 insertions(+), 1 deletion(-) create mode 100644 ml3d/datasets/waymo_semseg.py diff --git a/ml3d/datasets/__init__.py b/ml3d/datasets/__init__.py index 056b8b9f4..05b4b7fe7 100644 --- a/ml3d/datasets/__init__.py +++ b/ml3d/datasets/__init__.py @@ -15,6 +15,7 @@ from .kitti import KITTI from .nuscenes import NuScenes from .waymo import Waymo +from .waymo_semseg import WaymoSemSeg from .lyft import Lyft from .shapenet import ShapeNet from .argoverse import Argoverse @@ -27,5 +28,5 @@ 'Custom3D', 'utils', 'augment', 'samplers', 'KITTI', 'Waymo', 'NuScenes', 'Lyft', 'ShapeNet', 'SemSegRandomSampler', 'InferenceDummySplit', 'SemSegSpatiallyRegularSampler', 'Argoverse', 'Scannet', 'SunRGBD', - 'MatterportObjects' + 'MatterportObjects', 'WaymoSemSeg' ] diff --git a/ml3d/datasets/waymo_semseg.py b/ml3d/datasets/waymo_semseg.py new file mode 100644 index 000000000..e8418440c --- /dev/null +++ b/ml3d/datasets/waymo_semseg.py @@ -0,0 +1,202 @@ +import numpy as np +import os, argparse, pickle, sys +from os.path import exists, join, isfile, dirname, abspath, split +from pathlib import Path +from glob import glob +import logging +import yaml + +from .base_dataset import BaseDataset +from ..utils import Config, make_dir, DATASET +from .utils import BEVBox3D + +log = logging.getLogger(__name__) + + +class WaymoSemSeg(BaseDataset): + """This class is used to create a dataset based on the Waymo 3D dataset, and + used in object detection, visualizer, training, or testing. + + The Waymo 3D dataset is best suited for autonomous driving applications. + """ + + def __init__(self, + dataset_path, + name='Waymo', + cache_dir='./logs/cache', + use_cache=False, + **kwargs): + """Initialize the function by passing the dataset and other details. + + Args: + dataset_path: The path to the dataset to use. + name: The name of the dataset (Waymo in this case). + cache_dir: The directory where the cache is stored. + use_cache: Indicates if the dataset should be cached. + + Returns: + class: The corresponding class. + """ + super().__init__(dataset_path=dataset_path, + name=name, + cache_dir=cache_dir, + use_cache=use_cache, + **kwargs) + + cfg = self.cfg + + self.name = cfg.name + self.dataset_path = cfg.dataset_path + self.num_classes = 23 + self.label_to_names = self.get_label_to_names() + self.shuffle = kwargs.get('shuffle', False) + + self.all_files = sorted( + glob(join(cfg.dataset_path, 'velodyne', '*.bin'))) + self.train_files = [] + self.val_files = [] + + for f in self.all_files: + if 'train' in f: + self.train_files.append(f) + elif 'val' in f: + self.val_files.append(f) + elif 'test' in f: + self.test_files.append(f) + else: + log.warning( + f"Skipping {f}, prefix must be one of train, test or val.") + if self.shuffle: + log.info("Shuffling training files...") + self.rng.shuffle(self.train_files) + + @staticmethod + def get_label_to_names(): + """Returns a label to names dictionary object. + + Returns: + A dict where keys are label numbers and + values are the corresponding names. + """ + classes = "Undefined, Car, Truck, Bus, Other Vehicle, Motorcyclist, Bicyclist, Pedestrian, Sign, Traffic Light, Pole, Construction Cone, Bicycle, Motorcycle, Building, Vegetation, Tree Trunk, Curb, Road, Lane Marker, Other Ground, Walkable, Sidewalk" + classes = classes.replace(', ', ',').split(',') + label_to_names = {} + for i in range(len(classes)): + label_to_names[i] = classes[i] + + return label_to_names + + @staticmethod + def read_lidar(path): + """Reads lidar data from the path provided. + + Returns: + pc: pointcloud data with shape [N, 8], where + the format is x,y,z,intensity,elongation,timestamp,instance, semantic_label. + """ + return np.fromfile(path, dtype=np.float32).reshape(-1, 8) + + def get_split(self, split): + """Returns a dataset split. + + Args: + split: A string identifying the dataset split that is usually one of + 'training', 'test', 'validation', or 'all'. + + Returns: + A dataset split object providing the requested subset of the data. + """ + return WaymoSplit(self, split=split) + + def get_split_list(self, split): + """Returns the list of data splits available. + + Args: + split: A string identifying the dataset split that is usually one of + 'training', 'test', 'validation', or 'all'. + + Returns: + A dataset split object providing the requested subset of the data. + + Raises: + ValueError: Indicates that the split name passed is incorrect. The + split name should be one of 'training', 'test', 'validation', or + 'all'. + """ + cfg = self.cfg + dataset_path = cfg.dataset_path + file_list = [] + + if split in ['train', 'training']: + return self.train_files + seq_list = cfg.training_split + elif split in ['test', 'testing']: + return self.test_files + elif split in ['val', 'validation']: + return self.val_files + elif split in ['all']: + return self.train_files + self.val_files + self.test_files + else: + raise ValueError("Invalid split {}".format(split)) + + def is_tested(attr): + """Checks if a datum in the dataset has been tested. + + Args: + attr: The attribute that needs to be checked. + + Returns: + If the datum attribute is tested, then return the path where the + attribute is stored; else, returns false. + """ + raise NotImplementedError() + + def save_test_result(results, attr): + """Saves the output of a model. + + Args: + results: The output of a model for the datum associated with the attribute passed. + attr: The attributes that correspond to the outputs passed in results. + """ + raise NotImplementedError() + + +class WaymoSplit(): + + def __init__(self, dataset, split='train'): + self.cfg = dataset.cfg + path_list = dataset.get_split_list(split) + log.info("Found {} pointclouds for {}".format(len(path_list), split)) + + self.path_list = path_list + self.split = split + self.dataset = dataset + + def __len__(self): + return len(self.path_list) + + def get_data(self, idx): + pc_path = self.path_list[idx] + + pc = self.dataset.read_lidar(pc_path) + feat = pc[:, 3:6] + label = pc[:, 7].astype(np.int32) + pc = pc[:, :3] + + data = { + 'point': pc, + 'feat': feat, + 'label': label, + } + + return data + + def get_attr(self, idx): + pc_path = self.path_list[idx] + name = Path(pc_path).name.split('.')[0] + + attr = {'name': name, 'path': pc_path, 'split': self.split} + return attr + + +DATASET._register_module(WaymoSemSeg) \ No newline at end of file From eb3b5518e6fc4a4e0d531da1905afc773d041d19 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 5 Apr 2022 17:13:03 +0530 Subject: [PATCH 25/50] remove dataparallel --- ml3d/torch/pipelines/dataparallel.py | 57 ---------------------------- 1 file changed, 57 deletions(-) delete mode 100644 ml3d/torch/pipelines/dataparallel.py diff --git a/ml3d/torch/pipelines/dataparallel.py b/ml3d/torch/pipelines/dataparallel.py deleted file mode 100644 index eee68138d..000000000 --- a/ml3d/torch/pipelines/dataparallel.py +++ /dev/null @@ -1,57 +0,0 @@ -import torch -from torch.nn.parallel import DataParallel - - -class CustomDataParallel(DataParallel): - """Custom DataParallel method for performing scatter operation - outside of torch's DataParallel. - """ - - def __init__(self, module, **kwargs): - super(CustomDataParallel, self).__init__(module, **kwargs) - self.get_loss = self.module.get_loss - self.cfg = self.module.cfg - - def forward(self, *inputs, **kwargs): - if not self.device_ids: - return self.module(*inputs, **kwargs) - - if len(self.device_ids) == 1: - if hasattr(inputs[0], 'to'): - inputs[0].to(self.device_ids[0]) - return self.module(inputs[0], **kwargs) - - inputs, kwargs = self.customscatter(inputs, kwargs, self.device_ids) - - self.module.cuda() - replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) - outputs = self.parallel_apply(replicas, inputs, kwargs) - - return self.gather(outputs, self.output_device) - - def customscatter(self, inputs, kwargs, device_ids): - """Custom scatter method to override default method. - Scatter batch dimension based on custom scatter implemented - in custom batcher. - - Agrs: - inputs: Object of type custom batcher. - kwargs: Optional keyword arguments. - device_ids: List of device ids. - - Returns: - Returns a list of inputs of length num_devices. - Each input is transfered to different device id. - """ - if not hasattr(inputs[0], 'scatter'): - try: - return self.scatter(inputs, kwargs, device_ids) - except: - raise NotImplementedError( - f"Please implement scatter for {inputs[0]} for multi gpu execution." - ) - inputs = inputs[0].scatter(inputs[0], len(device_ids)) - for i in range(len(inputs)): - inputs[i].to(torch.device(device_ids[i])) - - return inputs, [kwargs for _ in range(len(inputs))] From bb7a715d77651c636bd7eabccd693ca2ec23761f Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 12 Apr 2022 11:30:53 +0530 Subject: [PATCH 26/50] change sampler --- ml3d/configs/sparseconvunet_scannet.yml | 2 +- ml3d/datasets/base_dataset.py | 12 +++-- ml3d/torch/pipelines/semantic_segmentation.py | 49 +++++++++++++------ ml3d/utils/builder.py | 2 + 4 files changed, 47 insertions(+), 18 deletions(-) diff --git a/ml3d/configs/sparseconvunet_scannet.yml b/ml3d/configs/sparseconvunet_scannet.yml index 465d88ed5..e23635ca9 100644 --- a/ml3d/configs/sparseconvunet_scannet.yml +++ b/ml3d/configs/sparseconvunet_scannet.yml @@ -6,7 +6,7 @@ dataset: test_result_folder: ./test use_cache: False sampler: - name: 'SemSegRandomSampler' + name: None model: name: SparseConvUnet batcher: ConcatBatcher diff --git a/ml3d/datasets/base_dataset.py b/ml3d/datasets/base_dataset.py index bcc62e244..a97a890c3 100644 --- a/ml3d/datasets/base_dataset.py +++ b/ml3d/datasets/base_dataset.py @@ -127,10 +127,16 @@ def __init__(self, dataset, split='training'): if split in ['test']: sampler_cls = get_module('sampler', 'SemSegSpatiallyRegularSampler') else: - sampler_cfg = self.cfg.get('sampler', - {'name': 'SemSegRandomSampler'}) + # sampler_cfg = self.cfg.get('sampler', + # {'name': 'SemSegRandomSampler'}) + sampler_cfg = self.cfg.get('sampler', {'name': None}) + if sampler_cfg['name'] == "None": + sampler_cfg['name'] = None sampler_cls = get_module('sampler', sampler_cfg['name']) - self.sampler = sampler_cls(self) + if sampler_cls is None: + self.sampler = None + else: + self.sampler = sampler_cls(self) @abstractmethod def __len__(self): diff --git a/ml3d/torch/pipelines/semantic_segmentation.py b/ml3d/torch/pipelines/semantic_segmentation.py index d1f7a77bb..ec9cf81d0 100644 --- a/ml3d/torch/pipelines/semantic_segmentation.py +++ b/ml3d/torch/pipelines/semantic_segmentation.py @@ -312,20 +312,20 @@ def update_tests(self, sampler, inputs, results): def run_train(self): torch.manual_seed(self.rng.integers(np.iinfo( np.int32).max)) # Random reproducible seed for torch + rank = self.rank # Rank for distributed training model = self.model device = self.device - model.device = device dataset = self.dataset cfg = self.cfg - model.to(device) - - log.info("DEVICE : {}".format(device)) - timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') + if rank == 0: + log.info("DEVICE : {}".format(device)) + timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') - log_file_path = join(cfg.logs_dir, 'log_train_' + timestamp + '.txt') - log.info("Logging in file : {}".format(log_file_path)) - log.addHandler(logging.FileHandler(log_file_path)) + log_file_path = join(cfg.logs_dir, + 'log_train_' + timestamp + '.txt') + log.info("Logging in file : {}".format(log_file_path)) + log.addHandler(logging.FileHandler(log_file_path)) Loss = SemSegLoss(self, model, dataset, device) self.metric_train = SemSegMetric() @@ -335,6 +335,7 @@ def run_train(self): train_dataset = dataset.get_split('train') train_sampler = train_dataset.sampler + train_split = TorchDataloader(dataset=train_dataset, preprocess=model.preprocess, transform=model.transform, @@ -343,12 +344,20 @@ def run_train(self): steps_per_epoch=dataset.cfg.get( 'steps_per_epoch_train', None)) + if self.distributed: + if train_sampler is not None: + raise NotImplementedError( + "Distributed training with sampler is not supported yet!") + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_split) + train_loader = DataLoader( train_split, batch_size=cfg.batch_size, - sampler=get_sampler(train_sampler), - num_workers=cfg.get('num_workers', 2), - pin_memory=cfg.get('pin_memory', True), + sampler=train_sampler + if self.distributed else get_sampler(train_sampler), + num_workers=cfg.get('num_workers', 0), + pin_memory=cfg.get('pin_memory', False), collate_fn=self.batcher.collate_fn, worker_init_fn=lambda x: np.random.seed(x + np.uint32( torch.utils.data.get_worker_info().seed)) @@ -388,9 +397,21 @@ def run_train(self): runid + '_' + Path(tensorboard_dir).name) writer = SummaryWriter(self.tensorboard_dir) - self.save_config(writer) - log.info("Writing summary in {}.".format(self.tensorboard_dir)) - record_summary = cfg.get('summary').get('record_for', []) + if rank == 0: + self.save_config(writer) + log.info("Writing summary in {}.".format(self.tensorboard_dir)) + + # wrap model for multiple GPU + if self.distributed: + model.cuda(self.device) + model.device = self.device + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[self.device]) + model.get_loss = model.module.get_loss + model.cfg = model.module.cfg + + record_summary = cfg.get('summary').get('record_for', + []) if self.rank == 0 else [] log.info("Started training") diff --git a/ml3d/utils/builder.py b/ml3d/utils/builder.py index f087d501c..6c2ac7ebb 100644 --- a/ml3d/utils/builder.py +++ b/ml3d/utils/builder.py @@ -55,6 +55,8 @@ def get_module(module_type, module_name, framework=None, **kwargs): return get_from_name(module_name, DATASET, framework) elif module_type == "sampler": + if module_name is None: + return None return get_from_name(module_name, SAMPLER, framework) elif module_type == "model": From d35d2f13c66d57691b9e3bdba6e2502c54b51421 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 12 Apr 2022 15:14:41 +0530 Subject: [PATCH 27/50] add val sampler --- ml3d/datasets/waymo_semseg.py | 5 +- ml3d/torch/dataloaders/torch_sampler.py | 2 + ml3d/torch/models/point_transformer.py | 7 +- ml3d/torch/pipelines/semantic_segmentation.py | 64 ++++++++++++------- 4 files changed, 52 insertions(+), 26 deletions(-) diff --git a/ml3d/datasets/waymo_semseg.py b/ml3d/datasets/waymo_semseg.py index e8418440c..94c079767 100644 --- a/ml3d/datasets/waymo_semseg.py +++ b/ml3d/datasets/waymo_semseg.py @@ -6,7 +6,7 @@ import logging import yaml -from .base_dataset import BaseDataset +from .base_dataset import BaseDataset, BaseDatasetSplit from ..utils import Config, make_dir, DATASET from .utils import BEVBox3D @@ -161,9 +161,10 @@ def save_test_result(results, attr): raise NotImplementedError() -class WaymoSplit(): +class WaymoSplit(BaseDatasetSplit): def __init__(self, dataset, split='train'): + super().__init__(dataset, split=split) self.cfg = dataset.cfg path_list = dataset.get_split_list(split) log.info("Found {} pointclouds for {}".format(len(path_list), split)) diff --git a/ml3d/torch/dataloaders/torch_sampler.py b/ml3d/torch/dataloaders/torch_sampler.py index 142f01fbc..78e0ebc45 100644 --- a/ml3d/torch/dataloaders/torch_sampler.py +++ b/ml3d/torch/dataloaders/torch_sampler.py @@ -15,4 +15,6 @@ def __len__(self): def get_sampler(sampler): + if sampler is None: + return None return TorchSamplerWrapper(sampler) diff --git a/ml3d/torch/models/point_transformer.py b/ml3d/torch/models/point_transformer.py index 804387628..7169c757a 100644 --- a/ml3d/torch/models/point_transformer.py +++ b/ml3d/torch/models/point_transformer.py @@ -718,6 +718,7 @@ def knn_batch(points, if points_row_splits.shape[0] != queries_row_splits.shape[0]: raise ValueError("KNN(points and queries must have same batch size)") + device = points.device points = points.cpu() queries = queries.cpu() @@ -730,9 +731,11 @@ def knn_batch(points, return_distances=True) if return_distances: return ans.neighbors_index.reshape( - -1, k).long().cuda(), ans.neighbors_distance.reshape(-1, k).cuda() + -1, + k).long().to(device), ans.neighbors_distance.reshape(-1, + k).to(device) else: - return ans.neighbors_index.reshape(-1, k).long().cuda() + return ans.neighbors_index.reshape(-1, k).long().to(device) def interpolation(points, diff --git a/ml3d/torch/pipelines/semantic_segmentation.py b/ml3d/torch/pipelines/semantic_segmentation.py index ec9cf81d0..a7a457558 100644 --- a/ml3d/torch/pipelines/semantic_segmentation.py +++ b/ml3d/torch/pipelines/semantic_segmentation.py @@ -336,6 +336,9 @@ def run_train(self): train_dataset = dataset.get_split('train') train_sampler = train_dataset.sampler + valid_dataset = dataset.get_split('validation') + valid_sampler = valid_dataset.sampler + train_split = TorchDataloader(dataset=train_dataset, preprocess=model.preprocess, transform=model.transform, @@ -344,12 +347,22 @@ def run_train(self): steps_per_epoch=dataset.cfg.get( 'steps_per_epoch_train', None)) + valid_split = TorchDataloader(dataset=valid_dataset, + preprocess=model.preprocess, + transform=model.transform, + sampler=valid_sampler, + use_cache=dataset.cfg.use_cache, + steps_per_epoch=dataset.cfg.get( + 'steps_per_epoch_valid', None)) + if self.distributed: - if train_sampler is not None: + if train_sampler is not None or valid_sampler is not None: raise NotImplementedError( "Distributed training with sampler is not supported yet!") train_sampler = torch.utils.data.distributed.DistributedSampler( train_split) + valid_sampler = torch.utild.data.distributed.DistributedSampler( + valid_split) train_loader = DataLoader( train_split, @@ -363,20 +376,11 @@ def run_train(self): torch.utils.data.get_worker_info().seed)) ) # numpy expects np.uint32, whereas torch returns np.uint64. - valid_dataset = dataset.get_split('validation') - valid_sampler = valid_dataset.sampler - valid_split = TorchDataloader(dataset=valid_dataset, - preprocess=model.preprocess, - transform=model.transform, - sampler=valid_sampler, - use_cache=dataset.cfg.use_cache, - steps_per_epoch=dataset.cfg.get( - 'steps_per_epoch_valid', None)) - valid_loader = DataLoader( valid_split, batch_size=cfg.val_batch_size, - sampler=get_sampler(valid_sampler), + sampler=valid_sampler + if self.distributed else get_sampler(valid_sampler), num_workers=cfg.get('num_workers', 2), pin_memory=cfg.get('pin_memory', True), collate_fn=self.batcher.collate_fn, @@ -413,16 +417,19 @@ def run_train(self): record_summary = cfg.get('summary').get('record_for', []) if self.rank == 0 else [] - log.info("Started training") + if rank == 0: + log.info("Started training") for epoch in range(0, cfg.max_epoch + 1): - log.info(f'=== EPOCH {epoch:d}/{cfg.max_epoch:d} ===') + if self.distributed: + train_sampler.set_epoch(epoch) + model.train() self.metric_train.reset() self.metric_val.reset() self.losses = [] - model.trans_point_sampler = train_sampler.get_point_sampler() + # model.trans_point_sampler = train_sampler.get_point_sampler() for step, inputs in enumerate(tqdm(train_loader, desc='training')): if hasattr(inputs['data'], 'to'): @@ -437,8 +444,13 @@ def run_train(self): loss.backward() if model.cfg.get('grad_clip_norm', -1) > 0: - torch.nn.utils.clip_grad_value_(model.parameters(), - model.cfg.grad_clip_norm) + if self.distributed: + torch.nn.utils.clip_grad_value_( + model.module.parameters(), model.cfg.grad_clip_norm) + else: + torch.nn.utils.clip_grad_value_( + model.parameters(), model.cfg.grad_clip_norm) + self.optimizer.step() self.metric_train.update(predict_scores, gt_labels) @@ -449,12 +461,16 @@ def run_train(self): self.summary['train'] = self.get_3d_summary( results, inputs['data'], epoch) - self.scheduler.step() + if self.distributed: + dist.barrier() + + if self.scheduler is not None: + self.scheduler.step() # --------------------- validation model.eval() self.valid_losses = [] - model.trans_point_sampler = valid_sampler.get_point_sampler() + # model.trans_point_sampler = valid_sampler.get_point_sampler() with torch.no_grad(): for step, inputs in enumerate( @@ -477,10 +493,14 @@ def run_train(self): self.summary['valid'] = self.get_3d_summary( results, inputs['data'], epoch) - self.save_logs(writer, epoch) + if self.distributed: + # TODO (sanskar): accumulate confusion matrix for all process. + dist.barrier() - if epoch % cfg.save_ckpt_freq == 0: - self.save_ckpt(epoch) + if rank == 0: + self.save_logs(writer, epoch) + if epoch % cfg.save_ckpt_freq == 0: + self.save_ckpt(epoch) def get_batcher(self, device, split='training'): """Get the batcher to be used based on the device and split.""" From db1128c142e5fbf7c804cfd0513dac9539f58a84 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 12 Apr 2022 15:22:45 +0530 Subject: [PATCH 28/50] add config files --- ml3d/configs/pointtransformer_waymo.yml | 56 ++++++++++++++++++ ml3d/configs/sparseconvunet_waymo.yml | 57 +++++++++++++++++++ ml3d/torch/pipelines/semantic_segmentation.py | 2 +- 3 files changed, 114 insertions(+), 1 deletion(-) create mode 100644 ml3d/configs/pointtransformer_waymo.yml create mode 100644 ml3d/configs/sparseconvunet_waymo.yml diff --git a/ml3d/configs/pointtransformer_waymo.yml b/ml3d/configs/pointtransformer_waymo.yml new file mode 100644 index 000000000..355a3b7d1 --- /dev/null +++ b/ml3d/configs/pointtransformer_waymo.yml @@ -0,0 +1,56 @@ +dataset: + name: WaymoSemSeg + dataset_path: # path/to/your/dataset + cache_dir: ./logs/cache + class_weights: [] + ignored_label_inds: [0] + test_result_folder: ./test + use_cache: false +model: + name: PointTransformer + batcher: ConcatBatcher + ckpt_path: # path/to/your/checkpoint + in_channels: 3 + blocks: [2, 3, 4, 6, 3] + num_classes: 23 + voxel_size: 0.04 + max_voxels: 50000 + ignored_label_inds: [-1] + augment: {} + # rotate: + # method: vertical + # scale: + # min_s: 0.95 + # max_s: 1.05 + # noise: + # noise_std: 0.005 + # ChromaticAutoContrast: + # randomize_blend_factor: True + # blend_factor: 0.2 + # ChromaticTranslation: + # trans_range_ratio: 0.05 + # ChromaticJitter: + # std: 0.01 + # HueSaturationTranslation: + # hue_max: 0.5 + # saturation_max: 0.2 +pipeline: + name: SemanticSegmentation + optimizer: + lr: 0.02 + momentum: 0.9 + weight_decay: 0.0001 + batch_size: 2 + learning_rate: 0.01 + main_log_dir: ./logs + max_epoch: 512 + save_ckpt_freq: 3 + scheduler_gamma: 0.99 + test_batch_size: 1 + train_sum_dir: train_log + val_batch_size: 2 + summary: + record_for: [] + max_pts: + use_reference: false + max_outputs: 1 diff --git a/ml3d/configs/sparseconvunet_waymo.yml b/ml3d/configs/sparseconvunet_waymo.yml new file mode 100644 index 000000000..3ca34d2c7 --- /dev/null +++ b/ml3d/configs/sparseconvunet_waymo.yml @@ -0,0 +1,57 @@ +dataset: + name: WaymoSemSeg + dataset_path: # path/to/your/dataset + cache_dir: ./logs/cache + class_weights: [] + ignored_label_inds: [0] + test_result_folder: ./test + use_cache: false +model: + name: SparseConvUnet + batcher: ConcatBatcher + ckpt_path: # path/to/your/checkpoint + multiplier: 32 + voxel_size: 0.02 + residual_blocks: True + conv_block_reps: 1 + in_channels: 3 + num_classes: 23 + grid_size: 4096 + ignored_label_inds: [0] + augment: {} + # rotate: + # method: vertical + # scale: + # min_s: 0.9 + # max_s: 1.1 + # noise: + # noise_std: 0.01 + # RandomDropout: + # dropout_ratio: 0.2 + # RandomHorizontalFlip: + # axes: [0, 1] + # ChromaticAutoContrast: + # randomize_blend_factor: True + # blend_factor: 0.5 + # ChromaticTranslation: + # trans_range_ratio: 0.1 + # ChromaticJitter: + # std: 0.05 +pipeline: + name: SemanticSegmentation + optimizer: + lr: 0.001 + betas: [0.9, 0.999] + batch_size: 2 + main_log_dir: ./logs + max_epoch: 256 + save_ckpt_freq: 3 + scheduler_gamma: 0.99 + test_batch_size: 1 + train_sum_dir: train_log + val_batch_size: 2 + summary: + record_for: [] + max_pts: + use_reference: false + max_outputs: 1 diff --git a/ml3d/torch/pipelines/semantic_segmentation.py b/ml3d/torch/pipelines/semantic_segmentation.py index a7a457558..eee795a99 100644 --- a/ml3d/torch/pipelines/semantic_segmentation.py +++ b/ml3d/torch/pipelines/semantic_segmentation.py @@ -381,7 +381,7 @@ def run_train(self): batch_size=cfg.val_batch_size, sampler=valid_sampler if self.distributed else get_sampler(valid_sampler), - num_workers=cfg.get('num_workers', 2), + num_workers=cfg.get('num_workers', 0), pin_memory=cfg.get('pin_memory', True), collate_fn=self.batcher.collate_fn, worker_init_fn=lambda x: np.random.seed(x + np.uint32( From d7180805184f855bb2851135824416921b4b0366 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Mon, 18 Apr 2022 03:45:09 -0700 Subject: [PATCH 29/50] fix waymo semseg --- ml3d/datasets/waymo_semseg.py | 1 + scripts/preprocess_waymo_semseg.py | 7 +++++-- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/ml3d/datasets/waymo_semseg.py b/ml3d/datasets/waymo_semseg.py index 94c079767..00fa29345 100644 --- a/ml3d/datasets/waymo_semseg.py +++ b/ml3d/datasets/waymo_semseg.py @@ -55,6 +55,7 @@ def __init__(self, glob(join(cfg.dataset_path, 'velodyne', '*.bin'))) self.train_files = [] self.val_files = [] + self.test_files = [] for f in self.all_files: if 'train' in f: diff --git a/scripts/preprocess_waymo_semseg.py b/scripts/preprocess_waymo_semseg.py index 464ec91f3..002de1ce3 100644 --- a/scripts/preprocess_waymo_semseg.py +++ b/scripts/preprocess_waymo_semseg.py @@ -20,6 +20,7 @@ from os import makedirs from multiprocessing import Pool from tqdm import tqdm +from tqdm.contrib.concurrent import process_map # or thread_map gpus = tf.config.experimental.list_physical_devices('GPU') if gpus: @@ -152,8 +153,10 @@ def convert(self): print(f"Start converting {len(self)} files ...") # for i in tqdm(range(len(self))): # self.process_one(i) - with Pool(self.workers) as p: - p.map(self.process_one, [i for i in range(len(self))]) + # with Pool(self.workers) as p: + # tqdm(p.imap(self.process_one, [i for i in range(len(self))]), total=len(self)) + # p.map(self.process_one, [i for i in range(len(self))]) + process_map(self.process_one, range(len(self)), max_workers=self.workers) def process_one(self, file_idx): print(f"Converting : {file_idx}") From 8e8d4f896a20fcbaa8f0c9c79c417594bc1f20b2 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Mon, 18 Apr 2022 03:45:29 -0700 Subject: [PATCH 30/50] fix optim --- ml3d/torch/pipelines/base_pipeline.py | 8 +++---- ml3d/torch/pipelines/semantic_segmentation.py | 21 ++++++++++++------- requirements-torch-cuda.txt | 8 +++---- 3 files changed, 22 insertions(+), 15 deletions(-) diff --git a/ml3d/torch/pipelines/base_pipeline.py b/ml3d/torch/pipelines/base_pipeline.py index f466868b9..58196c891 100644 --- a/ml3d/torch/pipelines/base_pipeline.py +++ b/ml3d/torch/pipelines/base_pipeline.py @@ -41,10 +41,10 @@ def __init__(self, self.rng = np.random.default_rng(kwargs.get('seed', None)) self.distributed = distributed - if self.distributed and self.name == "SemanticSegmentation": - raise NotImplementedError( - "Distributed training not implemented for SemanticSegmentation!" - ) + # if self.distributed and self.name == "SemanticSegmentation": + # raise NotImplementedError( + # "Distributed training not implemented for SemanticSegmentation!" + # ) self.rank = kwargs.get('rank', 0) diff --git a/ml3d/torch/pipelines/semantic_segmentation.py b/ml3d/torch/pipelines/semantic_segmentation.py index eee795a99..d3e1ec813 100644 --- a/ml3d/torch/pipelines/semantic_segmentation.py +++ b/ml3d/torch/pipelines/semantic_segmentation.py @@ -1,11 +1,12 @@ import logging +import numpy as np +import torch +import torch.distributed as dist + from os.path import exists, join from pathlib import Path from datetime import datetime - -import numpy as np from tqdm import tqdm -import torch from torch.utils.tensorboard import SummaryWriter from torch.utils.data import DataLoader @@ -361,7 +362,7 @@ def run_train(self): "Distributed training with sampler is not supported yet!") train_sampler = torch.utils.data.distributed.DistributedSampler( train_split) - valid_sampler = torch.utild.data.distributed.DistributedSampler( + valid_sampler = torch.utils.data.distributed.DistributedSampler( valid_split) train_loader = DataLoader( @@ -387,6 +388,10 @@ def run_train(self): worker_init_fn=lambda x: np.random.seed(x + np.uint32( torch.utils.data.get_worker_info().seed))) + # Optimizer must be created after moving model to specific device. + model.cuda(self.device) + model.device = self.device + self.optimizer, self.scheduler = model.get_optimizer(cfg) is_resume = model.cfg.get('is_resume', True) @@ -407,8 +412,6 @@ def run_train(self): # wrap model for multiple GPU if self.distributed: - model.cuda(self.device) - model.device = self.device model = torch.nn.parallel.DistributedDataParallel( model, device_ids=[self.device]) model.get_loss = model.module.get_loss @@ -702,7 +705,10 @@ def load_ckpt(self, ckpt_path=None, is_resume=True): want to resume. """ train_ckpt_dir = join(self.cfg.logs_dir, 'checkpoint') - make_dir(train_ckpt_dir) + if self.rank == 0: + make_dir(train_ckpt_dir) + if self.distributed: + dist.barrier() if ckpt_path is None: ckpt_path = latest_torch_ckpt(train_ckpt_dir) @@ -717,6 +723,7 @@ def load_ckpt(self, ckpt_path=None, is_resume=True): log.info(f'Loading checkpoint {ckpt_path}') ckpt = torch.load(ckpt_path, map_location=self.device) + self.model.load_state_dict(ckpt['model_state_dict']) if 'optimizer_state_dict' in ckpt and hasattr(self, 'optimizer'): log.info(f'Loading checkpoint optimizer_state_dict') diff --git a/requirements-torch-cuda.txt b/requirements-torch-cuda.txt index 99ed63688..2ebb50433 100644 --- a/requirements-torch-cuda.txt +++ b/requirements-torch-cuda.txt @@ -1,7 +1,7 @@ -https://s3.us-west-1.wasabisys.com/open3d-downloads/torch-1.8.2-cp36-cp36m-linux_x86_64.whl ; python_version == '3.6' -https://s3.us-west-1.wasabisys.com/open3d-downloads/torch-1.8.2-cp37-cp37m-linux_x86_64.whl ; python_version == '3.7' -https://s3.us-west-1.wasabisys.com/open3d-downloads/torch-1.8.2-cp38-cp38-linux_x86_64.whl ; python_version == '3.8' -https://s3.us-west-1.wasabisys.com/open3d-downloads/torch-1.8.2-cp39-cp39-linux_x86_64.whl ; python_version == '3.9' +https://open3d-downloads.b-cdn.net/torch-1.8.2-cp36-cp36m-linux_x86_64.whl ; python_version == '3.6' +https://open3d-downloads.b-cdn.net/torch-1.8.2-cp37-cp37m-linux_x86_64.whl ; python_version == '3.7' +https://open3d-downloads.b-cdn.net/torch-1.8.2-cp38-cp38-linux_x86_64.whl ; python_version == '3.8' +https://open3d-downloads.b-cdn.net/torch-1.8.2-cp39-cp39-linux_x86_64.whl ; python_version == '3.9' -f https://download.pytorch.org/whl/lts/1.8/torch_lts.html torchvision==0.9.2+cu111 tensorboard From 39978331ba87ba5b17b76dd05c923d886619c5eb Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Mon, 18 Apr 2022 05:20:58 -0700 Subject: [PATCH 31/50] improve tqdm --- ml3d/torch/pipelines/semantic_segmentation.py | 30 +++++++++++++------ scripts/preprocess_waymo_semseg.py | 8 +++-- 2 files changed, 26 insertions(+), 12 deletions(-) diff --git a/ml3d/torch/pipelines/semantic_segmentation.py b/ml3d/torch/pipelines/semantic_segmentation.py index d3e1ec813..be00965d3 100644 --- a/ml3d/torch/pipelines/semantic_segmentation.py +++ b/ml3d/torch/pipelines/semantic_segmentation.py @@ -395,7 +395,7 @@ def run_train(self): self.optimizer, self.scheduler = model.get_optimizer(cfg) is_resume = model.cfg.get('is_resume', True) - self.load_ckpt(model.cfg.ckpt_path, is_resume=is_resume) + start_ep = self.load_ckpt(model.cfg.ckpt_path, is_resume=is_resume) dataset_name = dataset.name if dataset is not None else '' tensorboard_dir = join( @@ -423,7 +423,7 @@ def run_train(self): if rank == 0: log.info("Started training") - for epoch in range(0, cfg.max_epoch + 1): + for epoch in range(start_ep, cfg.max_epoch + 1): log.info(f'=== EPOCH {epoch:d}/{cfg.max_epoch:d} ===') if self.distributed: train_sampler.set_epoch(epoch) @@ -432,9 +432,10 @@ def run_train(self): self.metric_train.reset() self.metric_val.reset() self.losses = [] - # model.trans_point_sampler = train_sampler.get_point_sampler() + # model.trans_point_sampler = train_sampler.get_point_sampler() # TODO: fix this for model with samplers. - for step, inputs in enumerate(tqdm(train_loader, desc='training')): + progress_bar = tqdm(train_loader, desc='training') + for inputs in progress_bar: if hasattr(inputs['data'], 'to'): inputs['data'].to(device) self.optimizer.zero_grad() @@ -459,11 +460,18 @@ def run_train(self): self.metric_train.update(predict_scores, gt_labels) self.losses.append(loss.cpu().item()) + # Save only for the first pcd in batch - if 'train' in record_summary and step == 0: + if 'train' in record_summary and progress_bar.n == 0: self.summary['train'] = self.get_3d_summary( results, inputs['data'], epoch) + desc = "training - Epoch: %d, loss: %.3f" % (epoch, + loss.cpu().item()) + # if rank == 0: + progress_bar.set_description(desc) + progress_bar.refresh() + if self.distributed: dist.barrier() @@ -713,10 +721,10 @@ def load_ckpt(self, ckpt_path=None, is_resume=True): if ckpt_path is None: ckpt_path = latest_torch_ckpt(train_ckpt_dir) if ckpt_path is not None and is_resume: - log.info('ckpt_path not given. Restore from the latest ckpt') + log.info("ckpt_path not given. Restore from the latest ckpt") else: log.info('Initializing from scratch.') - return + return 0 if not exists(ckpt_path): raise FileNotFoundError(f' ckpt {ckpt_path} not found') @@ -726,12 +734,16 @@ def load_ckpt(self, ckpt_path=None, is_resume=True): self.model.load_state_dict(ckpt['model_state_dict']) if 'optimizer_state_dict' in ckpt and hasattr(self, 'optimizer'): - log.info(f'Loading checkpoint optimizer_state_dict') + log.info('Loading checkpoint optimizer_state_dict') self.optimizer.load_state_dict(ckpt['optimizer_state_dict']) if 'scheduler_state_dict' in ckpt and hasattr(self, 'scheduler'): - log.info(f'Loading checkpoint scheduler_state_dict') + log.info('Loading checkpoint scheduler_state_dict') self.scheduler.load_state_dict(ckpt['scheduler_state_dict']) + epoch = 0 if 'epoch' not in ckpt else ckpt['epoch'] + + return epoch + def save_ckpt(self, epoch): """Save a checkpoint at the passed epoch.""" path_ckpt = join(self.cfg.logs_dir, 'checkpoint') diff --git a/scripts/preprocess_waymo_semseg.py b/scripts/preprocess_waymo_semseg.py index 002de1ce3..70b7f92be 100644 --- a/scripts/preprocess_waymo_semseg.py +++ b/scripts/preprocess_waymo_semseg.py @@ -154,9 +154,11 @@ def convert(self): # for i in tqdm(range(len(self))): # self.process_one(i) # with Pool(self.workers) as p: - # tqdm(p.imap(self.process_one, [i for i in range(len(self))]), total=len(self)) - # p.map(self.process_one, [i for i in range(len(self))]) - process_map(self.process_one, range(len(self)), max_workers=self.workers) + # tqdm(p.imap(self.process_one, [i for i in range(len(self))]), total=len(self)) + # p.map(self.process_one, [i for i in range(len(self))]) + process_map(self.process_one, + range(len(self)), + max_workers=self.workers) def process_one(self, file_idx): print(f"Converting : {file_idx}") From 5a4d385f903ed75eb3246956875b0c105e3c87d3 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Mon, 18 Apr 2022 11:16:41 -0700 Subject: [PATCH 32/50] gather metric among process --- ml3d/datasets/waymo_semseg.py | 2 +- ml3d/torch/models/sparseconvnet.py | 2 ++ ml3d/torch/modules/metrics/semseg_metric.py | 11 +++++++++++ ml3d/torch/pipelines/semantic_segmentation.py | 9 ++++++++- 4 files changed, 22 insertions(+), 2 deletions(-) diff --git a/ml3d/datasets/waymo_semseg.py b/ml3d/datasets/waymo_semseg.py index 00fa29345..3c3725fe4 100644 --- a/ml3d/datasets/waymo_semseg.py +++ b/ml3d/datasets/waymo_semseg.py @@ -181,7 +181,7 @@ def get_data(self, idx): pc_path = self.path_list[idx] pc = self.dataset.read_lidar(pc_path) - feat = pc[:, 3:6] + feat = pc[:, 3:5] # intensity, elongation label = pc[:, 7].astype(np.int32) pc = pc[:, :3] diff --git a/ml3d/torch/models/sparseconvnet.py b/ml3d/torch/models/sparseconvnet.py index 86c555f33..8282bd1b6 100644 --- a/ml3d/torch/models/sparseconvnet.py +++ b/ml3d/torch/models/sparseconvnet.py @@ -115,6 +115,8 @@ def preprocess(self, data, attr): "SparseConvnet doesn't work without feature values.") feat = np.array(data['feat'], dtype=np.float32) + if feat.shape[1] < 3: + feat = np.concatenate([feat, np.ones([feat.shape[0], 1])], 1) # Scale to voxel size. points *= 1. / self.cfg.voxel_size # Scale = 1/voxel_size diff --git a/ml3d/torch/modules/metrics/semseg_metric.py b/ml3d/torch/modules/metrics/semseg_metric.py index 763f31fdf..20f578c22 100644 --- a/ml3d/torch/modules/metrics/semseg_metric.py +++ b/ml3d/torch/modules/metrics/semseg_metric.py @@ -23,6 +23,17 @@ def update(self, scores, labels): assert self.confusion_matrix.shape == conf.shape self.confusion_matrix += conf + def __iadd__(self, otherMetric): + if self.confusion_matrix is None and otherMetric.confusion_matrix is None: + pass + elif self.confusion_matrix is None: + self.confusion_matrix = otherMetric.confusion_matrix + elif otherMetric.confusion_matrix is None: + pass + else: + self.confusion_matrix += otherMetric.confusion_matrix + return self + def acc(self): """Compute the per-class accuracies and the overall accuracy. diff --git a/ml3d/torch/pipelines/semantic_segmentation.py b/ml3d/torch/pipelines/semantic_segmentation.py index be00965d3..723439b05 100644 --- a/ml3d/torch/pipelines/semantic_segmentation.py +++ b/ml3d/torch/pipelines/semantic_segmentation.py @@ -505,7 +505,14 @@ def run_train(self): results, inputs['data'], epoch) if self.distributed: - # TODO (sanskar): accumulate confusion matrix for all process. + metric_gather = [None for _ in range(dist.get_world_size())] + dist.gather_object(self.metric_val, + metric_gather if rank == 0 else None, + dst=0) + if rank == 0: + for m in metric_gather[1:]: + self.metric_val += m + dist.barrier() if rank == 0: From df76487805a77f34f6f50f242d10e742f3072eea Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Wed, 27 Apr 2022 08:09:17 -0700 Subject: [PATCH 33/50] enable multi node training --- scripts/run_pipeline.py | 16 ++++++++++++---- 1 file changed, 12 insertions(+), 4 deletions(-) diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 68564e1b9..64718fdaf 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -43,6 +43,11 @@ def parse_args(): parser.add_argument('--main_log_dir', help='the dir to save logs and models') parser.add_argument('--seed', help='random seed', default=0) + parser.add_argument('--nodes', help='number of nodes', default=1, type=int) + parser.add_argument('--node_rank', + help='ranking within the nodes, default: 0', + default=0, + type=int) parser.add_argument( '--host', help='Host for distributed training, default: localhost', @@ -197,9 +202,10 @@ def cleanup(): dist.destroy_process_group() -def main_worker(rank, Dataset, Model, Pipeline, cfg_dict_dataset, +def main_worker(local_rank, Dataset, Model, Pipeline, cfg_dict_dataset, cfg_dict_model, cfg_dict_pipeline, args): - world_size = len(args.device_ids) + rank = args.node_rank * len(args.device_ids) + local_rank + world_size = args.nodes * len(args.device_ids) setup(rank, world_size, args) cfg_dict_dataset['rank'] = rank @@ -211,8 +217,10 @@ def main_worker(rank, Dataset, Model, Pipeline, cfg_dict_dataset, cfg_dict_model['seed'] = rng cfg_dict_pipeline['seed'] = rng - device = f"cuda:{args.device_ids[rank]}" - print(f"rank = {rank}, world_size = {world_size}, gpu = {device}") + device = f"cuda:{args.device_ids[local_rank]}" + print( + f"local_rank = {local_rank}, rank = {rank}, world_size = {world_size}, gpu = {device}" + ) cfg_dict_model['device'] = device cfg_dict_pipeline['device'] = device From 6fc34b15541683589e1682bdd6fe4a0e4144eaf2 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Thu, 5 May 2022 01:30:13 +0530 Subject: [PATCH 34/50] add nuscenes semantic --- ml3d/datasets/nuscenes.py | 14 ++++++++++++++ scripts/preprocess_nuscenes.py | 10 ++++++++++ 2 files changed, 24 insertions(+) diff --git a/ml3d/datasets/nuscenes.py b/ml3d/datasets/nuscenes.py index 44362e22a..2ef9dbc96 100644 --- a/ml3d/datasets/nuscenes.py +++ b/ml3d/datasets/nuscenes.py @@ -108,6 +108,17 @@ def read_lidar(path): return np.fromfile(path, dtype=np.float32).reshape(-1, 5) + @staticmethod + def read_lidarseg(path): + """Reads semantic data from the path provided. + + Returns: + A data object with semantic information. + """ + assert Path(path).exists() + + return np.fromfile(path, dtype=np.uint8).reshape(-1,).astype(np.int32) + @staticmethod def read_label(info, calib): """Reads labels of bound boxes. @@ -256,6 +267,7 @@ def __len__(self): def get_data(self, idx): info = self.infos[idx] lidar_path = info['lidar_path'] + lidarseg_path = info['lidarseg_path'] world_cam = np.eye(4) world_cam[:3, :3] = R.from_quat(info['lidar2ego_rot']).as_matrix() @@ -264,12 +276,14 @@ def get_data(self, idx): pc = self.dataset.read_lidar(lidar_path) label = self.dataset.read_label(info, calib) + lidarseg = self.dataset.read_lidarseg(lidarseg_path) data = { 'point': pc, 'feat': None, 'calib': calib, 'bounding_boxes': label, + 'label': lidarseg } if 'cams' in info: diff --git a/scripts/preprocess_nuscenes.py b/scripts/preprocess_nuscenes.py index 582240416..ec1ba09e5 100644 --- a/scripts/preprocess_nuscenes.py +++ b/scripts/preprocess_nuscenes.py @@ -58,11 +58,17 @@ def __init__(self, dataset_path, out_path, version='v1.0-trainval'): assert version in ['v1.0-trainval', 'v1.0-test', 'v1.0-mini'] self.is_test = 'test' in version self.out_path = out_path + self.dataset_path = dataset_path self.nusc = NuScenes(version=version, dataroot=dataset_path, verbose=True) + ## Get semantic label stats + # nusc = self.nusc + # print(nusc.list_lidarseg_categories(sort_by='count')) + # print(nusc.lidarseg_idx2name_mapping) + if version == 'v1.0-trainval': train_scenes = splits.train val_scenes = splits.val @@ -213,8 +219,12 @@ def process_scenes(self): lidar_path = os.path.abspath(lidar_path) assert os.path.exists(lidar_path) + lidarseg_path = nusc.get('lidarseg', lidar_token)['filename'] + lidarseg_path = os.path.abspath(os.path.join(self.dataset_path, lidarseg_path)) + data = { 'lidar_path': lidar_path, + 'lidarseg_path': lidarseg_path, 'token': sample['token'], 'cams': dict(), 'lidar2ego_tr': calib_rec['translation'], From 80d0040b71ad4965f85709f40822c8c5997b4523 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Thu, 5 May 2022 02:13:45 +0530 Subject: [PATCH 35/50] fix lidarseg --- scripts/preprocess_nuscenes.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/scripts/preprocess_nuscenes.py b/scripts/preprocess_nuscenes.py index ec1ba09e5..f2bbd7bf3 100644 --- a/scripts/preprocess_nuscenes.py +++ b/scripts/preprocess_nuscenes.py @@ -219,12 +219,8 @@ def process_scenes(self): lidar_path = os.path.abspath(lidar_path) assert os.path.exists(lidar_path) - lidarseg_path = nusc.get('lidarseg', lidar_token)['filename'] - lidarseg_path = os.path.abspath(os.path.join(self.dataset_path, lidarseg_path)) - data = { 'lidar_path': lidar_path, - 'lidarseg_path': lidarseg_path, 'token': sample['token'], 'cams': dict(), 'lidar2ego_tr': calib_rec['translation'], @@ -259,6 +255,11 @@ def process_scenes(self): data['cams'].update({cam: cam_info}) if not self.is_test: + lidarseg_path = nusc.get('lidarseg', lidar_token)['filename'] + lidarseg_path = os.path.abspath(os.path.join(self.dataset_path, lidarseg_path)) + assert os.path.exists(lidarseg_path) + data['lidarseg_path'] = lidarseg_path + annotations = [ nusc.get('sample_annotation', token) for token in sample['anns'] From 670ef483d091968c5fd8811fb6f7363db5f5e967 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 10 May 2022 21:11:03 +0530 Subject: [PATCH 36/50] add megaloader --- ml3d/datasets/__init__.py | 4 +- ml3d/datasets/megaloader.py | 160 +++++++++++++++++++++++++++++++++ scripts/preprocess_nuscenes.py | 3 +- 3 files changed, 165 insertions(+), 2 deletions(-) create mode 100644 ml3d/datasets/megaloader.py diff --git a/ml3d/datasets/__init__.py b/ml3d/datasets/__init__.py index 05b4b7fe7..befae57ee 100644 --- a/ml3d/datasets/__init__.py +++ b/ml3d/datasets/__init__.py @@ -22,11 +22,13 @@ from .scannet import Scannet from .sunrgbd import SunRGBD from .matterport_objects import MatterportObjects +from .kitti360 import KITTI360 +from .megaloader import MegaLoader __all__ = [ 'SemanticKITTI', 'S3DIS', 'Toronto3D', 'ParisLille3D', 'Semantic3D', 'Custom3D', 'utils', 'augment', 'samplers', 'KITTI', 'Waymo', 'NuScenes', 'Lyft', 'ShapeNet', 'SemSegRandomSampler', 'InferenceDummySplit', 'SemSegSpatiallyRegularSampler', 'Argoverse', 'Scannet', 'SunRGBD', - 'MatterportObjects', 'WaymoSemSeg' + 'MatterportObjects', 'WaymoSemSeg', 'KITTI360', 'MegaLoader' ] diff --git a/ml3d/datasets/megaloader.py b/ml3d/datasets/megaloader.py new file mode 100644 index 000000000..5dac617bd --- /dev/null +++ b/ml3d/datasets/megaloader.py @@ -0,0 +1,160 @@ +import numpy as np +import pandas as pd +import os, glob, pickle +from pathlib import Path +from os.path import join, exists, dirname, abspath, isdir +from sklearn.neighbors import KDTree +from tqdm import tqdm +import logging + +from .utils import DataProcessing, get_min_bbox, BEVBox3D +from .base_dataset import BaseDataset, BaseDatasetSplit +from ..utils import make_dir, DATASET, Config, get_module + +log = logging.getLogger(__name__) + + +class MegaLoader(): + """This class is used to create a combination of multiple datasets, + and sample data among them uniformly. + """ + + def __init__(self, + config_paths, + name='MegaLoader', + cache_dir='./logs/cache', + use_cache=False, + ignored_label_inds=[], + test_result_folder='./test', + **kwargs): + """Initialize the function by passing the dataset and other details. + + Args: + config_paths: List of dataset config files to use. + dataset_path: The path to the dataset to use (parent directory of data_3d_semantics). + name: The name of the dataset (MegaLoader in this case). + cache_dir: The directory where the cache is stored. + use_cache: Indicates if the dataset should be cached. + ignored_label_inds: A list of labels that should be ignored in the dataset. + test_result_folder: The folder where the test results should be stored. + """ + + self.cfg = Config(kwargs) + self.name = self.cfg.name + self.rng = np.random.default_rng(kwargs.get('seed', None)) + self.ignored_labels = np.array([]) + + self.num_datasets = len(config_paths) + self.configs = [ + Config.load_from_file(cfg_path) for cfg_path in config_paths + ] + self.datasets = [get_module('dataset', cfg.name) for cfg in configs] + + def get_split(self, split): + """Returns a dataset split. + + Args: + split: A string identifying the dataset split that is usually one of + 'training', 'test', 'validation', or 'all'. + + Returns: + A dataset split object providing the requested subset of the data. + """ + return MegaLoaderSplit(self, split=split) + + def get_split_list(self, split): + if split in ['train', 'training']: + return self.train_files + elif split in ['val', 'validation']: + return self.val_files + elif split in ['test', 'testing']: + return test_files + elif split == 'all': + return self.train_files + self.val_files + self.test_files + else: + raise ValueError("Invalid split {}".format(split)) + + def is_tested(self, attr): + cfg = self.cfg + name = attr['name'] + path = cfg.test_result_folder + store_path = join(path, self.name, name + '.npy') + if exists(store_path): + print("{} already exists.".format(store_path)) + return True + else: + return False + + def save_test_result(self, results, attr): + """Saves the output of a model. + + Args: + results: The output of a model for the datum associated with the attribute passed. + attr: The attributes that correspond to the outputs passed in results. + """ + cfg = self.cfg + name = attr['name'].split('.')[0] + path = cfg.test_result_folder + make_dir(path) + + pred = results['predict_labels'] + pred = np.array(pred) + + for ign in cfg.ignored_label_inds: + pred[pred >= ign] += 1 + + store_path = join(path, self.name, name + '.npy') + make_dir(Path(store_path).parent) + np.save(store_path, pred) + log.info("Saved {} in {}.".format(name, store_path)) + + +class MegaLoaderSplit(): + """This class is used to create a split for MegaLoader dataset. + + Initialize the class. + + Args: + dataset: The dataset to split. + split: A string identifying the dataset split that is usually one of + 'training', 'test', 'validation', or 'all'. + **kwargs: The configuration of the model as keyword arguments. + + Returns: + A dataset split object providing the requested subset of the data. + """ + + def __init__(self, dataset, split='training'): + self.cfg = dataset.cfg + self.split = split + self.dataset = dataset + self.dataset_splits = [ + a.get_split(split) for a in self.dataset.datasets + ] + self.num_datasets = dataset.num_datasets + + log.info("Found {} pointclouds for {}".format(len(self.path_list), + split)) + + def __len__(self): + lens = [len(a) for a in self.dataset_splits] + return max(lens) * self.num_datasets + + def get_data(self, idx): + dataset_idx = idx % self.num_datasets + idx = (idx // self.num_datasets) % len(self.dataset_splits[dataset_idx]) + + data = self.dataset_splits[dataset_idx].get_data(idx) + + return data + + def get_attr(self, idx): + dataset_idx = idx % self.num_datasets + idx = (idx // self.num_datasets) % len(self.dataset_splits[dataset_idx]) + attr = self.dataset_splits[dataset_idx].get_attr(idx) + attr['dataset_idx'] = dataset_idx + + return attr + + +DATASET._register_module(MegaLoader) diff --git a/scripts/preprocess_nuscenes.py b/scripts/preprocess_nuscenes.py index f2bbd7bf3..2e1bce21a 100644 --- a/scripts/preprocess_nuscenes.py +++ b/scripts/preprocess_nuscenes.py @@ -256,7 +256,8 @@ def process_scenes(self): if not self.is_test: lidarseg_path = nusc.get('lidarseg', lidar_token)['filename'] - lidarseg_path = os.path.abspath(os.path.join(self.dataset_path, lidarseg_path)) + lidarseg_path = os.path.abspath( + os.path.join(self.dataset_path, lidarseg_path)) assert os.path.exists(lidarseg_path) data['lidarseg_path'] = lidarseg_path From 1fdf0d17cdda75c9dcd9f0ff99b334f6b735f9c8 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Wed, 11 May 2022 18:01:14 +0530 Subject: [PATCH 37/50] add megamodel --- ml3d/torch/dataloaders/concat_batcher.py | 32 + ml3d/torch/models/sparseconvnet_megamodel.py | 690 +++++++++++++++++++ 2 files changed, 722 insertions(+) create mode 100644 ml3d/torch/models/sparseconvnet_megamodel.py diff --git a/ml3d/torch/dataloaders/concat_batcher.py b/ml3d/torch/dataloaders/concat_batcher.py index a10684663..295e056c1 100644 --- a/ml3d/torch/dataloaders/concat_batcher.py +++ b/ml3d/torch/dataloaders/concat_batcher.py @@ -452,6 +452,38 @@ def scatter(batch, num_gpu): return [b for b in batches if len(b.point)] # filter empty batch +class SparseConvUnetMegaLoaderBatch: + + def __init__(self, batches): + pc = [] + feat = [] + label = [] + lengths = [] + + for batch in batches: + data = batch['data'] + pc.append(data['point']) + feat.append(data['feat']) + label.append(data['label']) + lengths.append(data['point'].shape[0]) + + self.point = pc + self.feat = feat + self.label = label + self.batch_lengths = lengths + + def pin_memory(self): + self.point = [pc.pin_memory() for pc in self.point] + self.feat = [feat.pin_memory() for feat in self.feat] + self.label = [label.pin_memory() for label in self.label] + return self + + def to(self, device): + self.point = [pc.to(device) for pc in self.point] + self.feat = [feat.to(device) for feat in self.feat] + self.label = [label.to(device) for label in self.label] + + class PointTransformerBatch: def __init__(self, batches): diff --git a/ml3d/torch/models/sparseconvnet_megamodel.py b/ml3d/torch/models/sparseconvnet_megamodel.py new file mode 100644 index 000000000..4332cf3ea --- /dev/null +++ b/ml3d/torch/models/sparseconvnet_megamodel.py @@ -0,0 +1,690 @@ +import numpy as np +import torch +import torch.nn as nn + +from .base_model import BaseModel +from ...utils import MODEL +from ..modules.losses import filter_valid_label +from ...datasets.augment import SemsegAugmentation +from open3d.ml.torch.layers import SparseConv, SparseConvTranspose +from open3d.ml.torch.ops import voxelize, reduce_subarrays_sum + + +class SparseConvUnetMegaModel(BaseModel): + """Semantic Segmentation model. + + Uses UNet architecture replacing convolutions with Sparse Convolutions. + + Attributes: + name: Name of model. + Default to "SparseConvUnet". + device: Which device to use (cpu or cuda). + voxel_size: Voxel length for subsampling. + multiplier: min length of feature length in each layer. + conv_block_reps: repetition of Unet Blocks. + residual_blocks: Whether to use Residual Blocks. + in_channels: Number of features(default 3 for color). + num_classes: Number of classes. + """ + + def __init__( + self, + name="SparseConvUnet", + device="cuda", + num_heads=1, # number of segmentation heads. + multiplier=16, # Proportional to number of neurons in each layer. + voxel_size=0.05, + conv_block_reps=1, # Conv block repetitions. + residual_blocks=False, + in_channels=3, + num_classes=[20], + grid_size=4096, + batcher='ConcatBatcher', + augment=None, + **kwargs): + super(SparseConvUnet, self).__init__(name=name, + device=device, + num_heads=num_heads, + multiplier=multiplier, + voxel_size=voxel_size, + conv_block_reps=conv_block_reps, + residual_blocks=residual_blocks, + in_channels=in_channels, + num_classes=num_classes, + grid_size=grid_size, + batcher=batcher, + augment=augment, + **kwargs) + cfg = self.cfg + self.device = device + self.augmenter = SemsegAugmentation(cfg.augment, seed=self.rng) + self.multiplier = cfg.multiplier + self.input_layer = InputLayer() + self.sub_sparse_conv = SubmanifoldSparseConv(in_channels=in_channels, + filters=multiplier, + kernel_size=[3, 3, 3]) + self.unet = UNet(conv_block_reps, [ + multiplier, 2 * multiplier, 3 * multiplier, 4 * multiplier, + 5 * multiplier, 6 * multiplier, 7 * multiplier + ], residual_blocks) + self.batch_norm = BatchNormBlock(multiplier) + self.relu = ReLUBlock() + + if len(num_classes) != num_heads: + raise ValueError("Pass num_classes for each segmentation head.") + + self.linear = [] + for i in range(num_heads): + self.linear.append( + nn.Sequential(LinearBlock(multiplier, 2 * multiplier), + LinearBlock(2 * multiplier, num_classes[i]))) + + self.output_layer = OutputLayer() + + def forward(self, inputs): + pos_list = [] + feat_list = [] + index_map_list = [] + + for i in range(len(inputs.batch_lengths)): + pos = inputs.point[i] + feat = inputs.feat[i] + feat, pos, index_map = self.input_layer(feat, pos) + pos_list.append(pos) + feat_list.append(feat) + index_map_list.append(index_map) + + feat_list = self.sub_sparse_conv(feat_list, pos_list, voxel_size=1.0) + feat_list = self.unet(pos_list, feat_list) + feat_list = self.batch_norm(feat_list) + feat_list = self.relu(feat_list) + feat_list = self.linear(feat_list) + output = self.output_layer(feat_list, index_map_list) + + return output + + def preprocess(self, data, attr): + # If num_workers > 0, use new RNG with unique seed for each thread. + # Else, use default RNG. + if torch.utils.data.get_worker_info(): + seedseq = np.random.SeedSequence( + torch.utils.data.get_worker_info().seed + + torch.utils.data.get_worker_info().id) + rng = np.random.default_rng(seedseq.spawn(1)[0]) + else: + rng = self.rng + + points = np.array(data['point'], dtype=np.float32) + + if 'label' not in data or data['label'] is None: + labels = np.zeros((points.shape[0],), dtype=np.int32) + else: + labels = np.array(data['label'], dtype=np.int32).reshape((-1,)) + + if 'feat' not in data or data['feat'] is None: + raise Exception( + "SparseConvnet doesn't work without feature values.") + + feat = np.array(data['feat'], dtype=np.float32) + if feat.shape[1] < 3: + feat = np.concatenate([feat, np.ones([feat.shape[0], 1])], 1) + + # Scale to voxel size. + points *= 1. / self.cfg.voxel_size # Scale = 1/voxel_size + + if attr['split'] in ['training', 'train']: + points, feat, labels = self.augmenter.augment(points, + feat, + labels, + self.cfg.get( + 'augment', None), + seed=rng) + m = points.min(0) + M = points.max(0) + + # Randomly place pointcloud in 4096 size grid. + grid_size = self.cfg.grid_size + offset = -m + np.clip(grid_size - M + m - 0.001, 0, None) * rng.random( + 3) + np.clip(grid_size - M + m + 0.001, None, 0) * rng.random(3) + + points += offset + idxs = (points.min(1) >= 0) * (points.max(1) < 4096) + + points = points[idxs] + feat = feat[idxs] + labels = labels[idxs] + + points = (points.astype(np.int32) + 0.5).astype( + np.float32) # Move points to voxel center. + + data = {} + data['point'] = points + data['feat'] = feat + data['label'] = labels + + return data + + def transform(self, data, attr): + data['point'] = torch.from_numpy(data['point']) + data['feat'] = torch.from_numpy(data['feat']) + data['label'] = torch.from_numpy(data['label']) + + return data + + def update_probs(self, inputs, results, test_probs, test_labels): + result = results.reshape(-1, self.cfg.num_classes) + probs = torch.nn.functional.softmax(result, dim=-1).cpu().data.numpy() + labels = np.argmax(probs, 1) + + self.trans_point_sampler(patchwise=False) + + return probs, labels + + def inference_begin(self, data): + data = self.preprocess(data, {'split': 'test'}) + data['batch_lengths'] = [data['point'].shape[0]] + data = self.transform(data, {}) + + self.inference_input = data + + def inference_preprocess(self): + return self.inference_input + + def inference_end(self, inputs, results): + results = torch.reshape(results, (-1, self.cfg.num_classes)) + + m_softmax = torch.nn.Softmax(dim=-1) + results = m_softmax(results) + results = results.cpu().data.numpy() + + probs = np.reshape(results, [-1, self.cfg.num_classes]) + + pred_l = np.argmax(probs, 1) + + return {'predict_labels': pred_l, 'predict_scores': probs} + + def get_loss(self, Loss, results, inputs, device): + """Calculate the loss on output of the model. + + Attributes: + Loss: Object of type `SemSegLoss`. + results: Output of the model. + inputs: Input of the model. + device: device(cpu or cuda). + + Returns: + Returns loss, labels and scores. + """ + cfg = self.cfg + labels = torch.cat(inputs['data'].label, 0) + + scores, labels = filter_valid_label(results, labels, cfg.num_classes, + cfg.ignored_label_inds, device) + + loss = Loss.weighted_CrossEntropyLoss(scores, labels) + + return loss, labels, scores + + def get_optimizer(self, cfg_pipeline): + optimizer = torch.optim.Adam(self.parameters(), + **cfg_pipeline.optimizer) + scheduler = torch.optim.lr_scheduler.ExponentialLR( + optimizer, cfg_pipeline.scheduler_gamma) + + return optimizer, scheduler + + +MODEL._register_module(SparseConvUnet, 'torch') + + +class BatchNormBlock(nn.Module): + + def __init__(self, m, eps=1e-4, momentum=0.01): + super(BatchNormBlock, self).__init__() + self.bn = nn.BatchNorm1d(m, eps=eps, momentum=momentum) + + def forward(self, feat_list): + lengths = [feat.shape[0] for feat in feat_list] + out = self.bn(torch.cat(feat_list, 0)) + out_list = [] + start = 0 + for l in lengths: + out_list.append(out[start:start + l]) + start += l + + return out_list + + def __name__(self): + return "BatchNormBlock" + + +class ReLUBlock(nn.Module): + + def __init__(self): + super(ReLUBlock, self).__init__() + self.relu = nn.ReLU() + + def forward(self, feat_list): + lengths = [feat.shape[0] for feat in feat_list] + out = self.relu(torch.cat(feat_list, 0)) + out_list = [] + start = 0 + for l in lengths: + out_list.append(out[start:start + l]) + start += l + + return out_list + + def __name__(self): + return "ReLUBlock" + + +class LinearBlock(nn.Module): + + def __init__(self, a, b): + super(LinearBlock, self).__init__() + self.linear = nn.Linear(a, b) + + def forward(self, feat_list): + out_list = [] + for feat in feat_list: + out_list.append(self.linear(feat)) + + return out_list + + def __name__(self): + return "LinearBlock" + + +class InputLayer(nn.Module): + + def __init__(self, voxel_size=1.0): + super(InputLayer, self).__init__() + self.voxel_size = torch.Tensor([voxel_size, voxel_size, voxel_size]) + + def forward(self, features, in_positions): + v = voxelize( + in_positions, + torch.LongTensor([0, + in_positions.shape[0]]).to(in_positions.device), + self.voxel_size, torch.Tensor([0, 0, 0]), + torch.Tensor([40960, 40960, 40960])) + + # Contiguous repeating positions. + in_positions = in_positions[v.voxel_point_indices] + features = features[v.voxel_point_indices] + + # Find reverse mapping. + reverse_map_voxelize = np.zeros((in_positions.shape[0],)) + reverse_map_voxelize[v.voxel_point_indices.cpu().numpy()] = np.arange( + in_positions.shape[0]) + reverse_map_voxelize = reverse_map_voxelize.astype(np.int32) + + # Unique positions. + in_positions = in_positions[v.voxel_point_row_splits[:-1]] + + # Mean of features. + count = v.voxel_point_row_splits[1:] - v.voxel_point_row_splits[:-1] + reverse_map_sort = np.repeat(np.arange(count.shape[0]), + count.cpu().numpy()).astype(np.int32) + + features_avg = in_positions.clone() + features_avg[:, 0] = reduce_subarrays_sum(features[:, 0], + v.voxel_point_row_splits) + features_avg[:, 1] = reduce_subarrays_sum(features[:, 1], + v.voxel_point_row_splits) + features_avg[:, 2] = reduce_subarrays_sum(features[:, 2], + v.voxel_point_row_splits) + + features_avg = features_avg / count.unsqueeze(1) + + return features_avg, in_positions, reverse_map_sort[ + reverse_map_voxelize] + + +class OutputLayer(nn.Module): + + def __init__(self, voxel_size=1.0): + super(OutputLayer, self).__init__() + + def forward(self, features_list, index_map_list): + out = [] + for feat, index_map in zip(features_list, index_map_list): + out.append(feat[index_map]) + return torch.cat(out, 0) + + +class SubmanifoldSparseConv(nn.Module): + + def __init__(self, + in_channels, + filters, + kernel_size, + use_bias=False, + offset=None, + normalize=False): + super(SubmanifoldSparseConv, self).__init__() + + if offset is None: + if kernel_size[0] % 2: + offset = 0. + else: + offset = 0.5 + + offset = torch.full((3,), offset, dtype=torch.float32) + self.net = SparseConv(in_channels=in_channels, + filters=filters, + kernel_size=kernel_size, + use_bias=use_bias, + offset=offset, + normalize=normalize) + + def forward(self, + features_list, + in_positions_list, + out_positions_list=None, + voxel_size=1.0): + if out_positions_list is None: + out_positions_list = in_positions_list + + out_feat = [] + for feat, in_pos, out_pos in zip(features_list, in_positions_list, + out_positions_list): + out_feat.append(self.net(feat, in_pos, out_pos, voxel_size)) + + return out_feat + + def __name__(self): + return "SubmanifoldSparseConv" + + +def calculate_grid(in_positions): + filter = torch.Tensor([[-1, -1, -1], [-1, -1, 0], [-1, 0, -1], [-1, 0, 0], + [0, -1, -1], [0, -1, 0], [0, 0, -1], + [0, 0, 0]]).to(in_positions.device) + + out_pos = in_positions.long().repeat(1, filter.shape[0]).reshape(-1, 3) + filter = filter.repeat(in_positions.shape[0], 1) + + out_pos = out_pos + filter + out_pos = out_pos[out_pos.min(1).values >= 0] + out_pos = out_pos[(~((out_pos.long() % 2).bool()).any(1))] + out_pos = torch.unique(out_pos, dim=0) + + return out_pos + 0.5 + + +class Convolution(nn.Module): + + def __init__(self, + in_channels, + filters, + kernel_size, + use_bias=False, + offset=None, + normalize=False): + super(Convolution, self).__init__() + + if offset is None: + if kernel_size[0] % 2: + offset = 0. + else: + offset = -0.5 + + offset = torch.full((3,), offset, dtype=torch.float32) + self.net = SparseConv(in_channels=in_channels, + filters=filters, + kernel_size=kernel_size, + use_bias=use_bias, + offset=offset, + normalize=normalize) + + def forward(self, features_list, in_positions_list, voxel_size=1.0): + out_positions_list = [] + for in_positions in in_positions_list: + out_positions_list.append(calculate_grid(in_positions)) + + out_feat = [] + for feat, in_pos, out_pos in zip(features_list, in_positions_list, + out_positions_list): + out_feat.append(self.net(feat, in_pos, out_pos, voxel_size)) + + out_positions_list = [out / 2 for out in out_positions_list] + + return out_feat, out_positions_list + + def __name__(self): + return "Convolution" + + +class DeConvolution(nn.Module): + + def __init__(self, + in_channels, + filters, + kernel_size, + use_bias=False, + offset=None, + normalize=False): + super(DeConvolution, self).__init__() + + if offset is None: + if kernel_size[0] % 2: + offset = 0. + else: + offset = -0.5 + + offset = torch.full((3,), offset, dtype=torch.float32) + self.net = SparseConvTranspose(in_channels=in_channels, + filters=filters, + kernel_size=kernel_size, + use_bias=use_bias, + offset=offset, + normalize=normalize) + + def forward(self, + features_list, + in_positions_list, + out_positions_list, + voxel_size=1.0): + out_feat = [] + for feat, in_pos, out_pos in zip(features_list, in_positions_list, + out_positions_list): + out_feat.append(self.net(feat, in_pos, out_pos, voxel_size)) + + return out_feat + + def __name__(self): + return "DeConvolution" + + +class ConcatFeat(nn.Module): + + def __init__(self): + super(ConcatFeat, self).__init__() + + def __name__(self): + return "ConcatFeat" + + def forward(self, feat): + return feat + + +class JoinFeat(nn.Module): + + def __init__(self): + super(JoinFeat, self).__init__() + + def __name__(self): + return "JoinFeat" + + def forward(self, feat_cat, feat): + out = [] + for a, b in zip(feat_cat, feat): + out.append(torch.cat([a, b], -1)) + + return out + + +class NetworkInNetwork(nn.Module): + + def __init__(self, nIn, nOut, bias=False): + super(NetworkInNetwork, self).__init__() + if nIn == nOut: + self.linear = nn.Identity() + else: + self.linear = nn.Linear(nIn, nOut, bias=bias) + + def forward(self, inputs): + out = [] + for inp in inputs: + out.append(self.linear(inp)) + + return out + + +class ResidualBlock(nn.Module): + + def __init__(self, nIn, nOut): + super(ResidualBlock, self).__init__() + + self.lin = NetworkInNetwork(nIn, nOut) + + self.batch_norm1 = BatchNormBlock(nIn) + self.relu1 = ReLUBlock() + self.sub_sparse_conv1 = SubmanifoldSparseConv(in_channels=nIn, + filters=nOut, + kernel_size=[3, 3, 3]) + + self.batch_norm2 = BatchNormBlock(nOut) + self.relu2 = ReLUBlock() + self.sub_sparse_conv2 = SubmanifoldSparseConv(in_channels=nOut, + filters=nOut, + kernel_size=[3, 3, 3]) + + def forward(self, feat_list, pos_list): + out1 = self.lin(feat_list) + feat_list = self.batch_norm1(feat_list) + feat_list = self.relu1(feat_list) + feat_list = self.sub_sparse_conv1(feat_list, pos_list) + feat_list = self.batch_norm2(feat_list) + feat_list = self.relu2(feat_list) + out2 = self.sub_sparse_conv2(feat_list, pos_list) + + return [a + b for a, b in zip(out1, out2)] + + def __name__(self): + return "ResidualBlock" + + +class UNet(nn.Module): + + def __init__(self, + conv_block_reps, + nPlanes, + residual_blocks=False, + downsample=[2, 2], + leakiness=0): + super(UNet, self).__init__() + self.net = nn.ModuleList( + self.get_UNet(nPlanes, residual_blocks, conv_block_reps)) + self.residual_blocks = residual_blocks + + @staticmethod + def block(layers, a, b, residual_blocks): + if residual_blocks: + layers.append(ResidualBlock(a, b)) + + else: + layers.append(BatchNormBlock(a)) + layers.append(ReLUBlock()) + layers.append( + SubmanifoldSparseConv(in_channels=a, + filters=b, + kernel_size=[3, 3, 3])) + + @staticmethod + def get_UNet(nPlanes, residual_blocks, conv_block_reps): + layers = [] + for i in range(conv_block_reps): + UNet.block(layers, nPlanes[0], nPlanes[0], residual_blocks) + + if len(nPlanes) > 1: + layers.append(ConcatFeat()) + layers.append(BatchNormBlock(nPlanes[0])) + layers.append(ReLUBlock()) + layers.append( + Convolution(in_channels=nPlanes[0], + filters=nPlanes[1], + kernel_size=[2, 2, 2])) + layers = layers + UNet.get_UNet(nPlanes[1:], residual_blocks, + conv_block_reps) + layers.append(BatchNormBlock(nPlanes[1])) + layers.append(ReLUBlock()) + layers.append( + DeConvolution(in_channels=nPlanes[1], + filters=nPlanes[0], + kernel_size=[2, 2, 2])) + + layers.append(JoinFeat()) + + for i in range(conv_block_reps): + UNet.block(layers, nPlanes[0] * (2 if i == 0 else 1), + nPlanes[0], residual_blocks) + + return layers + + def forward(self, pos_list, feat_list): + conv_pos = [] + concat_feat = [] + for module in self.net: + if isinstance(module, BatchNormBlock): + feat_list = module(feat_list) + elif isinstance(module, ReLUBlock): + feat_list = module(feat_list) + + elif isinstance(module, ResidualBlock): + feat_list = module(feat_list, pos_list) + + elif isinstance(module, SubmanifoldSparseConv): + feat_list = module(feat_list, pos_list) + + elif isinstance(module, Convolution): + conv_pos.append([pos.clone() for pos in pos_list]) + feat_list, pos_list = module(feat_list, pos_list) + + elif isinstance(module, DeConvolution): + feat_list = module(feat_list, [2 * pos for pos in pos_list], + conv_pos[-1]) + pos_list = conv_pos.pop() + + elif isinstance(module, ConcatFeat): + concat_feat.append([feat.clone() for feat in module(feat_list)]) + + elif isinstance(module, JoinFeat): + feat_list = module(concat_feat.pop(), feat_list) + + else: + raise Exception("Unknown module {}".format(module)) + + return feat_list + + +def load_unet_wts(net, path): + wts = list(torch.load(path).values()) + state_dict = net.state_dict() + i = 0 + for key in state_dict: + if 'offset' in key or 'tracked' in key: + continue + if len(wts[i].shape) == 4: + shp = wts[i].shape + state_dict[key] = np.transpose( + wts[i].reshape(int(shp[0]**(1 / 3)), int(shp[0]**(1 / 3)), + int(shp[0]**(1 / 3)), shp[-2], shp[-1]), + (2, 1, 0, 3, 4)) + else: + state_dict[key] = wts[i] + i += 1 + + net.load_state_dict(state_dict) From b5c5296911c57d51ab6c89dfcdbb2a20f4e6cc87 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Fri, 13 May 2022 11:25:55 +0530 Subject: [PATCH 38/50] update linear layers --- ml3d/torch/dataloaders/concat_batcher.py | 13 +++++++++++- ml3d/torch/models/sparseconvnet_megamodel.py | 21 ++++++++++---------- 2 files changed, 23 insertions(+), 11 deletions(-) diff --git a/ml3d/torch/dataloaders/concat_batcher.py b/ml3d/torch/dataloaders/concat_batcher.py index 295e056c1..610543819 100644 --- a/ml3d/torch/dataloaders/concat_batcher.py +++ b/ml3d/torch/dataloaders/concat_batcher.py @@ -452,13 +452,14 @@ def scatter(batch, num_gpu): return [b for b in batches if len(b.point)] # filter empty batch -class SparseConvUnetMegaLoaderBatch: +class SparseConvUnetMegaModelBatch: def __init__(self, batches): pc = [] feat = [] label = [] lengths = [] + dataset_idx = [] for batch in batches: data = batch['data'] @@ -467,10 +468,17 @@ def __init__(self, batches): label.append(data['label']) lengths.append(data['point'].shape[0]) + attr = batch['attr'] + if 'dataset_idx' not in attr: + raise ValueError( + "dataset_idx is missing. Please use MegaLoader.") + dataset_idx.append(attr['dataset_idx']) + self.point = pc self.feat = feat self.label = label self.batch_lengths = lengths + self.dataset_idx = np.array(dataset_idx, dtype=np.int32) def pin_memory(self): self.point = [pc.pin_memory() for pc in self.point] @@ -618,6 +626,9 @@ def collate_fn(self, batches): elif self.model == "SparseConvUnet": return {'data': SparseConvUnetBatch(batches), 'attr': {}} + elif self.model == "SparseConvUnetMegaModel": + return {'data': SparseConvUnetMegaModelBatch(batches), 'attr': {}} + elif self.model == "PointTransformer": return {'data': PointTransformerBatch(batches), 'attr': {}} diff --git a/ml3d/torch/models/sparseconvnet_megamodel.py b/ml3d/torch/models/sparseconvnet_megamodel.py index 4332cf3ea..978f60487 100644 --- a/ml3d/torch/models/sparseconvnet_megamodel.py +++ b/ml3d/torch/models/sparseconvnet_megamodel.py @@ -73,12 +73,7 @@ def __init__( if len(num_classes) != num_heads: raise ValueError("Pass num_classes for each segmentation head.") - self.linear = [] - for i in range(num_heads): - self.linear.append( - nn.Sequential(LinearBlock(multiplier, 2 * multiplier), - LinearBlock(2 * multiplier, num_classes[i]))) - + self.linear = LinearBlock(multiplier, num_classes) self.output_layer = OutputLayer() def forward(self, inputs): @@ -281,14 +276,20 @@ def __name__(self): class LinearBlock(nn.Module): - def __init__(self, a, b): + def __init__(self, in_dim, num_classes): super(LinearBlock, self).__init__() - self.linear = nn.Linear(a, b) + + self.linear = [] + self.num_classes = num_classes + for i in range(len(num_classes)): + self.linear.append( + nn.Sequential(nn.Linear(in_dim, 2 * in_dim), + nn.Linear(2 * in_dim, num_classes[i]))) def forward(self, feat_list): out_list = [] - for feat in feat_list: - out_list.append(self.linear(feat)) + for i, feat in enumerate(feat_list): + out_list.append(self.linear[i](feat)) return out_list From e4c444135ca7421fc4621c616d2c75590dd17eb2 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Wed, 25 May 2022 19:56:30 +0530 Subject: [PATCH 39/50] add new pipeline --- ml3d/datasets/megaloader.py | 36 +- ml3d/torch/dataloaders/concat_batcher.py | 7 +- ml3d/torch/models/__init__.py | 3 +- ml3d/torch/models/sparseconvnet_megamodel.py | 78 +- ml3d/torch/modules/losses/__init__.py | 4 +- ml3d/torch/modules/losses/semseg_loss.py | 16 + ml3d/torch/pipelines/__init__.py | 5 +- .../semantic_segmentation_multi_head.py | 789 ++++++++++++++++++ 8 files changed, 892 insertions(+), 46 deletions(-) create mode 100644 ml3d/torch/pipelines/semantic_segmentation_multi_head.py diff --git a/ml3d/datasets/megaloader.py b/ml3d/datasets/megaloader.py index 5dac617bd..8ab8c2fc4 100644 --- a/ml3d/datasets/megaloader.py +++ b/ml3d/datasets/megaloader.py @@ -21,6 +21,7 @@ class MegaLoader(): def __init__(self, config_paths, + batch_size=1, name='MegaLoader', cache_dir='./logs/cache', use_cache=False, @@ -39,8 +40,17 @@ def __init__(self, test_result_folder: The folder where the test results should be stored. """ + kwargs['config_paths'] = config_paths + kwargs['name'] = name + kwargs['cache_dir'] = cache_dir + kwargs['use_cache'] = use_cache + kwargs['ignored_label_inds'] = ignored_label_inds + kwargs['test_result_folder'] = test_result_folder + kwargs['batch_size'] = batch_size + self.cfg = Config(kwargs) self.name = self.cfg.name + self.batch_size = batch_size self.rng = np.random.default_rng(kwargs.get('seed', None)) self.ignored_labels = np.array([]) @@ -48,7 +58,9 @@ def __init__(self, self.configs = [ Config.load_from_file(cfg_path) for cfg_path in config_paths ] - self.datasets = [get_module('dataset', cfg.name) for cfg in configs] + self.datasets = [ + get_module('dataset', cfg.name)(**cfg) for cfg in self.configs + ] def get_split(self, split): """Returns a dataset split. @@ -133,24 +145,34 @@ def __init__(self, dataset, split='training'): ] self.num_datasets = dataset.num_datasets - log.info("Found {} pointclouds for {}".format(len(self.path_list), - split)) + # log.info("Found {} pointclouds for {}".format(len(self.path_list), + # split)) def __len__(self): lens = [len(a) for a in self.dataset_splits] return max(lens) * self.num_datasets def get_data(self, idx): - dataset_idx = idx % self.num_datasets - idx = (idx // self.num_datasets) % len(self.dataset_splits[dataset_idx]) + dataset_idx = (idx // self.dataset.batch_size) % self.num_datasets + idx = ((((idx // self.dataset.batch_size) // self.num_datasets) * + self.dataset.batch_size) + idx % self.dataset.batch_size) % len( + self.dataset_splits[dataset_idx]) + + # dataset_idx = idx % self.num_datasets + # idx = (idx // self.num_datasets) % len(self.dataset_splits[dataset_idx]) data = self.dataset_splits[dataset_idx].get_data(idx) return data def get_attr(self, idx): - dataset_idx = idx % self.num_datasets - idx = (idx // self.num_datasets) % len(self.dataset_splits[dataset_idx]) + dataset_idx = (idx // self.dataset.batch_size) % self.num_datasets + idx = ((((idx // self.dataset.batch_size) // self.num_datasets) * + self.dataset.batch_size) + idx % self.dataset.batch_size) % len( + self.dataset_splits[dataset_idx]) + + # dataset_idx = idx % self.num_datasets + # idx = (idx // self.num_datasets) % len(self.dataset_splits[dataset_idx]) attr = self.dataset_splits[dataset_idx].get_attr(idx) attr['dataset_idx'] = dataset_idx diff --git a/ml3d/torch/dataloaders/concat_batcher.py b/ml3d/torch/dataloaders/concat_batcher.py index 610543819..f304d7c34 100644 --- a/ml3d/torch/dataloaders/concat_batcher.py +++ b/ml3d/torch/dataloaders/concat_batcher.py @@ -474,11 +474,16 @@ def __init__(self, batches): "dataset_idx is missing. Please use MegaLoader.") dataset_idx.append(attr['dataset_idx']) + if len(set(dataset_idx)) != 1: + raise ValueError( + "Multiple datasets in a single batch is not supported." + "Make sure to pass same batch size to dataset and pipeline") + self.point = pc self.feat = feat self.label = label self.batch_lengths = lengths - self.dataset_idx = np.array(dataset_idx, dtype=np.int32) + self.dataset_idx = dataset_idx[0] def pin_memory(self): self.point = [pc.pin_memory() for pc in self.point] diff --git a/ml3d/torch/models/__init__.py b/ml3d/torch/models/__init__.py index 817f0c095..31c0dab52 100644 --- a/ml3d/torch/models/__init__.py +++ b/ml3d/torch/models/__init__.py @@ -7,10 +7,11 @@ from .point_rcnn import PointRCNN from .point_transformer import PointTransformer from .pvcnn import PVCNN +from .sparseconvnet_megamodel import SparseConvUnetMegaModel __all__ = [ 'RandLANet', 'KPFCNN', 'PointPillars', 'PointRCNN', 'SparseConvUnet', - 'PointTransformer', 'PVCNN' + 'PointTransformer', 'PVCNN', 'SparseConvUnetMegaModel' ] try: diff --git a/ml3d/torch/models/sparseconvnet_megamodel.py b/ml3d/torch/models/sparseconvnet_megamodel.py index 978f60487..53b8f2ffe 100644 --- a/ml3d/torch/models/sparseconvnet_megamodel.py +++ b/ml3d/torch/models/sparseconvnet_megamodel.py @@ -29,11 +29,11 @@ class SparseConvUnetMegaModel(BaseModel): def __init__( self, - name="SparseConvUnet", + name="SparseConvUnetMegaModel", device="cuda", num_heads=1, # number of segmentation heads. - multiplier=16, # Proportional to number of neurons in each layer. - voxel_size=0.05, + multiplier=4, # Proportional to number of neurons in each layer. + voxel_size=0.1, conv_block_reps=1, # Conv block repetitions. residual_blocks=False, in_channels=3, @@ -41,24 +41,28 @@ def __init__( grid_size=4096, batcher='ConcatBatcher', augment=None, + ckpt_path=None, **kwargs): - super(SparseConvUnet, self).__init__(name=name, - device=device, - num_heads=num_heads, - multiplier=multiplier, - voxel_size=voxel_size, - conv_block_reps=conv_block_reps, - residual_blocks=residual_blocks, - in_channels=in_channels, - num_classes=num_classes, - grid_size=grid_size, - batcher=batcher, - augment=augment, - **kwargs) + super(SparseConvUnetMegaModel, + self).__init__(name=name, + device=device, + num_heads=num_heads, + multiplier=multiplier, + voxel_size=voxel_size, + conv_block_reps=conv_block_reps, + residual_blocks=residual_blocks, + in_channels=in_channels, + num_classes=num_classes, + grid_size=grid_size, + batcher=batcher, + augment=augment, + ckpt_path=ckpt_path, + **kwargs) cfg = self.cfg self.device = device self.augmenter = SemsegAugmentation(cfg.augment, seed=self.rng) self.multiplier = cfg.multiplier + self.num_heads = num_heads self.input_layer = InputLayer() self.sub_sparse_conv = SubmanifoldSparseConv(in_channels=in_channels, filters=multiplier, @@ -93,7 +97,7 @@ def forward(self, inputs): feat_list = self.unet(pos_list, feat_list) feat_list = self.batch_norm(feat_list) feat_list = self.relu(feat_list) - feat_list = self.linear(feat_list) + feat_list = self.linear(feat_list, inputs.dataset_idx) output = self.output_layer(feat_list, index_map_list) return output @@ -206,30 +210,31 @@ def get_loss(self, Loss, results, inputs, device): results: Output of the model. inputs: Input of the model. device: device(cpu or cuda). - + Returns: Returns loss, labels and scores. """ cfg = self.cfg - labels = torch.cat(inputs['data'].label, 0) - - scores, labels = filter_valid_label(results, labels, cfg.num_classes, - cfg.ignored_label_inds, device) + labels = torch.cat(inputs['data'].label, 0).to(torch.LongTensor()) - loss = Loss.weighted_CrossEntropyLoss(scores, labels) + loss = Loss.weighted_CrossEntropyLoss[inputs['data'].dataset_idx]( + results, labels) - return loss, labels, scores + return loss, labels, results def get_optimizer(self, cfg_pipeline): - optimizer = torch.optim.Adam(self.parameters(), - **cfg_pipeline.optimizer) - scheduler = torch.optim.lr_scheduler.ExponentialLR( - optimizer, cfg_pipeline.scheduler_gamma) + # optimizer = torch.optim.Adam(self.parameters(), + # **cfg_pipeline.optimizer) + # scheduler = torch.optim.lr_scheduler.ExponentialLR( + # optimizer, cfg_pipeline.scheduler_gamma) + + optimizer = torch.optim.Adam(self.parameters(), lr=0.001) + scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99) return optimizer, scheduler -MODEL._register_module(SparseConvUnet, 'torch') +MODEL._register_module(SparseConvUnetMegaModel, 'torch') class BatchNormBlock(nn.Module): @@ -279,17 +284,19 @@ class LinearBlock(nn.Module): def __init__(self, in_dim, num_classes): super(LinearBlock, self).__init__() - self.linear = [] + linear = [] self.num_classes = num_classes for i in range(len(num_classes)): - self.linear.append( + linear.append( nn.Sequential(nn.Linear(in_dim, 2 * in_dim), nn.Linear(2 * in_dim, num_classes[i]))) - def forward(self, feat_list): + self.linear = nn.ModuleList(linear) + + def forward(self, feat_list, dataset_idx): out_list = [] for i, feat in enumerate(feat_list): - out_list.append(self.linear[i](feat)) + out_list.append(self.linear[dataset_idx](feat)) return out_list @@ -352,7 +359,10 @@ def forward(self, features_list, index_map_list): out = [] for feat, index_map in zip(features_list, index_map_list): out.append(feat[index_map]) - return torch.cat(out, 0) + + out = torch.cat(out, 0) + + return out class SubmanifoldSparseConv(nn.Module): diff --git a/ml3d/torch/modules/losses/__init__.py b/ml3d/torch/modules/losses/__init__.py index 9582cf4b5..00d0a7ec9 100644 --- a/ml3d/torch/modules/losses/__init__.py +++ b/ml3d/torch/modules/losses/__init__.py @@ -1,11 +1,11 @@ """Loss modules""" -from .semseg_loss import filter_valid_label, SemSegLoss +from .semseg_loss import filter_valid_label, SemSegLoss, SemSegLossV2 from .cross_entropy import CrossEntropyLoss from .focal_loss import FocalLoss from .smooth_L1 import SmoothL1Loss __all__ = [ 'filter_valid_label', 'SemSegLoss', 'CrossEntropyLoss', 'FocalLoss', - 'SmoothL1Loss' + 'SmoothL1Loss', 'SemSegLossV2' ] diff --git a/ml3d/torch/modules/losses/semseg_loss.py b/ml3d/torch/modules/losses/semseg_loss.py index 8b7846f65..12a147def 100644 --- a/ml3d/torch/modules/losses/semseg_loss.py +++ b/ml3d/torch/modules/losses/semseg_loss.py @@ -52,3 +52,19 @@ def __init__(self, pipeline, model, dataset, device): self.weighted_CrossEntropyLoss = nn.CrossEntropyLoss(weight=weights) else: self.weighted_CrossEntropyLoss = nn.CrossEntropyLoss() + + +class SemSegLossV2(object): + """Loss functions for semantic segmentation.""" + + def __init__(self, num_heads, num_classes, ignored_labels=[], weights=None): + super(SemSegLossV2, self).__init__() + # weighted_CrossEntropyLoss + self.weighted_CrossEntropyLoss = [] + + for i in range(num_heads): + weights = torch.ones(num_classes[i]) + weights[ignored_labels[i]] = 0 + weights = torch.tensor(weights, dtype=torch.float) + self.weighted_CrossEntropyLoss.append( + nn.CrossEntropyLoss(weight=weights)) diff --git a/ml3d/torch/pipelines/__init__.py b/ml3d/torch/pipelines/__init__.py index e68b13df5..afd906e28 100644 --- a/ml3d/torch/pipelines/__init__.py +++ b/ml3d/torch/pipelines/__init__.py @@ -2,5 +2,8 @@ from .semantic_segmentation import SemanticSegmentation from .object_detection import ObjectDetection +from .semantic_segmentation_multi_head import SemanticSegmentationMultiHead -__all__ = ['SemanticSegmentation', 'ObjectDetection'] +__all__ = [ + 'SemanticSegmentation', 'ObjectDetection', 'SemanticSegmentationMultiHead' +] diff --git a/ml3d/torch/pipelines/semantic_segmentation_multi_head.py b/ml3d/torch/pipelines/semantic_segmentation_multi_head.py new file mode 100644 index 000000000..60eaf8bf6 --- /dev/null +++ b/ml3d/torch/pipelines/semantic_segmentation_multi_head.py @@ -0,0 +1,789 @@ +import logging +import numpy as np +import torch +import torch.distributed as dist + +from os.path import exists, join +from pathlib import Path +from datetime import datetime +from tqdm import tqdm +from torch.utils.tensorboard import SummaryWriter +from torch.utils.data import DataLoader + +# pylint: disable-next=unused-import +from open3d.visualization.tensorboard_plugin import summary +from .base_pipeline import BasePipeline +from ..dataloaders import get_sampler, TorchDataloader, DefaultBatcher, ConcatBatcher +from ..utils import latest_torch_ckpt +from ..modules.losses import SemSegLossV2, filter_valid_label +from ..modules.metrics import SemSegMetric +from ...utils import make_dir, PIPELINE, get_runid, code2md +from ...datasets import InferenceDummySplit + +log = logging.getLogger(__name__) + + +class SemanticSegmentationMultiHead(BasePipeline): + """This class allows you to perform semantic segmentation for both training + and inference using the Torch. This pipeline has multiple stages: Pre- + processing, loading dataset, testing, and inference or training. + + **Example:** + This example loads the Semantic Segmentation and performs a training + using the SemanticKITTI dataset. + + import torch + import torch.nn as nn + + from .base_pipeline import BasePipeline + from torch.utils.tensorboard import SummaryWriter + from ..dataloaders import get_sampler, TorchDataloader, DefaultBatcher, ConcatBatcher + + Mydataset = TorchDataloader(dataset=dataset.get_split('training')), + MyModel = SemanticSegmentation(self,model,dataset=Mydataset, name='SemanticSegmentation', + name='MySemanticSegmentation', + batch_size=4, + val_batch_size=4, + test_batch_size=3, + max_epoch=100, + learning_rate=1e-2, + lr_decays=0.95, + save_ckpt_freq=20, + adam_lr=1e-2, + scheduler_gamma=0.95, + momentum=0.98, + main_log_dir='./logs/', + device='gpu', + split='train', + train_sum_dir='train_log') + + **Args:** + dataset: The 3D ML dataset class. You can use the base dataset, sample datasets , or a custom dataset. + model: The model to be used for building the pipeline. + name: The name of the current training. + batch_size: The batch size to be used for training. + val_batch_size: The batch size to be used for validation. + test_batch_size: The batch size to be used for testing. + max_epoch: The maximum size of the epoch to be used for training. + leanring_rate: The hyperparameter that controls the weights during training. Also, known as step size. + lr_decays: The learning rate decay for the training. + save_ckpt_freq: The frequency in which the checkpoint should be saved. + adam_lr: The leanring rate to be applied for Adam optimization. + scheduler_gamma: The decaying factor associated with the scheduler. + momentum: The momentum that accelerates the training rate schedule. + main_log_dir: The directory where logs are stored. + device: The device to be used for training. + split: The dataset split to be used. In this example, we have used "train". + train_sum_dir: The directory where the trainig summary is stored. + + **Returns:** + class: The corresponding class. + """ + + def __init__( + self, + model, + dataset=None, + name='SemanticSegmentation', + batch_size=4, + val_batch_size=4, + test_batch_size=3, + max_epoch=100, # maximum epoch during training + learning_rate=1e-2, # initial learning rate + lr_decays=0.95, + save_ckpt_freq=20, + adam_lr=1e-2, + scheduler_gamma=0.95, + momentum=0.98, + main_log_dir='./logs/', + device='cuda', + split='train', + train_sum_dir='train_log', + **kwargs): + + super().__init__(model=model, + dataset=dataset, + name=name, + batch_size=batch_size, + val_batch_size=val_batch_size, + test_batch_size=test_batch_size, + max_epoch=max_epoch, + learning_rate=learning_rate, + lr_decays=lr_decays, + save_ckpt_freq=save_ckpt_freq, + adam_lr=adam_lr, + scheduler_gamma=scheduler_gamma, + momentum=momentum, + main_log_dir=main_log_dir, + device=device, + split=split, + train_sum_dir=train_sum_dir, + **kwargs) + + def run_inference(self, data): + """Run inference on given data. + + Args: + data: A raw data. + + Returns: + Returns the inference results. + """ + cfg = self.cfg + model = self.model + device = self.device + + model.to(device) + model.device = device + model.eval() + + batcher = self.get_batcher(device) + infer_dataset = InferenceDummySplit(data) + self.dataset_split = infer_dataset + infer_sampler = infer_dataset.sampler + infer_split = TorchDataloader(dataset=infer_dataset, + preprocess=model.preprocess, + transform=model.transform, + sampler=infer_sampler, + use_cache=False) + infer_loader = DataLoader(infer_split, + batch_size=cfg.batch_size, + sampler=get_sampler(infer_sampler), + collate_fn=batcher.collate_fn) + + model.trans_point_sampler = infer_sampler.get_point_sampler() + self.curr_cloud_id = -1 + self.test_probs = [] + self.test_labels = [] + self.ori_test_probs = [] + self.ori_test_labels = [] + + with torch.no_grad(): + for unused_step, inputs in enumerate(infer_loader): + results = model(inputs['data']) + self.update_tests(infer_sampler, inputs, results) + + inference_result = { + 'predict_labels': self.ori_test_labels.pop(), + 'predict_scores': self.ori_test_probs.pop() + } + + metric = SemSegMetric() + + valid_scores, valid_labels = filter_valid_label( + torch.tensor(inference_result['predict_scores']), + torch.tensor(data['label']), model.cfg.num_classes, + model.cfg.ignored_label_inds, device) + + metric.update(valid_scores, valid_labels) + log.info(f"Accuracy : {metric.acc()}") + log.info(f"IoU : {metric.iou()}") + + return inference_result + + def run_test(self): + """Run the test using the data passed.""" + model = self.model + dataset = self.dataset + device = self.device + cfg = self.cfg + model.device = device + model.to(device) + model.eval() + self.metric_test = SemSegMetric() + + timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') + + log.info("DEVICE : {}".format(device)) + log_file_path = join(cfg.logs_dir, 'log_test_' + timestamp + '.txt') + log.info("Logging in file : {}".format(log_file_path)) + log.addHandler(logging.FileHandler(log_file_path)) + + batcher = self.get_batcher(device) + + test_dataset = dataset.get_split('test') + test_sampler = test_dataset.sampler + test_split = TorchDataloader(dataset=test_dataset, + preprocess=model.preprocess, + transform=model.transform, + sampler=test_sampler, + use_cache=dataset.cfg.use_cache) + test_loader = DataLoader(test_split, + batch_size=cfg.test_batch_size, + sampler=get_sampler(test_sampler), + collate_fn=batcher.collate_fn) + + self.dataset_split = test_dataset + + self.load_ckpt(model.cfg.ckpt_path) + + model.trans_point_sampler = test_sampler.get_point_sampler() + self.curr_cloud_id = -1 + self.test_probs = [] + self.test_labels = [] + self.ori_test_probs = [] + self.ori_test_labels = [] + + record_summary = cfg.get('summary').get('record_for', []) + log.info("Started testing") + + with torch.no_grad(): + for unused_step, inputs in enumerate(test_loader): + if hasattr(inputs['data'], 'to'): + inputs['data'].to(device) + results = model(inputs['data']) + self.update_tests(test_sampler, inputs, results) + + if self.complete_infer: + inference_result = { + 'predict_labels': self.ori_test_labels.pop(), + 'predict_scores': self.ori_test_probs.pop() + } + attr = self.dataset_split.get_attr(test_sampler.cloud_id) + gt_labels = self.dataset_split.get_data( + test_sampler.cloud_id)['label'] + if (gt_labels > 0).any(): + valid_scores, valid_labels = filter_valid_label( + torch.tensor( + inference_result['predict_scores']).to(device), + torch.tensor(gt_labels).to(device), + model.cfg.num_classes, model.cfg.ignored_label_inds, + device) + + self.metric_test.update(valid_scores, valid_labels) + log.info(f"Accuracy : {self.metric_test.acc()}") + log.info(f"IoU : {self.metric_test.iou()}") + dataset.save_test_result(inference_result, attr) + # Save only for the first batch + if 'test' in record_summary and 'test' not in self.summary: + self.summary['test'] = self.get_3d_summary( + results, inputs['data'], 0, save_gt=False) + log.info( + f"Overall Testing Accuracy : {self.metric_test.acc()[-1]}, mIoU : {self.metric_test.iou()[-1]}" + ) + + log.info("Finished testing") + + def update_tests(self, sampler, inputs, results): + """Update tests using sampler, inputs, and results.""" + split = sampler.split + end_threshold = 0.5 + if self.curr_cloud_id != sampler.cloud_id: + self.curr_cloud_id = sampler.cloud_id + num_points = sampler.possibilities[sampler.cloud_id].shape[0] + self.pbar = tqdm(total=num_points, + desc="{} {}/{}".format(split, self.curr_cloud_id, + len(sampler.dataset))) + self.pbar_update = 0 + self.test_probs.append( + np.zeros(shape=[num_points, self.model.cfg.num_classes], + dtype=np.float16)) + self.test_labels.append(np.zeros(shape=[num_points], + dtype=np.int16)) + self.complete_infer = False + + this_possiblility = sampler.possibilities[sampler.cloud_id] + self.pbar.update( + this_possiblility[this_possiblility > end_threshold].shape[0] - + self.pbar_update) + self.pbar_update = this_possiblility[ + this_possiblility > end_threshold].shape[0] + self.test_probs[self.curr_cloud_id], self.test_labels[ + self.curr_cloud_id] = self.model.update_probs( + inputs, results, self.test_probs[self.curr_cloud_id], + self.test_labels[self.curr_cloud_id]) + + if (split in ['test'] and + this_possiblility[this_possiblility > end_threshold].shape[0] + == this_possiblility.shape[0]): + + proj_inds = self.model.preprocess( + self.dataset_split.get_data(self.curr_cloud_id), { + 'split': split + }).get('proj_inds', None) + if proj_inds is None: + proj_inds = np.arange( + self.test_probs[self.curr_cloud_id].shape[0]) + self.ori_test_probs.append( + self.test_probs[self.curr_cloud_id][proj_inds]) + self.ori_test_labels.append( + self.test_labels[self.curr_cloud_id][proj_inds]) + self.complete_infer = True + + def run_train(self): + torch.manual_seed(self.rng.integers(np.iinfo( + np.int32).max)) # Random reproducible seed for torch + rank = self.rank # Rank for distributed training + model = self.model + device = self.device + dataset = self.dataset + num_heads = model.num_heads + + cfg = self.cfg + if rank == 0: + log.info("DEVICE : {}".format(device)) + timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') + + log_file_path = join(cfg.logs_dir, + 'log_train_' + timestamp + '.txt') + log.info("Logging in file : {}".format(log_file_path)) + log.addHandler(logging.FileHandler(log_file_path)) + + Loss = SemSegLossV2(model.num_heads, model.cfg.num_classes, + model.cfg.ignored_label_inds) + self.metric_train = [SemSegMetric() for i in range(num_heads)] + self.metric_val = [SemSegMetric() for i in range(num_heads)] + + self.batcher = self.get_batcher(device) + + train_dataset = dataset.get_split('train') + train_sampler = None + + valid_dataset = dataset.get_split('val') + valid_sampler = None + + train_split = TorchDataloader(dataset=train_dataset, + preprocess=model.preprocess, + transform=model.transform, + sampler=train_sampler, + use_cache=dataset.cfg.use_cache, + steps_per_epoch=dataset.cfg.get( + 'steps_per_epoch_train', None)) + + valid_split = TorchDataloader(dataset=valid_dataset, + preprocess=model.preprocess, + transform=model.transform, + sampler=valid_sampler, + use_cache=dataset.cfg.use_cache, + steps_per_epoch=dataset.cfg.get( + 'steps_per_epoch_valid', None)) + + if self.distributed: + if train_sampler is not None or valid_sampler is not None: + raise NotImplementedError( + "Distributed training with sampler is not supported yet!") + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_split) + valid_sampler = torch.utils.data.distributed.DistributedSampler( + valid_split) + + train_loader = DataLoader( + train_split, + batch_size=cfg.batch_size, + sampler=train_sampler + if self.distributed else get_sampler(train_sampler), + num_workers=cfg.get('num_workers', 0), + pin_memory=cfg.get('pin_memory', False), + collate_fn=self.batcher.collate_fn, + worker_init_fn=lambda x: np.random.seed(x + np.uint32( + torch.utils.data.get_worker_info().seed)) + ) # numpy expects np.uint32, whereas torch returns np.uint64. + + valid_loader = DataLoader( + valid_split, + batch_size=cfg.val_batch_size, + sampler=valid_sampler + if self.distributed else get_sampler(valid_sampler), + num_workers=cfg.get('num_workers', 0), + pin_memory=cfg.get('pin_memory', True), + collate_fn=self.batcher.collate_fn, + worker_init_fn=lambda x: np.random.seed(x + np.uint32( + torch.utils.data.get_worker_info().seed))) + + # Optimizer must be created after moving model to specific device. + model.to(self.device) + model.device = self.device + + self.optimizer, self.scheduler = model.get_optimizer(cfg) + + is_resume = model.cfg.get('is_resume', True) + start_ep = self.load_ckpt(model.cfg.ckpt_path, is_resume=is_resume) + + dataset_name = dataset.name if dataset is not None else '' + tensorboard_dir = join( + self.cfg.train_sum_dir, + model.__class__.__name__ + '_' + dataset_name + '_torch') + runid = get_runid(tensorboard_dir) + self.tensorboard_dir = join(self.cfg.train_sum_dir, + runid + '_' + Path(tensorboard_dir).name) + + writer = SummaryWriter(self.tensorboard_dir) + if rank == 0: + self.save_config(writer) + log.info("Writing summary in {}.".format(self.tensorboard_dir)) + + # wrap model for multiple GPU + if self.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[self.device]) + model.get_loss = model.module.get_loss + model.cfg = model.module.cfg + + record_summary = cfg.get('summary').get('record_for', + []) if self.rank == 0 else [] + + if rank == 0: + log.info("Started training") + + for epoch in range(start_ep, cfg.max_epoch + 1): + log.info(f'=== EPOCH {epoch:d}/{cfg.max_epoch:d} ===') + if self.distributed: + train_sampler.set_epoch(epoch) + + model.train() + for i in range(num_heads): + self.metric_train[i].reset() + self.metric_val[i].reset() + self.losses = [[] for i in range(num_heads)] + # model.trans_point_sampler = train_sampler.get_point_sampler() # TODO: fix this for model with samplers. + + progress_bar = tqdm(train_loader, desc='training') + for inputs in progress_bar: + if hasattr(inputs['data'], 'to'): + inputs['data'].to(device) + dataset_idx = inputs['data'].dataset_idx + self.optimizer.zero_grad() + results = model(inputs['data']) + loss, gt_labels, predict_scores = model.get_loss( + Loss, results, inputs, device) + + if predict_scores.size()[-1] == 0: + continue + + loss.backward() + if model.cfg.get('grad_clip_norm', -1) > 0: + if self.distributed: + torch.nn.utils.clip_grad_value_( + model.module.parameters(), model.cfg.grad_clip_norm) + else: + torch.nn.utils.clip_grad_value_( + model.parameters(), model.cfg.grad_clip_norm) + + self.optimizer.step() + + self.metric_train[dataset_idx].update(predict_scores, gt_labels) + + self.losses[dataset_idx].append(loss.cpu().item()) + + # Save only for the first pcd in batch + if 'train' in record_summary and progress_bar.n == 0: + self.summary['train'] = self.get_3d_summary( + results, inputs['data'], epoch) + + desc = "training - Epoch: %d, loss (dataset : %d): %.3f" % ( + epoch, dataset_idx, loss.cpu().item()) + # if rank == 0: + progress_bar.set_description(desc) + progress_bar.refresh() + + if self.distributed: + dist.barrier() + + if self.scheduler is not None: + self.scheduler.step() + + # --------------------- validation + model.eval() + self.valid_losses = [[] for i in range(num_heads)] + # model.trans_point_sampler = valid_sampler.get_point_sampler() + + with torch.no_grad(): + for step, inputs in enumerate( + tqdm(valid_loader, desc='validation')): + if hasattr(inputs['data'], 'to'): + inputs['data'].to(device) + + results = model(inputs['data']) + loss, gt_labels, predict_scores = model.get_loss( + Loss, results, inputs, device) + + if predict_scores.size()[-1] == 0: + continue + + self.metric_val.update(predict_scores, gt_labels) + + self.valid_losses[dataset_idx].append(loss.cpu().item()) + # Save only for the first batch + if 'valid' in record_summary and step == 0: + self.summary['valid'] = self.get_3d_summary( + results, inputs['data'], epoch) + + if self.distributed: + metric_gather = [None for _ in range(dist.get_world_size())] + dist.gather_object(self.metric_val, + metric_gather if rank == 0 else None, + dst=0) + if rank == 0: + for m in metric_gather[1:]: + self.metric_val += m + + dist.barrier() + + if rank == 0: + for i in range(num_heads): + self.save_logs(writer, epoch, i) + if epoch % cfg.save_ckpt_freq == 0 or epoch == cfg.max_epoch: + self.save_ckpt(epoch) + + def get_batcher(self, device, split='training'): + """Get the batcher to be used based on the device and split.""" + batcher_name = getattr(self.model.cfg, 'batcher') + + if batcher_name == 'DefaultBatcher': + batcher = DefaultBatcher() + elif batcher_name == 'ConcatBatcher': + batcher = ConcatBatcher(device, self.model.cfg.name) + else: + batcher = None + return batcher + + def get_3d_summary(self, results, input_data, epoch, save_gt=True): + """ + Create visualization for network inputs and outputs. + + Args: + results: Model output (see below). + input_data: Model input (see below). + epoch (int): step + save_gt (bool): Save ground truth (for 'train' or 'valid' stages). + + RandLaNet: + results (Tensor(B, N, C)): Prediction scores for all classes + inputs_batch: Batch of pointclouds and labels as a Dict with keys: + 'xyz': First element is Tensor(B,N,3) points + 'labels': (B, N) (optional) labels + + SparseConvUNet: + results (Tensor(SN, C)): Prediction scores for all classes. SN is + total points in the batch. + input_batch (Dict): Batch of pointclouds and labels. Keys should be: + 'point' [Tensor(SN,3), float]: Concatenated points. + 'batch_lengths' [Tensor(B,), int]: Number of points in each + point cloud of the batch. + 'label' [Tensor(SN,) (optional)]: Concatenated labels. + + Returns: + [Dict] visualizations of inputs and outputs suitable to save as an + Open3D for TensorBoard summary. + """ + if not hasattr(self, "_first_step"): + self._first_step = epoch + label_to_names = self.dataset.get_label_to_names() + cfg = self.cfg.get('summary') + max_pts = cfg.get('max_pts') + if max_pts is None: + max_pts = np.iinfo(np.int32).max + use_reference = cfg.get('use_reference', False) + max_outputs = cfg.get('max_outputs', 1) + input_pcd = [] + gt_labels = [] + predict_labels = [] + + def to_sum_fmt(tensor, add_dims=(0, 0), dtype=torch.int32): + sten = tensor.cpu().detach().type(dtype) + new_shape = (1,) * add_dims[0] + sten.shape + (1,) * add_dims[1] + return sten.reshape(new_shape) + + # Variable size point clouds + if self.model.cfg['name'] in ('KPFCNN', 'KPConv'): + batch_lengths = input_data.lengths[0].detach().numpy() + row_splits = np.hstack(((0,), np.cumsum(batch_lengths))) + max_outputs = min(max_outputs, len(row_splits) - 1) + for k in range(max_outputs): + blen_k = row_splits[k + 1] - row_splits[k] + pcd_step = int(np.ceil(blen_k / min(max_pts, blen_k))) + res_pcd = results[row_splits[k]:row_splits[k + 1]:pcd_step, :] + predict_labels.append( + to_sum_fmt(torch.argmax(res_pcd, 1), (0, 1))) + if self._first_step != epoch and use_reference: + continue + pointcloud = input_data.points[0][ + row_splits[k]:row_splits[k + 1]:pcd_step] + input_pcd.append( + to_sum_fmt(pointcloud[:, :3], (0, 0), torch.float32)) + if torch.any(input_data.labels != 0): + gtl = input_data.labels[row_splits[k]:row_splits[k + 1]] + gt_labels.append(to_sum_fmt(gtl, (0, 1))) + + elif self.model.cfg['name'] in ('SparseConvUnet', 'PointTransformer'): + if self.model.cfg['name'] == 'SparseConvUnet': + row_splits = np.hstack( + ((0,), np.cumsum(input_data.batch_lengths))) + else: + row_splits = input_data.row_splits + max_outputs = min(max_outputs, len(row_splits) - 1) + for k in range(max_outputs): + blen_k = row_splits[k + 1] - row_splits[k] + pcd_step = int(np.ceil(blen_k / min(max_pts, blen_k))) + res_pcd = results[row_splits[k]:row_splits[k + 1]:pcd_step, :] + predict_labels.append( + to_sum_fmt(torch.argmax(res_pcd, 1), (0, 1))) + if self._first_step != epoch and use_reference: + continue + if self.model.cfg['name'] == 'SparseConvUnet': + pointcloud = input_data.point[k] + else: + pointcloud = input_data.point[ + row_splits[k]:row_splits[k + 1]:pcd_step] + input_pcd.append( + to_sum_fmt(pointcloud[:, :3], (0, 0), torch.float32)) + if getattr(input_data, 'label', None) is not None: + if self.model.cfg['name'] == 'SparseConvUnet': + gtl = input_data.label[k] + else: + gtl = input_data.label[ + row_splits[k]:row_splits[k + 1]:pcd_step] + gt_labels.append(to_sum_fmt(gtl, (0, 1))) + # Fixed size point clouds + elif self.model.cfg['name'] in ('RandLANet', 'PVCNN'): # Tuple input + if self.model.cfg['name'] == 'RandLANet': + pointcloud = input_data['xyz'][0] # 0 => input to first layer + elif self.model.cfg['name'] == 'PVCNN': + pointcloud = input_data['point'].transpose(1, 2) + pcd_step = int( + np.ceil(pointcloud.shape[1] / + min(max_pts, pointcloud.shape[1]))) + predict_labels = to_sum_fmt( + torch.argmax(results[:max_outputs, ::pcd_step, :], 2), (0, 1)) + if self._first_step == epoch or not use_reference: + input_pcd = to_sum_fmt(pointcloud[:max_outputs, ::pcd_step, :3], + (0, 0), torch.float32) + if save_gt: + gtl = input_data.get('label', + input_data.get('labels', None)) + if gtl is None: + raise ValueError("input_data does not have label(s).") + gt_labels = to_sum_fmt(gtl[:max_outputs, ::pcd_step], + (0, 1)) + else: + raise NotImplementedError( + "Saving 3D summary for the model " + f"{self.model.cfg['name']} is not implemented.") + + def get_reference_or(data_tensor): + if self._first_step == epoch or not use_reference: + return data_tensor + return self._first_step + + summary_dict = { + 'semantic_segmentation': { + "vertex_positions": get_reference_or(input_pcd), + "vertex_gt_labels": get_reference_or(gt_labels), + "vertex_predict_labels": predict_labels, + 'label_to_names': label_to_names + } + } + return summary_dict + + def save_logs(self, writer, epoch, dataset_idx): + """Save logs from the training and send results to TensorBoard.""" + train_accs = self.metric_train[dataset_idx].acc() + # val_accs = self.metric_val[dataset_idx].acc() + val_accs = self.metric_train[dataset_idx].acc() + + train_ious = self.metric_train[dataset_idx].iou() + # val_ious = self.metric_val[dataset_idx].iou() + val_ious = self.metric_train[dataset_idx].iou() + + loss_dict = { + 'Training loss': np.mean(self.losses[dataset_idx]), + 'Validation loss': np.mean(self.valid_losses[dataset_idx]) + } + acc_dicts = [{ + 'Training accuracy': acc, + 'Validation accuracy': val_acc + } for acc, val_acc in zip(train_accs, val_accs)] + + iou_dicts = [{ + 'Training IoU': iou, + 'Validation IoU': val_iou + } for iou, val_iou in zip(train_ious, val_ious)] + + for key, val in loss_dict.items(): + writer.add_scalar(key, val, epoch) + for key, val in acc_dicts[-1].items(): + writer.add_scalar("{}/ Overall".format(key), val, epoch) + for key, val in iou_dicts[-1].items(): + writer.add_scalar("{}/ Overall".format(key), val, epoch) + + log.info(f"Dataset Index : {dataset_idx}") + log.info(f"Loss train: {loss_dict['Training loss']:.3f} " + f" eval: {loss_dict['Validation loss']:.3f}") + log.info(f"Mean acc train: {acc_dicts[-1]['Training accuracy']:.3f} " + f" eval: {acc_dicts[-1]['Validation accuracy']:.3f}") + log.info(f"Mean IoU train: {iou_dicts[-1]['Training IoU']:.3f} " + f" eval: {iou_dicts[-1]['Validation IoU']:.3f}") + + for stage in self.summary: + for key, summary_dict in self.summary[stage].items(): + label_to_names = summary_dict.pop('label_to_names', None) + writer.add_3d('/'.join((stage, key)), + summary_dict, + epoch, + max_outputs=0, + label_to_names=label_to_names) + + def load_ckpt(self, ckpt_path=None, is_resume=True): + """Load a checkpoint. You must pass the checkpoint and indicate if you + want to resume. + """ + train_ckpt_dir = join(self.cfg.logs_dir, 'checkpoint') + if self.rank == 0: + make_dir(train_ckpt_dir) + if self.distributed: + dist.barrier() + + if ckpt_path is None: + ckpt_path = latest_torch_ckpt(train_ckpt_dir) + if ckpt_path is not None and is_resume: + log.info("ckpt_path not given. Restore from the latest ckpt") + else: + log.info('Initializing from scratch.') + return 0 + + if not exists(ckpt_path): + raise FileNotFoundError(f' ckpt {ckpt_path} not found') + + log.info(f'Loading checkpoint {ckpt_path}') + ckpt = torch.load(ckpt_path, map_location=self.device) + + self.model.load_state_dict(ckpt['model_state_dict']) + if 'optimizer_state_dict' in ckpt and hasattr(self, 'optimizer'): + log.info('Loading checkpoint optimizer_state_dict') + self.optimizer.load_state_dict(ckpt['optimizer_state_dict']) + if 'scheduler_state_dict' in ckpt and hasattr(self, 'scheduler'): + log.info('Loading checkpoint scheduler_state_dict') + self.scheduler.load_state_dict(ckpt['scheduler_state_dict']) + + epoch = 0 if 'epoch' not in ckpt else ckpt['epoch'] + + return epoch + + def save_ckpt(self, epoch): + """Save a checkpoint at the passed epoch.""" + path_ckpt = join(self.cfg.logs_dir, 'checkpoint') + make_dir(path_ckpt) + torch.save( + dict(epoch=epoch, + model_state_dict=self.model.state_dict(), + optimizer_state_dict=self.optimizer.state_dict(), + scheduler_state_dict=self.scheduler.state_dict()), + join(path_ckpt, f'ckpt_{epoch:05d}.pth')) + log.info(f'Epoch {epoch:3d}: save ckpt to {path_ckpt:s}') + + def save_config(self, writer): + """Save experiment configuration with tensorboard summary.""" + if hasattr(self, 'cfg_tb'): + writer.add_text("Description/Open3D-ML", self.cfg_tb['readme'], 0) + writer.add_text("Description/Command line", self.cfg_tb['cmd_line'], + 0) + writer.add_text('Configuration/Dataset', + code2md(self.cfg_tb['dataset'], language='json'), 0) + writer.add_text('Configuration/Model', + code2md(self.cfg_tb['model'], language='json'), 0) + writer.add_text('Configuration/Pipeline', + code2md(self.cfg_tb['pipeline'], language='json'), + 0) + + +PIPELINE._register_module(SemanticSegmentationMultiHead, "torch") From ea5f3035895c7b44142bdafd6d64f66cdce4c515 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Thu, 26 May 2022 03:00:45 +0530 Subject: [PATCH 40/50] add nuscenes semseg --- ml3d/datasets/__init__.py | 4 +- ml3d/datasets/nuscenes_semseg.py | 230 +++++++++++++++++++++++++++++++ ml3d/datasets/waymo_semseg.py | 2 +- 3 files changed, 234 insertions(+), 2 deletions(-) create mode 100644 ml3d/datasets/nuscenes_semseg.py diff --git a/ml3d/datasets/__init__.py b/ml3d/datasets/__init__.py index befae57ee..1677759a5 100644 --- a/ml3d/datasets/__init__.py +++ b/ml3d/datasets/__init__.py @@ -14,6 +14,7 @@ from .kitti import KITTI from .nuscenes import NuScenes +from .nuscenes_semseg import NuScenesSemSeg from .waymo import Waymo from .waymo_semseg import WaymoSemSeg from .lyft import Lyft @@ -30,5 +31,6 @@ 'Custom3D', 'utils', 'augment', 'samplers', 'KITTI', 'Waymo', 'NuScenes', 'Lyft', 'ShapeNet', 'SemSegRandomSampler', 'InferenceDummySplit', 'SemSegSpatiallyRegularSampler', 'Argoverse', 'Scannet', 'SunRGBD', - 'MatterportObjects', 'WaymoSemSeg', 'KITTI360', 'MegaLoader' + 'MatterportObjects', 'WaymoSemSeg', 'KITTI360', 'MegaLoader', + 'NuScenesSemSeg' ] diff --git a/ml3d/datasets/nuscenes_semseg.py b/ml3d/datasets/nuscenes_semseg.py new file mode 100644 index 000000000..49cbf6a1a --- /dev/null +++ b/ml3d/datasets/nuscenes_semseg.py @@ -0,0 +1,230 @@ +import os +import pickle +from os.path import join +from pathlib import Path +import logging +import numpy as np +from scipy.spatial.transform import Rotation as R + +from .base_dataset import BaseDataset +from ..utils import DATASET +from .utils import BEVBox3D +import open3d as o3d + +log = logging.getLogger(__name__) + + +class NuScenesSemSeg(BaseDataset): + """This class is used to create a dataset based on the NuScenes 3D dataset, + and used in object detection, visualizer, training, or testing. + + The NuScenes 3D dataset is best suited for autonomous driving applications. + """ + + def __init__(self, + dataset_path, + info_path=None, + name='NuScenes', + cache_dir='./logs/cache', + use_cache=False, + **kwargs): + """Initialize the function by passing the dataset and other details. + + Args: + dataset_path: The path to the dataset to use. + info_path: The path to the file that includes information about the + dataset. This is default to dataset path if nothing is provided. + name: The name of the dataset (NuScenes in this case). + cache_dir: The directory where the cache is stored. + use_cache: Indicates if the dataset should be cached. + + Returns: + class: The corresponding class. + """ + if info_path is None: + info_path = dataset_path + + super().__init__(dataset_path=dataset_path, + info_path=info_path, + name=name, + cache_dir=cache_dir, + use_cache=use_cache, + **kwargs) + + cfg = self.cfg + + self.name = cfg.name + self.dataset_path = cfg.dataset_path + self.label_to_names = self.get_label_to_names() + self.num_classes = len(self.label_to_names) + + self.train_info = {} + self.test_info = {} + self.val_info = {} + + if os.path.exists(join(info_path, 'infos_train.pkl')): + self.train_info = pickle.load( + open(join(info_path, 'infos_train.pkl'), 'rb')) + + if os.path.exists(join(info_path, 'infos_val.pkl')): + self.val_info = pickle.load( + open(join(info_path, 'infos_val.pkl'), 'rb')) + + if os.path.exists(join(info_path, 'infos_test.pkl')): + self.test_info = pickle.load( + open(join(info_path, 'infos_test.pkl'), 'rb')) + + @staticmethod + def get_label_to_names(): + """Returns a label to names dictionary object. + + Returns: + A dict where keys are label numbers and + values are the corresponding names. + """ + label_to_names = { + 0: 'ignore', + 1: 'barrier', + 2: 'bicycle', + 3: 'bus', + 4: 'car', + 5: 'construction_vehicle', + 6: 'motorcycle', + 7: 'pedestrian', + 8: 'traffic_cone', + 9: 'trailer', + 10: 'truck', + 11: 'driveable_surface', + 12: 'other_flat', + 13: 'sidewalk', + 14: 'terrain', + 15: 'manmade', + 16: 'vegetation' + } + return label_to_names + + @staticmethod + def read_lidar(path): + """Reads lidar data from the path provided. + + Returns: + A data object with lidar information. + """ + assert Path(path).exists() + + return np.fromfile(path, dtype=np.float32).reshape(-1, 5) + + @staticmethod + def read_lidarseg(path): + """Reads semantic data from the path provided. + + Returns: + A data object with semantic information. + """ + assert Path(path).exists() + + return np.fromfile(path, dtype=np.uint8).reshape(-1,).astype(np.int32) + + def get_split(self, split): + """Returns a dataset split. + + Args: + split: A string identifying the dataset split that is usually one of + 'training', 'test', 'validation', or 'all'. + + Returns: + A dataset split object providing the requested subset of the data. + """ + return NuScenesSemSegSplit(self, split=split) + + def get_split_list(self, split): + """Returns the list of data splits available. + + Args: + split: A string identifying the dataset split that is usually one of + 'training', 'test', 'validation', or 'all'. + + Returns: + A dataset split object providing the requested subset of the data. + + Raises: + ValueError: Indicates that the split name passed is incorrect. The + split name should be one of 'training', 'test', 'validation', or + 'all'. + """ + if split in ['train', 'training']: + return self.train_info + elif split in ['test', 'testing']: + return self.test_info + elif split in ['val', 'validation']: + return self.val_info + + raise ValueError("Invalid split {}".format(split)) + + def is_tested(): + """Checks if a datum in the dataset has been tested. + + Args: + dataset: The current dataset to which the datum belongs to. + attr: The attribute that needs to be checked. + + Returns: + If the dataum attribute is tested, then return the path where the + attribute is stored; else, returns false. + """ + pass + + def save_test_result(): + """Saves the output of a model. + + Args: + results: The output of a model for the datum associated with the + attribute passed. + attr: The attributes that correspond to the outputs passed in + results. + """ + pass + + +class NuScenesSemSegSplit(): + + def __init__(self, dataset, split='train'): + self.cfg = dataset.cfg + + self.infos = dataset.get_split_list(split) + self.path_list = [] + for info in self.infos: + self.path_list.append(info['lidar_path']) + + log.info("Found {} pointclouds for {}".format(len(self.infos), split)) + + self.split = split + self.dataset = dataset + + def __len__(self): + return len(self.infos) + + def get_data(self, idx): + info = self.infos[idx] + lidar_path = info['lidar_path'] + lidarseg_path = info['lidarseg_path'] + + pc = self.dataset.read_lidar(lidar_path) + feat = pc[:, 3:4] + pc = pc[:, :3] + lidarseg = self.dataset.read_lidarseg(lidarseg_path) + + data = {'point': pc, 'feat': feat, 'label': lidarseg} + + return data + + def get_attr(self, idx): + info = self.infos[idx] + pc_path = info['lidar_path'] + name = Path(pc_path).name.split('.')[0] + + attr = {'name': name, 'path': str(pc_path), 'split': self.split} + return attr + + +DATASET._register_module(NuScenesSemSeg) diff --git a/ml3d/datasets/waymo_semseg.py b/ml3d/datasets/waymo_semseg.py index 3c3725fe4..b2dbd99b8 100644 --- a/ml3d/datasets/waymo_semseg.py +++ b/ml3d/datasets/waymo_semseg.py @@ -47,8 +47,8 @@ def __init__(self, self.name = cfg.name self.dataset_path = cfg.dataset_path - self.num_classes = 23 self.label_to_names = self.get_label_to_names() + self.num_classes = len(self.label_to_names) self.shuffle = kwargs.get('shuffle', False) self.all_files = sorted( From 4ff5fbe0d283adf3992a37f806f09df1304dba0a Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Fri, 3 Jun 2022 14:53:44 +0530 Subject: [PATCH 41/50] add kitti360 --- ml3d/datasets/kitti360.py | 236 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 236 insertions(+) create mode 100644 ml3d/datasets/kitti360.py diff --git a/ml3d/datasets/kitti360.py b/ml3d/datasets/kitti360.py new file mode 100644 index 000000000..d7e16d2b8 --- /dev/null +++ b/ml3d/datasets/kitti360.py @@ -0,0 +1,236 @@ +import numpy as np +import pandas as pd +import os, pickle +import logging +import open3d as o3d + +from pathlib import Path +from os.path import join, exists, dirname, abspath, isdir +from sklearn.neighbors import KDTree +from tqdm import tqdm +from glob import glob + +from .utils import DataProcessing, get_min_bbox, BEVBox3D +from .base_dataset import BaseDataset, BaseDatasetSplit +from ..utils import make_dir, DATASET + +log = logging.getLogger(__name__) + + +class KITTI360(BaseDataset): + """This class is used to create a dataset based on the KITTI 360 + dataset, and used in visualizer, training, or testing. + """ + + def __init__(self, + dataset_path, + name='KITTI360', + cache_dir='./logs/cache', + use_cache=False, + class_weights=[ + 3370714, 2856755, 4919229, 318158, 375640, 478001, 974733, + 650464, 791496, 88727, 1284130, 229758, 2272837 + ], + num_points=40960, + ignored_label_inds=[], + test_result_folder='./test', + **kwargs): + """Initialize the function by passing the dataset and other details. + + Args: + dataset_path: The path to the dataset to use (parent directory of data_3d_semantics). + name: The name of the dataset (KITTI360 in this case). + cache_dir: The directory where the cache is stored. + use_cache: Indicates if the dataset should be cached. + class_weights: The class weights to use in the dataset. + num_points: The maximum number of points to use when splitting the dataset. + ignored_label_inds: A list of labels that should be ignored in the dataset. + test_result_folder: The folder where the test results should be stored. + """ + super().__init__(dataset_path=dataset_path, + name=name, + cache_dir=cache_dir, + use_cache=use_cache, + class_weights=class_weights, + test_result_folder=test_result_folder, + num_points=num_points, + ignored_label_inds=ignored_label_inds, + **kwargs) + + cfg = self.cfg + + self.label_to_names = self.get_label_to_names() + self.num_classes = len(self.label_to_names) + self.label_values = np.sort([k for k, v in self.label_to_names.items()]) + self.label_to_idx = {l: i for i, l in enumerate(self.label_values)} + self.ignored_labels = np.array([]) + + if not os.path.exists( + os.path.join( + dataset_path, + 'data_3d_semantics/train/2013_05_28_drive_train.txt')): + raise ValueError( + "Invalid Path, make sure dataset_path is the parent directory of data_3d_semantics." + ) + + with open( + os.path.join( + dataset_path, + 'data_3d_semantics/train/2013_05_28_drive_train.txt'), + 'r') as f: + train_paths = f.read().split('\n')[:-1] + train_paths = [os.path.join(dataset_path, p) for p in train_paths] + + with open( + os.path.join( + dataset_path, + 'data_3d_semantics/train/2013_05_28_drive_val.txt'), + 'r') as f: + val_paths = f.read().split('\n')[:-1] + val_paths = [os.path.join(dataset_path, p) for p in val_paths] + + self.train_files = train_paths + self.val_files = val_paths + self.test_files = sorted( + glob( + os.path.join(dataset_path, + 'data_3d_semantics/test/*/static/*.ply'))) + + @staticmethod + def get_label_to_names(): + """Returns a label to names dictionary object. + + Returns: + A dict where keys are label numbers and + values are the corresponding names. + """ + label_to_names = { + 0: 'ceiling', + 1: 'floor', + 2: 'wall', + 3: 'beam', + 4: 'column', + 5: 'window', + 6: 'door', + 7: 'table', + 8: 'chair', + 9: 'sofa', + 10: 'bookcase', + 11: 'board', + 12: 'clutter' + } + return label_to_names + + def get_split(self, split): + """Returns a dataset split. + + Args: + split: A string identifying the dataset split that is usually one of + 'training', 'test', 'validation', or 'all'. + + Returns: + A dataset split object providing the requested subset of the data. + """ + return KITTI360Split(self, split=split) + + def get_split_list(self, split): + if split in ['train', 'training']: + return self.train_files + elif split in ['val', 'validation']: + return self.val_files + elif split in ['test', 'testing']: + return self.test_files + elif split == 'all': + return self.train_files + self.val_files + self.test_files + else: + raise ValueError("Invalid split {}".format(split)) + + def is_tested(self, attr): + + cfg = self.cfg + name = attr['name'] + path = cfg.test_result_folder + store_path = join(path, self.name, name + '.npy') + if exists(store_path): + print("{} already exists.".format(store_path)) + return True + else: + return False + + """Saves the output of a model. + + Args: + results: The output of a model for the datum associated with the attribute passed. + attr: The attributes that correspond to the outputs passed in results. + """ + + def save_test_result(self, results, attr): + + cfg = self.cfg + name = attr['name'].split('.')[0] + path = cfg.test_result_folder + make_dir(path) + + pred = results['predict_labels'] + pred = np.array(pred) + + for ign in cfg.ignored_label_inds: + pred[pred >= ign] += 1 + + store_path = join(path, self.name, name + '.npy') + make_dir(Path(store_path).parent) + np.save(store_path, pred) + log.info("Saved {} in {}.".format(name, store_path)) + + +class KITTI360Split(BaseDatasetSplit): + """This class is used to create a split for KITTI360 dataset. + + Initialize the class. + + Args: + dataset: The dataset to split. + split: A string identifying the dataset split that is usually one of + 'training', 'test', 'validation', or 'all'. + **kwargs: The configuration of the model as keyword arguments. + + Returns: + A dataset split object providing the requested subset of the data. + """ + + def __init__(self, dataset, split='training'): + super().__init__(dataset, split=split) + log.info("Found {} pointclouds for {}".format(len(self.path_list), + split)) + + def __len__(self): + return len(self.path_list) + + def get_data(self, idx): + pc_path = self.path_list[idx] + + pc = o3d.t.io.read_point_cloud(pc_path) + + points = pc.point['positions'].numpy().astype(np.float32) + feat = pc.point['colors'].numpy().astype(np.float32) + labels = pc.point['semantic'].numpy().astype(np.int32).reshape((-1,)) + + data = { + 'point': points, + 'feat': feat, + 'label': labels, + } + + return data + + def get_attr(self, idx): + pc_path = Path(self.path_list[idx]) + name = pc_path.name.replace('.pkl', '') + + pc_path = str(pc_path) + split = self.split + attr = {'idx': idx, 'name': name, 'path': pc_path, 'split': split} + return attr + + +DATASET._register_module(KITTI360) From 43c39c7f3dcfb5eab21e18d77f6238a822c6b481 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 7 Jun 2022 05:28:22 -0700 Subject: [PATCH 42/50] fixes --- ml3d/configs/default_cfgs/nuscenes_semseg.yml | 6 +++++ ml3d/configs/default_cfgs/semantickitti.yml | 13 +++------ ml3d/configs/default_cfgs/waymo_semseg.yml | 6 +++++ ml3d/datasets/nuscenes_semseg.py | 27 ++++++------------- ml3d/datasets/semantickitti.py | 2 +- ml3d/datasets/waymo_semseg.py | 4 +-- ml3d/torch/dataloaders/concat_batcher.py | 2 +- ml3d/torch/models/sparseconvnet_megamodel.py | 18 +++++-------- ml3d/torch/modules/losses/semseg_loss.py | 4 +-- ml3d/torch/modules/metrics/semseg_metric.py | 7 +++++ .../semantic_segmentation_multi_head.py | 15 ++++++----- 11 files changed, 51 insertions(+), 53 deletions(-) create mode 100644 ml3d/configs/default_cfgs/nuscenes_semseg.yml create mode 100644 ml3d/configs/default_cfgs/waymo_semseg.yml diff --git a/ml3d/configs/default_cfgs/nuscenes_semseg.yml b/ml3d/configs/default_cfgs/nuscenes_semseg.yml new file mode 100644 index 000000000..16a667c35 --- /dev/null +++ b/ml3d/configs/default_cfgs/nuscenes_semseg.yml @@ -0,0 +1,6 @@ +name: NuScenesSemSeg +dataset_path: /export/share/projects/open3d_ml/NuScenes/processed/ +cache_dir: ./logs/cache +class_weights: [] +ignored_label_inds: [0] +use_cache: False diff --git a/ml3d/configs/default_cfgs/semantickitti.yml b/ml3d/configs/default_cfgs/semantickitti.yml index bc46ab895..733b37e01 100644 --- a/ml3d/configs/default_cfgs/semantickitti.yml +++ b/ml3d/configs/default_cfgs/semantickitti.yml @@ -1,13 +1,6 @@ name: SemanticKITTI -dataset_path: # path/to/your/dataset +dataset_path: /export/share/datasets/SemanticKITTI/ cache_dir: ./logs/cache +ignored_label_inds: [0] use_cache: false -class_weights: [55437630, 320797, 541736, 2578735, 3274484, 552662, -184064, 78858, 240942562, 17294618, 170599734, 6369672, 230413074, 101130274, -476491114, 9833174, 129609852, 4506626, 1168181] -test_result_folder: ./test -test_split: ['11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21'] -training_split: ['00', '01', '02', '03', '04', '05', '06', '07', '09', '10'] -validation_split: ['08'] -all_split: ['00', '01', '02', '03', '04', '05', '06', '07', '09', -'08', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21'] +class_weights: [] diff --git a/ml3d/configs/default_cfgs/waymo_semseg.yml b/ml3d/configs/default_cfgs/waymo_semseg.yml new file mode 100644 index 000000000..fc7c73cd8 --- /dev/null +++ b/ml3d/configs/default_cfgs/waymo_semseg.yml @@ -0,0 +1,6 @@ +name: WaymoSemSeg +dataset_path: /export/share/datasets/Waymo_1.3/processed/ +cache_dir: ./logs/cache +class_weights: [] +ignored_label_inds: [0] +use_cache: False diff --git a/ml3d/datasets/nuscenes_semseg.py b/ml3d/datasets/nuscenes_semseg.py index 49cbf6a1a..b539b01c4 100644 --- a/ml3d/datasets/nuscenes_semseg.py +++ b/ml3d/datasets/nuscenes_semseg.py @@ -82,25 +82,14 @@ def get_label_to_names(): A dict where keys are label numbers and values are the corresponding names. """ - label_to_names = { - 0: 'ignore', - 1: 'barrier', - 2: 'bicycle', - 3: 'bus', - 4: 'car', - 5: 'construction_vehicle', - 6: 'motorcycle', - 7: 'pedestrian', - 8: 'traffic_cone', - 9: 'trailer', - 10: 'truck', - 11: 'driveable_surface', - 12: 'other_flat', - 13: 'sidewalk', - 14: 'terrain', - 15: 'manmade', - 16: 'vegetation' - } + + classes = "noise, Car, Truck, Bendy Bus, Rigid Bus, Construction Vehicle, Motorcycle, Bicycle, Bicycle Rack, Trailer, Police Vehicle, Ambulance, Adult Pedestrian, Child Pedestrian, Construction Worker, Stroller, Wheelchair, Portable Personal Mobility Vehicle, Police Officer, Animal, Traffic Cone, Temporary Traffic Barrier, Pushable Pullable Object, Debris, Drivable Surface, Sidewalk, Terrain, Flat Other, Manmade, Vegetation, Static Other, Vechicle Ego" + + classes = classes.replace(', ', ',').split(',') + label_to_names = {} + for i in range(len(classes)): + label_to_names[i] = classes[i] + return label_to_names @staticmethod diff --git a/ml3d/datasets/semantickitti.py b/ml3d/datasets/semantickitti.py index d93dddcc9..8ec98f5cc 100644 --- a/ml3d/datasets/semantickitti.py +++ b/ml3d/datasets/semantickitti.py @@ -282,7 +282,7 @@ def get_data(self, idx): data = { 'point': points[:, 0:3], - 'feat': None, + 'feat': points[:, 3:4], 'label': labels, } diff --git a/ml3d/datasets/waymo_semseg.py b/ml3d/datasets/waymo_semseg.py index b2dbd99b8..bbc31b978 100644 --- a/ml3d/datasets/waymo_semseg.py +++ b/ml3d/datasets/waymo_semseg.py @@ -107,7 +107,7 @@ def get_split(self, split): Returns: A dataset split object providing the requested subset of the data. """ - return WaymoSplit(self, split=split) + return WaymoSemSegSplit(self, split=split) def get_split_list(self, split): """Returns the list of data splits available. @@ -162,7 +162,7 @@ def save_test_result(results, attr): raise NotImplementedError() -class WaymoSplit(BaseDatasetSplit): +class WaymoSemSegSplit(BaseDatasetSplit): def __init__(self, dataset, split='train'): super().__init__(dataset, split=split) diff --git a/ml3d/torch/dataloaders/concat_batcher.py b/ml3d/torch/dataloaders/concat_batcher.py index f304d7c34..caca4eb04 100644 --- a/ml3d/torch/dataloaders/concat_batcher.py +++ b/ml3d/torch/dataloaders/concat_batcher.py @@ -476,7 +476,7 @@ def __init__(self, batches): if len(set(dataset_idx)) != 1: raise ValueError( - "Multiple datasets in a single batch is not supported." + "Multiple datasets in a single batch is not supported. " "Make sure to pass same batch size to dataset and pipeline") self.point = pc diff --git a/ml3d/torch/models/sparseconvnet_megamodel.py b/ml3d/torch/models/sparseconvnet_megamodel.py index 53b8f2ffe..7c46a2c2d 100644 --- a/ml3d/torch/models/sparseconvnet_megamodel.py +++ b/ml3d/torch/models/sparseconvnet_megamodel.py @@ -32,10 +32,10 @@ def __init__( name="SparseConvUnetMegaModel", device="cuda", num_heads=1, # number of segmentation heads. - multiplier=4, # Proportional to number of neurons in each layer. - voxel_size=0.1, + multiplier=16, # Proportional to number of neurons in each layer. + voxel_size=0.05, conv_block_reps=1, # Conv block repetitions. - residual_blocks=False, + residual_blocks=True, in_channels=3, num_classes=[20], grid_size=4096, @@ -126,7 +126,7 @@ def preprocess(self, data, attr): feat = np.array(data['feat'], dtype=np.float32) if feat.shape[1] < 3: - feat = np.concatenate([feat, np.ones([feat.shape[0], 1])], 1) + feat = np.concatenate([points, feat[:, 0:1]], 1) # Scale to voxel size. points *= 1. / self.cfg.voxel_size # Scale = 1/voxel_size @@ -215,7 +215,7 @@ def get_loss(self, Loss, results, inputs, device): Returns loss, labels and scores. """ cfg = self.cfg - labels = torch.cat(inputs['data'].label, 0).to(torch.LongTensor()) + labels = torch.cat(inputs['data'].label, 0).to(torch.LongTensor()).to(results.device) loss = Loss.weighted_CrossEntropyLoss[inputs['data'].dataset_idx]( results, labels) @@ -337,12 +337,8 @@ def forward(self, features, in_positions): count.cpu().numpy()).astype(np.int32) features_avg = in_positions.clone() - features_avg[:, 0] = reduce_subarrays_sum(features[:, 0], - v.voxel_point_row_splits) - features_avg[:, 1] = reduce_subarrays_sum(features[:, 1], - v.voxel_point_row_splits) - features_avg[:, 2] = reduce_subarrays_sum(features[:, 2], - v.voxel_point_row_splits) + for i in range(features_avg.shape[1]): + features_avg[:, i] = reduce_subarrays_sum(features[:, i], v.voxel_point_row_splits) features_avg = features_avg / count.unsqueeze(1) diff --git a/ml3d/torch/modules/losses/semseg_loss.py b/ml3d/torch/modules/losses/semseg_loss.py index 12a147def..efc6ce422 100644 --- a/ml3d/torch/modules/losses/semseg_loss.py +++ b/ml3d/torch/modules/losses/semseg_loss.py @@ -57,7 +57,7 @@ def __init__(self, pipeline, model, dataset, device): class SemSegLossV2(object): """Loss functions for semantic segmentation.""" - def __init__(self, num_heads, num_classes, ignored_labels=[], weights=None): + def __init__(self, num_heads, num_classes, ignored_labels=[], device='cpu', weights=None): super(SemSegLossV2, self).__init__() # weighted_CrossEntropyLoss self.weighted_CrossEntropyLoss = [] @@ -65,6 +65,6 @@ def __init__(self, num_heads, num_classes, ignored_labels=[], weights=None): for i in range(num_heads): weights = torch.ones(num_classes[i]) weights[ignored_labels[i]] = 0 - weights = torch.tensor(weights, dtype=torch.float) + weights = torch.tensor(weights, dtype=torch.float).to(device) self.weighted_CrossEntropyLoss.append( nn.CrossEntropyLoss(weight=weights)) diff --git a/ml3d/torch/modules/metrics/semseg_metric.py b/ml3d/torch/modules/metrics/semseg_metric.py index 20f578c22..fa5f99d02 100644 --- a/ml3d/torch/modules/metrics/semseg_metric.py +++ b/ml3d/torch/modules/metrics/semseg_metric.py @@ -13,6 +13,7 @@ def __init__(self): super(SemSegMetric, self).__init__() self.confusion_matrix = None self.num_classes = None + self.count = 0 def update(self, scores, labels): conf = self.get_confusion_matrix(scores, labels) @@ -22,6 +23,7 @@ def update(self, scores, labels): else: assert self.confusion_matrix.shape == conf.shape self.confusion_matrix += conf + self.count += 1 def __iadd__(self, otherMetric): if self.confusion_matrix is None and otherMetric.confusion_matrix is None: @@ -32,6 +34,7 @@ def __iadd__(self, otherMetric): pass else: self.confusion_matrix += otherMetric.confusion_matrix + self.count += len(otherMetric) return self def acc(self): @@ -101,6 +104,10 @@ def iou(self): def reset(self): self.confusion_matrix = None + self.count = 0 + + def __len__(self): + return self.count @staticmethod def get_confusion_matrix(scores, labels): diff --git a/ml3d/torch/pipelines/semantic_segmentation_multi_head.py b/ml3d/torch/pipelines/semantic_segmentation_multi_head.py index 60eaf8bf6..7dd8fbf01 100644 --- a/ml3d/torch/pipelines/semantic_segmentation_multi_head.py +++ b/ml3d/torch/pipelines/semantic_segmentation_multi_head.py @@ -330,7 +330,7 @@ def run_train(self): log.addHandler(logging.FileHandler(log_file_path)) Loss = SemSegLossV2(model.num_heads, model.cfg.num_classes, - model.cfg.ignored_label_inds) + model.cfg.ignored_label_inds, device=device) self.metric_train = [SemSegMetric() for i in range(num_heads)] self.metric_val = [SemSegMetric() for i in range(num_heads)] @@ -492,7 +492,7 @@ def run_train(self): tqdm(valid_loader, desc='validation')): if hasattr(inputs['data'], 'to'): inputs['data'].to(device) - + dataset_idx = inputs['data'].dataset_idx results = model(inputs['data']) loss, gt_labels, predict_scores = model.get_loss( Loss, results, inputs, device) @@ -500,7 +500,7 @@ def run_train(self): if predict_scores.size()[-1] == 0: continue - self.metric_val.update(predict_scores, gt_labels) + self.metric_val[dataset_idx].update(predict_scores, gt_labels) self.valid_losses[dataset_idx].append(loss.cpu().item()) # Save only for the first batch @@ -677,13 +677,14 @@ def get_reference_or(data_tensor): def save_logs(self, writer, epoch, dataset_idx): """Save logs from the training and send results to TensorBoard.""" + if len(self.metric_train[dataset_idx]) == 0 or len(self.metric_val[dataset_idx]) == 0: + return + train_accs = self.metric_train[dataset_idx].acc() - # val_accs = self.metric_val[dataset_idx].acc() - val_accs = self.metric_train[dataset_idx].acc() + val_accs = self.metric_val[dataset_idx].acc() train_ious = self.metric_train[dataset_idx].iou() - # val_ious = self.metric_val[dataset_idx].iou() - val_ious = self.metric_train[dataset_idx].iou() + val_ious = self.metric_val[dataset_idx].iou() loss_dict = { 'Training loss': np.mean(self.losses[dataset_idx]), From 1803d77417ac6d96886088a1ec90a6109de8257c Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 7 Jun 2022 05:29:04 -0700 Subject: [PATCH 43/50] apply style --- ml3d/torch/models/sparseconvnet_megamodel.py | 6 ++++-- ml3d/torch/modules/losses/semseg_loss.py | 7 ++++++- .../pipelines/semantic_segmentation_multi_head.py | 12 ++++++++---- 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/ml3d/torch/models/sparseconvnet_megamodel.py b/ml3d/torch/models/sparseconvnet_megamodel.py index 7c46a2c2d..fe572bfb3 100644 --- a/ml3d/torch/models/sparseconvnet_megamodel.py +++ b/ml3d/torch/models/sparseconvnet_megamodel.py @@ -215,7 +215,8 @@ def get_loss(self, Loss, results, inputs, device): Returns loss, labels and scores. """ cfg = self.cfg - labels = torch.cat(inputs['data'].label, 0).to(torch.LongTensor()).to(results.device) + labels = torch.cat(inputs['data'].label, + 0).to(torch.LongTensor()).to(results.device) loss = Loss.weighted_CrossEntropyLoss[inputs['data'].dataset_idx]( results, labels) @@ -338,7 +339,8 @@ def forward(self, features, in_positions): features_avg = in_positions.clone() for i in range(features_avg.shape[1]): - features_avg[:, i] = reduce_subarrays_sum(features[:, i], v.voxel_point_row_splits) + features_avg[:, i] = reduce_subarrays_sum(features[:, i], + v.voxel_point_row_splits) features_avg = features_avg / count.unsqueeze(1) diff --git a/ml3d/torch/modules/losses/semseg_loss.py b/ml3d/torch/modules/losses/semseg_loss.py index efc6ce422..b1f0a4fb9 100644 --- a/ml3d/torch/modules/losses/semseg_loss.py +++ b/ml3d/torch/modules/losses/semseg_loss.py @@ -57,7 +57,12 @@ def __init__(self, pipeline, model, dataset, device): class SemSegLossV2(object): """Loss functions for semantic segmentation.""" - def __init__(self, num_heads, num_classes, ignored_labels=[], device='cpu', weights=None): + def __init__(self, + num_heads, + num_classes, + ignored_labels=[], + device='cpu', + weights=None): super(SemSegLossV2, self).__init__() # weighted_CrossEntropyLoss self.weighted_CrossEntropyLoss = [] diff --git a/ml3d/torch/pipelines/semantic_segmentation_multi_head.py b/ml3d/torch/pipelines/semantic_segmentation_multi_head.py index 7dd8fbf01..7cea84da8 100644 --- a/ml3d/torch/pipelines/semantic_segmentation_multi_head.py +++ b/ml3d/torch/pipelines/semantic_segmentation_multi_head.py @@ -329,8 +329,10 @@ def run_train(self): log.info("Logging in file : {}".format(log_file_path)) log.addHandler(logging.FileHandler(log_file_path)) - Loss = SemSegLossV2(model.num_heads, model.cfg.num_classes, - model.cfg.ignored_label_inds, device=device) + Loss = SemSegLossV2(model.num_heads, + model.cfg.num_classes, + model.cfg.ignored_label_inds, + device=device) self.metric_train = [SemSegMetric() for i in range(num_heads)] self.metric_val = [SemSegMetric() for i in range(num_heads)] @@ -500,7 +502,8 @@ def run_train(self): if predict_scores.size()[-1] == 0: continue - self.metric_val[dataset_idx].update(predict_scores, gt_labels) + self.metric_val[dataset_idx].update(predict_scores, + gt_labels) self.valid_losses[dataset_idx].append(loss.cpu().item()) # Save only for the first batch @@ -677,7 +680,8 @@ def get_reference_or(data_tensor): def save_logs(self, writer, epoch, dataset_idx): """Save logs from the training and send results to TensorBoard.""" - if len(self.metric_train[dataset_idx]) == 0 or len(self.metric_val[dataset_idx]) == 0: + if len(self.metric_train[dataset_idx]) == 0 or len( + self.metric_val[dataset_idx]) == 0: return train_accs = self.metric_train[dataset_idx].acc() From c8f4589d8a0e78b1b260eafd6b032212a9b36d21 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Thu, 9 Jun 2022 08:47:04 -0700 Subject: [PATCH 44/50] fix logging --- ml3d/torch/models/sparseconvnet_megamodel.py | 31 ++++++++++++++++--- ml3d/torch/modules/losses/semseg_loss.py | 2 +- ml3d/torch/modules/metrics/semseg_metric.py | 1 + .../semantic_segmentation_multi_head.py | 31 ++++++++++++------- 4 files changed, 47 insertions(+), 18 deletions(-) diff --git a/ml3d/torch/models/sparseconvnet_megamodel.py b/ml3d/torch/models/sparseconvnet_megamodel.py index fe572bfb3..c8dd6e1dd 100644 --- a/ml3d/torch/models/sparseconvnet_megamodel.py +++ b/ml3d/torch/models/sparseconvnet_megamodel.py @@ -34,6 +34,7 @@ def __init__( num_heads=1, # number of segmentation heads. multiplier=16, # Proportional to number of neurons in each layer. voxel_size=0.05, + varying_input_layers=False, conv_block_reps=1, # Conv block repetitions. residual_blocks=True, in_channels=3, @@ -49,6 +50,7 @@ def __init__( num_heads=num_heads, multiplier=multiplier, voxel_size=voxel_size, + varying_input_layers=varying_input_layers, conv_block_reps=conv_block_reps, residual_blocks=residual_blocks, in_channels=in_channels, @@ -63,10 +65,21 @@ def __init__( self.augmenter = SemsegAugmentation(cfg.augment, seed=self.rng) self.multiplier = cfg.multiplier self.num_heads = num_heads + self.varying_input_layers = varying_input_layers self.input_layer = InputLayer() - self.sub_sparse_conv = SubmanifoldSparseConv(in_channels=in_channels, - filters=multiplier, - kernel_size=[3, 3, 3]) + if self.varying_input_layers: + self.sub_sparse_conv = [ + SubmanifoldSparseConv(in_channels=in_channels, + filters=multiplier, + kernel_size=[3, 3, 3]) + for i in range(num_heads) + ] + self.sub_sparse_conv = nn.ModuleList(self.sub_sparse_conv) + else: + self.sub_sparse_conv = SubmanifoldSparseConv( + in_channels=in_channels, + filters=multiplier, + kernel_size=[3, 3, 3]) self.unet = UNet(conv_block_reps, [ multiplier, 2 * multiplier, 3 * multiplier, 4 * multiplier, 5 * multiplier, 6 * multiplier, 7 * multiplier @@ -93,7 +106,14 @@ def forward(self, inputs): feat_list.append(feat) index_map_list.append(index_map) - feat_list = self.sub_sparse_conv(feat_list, pos_list, voxel_size=1.0) + if self.varying_input_layers: + feat_list = self.sub_sparse_conv[inputs.dataset_idx](feat_list, + pos_list, + voxel_size=1.0) + else: + feat_list = self.sub_sparse_conv(feat_list, + pos_list, + voxel_size=1.0) feat_list = self.unet(pos_list, feat_list) feat_list = self.batch_norm(feat_list) feat_list = self.relu(feat_list) @@ -337,7 +357,8 @@ def forward(self, features, in_positions): reverse_map_sort = np.repeat(np.arange(count.shape[0]), count.cpu().numpy()).astype(np.int32) - features_avg = in_positions.clone() + features_avg = torch.zeros(in_positions.shape[0], + features.shape[1]).to(features.device) for i in range(features_avg.shape[1]): features_avg[:, i] = reduce_subarrays_sum(features[:, i], v.voxel_point_row_splits) diff --git a/ml3d/torch/modules/losses/semseg_loss.py b/ml3d/torch/modules/losses/semseg_loss.py index b1f0a4fb9..299580d17 100644 --- a/ml3d/torch/modules/losses/semseg_loss.py +++ b/ml3d/torch/modules/losses/semseg_loss.py @@ -70,6 +70,6 @@ def __init__(self, for i in range(num_heads): weights = torch.ones(num_classes[i]) weights[ignored_labels[i]] = 0 - weights = torch.tensor(weights, dtype=torch.float).to(device) + weights = weights.to(torch.float).to(device) self.weighted_CrossEntropyLoss.append( nn.CrossEntropyLoss(weight=weights)) diff --git a/ml3d/torch/modules/metrics/semseg_metric.py b/ml3d/torch/modules/metrics/semseg_metric.py index fa5f99d02..472ab8a4b 100644 --- a/ml3d/torch/modules/metrics/semseg_metric.py +++ b/ml3d/torch/modules/metrics/semseg_metric.py @@ -30,6 +30,7 @@ def __iadd__(self, otherMetric): pass elif self.confusion_matrix is None: self.confusion_matrix = otherMetric.confusion_matrix + self.num_classes = otherMetric.num_classes elif otherMetric.confusion_matrix is None: pass else: diff --git a/ml3d/torch/pipelines/semantic_segmentation_multi_head.py b/ml3d/torch/pipelines/semantic_segmentation_multi_head.py index 7cea84da8..799267c20 100644 --- a/ml3d/torch/pipelines/semantic_segmentation_multi_head.py +++ b/ml3d/torch/pipelines/semantic_segmentation_multi_head.py @@ -365,9 +365,9 @@ def run_train(self): raise NotImplementedError( "Distributed training with sampler is not supported yet!") train_sampler = torch.utils.data.distributed.DistributedSampler( - train_split) + train_split, shuffle=False) valid_sampler = torch.utils.data.distributed.DistributedSampler( - valid_split) + valid_split, shuffle=False) train_loader = DataLoader( train_split, @@ -417,7 +417,7 @@ def run_train(self): # wrap model for multiple GPU if self.distributed: model = torch.nn.parallel.DistributedDataParallel( - model, device_ids=[self.device]) + model, device_ids=[self.device], find_unused_parameters=True) model.get_loss = model.module.get_loss model.cfg = model.module.cfg @@ -512,14 +512,19 @@ def run_train(self): results, inputs['data'], epoch) if self.distributed: - metric_gather = [None for _ in range(dist.get_world_size())] - dist.gather_object(self.metric_val, - metric_gather if rank == 0 else None, + gather_arr = [None for _ in range(dist.get_world_size())] + dist.gather_object((self.metric_train, self.metric_val, + self.losses, self.valid_losses), + gather_arr if rank == 0 else None, dst=0) - if rank == 0: - for m in metric_gather[1:]: - self.metric_val += m + if rank == 0: + for m1, m2, l1, l2 in gather_arr[1:]: + for i in range(num_heads): + self.metric_train[i] += m1[i] + self.metric_val[i] += m2[i] + self.losses[i] += l1[i] + self.valid_losses[i] += l2[i] dist.barrier() if rank == 0: @@ -705,11 +710,13 @@ def save_logs(self, writer, epoch, dataset_idx): } for iou, val_iou in zip(train_ious, val_ious)] for key, val in loss_dict.items(): - writer.add_scalar(key, val, epoch) + writer.add_scalar(str(dataset_idx) + " : " + key, val, epoch) for key, val in acc_dicts[-1].items(): - writer.add_scalar("{}/ Overall".format(key), val, epoch) + writer.add_scalar("{} : {}/ Overall".format(dataset_idx, key), val, + epoch) for key, val in iou_dicts[-1].items(): - writer.add_scalar("{}/ Overall".format(key), val, epoch) + writer.add_scalar("{} : {}/ Overall".format(dataset_idx, key), val, + epoch) log.info(f"Dataset Index : {dataset_idx}") log.info(f"Loss train: {loss_dict['Training loss']:.3f} " From 5a4371ad34882d4ec4dd11ce487dc65beaad8144 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 20 Sep 2022 03:37:24 -0700 Subject: [PATCH 45/50] add distributed semseg --- ml3d/configs/default_cfgs/nuscenes_semseg.yml | 2 +- ml3d/configs/default_cfgs/semantickitti.yml | 2 +- ml3d/configs/default_cfgs/waymo_semseg.yml | 2 +- ml3d/datasets/megaloader.py | 42 +- ml3d/datasets/nuscenes_semseg.py | 49 +- ml3d/torch/dataloaders/concat_batcher.py | 3 +- ml3d/torch/models/sparseconvnet_megamodel.py | 42 +- ml3d/torch/modules/losses/semseg_loss.py | 15 +- ml3d/torch/pipelines/semantic_segmentation.py | 4 +- .../semantic_segmentation_multi_head.py | 801 ------------------ 10 files changed, 104 insertions(+), 858 deletions(-) delete mode 100644 ml3d/torch/pipelines/semantic_segmentation_multi_head.py diff --git a/ml3d/configs/default_cfgs/nuscenes_semseg.yml b/ml3d/configs/default_cfgs/nuscenes_semseg.yml index 16a667c35..056af4f05 100644 --- a/ml3d/configs/default_cfgs/nuscenes_semseg.yml +++ b/ml3d/configs/default_cfgs/nuscenes_semseg.yml @@ -1,6 +1,6 @@ name: NuScenesSemSeg dataset_path: /export/share/projects/open3d_ml/NuScenes/processed/ cache_dir: ./logs/cache -class_weights: [] +class_weights: [282265, 7676, 120, 3754, 31974, 1321, 346, 1898, 624, 4537, 13282, 260911, 6588, 57567, 56670, 146511, 100633] ignored_label_inds: [0] use_cache: False diff --git a/ml3d/configs/default_cfgs/semantickitti.yml b/ml3d/configs/default_cfgs/semantickitti.yml index 733b37e01..0f58292b8 100644 --- a/ml3d/configs/default_cfgs/semantickitti.yml +++ b/ml3d/configs/default_cfgs/semantickitti.yml @@ -3,4 +3,4 @@ dataset_path: /export/share/datasets/SemanticKITTI/ cache_dir: ./logs/cache ignored_label_inds: [0] use_cache: false -class_weights: [] +class_weights: [101665, 157022, 631, 1516, 5012, 7085, 1043, 457, 176, 693044, 53132, 494988, 12829, 459669, 236069, 924425, 22780, 255213, 9664, 2024] \ No newline at end of file diff --git a/ml3d/configs/default_cfgs/waymo_semseg.yml b/ml3d/configs/default_cfgs/waymo_semseg.yml index fc7c73cd8..c103aa3ce 100644 --- a/ml3d/configs/default_cfgs/waymo_semseg.yml +++ b/ml3d/configs/default_cfgs/waymo_semseg.yml @@ -1,6 +1,6 @@ name: WaymoSemSeg dataset_path: /export/share/datasets/Waymo_1.3/processed/ cache_dir: ./logs/cache -class_weights: [] +class_weights: [513255, 341079, 39946, 28066, 17254, 104, 1169, 31335, 23359, 2924, 43680, 2149, 483, 639, 1394353, 858409, 90903, 52484, 884591, 24487, 21477, 322212, 229034] ignored_label_inds: [0] use_cache: False diff --git a/ml3d/datasets/megaloader.py b/ml3d/datasets/megaloader.py index 8ab8c2fc4..7324ec94c 100644 --- a/ml3d/datasets/megaloader.py +++ b/ml3d/datasets/megaloader.py @@ -55,12 +55,15 @@ def __init__(self, self.ignored_labels = np.array([]) self.num_datasets = len(config_paths) + self.configs = [ Config.load_from_file(cfg_path) for cfg_path in config_paths ] self.datasets = [ get_module('dataset', cfg.name)(**cfg) for cfg in self.configs ] + self.class_weights = [self.datasets[i].cfg.class_weights for i in range(self.num_datasets)] + def get_split(self, split): """Returns a dataset split. @@ -97,7 +100,7 @@ def is_tested(self, attr): else: return False - def save_test_result(self, results, attr): + def save_test_result(self, results, name): """Saves the output of a model. Args: @@ -105,16 +108,12 @@ def save_test_result(self, results, attr): attr: The attributes that correspond to the outputs passed in results. """ cfg = self.cfg - name = attr['name'].split('.')[0] path = cfg.test_result_folder make_dir(path) pred = results['predict_labels'] pred = np.array(pred) - for ign in cfg.ignored_label_inds: - pred[pred >= ign] += 1 - store_path = join(path, self.name, name + '.npy') make_dir(Path(store_path).parent) np.save(store_path, pred) @@ -136,7 +135,7 @@ class MegaLoaderSplit(): A dataset split object providing the requested subset of the data. """ - def __init__(self, dataset, split='training'): + def __init__(self, dataset, split='training', test_dataset_idx=0): self.cfg = dataset.cfg self.split = split self.dataset = dataset @@ -144,35 +143,36 @@ def __init__(self, dataset, split='training'): a.get_split(split) for a in self.dataset.datasets ] self.num_datasets = dataset.num_datasets + self.test_dataset_idx = test_dataset_idx - # log.info("Found {} pointclouds for {}".format(len(self.path_list), - # split)) + if 'test' in split: + sampler_cls = get_module('sampler', 'SemSegSpatiallyRegularSampler') + self.sampler = sampler_cls(self) def __len__(self): + if 'test' in self.split: + return len(self.dataset_splits[self.test_dataset_idx]) lens = [len(a) for a in self.dataset_splits] return max(lens) * self.num_datasets def get_data(self, idx): - dataset_idx = (idx // self.dataset.batch_size) % self.num_datasets - idx = ((((idx // self.dataset.batch_size) // self.num_datasets) * - self.dataset.batch_size) + idx % self.dataset.batch_size) % len( - self.dataset_splits[dataset_idx]) - - # dataset_idx = idx % self.num_datasets - # idx = (idx // self.num_datasets) % len(self.dataset_splits[dataset_idx]) + if 'test' in self.split: + dataset_idx = self.test_dataset_idx + else: + dataset_idx = idx % self.num_datasets + idx = (idx // self.num_datasets) % len(self.dataset_splits[dataset_idx]) data = self.dataset_splits[dataset_idx].get_data(idx) return data def get_attr(self, idx): - dataset_idx = (idx // self.dataset.batch_size) % self.num_datasets - idx = ((((idx // self.dataset.batch_size) // self.num_datasets) * - self.dataset.batch_size) + idx % self.dataset.batch_size) % len( - self.dataset_splits[dataset_idx]) + if 'test' in self.split: + dataset_idx = self.test_dataset_idx + else: + dataset_idx = idx % self.num_datasets + idx = (idx // self.num_datasets) % len(self.dataset_splits[dataset_idx]) - # dataset_idx = idx % self.num_datasets - # idx = (idx // self.num_datasets) % len(self.dataset_splits[dataset_idx]) attr = self.dataset_splits[dataset_idx].get_attr(idx) attr['dataset_idx'] = dataset_idx diff --git a/ml3d/datasets/nuscenes_semseg.py b/ml3d/datasets/nuscenes_semseg.py index b539b01c4..e7bf4d246 100644 --- a/ml3d/datasets/nuscenes_semseg.py +++ b/ml3d/datasets/nuscenes_semseg.py @@ -74,6 +74,44 @@ def __init__(self, self.test_info = pickle.load( open(join(info_path, 'infos_test.pkl'), 'rb')) + # It comes with 32 classes, but the nuscenes challenge merge similar classes and remove rare classes. + mapping = { + 1: 0, + 5: 0, + 7: 0, + 8: 0, + 10: 0, + 11: 0, + 13: 0, + 19: 0, + 20: 0, + 0: 0, + 29: 0, + 31: 0, + 9: 1, + 14: 2, + 15: 3, + 16: 3, + 17: 4, + 18: 5, + 21: 6, + 2: 7, + 3: 7, + 4: 7, + 6: 7, + 12: 8, + 22: 9, + 23: 10, + 24: 11, + 25: 12, + 26: 13, + 27: 14, + 28: 15, + 30: 16 + } + self.label_mapping = np.array([mapping[i] for i in range(0, len(mapping))], dtype=np.int32) + + @staticmethod def get_label_to_names(): """Returns a label to names dictionary object. @@ -83,8 +121,7 @@ def get_label_to_names(): values are the corresponding names. """ - classes = "noise, Car, Truck, Bendy Bus, Rigid Bus, Construction Vehicle, Motorcycle, Bicycle, Bicycle Rack, Trailer, Police Vehicle, Ambulance, Adult Pedestrian, Child Pedestrian, Construction Worker, Stroller, Wheelchair, Portable Personal Mobility Vehicle, Police Officer, Animal, Traffic Cone, Temporary Traffic Barrier, Pushable Pullable Object, Debris, Drivable Surface, Sidewalk, Terrain, Flat Other, Manmade, Vegetation, Static Other, Vechicle Ego" - + classes = "ignore, barrier, bicycle, bus, car, construction_vehicle, motorcycle, pedestrian, traffic_cone, trailer, trucl, driveable_surface, other_flat, sidewalk, terrain, manmade, vegetation" classes = classes.replace(', ', ',').split(',') label_to_names = {} for i in range(len(classes)): @@ -103,8 +140,7 @@ def read_lidar(path): return np.fromfile(path, dtype=np.float32).reshape(-1, 5) - @staticmethod - def read_lidarseg(path): + def read_lidarseg(self, path): """Reads semantic data from the path provided. Returns: @@ -112,7 +148,10 @@ def read_lidarseg(path): """ assert Path(path).exists() - return np.fromfile(path, dtype=np.uint8).reshape(-1,).astype(np.int32) + labels = np.fromfile(path, dtype=np.uint8).reshape(-1,).astype(np.int32) + + return self.label_mapping[labels] + def get_split(self, split): """Returns a dataset split. diff --git a/ml3d/torch/dataloaders/concat_batcher.py b/ml3d/torch/dataloaders/concat_batcher.py index caca4eb04..a3b6d472b 100644 --- a/ml3d/torch/dataloaders/concat_batcher.py +++ b/ml3d/torch/dataloaders/concat_batcher.py @@ -470,8 +470,7 @@ def __init__(self, batches): attr = batch['attr'] if 'dataset_idx' not in attr: - raise ValueError( - "dataset_idx is missing. Please use MegaLoader.") + attr['dataset_idx'] = -1 dataset_idx.append(attr['dataset_idx']) if len(set(dataset_idx)) != 1: diff --git a/ml3d/torch/models/sparseconvnet_megamodel.py b/ml3d/torch/models/sparseconvnet_megamodel.py index c8dd6e1dd..f9b20e778 100644 --- a/ml3d/torch/models/sparseconvnet_megamodel.py +++ b/ml3d/torch/models/sparseconvnet_megamodel.py @@ -1,6 +1,7 @@ import numpy as np import torch import torch.nn as nn +import logging from .base_model import BaseModel from ...utils import MODEL @@ -9,6 +10,7 @@ from open3d.ml.torch.layers import SparseConv, SparseConvTranspose from open3d.ml.torch.ops import voxelize, reduce_subarrays_sum +log = logging.getLogger(__name__) class SparseConvUnetMegaModel(BaseModel): """Semantic Segmentation model. @@ -39,7 +41,7 @@ def __init__( residual_blocks=True, in_channels=3, num_classes=[20], - grid_size=4096, + grid_size=81920, batcher='ConcatBatcher', augment=None, ckpt_path=None, @@ -67,6 +69,7 @@ def __init__( self.num_heads = num_heads self.varying_input_layers = varying_input_layers self.input_layer = InputLayer() + if self.varying_input_layers: self.sub_sparse_conv = [ SubmanifoldSparseConv(in_channels=in_channels, @@ -145,11 +148,9 @@ def preprocess(self, data, attr): "SparseConvnet doesn't work without feature values.") feat = np.array(data['feat'], dtype=np.float32) - if feat.shape[1] < 3: - feat = np.concatenate([points, feat[:, 0:1]], 1) - # Scale to voxel size. - points *= 1. / self.cfg.voxel_size # Scale = 1/voxel_size + # only use xyz + feat = points.copy() if attr['split'] in ['training', 'train']: points, feat, labels = self.augmenter.augment(points, @@ -158,21 +159,20 @@ def preprocess(self, data, attr): self.cfg.get( 'augment', None), seed=rng) + + # Scale to voxel size. + points *= 1. / self.cfg.voxel_size # Scale = 1/voxel_size + m = points.min(0) M = points.max(0) # Randomly place pointcloud in 4096 size grid. grid_size = self.cfg.grid_size - offset = -m + np.clip(grid_size - M + m - 0.001, 0, None) * rng.random( - 3) + np.clip(grid_size - M + m + 0.001, None, 0) * rng.random(3) - - points += offset - idxs = (points.min(1) >= 0) * (points.max(1) < 4096) - points = points[idxs] - feat = feat[idxs] - labels = labels[idxs] + # make everything positive + offset = -1 * points.min(0) + 500 + points += offset points = (points.astype(np.int32) + 0.5).astype( np.float32) # Move points to voxel center. @@ -244,11 +244,6 @@ def get_loss(self, Loss, results, inputs, device): return loss, labels, results def get_optimizer(self, cfg_pipeline): - # optimizer = torch.optim.Adam(self.parameters(), - # **cfg_pipeline.optimizer) - # scheduler = torch.optim.lr_scheduler.ExponentialLR( - # optimizer, cfg_pipeline.scheduler_gamma) - optimizer = torch.optim.Adam(self.parameters(), lr=0.001) scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99) @@ -317,7 +312,14 @@ def __init__(self, in_dim, num_classes): def forward(self, feat_list, dataset_idx): out_list = [] for i, feat in enumerate(feat_list): - out_list.append(self.linear[dataset_idx](feat)) + if dataset_idx == -1: + out = [] + for j in range(len(self.num_classes)): + out.append(self.linear[j](feat)) + out = torch.cat(out, -1) + out_list.append(out) + else: + out_list.append(self.linear[dataset_idx](feat)) return out_list @@ -337,7 +339,7 @@ def forward(self, features, in_positions): torch.LongTensor([0, in_positions.shape[0]]).to(in_positions.device), self.voxel_size, torch.Tensor([0, 0, 0]), - torch.Tensor([40960, 40960, 40960])) + torch.Tensor([81920, 81920, 81920])) # Contiguous repeating positions. in_positions = in_positions[v.voxel_point_indices] diff --git a/ml3d/torch/modules/losses/semseg_loss.py b/ml3d/torch/modules/losses/semseg_loss.py index 299580d17..5d0e8c1a2 100644 --- a/ml3d/torch/modules/losses/semseg_loss.py +++ b/ml3d/torch/modules/losses/semseg_loss.py @@ -55,7 +55,7 @@ def __init__(self, pipeline, model, dataset, device): class SemSegLossV2(object): - """Loss functions for semantic segmentation.""" + """Loss functions for multi head semantic segmentation.""" def __init__(self, num_heads, @@ -68,8 +68,13 @@ def __init__(self, self.weighted_CrossEntropyLoss = [] for i in range(num_heads): - weights = torch.ones(num_classes[i]) - weights[ignored_labels[i]] = 0 - weights = weights.to(torch.float).to(device) + if weights is not None and len(weights[i]) != 0: + wts = DataProcessing.get_class_weights(weights[i])[0] + assert len(wts) == num_classes[i], f"num_classes : {num_classes[i]} is not equal to number of class weights : {len(wts)}" + wts = torch.tensor(wts) + else: + wts = torch.ones(num_classes[i]) + wts[ignored_labels[i]] = 0 + wts = wts.to(torch.float).to(device) self.weighted_CrossEntropyLoss.append( - nn.CrossEntropyLoss(weight=weights)) + nn.CrossEntropyLoss(weight=wts)) diff --git a/ml3d/torch/pipelines/semantic_segmentation.py b/ml3d/torch/pipelines/semantic_segmentation.py index cb3b6f3eb..355c34bad 100644 --- a/ml3d/torch/pipelines/semantic_segmentation.py +++ b/ml3d/torch/pipelines/semantic_segmentation.py @@ -160,6 +160,8 @@ def run_inference(self, data): with torch.no_grad(): for unused_step, inputs in enumerate(infer_loader): + if hasattr(inputs['data'], 'to'): + inputs['data'].to(device) results = model(inputs['data']) self.update_tests(infer_sampler, inputs, results) @@ -389,7 +391,7 @@ def run_train(self): torch.utils.data.get_worker_info().seed))) # Optimizer must be created after moving model to specific device. - model.cuda(self.device) + model.to(self.device) model.device = self.device self.optimizer, self.scheduler = model.get_optimizer(cfg) diff --git a/ml3d/torch/pipelines/semantic_segmentation_multi_head.py b/ml3d/torch/pipelines/semantic_segmentation_multi_head.py deleted file mode 100644 index 799267c20..000000000 --- a/ml3d/torch/pipelines/semantic_segmentation_multi_head.py +++ /dev/null @@ -1,801 +0,0 @@ -import logging -import numpy as np -import torch -import torch.distributed as dist - -from os.path import exists, join -from pathlib import Path -from datetime import datetime -from tqdm import tqdm -from torch.utils.tensorboard import SummaryWriter -from torch.utils.data import DataLoader - -# pylint: disable-next=unused-import -from open3d.visualization.tensorboard_plugin import summary -from .base_pipeline import BasePipeline -from ..dataloaders import get_sampler, TorchDataloader, DefaultBatcher, ConcatBatcher -from ..utils import latest_torch_ckpt -from ..modules.losses import SemSegLossV2, filter_valid_label -from ..modules.metrics import SemSegMetric -from ...utils import make_dir, PIPELINE, get_runid, code2md -from ...datasets import InferenceDummySplit - -log = logging.getLogger(__name__) - - -class SemanticSegmentationMultiHead(BasePipeline): - """This class allows you to perform semantic segmentation for both training - and inference using the Torch. This pipeline has multiple stages: Pre- - processing, loading dataset, testing, and inference or training. - - **Example:** - This example loads the Semantic Segmentation and performs a training - using the SemanticKITTI dataset. - - import torch - import torch.nn as nn - - from .base_pipeline import BasePipeline - from torch.utils.tensorboard import SummaryWriter - from ..dataloaders import get_sampler, TorchDataloader, DefaultBatcher, ConcatBatcher - - Mydataset = TorchDataloader(dataset=dataset.get_split('training')), - MyModel = SemanticSegmentation(self,model,dataset=Mydataset, name='SemanticSegmentation', - name='MySemanticSegmentation', - batch_size=4, - val_batch_size=4, - test_batch_size=3, - max_epoch=100, - learning_rate=1e-2, - lr_decays=0.95, - save_ckpt_freq=20, - adam_lr=1e-2, - scheduler_gamma=0.95, - momentum=0.98, - main_log_dir='./logs/', - device='gpu', - split='train', - train_sum_dir='train_log') - - **Args:** - dataset: The 3D ML dataset class. You can use the base dataset, sample datasets , or a custom dataset. - model: The model to be used for building the pipeline. - name: The name of the current training. - batch_size: The batch size to be used for training. - val_batch_size: The batch size to be used for validation. - test_batch_size: The batch size to be used for testing. - max_epoch: The maximum size of the epoch to be used for training. - leanring_rate: The hyperparameter that controls the weights during training. Also, known as step size. - lr_decays: The learning rate decay for the training. - save_ckpt_freq: The frequency in which the checkpoint should be saved. - adam_lr: The leanring rate to be applied for Adam optimization. - scheduler_gamma: The decaying factor associated with the scheduler. - momentum: The momentum that accelerates the training rate schedule. - main_log_dir: The directory where logs are stored. - device: The device to be used for training. - split: The dataset split to be used. In this example, we have used "train". - train_sum_dir: The directory where the trainig summary is stored. - - **Returns:** - class: The corresponding class. - """ - - def __init__( - self, - model, - dataset=None, - name='SemanticSegmentation', - batch_size=4, - val_batch_size=4, - test_batch_size=3, - max_epoch=100, # maximum epoch during training - learning_rate=1e-2, # initial learning rate - lr_decays=0.95, - save_ckpt_freq=20, - adam_lr=1e-2, - scheduler_gamma=0.95, - momentum=0.98, - main_log_dir='./logs/', - device='cuda', - split='train', - train_sum_dir='train_log', - **kwargs): - - super().__init__(model=model, - dataset=dataset, - name=name, - batch_size=batch_size, - val_batch_size=val_batch_size, - test_batch_size=test_batch_size, - max_epoch=max_epoch, - learning_rate=learning_rate, - lr_decays=lr_decays, - save_ckpt_freq=save_ckpt_freq, - adam_lr=adam_lr, - scheduler_gamma=scheduler_gamma, - momentum=momentum, - main_log_dir=main_log_dir, - device=device, - split=split, - train_sum_dir=train_sum_dir, - **kwargs) - - def run_inference(self, data): - """Run inference on given data. - - Args: - data: A raw data. - - Returns: - Returns the inference results. - """ - cfg = self.cfg - model = self.model - device = self.device - - model.to(device) - model.device = device - model.eval() - - batcher = self.get_batcher(device) - infer_dataset = InferenceDummySplit(data) - self.dataset_split = infer_dataset - infer_sampler = infer_dataset.sampler - infer_split = TorchDataloader(dataset=infer_dataset, - preprocess=model.preprocess, - transform=model.transform, - sampler=infer_sampler, - use_cache=False) - infer_loader = DataLoader(infer_split, - batch_size=cfg.batch_size, - sampler=get_sampler(infer_sampler), - collate_fn=batcher.collate_fn) - - model.trans_point_sampler = infer_sampler.get_point_sampler() - self.curr_cloud_id = -1 - self.test_probs = [] - self.test_labels = [] - self.ori_test_probs = [] - self.ori_test_labels = [] - - with torch.no_grad(): - for unused_step, inputs in enumerate(infer_loader): - results = model(inputs['data']) - self.update_tests(infer_sampler, inputs, results) - - inference_result = { - 'predict_labels': self.ori_test_labels.pop(), - 'predict_scores': self.ori_test_probs.pop() - } - - metric = SemSegMetric() - - valid_scores, valid_labels = filter_valid_label( - torch.tensor(inference_result['predict_scores']), - torch.tensor(data['label']), model.cfg.num_classes, - model.cfg.ignored_label_inds, device) - - metric.update(valid_scores, valid_labels) - log.info(f"Accuracy : {metric.acc()}") - log.info(f"IoU : {metric.iou()}") - - return inference_result - - def run_test(self): - """Run the test using the data passed.""" - model = self.model - dataset = self.dataset - device = self.device - cfg = self.cfg - model.device = device - model.to(device) - model.eval() - self.metric_test = SemSegMetric() - - timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') - - log.info("DEVICE : {}".format(device)) - log_file_path = join(cfg.logs_dir, 'log_test_' + timestamp + '.txt') - log.info("Logging in file : {}".format(log_file_path)) - log.addHandler(logging.FileHandler(log_file_path)) - - batcher = self.get_batcher(device) - - test_dataset = dataset.get_split('test') - test_sampler = test_dataset.sampler - test_split = TorchDataloader(dataset=test_dataset, - preprocess=model.preprocess, - transform=model.transform, - sampler=test_sampler, - use_cache=dataset.cfg.use_cache) - test_loader = DataLoader(test_split, - batch_size=cfg.test_batch_size, - sampler=get_sampler(test_sampler), - collate_fn=batcher.collate_fn) - - self.dataset_split = test_dataset - - self.load_ckpt(model.cfg.ckpt_path) - - model.trans_point_sampler = test_sampler.get_point_sampler() - self.curr_cloud_id = -1 - self.test_probs = [] - self.test_labels = [] - self.ori_test_probs = [] - self.ori_test_labels = [] - - record_summary = cfg.get('summary').get('record_for', []) - log.info("Started testing") - - with torch.no_grad(): - for unused_step, inputs in enumerate(test_loader): - if hasattr(inputs['data'], 'to'): - inputs['data'].to(device) - results = model(inputs['data']) - self.update_tests(test_sampler, inputs, results) - - if self.complete_infer: - inference_result = { - 'predict_labels': self.ori_test_labels.pop(), - 'predict_scores': self.ori_test_probs.pop() - } - attr = self.dataset_split.get_attr(test_sampler.cloud_id) - gt_labels = self.dataset_split.get_data( - test_sampler.cloud_id)['label'] - if (gt_labels > 0).any(): - valid_scores, valid_labels = filter_valid_label( - torch.tensor( - inference_result['predict_scores']).to(device), - torch.tensor(gt_labels).to(device), - model.cfg.num_classes, model.cfg.ignored_label_inds, - device) - - self.metric_test.update(valid_scores, valid_labels) - log.info(f"Accuracy : {self.metric_test.acc()}") - log.info(f"IoU : {self.metric_test.iou()}") - dataset.save_test_result(inference_result, attr) - # Save only for the first batch - if 'test' in record_summary and 'test' not in self.summary: - self.summary['test'] = self.get_3d_summary( - results, inputs['data'], 0, save_gt=False) - log.info( - f"Overall Testing Accuracy : {self.metric_test.acc()[-1]}, mIoU : {self.metric_test.iou()[-1]}" - ) - - log.info("Finished testing") - - def update_tests(self, sampler, inputs, results): - """Update tests using sampler, inputs, and results.""" - split = sampler.split - end_threshold = 0.5 - if self.curr_cloud_id != sampler.cloud_id: - self.curr_cloud_id = sampler.cloud_id - num_points = sampler.possibilities[sampler.cloud_id].shape[0] - self.pbar = tqdm(total=num_points, - desc="{} {}/{}".format(split, self.curr_cloud_id, - len(sampler.dataset))) - self.pbar_update = 0 - self.test_probs.append( - np.zeros(shape=[num_points, self.model.cfg.num_classes], - dtype=np.float16)) - self.test_labels.append(np.zeros(shape=[num_points], - dtype=np.int16)) - self.complete_infer = False - - this_possiblility = sampler.possibilities[sampler.cloud_id] - self.pbar.update( - this_possiblility[this_possiblility > end_threshold].shape[0] - - self.pbar_update) - self.pbar_update = this_possiblility[ - this_possiblility > end_threshold].shape[0] - self.test_probs[self.curr_cloud_id], self.test_labels[ - self.curr_cloud_id] = self.model.update_probs( - inputs, results, self.test_probs[self.curr_cloud_id], - self.test_labels[self.curr_cloud_id]) - - if (split in ['test'] and - this_possiblility[this_possiblility > end_threshold].shape[0] - == this_possiblility.shape[0]): - - proj_inds = self.model.preprocess( - self.dataset_split.get_data(self.curr_cloud_id), { - 'split': split - }).get('proj_inds', None) - if proj_inds is None: - proj_inds = np.arange( - self.test_probs[self.curr_cloud_id].shape[0]) - self.ori_test_probs.append( - self.test_probs[self.curr_cloud_id][proj_inds]) - self.ori_test_labels.append( - self.test_labels[self.curr_cloud_id][proj_inds]) - self.complete_infer = True - - def run_train(self): - torch.manual_seed(self.rng.integers(np.iinfo( - np.int32).max)) # Random reproducible seed for torch - rank = self.rank # Rank for distributed training - model = self.model - device = self.device - dataset = self.dataset - num_heads = model.num_heads - - cfg = self.cfg - if rank == 0: - log.info("DEVICE : {}".format(device)) - timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') - - log_file_path = join(cfg.logs_dir, - 'log_train_' + timestamp + '.txt') - log.info("Logging in file : {}".format(log_file_path)) - log.addHandler(logging.FileHandler(log_file_path)) - - Loss = SemSegLossV2(model.num_heads, - model.cfg.num_classes, - model.cfg.ignored_label_inds, - device=device) - self.metric_train = [SemSegMetric() for i in range(num_heads)] - self.metric_val = [SemSegMetric() for i in range(num_heads)] - - self.batcher = self.get_batcher(device) - - train_dataset = dataset.get_split('train') - train_sampler = None - - valid_dataset = dataset.get_split('val') - valid_sampler = None - - train_split = TorchDataloader(dataset=train_dataset, - preprocess=model.preprocess, - transform=model.transform, - sampler=train_sampler, - use_cache=dataset.cfg.use_cache, - steps_per_epoch=dataset.cfg.get( - 'steps_per_epoch_train', None)) - - valid_split = TorchDataloader(dataset=valid_dataset, - preprocess=model.preprocess, - transform=model.transform, - sampler=valid_sampler, - use_cache=dataset.cfg.use_cache, - steps_per_epoch=dataset.cfg.get( - 'steps_per_epoch_valid', None)) - - if self.distributed: - if train_sampler is not None or valid_sampler is not None: - raise NotImplementedError( - "Distributed training with sampler is not supported yet!") - train_sampler = torch.utils.data.distributed.DistributedSampler( - train_split, shuffle=False) - valid_sampler = torch.utils.data.distributed.DistributedSampler( - valid_split, shuffle=False) - - train_loader = DataLoader( - train_split, - batch_size=cfg.batch_size, - sampler=train_sampler - if self.distributed else get_sampler(train_sampler), - num_workers=cfg.get('num_workers', 0), - pin_memory=cfg.get('pin_memory', False), - collate_fn=self.batcher.collate_fn, - worker_init_fn=lambda x: np.random.seed(x + np.uint32( - torch.utils.data.get_worker_info().seed)) - ) # numpy expects np.uint32, whereas torch returns np.uint64. - - valid_loader = DataLoader( - valid_split, - batch_size=cfg.val_batch_size, - sampler=valid_sampler - if self.distributed else get_sampler(valid_sampler), - num_workers=cfg.get('num_workers', 0), - pin_memory=cfg.get('pin_memory', True), - collate_fn=self.batcher.collate_fn, - worker_init_fn=lambda x: np.random.seed(x + np.uint32( - torch.utils.data.get_worker_info().seed))) - - # Optimizer must be created after moving model to specific device. - model.to(self.device) - model.device = self.device - - self.optimizer, self.scheduler = model.get_optimizer(cfg) - - is_resume = model.cfg.get('is_resume', True) - start_ep = self.load_ckpt(model.cfg.ckpt_path, is_resume=is_resume) - - dataset_name = dataset.name if dataset is not None else '' - tensorboard_dir = join( - self.cfg.train_sum_dir, - model.__class__.__name__ + '_' + dataset_name + '_torch') - runid = get_runid(tensorboard_dir) - self.tensorboard_dir = join(self.cfg.train_sum_dir, - runid + '_' + Path(tensorboard_dir).name) - - writer = SummaryWriter(self.tensorboard_dir) - if rank == 0: - self.save_config(writer) - log.info("Writing summary in {}.".format(self.tensorboard_dir)) - - # wrap model for multiple GPU - if self.distributed: - model = torch.nn.parallel.DistributedDataParallel( - model, device_ids=[self.device], find_unused_parameters=True) - model.get_loss = model.module.get_loss - model.cfg = model.module.cfg - - record_summary = cfg.get('summary').get('record_for', - []) if self.rank == 0 else [] - - if rank == 0: - log.info("Started training") - - for epoch in range(start_ep, cfg.max_epoch + 1): - log.info(f'=== EPOCH {epoch:d}/{cfg.max_epoch:d} ===') - if self.distributed: - train_sampler.set_epoch(epoch) - - model.train() - for i in range(num_heads): - self.metric_train[i].reset() - self.metric_val[i].reset() - self.losses = [[] for i in range(num_heads)] - # model.trans_point_sampler = train_sampler.get_point_sampler() # TODO: fix this for model with samplers. - - progress_bar = tqdm(train_loader, desc='training') - for inputs in progress_bar: - if hasattr(inputs['data'], 'to'): - inputs['data'].to(device) - dataset_idx = inputs['data'].dataset_idx - self.optimizer.zero_grad() - results = model(inputs['data']) - loss, gt_labels, predict_scores = model.get_loss( - Loss, results, inputs, device) - - if predict_scores.size()[-1] == 0: - continue - - loss.backward() - if model.cfg.get('grad_clip_norm', -1) > 0: - if self.distributed: - torch.nn.utils.clip_grad_value_( - model.module.parameters(), model.cfg.grad_clip_norm) - else: - torch.nn.utils.clip_grad_value_( - model.parameters(), model.cfg.grad_clip_norm) - - self.optimizer.step() - - self.metric_train[dataset_idx].update(predict_scores, gt_labels) - - self.losses[dataset_idx].append(loss.cpu().item()) - - # Save only for the first pcd in batch - if 'train' in record_summary and progress_bar.n == 0: - self.summary['train'] = self.get_3d_summary( - results, inputs['data'], epoch) - - desc = "training - Epoch: %d, loss (dataset : %d): %.3f" % ( - epoch, dataset_idx, loss.cpu().item()) - # if rank == 0: - progress_bar.set_description(desc) - progress_bar.refresh() - - if self.distributed: - dist.barrier() - - if self.scheduler is not None: - self.scheduler.step() - - # --------------------- validation - model.eval() - self.valid_losses = [[] for i in range(num_heads)] - # model.trans_point_sampler = valid_sampler.get_point_sampler() - - with torch.no_grad(): - for step, inputs in enumerate( - tqdm(valid_loader, desc='validation')): - if hasattr(inputs['data'], 'to'): - inputs['data'].to(device) - dataset_idx = inputs['data'].dataset_idx - results = model(inputs['data']) - loss, gt_labels, predict_scores = model.get_loss( - Loss, results, inputs, device) - - if predict_scores.size()[-1] == 0: - continue - - self.metric_val[dataset_idx].update(predict_scores, - gt_labels) - - self.valid_losses[dataset_idx].append(loss.cpu().item()) - # Save only for the first batch - if 'valid' in record_summary and step == 0: - self.summary['valid'] = self.get_3d_summary( - results, inputs['data'], epoch) - - if self.distributed: - gather_arr = [None for _ in range(dist.get_world_size())] - dist.gather_object((self.metric_train, self.metric_val, - self.losses, self.valid_losses), - gather_arr if rank == 0 else None, - dst=0) - - if rank == 0: - for m1, m2, l1, l2 in gather_arr[1:]: - for i in range(num_heads): - self.metric_train[i] += m1[i] - self.metric_val[i] += m2[i] - self.losses[i] += l1[i] - self.valid_losses[i] += l2[i] - dist.barrier() - - if rank == 0: - for i in range(num_heads): - self.save_logs(writer, epoch, i) - if epoch % cfg.save_ckpt_freq == 0 or epoch == cfg.max_epoch: - self.save_ckpt(epoch) - - def get_batcher(self, device, split='training'): - """Get the batcher to be used based on the device and split.""" - batcher_name = getattr(self.model.cfg, 'batcher') - - if batcher_name == 'DefaultBatcher': - batcher = DefaultBatcher() - elif batcher_name == 'ConcatBatcher': - batcher = ConcatBatcher(device, self.model.cfg.name) - else: - batcher = None - return batcher - - def get_3d_summary(self, results, input_data, epoch, save_gt=True): - """ - Create visualization for network inputs and outputs. - - Args: - results: Model output (see below). - input_data: Model input (see below). - epoch (int): step - save_gt (bool): Save ground truth (for 'train' or 'valid' stages). - - RandLaNet: - results (Tensor(B, N, C)): Prediction scores for all classes - inputs_batch: Batch of pointclouds and labels as a Dict with keys: - 'xyz': First element is Tensor(B,N,3) points - 'labels': (B, N) (optional) labels - - SparseConvUNet: - results (Tensor(SN, C)): Prediction scores for all classes. SN is - total points in the batch. - input_batch (Dict): Batch of pointclouds and labels. Keys should be: - 'point' [Tensor(SN,3), float]: Concatenated points. - 'batch_lengths' [Tensor(B,), int]: Number of points in each - point cloud of the batch. - 'label' [Tensor(SN,) (optional)]: Concatenated labels. - - Returns: - [Dict] visualizations of inputs and outputs suitable to save as an - Open3D for TensorBoard summary. - """ - if not hasattr(self, "_first_step"): - self._first_step = epoch - label_to_names = self.dataset.get_label_to_names() - cfg = self.cfg.get('summary') - max_pts = cfg.get('max_pts') - if max_pts is None: - max_pts = np.iinfo(np.int32).max - use_reference = cfg.get('use_reference', False) - max_outputs = cfg.get('max_outputs', 1) - input_pcd = [] - gt_labels = [] - predict_labels = [] - - def to_sum_fmt(tensor, add_dims=(0, 0), dtype=torch.int32): - sten = tensor.cpu().detach().type(dtype) - new_shape = (1,) * add_dims[0] + sten.shape + (1,) * add_dims[1] - return sten.reshape(new_shape) - - # Variable size point clouds - if self.model.cfg['name'] in ('KPFCNN', 'KPConv'): - batch_lengths = input_data.lengths[0].detach().numpy() - row_splits = np.hstack(((0,), np.cumsum(batch_lengths))) - max_outputs = min(max_outputs, len(row_splits) - 1) - for k in range(max_outputs): - blen_k = row_splits[k + 1] - row_splits[k] - pcd_step = int(np.ceil(blen_k / min(max_pts, blen_k))) - res_pcd = results[row_splits[k]:row_splits[k + 1]:pcd_step, :] - predict_labels.append( - to_sum_fmt(torch.argmax(res_pcd, 1), (0, 1))) - if self._first_step != epoch and use_reference: - continue - pointcloud = input_data.points[0][ - row_splits[k]:row_splits[k + 1]:pcd_step] - input_pcd.append( - to_sum_fmt(pointcloud[:, :3], (0, 0), torch.float32)) - if torch.any(input_data.labels != 0): - gtl = input_data.labels[row_splits[k]:row_splits[k + 1]] - gt_labels.append(to_sum_fmt(gtl, (0, 1))) - - elif self.model.cfg['name'] in ('SparseConvUnet', 'PointTransformer'): - if self.model.cfg['name'] == 'SparseConvUnet': - row_splits = np.hstack( - ((0,), np.cumsum(input_data.batch_lengths))) - else: - row_splits = input_data.row_splits - max_outputs = min(max_outputs, len(row_splits) - 1) - for k in range(max_outputs): - blen_k = row_splits[k + 1] - row_splits[k] - pcd_step = int(np.ceil(blen_k / min(max_pts, blen_k))) - res_pcd = results[row_splits[k]:row_splits[k + 1]:pcd_step, :] - predict_labels.append( - to_sum_fmt(torch.argmax(res_pcd, 1), (0, 1))) - if self._first_step != epoch and use_reference: - continue - if self.model.cfg['name'] == 'SparseConvUnet': - pointcloud = input_data.point[k] - else: - pointcloud = input_data.point[ - row_splits[k]:row_splits[k + 1]:pcd_step] - input_pcd.append( - to_sum_fmt(pointcloud[:, :3], (0, 0), torch.float32)) - if getattr(input_data, 'label', None) is not None: - if self.model.cfg['name'] == 'SparseConvUnet': - gtl = input_data.label[k] - else: - gtl = input_data.label[ - row_splits[k]:row_splits[k + 1]:pcd_step] - gt_labels.append(to_sum_fmt(gtl, (0, 1))) - # Fixed size point clouds - elif self.model.cfg['name'] in ('RandLANet', 'PVCNN'): # Tuple input - if self.model.cfg['name'] == 'RandLANet': - pointcloud = input_data['xyz'][0] # 0 => input to first layer - elif self.model.cfg['name'] == 'PVCNN': - pointcloud = input_data['point'].transpose(1, 2) - pcd_step = int( - np.ceil(pointcloud.shape[1] / - min(max_pts, pointcloud.shape[1]))) - predict_labels = to_sum_fmt( - torch.argmax(results[:max_outputs, ::pcd_step, :], 2), (0, 1)) - if self._first_step == epoch or not use_reference: - input_pcd = to_sum_fmt(pointcloud[:max_outputs, ::pcd_step, :3], - (0, 0), torch.float32) - if save_gt: - gtl = input_data.get('label', - input_data.get('labels', None)) - if gtl is None: - raise ValueError("input_data does not have label(s).") - gt_labels = to_sum_fmt(gtl[:max_outputs, ::pcd_step], - (0, 1)) - else: - raise NotImplementedError( - "Saving 3D summary for the model " - f"{self.model.cfg['name']} is not implemented.") - - def get_reference_or(data_tensor): - if self._first_step == epoch or not use_reference: - return data_tensor - return self._first_step - - summary_dict = { - 'semantic_segmentation': { - "vertex_positions": get_reference_or(input_pcd), - "vertex_gt_labels": get_reference_or(gt_labels), - "vertex_predict_labels": predict_labels, - 'label_to_names': label_to_names - } - } - return summary_dict - - def save_logs(self, writer, epoch, dataset_idx): - """Save logs from the training and send results to TensorBoard.""" - if len(self.metric_train[dataset_idx]) == 0 or len( - self.metric_val[dataset_idx]) == 0: - return - - train_accs = self.metric_train[dataset_idx].acc() - val_accs = self.metric_val[dataset_idx].acc() - - train_ious = self.metric_train[dataset_idx].iou() - val_ious = self.metric_val[dataset_idx].iou() - - loss_dict = { - 'Training loss': np.mean(self.losses[dataset_idx]), - 'Validation loss': np.mean(self.valid_losses[dataset_idx]) - } - acc_dicts = [{ - 'Training accuracy': acc, - 'Validation accuracy': val_acc - } for acc, val_acc in zip(train_accs, val_accs)] - - iou_dicts = [{ - 'Training IoU': iou, - 'Validation IoU': val_iou - } for iou, val_iou in zip(train_ious, val_ious)] - - for key, val in loss_dict.items(): - writer.add_scalar(str(dataset_idx) + " : " + key, val, epoch) - for key, val in acc_dicts[-1].items(): - writer.add_scalar("{} : {}/ Overall".format(dataset_idx, key), val, - epoch) - for key, val in iou_dicts[-1].items(): - writer.add_scalar("{} : {}/ Overall".format(dataset_idx, key), val, - epoch) - - log.info(f"Dataset Index : {dataset_idx}") - log.info(f"Loss train: {loss_dict['Training loss']:.3f} " - f" eval: {loss_dict['Validation loss']:.3f}") - log.info(f"Mean acc train: {acc_dicts[-1]['Training accuracy']:.3f} " - f" eval: {acc_dicts[-1]['Validation accuracy']:.3f}") - log.info(f"Mean IoU train: {iou_dicts[-1]['Training IoU']:.3f} " - f" eval: {iou_dicts[-1]['Validation IoU']:.3f}") - - for stage in self.summary: - for key, summary_dict in self.summary[stage].items(): - label_to_names = summary_dict.pop('label_to_names', None) - writer.add_3d('/'.join((stage, key)), - summary_dict, - epoch, - max_outputs=0, - label_to_names=label_to_names) - - def load_ckpt(self, ckpt_path=None, is_resume=True): - """Load a checkpoint. You must pass the checkpoint and indicate if you - want to resume. - """ - train_ckpt_dir = join(self.cfg.logs_dir, 'checkpoint') - if self.rank == 0: - make_dir(train_ckpt_dir) - if self.distributed: - dist.barrier() - - if ckpt_path is None: - ckpt_path = latest_torch_ckpt(train_ckpt_dir) - if ckpt_path is not None and is_resume: - log.info("ckpt_path not given. Restore from the latest ckpt") - else: - log.info('Initializing from scratch.') - return 0 - - if not exists(ckpt_path): - raise FileNotFoundError(f' ckpt {ckpt_path} not found') - - log.info(f'Loading checkpoint {ckpt_path}') - ckpt = torch.load(ckpt_path, map_location=self.device) - - self.model.load_state_dict(ckpt['model_state_dict']) - if 'optimizer_state_dict' in ckpt and hasattr(self, 'optimizer'): - log.info('Loading checkpoint optimizer_state_dict') - self.optimizer.load_state_dict(ckpt['optimizer_state_dict']) - if 'scheduler_state_dict' in ckpt and hasattr(self, 'scheduler'): - log.info('Loading checkpoint scheduler_state_dict') - self.scheduler.load_state_dict(ckpt['scheduler_state_dict']) - - epoch = 0 if 'epoch' not in ckpt else ckpt['epoch'] - - return epoch - - def save_ckpt(self, epoch): - """Save a checkpoint at the passed epoch.""" - path_ckpt = join(self.cfg.logs_dir, 'checkpoint') - make_dir(path_ckpt) - torch.save( - dict(epoch=epoch, - model_state_dict=self.model.state_dict(), - optimizer_state_dict=self.optimizer.state_dict(), - scheduler_state_dict=self.scheduler.state_dict()), - join(path_ckpt, f'ckpt_{epoch:05d}.pth')) - log.info(f'Epoch {epoch:3d}: save ckpt to {path_ckpt:s}') - - def save_config(self, writer): - """Save experiment configuration with tensorboard summary.""" - if hasattr(self, 'cfg_tb'): - writer.add_text("Description/Open3D-ML", self.cfg_tb['readme'], 0) - writer.add_text("Description/Command line", self.cfg_tb['cmd_line'], - 0) - writer.add_text('Configuration/Dataset', - code2md(self.cfg_tb['dataset'], language='json'), 0) - writer.add_text('Configuration/Model', - code2md(self.cfg_tb['model'], language='json'), 0) - writer.add_text('Configuration/Pipeline', - code2md(self.cfg_tb['pipeline'], language='json'), - 0) - - -PIPELINE._register_module(SemanticSegmentationMultiHead, "torch") From 40f888fd6981f958125fd5de3cda91d5a1ec8d7e Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 20 Sep 2022 04:05:54 -0700 Subject: [PATCH 46/50] apply-style --- ml3d/datasets/megaloader.py | 11 +++++++---- ml3d/datasets/nuscenes_semseg.py | 5 ++--- ml3d/torch/dataloaders/concat_batcher.py | 1 + ml3d/torch/models/sparseconvnet_megamodel.py | 1 + ml3d/torch/modules/losses/semseg_loss.py | 3 ++- scripts/run_pipeline.py | 1 + 6 files changed, 14 insertions(+), 8 deletions(-) diff --git a/ml3d/datasets/megaloader.py b/ml3d/datasets/megaloader.py index 7324ec94c..1c6a48dc2 100644 --- a/ml3d/datasets/megaloader.py +++ b/ml3d/datasets/megaloader.py @@ -62,8 +62,9 @@ def __init__(self, self.datasets = [ get_module('dataset', cfg.name)(**cfg) for cfg in self.configs ] - self.class_weights = [self.datasets[i].cfg.class_weights for i in range(self.num_datasets)] - + self.class_weights = [ + self.datasets[i].cfg.class_weights for i in range(self.num_datasets) + ] def get_split(self, split): """Returns a dataset split. @@ -160,7 +161,8 @@ def get_data(self, idx): dataset_idx = self.test_dataset_idx else: dataset_idx = idx % self.num_datasets - idx = (idx // self.num_datasets) % len(self.dataset_splits[dataset_idx]) + idx = (idx // self.num_datasets) % len( + self.dataset_splits[dataset_idx]) data = self.dataset_splits[dataset_idx].get_data(idx) @@ -171,7 +173,8 @@ def get_attr(self, idx): dataset_idx = self.test_dataset_idx else: dataset_idx = idx % self.num_datasets - idx = (idx // self.num_datasets) % len(self.dataset_splits[dataset_idx]) + idx = (idx // self.num_datasets) % len( + self.dataset_splits[dataset_idx]) attr = self.dataset_splits[dataset_idx].get_attr(idx) attr['dataset_idx'] = dataset_idx diff --git a/ml3d/datasets/nuscenes_semseg.py b/ml3d/datasets/nuscenes_semseg.py index e7bf4d246..244672f2c 100644 --- a/ml3d/datasets/nuscenes_semseg.py +++ b/ml3d/datasets/nuscenes_semseg.py @@ -109,8 +109,8 @@ def __init__(self, 28: 15, 30: 16 } - self.label_mapping = np.array([mapping[i] for i in range(0, len(mapping))], dtype=np.int32) - + self.label_mapping = np.array( + [mapping[i] for i in range(0, len(mapping))], dtype=np.int32) @staticmethod def get_label_to_names(): @@ -152,7 +152,6 @@ def read_lidarseg(self, path): return self.label_mapping[labels] - def get_split(self, split): """Returns a dataset split. diff --git a/ml3d/torch/dataloaders/concat_batcher.py b/ml3d/torch/dataloaders/concat_batcher.py index 0c28e04a3..a10684663 100644 --- a/ml3d/torch/dataloaders/concat_batcher.py +++ b/ml3d/torch/dataloaders/concat_batcher.py @@ -451,6 +451,7 @@ def scatter(batch, num_gpu): return [b for b in batches if len(b.point)] # filter empty batch + class PointTransformerBatch: def __init__(self, batches): diff --git a/ml3d/torch/models/sparseconvnet_megamodel.py b/ml3d/torch/models/sparseconvnet_megamodel.py index f9b20e778..09ebec5db 100644 --- a/ml3d/torch/models/sparseconvnet_megamodel.py +++ b/ml3d/torch/models/sparseconvnet_megamodel.py @@ -12,6 +12,7 @@ log = logging.getLogger(__name__) + class SparseConvUnetMegaModel(BaseModel): """Semantic Segmentation model. diff --git a/ml3d/torch/modules/losses/semseg_loss.py b/ml3d/torch/modules/losses/semseg_loss.py index 5d0e8c1a2..56eda4704 100644 --- a/ml3d/torch/modules/losses/semseg_loss.py +++ b/ml3d/torch/modules/losses/semseg_loss.py @@ -70,7 +70,8 @@ def __init__(self, for i in range(num_heads): if weights is not None and len(weights[i]) != 0: wts = DataProcessing.get_class_weights(weights[i])[0] - assert len(wts) == num_classes[i], f"num_classes : {num_classes[i]} is not equal to number of class weights : {len(wts)}" + assert len(wts) == num_classes[ + i], f"num_classes : {num_classes[i]} is not equal to number of class weights : {len(wts)}" wts = torch.tensor(wts) else: wts = torch.ones(num_classes[i]) diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index b55661b5b..64718fdaf 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -201,6 +201,7 @@ def setup(rank, world_size, args): def cleanup(): dist.destroy_process_group() + def main_worker(local_rank, Dataset, Model, Pipeline, cfg_dict_dataset, cfg_dict_model, cfg_dict_pipeline, args): rank = args.node_rank * len(args.device_ids) + local_rank From 375bb8fe7d3b6a176c9f18bb8d4d69154fb908c2 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 20 Sep 2022 19:11:39 +0530 Subject: [PATCH 47/50] remove megamodel --- ml3d/configs/default_cfgs/kitti360.yml | 7 + ml3d/configs/default_cfgs/nuscenes_semseg.yml | 2 +- ml3d/configs/default_cfgs/semantickitti.yml | 4 +- ml3d/configs/default_cfgs/waymo_semseg.yml | 2 +- ml3d/configs/pointtransformer_waymo.yml | 17 - ml3d/configs/sparseconvunet_waymo.yml | 18 - ml3d/datasets/__init__.py | 4 +- ml3d/datasets/base_dataset.py | 2 - ml3d/datasets/megaloader.py | 185 ----- ml3d/datasets/nuscenes.py | 14 - ml3d/datasets/semantickitti.py | 2 +- ml3d/datasets/waymo_semseg.py | 2 +- ml3d/torch/models/__init__.py | 3 +- ml3d/torch/models/sparseconvnet_megamodel.py | 723 ------------------ ml3d/torch/modules/losses/__init__.py | 4 +- ml3d/torch/modules/losses/semseg_loss.py | 27 - ml3d/torch/pipelines/__init__.py | 5 +- ml3d/torch/pipelines/semantic_segmentation.py | 1 - scripts/preprocess_nuscenes.py | 5 - scripts/preprocess_waymo.py | 83 +- scripts/preprocess_waymo_semseg.py | 5 - 21 files changed, 51 insertions(+), 1064 deletions(-) create mode 100644 ml3d/configs/default_cfgs/kitti360.yml delete mode 100644 ml3d/datasets/megaloader.py delete mode 100644 ml3d/torch/models/sparseconvnet_megamodel.py diff --git a/ml3d/configs/default_cfgs/kitti360.yml b/ml3d/configs/default_cfgs/kitti360.yml new file mode 100644 index 000000000..6808e3861 --- /dev/null +++ b/ml3d/configs/default_cfgs/kitti360.yml @@ -0,0 +1,7 @@ +name: KITTI360 +dataset_path: /Users/sanskara/data/kitti360/ +cache_dir: ./logs/cache +class_weights: [] +ignored_label_inds: [] +test_result_folder: ./test +use_cache: False diff --git a/ml3d/configs/default_cfgs/nuscenes_semseg.yml b/ml3d/configs/default_cfgs/nuscenes_semseg.yml index 056af4f05..824c797cb 100644 --- a/ml3d/configs/default_cfgs/nuscenes_semseg.yml +++ b/ml3d/configs/default_cfgs/nuscenes_semseg.yml @@ -1,5 +1,5 @@ name: NuScenesSemSeg -dataset_path: /export/share/projects/open3d_ml/NuScenes/processed/ +dataset_path: # path/to/your/dataset cache_dir: ./logs/cache class_weights: [282265, 7676, 120, 3754, 31974, 1321, 346, 1898, 624, 4537, 13282, 260911, 6588, 57567, 56670, 146511, 100633] ignored_label_inds: [0] diff --git a/ml3d/configs/default_cfgs/semantickitti.yml b/ml3d/configs/default_cfgs/semantickitti.yml index 0f58292b8..cd26c0429 100644 --- a/ml3d/configs/default_cfgs/semantickitti.yml +++ b/ml3d/configs/default_cfgs/semantickitti.yml @@ -1,6 +1,6 @@ name: SemanticKITTI -dataset_path: /export/share/datasets/SemanticKITTI/ +dataset_path: # path/to/your/dataset cache_dir: ./logs/cache ignored_label_inds: [0] use_cache: false -class_weights: [101665, 157022, 631, 1516, 5012, 7085, 1043, 457, 176, 693044, 53132, 494988, 12829, 459669, 236069, 924425, 22780, 255213, 9664, 2024] \ No newline at end of file +class_weights: [101665, 157022, 631, 1516, 5012, 7085, 1043, 457, 176, 693044, 53132, 494988, 12829, 459669, 236069, 924425, 22780, 255213, 9664, 2024] diff --git a/ml3d/configs/default_cfgs/waymo_semseg.yml b/ml3d/configs/default_cfgs/waymo_semseg.yml index c103aa3ce..8658c4dae 100644 --- a/ml3d/configs/default_cfgs/waymo_semseg.yml +++ b/ml3d/configs/default_cfgs/waymo_semseg.yml @@ -1,5 +1,5 @@ name: WaymoSemSeg -dataset_path: /export/share/datasets/Waymo_1.3/processed/ +dataset_path: # path/to/your/dataset cache_dir: ./logs/cache class_weights: [513255, 341079, 39946, 28066, 17254, 104, 1169, 31335, 23359, 2924, 43680, 2149, 483, 639, 1394353, 858409, 90903, 52484, 884591, 24487, 21477, 322212, 229034] ignored_label_inds: [0] diff --git a/ml3d/configs/pointtransformer_waymo.yml b/ml3d/configs/pointtransformer_waymo.yml index 355a3b7d1..30605046c 100644 --- a/ml3d/configs/pointtransformer_waymo.yml +++ b/ml3d/configs/pointtransformer_waymo.yml @@ -17,23 +17,6 @@ model: max_voxels: 50000 ignored_label_inds: [-1] augment: {} - # rotate: - # method: vertical - # scale: - # min_s: 0.95 - # max_s: 1.05 - # noise: - # noise_std: 0.005 - # ChromaticAutoContrast: - # randomize_blend_factor: True - # blend_factor: 0.2 - # ChromaticTranslation: - # trans_range_ratio: 0.05 - # ChromaticJitter: - # std: 0.01 - # HueSaturationTranslation: - # hue_max: 0.5 - # saturation_max: 0.2 pipeline: name: SemanticSegmentation optimizer: diff --git a/ml3d/configs/sparseconvunet_waymo.yml b/ml3d/configs/sparseconvunet_waymo.yml index 3ca34d2c7..fac0da5c1 100644 --- a/ml3d/configs/sparseconvunet_waymo.yml +++ b/ml3d/configs/sparseconvunet_waymo.yml @@ -19,24 +19,6 @@ model: grid_size: 4096 ignored_label_inds: [0] augment: {} - # rotate: - # method: vertical - # scale: - # min_s: 0.9 - # max_s: 1.1 - # noise: - # noise_std: 0.01 - # RandomDropout: - # dropout_ratio: 0.2 - # RandomHorizontalFlip: - # axes: [0, 1] - # ChromaticAutoContrast: - # randomize_blend_factor: True - # blend_factor: 0.5 - # ChromaticTranslation: - # trans_range_ratio: 0.1 - # ChromaticJitter: - # std: 0.05 pipeline: name: SemanticSegmentation optimizer: diff --git a/ml3d/datasets/__init__.py b/ml3d/datasets/__init__.py index 1677759a5..d1ca319bc 100644 --- a/ml3d/datasets/__init__.py +++ b/ml3d/datasets/__init__.py @@ -24,13 +24,11 @@ from .sunrgbd import SunRGBD from .matterport_objects import MatterportObjects from .kitti360 import KITTI360 -from .megaloader import MegaLoader __all__ = [ 'SemanticKITTI', 'S3DIS', 'Toronto3D', 'ParisLille3D', 'Semantic3D', 'Custom3D', 'utils', 'augment', 'samplers', 'KITTI', 'Waymo', 'NuScenes', 'Lyft', 'ShapeNet', 'SemSegRandomSampler', 'InferenceDummySplit', 'SemSegSpatiallyRegularSampler', 'Argoverse', 'Scannet', 'SunRGBD', - 'MatterportObjects', 'WaymoSemSeg', 'KITTI360', 'MegaLoader', - 'NuScenesSemSeg' + 'MatterportObjects', 'WaymoSemSeg', 'KITTI360', 'NuScenesSemSeg' ] diff --git a/ml3d/datasets/base_dataset.py b/ml3d/datasets/base_dataset.py index a97a890c3..ec48550aa 100644 --- a/ml3d/datasets/base_dataset.py +++ b/ml3d/datasets/base_dataset.py @@ -127,8 +127,6 @@ def __init__(self, dataset, split='training'): if split in ['test']: sampler_cls = get_module('sampler', 'SemSegSpatiallyRegularSampler') else: - # sampler_cfg = self.cfg.get('sampler', - # {'name': 'SemSegRandomSampler'}) sampler_cfg = self.cfg.get('sampler', {'name': None}) if sampler_cfg['name'] == "None": sampler_cfg['name'] = None diff --git a/ml3d/datasets/megaloader.py b/ml3d/datasets/megaloader.py deleted file mode 100644 index 1c6a48dc2..000000000 --- a/ml3d/datasets/megaloader.py +++ /dev/null @@ -1,185 +0,0 @@ -import numpy as np -import pandas as pd -import os, glob, pickle -from pathlib import Path -from os.path import join, exists, dirname, abspath, isdir -from sklearn.neighbors import KDTree -from tqdm import tqdm -import logging - -from .utils import DataProcessing, get_min_bbox, BEVBox3D -from .base_dataset import BaseDataset, BaseDatasetSplit -from ..utils import make_dir, DATASET, Config, get_module - -log = logging.getLogger(__name__) - - -class MegaLoader(): - """This class is used to create a combination of multiple datasets, - and sample data among them uniformly. - """ - - def __init__(self, - config_paths, - batch_size=1, - name='MegaLoader', - cache_dir='./logs/cache', - use_cache=False, - ignored_label_inds=[], - test_result_folder='./test', - **kwargs): - """Initialize the function by passing the dataset and other details. - - Args: - config_paths: List of dataset config files to use. - dataset_path: The path to the dataset to use (parent directory of data_3d_semantics). - name: The name of the dataset (MegaLoader in this case). - cache_dir: The directory where the cache is stored. - use_cache: Indicates if the dataset should be cached. - ignored_label_inds: A list of labels that should be ignored in the dataset. - test_result_folder: The folder where the test results should be stored. - """ - - kwargs['config_paths'] = config_paths - kwargs['name'] = name - kwargs['cache_dir'] = cache_dir - kwargs['use_cache'] = use_cache - kwargs['ignored_label_inds'] = ignored_label_inds - kwargs['test_result_folder'] = test_result_folder - kwargs['batch_size'] = batch_size - - self.cfg = Config(kwargs) - self.name = self.cfg.name - self.batch_size = batch_size - self.rng = np.random.default_rng(kwargs.get('seed', None)) - self.ignored_labels = np.array([]) - - self.num_datasets = len(config_paths) - - self.configs = [ - Config.load_from_file(cfg_path) for cfg_path in config_paths - ] - self.datasets = [ - get_module('dataset', cfg.name)(**cfg) for cfg in self.configs - ] - self.class_weights = [ - self.datasets[i].cfg.class_weights for i in range(self.num_datasets) - ] - - def get_split(self, split): - """Returns a dataset split. - - Args: - split: A string identifying the dataset split that is usually one of - 'training', 'test', 'validation', or 'all'. - - Returns: - A dataset split object providing the requested subset of the data. - """ - return MegaLoaderSplit(self, split=split) - - def get_split_list(self, split): - if split in ['train', 'training']: - return self.train_files - elif split in ['val', 'validation']: - return self.val_files - elif split in ['test', 'testing']: - return test_files - elif split == 'all': - return self.train_files + self.val_files + self.test_files - else: - raise ValueError("Invalid split {}".format(split)) - - def is_tested(self, attr): - cfg = self.cfg - name = attr['name'] - path = cfg.test_result_folder - store_path = join(path, self.name, name + '.npy') - if exists(store_path): - print("{} already exists.".format(store_path)) - return True - else: - return False - - def save_test_result(self, results, name): - """Saves the output of a model. - - Args: - results: The output of a model for the datum associated with the attribute passed. - attr: The attributes that correspond to the outputs passed in results. - """ - cfg = self.cfg - path = cfg.test_result_folder - make_dir(path) - - pred = results['predict_labels'] - pred = np.array(pred) - - store_path = join(path, self.name, name + '.npy') - make_dir(Path(store_path).parent) - np.save(store_path, pred) - log.info("Saved {} in {}.".format(name, store_path)) - - -class MegaLoaderSplit(): - """This class is used to create a split for MegaLoader dataset. - - Initialize the class. - - Args: - dataset: The dataset to split. - split: A string identifying the dataset split that is usually one of - 'training', 'test', 'validation', or 'all'. - **kwargs: The configuration of the model as keyword arguments. - - Returns: - A dataset split object providing the requested subset of the data. - """ - - def __init__(self, dataset, split='training', test_dataset_idx=0): - self.cfg = dataset.cfg - self.split = split - self.dataset = dataset - self.dataset_splits = [ - a.get_split(split) for a in self.dataset.datasets - ] - self.num_datasets = dataset.num_datasets - self.test_dataset_idx = test_dataset_idx - - if 'test' in split: - sampler_cls = get_module('sampler', 'SemSegSpatiallyRegularSampler') - self.sampler = sampler_cls(self) - - def __len__(self): - if 'test' in self.split: - return len(self.dataset_splits[self.test_dataset_idx]) - lens = [len(a) for a in self.dataset_splits] - return max(lens) * self.num_datasets - - def get_data(self, idx): - if 'test' in self.split: - dataset_idx = self.test_dataset_idx - else: - dataset_idx = idx % self.num_datasets - idx = (idx // self.num_datasets) % len( - self.dataset_splits[dataset_idx]) - - data = self.dataset_splits[dataset_idx].get_data(idx) - - return data - - def get_attr(self, idx): - if 'test' in self.split: - dataset_idx = self.test_dataset_idx - else: - dataset_idx = idx % self.num_datasets - idx = (idx // self.num_datasets) % len( - self.dataset_splits[dataset_idx]) - - attr = self.dataset_splits[dataset_idx].get_attr(idx) - attr['dataset_idx'] = dataset_idx - - return attr - - -DATASET._register_module(MegaLoader) diff --git a/ml3d/datasets/nuscenes.py b/ml3d/datasets/nuscenes.py index 2ef9dbc96..44362e22a 100644 --- a/ml3d/datasets/nuscenes.py +++ b/ml3d/datasets/nuscenes.py @@ -108,17 +108,6 @@ def read_lidar(path): return np.fromfile(path, dtype=np.float32).reshape(-1, 5) - @staticmethod - def read_lidarseg(path): - """Reads semantic data from the path provided. - - Returns: - A data object with semantic information. - """ - assert Path(path).exists() - - return np.fromfile(path, dtype=np.uint8).reshape(-1,).astype(np.int32) - @staticmethod def read_label(info, calib): """Reads labels of bound boxes. @@ -267,7 +256,6 @@ def __len__(self): def get_data(self, idx): info = self.infos[idx] lidar_path = info['lidar_path'] - lidarseg_path = info['lidarseg_path'] world_cam = np.eye(4) world_cam[:3, :3] = R.from_quat(info['lidar2ego_rot']).as_matrix() @@ -276,14 +264,12 @@ def get_data(self, idx): pc = self.dataset.read_lidar(lidar_path) label = self.dataset.read_label(info, calib) - lidarseg = self.dataset.read_lidarseg(lidarseg_path) data = { 'point': pc, 'feat': None, 'calib': calib, 'bounding_boxes': label, - 'label': lidarseg } if 'cams' in info: diff --git a/ml3d/datasets/semantickitti.py b/ml3d/datasets/semantickitti.py index 8ec98f5cc..d93dddcc9 100644 --- a/ml3d/datasets/semantickitti.py +++ b/ml3d/datasets/semantickitti.py @@ -282,7 +282,7 @@ def get_data(self, idx): data = { 'point': points[:, 0:3], - 'feat': points[:, 3:4], + 'feat': None, 'label': labels, } diff --git a/ml3d/datasets/waymo_semseg.py b/ml3d/datasets/waymo_semseg.py index bbc31b978..44946f2fe 100644 --- a/ml3d/datasets/waymo_semseg.py +++ b/ml3d/datasets/waymo_semseg.py @@ -201,4 +201,4 @@ def get_attr(self, idx): return attr -DATASET._register_module(WaymoSemSeg) \ No newline at end of file +DATASET._register_module(WaymoSemSeg) diff --git a/ml3d/torch/models/__init__.py b/ml3d/torch/models/__init__.py index 31c0dab52..817f0c095 100644 --- a/ml3d/torch/models/__init__.py +++ b/ml3d/torch/models/__init__.py @@ -7,11 +7,10 @@ from .point_rcnn import PointRCNN from .point_transformer import PointTransformer from .pvcnn import PVCNN -from .sparseconvnet_megamodel import SparseConvUnetMegaModel __all__ = [ 'RandLANet', 'KPFCNN', 'PointPillars', 'PointRCNN', 'SparseConvUnet', - 'PointTransformer', 'PVCNN', 'SparseConvUnetMegaModel' + 'PointTransformer', 'PVCNN' ] try: diff --git a/ml3d/torch/models/sparseconvnet_megamodel.py b/ml3d/torch/models/sparseconvnet_megamodel.py deleted file mode 100644 index 09ebec5db..000000000 --- a/ml3d/torch/models/sparseconvnet_megamodel.py +++ /dev/null @@ -1,723 +0,0 @@ -import numpy as np -import torch -import torch.nn as nn -import logging - -from .base_model import BaseModel -from ...utils import MODEL -from ..modules.losses import filter_valid_label -from ...datasets.augment import SemsegAugmentation -from open3d.ml.torch.layers import SparseConv, SparseConvTranspose -from open3d.ml.torch.ops import voxelize, reduce_subarrays_sum - -log = logging.getLogger(__name__) - - -class SparseConvUnetMegaModel(BaseModel): - """Semantic Segmentation model. - - Uses UNet architecture replacing convolutions with Sparse Convolutions. - - Attributes: - name: Name of model. - Default to "SparseConvUnet". - device: Which device to use (cpu or cuda). - voxel_size: Voxel length for subsampling. - multiplier: min length of feature length in each layer. - conv_block_reps: repetition of Unet Blocks. - residual_blocks: Whether to use Residual Blocks. - in_channels: Number of features(default 3 for color). - num_classes: Number of classes. - """ - - def __init__( - self, - name="SparseConvUnetMegaModel", - device="cuda", - num_heads=1, # number of segmentation heads. - multiplier=16, # Proportional to number of neurons in each layer. - voxel_size=0.05, - varying_input_layers=False, - conv_block_reps=1, # Conv block repetitions. - residual_blocks=True, - in_channels=3, - num_classes=[20], - grid_size=81920, - batcher='ConcatBatcher', - augment=None, - ckpt_path=None, - **kwargs): - super(SparseConvUnetMegaModel, - self).__init__(name=name, - device=device, - num_heads=num_heads, - multiplier=multiplier, - voxel_size=voxel_size, - varying_input_layers=varying_input_layers, - conv_block_reps=conv_block_reps, - residual_blocks=residual_blocks, - in_channels=in_channels, - num_classes=num_classes, - grid_size=grid_size, - batcher=batcher, - augment=augment, - ckpt_path=ckpt_path, - **kwargs) - cfg = self.cfg - self.device = device - self.augmenter = SemsegAugmentation(cfg.augment, seed=self.rng) - self.multiplier = cfg.multiplier - self.num_heads = num_heads - self.varying_input_layers = varying_input_layers - self.input_layer = InputLayer() - - if self.varying_input_layers: - self.sub_sparse_conv = [ - SubmanifoldSparseConv(in_channels=in_channels, - filters=multiplier, - kernel_size=[3, 3, 3]) - for i in range(num_heads) - ] - self.sub_sparse_conv = nn.ModuleList(self.sub_sparse_conv) - else: - self.sub_sparse_conv = SubmanifoldSparseConv( - in_channels=in_channels, - filters=multiplier, - kernel_size=[3, 3, 3]) - self.unet = UNet(conv_block_reps, [ - multiplier, 2 * multiplier, 3 * multiplier, 4 * multiplier, - 5 * multiplier, 6 * multiplier, 7 * multiplier - ], residual_blocks) - self.batch_norm = BatchNormBlock(multiplier) - self.relu = ReLUBlock() - - if len(num_classes) != num_heads: - raise ValueError("Pass num_classes for each segmentation head.") - - self.linear = LinearBlock(multiplier, num_classes) - self.output_layer = OutputLayer() - - def forward(self, inputs): - pos_list = [] - feat_list = [] - index_map_list = [] - - for i in range(len(inputs.batch_lengths)): - pos = inputs.point[i] - feat = inputs.feat[i] - feat, pos, index_map = self.input_layer(feat, pos) - pos_list.append(pos) - feat_list.append(feat) - index_map_list.append(index_map) - - if self.varying_input_layers: - feat_list = self.sub_sparse_conv[inputs.dataset_idx](feat_list, - pos_list, - voxel_size=1.0) - else: - feat_list = self.sub_sparse_conv(feat_list, - pos_list, - voxel_size=1.0) - feat_list = self.unet(pos_list, feat_list) - feat_list = self.batch_norm(feat_list) - feat_list = self.relu(feat_list) - feat_list = self.linear(feat_list, inputs.dataset_idx) - output = self.output_layer(feat_list, index_map_list) - - return output - - def preprocess(self, data, attr): - # If num_workers > 0, use new RNG with unique seed for each thread. - # Else, use default RNG. - if torch.utils.data.get_worker_info(): - seedseq = np.random.SeedSequence( - torch.utils.data.get_worker_info().seed + - torch.utils.data.get_worker_info().id) - rng = np.random.default_rng(seedseq.spawn(1)[0]) - else: - rng = self.rng - - points = np.array(data['point'], dtype=np.float32) - - if 'label' not in data or data['label'] is None: - labels = np.zeros((points.shape[0],), dtype=np.int32) - else: - labels = np.array(data['label'], dtype=np.int32).reshape((-1,)) - - if 'feat' not in data or data['feat'] is None: - raise Exception( - "SparseConvnet doesn't work without feature values.") - - feat = np.array(data['feat'], dtype=np.float32) - - # only use xyz - feat = points.copy() - - if attr['split'] in ['training', 'train']: - points, feat, labels = self.augmenter.augment(points, - feat, - labels, - self.cfg.get( - 'augment', None), - seed=rng) - - # Scale to voxel size. - points *= 1. / self.cfg.voxel_size # Scale = 1/voxel_size - - m = points.min(0) - M = points.max(0) - - # Randomly place pointcloud in 4096 size grid. - grid_size = self.cfg.grid_size - - # make everything positive - offset = -1 * points.min(0) + 500 - - points += offset - points = (points.astype(np.int32) + 0.5).astype( - np.float32) # Move points to voxel center. - - data = {} - data['point'] = points - data['feat'] = feat - data['label'] = labels - - return data - - def transform(self, data, attr): - data['point'] = torch.from_numpy(data['point']) - data['feat'] = torch.from_numpy(data['feat']) - data['label'] = torch.from_numpy(data['label']) - - return data - - def update_probs(self, inputs, results, test_probs, test_labels): - result = results.reshape(-1, self.cfg.num_classes) - probs = torch.nn.functional.softmax(result, dim=-1).cpu().data.numpy() - labels = np.argmax(probs, 1) - - self.trans_point_sampler(patchwise=False) - - return probs, labels - - def inference_begin(self, data): - data = self.preprocess(data, {'split': 'test'}) - data['batch_lengths'] = [data['point'].shape[0]] - data = self.transform(data, {}) - - self.inference_input = data - - def inference_preprocess(self): - return self.inference_input - - def inference_end(self, inputs, results): - results = torch.reshape(results, (-1, self.cfg.num_classes)) - - m_softmax = torch.nn.Softmax(dim=-1) - results = m_softmax(results) - results = results.cpu().data.numpy() - - probs = np.reshape(results, [-1, self.cfg.num_classes]) - - pred_l = np.argmax(probs, 1) - - return {'predict_labels': pred_l, 'predict_scores': probs} - - def get_loss(self, Loss, results, inputs, device): - """Calculate the loss on output of the model. - - Attributes: - Loss: Object of type `SemSegLoss`. - results: Output of the model. - inputs: Input of the model. - device: device(cpu or cuda). - - Returns: - Returns loss, labels and scores. - """ - cfg = self.cfg - labels = torch.cat(inputs['data'].label, - 0).to(torch.LongTensor()).to(results.device) - - loss = Loss.weighted_CrossEntropyLoss[inputs['data'].dataset_idx]( - results, labels) - - return loss, labels, results - - def get_optimizer(self, cfg_pipeline): - optimizer = torch.optim.Adam(self.parameters(), lr=0.001) - scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, 0.99) - - return optimizer, scheduler - - -MODEL._register_module(SparseConvUnetMegaModel, 'torch') - - -class BatchNormBlock(nn.Module): - - def __init__(self, m, eps=1e-4, momentum=0.01): - super(BatchNormBlock, self).__init__() - self.bn = nn.BatchNorm1d(m, eps=eps, momentum=momentum) - - def forward(self, feat_list): - lengths = [feat.shape[0] for feat in feat_list] - out = self.bn(torch.cat(feat_list, 0)) - out_list = [] - start = 0 - for l in lengths: - out_list.append(out[start:start + l]) - start += l - - return out_list - - def __name__(self): - return "BatchNormBlock" - - -class ReLUBlock(nn.Module): - - def __init__(self): - super(ReLUBlock, self).__init__() - self.relu = nn.ReLU() - - def forward(self, feat_list): - lengths = [feat.shape[0] for feat in feat_list] - out = self.relu(torch.cat(feat_list, 0)) - out_list = [] - start = 0 - for l in lengths: - out_list.append(out[start:start + l]) - start += l - - return out_list - - def __name__(self): - return "ReLUBlock" - - -class LinearBlock(nn.Module): - - def __init__(self, in_dim, num_classes): - super(LinearBlock, self).__init__() - - linear = [] - self.num_classes = num_classes - for i in range(len(num_classes)): - linear.append( - nn.Sequential(nn.Linear(in_dim, 2 * in_dim), - nn.Linear(2 * in_dim, num_classes[i]))) - - self.linear = nn.ModuleList(linear) - - def forward(self, feat_list, dataset_idx): - out_list = [] - for i, feat in enumerate(feat_list): - if dataset_idx == -1: - out = [] - for j in range(len(self.num_classes)): - out.append(self.linear[j](feat)) - out = torch.cat(out, -1) - out_list.append(out) - else: - out_list.append(self.linear[dataset_idx](feat)) - - return out_list - - def __name__(self): - return "LinearBlock" - - -class InputLayer(nn.Module): - - def __init__(self, voxel_size=1.0): - super(InputLayer, self).__init__() - self.voxel_size = torch.Tensor([voxel_size, voxel_size, voxel_size]) - - def forward(self, features, in_positions): - v = voxelize( - in_positions, - torch.LongTensor([0, - in_positions.shape[0]]).to(in_positions.device), - self.voxel_size, torch.Tensor([0, 0, 0]), - torch.Tensor([81920, 81920, 81920])) - - # Contiguous repeating positions. - in_positions = in_positions[v.voxel_point_indices] - features = features[v.voxel_point_indices] - - # Find reverse mapping. - reverse_map_voxelize = np.zeros((in_positions.shape[0],)) - reverse_map_voxelize[v.voxel_point_indices.cpu().numpy()] = np.arange( - in_positions.shape[0]) - reverse_map_voxelize = reverse_map_voxelize.astype(np.int32) - - # Unique positions. - in_positions = in_positions[v.voxel_point_row_splits[:-1]] - - # Mean of features. - count = v.voxel_point_row_splits[1:] - v.voxel_point_row_splits[:-1] - reverse_map_sort = np.repeat(np.arange(count.shape[0]), - count.cpu().numpy()).astype(np.int32) - - features_avg = torch.zeros(in_positions.shape[0], - features.shape[1]).to(features.device) - for i in range(features_avg.shape[1]): - features_avg[:, i] = reduce_subarrays_sum(features[:, i], - v.voxel_point_row_splits) - - features_avg = features_avg / count.unsqueeze(1) - - return features_avg, in_positions, reverse_map_sort[ - reverse_map_voxelize] - - -class OutputLayer(nn.Module): - - def __init__(self, voxel_size=1.0): - super(OutputLayer, self).__init__() - - def forward(self, features_list, index_map_list): - out = [] - for feat, index_map in zip(features_list, index_map_list): - out.append(feat[index_map]) - - out = torch.cat(out, 0) - - return out - - -class SubmanifoldSparseConv(nn.Module): - - def __init__(self, - in_channels, - filters, - kernel_size, - use_bias=False, - offset=None, - normalize=False): - super(SubmanifoldSparseConv, self).__init__() - - if offset is None: - if kernel_size[0] % 2: - offset = 0. - else: - offset = 0.5 - - offset = torch.full((3,), offset, dtype=torch.float32) - self.net = SparseConv(in_channels=in_channels, - filters=filters, - kernel_size=kernel_size, - use_bias=use_bias, - offset=offset, - normalize=normalize) - - def forward(self, - features_list, - in_positions_list, - out_positions_list=None, - voxel_size=1.0): - if out_positions_list is None: - out_positions_list = in_positions_list - - out_feat = [] - for feat, in_pos, out_pos in zip(features_list, in_positions_list, - out_positions_list): - out_feat.append(self.net(feat, in_pos, out_pos, voxel_size)) - - return out_feat - - def __name__(self): - return "SubmanifoldSparseConv" - - -def calculate_grid(in_positions): - filter = torch.Tensor([[-1, -1, -1], [-1, -1, 0], [-1, 0, -1], [-1, 0, 0], - [0, -1, -1], [0, -1, 0], [0, 0, -1], - [0, 0, 0]]).to(in_positions.device) - - out_pos = in_positions.long().repeat(1, filter.shape[0]).reshape(-1, 3) - filter = filter.repeat(in_positions.shape[0], 1) - - out_pos = out_pos + filter - out_pos = out_pos[out_pos.min(1).values >= 0] - out_pos = out_pos[(~((out_pos.long() % 2).bool()).any(1))] - out_pos = torch.unique(out_pos, dim=0) - - return out_pos + 0.5 - - -class Convolution(nn.Module): - - def __init__(self, - in_channels, - filters, - kernel_size, - use_bias=False, - offset=None, - normalize=False): - super(Convolution, self).__init__() - - if offset is None: - if kernel_size[0] % 2: - offset = 0. - else: - offset = -0.5 - - offset = torch.full((3,), offset, dtype=torch.float32) - self.net = SparseConv(in_channels=in_channels, - filters=filters, - kernel_size=kernel_size, - use_bias=use_bias, - offset=offset, - normalize=normalize) - - def forward(self, features_list, in_positions_list, voxel_size=1.0): - out_positions_list = [] - for in_positions in in_positions_list: - out_positions_list.append(calculate_grid(in_positions)) - - out_feat = [] - for feat, in_pos, out_pos in zip(features_list, in_positions_list, - out_positions_list): - out_feat.append(self.net(feat, in_pos, out_pos, voxel_size)) - - out_positions_list = [out / 2 for out in out_positions_list] - - return out_feat, out_positions_list - - def __name__(self): - return "Convolution" - - -class DeConvolution(nn.Module): - - def __init__(self, - in_channels, - filters, - kernel_size, - use_bias=False, - offset=None, - normalize=False): - super(DeConvolution, self).__init__() - - if offset is None: - if kernel_size[0] % 2: - offset = 0. - else: - offset = -0.5 - - offset = torch.full((3,), offset, dtype=torch.float32) - self.net = SparseConvTranspose(in_channels=in_channels, - filters=filters, - kernel_size=kernel_size, - use_bias=use_bias, - offset=offset, - normalize=normalize) - - def forward(self, - features_list, - in_positions_list, - out_positions_list, - voxel_size=1.0): - out_feat = [] - for feat, in_pos, out_pos in zip(features_list, in_positions_list, - out_positions_list): - out_feat.append(self.net(feat, in_pos, out_pos, voxel_size)) - - return out_feat - - def __name__(self): - return "DeConvolution" - - -class ConcatFeat(nn.Module): - - def __init__(self): - super(ConcatFeat, self).__init__() - - def __name__(self): - return "ConcatFeat" - - def forward(self, feat): - return feat - - -class JoinFeat(nn.Module): - - def __init__(self): - super(JoinFeat, self).__init__() - - def __name__(self): - return "JoinFeat" - - def forward(self, feat_cat, feat): - out = [] - for a, b in zip(feat_cat, feat): - out.append(torch.cat([a, b], -1)) - - return out - - -class NetworkInNetwork(nn.Module): - - def __init__(self, nIn, nOut, bias=False): - super(NetworkInNetwork, self).__init__() - if nIn == nOut: - self.linear = nn.Identity() - else: - self.linear = nn.Linear(nIn, nOut, bias=bias) - - def forward(self, inputs): - out = [] - for inp in inputs: - out.append(self.linear(inp)) - - return out - - -class ResidualBlock(nn.Module): - - def __init__(self, nIn, nOut): - super(ResidualBlock, self).__init__() - - self.lin = NetworkInNetwork(nIn, nOut) - - self.batch_norm1 = BatchNormBlock(nIn) - self.relu1 = ReLUBlock() - self.sub_sparse_conv1 = SubmanifoldSparseConv(in_channels=nIn, - filters=nOut, - kernel_size=[3, 3, 3]) - - self.batch_norm2 = BatchNormBlock(nOut) - self.relu2 = ReLUBlock() - self.sub_sparse_conv2 = SubmanifoldSparseConv(in_channels=nOut, - filters=nOut, - kernel_size=[3, 3, 3]) - - def forward(self, feat_list, pos_list): - out1 = self.lin(feat_list) - feat_list = self.batch_norm1(feat_list) - feat_list = self.relu1(feat_list) - feat_list = self.sub_sparse_conv1(feat_list, pos_list) - feat_list = self.batch_norm2(feat_list) - feat_list = self.relu2(feat_list) - out2 = self.sub_sparse_conv2(feat_list, pos_list) - - return [a + b for a, b in zip(out1, out2)] - - def __name__(self): - return "ResidualBlock" - - -class UNet(nn.Module): - - def __init__(self, - conv_block_reps, - nPlanes, - residual_blocks=False, - downsample=[2, 2], - leakiness=0): - super(UNet, self).__init__() - self.net = nn.ModuleList( - self.get_UNet(nPlanes, residual_blocks, conv_block_reps)) - self.residual_blocks = residual_blocks - - @staticmethod - def block(layers, a, b, residual_blocks): - if residual_blocks: - layers.append(ResidualBlock(a, b)) - - else: - layers.append(BatchNormBlock(a)) - layers.append(ReLUBlock()) - layers.append( - SubmanifoldSparseConv(in_channels=a, - filters=b, - kernel_size=[3, 3, 3])) - - @staticmethod - def get_UNet(nPlanes, residual_blocks, conv_block_reps): - layers = [] - for i in range(conv_block_reps): - UNet.block(layers, nPlanes[0], nPlanes[0], residual_blocks) - - if len(nPlanes) > 1: - layers.append(ConcatFeat()) - layers.append(BatchNormBlock(nPlanes[0])) - layers.append(ReLUBlock()) - layers.append( - Convolution(in_channels=nPlanes[0], - filters=nPlanes[1], - kernel_size=[2, 2, 2])) - layers = layers + UNet.get_UNet(nPlanes[1:], residual_blocks, - conv_block_reps) - layers.append(BatchNormBlock(nPlanes[1])) - layers.append(ReLUBlock()) - layers.append( - DeConvolution(in_channels=nPlanes[1], - filters=nPlanes[0], - kernel_size=[2, 2, 2])) - - layers.append(JoinFeat()) - - for i in range(conv_block_reps): - UNet.block(layers, nPlanes[0] * (2 if i == 0 else 1), - nPlanes[0], residual_blocks) - - return layers - - def forward(self, pos_list, feat_list): - conv_pos = [] - concat_feat = [] - for module in self.net: - if isinstance(module, BatchNormBlock): - feat_list = module(feat_list) - elif isinstance(module, ReLUBlock): - feat_list = module(feat_list) - - elif isinstance(module, ResidualBlock): - feat_list = module(feat_list, pos_list) - - elif isinstance(module, SubmanifoldSparseConv): - feat_list = module(feat_list, pos_list) - - elif isinstance(module, Convolution): - conv_pos.append([pos.clone() for pos in pos_list]) - feat_list, pos_list = module(feat_list, pos_list) - - elif isinstance(module, DeConvolution): - feat_list = module(feat_list, [2 * pos for pos in pos_list], - conv_pos[-1]) - pos_list = conv_pos.pop() - - elif isinstance(module, ConcatFeat): - concat_feat.append([feat.clone() for feat in module(feat_list)]) - - elif isinstance(module, JoinFeat): - feat_list = module(concat_feat.pop(), feat_list) - - else: - raise Exception("Unknown module {}".format(module)) - - return feat_list - - -def load_unet_wts(net, path): - wts = list(torch.load(path).values()) - state_dict = net.state_dict() - i = 0 - for key in state_dict: - if 'offset' in key or 'tracked' in key: - continue - if len(wts[i].shape) == 4: - shp = wts[i].shape - state_dict[key] = np.transpose( - wts[i].reshape(int(shp[0]**(1 / 3)), int(shp[0]**(1 / 3)), - int(shp[0]**(1 / 3)), shp[-2], shp[-1]), - (2, 1, 0, 3, 4)) - else: - state_dict[key] = wts[i] - i += 1 - - net.load_state_dict(state_dict) diff --git a/ml3d/torch/modules/losses/__init__.py b/ml3d/torch/modules/losses/__init__.py index 00d0a7ec9..9582cf4b5 100644 --- a/ml3d/torch/modules/losses/__init__.py +++ b/ml3d/torch/modules/losses/__init__.py @@ -1,11 +1,11 @@ """Loss modules""" -from .semseg_loss import filter_valid_label, SemSegLoss, SemSegLossV2 +from .semseg_loss import filter_valid_label, SemSegLoss from .cross_entropy import CrossEntropyLoss from .focal_loss import FocalLoss from .smooth_L1 import SmoothL1Loss __all__ = [ 'filter_valid_label', 'SemSegLoss', 'CrossEntropyLoss', 'FocalLoss', - 'SmoothL1Loss', 'SemSegLossV2' + 'SmoothL1Loss' ] diff --git a/ml3d/torch/modules/losses/semseg_loss.py b/ml3d/torch/modules/losses/semseg_loss.py index 56eda4704..8b7846f65 100644 --- a/ml3d/torch/modules/losses/semseg_loss.py +++ b/ml3d/torch/modules/losses/semseg_loss.py @@ -52,30 +52,3 @@ def __init__(self, pipeline, model, dataset, device): self.weighted_CrossEntropyLoss = nn.CrossEntropyLoss(weight=weights) else: self.weighted_CrossEntropyLoss = nn.CrossEntropyLoss() - - -class SemSegLossV2(object): - """Loss functions for multi head semantic segmentation.""" - - def __init__(self, - num_heads, - num_classes, - ignored_labels=[], - device='cpu', - weights=None): - super(SemSegLossV2, self).__init__() - # weighted_CrossEntropyLoss - self.weighted_CrossEntropyLoss = [] - - for i in range(num_heads): - if weights is not None and len(weights[i]) != 0: - wts = DataProcessing.get_class_weights(weights[i])[0] - assert len(wts) == num_classes[ - i], f"num_classes : {num_classes[i]} is not equal to number of class weights : {len(wts)}" - wts = torch.tensor(wts) - else: - wts = torch.ones(num_classes[i]) - wts[ignored_labels[i]] = 0 - wts = wts.to(torch.float).to(device) - self.weighted_CrossEntropyLoss.append( - nn.CrossEntropyLoss(weight=wts)) diff --git a/ml3d/torch/pipelines/__init__.py b/ml3d/torch/pipelines/__init__.py index afd906e28..e68b13df5 100644 --- a/ml3d/torch/pipelines/__init__.py +++ b/ml3d/torch/pipelines/__init__.py @@ -2,8 +2,5 @@ from .semantic_segmentation import SemanticSegmentation from .object_detection import ObjectDetection -from .semantic_segmentation_multi_head import SemanticSegmentationMultiHead -__all__ = [ - 'SemanticSegmentation', 'ObjectDetection', 'SemanticSegmentationMultiHead' -] +__all__ = ['SemanticSegmentation', 'ObjectDetection'] diff --git a/ml3d/torch/pipelines/semantic_segmentation.py b/ml3d/torch/pipelines/semantic_segmentation.py index 355c34bad..c3a304566 100644 --- a/ml3d/torch/pipelines/semantic_segmentation.py +++ b/ml3d/torch/pipelines/semantic_segmentation.py @@ -470,7 +470,6 @@ def run_train(self): desc = "training - Epoch: %d, loss: %.3f" % (epoch, loss.cpu().item()) - # if rank == 0: progress_bar.set_description(desc) progress_bar.refresh() diff --git a/scripts/preprocess_nuscenes.py b/scripts/preprocess_nuscenes.py index 2e1bce21a..c2ad253c9 100644 --- a/scripts/preprocess_nuscenes.py +++ b/scripts/preprocess_nuscenes.py @@ -64,11 +64,6 @@ def __init__(self, dataset_path, out_path, version='v1.0-trainval'): dataroot=dataset_path, verbose=True) - ## Get semantic label stats - # nusc = self.nusc - # print(nusc.list_lidarseg_categories(sort_by='count')) - # print(nusc.lidarseg_idx2name_mapping) - if version == 'v1.0-trainval': train_scenes = splits.train val_scenes = splits.val diff --git a/scripts/preprocess_waymo.py b/scripts/preprocess_waymo.py index bc0b604a6..6a32dde7d 100644 --- a/scripts/preprocess_waymo.py +++ b/scripts/preprocess_waymo.py @@ -1,26 +1,25 @@ try: from waymo_open_dataset import dataset_pb2 - from waymo_open_dataset.utils import range_image_utils, transform_utils - from waymo_open_dataset.utils.frame_utils import \ - parse_range_image_and_camera_projection except ImportError: raise ImportError( - 'Please clone "https://github.com/waymo-research/waymo-open-dataset.git" ' - 'checkout branch "r1.3", and append its path to PYTHONPATH ' + 'Please run "pip install waymo-open-dataset-tf-2-1-0==1.2.0" ' 'to install the official devkit first.') import logging import numpy as np import os, sys, glob, pickle -import argparse -import tensorflow as tf -import matplotlib.image as mpimg - from pathlib import Path from os.path import join, exists, dirname, abspath from os import makedirs +import random +import argparse +import tensorflow as tf +import matplotlib.image as mpimg from multiprocessing import Pool from tqdm import tqdm +from waymo_open_dataset.utils import range_image_utils, transform_utils +from waymo_open_dataset.utils.frame_utils import \ + parse_range_image_and_camera_projection def parse_args(): @@ -39,10 +38,10 @@ def parse_args(): default=16, type=int) - parser.add_argument('--split', - help='One of {train, val, test} (default train)', - default='train', - type=str) + parser.add_argument('--is_test', + help='True for processing test data (default False)', + default=False, + type=bool) args = parser.parse_args() @@ -59,25 +58,6 @@ class Waymo2KITTI(): """Waymo to KITTI converter. This class converts tfrecord files from Waymo dataset to KITTI format. - KITTI format : (type, truncated, occluded, alpha, bbox, dimensions(3), location(3), - rotation_y(1), score(1, optional)) - type (string): Describes the type of object. - truncated (float): Ranges from 0(non-truncated) to 1(truncated). - occluded (int): Integer(0, 1, 2, 3) signifies state fully visible, partly - occluded, largely occluded, unknown. - alpha (float): Observation angle of object, ranging [-pi..pi]. - bbox (float): 2d bounding box of object in the image. - dimensions (float): 3D object dimensions: h, w, l in meters. - location (float): 3D object location: x,y,z in camera coordinates (in meters). - rotation_y (float): rotation around Y-axis in camera coordinates [-pi..pi]. - score (float): Only for predictions, indicating confidence in detection. - - Conversion writes following files: - pointcloud(np.float32) : pointcloud data with shape [N, 6]. Consists of - (x, y, z, intensity, elongation, timestamp). - images(np.uint8): camera images are saved if `write_image` is True. - calibrations(np.float32): Intinsic and Extrinsic matrix for all cameras. - label(np.float32): Bounding box information in KITTI format. Args: dataset_path (str): Directory to load waymo raw data. @@ -86,11 +66,11 @@ class Waymo2KITTI(): is_test (bool): Whether in the test_mode. Default: False. """ - def __init__(self, dataset_path, save_dir='', workers=8, split='train'): + def __init__(self, dataset_path, save_dir='', workers=8, is_test=False): - self.write_image = False + self.write_image = True self.filter_empty_3dboxes = True - self.filter_no_label_zone_points = False + self.filter_no_label_zone_points = True self.classes = ['VEHICLE', 'PEDESTRIAN', 'SIGN', 'CYCLIST'] @@ -106,8 +86,8 @@ def __init__(self, dataset_path, save_dir='', workers=8, split='train'): self.dataset_path = dataset_path self.save_dir = save_dir self.workers = int(workers) - self.is_test = split == 'test' - self.prefix = split + '_' + self.is_test = is_test + self.prefix = '' self.save_track_id = False self.tfrecord_files = sorted( @@ -157,6 +137,7 @@ def process_one(self, file_idx): if (self.selected_waymo_locations is not None and frame.context.stats.location not in self.selected_waymo_locations): + print("continue") continue if self.write_image: @@ -172,6 +153,8 @@ def __len__(self): return len(self.tfrecord_files) def save_image(self, frame, file_idx, frame_idx): + self.prefix = '' + for img in frame.images: img_path = Path(self.image_save_dir + str(img.name - 1)) / ( self.prefix + str(file_idx).zfill(3) + str(frame_idx).zfill(3) + @@ -227,6 +210,7 @@ def save_calib(self, frame, file_idx, frame_idx): f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', 'w+') as fp_calib: fp_calib.write(calib_context) + fp_calib.close() def save_pose(self, frame, file_idx, frame_idx): pose = np.array(frame.pose.transform).reshape(4, 4) @@ -244,6 +228,7 @@ def save_label(self, frame, file_idx, frame_idx): for labels in frame.projected_lidar_labels: name = labels.name for label in labels.labels: + # TODO: need a workaround as bbox may not belong to front cam bbox = [ label.box.center_x - label.box.length / 2, label.box.center_y - label.box.width / 2, @@ -272,6 +257,9 @@ def save_label(self, frame, file_idx, frame_idx): if my_type not in self.selected_waymo_classes: continue + # if self.filter_empty_3dboxes and obj.num_lidar_points_in_box < 1: + # continue + height = obj.box.height width = obj.box.width length = obj.box.length @@ -280,6 +268,11 @@ def save_label(self, frame, file_idx, frame_idx): y = obj.box.center_y z = obj.box.center_z + # # project bounding box to the virtual reference frame + # pt_ref = self.T_velo_to_front_cam @ \ + # np.array([x, y, z, 1]).reshape((4, 1)) + # x, y, z, _ = pt_ref.flatten().tolist() + rotation_y = -obj.box.heading - np.pi / 2 track_id = obj.id @@ -313,7 +306,7 @@ def save_label(self, frame, file_idx, frame_idx): fp_label_all.close() def save_lidar(self, frame, file_idx, frame_idx): - range_images, camera_projections, seg_labels, range_image_top_pose = parse_range_image_and_camera_projection( + range_images, camera_projections, range_image_top_pose = parse_range_image_and_camera_projection( frame) # First return @@ -321,7 +314,6 @@ def save_lidar(self, frame, file_idx, frame_idx): self.convert_range_image_to_point_cloud( frame, range_images, - seg_labels, camera_projections, range_image_top_pose, ri_index=0 @@ -335,7 +327,6 @@ def save_lidar(self, frame, file_idx, frame_idx): self.convert_range_image_to_point_cloud( frame, range_images, - seg_labels, camera_projections, range_image_top_pose, ri_index=1 @@ -360,7 +351,6 @@ def save_lidar(self, frame, file_idx, frame_idx): def convert_range_image_to_point_cloud(self, frame, range_images, - seg_labels, camera_projections, range_image_top_pose, ri_index=0): @@ -370,7 +360,6 @@ def convert_range_image_to_point_cloud(self, cp_points = [] intensity = [] elongation = [] - semseg_labels = [] frame_pose = tf.convert_to_tensor( value=np.reshape(np.array(frame.pose.transform), [4, 4])) @@ -391,7 +380,6 @@ def convert_range_image_to_point_cloud(self, range_image_top_pose_tensor_translation) for c in calibrations: range_image = range_images[c.name][ri_index] - seg_label = seg_labels[c.name][ri_index] if len(c.beam_inclinations) == 0: beam_inclinations = range_image_utils.compute_inclination( tf.constant( @@ -428,12 +416,9 @@ def convert_range_image_to_point_cloud(self, frame_pose=frame_pose_local) range_image_cartesian = tf.squeeze(range_image_cartesian, axis=0) - print(range_image_cartesian.shape) points_tensor = tf.gather_nd(range_image_cartesian, tf.compat.v1.where(range_image_mask)) - print(points_tensor.shape) - print(seg_label.shape) - exit(0) + cp = camera_projections[c.name][ri_index] cp_tensor = tf.reshape(tf.convert_to_tensor(value=cp.data), cp.shape.dims) @@ -475,8 +460,6 @@ def cart_to_homo(mat): out_path = args.out_path if out_path is None: args.out_path = args.dataset_path - if args.split not in ['train', 'val', 'test']: - raise ValueError("split must be one of {train, val, test}") converter = Waymo2KITTI(args.dataset_path, args.out_path, args.workers, - args.split) + args.is_test) converter.convert() diff --git a/scripts/preprocess_waymo_semseg.py b/scripts/preprocess_waymo_semseg.py index 70b7f92be..ae4077610 100644 --- a/scripts/preprocess_waymo_semseg.py +++ b/scripts/preprocess_waymo_semseg.py @@ -151,11 +151,6 @@ def create_folder(self): def convert(self): print(f"Start converting {len(self)} files ...") - # for i in tqdm(range(len(self))): - # self.process_one(i) - # with Pool(self.workers) as p: - # tqdm(p.imap(self.process_one, [i for i in range(len(self))]), total=len(self)) - # p.map(self.process_one, [i for i in range(len(self))]) process_map(self.process_one, range(len(self)), max_workers=self.workers) From 50c8c20e65ae3bc962d5a1ff63c05fbac802b0f1 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Tue, 20 Sep 2022 20:38:53 +0530 Subject: [PATCH 48/50] fix docstring --- ml3d/datasets/nuscenes_semseg.py | 1 - 1 file changed, 1 deletion(-) diff --git a/ml3d/datasets/nuscenes_semseg.py b/ml3d/datasets/nuscenes_semseg.py index 244672f2c..a8bad47c5 100644 --- a/ml3d/datasets/nuscenes_semseg.py +++ b/ml3d/datasets/nuscenes_semseg.py @@ -120,7 +120,6 @@ def get_label_to_names(): A dict where keys are label numbers and values are the corresponding names. """ - classes = "ignore, barrier, bicycle, bus, car, construction_vehicle, motorcycle, pedestrian, traffic_cone, trailer, trucl, driveable_surface, other_flat, sidewalk, terrain, manmade, vegetation" classes = classes.replace(', ', ',').split(',') label_to_names = {} From 5c2ba3201f4bc9e4095ed665f0efd6db2e8641d6 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Thu, 22 Sep 2022 14:41:38 +0530 Subject: [PATCH 49/50] address lgtm --- ml3d/datasets/kitti360.py | 10 ++-------- ml3d/datasets/nuscenes_semseg.py | 12 +++++------- ml3d/datasets/waymo_semseg.py | 17 ++++++----------- scripts/preprocess_waymo.py | 10 ++++------ 4 files changed, 17 insertions(+), 32 deletions(-) diff --git a/ml3d/datasets/kitti360.py b/ml3d/datasets/kitti360.py index d7e16d2b8..7974b5bd1 100644 --- a/ml3d/datasets/kitti360.py +++ b/ml3d/datasets/kitti360.py @@ -1,16 +1,12 @@ import numpy as np -import pandas as pd -import os, pickle +import os import logging import open3d as o3d from pathlib import Path -from os.path import join, exists, dirname, abspath, isdir -from sklearn.neighbors import KDTree -from tqdm import tqdm +from os.path import join, exists from glob import glob -from .utils import DataProcessing, get_min_bbox, BEVBox3D from .base_dataset import BaseDataset, BaseDatasetSplit from ..utils import make_dir, DATASET @@ -57,8 +53,6 @@ def __init__(self, ignored_label_inds=ignored_label_inds, **kwargs) - cfg = self.cfg - self.label_to_names = self.get_label_to_names() self.num_classes = len(self.label_to_names) self.label_values = np.sort([k for k, v in self.label_to_names.items()]) diff --git a/ml3d/datasets/nuscenes_semseg.py b/ml3d/datasets/nuscenes_semseg.py index a8bad47c5..36b57760f 100644 --- a/ml3d/datasets/nuscenes_semseg.py +++ b/ml3d/datasets/nuscenes_semseg.py @@ -1,15 +1,13 @@ import os import pickle -from os.path import join -from pathlib import Path import logging import numpy as np -from scipy.spatial.transform import Rotation as R + +from os.path import join +from pathlib import Path from .base_dataset import BaseDataset from ..utils import DATASET -from .utils import BEVBox3D -import open3d as o3d log = logging.getLogger(__name__) @@ -187,7 +185,7 @@ def get_split_list(self, split): raise ValueError("Invalid split {}".format(split)) - def is_tested(): + def is_tested(self): """Checks if a datum in the dataset has been tested. Args: @@ -200,7 +198,7 @@ def is_tested(): """ pass - def save_test_result(): + def save_test_result(self): """Saves the output of a model. Args: diff --git a/ml3d/datasets/waymo_semseg.py b/ml3d/datasets/waymo_semseg.py index 44946f2fe..b3bbd1b26 100644 --- a/ml3d/datasets/waymo_semseg.py +++ b/ml3d/datasets/waymo_semseg.py @@ -1,14 +1,12 @@ import numpy as np -import os, argparse, pickle, sys -from os.path import exists, join, isfile, dirname, abspath, split +import logging + +from os.path import join from pathlib import Path from glob import glob -import logging -import yaml from .base_dataset import BaseDataset, BaseDatasetSplit -from ..utils import Config, make_dir, DATASET -from .utils import BEVBox3D +from ..utils import DATASET log = logging.getLogger(__name__) @@ -125,12 +123,9 @@ def get_split_list(self, split): 'all'. """ cfg = self.cfg - dataset_path = cfg.dataset_path - file_list = [] if split in ['train', 'training']: return self.train_files - seq_list = cfg.training_split elif split in ['test', 'testing']: return self.test_files elif split in ['val', 'validation']: @@ -140,7 +135,7 @@ def get_split_list(self, split): else: raise ValueError("Invalid split {}".format(split)) - def is_tested(attr): + def is_tested(self, attr): """Checks if a datum in the dataset has been tested. Args: @@ -152,7 +147,7 @@ def is_tested(attr): """ raise NotImplementedError() - def save_test_result(results, attr): + def save_test_result(self, results, attr): """Saves the output of a model. Args: diff --git a/scripts/preprocess_waymo.py b/scripts/preprocess_waymo.py index 6a32dde7d..b34d26fc4 100644 --- a/scripts/preprocess_waymo.py +++ b/scripts/preprocess_waymo.py @@ -7,16 +7,14 @@ import logging import numpy as np -import os, sys, glob, pickle +import glob +import argparse +import tensorflow as tf + from pathlib import Path from os.path import join, exists, dirname, abspath from os import makedirs -import random -import argparse -import tensorflow as tf -import matplotlib.image as mpimg from multiprocessing import Pool -from tqdm import tqdm from waymo_open_dataset.utils import range_image_utils, transform_utils from waymo_open_dataset.utils.frame_utils import \ parse_range_image_and_camera_projection From e97174792feef285d705284d67b3859335c28977 Mon Sep 17 00:00:00 2001 From: Sanskar Agrawal Date: Thu, 22 Sep 2022 10:13:08 -0700 Subject: [PATCH 50/50] update preprocess --- scripts/preprocess_waymo.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/scripts/preprocess_waymo.py b/scripts/preprocess_waymo.py index b34d26fc4..0f677aabd 100644 --- a/scripts/preprocess_waymo.py +++ b/scripts/preprocess_waymo.py @@ -155,7 +155,7 @@ def save_image(self, frame, file_idx, frame_idx): for img in frame.images: img_path = Path(self.image_save_dir + str(img.name - 1)) / ( - self.prefix + str(file_idx).zfill(3) + str(frame_idx).zfill(3) + + self.prefix + str(file_idx).zfill(4) + str(frame_idx).zfill(4) + '.npy') image = tf.io.decode_jpeg(img.image).numpy() @@ -205,7 +205,7 @@ def save_calib(self, frame, file_idx, frame_idx): with open( f'{self.calib_save_dir}/{self.prefix}' + - f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.txt', 'w+') as fp_calib: fp_calib.write(calib_context) fp_calib.close() @@ -214,13 +214,13 @@ def save_pose(self, frame, file_idx, frame_idx): pose = np.array(frame.pose.transform).reshape(4, 4) np.savetxt( join(f'{self.pose_save_dir}/{self.prefix}' + - f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt'), + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.txt'), pose) def save_label(self, frame, file_idx, frame_idx): fp_label_all = open( f'{self.label_all_save_dir}/{self.prefix}' + - f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', 'w+') + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.txt', 'w+') id_to_bbox = dict() id_to_name = dict() for labels in frame.projected_lidar_labels: @@ -295,7 +295,7 @@ def save_label(self, frame, file_idx, frame_idx): fp_label = open( f'{self.label_save_dir}{name}/{self.prefix}' + - f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', 'a') + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.txt', 'a') fp_label.write(line) fp_label.close() @@ -304,7 +304,7 @@ def save_label(self, frame, file_idx, frame_idx): fp_label_all.close() def save_lidar(self, frame, file_idx, frame_idx): - range_images, camera_projections, range_image_top_pose = parse_range_image_and_camera_projection( + range_images, camera_projections, _, range_image_top_pose = parse_range_image_and_camera_projection( frame) # First return @@ -343,7 +343,7 @@ def save_lidar(self, frame, file_idx, frame_idx): (points, intensity, elongation, timestamp)) pc_path = f'{self.point_cloud_save_dir}/{self.prefix}' + \ - f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.bin' + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.bin' point_cloud.astype(np.float32).tofile(pc_path) def convert_range_image_to_point_cloud(self,