Skip to content

Commit

Permalink
Fix use_torch flag
Browse files Browse the repository at this point in the history
  • Loading branch information
1pha committed Apr 29, 2024
1 parent 580ff21 commit b175a9f
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions sage/xai/atlas_overlap.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ def flatten_to_dict(arr: np.ndarray,
elif isinstance(arr, torch.Tensor):
# Monai `MetaTensor` would also get caught here.
mask_ = arr.clone()
use_torch = True

if use_abs:
mask_ = np.abs(mask_)
Expand All @@ -114,7 +115,7 @@ def flatten_to_dict(arr: np.ndarray,
# Norm
num_nonzero = torch.sum(roi_mask)
roi_val = torch.nansum(roi_mask * mask_)
xai_dict[label] = (roi_val / num_nonzero).cpu().numpy()
xai_dict[label] = float((roi_val / num_nonzero).cpu().numpy())
mask_ = mask_.cpu().numpy()

else:
Expand All @@ -132,6 +133,7 @@ def project_to_atlas(atlas: Bunch,
title: str = "",
use_abs: bool = True,
vmin: float = None, vmax: float = None,
threshold: float = 0.25,
root_dir: Path | str = None,
verbose: bool = False) -> np.ndarray:
agg_saliency = np.zeros_like(atlas.array)
Expand All @@ -145,12 +147,15 @@ def project_to_atlas(atlas: Bunch,

save = root_dir / "proj_glass.png" if root_dir is not None else None
nilp_.plot_glass_brain(arr=agg_saliency,
target_affine=atlas.nii.affine, title=title,
target_affine=atlas.nii.affine, title=title, cmap=nilp.cm.bwr,
vmin=vmin, vmax=vmax, colorbar=True, plot_abs=use_abs, save=save)

save = root_dir / "proj_mosaic.png" if root_dir is not None else None
nilp_.plot_overlay(arr=agg_saliency, target_affine=atlas.nii.affine,
display_mode="mosaic", threshold=0.25, title=title, colorbar=True,
if (vmin is None) or (vmax is None):
_max = np.abs(agg_saliency).max()
vmin, vmax = -_max, _max
nilp_.plot_overlay(arr=agg_saliency, target_affine=atlas.nii.affine, vmin=vmin, vmax=vmax,
display_mode="mosaic", threshold=threshold, title=title, colorbar=True,
cmap=nilp.cm.red_transparent if use_abs else nilp.cm.bwr, save=save)
return agg_saliency

Expand All @@ -173,7 +178,6 @@ def calculate_overlaps(arr: np.ndarray,

xai_dict, mask_ = flatten_to_dict(arr=arr, atlas=atlas,
use_torch=use_torch, device=device, use_abs=use_abs)

if plot_raw_sal:
_title = f"{title}_RAW Mask"
save = root_dir / "raw_mask.png" if root_dir is not None else root_dir
Expand Down

0 comments on commit b175a9f

Please sign in to comment.