Skip to content

Commit

Permalink
Merge remote-tracking branch 'fdeconinck/feature/multimodal_metric_th…
Browse files Browse the repository at this point in the history
…reshold_override' into feature/multimodal_metric_threshold_override
  • Loading branch information
Florian Deconinck committed Oct 9, 2024
2 parents 020a259 + c9da47b commit 75f886f
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 4 deletions.
24 changes: 23 additions & 1 deletion ndsl/dsl/gt4py_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,14 @@ def wrapper(*args, **kwargs) -> Any:
def _mask_to_dimensions(
mask: Tuple[bool, ...], shape: Sequence[int]
) -> List[Union[str, int]]:
assert len(mask) == 3
assert len(mask) >= 3
dimensions: List[Union[str, int]] = []
for i, axis in enumerate(("I", "J", "K")):
if mask[i]:
dimensions.append(axis)
if len(mask) > 3:
for i in range(3, len(mask)):
dimensions.append(str(shape[i]))
offset = int(sum(mask))
dimensions.extend(shape[offset:])
return dimensions
Expand Down Expand Up @@ -154,6 +157,8 @@ def make_storage_data(
data = _make_storage_data_2d(
data, shape, start, dummy, axis, read_only, backend=backend
)
elif n_dims >= 4:
data = _make_storage_data_Nd(data, shape, start, backend=backend)
else:
data = _make_storage_data_3d(data, shape, start, backend=backend)

Expand Down Expand Up @@ -257,6 +262,21 @@ def _make_storage_data_3d(
return buffer


def _make_storage_data_Nd(
data: Field,
shape: Tuple[int, ...],
start: Tuple[int, ...] = None,
*,
backend: str,
) -> Field:
if start is None:
start = tuple([0] * data.ndim)
buffer = zeros(shape, backend=backend)
idx = tuple([slice(start[i], start[i] + data.shape[i]) for i in range(len(start))])
buffer[idx] = asarray(data, type(buffer))
return buffer


def make_storage_from_shape(
shape: Tuple[int, ...],
origin: Tuple[int, ...] = origin,
Expand Down Expand Up @@ -310,6 +330,7 @@ def make_storage_dict(
axis: int = 2,
*,
backend: str,
dtype: DTypes = Float,
) -> Dict[str, "Field"]:
assert names is not None, "for 4d variable storages, specify a list of names"
if shape is None:
Expand All @@ -324,6 +345,7 @@ def make_storage_dict(
dummy=dummy,
axis=axis,
backend=backend,
dtype=dtype,
)
return data_dict

Expand Down
3 changes: 3 additions & 0 deletions ndsl/dsl/stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,9 @@ def axis_offsets(
"local_js": gtscript.J[0] + self.jsc - origin[1],
"j_end": j_end,
"local_je": gtscript.J[-1] + self.jec - origin[1] - domain[1] + 1,
"k_start": origin[2] if len(origin) > 2 else 0,
"k_end": (origin[2] if len(origin) > 2 else 0)
+ (domain[2] - 1 if len(domain) > 2 else 0),
}

def get_origin_domain(
Expand Down
17 changes: 14 additions & 3 deletions ndsl/stencils/testing/translate.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import ndsl.dsl.gt4py_utils as utils
from ndsl.dsl.stencil import StencilFactory
from ndsl.dsl.typing import Field, Float # noqa: F401
from ndsl.dsl.typing import Field, Float, Int # noqa: F401
from ndsl.quantity import Quantity
from ndsl.stencils.testing.grid import Grid # type: ignore

Expand Down Expand Up @@ -116,6 +116,12 @@ def make_storage_data(
elif not full_shape and len(array.shape) < 3 and axis == len(array.shape) - 1:
use_shape[1] = 1
start = (int(istart), int(jstart), int(kstart))
if "float" in str(array.dtype):
dtype = Float
elif "int" in str(array.dtype):
dtype = Int
else:
dtype = array.dtype
if names_4d:
return utils.make_storage_dict(
array,
Expand All @@ -126,8 +132,12 @@ def make_storage_data(
axis=axis,
names=names_4d,
backend=self.stencil_factory.backend,
dtype=dtype,
)
else:
if len(array.shape) == 4:
start = (int(istart), int(jstart), int(kstart), 0) # type: ignore
use_shape.append(array.shape[-1])
return utils.make_storage_data(
array,
tuple(use_shape),
Expand All @@ -137,6 +147,7 @@ def make_storage_data(
axis=axis,
read_only=read_only,
backend=self.stencil_factory.backend,
dtype=dtype,
)

def storage_vars(self):
Expand All @@ -162,7 +173,7 @@ def collect_start_indices(self, datashape, varinfo):
kstart = self.get_index_from_info(varinfo, "kstart", 0)
return istart, jstart, kstart

def make_storage_data_input_vars(self, inputs, storage_vars=None):
def make_storage_data_input_vars(self, inputs, storage_vars=None, dict_4d=True):
inputs_in = {**inputs}
inputs_out = {}
if storage_vars is None:
Expand All @@ -188,7 +199,7 @@ def make_storage_data_input_vars(self, inputs, storage_vars=None):
)

names_4d = None
if len(inputs_in[serialname].shape) == 4:
if (len(inputs_in[serialname].shape) == 4) and dict_4d:
names_4d = info.get("names_4d", utils.tracer_variables)

dummy_axes = info.get("dummy_axes", None)
Expand Down
17 changes: 17 additions & 0 deletions tests/dsl/test_stencil_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,23 @@ def test_get_stencils_with_varied_bounds_and_regions(backend: str):
np.testing.assert_array_equal(q_orig.data, q_ref.data)


def test_stencil_vertical_bounds(backend: str):
factory = get_stencil_factory(backend)
origins = [(3, 3, 0), (2, 2, 1)]
domains = [(1, 1, 3), (2, 2, 4)]
stencils = get_stencils_with_varied_bounds(
add_1_in_region_stencil,
origins,
domains,
stencil_factory=factory,
)

assert "k_start" in stencils[0].externals and stencils[0].externals["k_start"] == 0
assert "k_end" in stencils[0].externals and stencils[0].externals["k_end"] == 2
assert "k_start" in stencils[1].externals and stencils[1].externals["k_start"] == 1
assert "k_end" in stencils[1].externals and stencils[1].externals["k_end"] == 4


@pytest.mark.parametrize("enabled", [True, False])
def test_stencil_factory_numpy_comparison_from_dims_halo(enabled: bool):
backend = "numpy"
Expand Down

0 comments on commit 75f886f

Please sign in to comment.