Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add model_embedding_dim argument to Dataset constructors #107

Merged
merged 15 commits into from
Jan 24, 2025
91 changes: 59 additions & 32 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,38 +89,6 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None):
return mask


def _consider_codon(codon):
"""Return False if codon should be masked, True otherwise."""
if "N" in codon:
return False
elif codon in RESERVED_TOKEN_TRANSLATIONS:
return False
else:
return True


def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None):
"""Return a mask tensor indicating codons which contain at least one N.

Codons beyond the length of the sequence are masked. If other_nt_seqs are provided,
the "and" mask will be computed for all sequences. Codons containing marker tokens
are also masked.
"""
if aa_length is None:
aa_length = len(nt_parent) // 3
sequences = (nt_parent,) + other_nt_seqs
mask = [
all(_consider_codon(codon) for codon in codons)
for codons in zip(*(iter_codons(sequence) for sequence in sequences))
]
if len(mask) < aa_length:
mask += [False] * (aa_length - len(mask))
else:
mask = mask[:aa_length]
assert len(mask) == aa_length
return torch.tensor(mask, dtype=torch.bool)


def aa_strs_from_idx_tensor(idx_tensor):
"""Convert a tensor of amino acid indices back to a list of amino acid strings.

Expand Down Expand Up @@ -177,6 +145,38 @@ def aa_mask_tensor_of(*args, **kwargs):
return generic_mask_tensor_of("X", *args, **kwargs)


def _consider_codon(codon):
"""Return False if codon should be masked, True otherwise."""
if "N" in codon:
return False
elif codon in RESERVED_TOKEN_TRANSLATIONS:
return False
else:
return True


def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None):
"""Return a mask tensor indicating codons which contain at least one N.

Codons beyond the length of the sequence are masked. If other_nt_seqs are provided,
the "and" mask will be computed for all sequences. Codons containing marker tokens
are also masked.
"""
if aa_length is None:
aa_length = len(nt_parent) // 3
sequences = (nt_parent,) + other_nt_seqs
mask = [
all(_consider_codon(codon) for codon in codons)
for codons in zip(*(iter_codons(sequence) for sequence in sequences))
]
if len(mask) < aa_length:
mask += [False] * (aa_length - len(mask))
else:
mask = mask[:aa_length]
assert len(mask) == aa_length
return torch.tensor(mask, dtype=torch.bool)


def informative_site_count(seq_str):
return sum(c != "N" for c in seq_str)

Expand Down Expand Up @@ -429,6 +429,32 @@ def chunked(iterable, n):
yield chunk


def assume_single_sequence_is_heavy_chain(seq_arg_idx=0):
"""Wraps a function that takes a heavy/light sequence pair as its first argument and
returns a tuple of results.

The wrapped function will assume that if the first argument is a string, it is a
heavy chain sequence, and in that case will return only the heavy chain result.
"""

def decorator(function):
@wraps(function)
def wrapper(*args, **kwargs):
seq = args[seq_arg_idx]
if isinstance(seq, str):
seq = (seq, "")
args = list(args)
args[seq_arg_idx] = seq
res = function(*args, **kwargs)
return res[0]
else:
return function(*args, **kwargs)

return wrapper

return decorator


def chunk_function(
first_chunkable_idx=0, default_chunk_size=2048, progress_bar_name=None
):
Expand Down Expand Up @@ -516,6 +542,7 @@ def parallelize_function(
max_worker_count = min(mp.cpu_count() // 2, max_workers)
if max_worker_count <= 1:
return function
force_spawn()

@wraps(function)
def wrapper(*args, **kwargs):
Expand Down
60 changes: 46 additions & 14 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import netam.molevol as molevol
import netam.sequences as sequences
import copy
from typing import Tuple


class DASMDataset(DXSMDataset):
Expand Down Expand Up @@ -99,7 +100,7 @@ def to(self, device):
self.multihit_model = self.multihit_model.to(device)


def zap_predictions_along_diagonal(predictions, aa_parents_idxs):
def zap_predictions_along_diagonal(predictions, aa_parents_idxs, fill=-BIG):
"""Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG,
except where aa_parents_idxs >= 20, which indicates no update should be done."""

Expand All @@ -116,7 +117,7 @@ def zap_predictions_along_diagonal(predictions, aa_parents_idxs):
batch_indices[valid_mask],
sequence_indices[valid_mask],
aa_parents_idxs[valid_mask],
] = -BIG
] = fill

return predictions

Expand All @@ -139,10 +140,7 @@ def prediction_pair_of_batch(self, batch):
raise ValueError(
f"log_neutral_aa_probs has non-finite values at relevant positions: {log_neutral_aa_probs[mask]}"
)
# We need the model to see special tokens here. For every other purpose
# they are masked out.
keep_token_mask = mask | sequences.token_mask_of_aa_idxs(aa_parents_idxs)
log_selection_factors = self.model(aa_parents_idxs, keep_token_mask)
log_selection_factors = self.selection_factors_of_aa_idxs(aa_parents_idxs, mask)
return log_neutral_aa_probs, log_selection_factors

def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors):
Expand Down Expand Up @@ -204,19 +202,53 @@ def loss_of_batch(self, batch):
csp_loss = self.xent_loss(csp_pred, csp_targets)
return torch.stack([subs_pos_loss, csp_loss])

def build_selection_matrix_from_parent(self, parent: str):
"""Build a selection matrix from a parent amino acid sequence.
def build_selection_matrix_from_parent_aa(
self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor
):
"""Build a selection matrix from a single parent amino acid sequence. Inputs are
expected to be as prepared in the Dataset constructor.

Values at ambiguous sites are meaningless.
"""
with torch.no_grad():
per_aa_selection_factors = self.selection_factors_of_aa_idxs(
aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0)
).exp()
return zap_predictions_along_diagonal(
per_aa_selection_factors, aa_parent_idxs.unsqueeze(0), fill=1.0
).squeeze(0)

