From 22992b787f1b5003a7d1ef1d92aee74bd815e56c Mon Sep 17 00:00:00 2001 From: "annie.didier" Date: Thu, 23 May 2024 19:35:52 +0000 Subject: [PATCH] Update geoformer to use cuda 11.3, pytorch 1.11.0, and spconv 2.3.6 --- .devcontainer/Dockerfile_U2004_CUDA113 | 48 +++ .devcontainer/devcontainer.json | 51 +++ checkpoint.py | 47 ++- .../src/bfs_cluster/bfs_cluster.h | 2 +- model/geoformer/geoformer.py | 205 +++++++++--- model/geoformer/geoformer_fs.py | 314 +++++++++++++----- model/geoformer/geoformer_modules.py | 140 +++++--- test_fs.py | 101 ++++-- 8 files changed, 705 insertions(+), 203 deletions(-) create mode 100644 .devcontainer/Dockerfile_U2004_CUDA113 create mode 100644 .devcontainer/devcontainer.json diff --git a/.devcontainer/Dockerfile_U2004_CUDA113 b/.devcontainer/Dockerfile_U2004_CUDA113 new file mode 100644 index 0000000..121f252 --- /dev/null +++ b/.devcontainer/Dockerfile_U2004_CUDA113 @@ -0,0 +1,48 @@ +FROM nvidia/cuda:11.3.1-devel-ubuntu20.04 + +RUN apt-get update && apt-get install wget git -yq +RUN apt-get install build-essential g++ gcc -y +ENV DEBIAN_FRONTEND noninteractive +# Unsure if openmpi is needed +# RUN apt-get update && apt-get install libgl1-mesa-glx libglib2.0-0 libxcb-* \ +# openmpi-bin openmpi-common libopenmpi-dev libgtk2.0-dev -y + +# Install miniconda +ENV CONDA_DIR /opt/conda + +RUN wget --quiet https://repo.anaconda.com/miniconda/Miniconda3-latest-Linux-x86_64.sh -O ~/miniconda.sh && \ + /bin/bash ~/miniconda.sh -b -p /opt/conda + +# Put conda in path so we can use conda activate +ENV PATH=$CONDA_DIR/bin:/usr/local/bin:$PATH +# general packages +RUN conda install python=3.8 +RUN conda install numpy=1.23 +RUN conda install -c anaconda jupyter +RUN echo "numpy==1.23.*" > /opt/conda/conda-meta/pinned +RUN conda install pytorch==1.11.0 torchvision==0.12.0 torchaudio==0.11.0 cudatoolkit=11.3 -c pytorch +RUN conda install conda=22.11 +RUN conda install -c conda-forge setuptools=59.5 + +# Make sure CUDA is visible +# ENV LD_LIBRARY_PATH /usr/local/cuda/lib64:$LD_LIBRARY_PATH +ENV NVIDIA_VISIBLE_DEVICES all +ENV NVIDIA_DRIVER_CAPABILITIES compute,utility +# ARG TORCH_CUDA_ARCH_LIST="8.9" +# Install pointgroup_ops +RUN apt-get install libsparsehash-dev +COPY requirements.txt /tmp/requirements.txt +RUN pip install -r /tmp/requirements.txt +COPY lib /lib +RUN cd /lib/pointgroup_ops && python setup.py develop + +# Install spconv +RUN conda install libboost && pip install pccm +RUN pip install spconv-cu113 + +# Install pointnet2 +# RUN cd /lib/pointnet2 && python setup.py install + +# Install faiss +RUN conda install -c conda-forge faiss-gpu + diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json new file mode 100644 index 0000000..3bb3629 --- /dev/null +++ b/.devcontainer/devcontainer.json @@ -0,0 +1,51 @@ +{ + "build": { + "dockerfile": "Dockerfile_U2004_CUDA113", + "context": "..", + "args": { + "DOCKER_BUILDKIT": "0" + } + }, + "mounts": [ + "source=${localWorkspaceFolder},target=/workspace,type=bind,consistency=cached" + ], + "runArgs": [ + "--gpus", + "all", + "--shm-size", + "16gb", + "-v", + "/tmp/.X11-unix:/tmp.X11-unix" + ], + "containerEnv": { + "NVIDIA_DRIVER_CAPABILITIES": "all", + "DISPLAY": "unix:0" + }, + "forwardPorts": [ + 8887, + 8888, + 8886 + ], + "customizations": { + "vscode": { + "extensions": [ + "ms-python.python", + "ms-python.vscode-pylance", + "ms-toolsai.jupyter", + "ms-python.black-formatter" + ], + "settings": { +         "python.defaultInterpreterPath": "/opt/conda/bin/python", +         "python.linting.enabled": true, +         "python.linting.pylintEnabled": true, + "[python]": { + "editor.defaultFormatter": "ms-python.black-formatter", + "editor.formatOnSave": true + } +          +        + } + } + }, + "workspaceFolder": "/workspace" +} \ No newline at end of file diff --git a/checkpoint.py b/checkpoint.py index c4b76f6..6b09b28 100644 --- a/checkpoint.py +++ b/checkpoint.py @@ -26,8 +26,12 @@ def align_and_update_state_dicts(model_state_dict, loaded_state_dict): loaded_keys = sorted(list(loaded_state_dict.keys())) # get a matrix of string matches, where each (i, j) entry correspond to the size of the # loaded_key string, if it matches - match_matrix = [len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys] - match_matrix = torch.as_tensor(match_matrix).view(len(current_keys), len(loaded_keys)) + match_matrix = [ + len(j) if i.endswith(j) else 0 for i in current_keys for j in loaded_keys + ] + match_matrix = torch.as_tensor(match_matrix).view( + len(current_keys), len(loaded_keys) + ) max_match_size, idxs = match_matrix.max(1) # remove indices that correspond to no-match idxs[max_match_size == 0] = -1 @@ -44,15 +48,19 @@ def align_and_update_state_dicts(model_state_dict, loaded_state_dict): key = current_keys[idx_new] key_old = loaded_keys[idx_old] if loaded_state_dict[key_old].shape != model_state_dict[key].shape: - # if 'unet' in key or 'input_conv' in key: - # reshaped = loaded_state_dict[key_old].permute(4,0,1,2,3) - # loaded_state_dict[key_old] = reshaped - # else: - print( - "Skip loading parameter {}, required shape{}, " - "loaded shape{}.".format(key, model_state_dict[key].shape, loaded_state_dict[key_old].shape) - ) - loaded_state_dict[key_old] = model_state_dict[key] + if "unet" in key or "input_conv" in key: + reshaped = loaded_state_dict[key_old].permute(4, 0, 1, 2, 3) + loaded_state_dict[key_old] = reshaped + else: + print( + "Skip loading parameter {}, required shape{}, " + "loaded shape{}.".format( + key, + model_state_dict[key].shape, + loaded_state_dict[key_old].shape, + ) + ) + loaded_state_dict[key_old] = model_state_dict[key] model_state_dict[key] = loaded_state_dict[key_old] logger.info( @@ -87,7 +95,16 @@ def mkdir_p(path): raise -def checkpoint(model, optimizer, epoch, log_dir, best_val=None, best_val_iter=None, postfix=None, last=False): +def checkpoint( + model, + optimizer, + epoch, + log_dir, + best_val=None, + best_val_iter=None, + postfix=None, + last=False, +): mkdir_p(log_dir) if last: @@ -95,7 +112,11 @@ def checkpoint(model, optimizer, epoch, log_dir, best_val=None, best_val_iter=No else: filename = f"checkpoint_epoch_{epoch}.pth" checkpoint_file = log_dir + "/" + filename - state = {"epoch": epoch, "state_dict": model.state_dict(), "optimizer": optimizer.state_dict()} + state = { + "epoch": epoch, + "state_dict": model.state_dict(), + "optimizer": optimizer.state_dict(), + } torch.save(state, checkpoint_file) logging.info(f"Checkpoint saved to {checkpoint_file}") diff --git a/lib/pointgroup_ops/src/bfs_cluster/bfs_cluster.h b/lib/pointgroup_ops/src/bfs_cluster/bfs_cluster.h index 849ea6e..d971f56 100644 --- a/lib/pointgroup_ops/src/bfs_cluster/bfs_cluster.h +++ b/lib/pointgroup_ops/src/bfs_cluster/bfs_cluster.h @@ -8,7 +8,7 @@ All Rights Reserved 2020. #define BFS_CLUSTER_H #include #include -#include +// #include #include "../datatype/datatype.h" diff --git a/model/geoformer/geoformer.py b/model/geoformer/geoformer.py index 0e1f81a..3ceab24 100644 --- a/model/geoformer/geoformer.py +++ b/model/geoformer/geoformer.py @@ -1,7 +1,10 @@ import functools # import spconv.pytorch as spconv -import spconv as spconv +# import spconv as spconv +from spconv.pytorch.modules import SparseSequential +from spconv.pytorch.conv import SubMConv3d +from spconv.pytorch.core import SparseConvTensor import torch import torch.nn as nn @@ -11,7 +14,12 @@ from lib.pointgroup_ops.functions import pointgroup_ops from lib.pointnet2.pointnet2_modules import PointnetSAModuleVotesSeparate from model.geoformer.geodesic_utils import cal_geodesic_vectorize -from model.geoformer.geoformer_modules import ResidualBlock, UBlock, conv_with_kaiming_uniform, random_downsample +from model.geoformer.geoformer_modules import ( + ResidualBlock, + UBlock, + conv_with_kaiming_uniform, + random_downsample, +) from model.helper import GenericMLP from model.pos_embedding import PositionEmbeddingCoordsSine from model.transformer_detr import TransformerDecoder, TransformerDecoderLayer @@ -39,8 +47,10 @@ def __init__(self): norm_fn = functools.partial(nn.BatchNorm1d, eps=1e-4, momentum=0.1) # backbone - self.input_conv = spconv.SparseSequential( - spconv.SubMConv3d(input_c, m, kernel_size=3, padding=1, bias=False, indice_key="subm1") + self.input_conv = SparseSequential( + SubMConv3d( + input_c, m, kernel_size=3, padding=1, bias=False, indice_key="subm1" + ) ) self.unet = UBlock( [m, 2 * m, 3 * m, 4 * m, 5 * m, 6 * m, 7 * m], @@ -50,11 +60,16 @@ def __init__(self): use_backbone_transformer=True, indice_key_id=1, ) - self.output_layer = spconv.SparseSequential(norm_fn(m), nn.ReLU()) + self.output_layer = SparseSequential(norm_fn(m), nn.ReLU()) # semantic segmentation self.semantic = nn.Sequential( - nn.Linear(m, m, bias=True), norm_fn(m), nn.ReLU(), nn.Linear(m, m, bias=True), norm_fn(m), nn.ReLU() + nn.Linear(m, m, bias=True), + norm_fn(m), + nn.ReLU(), + nn.Linear(m, m, bias=True), + norm_fn(m), + nn.ReLU(), ) self.semantic_linear = nn.Linear(m, classes, bias=True) @@ -76,7 +91,9 @@ def __init__(self): for i in range(before_embedding_conv_num - 1): before_embedding_tower.append(conv_block(cfg.dec_dim, cfg.dec_dim)) before_embedding_tower.append(conv_block(cfg.dec_dim, self.output_dim)) - self.add_module("before_embedding_tower", nn.Sequential(*before_embedding_tower)) + self.add_module( + "before_embedding_tower", nn.Sequential(*before_embedding_tower) + ) # cond inst generate parameters for self.use_coords = True @@ -116,7 +133,9 @@ def __init__(self): ) """ Position embedding """ - self.pos_embedding = PositionEmbeddingCoordsSine(d_pos=cfg.dec_dim, pos_type="fourier", normalize=True) + self.pos_embedding = PositionEmbeddingCoordsSine( + d_pos=cfg.dec_dim, pos_type="fourier", normalize=True + ) """ DETR-Decoder """ decoder_layer = TransformerDecoderLayer( @@ -128,7 +147,9 @@ def __init__(self): use_rel=True, ) - self.decoder = TransformerDecoder(decoder_layer, num_layers=cfg.dec_nlayers, return_intermediate=True) + self.decoder = TransformerDecoder( + decoder_layer, num_layers=cfg.dec_nlayers, return_intermediate=True + ) self.query_projection = GenericMLP( input_dim=cfg.dec_dim, @@ -226,19 +247,27 @@ def generate_proposal( proposals_npoints = torch.sum(mask_logit_b_bool, dim=1) npoints_cond = proposals_npoints >= npoint_thresh - mask_logit_scores = torch.sum(mask_logit_b * mask_logit_b_bool.int(), dim=1) / (proposals_npoints + 1e-6) + mask_logit_scores = torch.sum(mask_logit_b * mask_logit_b_bool.int(), dim=1) / ( + proposals_npoints + 1e-6 + ) mask_logit_scores_cond = mask_logit_scores >= score_thresh - cls_logits_scores = torch.gather(cls_logits_b, 1, cls_logits_pred_b.unsqueeze(-1)).squeeze(-1) + cls_logits_scores = torch.gather( + cls_logits_b, 1, cls_logits_pred_b.unsqueeze(-1) + ).squeeze(-1) sem_scores = torch.sum( - semantic_scores_b[None, :, :].expand(n_queries, semantic_scores_b.shape[0], semantic_scores_b.shape[1]) + semantic_scores_b[None, :, :].expand( + n_queries, semantic_scores_b.shape[0], semantic_scores_b.shape[1] + ) * mask_logit_b_bool.int()[:, :, None], dim=1, ) / ( proposals_npoints[:, None] + 1e-6 ) # n_pred, n_clas - sem_scores = torch.gather(sem_scores, 1, cls_logits_pred_b.unsqueeze(-1)).squeeze(-1) + sem_scores = torch.gather( + sem_scores, 1, cls_logits_pred_b.unsqueeze(-1) + ).squeeze(-1) scores = mask_logit_scores * torch.pow(cls_logits_scores, 0.5) * sem_scores @@ -252,7 +281,9 @@ def generate_proposal( scores_final = scores[final_cond] num_insts = scores_final.shape[0] - proposals_pred = torch.zeros((num_insts, num_points), dtype=torch.int, device=mask_logit_b.device) + proposals_pred = torch.zeros( + (num_insts, num_points), dtype=torch.int, device=mask_logit_b.device + ) inst_inds, point_inds = torch.nonzero(masks_final, as_tuple=True) @@ -268,14 +299,18 @@ def parse_dynamic_params(self, params, out_channels): num_instances = params.size(0) num_layers = len(self.weight_nums) - params_splits = list(torch.split_with_sizes(params, self.weight_nums + self.bias_nums, dim=1)) + params_splits = list( + torch.split_with_sizes(params, self.weight_nums + self.bias_nums, dim=1) + ) weight_splits = params_splits[:num_layers] bias_splits = params_splits[num_layers:] for l in range(num_layers): if l < num_layers - 1: - weight_splits[l] = weight_splits[l].reshape(num_instances * out_channels, -1, 1) + weight_splits[l] = weight_splits[l].reshape( + num_instances * out_channels, -1, 1 + ) bias_splits[l] = bias_splits[l].reshape(num_instances * out_channels) else: weight_splits[l] = weight_splits[l].reshape(num_instances, -1, 1) @@ -284,14 +319,26 @@ def parse_dynamic_params(self, params, out_channels): return weight_splits, bias_splits def mask_heads_forward( - self, geo_dist, mask_features, weights, biases, num_insts, coords_, fps_sampling_coords, use_geo=True + self, + geo_dist, + mask_features, + weights, + biases, + num_insts, + coords_, + fps_sampling_coords, + use_geo=True, ): assert mask_features.dim() == 3 n_layers = len(weights) n_mask = mask_features.size(0) - x = mask_features.permute(2, 1, 0).repeat(num_insts, 1, 1) # num_inst * c * N_mask + x = mask_features.permute(2, 1, 0).repeat( + num_insts, 1, 1 + ) # num_inst * c * N_mask - relative_coords = fps_sampling_coords.reshape(-1, 1, 3) - coords_.reshape(1, -1, 3) # N_inst * N_mask * 3 + relative_coords = fps_sampling_coords.reshape(-1, 1, 3) - coords_.reshape( + 1, -1, 3 + ) # N_inst * N_mask * 3 if use_geo: n_queries, n_contexts = geo_dist.shape[:2] @@ -305,9 +352,9 @@ def mask_heads_forward( ) # b x n_queries x n_contexts x 3 cond = (geo_dist < 0).unsqueeze(-1).expand(n_queries, n_contexts, 3) - relative_coords[cond] = relative_coords[cond] + max_geo_dist_context[cond] * torch.sign( - relative_coords[cond] - ) + relative_coords[cond] = relative_coords[cond] + max_geo_dist_context[ + cond + ] * torch.sign(relative_coords[cond]) relative_coords = relative_coords.permute(0, 2, 1) x = torch.cat([relative_coords, x], dim=1) # num_inst * (3+c) * N_mask @@ -324,7 +371,13 @@ def mask_heads_forward( return x def get_mask_prediction( - self, geo_dists, param_kernels, mask_features, locs_float_, fps_sampling_locs, batch_offsets_ + self, + geo_dists, + param_kernels, + mask_features, + locs_float_, + fps_sampling_locs, + batch_offsets_, ): # param_kernels = param_kernels.permute(0, 2, 1, 3) # num_layers x batch x n_queries x channel num_layers, n_queries, batch = ( @@ -342,8 +395,12 @@ def get_mask_prediction( 1, 2 ) # batch x n_queries x n_classes - param_kernel2 = param_kernel.transpose(0, 1).flatten(0, 1) # (batch * n_queries) * channel - before_embedding_feature = self.before_embedding_tower(torch.unsqueeze(param_kernel2, dim=2)) + param_kernel2 = param_kernel.transpose(0, 1).flatten( + 0, 1 + ) # (batch * n_queries) * channel + before_embedding_feature = self.before_embedding_tower( + torch.unsqueeze(param_kernel2, dim=2) + ) controllers = self.controller(before_embedding_feature).squeeze(dim=2) controllers = controllers.reshape(batch, n_queries, -1) @@ -394,8 +451,12 @@ def preprocess_input(self, batch_input, batch_size): if cfg.use_coords: feats = torch.cat((feats, locs_float), 1).float() - voxel_feats = pointgroup_ops.voxelization(feats, v2p_map, cfg.mode) # (M, C), float, cuda - sparse_input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size) + voxel_feats = pointgroup_ops.voxelization( + feats, v2p_map, cfg.mode + ) # (M, C), float, cuda + sparse_input = SparseConvTensor( + voxel_feats, voxel_coords.int(), spatial_shape, batch_size + ) return sparse_input @@ -414,7 +475,9 @@ def forward(self, batch_input, epoch, training=True): batch_input["pc_mins"], ] - output_feats, semantic_scores, semantic_preds = self.forward_backbone(batch_input, batch_size) + output_feats, semantic_scores, semantic_preds = self.forward_backbone( + batch_input, batch_size + ) outputs["semantic_scores"] = semantic_scores if epoch <= self.prepare_epochs: @@ -438,10 +501,14 @@ def forward(self, batch_input, epoch, training=True): output_feats_ = output_feats[fg_idxs] semantic_scores_ = semantic_scores[fg_idxs] - mask_features_ = self.mask_tower(torch.unsqueeze(output_feats_, dim=2).permute(2, 1, 0)).permute(2, 1, 0) + mask_features_ = self.mask_tower( + torch.unsqueeze(output_feats_, dim=2).permute(2, 1, 0) + ).permute(2, 1, 0) # NOTE aggregator - contexts = self.forward_aggregator(locs_float_, output_feats_, batch_offsets_, batch_size) + contexts = self.forward_aggregator( + locs_float_, output_feats_, batch_offsets_, batch_size + ) if contexts is None: outputs["mask_predictions"] = None return outputs @@ -456,20 +523,25 @@ def forward(self, batch_input, epoch, training=True): pre_enc_inds, locs_float_, batch_offsets_, - max_step=128 if self.training else 256, + # max_step=128 if self.training else 256, + max_step=128, neighbor=64, radius=0.05, n_queries=cfg.n_query_points, ) # NOTE transformer decoder - dec_outputs = self.forward_decoder(context_locs, context_feats, query_locs, pc_dims, geo_dists, pre_enc_inds) + dec_outputs = self.forward_decoder( + context_locs, context_feats, query_locs, pc_dims, geo_dists, pre_enc_inds + ) if training: # NOTE subsample for dynamic conv # NOTE: downsample when training to avoid OOM - idxs_subsample, idxs_subsample_raw = random_downsample(batch_offsets_, batch_size, n_subsample=30000) + idxs_subsample, idxs_subsample_raw = random_downsample( + batch_offsets_, batch_size, n_subsample=30000 + ) geo_dists_subsample = [] for b in range(batch_size): @@ -479,15 +551,21 @@ def forward(self, batch_input, epoch, training=True): mask_features_subsample = mask_features_[idxs_subsample] locs_float_subsample = locs_float_[idxs_subsample] batch_idxs_subsample = batch_idxs_[idxs_subsample] - batch_offsets_subsample = self.get_batch_offsets(batch_idxs_subsample, batch_size) - + batch_offsets_subsample = self.get_batch_offsets( + batch_idxs_subsample, batch_size + ) outputs["fg_idxs"] = fg_idxs[idxs_subsample] outputs["num_insts"] = cfg.n_query_points * batch_size outputs["batch_idxs"] = batch_idxs_subsample mask_predictions = self.get_mask_prediction( - geo_dists_subsample, dec_outputs, mask_features_subsample, locs_float_subsample, query_locs, batch_offsets_subsample + geo_dists_subsample, + dec_outputs, + mask_features_subsample, + locs_float_subsample, + query_locs, + batch_offsets_subsample, ) outputs["mask_predictions"] = mask_predictions @@ -498,9 +576,14 @@ def forward(self, batch_input, epoch, training=True): outputs["fg_idxs"] = fg_idxs outputs["num_insts"] = cfg.n_query_points * batch_size outputs["batch_idxs"] = batch_idxs_ - + torch.cuda.empty_cache() mask_predictions = self.get_mask_prediction( - geo_dists, dec_outputs, mask_features_, locs_float_, query_locs, batch_offsets_ + geo_dists, + dec_outputs, + mask_features_, + locs_float_, + query_locs, + batch_offsets_, ) outputs["mask_predictions"] = mask_predictions @@ -528,7 +611,9 @@ def forward(self, batch_input, epoch, training=True): return outputs def forward_backbone(self, batch_input, batch_size): - context_backbone = torch.no_grad if "unet" in self.fix_module else torch.enable_grad + context_backbone = ( + torch.no_grad if "unet" in self.fix_module else torch.enable_grad + ) with context_backbone(): p2v_map = batch_input["p2v_map"] @@ -548,8 +633,12 @@ def forward_backbone(self, batch_input, batch_size): return output_feats, semantic_scores, semantic_preds - def forward_aggregator(self, locs_float_, output_feats_, batch_offsets_, batch_size): - context_aggregator = torch.no_grad if "set_aggregator" in self.fix_module else torch.enable_grad + def forward_aggregator( + self, locs_float_, output_feats_, batch_offsets_, batch_size + ): + context_aggregator = ( + torch.no_grad if "set_aggregator" in self.fix_module else torch.enable_grad + ) with context_aggregator(): context_locs = [] @@ -573,14 +662,19 @@ def forward_aggregator(self, locs_float_, output_feats_, batch_offsets_, batch_s npoint = cfg.n_downsampling sampling_indices = torch.tensor( - np.random.choice(batch_points, npoint, replace=False), dtype=torch.long, device=locs_float_.device + np.random.choice(batch_points, npoint, replace=False), + dtype=torch.long, + device=locs_float_.device, ) locs_float_b = locs_float_b[sampling_indices].unsqueeze(0) output_feats_b = output_feats_b[sampling_indices].unsqueeze(0) - context_locs_b, grouped_features_b, grouped_xyz_b, pre_enc_inds_b = self.set_aggregator.group_points( - locs_float_b.contiguous(), output_feats_b.transpose(1, 2).contiguous() + context_locs_b, grouped_features_b, grouped_xyz_b, pre_enc_inds_b = ( + self.set_aggregator.group_points( + locs_float_b.contiguous(), + output_feats_b.transpose(1, 2).contiguous(), + ) ) context_locs.append(context_locs_b) @@ -598,11 +692,15 @@ def forward_aggregator(self, locs_float_, output_feats_, batch_offsets_, batch_s return context_locs, context_feats, pre_enc_inds - def forward_decoder(self, context_locs, context_feats, query_locs, pc_dims, geo_dists, pre_enc_inds): + def forward_decoder( + self, context_locs, context_feats, query_locs, pc_dims, geo_dists, pre_enc_inds + ): batch_size = context_locs.shape[0] context_embedding_pos = self.pos_embedding(context_locs, input_range=pc_dims) - context_feats = self.encoder_to_decoder_projection(context_feats.permute(0, 2, 1)) # batch x channel x npoints + context_feats = self.encoder_to_decoder_projection( + context_feats.permute(0, 2, 1) + ) # batch x channel x npoints """ Init dec_inputs by query features """ query_embedding_pos = self.pos_embedding(query_locs, input_range=pc_dims) @@ -623,13 +721,19 @@ def forward_decoder(self, context_locs, context_feats, query_locs, pc_dims, geo_ geo_dist_context = [] for b in range(batch_size): - geo_dist_context_b = geo_dists[b][:, pre_enc_inds[b].long()] # n_queries x n_contexts + geo_dist_context_b = geo_dists[b][ + :, pre_enc_inds[b].long() + ] # n_queries x n_contexts geo_dist_context.append(geo_dist_context_b) - geo_dist_context = torch.stack(geo_dist_context, dim=0) # b x n_queries x n_contexts + geo_dist_context = torch.stack( + geo_dist_context, dim=0 + ) # b x n_queries x n_contexts max_geo_dist_context = torch.max(geo_dist_context, dim=2)[0] # b x n_queries max_geo_val = torch.max(max_geo_dist_context) - max_geo_dist_context[max_geo_dist_context < 0] = max_geo_val # NOTE assign very big value to invalid queries + max_geo_dist_context[max_geo_dist_context < 0] = ( + max_geo_val # NOTE assign very big value to invalid queries + ) max_geo_dist_context = max_geo_dist_context[:, :, None, None].expand( batch_size, n_queries, n_contexts, 3 @@ -641,7 +745,8 @@ def forward_decoder(self, context_locs, context_feats, query_locs, pc_dims, geo_ geo_dist_context[cond] = max_geo_dist_context[cond] + relative_coords[cond] relative_embedding_pos = self.pos_embedding( - geo_dist_context.reshape(batch_size, n_queries * n_contexts, -1), input_range=pc_dims + geo_dist_context.reshape(batch_size, n_queries * n_contexts, -1), + input_range=pc_dims, ).reshape( batch_size, -1, diff --git a/model/geoformer/geoformer_fs.py b/model/geoformer/geoformer_fs.py index 5692c9b..1ce29b0 100644 --- a/model/geoformer/geoformer_fs.py +++ b/model/geoformer/geoformer_fs.py @@ -1,5 +1,9 @@ import functools -import spconv as spconv + +# import spconv as spconv +from spconv.pytorch.modules import SparseSequential +from spconv.pytorch.conv import SubMConv3d +from spconv.pytorch.core import SparseConvTensor import torch import torch.nn as nn @@ -9,7 +13,12 @@ from lib.pointgroup_ops.functions import pointgroup_ops from lib.pointnet2.pointnet2_modules import PointnetSAModuleVotesSeparate from model.geoformer.geodesic_utils import cal_geodesic_vectorize -from model.geoformer.geoformer_modules import ResidualBlock, UBlock, conv_with_kaiming_uniform, random_downsample +from model.geoformer.geoformer_modules import ( + ResidualBlock, + UBlock, + conv_with_kaiming_uniform, + random_downsample, +) from model.helper import GenericMLP from model.pos_embedding import PositionEmbeddingCoordsSine from model.transformer_detr import TransformerDecoder, TransformerDecoderLayer @@ -37,8 +46,10 @@ def __init__(self): norm_fn = functools.partial(nn.BatchNorm1d, eps=1e-4, momentum=0.1) # backbone - self.input_conv = spconv.SparseSequential( - spconv.SubMConv3d(input_c, m, kernel_size=3, padding=1, bias=False, indice_key="subm1") + self.input_conv = SparseSequential( + SubMConv3d( + input_c, m, kernel_size=3, padding=1, bias=False, indice_key="subm1" + ) ) self.unet = UBlock( [m, 2 * m, 3 * m, 4 * m, 5 * m, 6 * m, 7 * m], @@ -48,11 +59,16 @@ def __init__(self): use_backbone_transformer=True, indice_key_id=1, ) - self.output_layer = spconv.SparseSequential(norm_fn(m), nn.ReLU()) + self.output_layer = SparseSequential(norm_fn(m), nn.ReLU()) # semantic segmentation self.semantic = nn.Sequential( - nn.Linear(m, m, bias=True), norm_fn(m), nn.ReLU(), nn.Linear(m, m, bias=True), norm_fn(m), nn.ReLU() + nn.Linear(m, m, bias=True), + norm_fn(m), + nn.ReLU(), + nn.Linear(m, m, bias=True), + norm_fn(m), + nn.ReLU(), ) self.semantic_linear = nn.Linear(m, classes, bias=True) @@ -73,7 +89,9 @@ def __init__(self): for i in range(before_embedding_conv_num - 1): before_embedding_tower.append(conv_block(cfg.dec_dim, cfg.dec_dim)) before_embedding_tower.append(conv_block(cfg.dec_dim, self.output_dim)) - self.add_module("before_embedding_tower", nn.Sequential(*before_embedding_tower)) + self.add_module( + "before_embedding_tower", nn.Sequential(*before_embedding_tower) + ) # cond inst generate parameters for self.use_coords = True @@ -113,7 +131,9 @@ def __init__(self): ) """ Position embedding """ - self.pos_embedding = PositionEmbeddingCoordsSine(d_pos=cfg.dec_dim, pos_type="fourier", normalize=True) + self.pos_embedding = PositionEmbeddingCoordsSine( + d_pos=cfg.dec_dim, pos_type="fourier", normalize=True + ) """ DETR-Decoder """ decoder_layer = TransformerDecoderLayer( @@ -125,7 +145,9 @@ def __init__(self): use_rel=True, ) - self.decoder = TransformerDecoder(decoder_layer, num_layers=cfg.dec_nlayers, return_intermediate=True) + self.decoder = TransformerDecoder( + decoder_layer, num_layers=cfg.dec_nlayers, return_intermediate=True + ) self.query_projection = GenericMLP( input_dim=cfg.dec_dim, @@ -215,7 +237,9 @@ def generate_proposal( proposals_npoints = torch.sum(mask_logit_b_bool, dim=1) npoints_cond = proposals_npoints >= npoint_thresh - mask_logit_scores = torch.sum(mask_logit_b * mask_logit_b_bool.int(), dim=1) / (proposals_npoints + 1e-6) + mask_logit_scores = torch.sum(mask_logit_b * mask_logit_b_bool.int(), dim=1) / ( + proposals_npoints + 1e-6 + ) mask_logit_scores_cond = mask_logit_scores >= score_thresh scores = mask_logit_scores * torch.pow(similarity_score_b, 0.5) @@ -229,7 +253,9 @@ def generate_proposal( scores_final = scores[final_cond] num_insts = scores_final.shape[0] - proposals_pred = torch.zeros((num_insts, num_points), dtype=torch.int, device=mask_logit_b.device) + proposals_pred = torch.zeros( + (num_insts, num_points), dtype=torch.int, device=mask_logit_b.device + ) inst_inds, point_inds = torch.nonzero(masks_final, as_tuple=True) @@ -245,14 +271,18 @@ def parse_dynamic_params(self, params, out_channels): num_instances = params.size(0) num_layers = len(self.weight_nums) - params_splits = list(torch.split_with_sizes(params, self.weight_nums + self.bias_nums, dim=1)) + params_splits = list( + torch.split_with_sizes(params, self.weight_nums + self.bias_nums, dim=1) + ) weight_splits = params_splits[:num_layers] bias_splits = params_splits[num_layers:] for l in range(num_layers): if l < num_layers - 1: - weight_splits[l] = weight_splits[l].reshape(num_instances * out_channels, -1, 1) + weight_splits[l] = weight_splits[l].reshape( + num_instances * out_channels, -1, 1 + ) bias_splits[l] = bias_splits[l].reshape(num_instances * out_channels) else: weight_splits[l] = weight_splits[l].reshape(num_instances, -1, 1) @@ -261,16 +291,28 @@ def parse_dynamic_params(self, params, out_channels): return weight_splits, bias_splits def mask_heads_forward( - self, geo_dist, mask_features, weights, biases, num_insts, coords_, fps_sampling_coords, use_geo=True - ): + self, + geo_dist, + mask_features, + weights, + biases, + num_insts, + coords_, + fps_sampling_coords, + use_geo=True, + ): assert mask_features.dim() == 3 n_layers = len(weights) n_mask = mask_features.size(0) - x = mask_features.permute(2, 1, 0).repeat(num_insts, 1, 1) # num_inst * c * N_mask + x = mask_features.permute(2, 1, 0).repeat( + num_insts, 1, 1 + ) # num_inst * c * N_mask geo_dist = geo_dist.cuda() - relative_coords = fps_sampling_coords[:, None, :] - coords_[None, :, :] # N_inst * N_mask * 3 + relative_coords = ( + fps_sampling_coords[:, None, :] - coords_[None, :, :] + ) # N_inst * N_mask * 3 if use_geo: n_queries, n_contexts = geo_dist.shape[:2] @@ -284,9 +326,9 @@ def mask_heads_forward( ) # b x n_queries x n_contexts x 3 cond = (geo_dist < 0).unsqueeze(-1).expand(n_queries, n_contexts, 3) - relative_coords[cond] = relative_coords[cond] + max_geo_dist_context[cond] * torch.sign( - relative_coords[cond] - ) + relative_coords[cond] = relative_coords[cond] + max_geo_dist_context[ + cond + ] * torch.sign(relative_coords[cond]) relative_coords = relative_coords.permute(0, 2, 1) x = torch.cat([relative_coords, x], dim=1) # num_inst * (3+c) * N_mask @@ -300,13 +342,19 @@ def mask_heads_forward( return x def get_mask_prediction( - self, geo_dists, param_kernels, mask_features, locs_float_, fps_sampling_locs, batch_offsets_ + self, + geo_dists, + param_kernels, + mask_features, + locs_float_, + fps_sampling_locs, + batch_offsets_, ): # param_kernels = param_kernels.permute(0, 2, 1, 3) # num_layers x batch x n_queries x channel num_layers, n_queries, batch = ( param_kernels.shape[0], param_kernels.shape[1], - param_kernels.shape[2] + param_kernels.shape[2], ) outputs = [] @@ -314,8 +362,12 @@ def get_mask_prediction( param_kernel = param_kernels[l] # n_queries x batch x channel # mlp head outputs are (num_layers x batch) x noutput x nqueries, so transpose last two dims - param_kernel2 = param_kernel.transpose(0, 1).flatten(0, 1) # (batch * n_queries) * channel - before_embedding_feature = self.before_embedding_tower(torch.unsqueeze(param_kernel2, dim=2)) + param_kernel2 = param_kernel.transpose(0, 1).flatten( + 0, 1 + ) # (batch * n_queries) * channel + before_embedding_feature = self.before_embedding_tower( + torch.unsqueeze(param_kernel2, dim=2) + ) controllers = self.controller(before_embedding_feature).squeeze(dim=2) controllers = controllers.reshape(batch, n_queries, -1) @@ -369,8 +421,12 @@ def preprocess_input(self, batch_input, batch_size): if cfg.use_coords: feats = torch.cat((feats, locs_float), 1).float() - voxel_feats = pointgroup_ops.voxelization(feats, v2p_map, cfg.mode) # (M, C), float, cuda - sparse_input = spconv.SparseConvTensor(voxel_feats, voxel_coords.int(), spatial_shape, batch_size) + voxel_feats = pointgroup_ops.voxelization( + feats, v2p_map, cfg.mode + ) # (M, C), float, cuda + sparse_input = SparseConvTensor( + voxel_feats, voxel_coords.int(), spatial_shape, batch_size + ) return sparse_input @@ -410,18 +466,33 @@ def process_support(self, batch_input, training=True): output_feats_b = output_feats_[start:end, :].unsqueeze(0) # support_embedding = torch.mean(output_feats_b, dim=1) # channel - context_locs_b, grouped_features, grouped_xyz, pre_enc_inds = self.set_aggregator.group_points( - locs_float_b.contiguous(), output_feats_b.transpose(1, 2).contiguous(), npoint_new=32 + context_locs_b, grouped_features, grouped_xyz, pre_enc_inds = ( + self.set_aggregator.group_points( + locs_float_b.contiguous(), + output_feats_b.transpose(1, 2).contiguous(), + npoint_new=32, + ) + ) + context_feats_b = self.set_aggregator.mlp( + grouped_features, grouped_xyz, pooling="avg" ) - context_feats_b = self.set_aggregator.mlp(grouped_features, grouped_xyz, pooling="avg") - context_feats_b = context_feats_b.transpose(1, 2) # 1 x n_point x channel + context_feats_b = context_feats_b.transpose( + 1, 2 + ) # 1 x n_point x channel support_embedding = torch.mean(context_feats_b, dim=1) # channel support_embeddings.append(support_embedding) support_embeddings = torch.cat(support_embeddings) # batch x channel return support_embeddings - def forward(self, support_dict, scene_dict, training=True, remember=False, support_embeddings=None): + def forward( + self, + support_dict, + scene_dict, + training=True, + remember=False, + support_embeddings=None, + ): outputs = {} batch_idxs = scene_dict["locs"][:, 0].int() @@ -457,7 +528,9 @@ def forward(self, support_dict, scene_dict, training=True, remember=False, suppo outputs["semantic_scores"] = semantic_scores else: # with torch.cuda.amp.autocast(enabled=True): - output_feats, semantic_scores, semantic_preds = self.forward_backbone(scene_dict, batch_size) + output_feats, semantic_scores, semantic_preds = self.forward_backbone( + scene_dict, batch_size + ) outputs["semantic_scores"] = semantic_scores @@ -474,16 +547,26 @@ def forward(self, support_dict, scene_dict, training=True, remember=False, suppo output_feats_ = output_feats[fg_idxs] semantic_preds_ = semantic_preds[fg_idxs] - context_mask_tower = ( - torch.enable_grad if self.training and "mask_tower" not in self.fix_module else torch.no_grad - ) - with context_mask_tower(): - mask_features_ = self.mask_tower(torch.unsqueeze(output_feats_, dim=2).permute(2, 1, 0)).permute( - 2, 1, 0 + try: + context_mask_tower = ( + torch.enable_grad + if self.training and "mask_tower" not in self.fix_module + else torch.no_grad ) + with context_mask_tower(): + mask_features_ = self.mask_tower( + torch.unsqueeze(output_feats_, dim=2).permute(2, 1, 0) + ).permute(2, 1, 0) + except Exception as e: + print( + f"Failed on input scene: {scene_dict} with output_feats: {output_feats_}" + ) + raise e # NOTE aggregator - contexts = self.forward_aggregator(locs_float_, output_feats_, batch_offsets_, batch_size) + contexts = self.forward_aggregator( + locs_float_, output_feats_, batch_offsets_, batch_size + ) if contexts is None: outputs["mask_predictions"] = None return outputs @@ -527,27 +610,40 @@ def forward(self, support_dict, scene_dict, training=True, remember=False, suppo return outputs if support_embeddings is None: - support_embeddings = self.process_support(support_dict, training) # batch x channel + support_embeddings = self.process_support( + support_dict, training + ) # batch x channel with torch.cuda.amp.autocast(enabled=True): # NOTE aggregate support and query feats - channel_wise_tensor = context_feats * support_embeddings.unsqueeze(1).repeat(1, cfg.n_decode_point, 1) - subtraction_tensor = context_feats - support_embeddings.unsqueeze(1).repeat(1, cfg.n_decode_point, 1) + channel_wise_tensor = context_feats * support_embeddings.unsqueeze( + 1 + ).repeat(1, cfg.n_decode_point, 1) + subtraction_tensor = context_feats - support_embeddings.unsqueeze(1).repeat( + 1, cfg.n_decode_point, 1 + ) aggregation_tensor = torch.cat( [channel_wise_tensor, subtraction_tensor, context_feats], dim=2 ) # batch * n_sampling *(3*channel) # NOTE transformer decoder - + dec_outputs = self.forward_decoder( - context_locs, aggregation_tensor, query_locs, pc_dims, geo_dists, pre_enc_inds + context_locs, + aggregation_tensor, + query_locs, + pc_dims, + geo_dists, + pre_enc_inds, ) if not training: dec_outputs = dec_outputs[-1:, ...] else: # NOTE: downsample when training to avoid OOM - idxs_subsample, idxs_subsample_raw = random_downsample(batch_offsets_, batch_size, n_subsample=30000) + idxs_subsample, idxs_subsample_raw = random_downsample( + batch_offsets_, batch_size, n_subsample=30000 + ) geo_dists2 = [] for b in range(batch_size): @@ -564,13 +660,23 @@ def forward(self, support_dict, scene_dict, training=True, remember=False, suppo # NOTE dynamic convolution mask_predictions = self.get_mask_prediction( - geo_dists, dec_outputs, mask_features_, locs_float_, query_locs, batch_offsets_ + geo_dists, + dec_outputs, + mask_features_, + locs_float_, + query_locs, + batch_offsets_, ) mask_logit_final = mask_predictions[-1]["mask_logits"] - similarity_score = self.similarity_net(aggregation_tensor[:, :cfg.n_query_points, :].flatten(0,1)).squeeze(-1).reshape(batch_size, cfg.n_query_points) # batch x n_sampling - + similarity_score = ( + self.similarity_net( + aggregation_tensor[:, : cfg.n_query_points, :].flatten(0, 1) + ) + .squeeze(-1) + .reshape(batch_size, cfg.n_query_points) + ) # batch x n_sampling if training: outputs["fg_idxs"] = fg_idxs @@ -596,7 +702,11 @@ def forward(self, support_dict, scene_dict, training=True, remember=False, suppo return outputs def forward_backbone(self, batch_input, batch_size): - context_backbone = torch.enable_grad if self.training and "unet" not in self.fix_module else torch.no_grad + context_backbone = ( + torch.enable_grad + if self.training and "unet" not in self.fix_module + else torch.no_grad + ) with context_backbone(): p2v_map = batch_input["p2v_map"] @@ -616,9 +726,13 @@ def forward_backbone(self, batch_input, batch_size): return output_feats, semantic_scores, semantic_preds - def forward_aggregator(self, locs_float_, output_feats_, batch_offsets_, batch_size): + def forward_aggregator( + self, locs_float_, output_feats_, batch_offsets_, batch_size + ): context_aggregator = ( - torch.enable_grad if self.training and "set_aggregator" not in self.fix_module else torch.no_grad + torch.enable_grad + if self.training and "set_aggregator" not in self.fix_module + else torch.no_grad ) with context_aggregator(): @@ -640,8 +754,11 @@ def forward_aggregator(self, locs_float_, output_feats_, batch_offsets_, batch_s locs_float_b = locs_float_b.unsqueeze(0) output_feats_b = output_feats_b.unsqueeze(0) - context_locs_b, grouped_features_b, grouped_xyz_b, pre_enc_inds_b = self.set_aggregator.group_points( - locs_float_b.contiguous(), output_feats_b.transpose(1, 2).contiguous() + context_locs_b, grouped_features_b, grouped_xyz_b, pre_enc_inds_b = ( + self.set_aggregator.group_points( + locs_float_b.contiguous(), + output_feats_b.transpose(1, 2).contiguous(), + ) ) context_locs.append(context_locs_b) @@ -659,11 +776,15 @@ def forward_aggregator(self, locs_float_, output_feats_, batch_offsets_, batch_s return context_locs, context_feats, pre_enc_inds - def forward_decoder(self, context_locs, context_feats, query_locs, pc_dims, geo_dists, pre_enc_inds): + def forward_decoder( + self, context_locs, context_feats, query_locs, pc_dims, geo_dists, pre_enc_inds + ): batch_size = context_locs.shape[0] context_embedding_pos = self.pos_embedding(context_locs, input_range=pc_dims) - context_feats = self.encoder_to_decoder_projection(context_feats.permute(0, 2, 1)) # batch x channel x npoints + context_feats = self.encoder_to_decoder_projection( + context_feats.permute(0, 2, 1) + ) # batch x channel x npoints """ Init dec_inputs by query features """ query_embedding_pos = self.pos_embedding(query_locs, input_range=pc_dims) @@ -684,13 +805,19 @@ def forward_decoder(self, context_locs, context_feats, query_locs, pc_dims, geo_ geo_dist_context = [] for b in range(batch_size): - geo_dist_context_b = geo_dists[b][:, pre_enc_inds[b].long()] # n_queries x n_contexts + geo_dist_context_b = geo_dists[b][ + :, pre_enc_inds[b].long() + ] # n_queries x n_contexts geo_dist_context.append(geo_dist_context_b) - geo_dist_context = torch.stack(geo_dist_context, dim=0) # b x n_queries x n_contexts + geo_dist_context = torch.stack( + geo_dist_context, dim=0 + ) # b x n_queries x n_contexts max_geo_dist_context = torch.max(geo_dist_context, dim=2)[0] # b x n_queries max_geo_val = torch.max(max_geo_dist_context) - max_geo_dist_context[max_geo_dist_context < 0] = max_geo_val # NOTE assign very big value to invalid queries + max_geo_dist_context[max_geo_dist_context < 0] = ( + max_geo_val # NOTE assign very big value to invalid queries + ) max_geo_dist_context = max_geo_dist_context[:, :, None, None].expand( batch_size, n_queries, n_contexts, 3 @@ -702,7 +829,8 @@ def forward_decoder(self, context_locs, context_feats, query_locs, pc_dims, geo_ geo_dist_context[cond] = max_geo_dist_context[cond] + relative_coords[cond] relative_embedding_pos = self.pos_embedding( - geo_dist_context.reshape(batch_size, n_queries * n_contexts, -1), input_range=pc_dims + geo_dist_context.reshape(batch_size, n_queries * n_contexts, -1), + input_range=pc_dims, ).reshape( batch_size, -1, @@ -723,7 +851,13 @@ def forward_decoder(self, context_locs, context_feats, query_locs, pc_dims, geo_ return dec_outputs def get_similarity( - self, mask_logit_final, batch_offsets_, locs_float_, output_feats_, support_embeddings, pre_enc_inds_mask=None + self, + mask_logit_final, + batch_offsets_, + locs_float_, + output_feats_, + support_embeddings, + pre_enc_inds_mask=None, ): batch_size = len(mask_logit_final) @@ -737,52 +871,74 @@ def get_similarity( end = batch_offsets_[b + 1] locs_float_b = locs_float_[start:end, :].unsqueeze(0) - output_feats_b = output_feats_[start:end, :].unsqueeze(0) # 1, n_point, f + output_feats_b = output_feats_[start:end, :].unsqueeze( + 0 + ) # 1, n_point, f npoint_new = min(4096, end - start) - context_locs_b, grouped_features, grouped_xyz, pre_enc_inds_b1 = self.set_aggregator.group_points( - locs_float_b.contiguous(), - output_feats_b.transpose(1, 2).contiguous(), - npoint_new=npoint_new, - inds=None if no_cache else pre_enc_inds_mask[b], + context_locs_b, grouped_features, grouped_xyz, pre_enc_inds_b1 = ( + self.set_aggregator.group_points( + locs_float_b.contiguous(), + output_feats_b.transpose(1, 2).contiguous(), + npoint_new=npoint_new, + inds=None if no_cache else pre_enc_inds_mask[b], + ) ) - context_feats_b1 = self.set_aggregator.mlp(grouped_features, grouped_xyz, pooling="avg") - context_feats_b1 = context_feats_b1.transpose(1, 2) # 1 x n_point x channel + context_feats_b1 = self.set_aggregator.mlp( + grouped_features, grouped_xyz, pooling="avg" + ) + context_feats_b1 = context_feats_b1.transpose( + 1, 2 + ) # 1 x n_point x channel - mask_logit_final_b = mask_logit_final[b].detach().sigmoid() # n_queries, mask + mask_logit_final_b = ( + mask_logit_final[b].detach().sigmoid() + ) # n_queries, mask - mask_logit_final_b = mask_logit_final_b[:, pre_enc_inds_b1.squeeze(0).long()] + mask_logit_final_b = mask_logit_final_b[ + :, pre_enc_inds_b1.squeeze(0).long() + ] mask_logit_final_bool = mask_logit_final_b >= 0.2 # n_queries, mask count_mask = torch.sum(mask_logit_final_bool, dim=1).int() output_feats_b_expand = context_feats_b1.expand( - count_mask.shape[0], context_feats_b1.shape[1], context_feats_b1.shape[2] + count_mask.shape[0], + context_feats_b1.shape[1], + context_feats_b1.shape[2], ) # n_queries, mask, f final_mask_features = torch.sum( (output_feats_b_expand * mask_logit_final_bool[:, :, None]), dim=1 ) # n_queries, f - final_mask_features = final_mask_features / (count_mask[:, None] + 1e-6) # n_queries, f + final_mask_features = final_mask_features / ( + count_mask[:, None] + 1e-6 + ) # n_queries, f final_mask_features[count_mask <= 1] = 0.0 final_mask_features_arr.append(final_mask_features) if no_cache: pre_enc_inds_mask.append(pre_enc_inds_b1) - final_mask_features_arr = torch.stack(final_mask_features_arr, dim=0) # batch, n_queries, f + final_mask_features_arr = torch.stack( + final_mask_features_arr, dim=0 + ) # batch, n_queries, f """ channel-wise correlate """ - channel_wise_tensor_sim = final_mask_features_arr * support_embeddings.unsqueeze(1).repeat( - 1, final_mask_features_arr.shape[1], 1 - ) - subtraction_tensor_sim = final_mask_features_arr - support_embeddings.unsqueeze(1).repeat( - 1, final_mask_features_arr.shape[1], 1 + channel_wise_tensor_sim = ( + final_mask_features_arr + * support_embeddings.unsqueeze(1).repeat( + 1, final_mask_features_arr.shape[1], 1 + ) ) + subtraction_tensor_sim = final_mask_features_arr - support_embeddings.unsqueeze( + 1 + ).repeat(1, final_mask_features_arr.shape[1], 1) aggregation_tensor_sim = torch.cat( - [channel_wise_tensor_sim, subtraction_tensor_sim, final_mask_features_arr], dim=2 + [channel_wise_tensor_sim, subtraction_tensor_sim, final_mask_features_arr], + dim=2, ) similarity_score = ( self.similarity_net(aggregation_tensor_sim.flatten(0, 1)) diff --git a/model/geoformer/geoformer_modules.py b/model/geoformer/geoformer_modules.py index f81d496..3d2bf3c 100644 --- a/model/geoformer/geoformer_modules.py +++ b/model/geoformer/geoformer_modules.py @@ -3,35 +3,54 @@ import torch import torch.nn as nn from model.transformer import TransformerEncoder -from spconv.modules import SparseModule +from spconv.pytorch.conv import SubMConv3d, SparseInverseConv3d, SparseConv3d +from spconv.pytorch.modules import SparseModule, SparseSequential +from spconv.pytorch.core import SparseConvTensor from util.warpper import BatchNorm1d, Conv1d import numpy as np + class ResidualBlock(SparseModule): def __init__(self, in_channels, out_channels, norm_fn, indice_key=None): super().__init__() if in_channels == out_channels: - self.i_branch = spconv.SparseSequential(nn.Identity()) + self.i_branch = SparseSequential(nn.Identity()) else: - self.i_branch = spconv.SparseSequential( - spconv.SubMConv3d(in_channels, out_channels, kernel_size=1, bias=False) + self.i_branch = SparseSequential( + SubMConv3d(in_channels, out_channels, kernel_size=1, bias=False) ) - self.conv_branch = spconv.SparseSequential( + self.conv_branch = SparseSequential( norm_fn(in_channels), nn.ReLU(), - spconv.SubMConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key), + SubMConv3d( + in_channels, + out_channels, + kernel_size=3, + padding=1, + bias=False, + indice_key=indice_key, + ), norm_fn(out_channels), nn.ReLU(), - spconv.SubMConv3d(out_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key), + SubMConv3d( + out_channels, + out_channels, + kernel_size=3, + padding=1, + bias=False, + indice_key=indice_key, + ), ) def forward(self, input): - identity = spconv.SparseConvTensor(input.features, input.indices, input.spatial_shape, input.batch_size) + identity = SparseConvTensor( + input.features, input.indices, input.spatial_shape, input.batch_size + ) output = self.conv_branch(input) - output.features += self.i_branch(identity).features - + new_features = output.features + self.i_branch(identity).features + output = output.replace_feature(new_features) return output @@ -39,10 +58,17 @@ class VGGBlock(SparseModule): def __init__(self, in_channels, out_channels, norm_fn, indice_key=None): super().__init__() - self.conv_layers = spconv.SparseSequential( + self.conv_layers = SparseSequential( norm_fn(in_channels), nn.ReLU(), - spconv.SubMConv3d(in_channels, out_channels, kernel_size=3, padding=1, bias=False, indice_key=indice_key), + SubMConv3d( + in_channels, + out_channels, + kernel_size=3, + padding=1, + bias=False, + indice_key=indice_key, + ), ) def forward(self, input): @@ -50,31 +76,46 @@ def forward(self, input): class UBlock(nn.Module): - def __init__(self, nPlanes, norm_fn, block_reps, block, use_backbone_transformer=False, indice_key_id=1): + def __init__( + self, + nPlanes, + norm_fn, + block_reps, + block, + use_backbone_transformer=False, + indice_key_id=1, + ): super().__init__() self.nPlanes = nPlanes blocks = { - "block{}".format(i): block(nPlanes[0], nPlanes[0], norm_fn, indice_key="subm{}".format(indice_key_id)) + "block{}".format(i): block( + nPlanes[0], + nPlanes[0], + norm_fn, + indice_key="subm{}".format(indice_key_id), + ) for i in range(block_reps) } blocks = OrderedDict(blocks) - self.blocks = spconv.SparseSequential(blocks) + self.blocks = SparseSequential(blocks) if len(nPlanes) <= 2 and use_backbone_transformer: d_model = 128 self.before_transformer_linear = nn.Linear(nPlanes[0], d_model) - self.transformer = TransformerEncoder(d_model=d_model, N=2, heads=4, d_ff=64) + self.transformer = TransformerEncoder( + d_model=d_model, N=2, heads=4, d_ff=64 + ) self.after_transformer_linear = nn.Linear(d_model, nPlanes[0]) else: self.before_transformer_linear = None self.transformer = None self.after_transformer_linear = None if len(nPlanes) > 1: - self.conv = spconv.SparseSequential( + self.conv = SparseSequential( norm_fn(nPlanes[0]), nn.ReLU(), - spconv.SparseConv3d( + SparseConv3d( nPlanes[0], nPlanes[1], kernel_size=2, @@ -85,46 +126,64 @@ def __init__(self, nPlanes, norm_fn, block_reps, block, use_backbone_transformer ) self.u = UBlock( - nPlanes[1:], norm_fn, block_reps, block, use_backbone_transformer, indice_key_id=indice_key_id + 1 + nPlanes[1:], + norm_fn, + block_reps, + block, + use_backbone_transformer, + indice_key_id=indice_key_id + 1, ) - self.deconv = spconv.SparseSequential( + self.deconv = SparseSequential( norm_fn(nPlanes[1]), nn.ReLU(), - spconv.SparseInverseConv3d( - nPlanes[1], nPlanes[0], kernel_size=2, bias=False, indice_key="spconv{}".format(indice_key_id) + SparseInverseConv3d( + nPlanes[1], + nPlanes[0], + kernel_size=2, + bias=False, + indice_key="spconv{}".format(indice_key_id), ), ) blocks_tail = {} for i in range(block_reps): blocks_tail["block{}".format(i)] = block( - nPlanes[0] * (2 - i), nPlanes[0], norm_fn, indice_key="subm{}".format(indice_key_id) + nPlanes[0] * (2 - i), + nPlanes[0], + norm_fn, + indice_key="subm{}".format(indice_key_id), ) blocks_tail = OrderedDict(blocks_tail) - self.blocks_tail = spconv.SparseSequential(blocks_tail) + self.blocks_tail = SparseSequential(blocks_tail) def forward(self, input): output = self.blocks(input) - identity = spconv.SparseConvTensor(output.features, output.indices, output.spatial_shape, output.batch_size) + identity = SparseConvTensor( + output.features, output.indices, output.spatial_shape, output.batch_size + ) if len(self.nPlanes) > 1: output_decoder = self.conv(output) output_decoder = self.u(output_decoder) output_decoder = self.deconv(output_decoder) - output.features = torch.cat((identity.features, output_decoder.features), dim=1) + output = output.replace_feature( + torch.cat((identity.features, output_decoder.features), dim=1) + ) output = self.blocks_tail(output) - + if self.before_transformer_linear: batch_ids = output.indices[:, 0] xyz = output.indices[:, 1:].float() feats = output.features before_params_feats = self.before_transformer_linear(feats) - feats = self.transformer(xyz=xyz, features=before_params_feats, batch_ids=batch_ids) + feats = self.transformer( + xyz=xyz, features=before_params_feats, batch_ids=batch_ids + ) feats = self.after_transformer_linear(feats) - output.features = feats + output = output.replace_feature(feats) return output @@ -139,7 +198,13 @@ def make_conv(in_channels, out_channels): groups = 1 conv = conv_func( - in_channels, out_channels, kernel_size=1, stride=1, padding=0, groups=groups, bias=(norm is None) + in_channels, + out_channels, + kernel_size=1, + stride=1, + padding=0, + groups=groups, + bias=(norm is None), ) nn.init.kaiming_uniform_(conv.weight, a=1) @@ -161,7 +226,6 @@ def make_conv(in_channels, out_channels): return make_conv - def random_downsample(batch_offsets, batch_size, n_subsample=30000): idxs_subsample = [] idxs_subsample_raw = [] @@ -170,14 +234,14 @@ def random_downsample(batch_offsets, batch_size, n_subsample=30000): num_points_b = (end - start).cpu() if n_subsample == -1 or n_subsample >= num_points_b: - new_inds = torch.arange(num_points_b, dtype=torch.long, device=batch_offsets.device) + new_inds = torch.arange( + num_points_b, dtype=torch.long, device=batch_offsets.device + ) else: - new_inds = ( - torch.tensor( - np.random.choice(num_points_b, n_subsample, replace=False), - dtype=torch.long, - device=batch_offsets.device, - ) + new_inds = torch.tensor( + np.random.choice(num_points_b, n_subsample, replace=False), + dtype=torch.long, + device=batch_offsets.device, ) idxs_subsample_raw.append(new_inds) idxs_subsample.append(new_inds + start) diff --git a/test_fs.py b/test_fs.py index 70b6c37..dd4c98e 100644 --- a/test_fs.py +++ b/test_fs.py @@ -31,7 +31,9 @@ def init(): def load_set_support(model, dataset): - set_support_name = cfg.type_support + str(cfg.cvfold) + "_" + str(cfg.k_shot) + "shot_10sets.pth" + set_support_name = ( + cfg.type_support + str(cfg.cvfold) + "_" + str(cfg.k_shot) + "shot_10sets.pth" + ) set_support_file = os.path.join("exp", cfg.file_support, set_support_name) # print(set_support_file) @@ -57,18 +59,31 @@ def load_set_support(model, dataset): for i in range(cfg.k_shot): support_tuple = list_scenes[i] - support_scene_name, support_instance_id = support_tuple[0], support_tuple[1] + support_scene_name, support_instance_id = ( + support_tuple[0], + support_tuple[1], + ) ( support_xyz_middle, support_xyz_scaled, support_rgb, support_label, support_instance_label, - ) = dataset.load_single(support_scene_name, aug=False, permutate=False, val=True, support=True) + ) = dataset.load_single( + support_scene_name, + aug=False, + permutate=False, + val=True, + support=True, + ) - support_mask = (support_instance_label == support_instance_id).astype(int) + support_mask = ( + support_instance_label == support_instance_id + ).astype(int) - support_batch_offsets = torch.tensor([0, support_xyz_middle.shape[0]], dtype=torch.int) + support_batch_offsets = torch.tensor( + [0, support_xyz_middle.shape[0]], dtype=torch.int + ) support_masks_offset = torch.tensor( [0, np.count_nonzero(support_mask)], dtype=torch.int ) # int (B+1) @@ -79,14 +94,22 @@ def load_set_support(model, dataset): ], 1, ) - support_locs_float = torch.from_numpy(support_xyz_middle).to(torch.float32) - support_feats = torch.from_numpy(support_rgb).to(torch.float32) # float (N, C) + support_locs_float = torch.from_numpy(support_xyz_middle).to( + torch.float32 + ) + support_feats = torch.from_numpy(support_rgb).to( + torch.float32 + ) # float (N, C) support_masks = torch.from_numpy(support_mask) - support_spatial_shape = np.clip((support_locs.max(0)[0][1:] + 1).numpy(), cfg.full_scale[0], None) + support_spatial_shape = np.clip( + (support_locs.max(0)[0][1:] + 1).numpy(), + cfg.full_scale[0], + None, + ) # voxelize - support_voxel_locs, support_p2v_map, support_v2p_map = pointgroup_ops.voxelization_idx( - support_locs, 1, dataset.mode + support_voxel_locs, support_p2v_map, support_v2p_map = ( + pointgroup_ops.voxelization_idx(support_locs, 1, dataset.mode) ) support_dict = { @@ -154,13 +177,17 @@ def do_test(model, dataset): if torch.is_tensor(query_dict[key]): query_dict[key] = query_dict[key].to(net_device) - for j, (label, support_dict) in enumerate(zip(active_label, list_support_dicts)): + for j, (label, support_dict) in enumerate( + zip(active_label, list_support_dicts) + ): for k in range(cfg.run_num): # NOTE number of runs remember = False if (j == 0 and k == 0) else True support_embeddings = None if cfg.fix_support: - support_embeddings = set_support_vectors[k][label].unsqueeze(0).to(net_device) + support_embeddings = ( + set_support_vectors[k][label].unsqueeze(0).to(net_device) + ) else: for key in support_dict: if torch.is_tensor(support_dict[key]): @@ -180,16 +207,22 @@ def do_test(model, dataset): continue benchmark_label = BENCHMARK_SEMANTIC_LABELS[label] - cluster_semantic = torch.ones((proposals_pred.shape[0])).cuda() * benchmark_label + cluster_semantic = ( + torch.ones((proposals_pred.shape[0])).cuda() * benchmark_label + ) clusters[k].append(proposals_pred) cluster_scores[k].append(scores_pred) cluster_semantic_id[k].append(cluster_semantic) # torch.cuda.empty_cache() - + print(f"clusters: {clusters}") + print(f"cluster_scores: {cluster_scores}") + print(f"cluster_semantic_id: {cluster_semantic_id}") test_scene_name_arr.append(test_scene_name) - gt_file_name = os.path.join(cfg.data_root, cfg.dataset, "val_gt", test_scene_name + ".txt") + gt_file_name = os.path.join( + cfg.data_root, cfg.dataset, "val_gt", test_scene_name + ".txt" + ) gt_file_arr.append(gt_file_name) for k in range(cfg.run_num): @@ -208,12 +241,14 @@ def do_test(model, dataset): clusters[k].float(), cluster_scores[k], cluster_semantic_id[k], - final_score_thresh=0.5 + final_score_thresh=0.5, ) clusters[k] = clusters[k][pick_idxs_cluster].cpu().numpy() cluster_scores[k] = cluster_scores[k][pick_idxs_cluster].cpu().numpy() - cluster_semantic_id[k] = cluster_semantic_id[k][pick_idxs_cluster].cpu().numpy() + cluster_semantic_id[k] = ( + cluster_semantic_id[k][pick_idxs_cluster].cpu().numpy() + ) nclusters[k] = clusters[k].shape[0] if cfg.eval: @@ -228,7 +263,9 @@ def do_test(model, dataset): logger.info( f"Test scene {i+1}/{num_test_scenes}: {test_scene_name} | Elapsed time: {int(overlap_time)}s | Remaining time: {int(overlap_time * float(num_test_scenes-(i+1))/(i+1))}s" ) - logger.info(f"Num points: {N} | Num instances of {cfg.run_num} runs: {nclusters}") + logger.info( + f"Num points: {N} | Num instances of {cfg.run_num} runs: {nclusters}" + ) # evaluation if cfg.eval: @@ -245,7 +282,9 @@ def do_test(model, dataset): test_scene_name = test_scene_name_arr[i] gt_ids = load_ids(gt_file_name) - gt2pred, pred2gt = eval.assign_instances_for_scan(test_scene_name, pred_info, gt_ids) + gt2pred, pred2gt = eval.assign_instances_for_scan( + test_scene_name, pred_info, gt_ids + ) matches[test_scene_name] = {} matches[test_scene_name]["gt"] = gt2pred matches[test_scene_name]["pred"] = pred2gt @@ -268,20 +307,38 @@ def do_test(model, dataset): model = GeoFormerFS() model = model.cuda() - logger.info("# parameters (model): {}".format(sum([x.nelement() for x in model.parameters()]))) + logger.info( + "# parameters (model): {}".format( + sum([x.nelement() for x in model.parameters()]) + ) + ) checkpoint_fn = cfg.resume if os.path.isfile(checkpoint_fn): logger.info("=> loading checkpoint '{}'".format(checkpoint_fn)) state = torch.load(checkpoint_fn) + # print(f"state is") + # for key, val in state.items(): + # print(f"{key}: {val.shape if hasattr(val, 'shape') else type(val)}") model_state_dict = model.state_dict() - loaded_state_dict = strip_prefix_if_present(state["state_dict"], prefix="module.") + # print(f"model state dict is ") + # for key, val in model_state_dict.items(): + # print(f"{key}: {val.shape if hasattr(val, 'shape') else type(val)}") + loaded_state_dict = strip_prefix_if_present( + state["state_dict"], prefix="module." + ) + # print(f"loaded_state_dict is") + # for key, val in loaded_state_dict.items(): + # print(f"{key}: {val.shape if hasattr(val, 'shape') else type(val)}") align_and_update_state_dicts(model_state_dict, loaded_state_dict) + # print(f"After alignment model state dict is: ") + # for key, val in model_state_dict.items(): + # print(f"{key}: {val.shape if hasattr(val, 'shape') else type(val)}") model.load_state_dict(model_state_dict) logger.info("=> loaded checkpoint '{}')".format(checkpoint_fn)) else: - raise RuntimeError + raise RuntimeError(f"No checkpoint found at {checkpoint_fn}") dataset = FSInstDataset(split_set="val")