Skip to content

Commit

Permalink
Auto-convert NumPy scalars to primitives in functional math
Browse files Browse the repository at this point in the history
This avoids NumPy conversion errors. This only affects the case that non-Tensors are passed.
  • Loading branch information
holl- committed Jan 16, 2024
1 parent 908b697 commit f35bcd6
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions phiml/math/_tensors.py
Original file line number Diff line number Diff line change
Expand Up @@ -1927,6 +1927,8 @@ def assemble_tree(obj: PhiTreeNodeType, values: List[Tensor]) -> PhiTreeNodeType
elif obj is NATIVE_TENSOR:
value = values.pop(0)
assert isinstance(value, NativeTensor), f"Failed to assemble tree structure. Encountered {value}"
if isinstance(value._native, np.ndarray) and value.shape == EMPTY_SHAPE: # this can be represented as a Python scalar, which leads to less conversion errors
return value._native.item()
return value._native
elif obj is None:
value = values.pop(0)
Expand Down

0 comments on commit f35bcd6

Please sign in to comment.