Skip to content

Commit

Permalink
fixing computation graph
Browse files Browse the repository at this point in the history
  • Loading branch information
matsen committed Jan 16, 2025
1 parent 7997c6e commit ea73121
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions netam/dcsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,11 +305,18 @@ def predictions_of_batch(self, batch):
# Create mask for valid (non-ambiguous) codons.
valid_mask = parent_indices != AMBIGUOUS_CODON_IDX # Shape: [B, L]

# Zero out valid parent codon entries, and then assign so it's
# 1-(sum of the non-parent codons).
preds[valid_mask, parent_indices[valid_mask]] = 0.0
# Zero out the parent indices in preds, keeping the computation graph intact.
preds_zeroer = torch.ones_like(preds)
preds_zeroer[valid_mask, parent_indices[valid_mask]] = 0.0
preds = preds * preds_zeroer

# Calculate the non-parent sum after zeroing out the parent indices.
non_parent_sum = preds[valid_mask, :].sum(dim=-1)
preds[valid_mask, parent_indices[valid_mask]] = 1.0 - non_parent_sum

# Add these parent values back in, again keeping the computation graph intact.
preds_parent = torch.zeros_like(preds)
preds_parent[valid_mask, parent_indices[valid_mask]] = 1.0 - non_parent_sum
preds = preds + preds_parent

# We have to clamp the predictions to avoid log(0) issues.
preds = torch.clamp(preds, min=torch.finfo(preds.dtype).eps)
Expand Down

0 comments on commit ea73121

Please sign in to comment.