From f0f4086bd7586d7fd77df1393d319577920646ed Mon Sep 17 00:00:00 2001 From: holl- Date: Fri, 24 Nov 2023 13:46:01 +0100 Subject: [PATCH] Fix dtype for bool --- phiml/backend/_numpy_backend.py | 2 ++ phiml/backend/jax/_jax_backend.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/phiml/backend/_numpy_backend.py b/phiml/backend/_numpy_backend.py index 5d29ee9a..a9719b75 100644 --- a/phiml/backend/_numpy_backend.py +++ b/phiml/backend/_numpy_backend.py @@ -393,6 +393,8 @@ def ifft(self, k, axes: Union[tuple, list]): return np.fft.ifftn(k, axes=axes).astype(k.dtype) def dtype(self, array) -> DType: + if isinstance(array, bool): + return DType(bool) if isinstance(array, int): return DType(int, 32) if isinstance(array, float): diff --git a/phiml/backend/jax/_jax_backend.py b/phiml/backend/jax/_jax_backend.py index 68bccb55..208baea2 100644 --- a/phiml/backend/jax/_jax_backend.py +++ b/phiml/backend/jax/_jax_backend.py @@ -564,6 +564,8 @@ def ifft(self, k, axes: Union[tuple, list]): return jnp.fft.ifftn(k, axes=axes).astype(k.dtype) def dtype(self, array) -> DType: + if isinstance(array, bool): + return DType(bool) if isinstance(array, int): return DType(int, 32) if isinstance(array, float):