Skip to content

Commit

Permalink
Fix dtype conversion
Browse files Browse the repository at this point in the history
  • Loading branch information
holl- committed Nov 28, 2023
1 parent a70c465 commit 26b1918
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion phiml/backend/_numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -514,7 +514,7 @@ def matrix_solve_least_squares(self, matrix: TensorType, rhs: TensorType) -> Ten

def solve_triangular_dense(self, matrix, rhs, lower: bool, unit_diagonal: bool):
if matrix.ndim == 2:
return scipy.linalg.solve_triangular(matrix, rhs, lower=lower, unit_diagonal=unit_diagonal).astype(matrix.dtype)
return scipy.linalg.solve_triangular(matrix, rhs, lower=lower, unit_diagonal=unit_diagonal)
else:
batch_size = matrix.shape[0]
return np.stack([self.solve_triangular(matrix[b], rhs[b], lower, unit_diagonal) for b in range(batch_size)])
Expand Down

0 comments on commit 26b1918

Please sign in to comment.