Skip to content

Commit

Permalink
replacing squeeze with advanced indexing
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jan 16, 2025
1 parent 4d2debf commit 7997c6e
Showing 1 changed file with 5 additions and 13 deletions.
18 changes: 5 additions & 13 deletions netam/dcsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,19 +305,11 @@ def predictions_of_batch(self, batch):
# Create mask for valid (non-ambiguous) codons.
valid_mask = parent_indices != AMBIGUOUS_CODON_IDX # Shape: [B, L]

# Now unsqueeze indices for scatter operation.
parent_indices = parent_indices.unsqueeze(-1) # Shape: [B, L, 1]

# Zero out valid parent codon entries.
preds[valid_mask, :].scatter_(-1, parent_indices[valid_mask, :], 0.0)

# Sum non-parent probabilities for valid indices.
non_parent_sum = preds[valid_mask, :].sum(dim=-1, keepdim=True)

# Set parent probability for valid indices.
preds[valid_mask, :].scatter_(
-1, parent_indices[valid_mask, :], 1.0 - non_parent_sum
)
# Zero out valid parent codon entries, and then assign so it's
# 1-(sum of the non-parent codons).
preds[valid_mask, parent_indices[valid_mask]] = 0.0
non_parent_sum = preds[valid_mask, :].sum(dim=-1)
preds[valid_mask, parent_indices[valid_mask]] = 1.0 - non_parent_sum

# We have to clamp the predictions to avoid log(0) issues.
preds = torch.clamp(preds, min=torch.finfo(preds.dtype).eps)
Expand Down

0 comments on commit 7997c6e

Please sign in to comment.