From 8c9dbbf63d4138a3c19cb708ed1bae571a3db2ff Mon Sep 17 00:00:00 2001 From: Nicolas Maier <20108526+MaierN@users.noreply.github.com> Date: Fri, 24 Nov 2023 01:18:04 +0100 Subject: [PATCH 1/3] fix: remove obsolete ml_nms --- adet/layers/csrc/ml_nms/ml_nms.cu | 139 ------------------------------ adet/layers/csrc/ml_nms/ml_nms.h | 32 ------- adet/layers/csrc/vision.cpp | 3 +- 3 files changed, 1 insertion(+), 173 deletions(-) delete mode 100644 adet/layers/csrc/ml_nms/ml_nms.cu delete mode 100644 adet/layers/csrc/ml_nms/ml_nms.h diff --git a/adet/layers/csrc/ml_nms/ml_nms.cu b/adet/layers/csrc/ml_nms/ml_nms.cu deleted file mode 100644 index f1c1a4206..000000000 --- a/adet/layers/csrc/ml_nms/ml_nms.cu +++ /dev/null @@ -1,139 +0,0 @@ -// Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved. -#include -#include -#include -#include - -#include -#include - -int const threadsPerBlock = sizeof(unsigned long long) * 8; - -__device__ inline float devIoU(float const * const a, float const * const b) { - if (a[5] != b[5]) { - return 0.0; - } - float left = max(a[0], b[0]), right = min(a[2], b[2]); - float top = max(a[1], b[1]), bottom = min(a[3], b[3]); - float width = max(right - left + 1, 0.f), height = max(bottom - top + 1, 0.f); - float interS = width * height; - float Sa = (a[2] - a[0] + 1) * (a[3] - a[1] + 1); - float Sb = (b[2] - b[0] + 1) * (b[3] - b[1] + 1); - return interS / (Sa + Sb - interS); -} - -__global__ void ml_nms_kernel(const int n_boxes, const float nms_overlap_thresh, - const float *dev_boxes, unsigned long long *dev_mask) { - const int row_start = blockIdx.y; - const int col_start = blockIdx.x; - - // if (row_start > col_start) return; - - const int row_size = - min(n_boxes - row_start * threadsPerBlock, threadsPerBlock); - const int col_size = - min(n_boxes - col_start * threadsPerBlock, threadsPerBlock); - - __shared__ float block_boxes[threadsPerBlock * 6]; - if (threadIdx.x < col_size) { - block_boxes[threadIdx.x * 6 + 0] = - dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 0]; - block_boxes[threadIdx.x * 6 + 1] = - dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 1]; - block_boxes[threadIdx.x * 6 + 2] = - dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 2]; - block_boxes[threadIdx.x * 6 + 3] = - dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 3]; - block_boxes[threadIdx.x * 6 + 4] = - dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 4]; - block_boxes[threadIdx.x * 6 + 5] = - dev_boxes[(threadsPerBlock * col_start + threadIdx.x) * 6 + 5]; - } - __syncthreads(); - - if (threadIdx.x < row_size) { - const int cur_box_idx = threadsPerBlock * row_start + threadIdx.x; - const float *cur_box = dev_boxes + cur_box_idx * 6; - int i = 0; - unsigned long long t = 0; - int start = 0; - if (row_start == col_start) { - start = threadIdx.x + 1; - } - for (i = start; i < col_size; i++) { - if (devIoU(cur_box, block_boxes + i * 6) > nms_overlap_thresh) { - t |= 1ULL << i; - } - } - const int col_blocks = THCCeilDiv(n_boxes, threadsPerBlock); - dev_mask[cur_box_idx * col_blocks + col_start] = t; - } -} - -namespace adet { - -// boxes is a N x 6 tensor -at::Tensor ml_nms_cuda(const at::Tensor boxes, const float nms_overlap_thresh) { - using scalar_t = float; - AT_ASSERTM(boxes.type().is_cuda(), "boxes must be a CUDA tensor"); - auto scores = boxes.select(1, 4); - auto order_t = std::get<1>(scores.sort(0, /* descending=*/true)); - auto boxes_sorted = boxes.index_select(0, order_t); - - int boxes_num = boxes.size(0); - - const int col_blocks = THCCeilDiv(boxes_num, threadsPerBlock); - - scalar_t* boxes_dev = boxes_sorted.data(); - - THCState *state = at::globalContext().lazyInitCUDA(); // TODO replace with getTHCState - - unsigned long long* mask_dev = NULL; - //THCudaCheck(THCudaMalloc(state, (void**) &mask_dev, - // boxes_num * col_blocks * sizeof(unsigned long long))); - - mask_dev = (unsigned long long*) THCudaMalloc(state, boxes_num * col_blocks * sizeof(unsigned long long)); - - dim3 blocks(THCCeilDiv(boxes_num, threadsPerBlock), - THCCeilDiv(boxes_num, threadsPerBlock)); - dim3 threads(threadsPerBlock); - ml_nms_kernel<<>>(boxes_num, - nms_overlap_thresh, - boxes_dev, - mask_dev); - - std::vector mask_host(boxes_num * col_blocks); - THCudaCheck(cudaMemcpy(&mask_host[0], - mask_dev, - sizeof(unsigned long long) * boxes_num * col_blocks, - cudaMemcpyDeviceToHost)); - - std::vector remv(col_blocks); - memset(&remv[0], 0, sizeof(unsigned long long) * col_blocks); - - at::Tensor keep = at::empty({boxes_num}, boxes.options().dtype(at::kLong).device(at::kCPU)); - int64_t* keep_out = keep.data(); - - int num_to_keep = 0; - for (int i = 0; i < boxes_num; i++) { - int nblock = i / threadsPerBlock; - int inblock = i % threadsPerBlock; - - if (!(remv[nblock] & (1ULL << inblock))) { - keep_out[num_to_keep++] = i; - unsigned long long *p = &mask_host[0] + i * col_blocks; - for (int j = nblock; j < col_blocks; j++) { - remv[j] |= p[j]; - } - } - } - - THCudaFree(state, mask_dev); - // TODO improve this part - return std::get<0>(order_t.index({ - keep.narrow(/*dim=*/0, /*start=*/0, /*length=*/num_to_keep).to( - order_t.device(), keep.scalar_type()) - }).sort(0, false)); -} - -} // namespace adet \ No newline at end of file diff --git a/adet/layers/csrc/ml_nms/ml_nms.h b/adet/layers/csrc/ml_nms/ml_nms.h deleted file mode 100644 index f33851a18..000000000 --- a/adet/layers/csrc/ml_nms/ml_nms.h +++ /dev/null @@ -1,32 +0,0 @@ -#pragma once -#include - -namespace adet { - - -#ifdef WITH_CUDA -at::Tensor ml_nms_cuda( - const at::Tensor dets, - const float threshold); -#endif - -at::Tensor ml_nms(const at::Tensor& dets, - const at::Tensor& scores, - const at::Tensor& labels, - const float threshold) { - - if (dets.type().is_cuda()) { -#ifdef WITH_CUDA - // TODO raise error if not compiled with CUDA - if (dets.numel() == 0) - return at::empty({0}, dets.options().dtype(at::kLong).device(at::kCPU)); - auto b = at::cat({dets, scores.unsqueeze(1), labels.unsqueeze(1)}, 1); - return ml_nms_cuda(b, threshold); -#else - AT_ERROR("Not compiled with GPU support"); -#endif - } - AT_ERROR("CPU version not implemented"); -} - -} // namespace adet diff --git a/adet/layers/csrc/vision.cpp b/adet/layers/csrc/vision.cpp index d780a95e3..623b9c046 100644 --- a/adet/layers/csrc/vision.cpp +++ b/adet/layers/csrc/vision.cpp @@ -1,6 +1,6 @@ // Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved -#include "ml_nms/ml_nms.h" +#include #include "DefROIAlign/DefROIAlign.h" #include "BezierAlign/BezierAlign.h" @@ -53,7 +53,6 @@ std::string get_compiler_version() { } PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { - m.def("ml_nms", &ml_nms, "Multi-Label NMS"); m.def("def_roi_align_forward", &DefROIAlign_forward, "def_roi_align_forward"); m.def("def_roi_align_backward", &DefROIAlign_backward, "def_roi_align_backward"); m.def("bezier_align_forward", &BezierAlign_forward, "bezier_align_forward"); From 86d1df92b9a061305154335e3df799dbdd442f62 Mon Sep 17 00:00:00 2001 From: Nicolas Maier <20108526+MaierN@users.noreply.github.com> Date: Sun, 16 Jun 2024 20:01:02 +0200 Subject: [PATCH 2/3] fix torch warning related to future change of default behavior of 'meshgrid' method --- adet/modeling/MEInst/MEInst.py | 2 +- adet/modeling/batext/batext.py | 2 +- adet/modeling/fcpose/fcpose_head.py | 2 +- adet/modeling/roi_heads/text_head.py | 2 +- adet/modeling/solov2/solov2.py | 4 ++-- adet/utils/comm.py | 2 +- 6 files changed, 7 insertions(+), 7 deletions(-) diff --git a/adet/modeling/MEInst/MEInst.py b/adet/modeling/MEInst/MEInst.py index d5e18221d..6f477ef0b 100644 --- a/adet/modeling/MEInst/MEInst.py +++ b/adet/modeling/MEInst/MEInst.py @@ -169,7 +169,7 @@ def compute_locations_per_level(h, w, stride, device): 0, h * stride, step=stride, dtype=torch.float32, device=device ) - shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij") shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 diff --git a/adet/modeling/batext/batext.py b/adet/modeling/batext/batext.py index 974bec5cd..4236af7dd 100644 --- a/adet/modeling/batext/batext.py +++ b/adet/modeling/batext/batext.py @@ -155,7 +155,7 @@ def compute_locations_per_level(self, h, w, stride, device): 0, h * stride, step=stride, dtype=torch.float32, device=device ) - shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij") shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 diff --git a/adet/modeling/fcpose/fcpose_head.py b/adet/modeling/fcpose/fcpose_head.py index 63820703e..a1516bec0 100644 --- a/adet/modeling/fcpose/fcpose_head.py +++ b/adet/modeling/fcpose/fcpose_head.py @@ -30,7 +30,7 @@ def compute_locations_per_level(h, w, stride, device): 0, h * stride, step=stride, dtype=torch.float32, device=device ) - shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij") shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 diff --git a/adet/modeling/roi_heads/text_head.py b/adet/modeling/roi_heads/text_head.py index b40d1703e..5b63e5f7b 100644 --- a/adet/modeling/roi_heads/text_head.py +++ b/adet/modeling/roi_heads/text_head.py @@ -85,7 +85,7 @@ def __init__(self, cfg): def forward(self, features): x_range = torch.linspace(-1, 1, features.shape[-1], device=features.device) y_range = torch.linspace(-1, 1, features.shape[-2], device=features.device) - y, x = torch.meshgrid(y_range, x_range) + y, x = torch.meshgrid(y_range, x_range, indexing="ij") y = y.expand([features.shape[0], 1, -1, -1]) x = x.expand([features.shape[0], 1, -1, -1]) coord_feat = torch.cat([x, y], 1) diff --git a/adet/modeling/solov2/solov2.py b/adet/modeling/solov2/solov2.py index 5fd4a50f7..4a938902a 100644 --- a/adet/modeling/solov2/solov2.py +++ b/adet/modeling/solov2/solov2.py @@ -606,7 +606,7 @@ def forward(self, features): # concat coord x_range = torch.linspace(-1, 1, ins_kernel_feat.shape[-1], device=ins_kernel_feat.device) y_range = torch.linspace(-1, 1, ins_kernel_feat.shape[-2], device=ins_kernel_feat.device) - y, x = torch.meshgrid(y_range, x_range) + y, x = torch.meshgrid(y_range, x_range, indexing="ij") y = y.expand([ins_kernel_feat.shape[0], 1, -1, -1]) x = x.expand([ins_kernel_feat.shape[0], 1, -1, -1]) coord_feat = torch.cat([x, y], 1) @@ -732,7 +732,7 @@ def forward(self, features): if i == 3: # add for coord. x_range = torch.linspace(-1, 1, mask_feat.shape[-1], device=mask_feat.device) y_range = torch.linspace(-1, 1, mask_feat.shape[-2], device=mask_feat.device) - y, x = torch.meshgrid(y_range, x_range) + y, x = torch.meshgrid(y_range, x_range, indexing="ij") y = y.expand([mask_feat.shape[0], 1, -1, -1]) x = x.expand([mask_feat.shape[0], 1, -1, -1]) coord_feat = torch.cat([x, y], 1) diff --git a/adet/utils/comm.py b/adet/utils/comm.py index 78f2f32dd..43c2511f0 100644 --- a/adet/utils/comm.py +++ b/adet/utils/comm.py @@ -54,7 +54,7 @@ def compute_locations(h, w, stride, device): 0, h * stride, step=stride, dtype=torch.float32, device=device ) - shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x) + shift_y, shift_x = torch.meshgrid(shifts_y, shifts_x, indexing="ij") shift_x = shift_x.reshape(-1) shift_y = shift_y.reshape(-1) locations = torch.stack((shift_x, shift_y), dim=1) + stride // 2 From ff6b7ff9aba377c2cd09d0f1c21195f0500af62a Mon Sep 17 00:00:00 2001 From: Nicolas Maier <20108526+MaierN@users.noreply.github.com> Date: Sun, 16 Jun 2024 20:06:29 +0200 Subject: [PATCH 3/3] fix torch warning related to class decorator --- adet/modeling/MEInst/MaskEncoding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/adet/modeling/MEInst/MaskEncoding.py b/adet/modeling/MEInst/MaskEncoding.py index ae5c71862..882d375af 100755 --- a/adet/modeling/MEInst/MaskEncoding.py +++ b/adet/modeling/MEInst/MaskEncoding.py @@ -6,7 +6,6 @@ VALUE_MIN = 0.01 -@torch.no_grad() class PCAMaskEncoding(nn.Module): """ To do the mask encoding of PCA. @@ -28,6 +27,7 @@ class PCAMaskEncoding(nn.Module): making data respect some hard-wired assumptions. sigmoid: (bool) whether to apply inverse sigmoid before transform. """ + @torch.no_grad() def __init__(self, cfg): super().__init__() self.cfg = cfg