Skip to content

Commit

Permalink
addtl format fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
daviddpruitt committed Aug 7, 2024
1 parent 30749a4 commit f5ad041
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 14 deletions.
2 changes: 1 addition & 1 deletion earth2grid/csrc/healpixpad/healpixpad.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ void HEALPixPad_bwd_fp64(int dimI, // batch size
double *dataIn_d,
double *dataOut_d,
cudaStream_t stream=0);

#ifdef __cplusplus
}
#endif
Expand Down
14 changes: 7 additions & 7 deletions earth2grid/csrc/healpixpad/healpixpad_cuda_bwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
#include <stdlib.h>
#include <cuda_runtime.h>
#include "cudamacro.h"
#include "healpixpad.h"
#include "healpixpad.h"

#define THREADS 64

Expand Down Expand Up @@ -66,7 +66,7 @@ __global__ void HEALPixPadBck_bulk_k(const int padSize,
if (tid >= ((long long)dimI)*dimJ*dimK*dimL*dimM) {
return;
}

const long long sliceId = tid / (dimM*dimL);

const int i = (tid % (dimM*dimL)) / dimM;
Expand Down Expand Up @@ -747,9 +747,9 @@ void HEALPixPad_bwd_fp32(int padSize,
float *dataIn_d,
float *dataOut_d,
cudaStream_t stream) {

HEALPixPadBck<float>(padSize, dimI, dimJ, dimK, dimL, dimM, dataIn_d, dataOut_d, stream);

return;
}

Expand All @@ -762,9 +762,9 @@ void HEALPixPad_bwd_fp64(int padSize,
double *dataIn_d,
double *dataOut_d,
cudaStream_t stream) {

HEALPixPadBck<double>(padSize, dimI, dimJ, dimK, dimL, dimM, dataIn_d, dataOut_d, stream);

return;
}

Expand Down Expand Up @@ -834,4 +834,4 @@ std::vector<torch::Tensor> healpixpad_cuda_backward(
}

return {goutput};
}
}
7 changes: 3 additions & 4 deletions earth2grid/csrc/healpixpad/healpixpad_cuda_fwd.cu
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
#define DIV_UP(a,b) (((a)+((b)-1))/(b))

// All coordinates are w.r.t. a face[dimK][dimL][dimM]:
//
//
// ^ k-axis
// /
// *---------*
Expand Down Expand Up @@ -669,7 +669,7 @@ void HEALPixPad_fp32(
cudaStream_t stream) {

HEALPixPadFwd<float>(padSize, dimI, dimJ, dimK, dimL, dimM, dataIn_d, dataOut_d, stream);

return;
}

Expand Down Expand Up @@ -706,7 +706,7 @@ std::vector<torch::Tensor> healpixpad_cuda_forward(

// get cuda stream:
cudaStream_t my_stream = c10::cuda::getCurrentCUDAStream(input.device().index()).stream();

switch (input.scalar_type()) {
case torch::ScalarType::Double:
HEALPixPadFwd<double>(
Expand Down Expand Up @@ -760,4 +760,3 @@ std::vector<torch::Tensor> healpixpad_cuda_forward(

return {output};
}

8 changes: 6 additions & 2 deletions earth2grid/healpixpad.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ class HEALPixPadFunction(torch.autograd.Function):
"""
A torch autograd class that pads a healpixpad xy tensor
"""

@staticmethod
def forward(ctx, input, pad):
"""
Expand All @@ -43,7 +44,9 @@ def forward(ctx, input, pad):
"""
ctx.pad = pad
if len(input.shape) != 5:
raise ValueError(f"Input tensor must be have 4 dimensions (B, F, C, H, W), got {len(input.shape)} dimensions instead")
raise ValueError(
f"Input tensor must be have 4 dimensions (B, F, C, H, W), got {len(input.shape)} dimensions instead"
)
# make contiguous
input = input.contiguous()
out = healpixpad_cuda.forward(input, pad)[0]
Expand All @@ -56,6 +59,7 @@ def backward(ctx, grad):
out = healpixpad_cuda.backward(grad, pad)[0]
return out, None


class HEALPixPad(torch.nn.Module):
"""
A torch module that handles padding of healpixpad xy tensors
Expand All @@ -65,6 +69,7 @@ class HEALPixPad(torch.nn.Module):
padding: int
The amount to pad the tensors
"""

def __init__(self, padding: int):
super(HEALPixPad, self).__init__()
self.padding = padding
Expand All @@ -84,4 +89,3 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
torch.tensor: The padded tensor
"""
return HEALPixPadFunction.apply(input, self.padding)

0 comments on commit f5ad041

Please sign in to comment.