Skip to content

Commit

Permalink
Add model_embedding_dim argument to Dataset constructors (#107)
Browse files Browse the repository at this point in the history
This PR makes DXSM datasets specific to the embedding dimension of the model that the dataset is intended for.
Minimal changes in dnsm-experiments-1 were required, and these are in matsengrp/dnsm-experiments-1#78.

* `load_pcp_df` and associated functions will load pcp_df without inserting any special token scaffolding into the sequences. We will always maintain separate heavy and light-chain columns wherever applicable.
* There will be a free function that takes perhaps pairs of heavy- and light- chain sequences and scaffolds them with special tokens so they can be presented to the model. This function will take a `known_token_count` so that it knows how to process the sequences.
* This free function will be called by the Dataset constructor, which will also still need to accept the `known_token_count` parameter.
* Calling the `model` forward/represent functions directly (or through model.__call__) will require the user to do any sequence token scaffolding on their own, perhaps using the free function mentioned above
* Calling the `Crepe.__call__` function on sequences will do the required scaffolding automatically. Perhaps we'll have to strip out the model predictions for special token sites, so the outputs match the input sequence lengths? Erick suggests input format of something like `crepe([(heavy, None), (heavy1, light1), ...])` to allow heavy and light chain sequences to be passed to the crepe for proper scaffolding.
  • Loading branch information
willdumm authored Jan 24, 2025
1 parent 5cac227 commit 09d9274
Show file tree
Hide file tree
Showing 14 changed files with 670 additions and 199 deletions.
91 changes: 59 additions & 32 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,38 +89,6 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None):
return mask


def _consider_codon(codon):
"""Return False if codon should be masked, True otherwise."""
if "N" in codon:
return False
elif codon in RESERVED_TOKEN_TRANSLATIONS:
return False
else:
return True


def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None):
"""Return a mask tensor indicating codons which contain at least one N.
Codons beyond the length of the sequence are masked. If other_nt_seqs are provided,
the "and" mask will be computed for all sequences. Codons containing marker tokens
are also masked.
"""
if aa_length is None:
aa_length = len(nt_parent) // 3
sequences = (nt_parent,) + other_nt_seqs
mask = [
all(_consider_codon(codon) for codon in codons)
for codons in zip(*(iter_codons(sequence) for sequence in sequences))
]
if len(mask) < aa_length:
mask += [False] * (aa_length - len(mask))
else:
mask = mask[:aa_length]
assert len(mask) == aa_length
return torch.tensor(mask, dtype=torch.bool)


