From 09840ee9e5a8c64dd3e456c2af4adfcaf000cd17 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=BA=90=E6=96=87=E9=9B=A8?= <41315874+fumiama@users.noreply.github.com> Date: Fri, 14 Jun 2024 14:20:57 +0900 Subject: [PATCH] fix(filter): MPS framework doesn't support float64 by replacing x.double() with x.float() --- torchcrepe/filter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchcrepe/filter.py b/torchcrepe/filter.py index dd62ef5..ccade12 100644 --- a/torchcrepe/filter.py +++ b/torchcrepe/filter.py @@ -86,7 +86,7 @@ def median(signals, win_length): mask = mask.contiguous().view(mask.size()[:3] + (-1,)) # Combine the mask with the input tensor - x_masked = torch.where(mask.bool(), x.double(), float("inf")).to(x) + x_masked = torch.where(mask.bool(), x.float(), float("inf")).to(x) # Sort the masked tensor along the last dimension x_sorted, _ = torch.sort(x_masked, dim=-1)