Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ENH(psoct): add option for output data type #29

Merged
merged 13 commits into from
Nov 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 9 additions & 2 deletions linc_convert/modalities/psoct/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,17 +203,24 @@ def generate_pyramid(
slice(i * max_load, min((i + 1) * max_load, n))
for i, n in zip(chunk_index, prev_shape)
]
fullshape = omz[str(level - 1)].shape
dat = omz[str(level - 1)][tuple(slicer)]

# Discard the last voxel along odd dimensions
crop = [0 if x == 1 else x % 2 for x in dat.shape[-ndim:]]
crop = [
0 if y == 1 else x % 2 for x, y in zip(dat.shape[-ndim:], fullshape)
]
# Don't crop the axis not down-sampling
# cannot do if not no_pyramid_axis since it could be 0
if no_pyramid_axis is not None:
crop[no_pyramid_axis] = 0
slcr = [slice(-1) if x else slice(None) for x in crop]
dat = dat[tuple([Ellipsis, *slcr])]

if any(n == 0 for n in dat.shape):
# last strip had a single voxel, nothing to do
continue

patch_shape = dat.shape[-ndim:]

# Reshape into patches of shape 2x2x2
Expand All @@ -234,7 +241,7 @@ def generate_pyramid(
# -> flatten patches
smaller_shape = [max(n // 2, 1) for n in patch_shape]
if no_pyramid_axis is not None:
smaller_shape[2 * no_pyramid_axis] = patch_shape[no_pyramid_axis]
smaller_shape[no_pyramid_axis] = patch_shape[no_pyramid_axis]

dat = dat.reshape(batch + smaller_shape + [-1])

Expand Down
222 changes: 150 additions & 72 deletions linc_convert/modalities/psoct/multi_slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,9 @@
import json
import math
import os
from contextlib import contextmanager
from functools import wraps
from itertools import product
from typing import Any, Callable, Optional
from typing import Callable, Mapping, Optional
from warnings import warn

import cyclopts
Expand All @@ -38,54 +37,126 @@


def _automap(func: Callable) -> Callable:
"""Decorator to automatically map the array in the mat file.""" # noqa: D401
"""Automatically maps the array in the mat file."""

@wraps(func)
def wrapper(inp: str, out: str = None, **kwargs: dict) -> Any: # noqa: ANN401
def wrapper(inp: list[str], out: str = None, **kwargs: dict) -> callable:
if out is None:
out = os.path.splitext(inp[0])[0]
out += ".nii.zarr" if kwargs.get("nii", False) else ".ome.zarr"
kwargs["nii"] = kwargs.get("nii", False) or out.endswith(".nii.zarr")
with _mapmat(inp, kwargs.get("key", None)) as dat:
return func(dat, out, **kwargs)
dat = _mapmat(inp, kwargs.get("key", None))
return func(dat, out, **kwargs)

return wrapper


@contextmanager
def _mapmat(fnames: list[str], key: str = None) -> None:
"""Load or memory-map an array stored in a .mat file."""
loaded_data = []

for fname in fnames:
try:
# "New" .mat file
f = h5py.File(fname, "r")
except Exception:
# "Old" .mat file
f = loadmat(fname)

class _ArrayWrapper:
def _get_key(self, f: Mapping) -> str:
key = self.key
if key is None:
if not len(f.keys()):
raise Exception(f"{fname} is empty")
key = list(f.keys())[0]
raise Exception(f"{self.file} is empty")
for key in f.keys():
if key[:1] != "_":
break
if len(f.keys()) > 1:
warn(
f"More than one key in .mat file {fname}, "
f"More than one key in .mat file {self.file}, "
f'arbitrarily loading "{key}"'
)

if key not in f.keys():
raise Exception(f"Key {key} not found in file {fname}")
raise Exception(f"Key {key} not found in file {self.file}")

return key


class _H5ArrayWrapper(_ArrayWrapper):
def __init__(self, file: h5py.File, key: str | None) -> None:
self.file = file
self.key = key
self.array = file.get(self._get_key(self.file))

def __del__(self) -> None:
if hasattr(self.file, "close"):
self.file.close()

def load(self) -> np.ndarray:
self.array = self.array[...]
if hasattr(self.file, "close"):
self.file.close()
self.file = None
return self.array

@property
def shape(self) -> list[int]:
return self.array.shape

@property
def dtype(self) -> np.dtype:
return self.array.dtype

def __len__(self) -> int:
return len(self.array)

def __getitem__(self, index: object) -> np.ndarray:
return self.array[index]


if len(fnames) == 1:
yield f.get(key)
if hasattr(f, "close"):
f.close()
break
loaded_data.append(f.get(key))
yield loaded_data
# yield np.stack(loaded_data, axis=-1)
class _MatArrayWrapper(_ArrayWrapper):
def __init__(self, file: str, key: str | None) -> None:
self.file = file
self.key = key
self.array = None

def __del__(self) -> None:
if hasattr(self.file, "close"):
self.file.close()

def load(self) -> np.ndarray:
f = loadmat(self.file)
self.array = f.get(self._get_key(f))
self.file = None
return self.array

@property
def shape(self) -> list[int]:
if self.array is None:
self.load()
return self.array.shape

@property
def dtype(self) -> np.dtype:
if self.array is None:
self.load()
return self.array.dtype

def __len__(self) -> int:
if self.array is None:
self.load()
return len(self.array)

def __getitem__(self, index: object) -> np.ndarray:
if self.array is None:
self.load()
return self.array[index]


def _mapmat(fnames: list[str], key: str = None) -> list[_ArrayWrapper]:
"""Load or memory-map an array stored in a .mat file."""
# loaded_data = []

def make_wrapper(fname: str) -> callable:
try:
# "New" .mat file
f = h5py.File(fname, "r")
return _H5ArrayWrapper(f, key)
except Exception:
# "Old" .mat file
return _MatArrayWrapper(fname, key)

return [make_wrapper(fname) for fname in fnames]


@multi_slice.default
Expand All @@ -105,12 +176,17 @@ def convert(
nii: bool = False,
orientation: str = "RAS",
center: bool = True,
dtype: str | None = None,
) -> None:
"""
Matlab to OME-Zarr.

Convert OCT volumes in raw matlab files
into a pyramidal OME-ZARR (or NIfTI-Zarr) hierarchy.
Convert OCT volumes in raw matlab files into a pyramidal
OME-ZARR (or NIfTI-Zarr) hierarchy.

This command assumes that each slice in a volume is stored in a
different mat file. All slices must have the same shape, and will
be concatenated into a 3D Zarr.

Parameters
----------
Expand All @@ -133,13 +209,15 @@ def convert(
max_levels
Maximum number of pyramid levels
no_pool
Index of dimension to not pool when building pyramid
Index of dimension to not pool when building pyramid.
nii
Convert to nifti-zarr. True if path ends in ".nii.zarr"
orientation
Orientation of the volume
center
Set RAS[0, 0, 0] at FOV center
dtype
Data type to write into
"""
if isinstance(compressor_opt, str):
compressor_opt = ast.literal_eval(compressor_opt)
Expand All @@ -163,24 +241,25 @@ def convert(
omz = zarr.storage.DirectoryStore(out)
omz = zarr.group(store=omz, overwrite=True)

if not hasattr(inp[0], "dtype"):
raise Exception("Input is not numpy array. This is likely unexpected")
if len(inp[0].shape) != 2:
raise Exception("Input array is not 2d")
# if not hasattr(inp[0], "dtype"):
# raise Exception("Input is not an array. This is likely unexpected")
if len(inp[0].shape) < 2:
raise Exception("Input array is not 2d:", inp[0].shape)
# Prepare chunking options
dtype = dtype or np.dtype(inp[0].dtype).str
opt = {
"dimension_separator": r"/",
"order": "F",
"dtype": np.dtype(inp[0].dtype).str,
"dtype": dtype,
"fill_value": None,
"compressor": make_compressor(compressor, **compressor_opt),
}
inp: list = inp
inp_shape = (*inp[0].shape, len(inp))
inp_chunk = [min(x, max_load) for x in inp_shape]
nk = ceildiv(inp_shape[0], inp_chunk[0])
nj = ceildiv(inp_shape[1], inp_chunk[1])
ni = ceildiv(inp_shape[2], inp_chunk[2])
inp_chunk = [min(x, max_load) for x in inp_shape[-3:]]
nk = ceildiv(inp_shape[-3], inp_chunk[0])
nj = ceildiv(inp_shape[-2], inp_chunk[1])
ni = len(inp)

nblevels = min(
[int(math.ceil(math.log2(x))) for i, x in enumerate(inp_shape) if i != no_pool]
Expand All @@ -193,34 +272,33 @@ def convert(
omz.create_dataset(str(0), shape=inp_shape, **opt)

# iterate across input chunks
for i, j, k in product(range(ni), range(nj), range(nk)):
loaded_chunk = np.stack(
[
inp[index][
k * inp_chunk[0] : (k + 1) * inp_chunk[0],
j * inp_chunk[1] : (j + 1) * inp_chunk[1],
]
for index in range(i * inp_chunk[2], (i + 1) * inp_chunk[2])
],
axis=-1,
)

print(
f"[{i + 1:03d}, {j + 1:03d}, {k + 1:03d}]",
"/",
f"[{ni:03d}, {nj:03d}, {nk:03d}]",
# f"({1 + level}/{nblevels})",
end="\r",
)

# save current chunk
omz["0"][
k * inp_chunk[0] : k * inp_chunk[0] + loaded_chunk.shape[0],
j * inp_chunk[1] : j * inp_chunk[1] + loaded_chunk.shape[1],
i * inp_chunk[2] : i * inp_chunk[2] + loaded_chunk.shape[2],
] = loaded_chunk

generate_pyramid(omz, nblevels - 1, mode="mean")
for i in range(ni):
for j, k in product(range(nj), range(nk)):
loaded_chunk = inp[i][
...,
k * inp_chunk[0] : (k + 1) * inp_chunk[0],
j * inp_chunk[1] : (j + 1) * inp_chunk[1],
]

print(
f"[{i + 1:03d}, {j + 1:03d}, {k + 1:03d}]",
"/",
f"[{ni:03d}, {nj:03d}, {nk:03d}]",
# f"({1 + level}/{nblevels})",
end="\r",
)

# save current chunk
omz["0"][
...,
k * inp_chunk[0] : k * inp_chunk[0] + loaded_chunk.shape[0],
j * inp_chunk[1] : j * inp_chunk[1] + loaded_chunk.shape[1],
i,
] = loaded_chunk

inp[i] = None # no ref count -> delete array

generate_pyramid(omz, nblevels - 1, mode="mean", no_pyramid_axis=no_pool)

print("")

Expand All @@ -234,7 +312,7 @@ def convert(
no_pool=no_pool,
space_unit=ome_unit,
space_scale=vx,
multiscales_type=("2x2x2" if no_pool is None else "2x2") + "mean window",
multiscales_type=(("2x2x2" if no_pool is None else "2x2") + "mean window"),
)

if not nii:
Expand Down
Loading
Loading