diff --git a/ml3d/configs/default_cfgs/kitti360.yml b/ml3d/configs/default_cfgs/kitti360.yml new file mode 100644 index 000000000..6808e3861 --- /dev/null +++ b/ml3d/configs/default_cfgs/kitti360.yml @@ -0,0 +1,7 @@ +name: KITTI360 +dataset_path: /Users/sanskara/data/kitti360/ +cache_dir: ./logs/cache +class_weights: [] +ignored_label_inds: [] +test_result_folder: ./test +use_cache: False diff --git a/ml3d/configs/default_cfgs/nuscenes_semseg.yml b/ml3d/configs/default_cfgs/nuscenes_semseg.yml new file mode 100644 index 000000000..824c797cb --- /dev/null +++ b/ml3d/configs/default_cfgs/nuscenes_semseg.yml @@ -0,0 +1,6 @@ +name: NuScenesSemSeg +dataset_path: # path/to/your/dataset +cache_dir: ./logs/cache +class_weights: [282265, 7676, 120, 3754, 31974, 1321, 346, 1898, 624, 4537, 13282, 260911, 6588, 57567, 56670, 146511, 100633] +ignored_label_inds: [0] +use_cache: False diff --git a/ml3d/configs/default_cfgs/semantickitti.yml b/ml3d/configs/default_cfgs/semantickitti.yml index bc46ab895..cd26c0429 100644 --- a/ml3d/configs/default_cfgs/semantickitti.yml +++ b/ml3d/configs/default_cfgs/semantickitti.yml @@ -1,13 +1,6 @@ name: SemanticKITTI dataset_path: # path/to/your/dataset cache_dir: ./logs/cache +ignored_label_inds: [0] use_cache: false -class_weights: [55437630, 320797, 541736, 2578735, 3274484, 552662, -184064, 78858, 240942562, 17294618, 170599734, 6369672, 230413074, 101130274, -476491114, 9833174, 129609852, 4506626, 1168181] -test_result_folder: ./test -test_split: ['11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21'] -training_split: ['00', '01', '02', '03', '04', '05', '06', '07', '09', '10'] -validation_split: ['08'] -all_split: ['00', '01', '02', '03', '04', '05', '06', '07', '09', -'08', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21'] +class_weights: [101665, 157022, 631, 1516, 5012, 7085, 1043, 457, 176, 693044, 53132, 494988, 12829, 459669, 236069, 924425, 22780, 255213, 9664, 2024] diff --git a/ml3d/configs/default_cfgs/waymo_semseg.yml b/ml3d/configs/default_cfgs/waymo_semseg.yml new file mode 100644 index 000000000..8658c4dae --- /dev/null +++ b/ml3d/configs/default_cfgs/waymo_semseg.yml @@ -0,0 +1,6 @@ +name: WaymoSemSeg +dataset_path: # path/to/your/dataset +cache_dir: ./logs/cache +class_weights: [513255, 341079, 39946, 28066, 17254, 104, 1169, 31335, 23359, 2924, 43680, 2149, 483, 639, 1394353, 858409, 90903, 52484, 884591, 24487, 21477, 322212, 229034] +ignored_label_inds: [0] +use_cache: False diff --git a/ml3d/configs/pointtransformer_waymo.yml b/ml3d/configs/pointtransformer_waymo.yml new file mode 100644 index 000000000..30605046c --- /dev/null +++ b/ml3d/configs/pointtransformer_waymo.yml @@ -0,0 +1,39 @@ +dataset: + name: WaymoSemSeg + dataset_path: # path/to/your/dataset + cache_dir: ./logs/cache + class_weights: [] + ignored_label_inds: [0] + test_result_folder: ./test + use_cache: false +model: + name: PointTransformer + batcher: ConcatBatcher + ckpt_path: # path/to/your/checkpoint + in_channels: 3 + blocks: [2, 3, 4, 6, 3] + num_classes: 23 + voxel_size: 0.04 + max_voxels: 50000 + ignored_label_inds: [-1] + augment: {} +pipeline: + name: SemanticSegmentation + optimizer: + lr: 0.02 + momentum: 0.9 + weight_decay: 0.0001 + batch_size: 2 + learning_rate: 0.01 + main_log_dir: ./logs + max_epoch: 512 + save_ckpt_freq: 3 + scheduler_gamma: 0.99 + test_batch_size: 1 + train_sum_dir: train_log + val_batch_size: 2 + summary: + record_for: [] + max_pts: + use_reference: false + max_outputs: 1 diff --git a/ml3d/configs/sparseconvunet_scannet.yml b/ml3d/configs/sparseconvunet_scannet.yml index 465d88ed5..e23635ca9 100644 --- a/ml3d/configs/sparseconvunet_scannet.yml +++ b/ml3d/configs/sparseconvunet_scannet.yml @@ -6,7 +6,7 @@ dataset: test_result_folder: ./test use_cache: False sampler: - name: 'SemSegRandomSampler' + name: None model: name: SparseConvUnet batcher: ConcatBatcher diff --git a/ml3d/configs/sparseconvunet_waymo.yml b/ml3d/configs/sparseconvunet_waymo.yml new file mode 100644 index 000000000..fac0da5c1 --- /dev/null +++ b/ml3d/configs/sparseconvunet_waymo.yml @@ -0,0 +1,39 @@ +dataset: + name: WaymoSemSeg + dataset_path: # path/to/your/dataset + cache_dir: ./logs/cache + class_weights: [] + ignored_label_inds: [0] + test_result_folder: ./test + use_cache: false +model: + name: SparseConvUnet + batcher: ConcatBatcher + ckpt_path: # path/to/your/checkpoint + multiplier: 32 + voxel_size: 0.02 + residual_blocks: True + conv_block_reps: 1 + in_channels: 3 + num_classes: 23 + grid_size: 4096 + ignored_label_inds: [0] + augment: {} +pipeline: + name: SemanticSegmentation + optimizer: + lr: 0.001 + betas: [0.9, 0.999] + batch_size: 2 + main_log_dir: ./logs + max_epoch: 256 + save_ckpt_freq: 3 + scheduler_gamma: 0.99 + test_batch_size: 1 + train_sum_dir: train_log + val_batch_size: 2 + summary: + record_for: [] + max_pts: + use_reference: false + max_outputs: 1 diff --git a/ml3d/datasets/__init__.py b/ml3d/datasets/__init__.py index 056b8b9f4..d1ca319bc 100644 --- a/ml3d/datasets/__init__.py +++ b/ml3d/datasets/__init__.py @@ -14,18 +14,21 @@ from .kitti import KITTI from .nuscenes import NuScenes +from .nuscenes_semseg import NuScenesSemSeg from .waymo import Waymo +from .waymo_semseg import WaymoSemSeg from .lyft import Lyft from .shapenet import ShapeNet from .argoverse import Argoverse from .scannet import Scannet from .sunrgbd import SunRGBD from .matterport_objects import MatterportObjects +from .kitti360 import KITTI360 __all__ = [ 'SemanticKITTI', 'S3DIS', 'Toronto3D', 'ParisLille3D', 'Semantic3D', 'Custom3D', 'utils', 'augment', 'samplers', 'KITTI', 'Waymo', 'NuScenes', 'Lyft', 'ShapeNet', 'SemSegRandomSampler', 'InferenceDummySplit', 'SemSegSpatiallyRegularSampler', 'Argoverse', 'Scannet', 'SunRGBD', - 'MatterportObjects' + 'MatterportObjects', 'WaymoSemSeg', 'KITTI360', 'NuScenesSemSeg' ] diff --git a/ml3d/datasets/base_dataset.py b/ml3d/datasets/base_dataset.py index bcc62e244..ec48550aa 100644 --- a/ml3d/datasets/base_dataset.py +++ b/ml3d/datasets/base_dataset.py @@ -127,10 +127,14 @@ def __init__(self, dataset, split='training'): if split in ['test']: sampler_cls = get_module('sampler', 'SemSegSpatiallyRegularSampler') else: - sampler_cfg = self.cfg.get('sampler', - {'name': 'SemSegRandomSampler'}) + sampler_cfg = self.cfg.get('sampler', {'name': None}) + if sampler_cfg['name'] == "None": + sampler_cfg['name'] = None sampler_cls = get_module('sampler', sampler_cfg['name']) - self.sampler = sampler_cls(self) + if sampler_cls is None: + self.sampler = None + else: + self.sampler = sampler_cls(self) @abstractmethod def __len__(self): diff --git a/ml3d/datasets/kitti360.py b/ml3d/datasets/kitti360.py new file mode 100644 index 000000000..7974b5bd1 --- /dev/null +++ b/ml3d/datasets/kitti360.py @@ -0,0 +1,230 @@ +import numpy as np +import os +import logging +import open3d as o3d + +from pathlib import Path +from os.path import join, exists +from glob import glob + +from .base_dataset import BaseDataset, BaseDatasetSplit +from ..utils import make_dir, DATASET + +log = logging.getLogger(__name__) + + +class KITTI360(BaseDataset): + """This class is used to create a dataset based on the KITTI 360 + dataset, and used in visualizer, training, or testing. + """ + + def __init__(self, + dataset_path, + name='KITTI360', + cache_dir='./logs/cache', + use_cache=False, + class_weights=[ + 3370714, 2856755, 4919229, 318158, 375640, 478001, 974733, + 650464, 791496, 88727, 1284130, 229758, 2272837 + ], + num_points=40960, + ignored_label_inds=[], + test_result_folder='./test', + **kwargs): + """Initialize the function by passing the dataset and other details. + + Args: + dataset_path: The path to the dataset to use (parent directory of data_3d_semantics). + name: The name of the dataset (KITTI360 in this case). + cache_dir: The directory where the cache is stored. + use_cache: Indicates if the dataset should be cached. + class_weights: The class weights to use in the dataset. + num_points: The maximum number of points to use when splitting the dataset. + ignored_label_inds: A list of labels that should be ignored in the dataset. + test_result_folder: The folder where the test results should be stored. + """ + super().__init__(dataset_path=dataset_path, + name=name, + cache_dir=cache_dir, + use_cache=use_cache, + class_weights=class_weights, + test_result_folder=test_result_folder, + num_points=num_points, + ignored_label_inds=ignored_label_inds, + **kwargs) + + self.label_to_names = self.get_label_to_names() + self.num_classes = len(self.label_to_names) + self.label_values = np.sort([k for k, v in self.label_to_names.items()]) + self.label_to_idx = {l: i for i, l in enumerate(self.label_values)} + self.ignored_labels = np.array([]) + + if not os.path.exists( + os.path.join( + dataset_path, + 'data_3d_semantics/train/2013_05_28_drive_train.txt')): + raise ValueError( + "Invalid Path, make sure dataset_path is the parent directory of data_3d_semantics." + ) + + with open( + os.path.join( + dataset_path, + 'data_3d_semantics/train/2013_05_28_drive_train.txt'), + 'r') as f: + train_paths = f.read().split('\n')[:-1] + train_paths = [os.path.join(dataset_path, p) for p in train_paths] + + with open( + os.path.join( + dataset_path, + 'data_3d_semantics/train/2013_05_28_drive_val.txt'), + 'r') as f: + val_paths = f.read().split('\n')[:-1] + val_paths = [os.path.join(dataset_path, p) for p in val_paths] + + self.train_files = train_paths + self.val_files = val_paths + self.test_files = sorted( + glob( + os.path.join(dataset_path, + 'data_3d_semantics/test/*/static/*.ply'))) + + @staticmethod + def get_label_to_names(): + """Returns a label to names dictionary object. + + Returns: + A dict where keys are label numbers and + values are the corresponding names. + """ + label_to_names = { + 0: 'ceiling', + 1: 'floor', + 2: 'wall', + 3: 'beam', + 4: 'column', + 5: 'window', + 6: 'door', + 7: 'table', + 8: 'chair', + 9: 'sofa', + 10: 'bookcase', + 11: 'board', + 12: 'clutter' + } + return label_to_names + + def get_split(self, split): + """Returns a dataset split. + + Args: + split: A string identifying the dataset split that is usually one of + 'training', 'test', 'validation', or 'all'. + + Returns: + A dataset split object providing the requested subset of the data. + """ + return KITTI360Split(self, split=split) + + def get_split_list(self, split): + if split in ['train', 'training']: + return self.train_files + elif split in ['val', 'validation']: + return self.val_files + elif split in ['test', 'testing']: + return self.test_files + elif split == 'all': + return self.train_files + self.val_files + self.test_files + else: + raise ValueError("Invalid split {}".format(split)) + + def is_tested(self, attr): + + cfg = self.cfg + name = attr['name'] + path = cfg.test_result_folder + store_path = join(path, self.name, name + '.npy') + if exists(store_path): + print("{} already exists.".format(store_path)) + return True + else: + return False + + """Saves the output of a model. + + Args: + results: The output of a model for the datum associated with the attribute passed. + attr: The attributes that correspond to the outputs passed in results. + """ + + def save_test_result(self, results, attr): + + cfg = self.cfg + name = attr['name'].split('.')[0] + path = cfg.test_result_folder + make_dir(path) + + pred = results['predict_labels'] + pred = np.array(pred) + + for ign in cfg.ignored_label_inds: + pred[pred >= ign] += 1 + + store_path = join(path, self.name, name + '.npy') + make_dir(Path(store_path).parent) + np.save(store_path, pred) + log.info("Saved {} in {}.".format(name, store_path)) + + +class KITTI360Split(BaseDatasetSplit): + """This class is used to create a split for KITTI360 dataset. + + Initialize the class. + + Args: + dataset: The dataset to split. + split: A string identifying the dataset split that is usually one of + 'training', 'test', 'validation', or 'all'. + **kwargs: The configuration of the model as keyword arguments. + + Returns: + A dataset split object providing the requested subset of the data. + """ + + def __init__(self, dataset, split='training'): + super().__init__(dataset, split=split) + log.info("Found {} pointclouds for {}".format(len(self.path_list), + split)) + + def __len__(self): + return len(self.path_list) + + def get_data(self, idx): + pc_path = self.path_list[idx] + + pc = o3d.t.io.read_point_cloud(pc_path) + + points = pc.point['positions'].numpy().astype(np.float32) + feat = pc.point['colors'].numpy().astype(np.float32) + labels = pc.point['semantic'].numpy().astype(np.int32).reshape((-1,)) + + data = { + 'point': points, + 'feat': feat, + 'label': labels, + } + + return data + + def get_attr(self, idx): + pc_path = Path(self.path_list[idx]) + name = pc_path.name.replace('.pkl', '') + + pc_path = str(pc_path) + split = self.split + attr = {'idx': idx, 'name': name, 'path': pc_path, 'split': split} + return attr + + +DATASET._register_module(KITTI360) diff --git a/ml3d/datasets/nuscenes_semseg.py b/ml3d/datasets/nuscenes_semseg.py new file mode 100644 index 000000000..36b57760f --- /dev/null +++ b/ml3d/datasets/nuscenes_semseg.py @@ -0,0 +1,254 @@ +import os +import pickle +import logging +import numpy as np + +from os.path import join +from pathlib import Path + +from .base_dataset import BaseDataset +from ..utils import DATASET + +log = logging.getLogger(__name__) + + +class NuScenesSemSeg(BaseDataset): + """This class is used to create a dataset based on the NuScenes 3D dataset, + and used in object detection, visualizer, training, or testing. + + The NuScenes 3D dataset is best suited for autonomous driving applications. + """ + + def __init__(self, + dataset_path, + info_path=None, + name='NuScenes', + cache_dir='./logs/cache', + use_cache=False, + **kwargs): + """Initialize the function by passing the dataset and other details. + + Args: + dataset_path: The path to the dataset to use. + info_path: The path to the file that includes information about the + dataset. This is default to dataset path if nothing is provided. + name: The name of the dataset (NuScenes in this case). + cache_dir: The directory where the cache is stored. + use_cache: Indicates if the dataset should be cached. + + Returns: + class: The corresponding class. + """ + if info_path is None: + info_path = dataset_path + + super().__init__(dataset_path=dataset_path, + info_path=info_path, + name=name, + cache_dir=cache_dir, + use_cache=use_cache, + **kwargs) + + cfg = self.cfg + + self.name = cfg.name + self.dataset_path = cfg.dataset_path + self.label_to_names = self.get_label_to_names() + self.num_classes = len(self.label_to_names) + + self.train_info = {} + self.test_info = {} + self.val_info = {} + + if os.path.exists(join(info_path, 'infos_train.pkl')): + self.train_info = pickle.load( + open(join(info_path, 'infos_train.pkl'), 'rb')) + + if os.path.exists(join(info_path, 'infos_val.pkl')): + self.val_info = pickle.load( + open(join(info_path, 'infos_val.pkl'), 'rb')) + + if os.path.exists(join(info_path, 'infos_test.pkl')): + self.test_info = pickle.load( + open(join(info_path, 'infos_test.pkl'), 'rb')) + + # It comes with 32 classes, but the nuscenes challenge merge similar classes and remove rare classes. + mapping = { + 1: 0, + 5: 0, + 7: 0, + 8: 0, + 10: 0, + 11: 0, + 13: 0, + 19: 0, + 20: 0, + 0: 0, + 29: 0, + 31: 0, + 9: 1, + 14: 2, + 15: 3, + 16: 3, + 17: 4, + 18: 5, + 21: 6, + 2: 7, + 3: 7, + 4: 7, + 6: 7, + 12: 8, + 22: 9, + 23: 10, + 24: 11, + 25: 12, + 26: 13, + 27: 14, + 28: 15, + 30: 16 + } + self.label_mapping = np.array( + [mapping[i] for i in range(0, len(mapping))], dtype=np.int32) + + @staticmethod + def get_label_to_names(): + """Returns a label to names dictionary object. + + Returns: + A dict where keys are label numbers and + values are the corresponding names. + """ + classes = "ignore, barrier, bicycle, bus, car, construction_vehicle, motorcycle, pedestrian, traffic_cone, trailer, trucl, driveable_surface, other_flat, sidewalk, terrain, manmade, vegetation" + classes = classes.replace(', ', ',').split(',') + label_to_names = {} + for i in range(len(classes)): + label_to_names[i] = classes[i] + + return label_to_names + + @staticmethod + def read_lidar(path): + """Reads lidar data from the path provided. + + Returns: + A data object with lidar information. + """ + assert Path(path).exists() + + return np.fromfile(path, dtype=np.float32).reshape(-1, 5) + + def read_lidarseg(self, path): + """Reads semantic data from the path provided. + + Returns: + A data object with semantic information. + """ + assert Path(path).exists() + + labels = np.fromfile(path, dtype=np.uint8).reshape(-1,).astype(np.int32) + + return self.label_mapping[labels] + + def get_split(self, split): + """Returns a dataset split. + + Args: + split: A string identifying the dataset split that is usually one of + 'training', 'test', 'validation', or 'all'. + + Returns: + A dataset split object providing the requested subset of the data. + """ + return NuScenesSemSegSplit(self, split=split) + + def get_split_list(self, split): + """Returns the list of data splits available. + + Args: + split: A string identifying the dataset split that is usually one of + 'training', 'test', 'validation', or 'all'. + + Returns: + A dataset split object providing the requested subset of the data. + + Raises: + ValueError: Indicates that the split name passed is incorrect. The + split name should be one of 'training', 'test', 'validation', or + 'all'. + """ + if split in ['train', 'training']: + return self.train_info + elif split in ['test', 'testing']: + return self.test_info + elif split in ['val', 'validation']: + return self.val_info + + raise ValueError("Invalid split {}".format(split)) + + def is_tested(self): + """Checks if a datum in the dataset has been tested. + + Args: + dataset: The current dataset to which the datum belongs to. + attr: The attribute that needs to be checked. + + Returns: + If the dataum attribute is tested, then return the path where the + attribute is stored; else, returns false. + """ + pass + + def save_test_result(self): + """Saves the output of a model. + + Args: + results: The output of a model for the datum associated with the + attribute passed. + attr: The attributes that correspond to the outputs passed in + results. + """ + pass + + +class NuScenesSemSegSplit(): + + def __init__(self, dataset, split='train'): + self.cfg = dataset.cfg + + self.infos = dataset.get_split_list(split) + self.path_list = [] + for info in self.infos: + self.path_list.append(info['lidar_path']) + + log.info("Found {} pointclouds for {}".format(len(self.infos), split)) + + self.split = split + self.dataset = dataset + + def __len__(self): + return len(self.infos) + + def get_data(self, idx): + info = self.infos[idx] + lidar_path = info['lidar_path'] + lidarseg_path = info['lidarseg_path'] + + pc = self.dataset.read_lidar(lidar_path) + feat = pc[:, 3:4] + pc = pc[:, :3] + lidarseg = self.dataset.read_lidarseg(lidarseg_path) + + data = {'point': pc, 'feat': feat, 'label': lidarseg} + + return data + + def get_attr(self, idx): + info = self.infos[idx] + pc_path = info['lidar_path'] + name = Path(pc_path).name.split('.')[0] + + attr = {'name': name, 'path': str(pc_path), 'split': self.split} + return attr + + +DATASET._register_module(NuScenesSemSeg) diff --git a/ml3d/datasets/waymo_semseg.py b/ml3d/datasets/waymo_semseg.py new file mode 100644 index 000000000..b3bbd1b26 --- /dev/null +++ b/ml3d/datasets/waymo_semseg.py @@ -0,0 +1,199 @@ +import numpy as np +import logging + +from os.path import join +from pathlib import Path +from glob import glob + +from .base_dataset import BaseDataset, BaseDatasetSplit +from ..utils import DATASET + +log = logging.getLogger(__name__) + + +class WaymoSemSeg(BaseDataset): + """This class is used to create a dataset based on the Waymo 3D dataset, and + used in object detection, visualizer, training, or testing. + + The Waymo 3D dataset is best suited for autonomous driving applications. + """ + + def __init__(self, + dataset_path, + name='Waymo', + cache_dir='./logs/cache', + use_cache=False, + **kwargs): + """Initialize the function by passing the dataset and other details. + + Args: + dataset_path: The path to the dataset to use. + name: The name of the dataset (Waymo in this case). + cache_dir: The directory where the cache is stored. + use_cache: Indicates if the dataset should be cached. + + Returns: + class: The corresponding class. + """ + super().__init__(dataset_path=dataset_path, + name=name, + cache_dir=cache_dir, + use_cache=use_cache, + **kwargs) + + cfg = self.cfg + + self.name = cfg.name + self.dataset_path = cfg.dataset_path + self.label_to_names = self.get_label_to_names() + self.num_classes = len(self.label_to_names) + self.shuffle = kwargs.get('shuffle', False) + + self.all_files = sorted( + glob(join(cfg.dataset_path, 'velodyne', '*.bin'))) + self.train_files = [] + self.val_files = [] + self.test_files = [] + + for f in self.all_files: + if 'train' in f: + self.train_files.append(f) + elif 'val' in f: + self.val_files.append(f) + elif 'test' in f: + self.test_files.append(f) + else: + log.warning( + f"Skipping {f}, prefix must be one of train, test or val.") + if self.shuffle: + log.info("Shuffling training files...") + self.rng.shuffle(self.train_files) + + @staticmethod + def get_label_to_names(): + """Returns a label to names dictionary object. + + Returns: + A dict where keys are label numbers and + values are the corresponding names. + """ + classes = "Undefined, Car, Truck, Bus, Other Vehicle, Motorcyclist, Bicyclist, Pedestrian, Sign, Traffic Light, Pole, Construction Cone, Bicycle, Motorcycle, Building, Vegetation, Tree Trunk, Curb, Road, Lane Marker, Other Ground, Walkable, Sidewalk" + classes = classes.replace(', ', ',').split(',') + label_to_names = {} + for i in range(len(classes)): + label_to_names[i] = classes[i] + + return label_to_names + + @staticmethod + def read_lidar(path): + """Reads lidar data from the path provided. + + Returns: + pc: pointcloud data with shape [N, 8], where + the format is x,y,z,intensity,elongation,timestamp,instance, semantic_label. + """ + return np.fromfile(path, dtype=np.float32).reshape(-1, 8) + + def get_split(self, split): + """Returns a dataset split. + + Args: + split: A string identifying the dataset split that is usually one of + 'training', 'test', 'validation', or 'all'. + + Returns: + A dataset split object providing the requested subset of the data. + """ + return WaymoSemSegSplit(self, split=split) + + def get_split_list(self, split): + """Returns the list of data splits available. + + Args: + split: A string identifying the dataset split that is usually one of + 'training', 'test', 'validation', or 'all'. + + Returns: + A dataset split object providing the requested subset of the data. + + Raises: + ValueError: Indicates that the split name passed is incorrect. The + split name should be one of 'training', 'test', 'validation', or + 'all'. + """ + cfg = self.cfg + + if split in ['train', 'training']: + return self.train_files + elif split in ['test', 'testing']: + return self.test_files + elif split in ['val', 'validation']: + return self.val_files + elif split in ['all']: + return self.train_files + self.val_files + self.test_files + else: + raise ValueError("Invalid split {}".format(split)) + + def is_tested(self, attr): + """Checks if a datum in the dataset has been tested. + + Args: + attr: The attribute that needs to be checked. + + Returns: + If the datum attribute is tested, then return the path where the + attribute is stored; else, returns false. + """ + raise NotImplementedError() + + def save_test_result(self, results, attr): + """Saves the output of a model. + + Args: + results: The output of a model for the datum associated with the attribute passed. + attr: The attributes that correspond to the outputs passed in results. + """ + raise NotImplementedError() + + +class WaymoSemSegSplit(BaseDatasetSplit): + + def __init__(self, dataset, split='train'): + super().__init__(dataset, split=split) + self.cfg = dataset.cfg + path_list = dataset.get_split_list(split) + log.info("Found {} pointclouds for {}".format(len(path_list), split)) + + self.path_list = path_list + self.split = split + self.dataset = dataset + + def __len__(self): + return len(self.path_list) + + def get_data(self, idx): + pc_path = self.path_list[idx] + + pc = self.dataset.read_lidar(pc_path) + feat = pc[:, 3:5] # intensity, elongation + label = pc[:, 7].astype(np.int32) + pc = pc[:, :3] + + data = { + 'point': pc, + 'feat': feat, + 'label': label, + } + + return data + + def get_attr(self, idx): + pc_path = self.path_list[idx] + name = Path(pc_path).name.split('.')[0] + + attr = {'name': name, 'path': pc_path, 'split': self.split} + return attr + + +DATASET._register_module(WaymoSemSeg) diff --git a/ml3d/torch/dataloaders/torch_sampler.py b/ml3d/torch/dataloaders/torch_sampler.py index 142f01fbc..78e0ebc45 100644 --- a/ml3d/torch/dataloaders/torch_sampler.py +++ b/ml3d/torch/dataloaders/torch_sampler.py @@ -15,4 +15,6 @@ def __len__(self): def get_sampler(sampler): + if sampler is None: + return None return TorchSamplerWrapper(sampler) diff --git a/ml3d/torch/models/point_transformer.py b/ml3d/torch/models/point_transformer.py index 804387628..7169c757a 100644 --- a/ml3d/torch/models/point_transformer.py +++ b/ml3d/torch/models/point_transformer.py @@ -718,6 +718,7 @@ def knn_batch(points, if points_row_splits.shape[0] != queries_row_splits.shape[0]: raise ValueError("KNN(points and queries must have same batch size)") + device = points.device points = points.cpu() queries = queries.cpu() @@ -730,9 +731,11 @@ def knn_batch(points, return_distances=True) if return_distances: return ans.neighbors_index.reshape( - -1, k).long().cuda(), ans.neighbors_distance.reshape(-1, k).cuda() + -1, + k).long().to(device), ans.neighbors_distance.reshape(-1, + k).to(device) else: - return ans.neighbors_index.reshape(-1, k).long().cuda() + return ans.neighbors_index.reshape(-1, k).long().to(device) def interpolation(points, diff --git a/ml3d/torch/models/sparseconvnet.py b/ml3d/torch/models/sparseconvnet.py index 86c555f33..8282bd1b6 100644 --- a/ml3d/torch/models/sparseconvnet.py +++ b/ml3d/torch/models/sparseconvnet.py @@ -115,6 +115,8 @@ def preprocess(self, data, attr): "SparseConvnet doesn't work without feature values.") feat = np.array(data['feat'], dtype=np.float32) + if feat.shape[1] < 3: + feat = np.concatenate([feat, np.ones([feat.shape[0], 1])], 1) # Scale to voxel size. points *= 1. / self.cfg.voxel_size # Scale = 1/voxel_size diff --git a/ml3d/torch/modules/metrics/semseg_metric.py b/ml3d/torch/modules/metrics/semseg_metric.py index 763f31fdf..472ab8a4b 100644 --- a/ml3d/torch/modules/metrics/semseg_metric.py +++ b/ml3d/torch/modules/metrics/semseg_metric.py @@ -13,6 +13,7 @@ def __init__(self): super(SemSegMetric, self).__init__() self.confusion_matrix = None self.num_classes = None + self.count = 0 def update(self, scores, labels): conf = self.get_confusion_matrix(scores, labels) @@ -22,6 +23,20 @@ def update(self, scores, labels): else: assert self.confusion_matrix.shape == conf.shape self.confusion_matrix += conf + self.count += 1 + + def __iadd__(self, otherMetric): + if self.confusion_matrix is None and otherMetric.confusion_matrix is None: + pass + elif self.confusion_matrix is None: + self.confusion_matrix = otherMetric.confusion_matrix + self.num_classes = otherMetric.num_classes + elif otherMetric.confusion_matrix is None: + pass + else: + self.confusion_matrix += otherMetric.confusion_matrix + self.count += len(otherMetric) + return self def acc(self): """Compute the per-class accuracies and the overall accuracy. @@ -90,6 +105,10 @@ def iou(self): def reset(self): self.confusion_matrix = None + self.count = 0 + + def __len__(self): + return self.count @staticmethod def get_confusion_matrix(scores, labels): diff --git a/ml3d/torch/pipelines/base_pipeline.py b/ml3d/torch/pipelines/base_pipeline.py index f466868b9..4e8d79e19 100644 --- a/ml3d/torch/pipelines/base_pipeline.py +++ b/ml3d/torch/pipelines/base_pipeline.py @@ -41,10 +41,6 @@ def __init__(self, self.rng = np.random.default_rng(kwargs.get('seed', None)) self.distributed = distributed - if self.distributed and self.name == "SemanticSegmentation": - raise NotImplementedError( - "Distributed training not implemented for SemanticSegmentation!" - ) self.rank = kwargs.get('rank', 0) diff --git a/ml3d/torch/pipelines/semantic_segmentation.py b/ml3d/torch/pipelines/semantic_segmentation.py index 232ce5a30..c3a304566 100644 --- a/ml3d/torch/pipelines/semantic_segmentation.py +++ b/ml3d/torch/pipelines/semantic_segmentation.py @@ -1,11 +1,12 @@ import logging +import numpy as np +import torch +import torch.distributed as dist + from os.path import exists, join from pathlib import Path from datetime import datetime - -import numpy as np from tqdm import tqdm -import torch from torch.utils.tensorboard import SummaryWriter from torch.utils.data import DataLoader @@ -159,6 +160,8 @@ def run_inference(self, data): with torch.no_grad(): for unused_step, inputs in enumerate(infer_loader): + if hasattr(inputs['data'], 'to'): + inputs['data'].to(device) results = model(inputs['data']) self.update_tests(infer_sampler, inputs, results) @@ -312,20 +315,20 @@ def update_tests(self, sampler, inputs, results): def run_train(self): torch.manual_seed(self.rng.integers(np.iinfo( np.int32).max)) # Random reproducible seed for torch + rank = self.rank # Rank for distributed training model = self.model device = self.device - model.device = device dataset = self.dataset cfg = self.cfg - model.to(device) - - log.info("DEVICE : {}".format(device)) - timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') + if rank == 0: + log.info("DEVICE : {}".format(device)) + timestamp = datetime.now().strftime('%Y-%m-%d_%H:%M:%S') - log_file_path = join(cfg.logs_dir, 'log_train_' + timestamp + '.txt') - log.info("Logging in file : {}".format(log_file_path)) - log.addHandler(logging.FileHandler(log_file_path)) + log_file_path = join(cfg.logs_dir, + 'log_train_' + timestamp + '.txt') + log.info("Logging in file : {}".format(log_file_path)) + log.addHandler(logging.FileHandler(log_file_path)) Loss = SemSegLoss(self, model, dataset, device) self.metric_train = SemSegMetric() @@ -335,6 +338,10 @@ def run_train(self): train_dataset = dataset.get_split('train') train_sampler = train_dataset.sampler + + valid_dataset = dataset.get_split('validation') + valid_sampler = valid_dataset.sampler + train_split = TorchDataloader(dataset=train_dataset, preprocess=model.preprocess, transform=model.transform, @@ -343,19 +350,6 @@ def run_train(self): steps_per_epoch=dataset.cfg.get( 'steps_per_epoch_train', None)) - train_loader = DataLoader( - train_split, - batch_size=cfg.batch_size, - sampler=get_sampler(train_sampler), - num_workers=cfg.get('num_workers', 2), - pin_memory=cfg.get('pin_memory', True), - collate_fn=self.batcher.collate_fn, - worker_init_fn=lambda x: np.random.seed(x + np.uint32( - torch.utils.data.get_worker_info().seed)) - ) # numpy expects np.uint32, whereas torch returns np.uint64. - - valid_dataset = dataset.get_split('validation') - valid_sampler = valid_dataset.sampler valid_split = TorchDataloader(dataset=valid_dataset, preprocess=model.preprocess, transform=model.transform, @@ -364,20 +358,46 @@ def run_train(self): steps_per_epoch=dataset.cfg.get( 'steps_per_epoch_valid', None)) + if self.distributed: + if train_sampler is not None or valid_sampler is not None: + raise NotImplementedError( + "Distributed training with sampler is not supported yet!") + train_sampler = torch.utils.data.distributed.DistributedSampler( + train_split) + valid_sampler = torch.utils.data.distributed.DistributedSampler( + valid_split) + + train_loader = DataLoader( + train_split, + batch_size=cfg.batch_size, + sampler=train_sampler + if self.distributed else get_sampler(train_sampler), + num_workers=cfg.get('num_workers', 0), + pin_memory=cfg.get('pin_memory', False), + collate_fn=self.batcher.collate_fn, + worker_init_fn=lambda x: np.random.seed(x + np.uint32( + torch.utils.data.get_worker_info().seed)) + ) # numpy expects np.uint32, whereas torch returns np.uint64. + valid_loader = DataLoader( valid_split, batch_size=cfg.val_batch_size, - sampler=get_sampler(valid_sampler), - num_workers=cfg.get('num_workers', 2), + sampler=valid_sampler + if self.distributed else get_sampler(valid_sampler), + num_workers=cfg.get('num_workers', 0), pin_memory=cfg.get('pin_memory', True), collate_fn=self.batcher.collate_fn, worker_init_fn=lambda x: np.random.seed(x + np.uint32( torch.utils.data.get_worker_info().seed))) + # Optimizer must be created after moving model to specific device. + model.to(self.device) + model.device = self.device + self.optimizer, self.scheduler = model.get_optimizer(cfg) is_resume = model.cfg.get('is_resume', True) - self.load_ckpt(model.cfg.ckpt_path, is_resume=is_resume) + start_ep = self.load_ckpt(model.cfg.ckpt_path, is_resume=is_resume) dataset_name = dataset.name if dataset is not None else '' tensorboard_dir = join( @@ -388,22 +408,36 @@ def run_train(self): runid + '_' + Path(tensorboard_dir).name) writer = SummaryWriter(self.tensorboard_dir) - self.save_config(writer) - log.info("Writing summary in {}.".format(self.tensorboard_dir)) - record_summary = cfg.get('summary').get('record_for', []) + if rank == 0: + self.save_config(writer) + log.info("Writing summary in {}.".format(self.tensorboard_dir)) - log.info("Started training") + # wrap model for multiple GPU + if self.distributed: + model = torch.nn.parallel.DistributedDataParallel( + model, device_ids=[self.device]) + model.get_loss = model.module.get_loss + model.cfg = model.module.cfg - for epoch in range(0, cfg.max_epoch + 1): + record_summary = cfg.get('summary').get('record_for', + []) if self.rank == 0 else [] + if rank == 0: + log.info("Started training") + + for epoch in range(start_ep, cfg.max_epoch + 1): log.info(f'=== EPOCH {epoch:d}/{cfg.max_epoch:d} ===') + if self.distributed: + train_sampler.set_epoch(epoch) + model.train() self.metric_train.reset() self.metric_val.reset() self.losses = [] - model.trans_point_sampler = train_sampler.get_point_sampler() + # model.trans_point_sampler = train_sampler.get_point_sampler() # TODO: fix this for model with samplers. - for step, inputs in enumerate(tqdm(train_loader, desc='training')): + progress_bar = tqdm(train_loader, desc='training') + for inputs in progress_bar: if hasattr(inputs['data'], 'to'): inputs['data'].to(device) self.optimizer.zero_grad() @@ -416,24 +450,39 @@ def run_train(self): loss.backward() if model.cfg.get('grad_clip_norm', -1) > 0: - torch.nn.utils.clip_grad_value_(model.parameters(), - model.cfg.grad_clip_norm) + if self.distributed: + torch.nn.utils.clip_grad_value_( + model.module.parameters(), model.cfg.grad_clip_norm) + else: + torch.nn.utils.clip_grad_value_( + model.parameters(), model.cfg.grad_clip_norm) + self.optimizer.step() self.metric_train.update(predict_scores, gt_labels) self.losses.append(loss.cpu().item()) + # Save only for the first pcd in batch - if 'train' in record_summary and step == 0: + if 'train' in record_summary and progress_bar.n == 0: self.summary['train'] = self.get_3d_summary( results, inputs['data'], epoch) - self.scheduler.step() + desc = "training - Epoch: %d, loss: %.3f" % (epoch, + loss.cpu().item()) + progress_bar.set_description(desc) + progress_bar.refresh() + + if self.distributed: + dist.barrier() + + if self.scheduler is not None: + self.scheduler.step() # --------------------- validation model.eval() self.valid_losses = [] - model.trans_point_sampler = valid_sampler.get_point_sampler() + # model.trans_point_sampler = valid_sampler.get_point_sampler() with torch.no_grad(): for step, inputs in enumerate( @@ -456,10 +505,21 @@ def run_train(self): self.summary['valid'] = self.get_3d_summary( results, inputs['data'], epoch) - self.save_logs(writer, epoch) + if self.distributed: + metric_gather = [None for _ in range(dist.get_world_size())] + dist.gather_object(self.metric_val, + metric_gather if rank == 0 else None, + dst=0) + if rank == 0: + for m in metric_gather[1:]: + self.metric_val += m + + dist.barrier() - if epoch % cfg.save_ckpt_freq == 0 or epoch == cfg.max_epoch: - self.save_ckpt(epoch) + if rank == 0: + self.save_logs(writer, epoch) + if epoch % cfg.save_ckpt_freq == 0 or epoch == cfg.max_epoch: + self.save_ckpt(epoch) def get_batcher(self, device, split='training'): """Get the batcher to be used based on the device and split.""" @@ -661,29 +721,37 @@ def load_ckpt(self, ckpt_path=None, is_resume=True): want to resume. """ train_ckpt_dir = join(self.cfg.logs_dir, 'checkpoint') - make_dir(train_ckpt_dir) + if self.rank == 0: + make_dir(train_ckpt_dir) + if self.distributed: + dist.barrier() if ckpt_path is None: ckpt_path = latest_torch_ckpt(train_ckpt_dir) if ckpt_path is not None and is_resume: - log.info('ckpt_path not given. Restore from the latest ckpt') + log.info("ckpt_path not given. Restore from the latest ckpt") else: log.info('Initializing from scratch.') - return + return 0 if not exists(ckpt_path): raise FileNotFoundError(f' ckpt {ckpt_path} not found') log.info(f'Loading checkpoint {ckpt_path}') ckpt = torch.load(ckpt_path, map_location=self.device) + self.model.load_state_dict(ckpt['model_state_dict']) if 'optimizer_state_dict' in ckpt and hasattr(self, 'optimizer'): - log.info(f'Loading checkpoint optimizer_state_dict') + log.info('Loading checkpoint optimizer_state_dict') self.optimizer.load_state_dict(ckpt['optimizer_state_dict']) if 'scheduler_state_dict' in ckpt and hasattr(self, 'scheduler'): - log.info(f'Loading checkpoint scheduler_state_dict') + log.info('Loading checkpoint scheduler_state_dict') self.scheduler.load_state_dict(ckpt['scheduler_state_dict']) + epoch = 0 if 'epoch' not in ckpt else ckpt['epoch'] + + return epoch + def save_ckpt(self, epoch): """Save a checkpoint at the passed epoch.""" path_ckpt = join(self.cfg.logs_dir, 'checkpoint') diff --git a/ml3d/utils/builder.py b/ml3d/utils/builder.py index f087d501c..6c2ac7ebb 100644 --- a/ml3d/utils/builder.py +++ b/ml3d/utils/builder.py @@ -55,6 +55,8 @@ def get_module(module_type, module_name, framework=None, **kwargs): return get_from_name(module_name, DATASET, framework) elif module_type == "sampler": + if module_name is None: + return None return get_from_name(module_name, SAMPLER, framework) elif module_type == "model": diff --git a/scripts/preprocess_nuscenes.py b/scripts/preprocess_nuscenes.py index 582240416..c2ad253c9 100644 --- a/scripts/preprocess_nuscenes.py +++ b/scripts/preprocess_nuscenes.py @@ -58,6 +58,7 @@ def __init__(self, dataset_path, out_path, version='v1.0-trainval'): assert version in ['v1.0-trainval', 'v1.0-test', 'v1.0-mini'] self.is_test = 'test' in version self.out_path = out_path + self.dataset_path = dataset_path self.nusc = NuScenes(version=version, dataroot=dataset_path, @@ -249,6 +250,12 @@ def process_scenes(self): data['cams'].update({cam: cam_info}) if not self.is_test: + lidarseg_path = nusc.get('lidarseg', lidar_token)['filename'] + lidarseg_path = os.path.abspath( + os.path.join(self.dataset_path, lidarseg_path)) + assert os.path.exists(lidarseg_path) + data['lidarseg_path'] = lidarseg_path + annotations = [ nusc.get('sample_annotation', token) for token in sample['anns'] diff --git a/scripts/preprocess_waymo.py b/scripts/preprocess_waymo.py index 6b44a61fd..0f677aabd 100644 --- a/scripts/preprocess_waymo.py +++ b/scripts/preprocess_waymo.py @@ -7,16 +7,14 @@ import logging import numpy as np -import os, sys, glob, pickle +import glob import argparse import tensorflow as tf -import matplotlib.image as mpimg from pathlib import Path from os.path import join, exists, dirname, abspath from os import makedirs from multiprocessing import Pool -from tqdm import tqdm from waymo_open_dataset.utils import range_image_utils, transform_utils from waymo_open_dataset.utils.frame_utils import \ parse_range_image_and_camera_projection @@ -38,10 +36,10 @@ def parse_args(): default=16, type=int) - parser.add_argument('--split', - help='One of {train, val, test} (default train)', - default='train', - type=str) + parser.add_argument('--is_test', + help='True for processing test data (default False)', + default=False, + type=bool) args = parser.parse_args() @@ -58,25 +56,6 @@ class Waymo2KITTI(): """Waymo to KITTI converter. This class converts tfrecord files from Waymo dataset to KITTI format. - KITTI format : (type, truncated, occluded, alpha, bbox, dimensions(3), location(3), - rotation_y(1), score(1, optional)) - type (string): Describes the type of object. - truncated (float): Ranges from 0(non-truncated) to 1(truncated). - occluded (int): Integer(0, 1, 2, 3) signifies state fully visible, partly - occluded, largely occluded, unknown. - alpha (float): Observation angle of object, ranging [-pi..pi]. - bbox (float): 2d bounding box of object in the image. - dimensions (float): 3D object dimensions: h, w, l in meters. - location (float): 3D object location: x,y,z in camera coordinates (in meters). - rotation_y (float): rotation around Y-axis in camera coordinates [-pi..pi]. - score (float): Only for predictions, indicating confidence in detection. - - Conversion writes following files: - pointcloud(np.float32) : pointcloud data with shape [N, 6]. Consists of - (x, y, z, intensity, elongation, timestamp). - images(np.uint8): camera images are saved if `write_image` is True. - calibrations(np.float32): Intinsic and Extrinsic matrix for all cameras. - label(np.float32): Bounding box information in KITTI format. Args: dataset_path (str): Directory to load waymo raw data. @@ -85,9 +64,9 @@ class Waymo2KITTI(): is_test (bool): Whether in the test_mode. Default: False. """ - def __init__(self, dataset_path, save_dir='', workers=8, split='train'): + def __init__(self, dataset_path, save_dir='', workers=8, is_test=False): - self.write_image = False + self.write_image = True self.filter_empty_3dboxes = True self.filter_no_label_zone_points = True @@ -105,8 +84,8 @@ def __init__(self, dataset_path, save_dir='', workers=8, split='train'): self.dataset_path = dataset_path self.save_dir = save_dir self.workers = int(workers) - self.is_test = split == 'test' - self.prefix = split + '_' + self.is_test = is_test + self.prefix = '' self.save_track_id = False self.tfrecord_files = sorted( @@ -156,6 +135,7 @@ def process_one(self, file_idx): if (self.selected_waymo_locations is not None and frame.context.stats.location not in self.selected_waymo_locations): + print("continue") continue if self.write_image: @@ -171,9 +151,11 @@ def __len__(self): return len(self.tfrecord_files) def save_image(self, frame, file_idx, frame_idx): + self.prefix = '' + for img in frame.images: img_path = Path(self.image_save_dir + str(img.name - 1)) / ( - self.prefix + str(file_idx).zfill(3) + str(frame_idx).zfill(3) + + self.prefix + str(file_idx).zfill(4) + str(frame_idx).zfill(4) + '.npy') image = tf.io.decode_jpeg(img.image).numpy() @@ -223,26 +205,28 @@ def save_calib(self, frame, file_idx, frame_idx): with open( f'{self.calib_save_dir}/{self.prefix}' + - f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.txt', 'w+') as fp_calib: fp_calib.write(calib_context) + fp_calib.close() def save_pose(self, frame, file_idx, frame_idx): pose = np.array(frame.pose.transform).reshape(4, 4) np.savetxt( join(f'{self.pose_save_dir}/{self.prefix}' + - f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt'), + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.txt'), pose) def save_label(self, frame, file_idx, frame_idx): fp_label_all = open( f'{self.label_all_save_dir}/{self.prefix}' + - f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', 'w+') + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.txt', 'w+') id_to_bbox = dict() id_to_name = dict() for labels in frame.projected_lidar_labels: name = labels.name for label in labels.labels: + # TODO: need a workaround as bbox may not belong to front cam bbox = [ label.box.center_x - label.box.length / 2, label.box.center_y - label.box.width / 2, @@ -271,6 +255,9 @@ def save_label(self, frame, file_idx, frame_idx): if my_type not in self.selected_waymo_classes: continue + # if self.filter_empty_3dboxes and obj.num_lidar_points_in_box < 1: + # continue + height = obj.box.height width = obj.box.width length = obj.box.length @@ -279,6 +266,11 @@ def save_label(self, frame, file_idx, frame_idx): y = obj.box.center_y z = obj.box.center_z + # # project bounding box to the virtual reference frame + # pt_ref = self.T_velo_to_front_cam @ \ + # np.array([x, y, z, 1]).reshape((4, 1)) + # x, y, z, _ = pt_ref.flatten().tolist() + rotation_y = -obj.box.heading - np.pi / 2 track_id = obj.id @@ -303,7 +295,7 @@ def save_label(self, frame, file_idx, frame_idx): fp_label = open( f'{self.label_save_dir}{name}/{self.prefix}' + - f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.txt', 'a') + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.txt', 'a') fp_label.write(line) fp_label.close() @@ -312,7 +304,7 @@ def save_label(self, frame, file_idx, frame_idx): fp_label_all.close() def save_lidar(self, frame, file_idx, frame_idx): - range_images, camera_projections, range_image_top_pose = parse_range_image_and_camera_projection( + range_images, camera_projections, _, range_image_top_pose = parse_range_image_and_camera_projection( frame) # First return @@ -351,7 +343,7 @@ def save_lidar(self, frame, file_idx, frame_idx): (points, intensity, elongation, timestamp)) pc_path = f'{self.point_cloud_save_dir}/{self.prefix}' + \ - f'{str(file_idx).zfill(3)}{str(frame_idx).zfill(3)}.bin' + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.bin' point_cloud.astype(np.float32).tofile(pc_path) def convert_range_image_to_point_cloud(self, @@ -466,8 +458,6 @@ def cart_to_homo(mat): out_path = args.out_path if out_path is None: args.out_path = args.dataset_path - if args.split not in ['train', 'val', 'test']: - raise ValueError("split must be one of {train, val, test}") converter = Waymo2KITTI(args.dataset_path, args.out_path, args.workers, - args.split) + args.is_test) converter.convert() diff --git a/scripts/preprocess_waymo_semseg.py b/scripts/preprocess_waymo_semseg.py new file mode 100644 index 000000000..ae4077610 --- /dev/null +++ b/scripts/preprocess_waymo_semseg.py @@ -0,0 +1,425 @@ +try: + from waymo_open_dataset import dataset_pb2 + from waymo_open_dataset.utils import range_image_utils, transform_utils + from waymo_open_dataset.utils.frame_utils import \ + parse_range_image_and_camera_projection +except ImportError: + raise ImportError( + 'Please clone "https://github.com/waymo-research/waymo-open-dataset.git" ' + 'checkout branch "r1.3", and install the official devkit first') + +import logging +import numpy as np +import os, sys, glob, pickle +import argparse +import tensorflow as tf +import matplotlib.image as mpimg + +from pathlib import Path +from os.path import join, exists, dirname, abspath +from os import makedirs +from multiprocessing import Pool +from tqdm import tqdm +from tqdm.contrib.concurrent import process_map # or thread_map + +gpus = tf.config.experimental.list_physical_devices('GPU') +if gpus: + try: + for gpu in gpus: + tf.config.experimental.set_memory_growth(gpu, True) + except RuntimeError as e: + print(e) + + +def parse_args(): + parser = argparse.ArgumentParser(description='Preprocess Waymo Dataset.') + parser.add_argument('--dataset_path', + help='path to Waymo tfrecord files', + required=True) + parser.add_argument( + '--out_path', + help='Output path to store pickle (default to dataet_path)', + default=None, + required=False) + + parser.add_argument('--workers', + help='Number of workers.', + default=16, + type=int) + + parser.add_argument('--split', + help='One of {train, val, test} (default train)', + default='train', + type=str) + + args = parser.parse_args() + + dict_args = vars(args) + for k in dict_args: + v = dict_args[k] + print("{}: {}".format(k, v) if v is not None else "{} not given". + format(k)) + + return args + + +class Waymo2KITTI(): + """Waymo to KITTI converter. + + This class converts tfrecord files from Waymo dataset to KITTI format. + KITTI format : (type, truncated, occluded, alpha, bbox, dimensions(3), location(3), + rotation_y(1), score(1, optional)) + type (string): Describes the type of object. + truncated (float): Ranges from 0(non-truncated) to 1(truncated). + occluded (int): Integer(0, 1, 2, 3) signifies state fully visible, partly + occluded, largely occluded, unknown. + alpha (float): Observation angle of object, ranging [-pi..pi]. + bbox (float): 2d bounding box of object in the image. + dimensions (float): 3D object dimensions: h, w, l in meters. + location (float): 3D object location: x,y,z in camera coordinates (in meters). + rotation_y (float): rotation around Y-axis in camera coordinates [-pi..pi]. + score (float): Only for predictions, indicating confidence in detection. + + Conversion writes following files: + pointcloud(np.float32) : pointcloud data with shape [N, 6]. Consists of + (x, y, z, intensity, elongation, timestamp). + images(np.uint8): camera images are saved if `write_image` is True. + calibrations(np.float32): Intinsic and Extrinsic matrix for all cameras. + label(np.float32): Bounding box information in KITTI format. + + Args: + dataset_path (str): Directory to load waymo raw data. + save_dir (str): Directory to save data in KITTI format. + workers (str): Number of workers for the parallel process. + is_test (bool): Whether in the test_mode. Default: False. + """ + + def __init__(self, dataset_path, save_dir='', workers=8, split='train'): + + self.write_image = False + self.filter_empty_3dboxes = True + self.filter_no_label_zone_points = False + + self.classes = ['VEHICLE', 'PEDESTRIAN', 'SIGN', 'CYCLIST'] + + self.lidar_list = [ + '_FRONT', '_FRONT_RIGHT', '_FRONT_LEFT', '_SIDE_RIGHT', '_SIDE_LEFT' + ] + self.type_list = ['UNKNOWN', 'VEHICLE', 'PEDESTRIAN', 'SIGN', 'CYCLIST'] + + self.selected_waymo_classes = self.classes + + self.selected_waymo_locations = None + + self.dataset_path = dataset_path + self.save_dir = save_dir + self.workers = int(workers) + self.is_test = split == 'test' + self.prefix = split + '_' + self.save_track_id = False + + self.tfrecord_files = sorted( + glob.glob(join(self.dataset_path, "*.tfrecord"))) + + self.label_save_dir = f'{self.save_dir}/label_' + self.label_all_save_dir = f'{self.save_dir}/label_all' + self.image_save_dir = f'{self.save_dir}/image_' + self.calib_save_dir = f'{self.save_dir}/calib' + self.point_cloud_save_dir = f'{self.save_dir}/velodyne' + self.pose_save_dir = f'{self.save_dir}/pose' + + self.create_folder() + + def create_folder(self): + if not self.is_test: + dir_list1 = [ + self.label_all_save_dir, self.calib_save_dir, + self.point_cloud_save_dir, self.pose_save_dir + ] + dir_list2 = [self.label_save_dir, self.image_save_dir] + else: + dir_list1 = [ + self.calib_save_dir, self.point_cloud_save_dir, + self.pose_save_dir + ] + dir_list2 = [self.image_save_dir] + for d in dir_list1: + makedirs(d, exist_ok=True) + for d in dir_list2: + for i in range(5): + makedirs(f'{d}{str(i)}', exist_ok=True) + + def convert(self): + print(f"Start converting {len(self)} files ...") + process_map(self.process_one, + range(len(self)), + max_workers=self.workers) + + def process_one(self, file_idx): + print(f"Converting : {file_idx}") + path = self.tfrecord_files[file_idx] + dataset = tf.data.TFRecordDataset(path, compression_type='') + + for frame_idx, data in enumerate(dataset): + frame = dataset_pb2.Frame() + frame.ParseFromString(bytearray(data.numpy())) + if (not frame.lasers[0].ri_return1.segmentation_label_compressed + ) and (not self.is_test): + continue + + if (self.selected_waymo_locations is not None and + frame.context.stats.location + not in self.selected_waymo_locations): + continue + + if self.write_image: + self.save_image(frame, file_idx, frame_idx) + self.save_calib(frame, file_idx, frame_idx) + self.save_lidar(frame, file_idx, frame_idx) + self.save_pose(frame, file_idx, frame_idx) + + def __len__(self): + return len(self.tfrecord_files) + + def save_image(self, frame, file_idx, frame_idx): + for img in frame.images: + img_path = Path(self.image_save_dir + str(img.name - 1)) / ( + self.prefix + str(file_idx).zfill(4) + str(frame_idx).zfill(4) + + '.npy') + image = tf.io.decode_jpeg(img.image).numpy() + + np.save(img_path, image) + + def save_calib(self, frame, file_idx, frame_idx): + # waymo front camera to kitti reference camera + T_front_cam_to_ref = np.array([[0.0, -1.0, 0.0], [0.0, 0.0, -1.0], + [1.0, 0.0, 0.0]]) + camera_calibs = [] + R0_rect = [f'{i:e}' for i in np.eye(3).flatten()] + Tr_velo_to_cams = [] + calib_context = '' + + for camera in frame.context.camera_calibrations: + # extrinsic parameters + T_cam_to_vehicle = np.array(camera.extrinsic.transform).reshape( + 4, 4) + T_vehicle_to_cam = np.linalg.inv(T_cam_to_vehicle) + Tr_velo_to_cam = \ + self.cart_to_homo(T_front_cam_to_ref) @ T_vehicle_to_cam + if camera.name == 1: # FRONT = 1, see dataset.proto for details + self.T_velo_to_front_cam = Tr_velo_to_cam.copy() + Tr_velo_to_cam = Tr_velo_to_cam[:3, :].reshape((12,)) + Tr_velo_to_cams.append([f'{i:e}' for i in Tr_velo_to_cam]) + + # intrinsic parameters + camera_calib = np.zeros((3, 4)) + camera_calib[0, 0] = camera.intrinsic[0] + camera_calib[1, 1] = camera.intrinsic[1] + camera_calib[0, 2] = camera.intrinsic[2] + camera_calib[1, 2] = camera.intrinsic[3] + camera_calib[2, 2] = 1 + camera_calib = list(camera_calib.reshape(12)) + camera_calib = [f'{i:e}' for i in camera_calib] + camera_calibs.append(camera_calib) + + # all camera ids are saved as id-1 in the result because + # camera 0 is unknown in the proto + for i in range(5): + calib_context += 'P' + str(i) + ': ' + \ + ' '.join(camera_calibs[i]) + '\n' + calib_context += 'R0_rect' + ': ' + ' '.join(R0_rect) + '\n' + for i in range(5): + calib_context += 'Tr_velo_to_cam_' + str(i) + ': ' + \ + ' '.join(Tr_velo_to_cams[i]) + '\n' + + with open( + f'{self.calib_save_dir}/{self.prefix}' + + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.txt', + 'w+') as fp_calib: + fp_calib.write(calib_context) + + def save_pose(self, frame, file_idx, frame_idx): + pose = np.array(frame.pose.transform).reshape(4, 4) + np.savetxt( + join(f'{self.pose_save_dir}/{self.prefix}' + + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.txt'), + pose) + + def save_lidar(self, frame, file_idx, frame_idx): + range_images, camera_projections, seg_labels, range_image_top_pose = parse_range_image_and_camera_projection( + frame) + + # First return + points_0, cp_points_0, intensity_0, elongation_0, seg_label_0 = \ + self.convert_range_image_to_point_cloud( + frame, + range_images, + seg_labels, + camera_projections, + range_image_top_pose, + ri_index=0 + ) + points_0 = np.concatenate(points_0, axis=0) + intensity_0 = np.concatenate(intensity_0, axis=0) + elongation_0 = np.concatenate(elongation_0, axis=0) + seg_label_0 = np.concatenate(seg_label_0, axis=0) + + # Second return + points_1, cp_points_1, intensity_1, elongation_1, seg_label_1 = \ + self.convert_range_image_to_point_cloud( + frame, + range_images, + seg_labels, + camera_projections, + range_image_top_pose, + ri_index=1 + ) + points_1 = np.concatenate(points_1, axis=0) + intensity_1 = np.concatenate(intensity_1, axis=0) + elongation_1 = np.concatenate(elongation_1, axis=0) + seg_label_1 = np.concatenate(seg_label_1, axis=0) + + points = np.concatenate([points_0, points_1], axis=0) + intensity = np.concatenate([intensity_0, intensity_1], axis=0) + elongation = np.concatenate([elongation_0, elongation_1], axis=0) + semseg_labels = np.concatenate([seg_label_0, seg_label_1], axis=0) + timestamp = frame.timestamp_micros * np.ones_like(intensity) + + # concatenate x,y,z, intensity, elongation, timestamp (6-dim) + point_cloud = np.column_stack( + (points, intensity, elongation, timestamp, semseg_labels)) + + pc_path = f'{self.point_cloud_save_dir}/{self.prefix}' + \ + f'{str(file_idx).zfill(4)}{str(frame_idx).zfill(4)}.bin' + point_cloud.astype(np.float32).tofile(pc_path) + + def convert_range_image_to_point_cloud(self, + frame, + range_images, + segmentation_labels, + camera_projections, + range_image_top_pose, + ri_index=0): + calibrations = sorted(frame.context.laser_calibrations, + key=lambda c: c.name) + points = [] + cp_points = [] + intensity = [] + elongation = [] + semseg_labels = [] + + frame_pose = tf.convert_to_tensor( + value=np.reshape(np.array(frame.pose.transform), [4, 4])) + # [H, W, 6] + range_image_top_pose_tensor = tf.reshape( + tf.convert_to_tensor(value=range_image_top_pose.data), + range_image_top_pose.shape.dims) + # [H, W, 3, 3] + range_image_top_pose_tensor_rotation = \ + transform_utils.get_rotation_matrix( + range_image_top_pose_tensor[..., 0], + range_image_top_pose_tensor[..., 1], + range_image_top_pose_tensor[..., 2]) + range_image_top_pose_tensor_translation = \ + range_image_top_pose_tensor[..., 3:] + range_image_top_pose_tensor = transform_utils.get_transform( + range_image_top_pose_tensor_rotation, + range_image_top_pose_tensor_translation) + for c in calibrations: + range_image = range_images[c.name][ri_index] + if len(c.beam_inclinations) == 0: + beam_inclinations = range_image_utils.compute_inclination( + tf.constant( + [c.beam_inclination_min, c.beam_inclination_max]), + height=range_image.shape.dims[0]) + else: + beam_inclinations = tf.constant(c.beam_inclinations) + + beam_inclinations = tf.reverse(beam_inclinations, axis=[-1]) + extrinsic = np.reshape(np.array(c.extrinsic.transform), [4, 4]) + + range_image_tensor = tf.reshape( + tf.convert_to_tensor(value=range_image.data), + range_image.shape.dims) + pixel_pose_local = None + frame_pose_local = None + if c.name == dataset_pb2.LaserName.TOP: + pixel_pose_local = range_image_top_pose_tensor + pixel_pose_local = tf.expand_dims(pixel_pose_local, axis=0) + frame_pose_local = tf.expand_dims(frame_pose, axis=0) + range_image_mask = range_image_tensor[..., 0] > 0 + + if self.filter_no_label_zone_points: + nlz_mask = range_image_tensor[..., 3] != 1.0 # 1.0: in NLZ + range_image_mask = range_image_mask & nlz_mask + + range_image_cartesian = \ + range_image_utils.extract_point_cloud_from_range_image( + tf.expand_dims(range_image_tensor[..., 0], axis=0), + tf.expand_dims(extrinsic, axis=0), + tf.expand_dims(tf.convert_to_tensor( + value=beam_inclinations), axis=0), + pixel_pose=pixel_pose_local, + frame_pose=frame_pose_local) + + range_image_cartesian = tf.squeeze(range_image_cartesian, axis=0) + points_tensor = tf.gather_nd(range_image_cartesian, + tf.compat.v1.where(range_image_mask)) + cp = camera_projections[c.name][ri_index] + cp_tensor = tf.reshape(tf.convert_to_tensor(value=cp.data), + cp.shape.dims) + cp_points_tensor = tf.gather_nd( + cp_tensor, tf.compat.v1.where(range_image_mask)) + points.append(points_tensor.numpy()) + cp_points.append(cp_points_tensor.numpy()) + + intensity_tensor = tf.gather_nd(range_image_tensor[..., 1], + tf.where(range_image_mask)) + intensity.append(intensity_tensor.numpy()) + + elongation_tensor = tf.gather_nd(range_image_tensor[..., 2], + tf.where(range_image_mask)) + elongation.append(elongation_tensor.numpy()) + + if c.name in segmentation_labels: + sl = segmentation_labels[c.name][ri_index] + sl_tensor = tf.reshape(tf.convert_to_tensor(sl.data), + sl.shape.dims) + sl_points_tensor = tf.gather_nd(sl_tensor, + tf.where(range_image_mask)) + else: + sl_points_tensor = tf.zeros([points_tensor.shape[0], 2], + dtype=tf.int32) + + semseg_labels.append(sl_points_tensor.numpy()) + + return points, cp_points, intensity, elongation, semseg_labels + + @staticmethod + def cart_to_homo(mat): + ret = np.eye(4) + if mat.shape == (3, 3): + ret[:3, :3] = mat + elif mat.shape == (3, 4): + ret[:3, :] = mat + else: + raise ValueError(mat.shape) + return ret + + +if __name__ == '__main__': + + logging.basicConfig( + level=logging.INFO, + format='%(levelname)s - %(asctime)s - %(module)s - %(message)s', + ) + + args = parse_args() + out_path = args.out_path + if out_path is None: + args.out_path = args.dataset_path + if args.split not in ['train', 'val', 'test']: + raise ValueError("split must be one of {train, val, test}") + converter = Waymo2KITTI(args.dataset_path, args.out_path, args.workers, + args.split) + converter.convert() diff --git a/scripts/run_pipeline.py b/scripts/run_pipeline.py index 68564e1b9..64718fdaf 100644 --- a/scripts/run_pipeline.py +++ b/scripts/run_pipeline.py @@ -43,6 +43,11 @@ def parse_args(): parser.add_argument('--main_log_dir', help='the dir to save logs and models') parser.add_argument('--seed', help='random seed', default=0) + parser.add_argument('--nodes', help='number of nodes', default=1, type=int) + parser.add_argument('--node_rank', + help='ranking within the nodes, default: 0', + default=0, + type=int) parser.add_argument( '--host', help='Host for distributed training, default: localhost', @@ -197,9 +202,10 @@ def cleanup(): dist.destroy_process_group() -def main_worker(rank, Dataset, Model, Pipeline, cfg_dict_dataset, +def main_worker(local_rank, Dataset, Model, Pipeline, cfg_dict_dataset, cfg_dict_model, cfg_dict_pipeline, args): - world_size = len(args.device_ids) + rank = args.node_rank * len(args.device_ids) + local_rank + world_size = args.nodes * len(args.device_ids) setup(rank, world_size, args) cfg_dict_dataset['rank'] = rank @@ -211,8 +217,10 @@ def main_worker(rank, Dataset, Model, Pipeline, cfg_dict_dataset, cfg_dict_model['seed'] = rng cfg_dict_pipeline['seed'] = rng - device = f"cuda:{args.device_ids[rank]}" - print(f"rank = {rank}, world_size = {world_size}, gpu = {device}") + device = f"cuda:{args.device_ids[local_rank]}" + print( + f"local_rank = {local_rank}, rank = {rank}, world_size = {world_size}, gpu = {device}" + ) cfg_dict_model['device'] = device cfg_dict_pipeline['device'] = device