Skip to content

Commit

Permalink
Add open_datatree benchmark (#9158)
Browse files Browse the repository at this point in the history
* open_datatree performance improvement on NetCDF files

* fixing issue with forward slashes

* fixing issue with pytest

* open datatree in zarr format improvement

* fixing incompatibility in returned object

* passing group parameter to opendatatree method and reducing duplicated code

* passing group parameter to opendatatree method - NetCDF

* Update xarray/backends/netCDF4_.py

renaming variables

Co-authored-by: Tom Nicholas <[email protected]>

* renaming variables

* renaming variables

* renaming group_store variable

* removing _open_datatree_netcdf function not used anymore in open_datatree implementations

* improving performance of open_datatree method

* renaming 'i' variable within list comprehension in open_store method for zarr datatree

* using the default generator instead of loading zarr groups in memory

* fixing issue with group path to avoid using group[1:] notation. Adding group variable typing hints (str | Iterable[str] | callable) under the open_datatree for h5 files. Finally, separating positional from keyword args

* fixing issue with group path to avoid using group[1:] notation and adding group variable typing hints (str | Iterable[str] | callable) under the open_datatree method for netCDF files

* fixing issue with group path to avoid using group[1:] notation and adding group variable typing hints (str | Iterable[str] | callable) under the open_datatree method for zarr files

* adding 'mode' parameter to open_datatree method

* adding 'mode' parameter to H5NetCDFStore.open method

* adding new entry related to open_datatree performance improvement

* adding new entry related to open_datatree performance improvement

* Getting rid of unnecessary parameters for 'open_datatree' method for netCDF4 and Hdf5 backends

* passing parent argument into _iter_zarr_groups instead of group[1:] for creating group path

* adding benchmark test for opening a deeply nested data tree. This include a new class named 'IONestedDataTree' and another class for benchmarck named 'IOReadDataTreeNetCDF4'

* Update doc/whats-new.rst

---------

Co-authored-by: Tom Nicholas <[email protected]>
Co-authored-by: Kai Mühlbauer <[email protected]>
Co-authored-by: Deepak Cherian <[email protected]>
  • Loading branch information
4 people authored Jul 1, 2024
1 parent 24ab84c commit 90e4486
Show file tree
Hide file tree
Showing 2 changed files with 113 additions and 2 deletions.
113 changes: 112 additions & 1 deletion asv_bench/benchmarks/dataset_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import pandas as pd

import xarray as xr
from xarray.backends.api import open_datatree
from xarray.core.datatree import DataTree

from . import _skip_slow, parameterized, randint, randn, requires_dask

Expand All @@ -16,7 +18,6 @@
except ImportError:
pass


os.environ["HDF5_USE_FILE_LOCKING"] = "FALSE"

_ENGINES = tuple(xr.backends.list_engines().keys() - {"store"})
Expand Down Expand Up @@ -469,6 +470,116 @@ def create_delayed_write():
return ds.to_netcdf("file.nc", engine="netcdf4", compute=False)


class IONestedDataTree:
"""
A few examples that benchmark reading/writing a heavily nested netCDF datatree with
xarray
"""

timeout = 300.0
repeat = 1
number = 5

def make_datatree(self, nchildren=10):
# multiple Dataset
self.ds = xr.Dataset()
self.nt = 1000
self.nx = 90
self.ny = 45
self.nchildren = nchildren

self.block_chunks = {
"time": self.nt / 4,
"lon": self.nx / 3,
"lat": self.ny / 3,
}

self.time_chunks = {"time": int(self.nt / 36)}

times = pd.date_range("1970-01-01", periods=self.nt, freq="D")
lons = xr.DataArray(
np.linspace(0, 360, self.nx),
dims=("lon",),
attrs={"units": "degrees east", "long_name": "longitude"},
)
lats = xr.DataArray(
np.linspace(-90, 90, self.ny),
dims=("lat",),
attrs={"units": "degrees north", "long_name": "latitude"},
)
self.ds["foo"] = xr.DataArray(
randn((self.nt, self.nx, self.ny), frac_nan=0.2),
coords={"lon": lons, "lat": lats, "time": times},
dims=("time", "lon", "lat"),
name="foo",
attrs={"units": "foo units", "description": "a description"},
)
self.ds["bar"] = xr.DataArray(
randn((self.nt, self.nx, self.ny), frac_nan=0.2),
coords={"lon": lons, "lat": lats, "time": times},
dims=("time", "lon", "lat"),
name="bar",
attrs={"units": "bar units", "description": "a description"},
)
self.ds["baz"] = xr.DataArray(
randn((self.nx, self.ny), frac_nan=0.2).astype(np.float32),
coords={"lon": lons, "lat": lats},
dims=("lon", "lat"),
name="baz",
attrs={"units": "baz units", "description": "a description"},
)

self.ds.attrs = {"history": "created for xarray benchmarking"}

self.oinds = {
"time": randint(0, self.nt, 120),
"lon": randint(0, self.nx, 20),
"lat": randint(0, self.ny, 10),
}
self.vinds = {
"time": xr.DataArray(randint(0, self.nt, 120), dims="x"),
"lon": xr.DataArray(randint(0, self.nx, 120), dims="x"),
"lat": slice(3, 20),
}
root = {f"group_{group}": self.ds for group in range(self.nchildren)}
nested_tree1 = {
f"group_{group}/subgroup_1": xr.Dataset() for group in range(self.nchildren)
}
nested_tree2 = {
f"group_{group}/subgroup_2": xr.DataArray(np.arange(1, 10)).to_dataset(
name="a"
)
for group in range(self.nchildren)
}
nested_tree3 = {
f"group_{group}/subgroup_2/sub-subgroup_1": self.ds
for group in range(self.nchildren)
}
dtree = root | nested_tree1 | nested_tree2 | nested_tree3
self.dtree = DataTree.from_dict(dtree)


class IOReadDataTreeNetCDF4(IONestedDataTree):
def setup(self):
# TODO: Lazily skipped in CI as it is very demanding and slow.
# Improve times and remove errors.
_skip_slow()

requires_dask()

self.make_datatree()
self.format = "NETCDF4"
self.filepath = "datatree.nc4.nc"
dtree = self.dtree
dtree.to_netcdf(filepath=self.filepath)

def time_load_datatree_netcdf4(self):
open_datatree(self.filepath, engine="netcdf4").load()

def time_open_datatree_netcdf4(self):
open_datatree(self.filepath, engine="netcdf4")


class IOWriteNetCDFDask:
timeout = 60
repeat = 1
Expand Down
2 changes: 1 addition & 1 deletion xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,7 +446,7 @@ def open_store(
stacklevel=stacklevel,
zarr_version=zarr_version,
)
group_paths = [str(group / node[1:]) for node in _iter_zarr_groups(zarr_group)]
group_paths = [node for node in _iter_zarr_groups(zarr_group, parent=group)]
return {
group: cls(
zarr_group.get(group),
Expand Down

0 comments on commit 90e4486

Please sign in to comment.