def aa_strs_from_idx_tensor(idx_tensor):
"""Convert a tensor of amino acid indices back to a list of amino acid strings.
Expand Down Expand Up @@ -177,6 +145,38 @@ def aa_mask_tensor_of(*args, **kwargs):
return generic_mask_tensor_of("X", *args, **kwargs)


def _consider_codon(codon):
"""Return False if codon should be masked, True otherwise."""
if "N" in codon:
return False
elif codon in RESERVED_TOKEN_TRANSLATIONS:
return False
else:
return True


def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None):
"""Return a mask tensor indicating codons which contain at least one N.
Codons beyond the length of the sequence are masked. If other_nt_seqs are provided,
the "and" mask will be computed for all sequences. Codons containing marker tokens
are also masked.
"""
if aa_length is None:
aa_length = len(nt_parent) // 3
sequences = (nt_parent,) + other_nt_seqs
mask = [
all(_consider_codon(codon) for codon in codons)
for codons in zip(*(iter_codons(sequence) for sequence in sequences))
]
if len(mask) < aa_length:
mask += [False] * (aa_length - len(mask))
else:
mask = mask[:aa_length]
assert len(mask) == aa_length
return torch.tensor(mask, dtype=torch.bool)


def informative_site_count(seq_str):
return sum(c != "N" for c in seq_str)

Expand Down Expand Up @@ -429,6 +429,32 @@ def chunked(iterable, n):
yield chunk


def assume_single_sequence_is_heavy_chain(seq_arg_idx=0):
"""Wraps a function that takes a heavy/light sequence pair as its first argument and
returns a tuple of results.
The wrapped function will assume that if the first argument is a string, it is a
heavy chain sequence, and in that case will return only the heavy chain result.
"""

def decorator(function):
@wraps(function)
def wrapper(*args, **kwargs):
seq = args[seq_arg_idx]
if isinstance(seq, str):
seq = (seq, "")
args = list(args)
args[seq_arg_idx] = seq
res = function(*args, **kwargs)
return res[0]
else:
return function(*args, **kwargs)

return wrapper

return decorator


def chunk_function(
first_chunkable_idx=0, default_chunk_size=2048, progress_bar_name=None
):
Expand Down Expand Up @@ -516,6 +542,7 @@ def parallelize_function(
max_worker_count = min(mp.cpu_count() // 2, max_workers)
if max_worker_count <= 1:
return function
force_spawn()

@wraps(function)
def wrapper(*args, **kwargs):
Expand Down
60 changes: 46 additions & 14 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import netam.molevol as molevol
import netam.sequences as sequences
import copy
from typing import Tuple


class DASMDataset(DXSMDataset):
Expand Down Expand Up @@ -99,7 +100,7 @@ def to(self, device):
self.multihit_model = self.multihit_model.to(device)


def zap_predictions_along_diagonal(predictions, aa_parents_idxs):
def zap_predictions_along_diagonal(predictions, aa_parents_idxs, fill=-BIG):
"""Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG,
except where aa_parents_idxs >= 20, which indicates no update should be done."""

Expand All @@ -116,7 +117,7 @@ def zap_predictions_along_diagonal(predictions, aa_parents_idxs):
batch_indices[valid_mask],
sequence_indices[valid_mask],
aa_parents_idxs[valid_mask],
] = -BIG
] = fill

return predictions

Expand All @@ -139,10 +140,7 @@ def prediction_pair_of_batch(self, batch):
raise ValueError(
f"log_neutral_aa_probs has non-finite values at relevant positions: {log_neutral_aa_probs[mask]}"
)
# We need the model to see special tokens here. For every other purpose
# they are masked out.
keep_token_mask = mask | sequences.token_mask_of_aa_idxs(aa_parents_idxs)
log_selection_factors = self.model(aa_parents_idxs, keep_token_mask)
log_selection_factors = self.selection_factors_of_aa_idxs(aa_parents_idxs, mask)
return log_neutral_aa_probs, log_selection_factors

def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors):
Expand Down Expand Up @@ -204,19 +202,53 @@ def loss_of_batch(self, batch):
csp_loss = self.xent_loss(csp_pred, csp_targets)
return torch.stack([subs_pos_loss, csp_loss])

def build_selection_matrix_from_parent(self, parent: str):
"""Build a selection matrix from a parent amino acid sequence.
def build_selection_matrix_from_parent_aa(
self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor
):
"""Build a selection matrix from a single parent amino acid sequence. Inputs are
expected to be as prepared in the Dataset constructor.
Values at ambiguous sites are meaningless.
"""
with torch.no_grad():
per_aa_selection_factors = self.selection_factors_of_aa_idxs(
aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0)
).exp()
return zap_predictions_along_diagonal(
per_aa_selection_factors, aa_parent_idxs.unsqueeze(0), fill=1.0
).squeeze(0)

# This is not used anywhere, except for in a few tests. Keeping it around
# for that reason.
def _build_selection_matrix_from_parent(self, parent: Tuple[str, str]):
"""Build a selection matrix from a parent nucleotide sequence, a heavy-chain,
light-chain pair.
Values at ambiguous sites are meaningless. Returned value is a tuple of
selection matrix for heavy and light chain sequences.
"""
# This is simpler than the equivalent in dnsm.py because we get the selection
# matrix directly. Note that selection_factors_of_aa_str does the exponentiation
# so this indeed gives us the selection factors, not the log selection factors.
parent = sequences.translate_sequence(parent)
per_aa_selection_factors = self.model.selection_factors_of_aa_str(parent)
aa_parent_pair = tuple(map(sequences.translate_sequence, parent))
per_aa_selection_factorss = self.model.selection_factors_of_aa_str(
aa_parent_pair
)

