Skip to content

Commit

Permalink
Merge branch 'release/0.18.11'
Browse files Browse the repository at this point in the history
  • Loading branch information
floriankrb committed Oct 31, 2023
2 parents 4d7444c + 1b946fd commit 13d7516
Show file tree
Hide file tree
Showing 2 changed files with 44 additions and 53 deletions.
95 changes: 43 additions & 52 deletions climetlab/loaders/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@

LOG = logging.getLogger(__name__)

VERSION = "0.12"
VERSION = "0.13"


class DatasetName:
Expand Down Expand Up @@ -538,32 +538,43 @@ def load_datacube(self, cube, array):
)


def add_zarr_dataset(name, nparray, *, zarr_root, overwrite=True, dtype=np.float32):
def add_zarr_dataset(
name,
nparray,
*,
zarr_root,
overwrite=True,
dtype=np.float32,
**kwargs,
):
if isinstance(nparray, (tuple, list)):
nparray = np.array(nparray, dtype=dtype)
a = zarr_root.create_dataset(
name,
shape=nparray.shape,
dtype=dtype,
overwrite=overwrite,
**kwargs,
)
a[...] = nparray[...]
return a


class ZarrRegistry:
synchronizer_path = None # to be defined in subclasses
synchronizer_name = None # to be defined in subclasses

def __init__(self, path):
assert self.synchronizer_path is not None, self.synchronizer_path
assert self.synchronizer_name is not None, self.synchronizer_name

import zarr

assert isinstance(path, str), path
self.zarr_path = path
self.synchronizer = zarr.ProcessSynchronizer(
os.path.join(self.zarr_path + ".sync", self.synchronizer_path)
)
self.synchronizer = zarr.ProcessSynchronizer(self._synchronizer_path)

@property
def _synchronizer_path(self):
return self.zarr_path + "-" + self.synchronizer_name + ".sync"

def _open_write(self):
import zarr
Expand All @@ -578,6 +589,11 @@ def _open_read(self, sync=True):
else:
return zarr.open(self.zarr_path, mode="r")

def new_dataset(self, *args, **kwargs):
z = self._open_write()
zarr_root = z["_build"]
add_zarr_dataset(*args, zarr_root=zarr_root, overwrite=True, **kwargs)

def add_to_history(self, action, **kwargs):
new = dict(
action=action,
Expand Down Expand Up @@ -608,41 +624,30 @@ class ZarrStatisticsRegistry(ZarrRegistry):
"squares",
"count",
]
synchronizer_path = "statistics.sync"
synchronizer_name = "statistics"

def __init__(self, path):
super().__init__(path)

def create(self):
z = self._open_write()
z = self._open_read()
shape = z["data"].shape
shape = (shape[0], shape[1])

nans = np.full((shape[0], shape[1]), np.nan)
data = np.full(shape, np.nan)
for name in self.build_names:
add_zarr_dataset(
name,
nans,
zarr_root=z["_build"],
dtype=np.float64,
overwrite=True,
)
z = None
self.new_dataset(name, data) # , chunks=None)
self.add_to_history("statistics_initialised")

def __setitem__(self, key, stats):
z = self._open_write()

LOG.debug("Writting stats for ", key)
LOG.debug(f"Writting stats for {key}")
for name in self.build_names:
z["_build"][name][key] = stats[name]
LOG.debug("Written stats for ", key)
LOG.debug(f"Written stats for {key}")

def get_all(self, key=None):
if key is None:
key = slice(None, None)
dic = {}
for name in self.build_names:
dic[name] = self.get(name)[key]
return dic

def get(self, name):
def get_by_name(self, name):
z = self._open_read()
return z["_build"][name]

Expand All @@ -653,7 +658,7 @@ class ZarrBuiltRegistry(ZarrRegistry):
lengths = None
flags = None
z = None
synchronizer_path = "registry.sync"
synchronizer_name = "build"

def get_slice_for(self, i):
lengths = self.get_lengths()
Expand Down Expand Up @@ -682,23 +687,8 @@ def set_flag(self, i, value=True):
z["_build"][self.name_flags][i] = value

def create(self, lengths, overwrite=False):
z = self._open_write()

add_zarr_dataset(
self.name_lengths,
lengths,
zarr_root=z["_build"],
dtype="i4",
overwrite=overwrite,
)
add_zarr_dataset(
self.name_flags,
len(lengths) * [False],
zarr_root=z["_build"],
dtype=bool,
overwrite=overwrite,
)
z = None
self.new_dataset(self.name_lengths, lengths, dtype="i4")
self.new_dataset(self.name_flags, len(lengths) * [False], dtype=bool)
self.add_to_history("initialised")

def reset(self, lengths):
Expand Down Expand Up @@ -1020,6 +1010,7 @@ def add_statistics(self, no_write, **kwargs):
start=statistics_start,
end=statistics_end,
)

self.registry.add_provenance(name="provenance_statistics")

def compute_statistics(self, ds, statistics_start, statistics_end):
Expand Down Expand Up @@ -1049,11 +1040,11 @@ def _compute_statistics(self, ds, statistics_start, statistics_end):

reg = self.statistics_registry

_minimum = np.amin(reg.get("minimum"), axis=0)
_maximum = np.amax(reg.get("maximum"), axis=0)
_count = np.sum(reg.get("count"), axis=0)
_sums = np.sum(reg.get("sums"), axis=0)
_squares = np.sum(reg.get("squares"), axis=0)
_minimum = np.amin(reg.get_by_name("minimum"), axis=0)
_maximum = np.amax(reg.get_by_name("maximum"), axis=0)
_count = np.sum(reg.get_by_name("count"), axis=0)
_sums = np.sum(reg.get_by_name("sums"), axis=0)
_squares = np.sum(reg.get_by_name("squares"), axis=0)
_mean = _sums / _count
_stdev = np.sqrt(_squares / _count - _mean * _mean)

Expand Down
2 changes: 1 addition & 1 deletion climetlab/version
Original file line number Diff line number Diff line change
@@ -1 +1 @@
0.18.10
0.18.11

0 comments on commit 13d7516

Please sign in to comment.