From 646ab93af3c7d42549774f2a2a780e1b8a1f257b Mon Sep 17 00:00:00 2001 From: Philipp Holl Date: Sun, 12 Jan 2025 20:15:22 +0100 Subject: [PATCH] Avoid referencing non-existing dims --- phiml/math/_deprecated.py | 6 +++--- phiml/math/_nd.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/phiml/math/_deprecated.py b/phiml/math/_deprecated.py index 5a3e7e25..bdf3faac 100644 --- a/phiml/math/_deprecated.py +++ b/phiml/math/_deprecated.py @@ -55,8 +55,8 @@ def cross_product(vec1: Tensor, vec2: Tensor) -> Tensor: vec2 = tensor(vec2) spatial_rank = vec1.vector.size if 'vector' in vec1.shape else vec2.vector.size if spatial_rank == 2: # Curl in 2D - assert vec2.vector.exists - if vec1.vector.exists: + assert 'vector' in vec2.shape + if 'vector' in vec1.shape: v1_x, v1_y = vec1.vector v2_x, v2_y = vec2.vector return v1_x * v2_y - v1_y * v2_x @@ -64,7 +64,7 @@ def cross_product(vec1: Tensor, vec2: Tensor) -> Tensor: v2_x, v2_y = vec2.vector return vec1 * stack_tensors([-v2_y, v2_x], channel(vec2)) elif spatial_rank == 3: # Curl in 3D - assert vec1.vector.exists and vec2.vector.exists, f"Both vectors must have a 'vector' dimension but got shapes {vec1.shape}, {vec2.shape}" + assert 'vector' in vec1.shape and 'vector' in vec2.shape, f"Both vectors must have a 'vector' dimension but got shapes {vec1.shape}, {vec2.shape}" v1_x, v1_y, v1_z = vec1.vector v2_x, v2_y, v2_z = vec2.vector return stack_tensors([ diff --git a/phiml/math/_nd.py b/phiml/math/_nd.py index 0985086f..7800acad 100644 --- a/phiml/math/_nd.py +++ b/phiml/math/_nd.py @@ -662,7 +662,7 @@ def spatial_gradient(grid: Tensor, grid = grid[{stack_dim.name: 0}] dims = grid.shape.only(dims) dx = wrap(dx) - if dx.vector.exists: + if 'vector' in dx.shape: dx = dx.vector[dims] if dx.vector.size in (None, 1): dx = dx.vector[0] @@ -706,7 +706,7 @@ def laplace(x: Tensor, """ if isinstance(dx, (tuple, list)): dx = wrap(dx, batch('_laplace')) - elif isinstance(dx, Tensor) and dx.vector.exists: + elif isinstance(dx, Tensor) and 'vector' in dx.shape: dx = rename_dims(dx, 'vector', batch('_laplace')) if isinstance(x, Extrapolation): return x.spatial_gradient()