Skip to content

Commit

Permalink
fix cpu default
Browse files Browse the repository at this point in the history
  • Loading branch information
daubners committed Jan 21, 2025
1 parent 96b6e94 commit b9fa69e
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions taufactor/metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def crop_area_of_interest_numpy(array, labels):
max(min_idx[2]-3, 0):min(max_idx[2]+4, array.shape[2])]
return sub_array

def gaussian_kernel_3d_torch(size=3, sigma=1.0, device=torch.device('cuda')):
def gaussian_kernel_3d_torch(size=3, sigma=1.0, device):
"""Creates a 3D Gaussian kernel using PyTorch"""
ax = torch.linspace(-(size // 2), size // 2, size)
xx, yy, zz = torch.meshgrid(ax, ax, ax, indexing="ij")
Expand Down Expand Up @@ -90,6 +90,10 @@ def specific_surface_area(img, spacing=(1,1,1), phases={}, method='gradient', de
[dx,dy,dz] = spacing
surface_areas = {}

if torch.device(device).type.startswith('cuda') and not torch.cuda.is_available():
device = torch.device('cpu')
warnings.warn("CUDA not available, defaulting device to cpu.")

if (method == 'gradient') | (method == 'face_counting'):
if type(img) is not type(torch.tensor(1)):
tensor = torch.tensor(img)
Expand All @@ -102,7 +106,7 @@ def specific_surface_area(img, spacing=(1,1,1), phases={}, method='gradient', de
labels = torch.unique(tensor)
labels = labels.int()
phases = {str(label.item()): label.item() for label in labels}
gaussian = gaussian_kernel_3d_torch(device=device)
gaussian = gaussian_kernel_3d_torch(device)

volume = torch.numel(tensor)
for name, label in phases.items():
Expand Down

0 comments on commit b9fa69e

Please sign in to comment.