Skip to content

Commit

Permalink
Merge pull request #4 from data-exp-lab/dtype_handling
Browse files Browse the repository at this point in the history
handle different dtypes
  • Loading branch information
chrishavlin authored Oct 2, 2024
2 parents c37c6ac + 6bcded3 commit 60ad9fd
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 7 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ ignore = [
"PLR09", # Too many <...>
"PLR2004", # Magic value used in comparison
"ISC001", # Conflicts with formatter
"SIM108", # Use ternary operator
]
isort.required-imports = ["from __future__ import annotations"]
# Uncomment if using a _compat.typing backport
Expand Down
25 changes: 20 additions & 5 deletions src/pyramid_sampler/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,8 +132,10 @@ def _downsample_by_one_level(
level = coarse_level
fine_level = level - 1
lev_shape = self._get_level_shape(level)

field1 = zarr.open(self.zarr_store_path)[zarr_field]
field1.empty(level, shape=lev_shape, chunks=self.chunks)
dtype = field1[fine_level].dtype
field1.empty(level, shape=lev_shape, chunks=self.chunks, dtype=dtype)

numchunks = field1[str(level)].nchunks

Expand Down Expand Up @@ -197,7 +199,8 @@ def _write_chunk_values(
)

coarse_zarr = zarr.open(zarr_file)[zarr_field][str(level)]
coarse_zarr[si[0] : ei[0], si[1] : ei[1] :, si[2] : ei[2]] = outvals
dtype = coarse_zarr.dtype
coarse_zarr[si[0] : ei[0], si[1] : ei[1] :, si[2] : ei[2]] = outvals.astype(dtype)

return 1

Expand All @@ -208,15 +211,27 @@ def initialize_test_image(
base_resolution: tuple[int, int, int],
chunks: int | tuple[int, int, int] | None = None,
overwrite_field: bool = True,
dtype: str | type | None = None,
) -> None:
if dtype is None:
dtype = np.float64
field1 = zarr_store.create_group(zarr_field, overwrite=overwrite_field)

if chunks is None:
chunks = (64, 64, 64)
lev0 = da.random.random(base_resolution, chunks=chunks)
fac: int | float
if np.issubdtype(dtype, np.integer):
fac = 100
elif np.issubdtype(dtype, np.floating):
fac = 1.0
else:
msg = f"Unexpected dtype of {dtype}"
raise RuntimeError(msg)
lev0 = fac * da.random.random(base_resolution, chunks=chunks)
lev0 = lev0.astype(dtype)
halfway = np.asarray(base_resolution) // 2
lev0[0 : halfway[0], 0 : halfway[1], 0 : halfway[2]] = (
lev0[0 : halfway[0], 0 : halfway[1], 0 : halfway[2]] + 0.5
lev0[0 : halfway[0], 0 : halfway[1], 0 : halfway[2]] + 0.5 * fac
)
field1.empty(0, shape=base_resolution, chunks=chunks)
field1.empty(0, shape=base_resolution, chunks=chunks, dtype=dtype)
da.to_zarr(lev0, field1["0"])
8 changes: 6 additions & 2 deletions tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,24 @@ def test_initialize_test_image(tmp_path):
assert zarr_store[fieldname][0].shape == res


def test_downsampler(tmp_path):
@pytest.mark.parametrize("dtype", ["float32", np.float64, "int", np.int32, np.int16])
def test_downsampler(tmp_path, dtype):
tmp_zrr = str(tmp_path / "myzarr.zarr")
zarr_store = zarr.open(tmp_zrr)
res = (32, 32, 32)
chunks = (8, 8, 8)
fieldname = "test_field"
initialize_test_image(zarr_store, fieldname, res, chunks, overwrite_field=False)
initialize_test_image(
zarr_store, fieldname, res, chunks, overwrite_field=False, dtype=dtype
)

dsr = Downsampler(tmp_zrr, (2, 2, 2), res, chunks)

dsr.downsample(10, fieldname)
expected_max_lev = 2
for lev in range(expected_max_lev + 1):
assert lev in zarr_store[fieldname]
assert zarr_store[fieldname][lev].dtype == np.dtype(dtype)

with pytest.raises(ValueError, match="max_level must exceed 0"):
dsr.downsample(0, fieldname)
Expand Down

0 comments on commit 60ad9fd

Please sign in to comment.