From 6bcded38fcbf9f1eedffaad070bc1e8d9d5f5f4b Mon Sep 17 00:00:00 2001 From: chavlin Date: Wed, 2 Oct 2024 14:37:02 -0500 Subject: [PATCH] handle different dtypes --- pyproject.toml | 1 + src/pyramid_sampler/sampler.py | 25 ++++++++++++++++++++----- tests/test_sampler.py | 8 ++++++-- 3 files changed, 27 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 1b3fdb4..c6bdb28 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -131,6 +131,7 @@ ignore = [ "PLR09", # Too many <...> "PLR2004", # Magic value used in comparison "ISC001", # Conflicts with formatter + "SIM108", # Use ternary operator ] isort.required-imports = ["from __future__ import annotations"] # Uncomment if using a _compat.typing backport diff --git a/src/pyramid_sampler/sampler.py b/src/pyramid_sampler/sampler.py index 5fe0e0d..0d4e1a1 100644 --- a/src/pyramid_sampler/sampler.py +++ b/src/pyramid_sampler/sampler.py @@ -132,8 +132,10 @@ def _downsample_by_one_level( level = coarse_level fine_level = level - 1 lev_shape = self._get_level_shape(level) + field1 = zarr.open(self.zarr_store_path)[zarr_field] - field1.empty(level, shape=lev_shape, chunks=self.chunks) + dtype = field1[fine_level].dtype + field1.empty(level, shape=lev_shape, chunks=self.chunks, dtype=dtype) numchunks = field1[str(level)].nchunks @@ -197,7 +199,8 @@ def _write_chunk_values( ) coarse_zarr = zarr.open(zarr_file)[zarr_field][str(level)] - coarse_zarr[si[0] : ei[0], si[1] : ei[1] :, si[2] : ei[2]] = outvals + dtype = coarse_zarr.dtype + coarse_zarr[si[0] : ei[0], si[1] : ei[1] :, si[2] : ei[2]] = outvals.astype(dtype) return 1 @@ -208,15 +211,27 @@ def initialize_test_image( base_resolution: tuple[int, int, int], chunks: int | tuple[int, int, int] | None = None, overwrite_field: bool = True, + dtype: str | type | None = None, ) -> None: + if dtype is None: + dtype = np.float64 field1 = zarr_store.create_group(zarr_field, overwrite=overwrite_field) if chunks is None: chunks = (64, 64, 64) - lev0 = da.random.random(base_resolution, chunks=chunks) + fac: int | float + if np.issubdtype(dtype, np.integer): + fac = 100 + elif np.issubdtype(dtype, np.floating): + fac = 1.0 + else: + msg = f"Unexpected dtype of {dtype}" + raise RuntimeError(msg) + lev0 = fac * da.random.random(base_resolution, chunks=chunks) + lev0 = lev0.astype(dtype) halfway = np.asarray(base_resolution) // 2 lev0[0 : halfway[0], 0 : halfway[1], 0 : halfway[2]] = ( - lev0[0 : halfway[0], 0 : halfway[1], 0 : halfway[2]] + 0.5 + lev0[0 : halfway[0], 0 : halfway[1], 0 : halfway[2]] + 0.5 * fac ) - field1.empty(0, shape=base_resolution, chunks=chunks) + field1.empty(0, shape=base_resolution, chunks=chunks, dtype=dtype) da.to_zarr(lev0, field1["0"]) diff --git a/tests/test_sampler.py b/tests/test_sampler.py index f9e04a3..d03c64f 100644 --- a/tests/test_sampler.py +++ b/tests/test_sampler.py @@ -27,13 +27,16 @@ def test_initialize_test_image(tmp_path): assert zarr_store[fieldname][0].shape == res -def test_downsampler(tmp_path): +@pytest.mark.parametrize("dtype", ["float32", np.float64, "int", np.int32, np.int16]) +def test_downsampler(tmp_path, dtype): tmp_zrr = str(tmp_path / "myzarr.zarr") zarr_store = zarr.open(tmp_zrr) res = (32, 32, 32) chunks = (8, 8, 8) fieldname = "test_field" - initialize_test_image(zarr_store, fieldname, res, chunks, overwrite_field=False) + initialize_test_image( + zarr_store, fieldname, res, chunks, overwrite_field=False, dtype=dtype + ) dsr = Downsampler(tmp_zrr, (2, 2, 2), res, chunks) @@ -41,6 +44,7 @@ def test_downsampler(tmp_path): expected_max_lev = 2 for lev in range(expected_max_lev + 1): assert lev in zarr_store[fieldname] + assert zarr_store[fieldname][lev].dtype == np.dtype(dtype) with pytest.raises(ValueError, match="max_level must exceed 0"): dsr.downsample(0, fieldname)