Skip to content

Commit

Permalink
[field] Fix 1D staggered grid
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Dec 11, 2023
1 parent 40ad189 commit 599cc0e
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions phi/field/_field_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -346,8 +346,8 @@ def stagger(field: Field,
width_lower, width_upper = {dim: (0, -1)}, {dim: (-1, 0)}
all_lower.append(math.pad(field.values, width_lower, field.extrapolation, bounds=field.bounds))
all_upper.append(math.pad(field.values, width_upper, field.extrapolation, bounds=field.bounds))
all_upper = math.stack(all_upper, channel('vector'))
all_lower = math.stack(all_lower, channel('vector'))
all_upper = math.stack(all_upper, dual(vector=field.resolution.names))
all_lower = math.stack(all_lower, dual(vector=field.resolution.names))
values = face_function(all_lower, all_upper)
result = StaggeredGrid(values, bounds=field.bounds, extrapolation=extrapolation)
assert result.shape.spatial == field.shape.spatial
Expand Down
4 changes: 2 additions & 2 deletions phi/field/_grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ def StaggeredGrid(values: Any = 0.,
if resolution is None and not resolution_:
assert isinstance(values, Tensor), "Grid resolution must be specified when 'values' is not a Tensor."
if not all(extrapolation.valid_outer_faces(d)[0] != extrapolation.valid_outer_faces(d)[1] for d in spatial(values).names): # non-uniform values required
if values.shape.is_uniform:
if '~vector' not in values.shape:
values = unstack_staggered_tensor(values, extrapolation)
resolution = resolution_from_staggered_tensor(values, extrapolation)
else:
Expand All @@ -144,7 +144,7 @@ def StaggeredGrid(values: Any = 0.,
if not spatial(values):
values = expand_staggered(values, resolution, extrapolation)
if not all(extrapolation.valid_outer_faces(d)[0] != extrapolation.valid_outer_faces(d)[1] for d in resolution.names): # non-uniform values required
if values.shape.is_uniform:
if '~vector' not in values.shape:
values = unstack_staggered_tensor(values, extrapolation)
else: # Keep dim order from data and check it matches resolution
assert set(resolution_from_staggered_tensor(values, extrapolation)) == set(resolution), f"Failed to create StaggeredGrid: values {values.shape} do not match given resolution {resolution} for extrapolation {extrapolation}. See https://tum-pbs.github.io/PhiFlow/Staggered_Grids.html"
Expand Down

0 comments on commit 599cc0e

Please sign in to comment.