diff --git a/netam/dcsm.py b/netam/dcsm.py index 3d952da8..efbd393e 100644 --- a/netam/dcsm.py +++ b/netam/dcsm.py @@ -305,19 +305,11 @@ def predictions_of_batch(self, batch): # Create mask for valid (non-ambiguous) codons. valid_mask = parent_indices != AMBIGUOUS_CODON_IDX # Shape: [B, L] - # Now unsqueeze indices for scatter operation. - parent_indices = parent_indices.unsqueeze(-1) # Shape: [B, L, 1] - - # Zero out valid parent codon entries. - preds[valid_mask, :].scatter_(-1, parent_indices[valid_mask, :], 0.0) - - # Sum non-parent probabilities for valid indices. - non_parent_sum = preds[valid_mask, :].sum(dim=-1, keepdim=True) - - # Set parent probability for valid indices. - preds[valid_mask, :].scatter_( - -1, parent_indices[valid_mask, :], 1.0 - non_parent_sum - ) + # Zero out valid parent codon entries, and then assign so it's + # 1-(sum of the non-parent codons). + preds[valid_mask, parent_indices[valid_mask]] = 0.0 + non_parent_sum = preds[valid_mask, :].sum(dim=-1) + preds[valid_mask, parent_indices[valid_mask]] = 1.0 - non_parent_sum # We have to clamp the predictions to avoid log(0) issues. preds = torch.clamp(preds, min=torch.finfo(preds.dtype).eps)