diff --git a/phiml/backend/torch/_torch_backend.py b/phiml/backend/torch/_torch_backend.py index 30f49476..795b088a 100644 --- a/phiml/backend/torch/_torch_backend.py +++ b/phiml/backend/torch/_torch_backend.py @@ -398,6 +398,7 @@ def all(self, boolean_tensor, axis=None, keepdims=False): def quantile(self, x, quantiles): x = self.to_float(x) + quantiles = self.to_float(quantiles) result = torch.quantile(x, quantiles, dim=-1) return result