Skip to content

Commit

Permalink
Avoid referencing non-existing dims
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Jan 12, 2025
1 parent 13c9e53 commit 6454121
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
6 changes: 3 additions & 3 deletions phiml/math/_deprecated.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ def cross_product(vec1: Tensor, vec2: Tensor) -> Tensor:
vec2 = tensor(vec2)
spatial_rank = vec1.vector.size if 'vector' in vec1.shape else vec2.vector.size
if spatial_rank == 2: # Curl in 2D
assert vec2.vector.exists
if vec1.vector.exists:
assert 'vector' in vec2.shape
if 'vector' in vec1.shape:
v1_x, v1_y = vec1.vector
v2_x, v2_y = vec2.vector
return v1_x * v2_y - v1_y * v2_x
else:
v2_x, v2_y = vec2.vector
return vec1 * stack_tensors([-v2_y, v2_x], channel(vec2))
elif spatial_rank == 3: # Curl in 3D
assert vec1.vector.exists and vec2.vector.exists, f"Both vectors must have a 'vector' dimension but got shapes {vec1.shape}, {vec2.shape}"
assert 'vector' in vec1.shape and 'vector' in vec2.shape, f"Both vectors must have a 'vector' dimension but got shapes {vec1.shape}, {vec2.shape}"
v1_x, v1_y, v1_z = vec1.vector
v2_x, v2_y, v2_z = vec2.vector
return stack_tensors([
Expand Down
4 changes: 2 additions & 2 deletions phiml/math/_nd.py
Original file line number Diff line number Diff line change
Expand Up @@ -662,7 +662,7 @@ def spatial_gradient(grid: Tensor,
grid = grid[{stack_dim.name: 0}]
dims = grid.shape.only(dims)
dx = wrap(dx)
if dx.vector.exists:
if 'vector' in dx.shape:
dx = dx.vector[dims]
if dx.vector.size in (None, 1):
dx = dx.vector[0]
Expand Down Expand Up @@ -706,7 +706,7 @@ def laplace(x: Tensor,
"""
if isinstance(dx, (tuple, list)):
dx = wrap(dx, batch('_laplace'))
elif isinstance(dx, Tensor) and dx.vector.exists:
elif isinstance(dx, Tensor) and 'vector' in dx.shape:
dx = rename_dims(dx, 'vector', batch('_laplace'))
if isinstance(x, Extrapolation):
return x.spatial_gradient()
Expand Down

0 comments on commit 6454121

Please sign in to comment.