parent = parent.replace("X", "A")
parent_idxs = sequences.aa_idx_array_of_str(parent)
per_aa_selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0
result = []
for per_aa_selection_factors, aa_parent in zip(
per_aa_selection_factorss, aa_parent_pair
):
aa_parent_idxs = torch.tensor(sequences.aa_idx_array_of_str(aa_parent))
if len(per_aa_selection_factors) > 0:
result.append(
zap_predictions_along_diagonal(
per_aa_selection_factors.unsqueeze(0),
aa_parent_idxs.unsqueeze(0),
fill=1.0,
).squeeze(0)
)
else:
result.append(per_aa_selection_factors)

return per_aa_selection_factors
return tuple(result)
68 changes: 57 additions & 11 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import netam.molevol as molevol
import netam.sequences as sequences

from typing import Tuple


class DNSMDataset(DXSMDataset):

Expand Down Expand Up @@ -127,7 +129,8 @@ def prediction_pair_of_batch(self, batch):
raise ValueError(
f"log_neutral_aa_mut_probs has non-finite values at relevant positions: {log_neutral_aa_mut_probs[mask]}"
)
log_selection_factors = self.model(aa_parents_idxs, mask)
# Right here is where model is evaluated!
log_selection_factors = self.selection_factors_of_aa_idxs(aa_parents_idxs, mask)
return log_neutral_aa_mut_probs, log_selection_factors

def predictions_of_pair(self, log_neutral_aa_mut_probs, log_selection_factors):
Expand Down Expand Up @@ -156,24 +159,67 @@ def loss_of_batch(self, batch):
predictions = self.predictions_of_batch(batch).masked_select(mask)
return self.bce_loss(predictions, aa_subs_indicator)

def build_selection_matrix_from_parent(self, parent: str):
"""Build a selection matrix from a parent amino acid sequence.
def _build_selection_matrix_from_selection_factors(
self, selection_factors, aa_parent_idxs
):
"""Build a selection matrix from a selection factor tensor for a single
sequence.
Values at ambiguous sites are meaningless.
upgrades the provided tensor containing a selection factor per site to a matrix
containing a selection factor per site and amino acid. The wildtype aa selection
factor is set ot 1, and the rest are set to the selection factor.
"""
parent = sequences.translate_sequence(parent)
selection_factors = self.model.selection_factors_of_aa_str(parent)
selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float)
# Every "off-diagonal" entry of the selection matrix is set to the selection
# factor, where "diagonal" means keeping the same amino acid.
selection_matrix[:, :] = selection_factors[:, None]
parent = parent.replace("X", "A")
# Set "diagonal" elements to one.
parent_idxs = sequences.aa_idx_array_of_str(parent)
selection_matrix[torch.arange(len(parent_idxs)), parent_idxs] = 1.0

valid_mask = aa_parent_idxs < 20
selection_matrix[
torch.arange(len(aa_parent_idxs))[valid_mask], aa_parent_idxs[valid_mask]
] = 1.0
selection_matrix[~valid_mask] = 1.0
return selection_matrix

def build_selection_matrix_from_parent_aa(
self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor
):
"""Build a selection matrix from a single parent amino acid sequence.
Values at ambiguous sites are meaningless.
"""
with torch.no_grad():
selection_factors = (
self.selection_factors_of_aa_idxs(
aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0)
)
.squeeze(0)
.exp()
)
return self._build_selection_matrix_from_selection_factors(
selection_factors, aa_parent_idxs
)

def _build_selection_matrix_from_parent(self, parent: Tuple[str, str]):
"""Build a selection matrix from a nucleotide sequence.
Values at ambiguous sites are meaningless.
"""
aa_parent_pair = tuple(map(sequences.translate_sequence, parent))
selection_factorss = self.model.selection_factors_of_aa_str(aa_parent_pair)

result = []
for selection_factors, aa_parent in zip(selection_factorss, aa_parent_pair):
aa_parent_idxs = sequences.aa_idx_array_of_str(aa_parent)
if len(selection_factors) > 0:
result.append(
self._build_selection_matrix_from_selection_factors(
selection_factors, aa_parent_idxs
)
)
else:
result.append(torch.empty(0, 20))
return tuple(result)


class DNSMHyperBurrito(HyperBurrito):
# Note that we have to write the args out explicitly because we use some magic to filter kwargs in the optuna_objective method.
Expand Down
Loading

0 comments on commit 09d9274

Please sign in to comment.