# This is not used anywhere, except for in a few tests. Keeping it around
# for that reason.
def _build_selection_matrix_from_parent(self, parent: Tuple[str, str]):
"""Build a selection matrix from a parent nucleotide sequence, a heavy-chain,
light-chain pair.

Values at ambiguous sites are meaningless. Returned value is a tuple of
selection matrix for heavy and light chain sequences.
"""
# This is simpler than the equivalent in dnsm.py because we get the selection
# matrix directly. Note that selection_factors_of_aa_str does the exponentiation
# so this indeed gives us the selection factors, not the log selection factors.
parent = sequences.translate_sequence(parent)
per_aa_selection_factors = self.model.selection_factors_of_aa_str(parent)
aa_parent_pair = tuple(map(sequences.translate_sequence, parent))
per_aa_selection_factorss = self.model.selection_factors_of_aa_str(
aa_parent_pair
)

parent = parent.replace("X", "A")
parent_idxs = sequences.aa_idx_array_of_str(parent)
per_aa_selection_factors[torch.arange(len(parent_idxs)), parent_idxs] = 1.0
result = []
for per_aa_selection_factors, aa_parent in zip(
per_aa_selection_factorss, aa_parent_pair
):
aa_parent_idxs = torch.tensor(sequences.aa_idx_array_of_str(aa_parent))
if len(per_aa_selection_factors) > 0:
result.append(
zap_predictions_along_diagonal(
per_aa_selection_factors.unsqueeze(0),
aa_parent_idxs.unsqueeze(0),
fill=1.0,
).squeeze(0)
)
else:
result.append(per_aa_selection_factors)

return per_aa_selection_factors
return tuple(result)
68 changes: 57 additions & 11 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
import netam.molevol as molevol
import netam.sequences as sequences

from typing import Tuple


class DNSMDataset(DXSMDataset):

Expand Down Expand Up @@ -127,7 +129,8 @@ def prediction_pair_of_batch(self, batch):
raise ValueError(
f"log_neutral_aa_mut_probs has non-finite values at relevant positions: {log_neutral_aa_mut_probs[mask]}"
)
log_selection_factors = self.model(aa_parents_idxs, mask)
# Right here is where model is evaluated!
log_selection_factors = self.selection_factors_of_aa_idxs(aa_parents_idxs, mask)
return log_neutral_aa_mut_probs, log_selection_factors

def predictions_of_pair(self, log_neutral_aa_mut_probs, log_selection_factors):
Expand Down Expand Up @@ -156,24 +159,67 @@ def loss_of_batch(self, batch):
predictions = self.predictions_of_batch(batch).masked_select(mask)
return self.bce_loss(predictions, aa_subs_indicator)

def build_selection_matrix_from_parent(self, parent: str):
"""Build a selection matrix from a parent amino acid sequence.
def _build_selection_matrix_from_selection_factors(
self, selection_factors, aa_parent_idxs
):
"""Build a selection matrix from a selection factor tensor for a single
sequence.

Values at ambiguous sites are meaningless.
upgrades the provided tensor containing a selection factor per site to a matrix
containing a selection factor per site and amino acid. The wildtype aa selection
factor is set ot 1, and the rest are set to the selection factor.
"""
parent = sequences.translate_sequence(parent)
selection_factors = self.model.selection_factors_of_aa_str(parent)
selection_matrix = torch.zeros((len(selection_factors), 20), dtype=torch.float)
# Every "off-diagonal" entry of the selection matrix is set to the selection
# factor, where "diagonal" means keeping the same amino acid.
selection_matrix[:, :] = selection_factors[:, None]
parent = parent.replace("X", "A")
# Set "diagonal" elements to one.
parent_idxs = sequences.aa_idx_array_of_str(parent)
selection_matrix[torch.arange(len(parent_idxs)), parent_idxs] = 1.0

valid_mask = aa_parent_idxs < 20
selection_matrix[
torch.arange(len(aa_parent_idxs))[valid_mask], aa_parent_idxs[valid_mask]
] = 1.0
selection_matrix[~valid_mask] = 1.0
return selection_matrix

def build_selection_matrix_from_parent_aa(
self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor
):
"""Build a selection matrix from a single parent amino acid sequence.

Values at ambiguous sites are meaningless.
"""
with torch.no_grad():
selection_factors = (
self.selection_factors_of_aa_idxs(
aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0)
)
.squeeze(0)
.exp()
)
return self._build_selection_matrix_from_selection_factors(
selection_factors, aa_parent_idxs
)

def _build_selection_matrix_from_parent(self, parent: Tuple[str, str]):
"""Build a selection matrix from a nucleotide sequence.

Values at ambiguous sites are meaningless.
"""
aa_parent_pair = tuple(map(sequences.translate_sequence, parent))
selection_factorss = self.model.selection_factors_of_aa_str(aa_parent_pair)

result = []
for selection_factors, aa_parent in zip(selection_factorss, aa_parent_pair):
aa_parent_idxs = sequences.aa_idx_array_of_str(aa_parent)
if len(selection_factors) > 0:
result.append(
self._build_selection_matrix_from_selection_factors(
selection_factors, aa_parent_idxs
)
)
else:
result.append(torch.empty(0, 20))
return tuple(result)


class DNSMHyperBurrito(HyperBurrito):
# Note that we have to write the args out explicitly because we use some magic to filter kwargs in the optuna_objective method.
Expand Down
Loading
Loading