Skip to content

Commit

Permalink
Add deprecated rotation functions
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 25, 2024
1 parent ae742a7 commit 30ed469
Show file tree
Hide file tree
Showing 2 changed files with 71 additions and 4 deletions.
2 changes: 1 addition & 1 deletion phiml/math/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@

from ._optimize import solve_linear, solve_nonlinear, minimize, Solve, SolveInfo, ConvergenceException, NotConverged, Diverged, SolveTape, factor_ilu

# from ._deprecated import clip_length, cross_product, cross_product as cross,
from ._deprecated import clip_length, cross_product, cross_product as cross, rotate_vector, rotation_matrix

import sys as _sys
math = _sys.modules[__name__]
Expand Down
73 changes: 70 additions & 3 deletions phiml/math/_deprecated.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Union
from typing import Union, Optional

from . import Tensor, DimFilter, channel, length, clip, safe_div, tensor
from ._ops import stack_tensors
from ._shape import shape
from ._nd import Tensor, DimFilter, channel, length, tensor, math, wrap, normalize, rename_dims
from ._ops import clip, safe_div, stack_tensors


def clip_length(vec: Tensor, min_len=0, max_len=1, vec_dim: DimFilter = channel, eps: Union[float, Tensor] = None):
Expand Down Expand Up @@ -60,3 +61,69 @@ def cross_product(vec1: Tensor, vec2: Tensor) -> Tensor:
raise AssertionError(f'dims = {spatial_rank}. Vector product not available in > 3 dimensions')



def rotate_vector(vector: math.Tensor, angle: Optional[Union[float, math.Tensor]], invert=False, dim='vector') -> Tensor:
"""
Rotates `vector` around the origin.
Args:
vector: n-dimensional vector with exactly one channel dimension
angle: Euler angle(s) or rotation matrix.
`None` is interpreted as no rotation.
invert: Whether to apply the inverse rotation.
Returns:
Rotated vector as `Tensor`
"""
assert 'vector' in vector.shape, f"vector must have exactly a channel dimension named 'vector'"
if angle is None:
return vector
matrix = rotation_matrix(angle, matrix_dim=channel(vector))
if invert:
matrix = rename_dims(matrix, '~vector,vector', matrix.shape['vector'] + matrix.shape['~vector'])
assert matrix.vector.dual.size == vector.vector.size, f"Rotation matrix from {shape(angle)} is {matrix.vector.dual.size}D but vector {vector.shape} is {vector.vector.size}D."
dim = vector.shape.only(dim)
return math.dot(matrix, dim.as_dual(), vector, dim)


def rotation_matrix(x: Union[float, math.Tensor, None], matrix_dim=channel('vector')) -> Optional[Tensor]:
"""
Create a 2D or 3D rotation matrix from the corresponding angle(s).
Args:
x:
2D: scalar angle
3D: Either vector pointing along the rotation axis with rotation angle as length or Euler angles.
Euler angles need to be laid out along a `angle` channel dimension with dimension names listing the spatial dimensions.
E.g. a 90° rotation about the z-axis is represented by `vec('angles', x=0, y=0, z=PI/2)`.
If a rotation matrix is passed for `angle`, it is returned without modification.
matrix_dim: Matrix dimension for 2D rotations. In 3D, the channel dimension of angle is used.
Returns:
Matrix containing `matrix_dim` in primal and dual form as well as all non-channel dimensions of `x`.
"""
if x is None:
return None
if isinstance(x, Tensor) and '~vector' in x.shape and 'vector' in x.shape.channel and x.shape.get_size('~vector') == x.shape.get_size('vector'):
return x # already a rotation matrix
elif 'angle' in shape(x) and shape(x).get_size('angle') == 3: # 3D Euler angles
assert channel(x).rank == 1 and channel(x).size == 3, f"x for 3D rotations needs to be a 3-vector but got {x}"
s1, s2, s3 = math.sin(x).angle # x, y, z
c1, c2, c3 = math.cos(x).angle
matrix_dim = matrix_dim.with_size(shape(x).get_item_names('angle'))
return wrap([[c3 * c2, c3 * s2 * s1 - s3 * c1, c3 * s2 * c1 + s3 * s1],
[s3 * c2, s3 * s2 * s1 + c3 * c1, s3 * s2 * c1 - c3 * s1],
[-s2, c2 * s1, c2 * c1]], matrix_dim, matrix_dim.as_dual()) # Rz * Ry * Rx (1. rotate about X by first angle)
elif 'vector' in shape(x) and shape(x).get_size('vector') == 3: # 3D axis + x
angle = length(x)
s, c = math.sin(angle), math.cos(angle)
t = 1 - c
k1, k2, k3 = normalize(x, epsilon=1e-12).vector
matrix_dim = matrix_dim.with_size(shape(x).get_item_names('vector'))
return wrap([[c + k1**2 * t, k1 * k2 * t - k3 * s, k1 * k3 * t + k2 * s],
[k2 * k1 * t + k3 * s, c + k2**2 * t, k2 * k3 * t - k1 * s],
[k3 * k1 * t - k2 * s, k3 * k2 * t + k1 * s, c + k3**2 * t]], matrix_dim, matrix_dim.as_dual())
else: # 2D rotation
sin = wrap(math.sin(x))
cos = wrap(math.cos(x))
return wrap([[cos, -sin], [sin, cos]], matrix_dim, matrix_dim.as_dual())

0 comments on commit 30ed469

Please sign in to comment.