From d4848952cca12565b3aebad5b22c65c38eaf0b63 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Tue, 14 Jan 2025 12:42:48 -0800 Subject: [PATCH] fixing a device proble --- netam/dcsm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/netam/dcsm.py b/netam/dcsm.py index 9f2a4da8..c766da0c 100644 --- a/netam/dcsm.py +++ b/netam/dcsm.py @@ -347,7 +347,7 @@ def predictions_of_batch(self, batch): # However, we have to unsqueeze because scatter_ requires the `index` # tensor and the `src` value(s) to have the same shape for broadcasting. - parent_indices = batch["codon_parents_idxs"] # Shape: [B, L] + parent_indices = batch["codon_parents_idxs"].to(self.device) # Shape: [B, L] # Create mask for valid (non-ambiguous) codons. valid_mask = parent_indices != AMBIGUOUS_CODON_IDX # Shape: [B, L]