diff --git a/netam/codon_table.py b/netam/codon_table.py new file mode 100644 index 00000000..11a17707 --- /dev/null +++ b/netam/codon_table.py @@ -0,0 +1,43 @@ +import numpy as np + +from Bio.Data import CodonTable +from netam.sequences import AA_STR_SORTED + + +def single_mutant_aa_indices(codon): + """Given a codon, return the amino acid indices for all single-mutant neighbors. + + Args: + codon (str): A three-letter codon (e.g., "ATG"). + AA_STR_SORTED (str): A string of amino acids in a sorted order. + + Returns: + list of int: Indices of the resulting amino acids for single mutants. + """ + standard_table = CodonTable.unambiguous_dna_by_id[1] # Standard codon table + bases = ["A", "C", "G", "T"] + + mutant_aa_indices = set() # Use a set to avoid duplicates + + # Generate all single-mutant neighbors + for pos in range(3): # Codons have 3 positions + for base in bases: + if base != codon[pos]: # Mutate only if it's a different base + mutant_codon = codon[:pos] + base + codon[pos + 1 :] + + # Check if the mutant codon translates to a valid amino acid + if mutant_codon in standard_table.forward_table: + mutant_aa = standard_table.forward_table[mutant_codon] + mutant_aa_indices.add(AA_STR_SORTED.index(mutant_aa)) + + return sorted(mutant_aa_indices) + + +def make_codon_neighbor_indicator(nt_seq): + """Create a binary array indicating the single-mutant amino acid neighbors of each + codon in a given DNA sequence.""" + neighbor = np.zeros((len(AA_STR_SORTED), len(nt_seq) // 3), dtype=bool) + for i in range(0, len(nt_seq), 3): + codon = nt_seq[i : i + 3] + neighbor[single_mutant_aa_indices(codon), i // 3] = True + return neighbor diff --git a/netam/dasm.py b/netam/dasm.py index 1095e4a5..811e30d6 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -3,11 +3,8 @@ import torch import torch.nn.functional as F -from netam.common import ( - clamp_probability, - BIG, -) -from netam.dxsm import DXSMDataset, DXSMBurrito +from netam.common import clamp_probability +from netam.dxsm import DXSMDataset, DXSMBurrito, zap_predictions_along_diagonal import netam.framework as framework import netam.molevol as molevol import netam.sequences as sequences @@ -100,28 +97,6 @@ def to(self, device): self.multihit_model = self.multihit_model.to(device) -def zap_predictions_along_diagonal(predictions, aa_parents_idxs, fill=-BIG): - """Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG, - except where aa_parents_idxs >= 20, which indicates no update should be done.""" - - device = predictions.device - batch_size, L, _ = predictions.shape - batch_indices = torch.arange(batch_size, device=device)[:, None].expand(-1, L) - sequence_indices = torch.arange(L, device=device)[None, :].expand(batch_size, -1) - - # Create a mask for valid positions (where aa_parents_idxs is less than 20) - valid_mask = aa_parents_idxs < 20 - - # Only update the predictions for valid positions - predictions[ - batch_indices[valid_mask], - sequence_indices[valid_mask], - aa_parents_idxs[valid_mask], - ] = fill - - return predictions - - class DASMBurrito(framework.TwoLossMixin, DXSMBurrito): model_type = "dasm" @@ -202,22 +177,6 @@ def loss_of_batch(self, batch): csp_loss = self.xent_loss(csp_pred, csp_targets) return torch.stack([subs_pos_loss, csp_loss]) - def build_selection_matrix_from_parent_aa( - self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor - ): - """Build a selection matrix from a single parent amino acid sequence. Inputs are - expected to be as prepared in the Dataset constructor. - - Values at ambiguous sites are meaningless. - """ - with torch.no_grad(): - per_aa_selection_factors = self.selection_factors_of_aa_idxs( - aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0) - ).exp() - return zap_predictions_along_diagonal( - per_aa_selection_factors, aa_parent_idxs.unsqueeze(0), fill=1.0 - ).squeeze(0) - # This is not used anywhere, except for in a few tests. Keeping it around # for that reason. def _build_selection_matrix_from_parent(self, parent: Tuple[str, str]): diff --git a/netam/dcsm.py b/netam/dcsm.py new file mode 100644 index 00000000..5cf44b37 --- /dev/null +++ b/netam/dcsm.py @@ -0,0 +1,253 @@ +"""Defining the deep natural selection model (DNSM).""" + +import copy + +import torch +import torch.nn.functional as F + +from netam.common import ( + clamp_probability, + BIG, +) +from netam.dxsm import DXSMDataset, DXSMBurrito +import netam.molevol as molevol + +from netam.sequences import ( + build_stop_codon_indicator_tensor, + nt_idx_tensor_of_str, + codon_idx_tensor_of_str_ambig, + AMBIGUOUS_CODON_IDX, + CODON_AA_INDICATOR_MATRIX, +) + + +class DCSMDataset(DXSMDataset): + + def __init__( + self, + *args, + **kwargs, + ): + super().__init__(*args, **kwargs) + assert len(self.nt_parents) == len(self.nt_children) + # We need to add codon index tensors to the dataset. + + self.max_codon_seq_len = self.max_aa_seq_len + self.codon_parents_idxss = torch.full_like( + self.aa_parents_idxss, AMBIGUOUS_CODON_IDX + ) + self.codon_children_idxss = self.codon_parents_idxss.clone() + + # We are using the modified nt_parents and nt_children here because we + # don't want any funky symbols in our codon indices. + for i, (nt_parent, nt_child) in enumerate( + zip(self.nt_parents, self.nt_children) + ): + assert len(nt_parent) % 3 == 0 + codon_seq_len = len(nt_parent) // 3 + self.codon_parents_idxss[i, :codon_seq_len] = codon_idx_tensor_of_str_ambig( + nt_parent + ) + self.codon_children_idxss[i, :codon_seq_len] = ( + codon_idx_tensor_of_str_ambig(nt_child) + ) + assert torch.max(self.codon_parents_idxss) <= AMBIGUOUS_CODON_IDX + + def update_neutral_probs(self): + """Update the neutral mutation probabilities for the dataset. + + This is a somewhat vague name, but that's because it includes all of the various + types of neutral mutation probabilities that we might want to compute. + + In this case it's the neutral codon probabilities. + """ + neutral_codon_probs_l = [] + + for nt_parent, mask, nt_rates, nt_csps, branch_length in zip( + self.nt_parents, + self.masks, + self.nt_ratess, + self.nt_cspss, + self._branch_lengths, + ): + mask = mask.to("cpu") + nt_rates = nt_rates.to("cpu") + nt_csps = nt_csps.to("cpu") + if self.multihit_model is not None: + multihit_model = copy.deepcopy(self.multihit_model).to("cpu") + else: + multihit_model = None + # Note we are replacing all Ns with As, which means that we need to be careful + # with masking out these positions later. We do this below. + parent_idxs = nt_idx_tensor_of_str(nt_parent.replace("N", "A")) + parent_len = len(nt_parent) + + mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) + nt_csps = nt_csps[:parent_len, :] + nt_mask = mask.repeat_interleave(3)[: len(nt_parent)] + molevol.check_csps(parent_idxs[nt_mask], nt_csps[: len(nt_parent)][nt_mask]) + + neutral_codon_probs = molevol.neutral_codon_probs( + parent_idxs.reshape(-1, 3), + mut_probs.reshape(-1, 3), + nt_csps.reshape(-1, 3, 4), + multihit_model=multihit_model, + ) + + if not torch.isfinite(neutral_codon_probs).all(): + print(f"Found a non-finite neutral_codon_prob") + print(f"nt_parent: {nt_parent}") + print(f"mask: {mask}") + print(f"nt_rates: {nt_rates}") + print(f"nt_csps: {nt_csps}") + print(f"branch_length: {branch_length}") + raise ValueError( + f"neutral_codon_probs is not finite: {neutral_codon_probs}" + ) + + # Ensure that all values are positive before taking the log later + neutral_codon_probs = clamp_probability(neutral_codon_probs) + + pad_len = self.max_aa_seq_len - neutral_codon_probs.shape[0] + if pad_len > 0: + neutral_codon_probs = F.pad( + neutral_codon_probs, (0, 0, 0, pad_len), value=1e-8 + ) + # Here we zero out masked positions. + neutral_codon_probs *= mask[:, None] + + neutral_codon_probs_l.append(neutral_codon_probs) + + # Note that our masked out positions will have a nan log probability, + # which will require us to handle them correctly downstream. + self.log_neutral_codon_probss = torch.log(torch.stack(neutral_codon_probs_l)) + + def __getitem__(self, idx): + return { + "codon_parents_idxs": self.codon_parents_idxss[idx], + "codon_children_idxs": self.codon_children_idxss[idx], + "aa_parents_idxs": self.aa_parents_idxss[idx], + "aa_children_idxs": self.aa_children_idxss[idx], + "subs_indicator": self.aa_subs_indicators[idx], + "mask": self.masks[idx], + "log_neutral_codon_probs": self.log_neutral_codon_probss[idx], + "nt_rates": self.nt_ratess[idx], + "nt_csps": self.nt_cspss[idx], + } + + def to(self, device): + self.aa_codon_indicator_matrix = self.aa_codon_indicator_matrix.to(device) + self.stop_codon_zapper = self.stop_codon_zapper.to(device) + self.codon_parents_idxss = self.codon_parents_idxss.to(device) + self.codon_children_idxss = self.codon_children_idxss.to(device) + self.aa_parents_idxss = self.aa_parents_idxss.to(device) + self.aa_children_idxss = self.aa_children_idxss.to(device) + self.aa_subs_indicators = self.aa_subs_indicators.to(device) + self.masks = self.masks.to(device) + self.log_neutral_codon_probss = self.log_neutral_codon_probss.to(device) + self.nt_ratess = self.nt_ratess.to(device) + self.nt_cspss = self.nt_cspss.to(device) + if self.multihit_model is not None: + self.multihit_model = self.multihit_model.to(device) + + +class DCSMBurrito(DXSMBurrito): + + model_type = "dcsm" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.xent_loss = torch.nn.CrossEntropyLoss() + self.stop_codon_zapper = (build_stop_codon_indicator_tensor() * -BIG).to( + self.device + ) + self.aa_codon_indicator_matrix = CODON_AA_INDICATOR_MATRIX.to(self.device).T + + def prediction_pair_of_batch(self, batch): + """Get log neutral codon substitution probabilities and log selection factors + for a batch of data. + + We don't mask on the output, which will thus contain junk in all of the masked + sites. + """ + aa_parents_idxs = batch["aa_parents_idxs"].to(self.device) + mask = batch["mask"].to(self.device) + log_neutral_codon_probs = batch["log_neutral_codon_probs"].to(self.device) + if not torch.isfinite(log_neutral_codon_probs[mask]).all(): + raise ValueError( + f"log_neutral_codon_probs has non-finite values at relevant positions: {log_neutral_codon_probs[mask]}" + ) + log_selection_factors = self.selection_factors_of_aa_idxs(aa_parents_idxs, mask) + return log_neutral_codon_probs, log_selection_factors + + def predictions_of_batch(self, batch): + """Make log probability predictions for a batch of data. + + In this case they are log probabilities of codons, which are made to be + probabilities by setting the parent codon to 1 - sum(children). + + After all this, we clip the probabilities below to avoid log(0) issues. + So, in cases when the sum of the children is > 1, we don't give a + normalized probability distribution, but that won't crash the loss + calculation because that step uses softmax. + + Note that make all ambiguous codons nan in the output, ensuring that + they must get properly masked downstream. + """ + log_neutral_codon_probs, log_selection_factors = self.prediction_pair_of_batch( + batch + ) + + # This code block, in other burritos, is done in a separate function, + # but we can't do that here because we need to normalize the + # probabilities in a way that is not possible without having the index + # of the parent codon. Namely, we need to set the parent codon to 1 - + # sum(children). + + # The aa_codon_indicator_matrix lifts things up from aa land to codon land. + log_preds = ( + log_neutral_codon_probs + + log_selection_factors @ self.aa_codon_indicator_matrix + + self.stop_codon_zapper + ) + assert torch.isnan(log_preds).sum() == 0 + + parent_indices = batch["codon_parents_idxs"].to(self.device) # Shape: [B, L] + valid_mask = parent_indices != AMBIGUOUS_CODON_IDX # Shape: [B, L] + + # Convert to linear space so we can add probabilities. + preds = torch.exp(log_preds) + + # Zero out the parent indices in preds, while keeping the computation + # graph intact. + preds_zeroer = torch.ones_like(preds) + preds_zeroer[valid_mask, parent_indices[valid_mask]] = 0.0 + preds = preds * preds_zeroer + + # Calculate the non-parent sum after zeroing out the parent indices. + non_parent_sum = preds[valid_mask, :].sum(dim=-1) + + # Add these parent values back in, again keeping the computation graph intact. + preds_parent = torch.zeros_like(preds) + preds_parent[valid_mask, parent_indices[valid_mask]] = 1.0 - non_parent_sum + preds = preds + preds_parent + + # We have to clamp the predictions to avoid log(0) issues. + preds = torch.clamp(preds, min=torch.finfo(preds.dtype).eps) + + log_preds = torch.log(preds) + + # Set ambiguous codons to nan to make sure that we handle them correctly downstream. + log_preds[~valid_mask, :] = float("nan") + + return log_preds + + def loss_of_batch(self, batch): + codon_children_idxs = batch["codon_children_idxs"].to(self.device) + mask = batch["mask"].to(self.device) + + predictions = self.predictions_of_batch(batch)[mask] + assert torch.isnan(predictions).sum() == 0 + codon_children_idxs = codon_children_idxs[mask] + + return self.xent_loss(predictions, codon_children_idxs) diff --git a/netam/dnsm.py b/netam/dnsm.py index bc05d479..abce4255 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -56,14 +56,14 @@ def update_neutral_probs(self): mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) nt_csps = nt_csps[:parent_len, :] - neutral_aa_mut_prob = molevol.neutral_aa_mut_probs( + neutral_aa_mut_probs = molevol.neutral_aa_mut_probs( parent_idxs.reshape(-1, 3), mut_probs.reshape(-1, 3), nt_csps.reshape(-1, 3, 4), multihit_model=multihit_model, ) - if not torch.isfinite(neutral_aa_mut_prob).all(): + if not torch.isfinite(neutral_aa_mut_probs).all(): print(f"Found a non-finite neutral_aa_mut_prob") print(f"nt_parent: {nt_parent}") print(f"mask: {mask}") @@ -71,21 +71,21 @@ def update_neutral_probs(self): print(f"nt_csps: {nt_csps}") print(f"branch_length: {branch_length}") raise ValueError( - f"neutral_aa_mut_prob is not finite: {neutral_aa_mut_prob}" + f"neutral_aa_mut_prob is not finite: {neutral_aa_mut_probs}" ) # Ensure that all values are positive before taking the log later - neutral_aa_mut_prob = clamp_probability(neutral_aa_mut_prob) + neutral_aa_mut_probs = clamp_probability(neutral_aa_mut_probs) - pad_len = self.max_aa_seq_len - neutral_aa_mut_prob.shape[0] + pad_len = self.max_aa_seq_len - neutral_aa_mut_probs.shape[0] if pad_len > 0: - neutral_aa_mut_prob = F.pad( - neutral_aa_mut_prob, (0, pad_len), value=1e-8 + neutral_aa_mut_probs = F.pad( + neutral_aa_mut_probs, (0, pad_len), value=1e-8 ) # Here we zero out masked positions. - neutral_aa_mut_prob *= mask + neutral_aa_mut_probs *= mask - neutral_aa_mut_prob_l.append(neutral_aa_mut_prob) + neutral_aa_mut_prob_l.append(neutral_aa_mut_probs) # Note that our masked out positions will have a nan log probability, # which will require us to handle them correctly downstream. diff --git a/netam/dxsm.py b/netam/dxsm.py index 58f13c18..512d3922 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -19,6 +19,7 @@ stack_heterogeneous, codon_mask_tensor_of, assert_pcp_valid, + BIG, ) import netam.framework as framework import netam.molevol as molevol @@ -78,6 +79,7 @@ def __init__( assert self.masks.shape[1] * 3 == self.nt_cspss.shape[1] assert torch.all(self.masks.sum(dim=1) > 0) assert torch.max(self.aa_parents_idxss) <= MAX_AA_TOKEN_IDX + assert torch.max(self.aa_children_idxss) <= MAX_AA_TOKEN_IDX self._branch_lengths = branch_lengths self.update_neutral_probs() @@ -421,6 +423,23 @@ def to_crepe(self): encoder = framework.PlaceholderEncoder() return framework.Crepe(encoder, self.model, training_hyperparameters) + # This is overridden in DNSMBurrito + def build_selection_matrix_from_parent_aa( + self, aa_parent_idxs: torch.Tensor, mask: torch.Tensor + ): + """Build a selection matrix from a single parent amino acid sequence. Inputs are + expected to be as prepared in the Dataset constructor. + + Values at ambiguous sites are meaningless. + """ + with torch.no_grad(): + per_aa_selection_factors = self.selection_factors_of_aa_idxs( + aa_parent_idxs.unsqueeze(0), mask.unsqueeze(0) + ).exp() + return zap_predictions_along_diagonal( + per_aa_selection_factors, aa_parent_idxs.unsqueeze(0), fill=1.0 + ).squeeze(0) + @abstractmethod def loss_of_batch(self, batch): pass @@ -430,3 +449,25 @@ def worker_optimize_branch_length(burrito_class, model, dataset, optimization_kw """The worker used for parallel branch length optimization.""" burrito = burrito_class(None, dataset, copy.deepcopy(model)) return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) + + +def zap_predictions_along_diagonal(predictions, aa_parents_idxs, fill=-BIG): + """Set the diagonal (i.e. no amino acid change) of the predictions tensor to -BIG, + except where aa_parents_idxs >= 20, which indicates no update should be done.""" + + device = predictions.device + batch_size, L, _ = predictions.shape + batch_indices = torch.arange(batch_size, device=device)[:, None].expand(-1, L) + sequence_indices = torch.arange(L, device=device)[None, :].expand(batch_size, -1) + + # Create a mask for valid positions (where aa_parents_idxs is less than 20) + valid_mask = aa_parents_idxs < 20 + + # Only update the predictions for valid positions + predictions[ + batch_indices[valid_mask], + sequence_indices[valid_mask], + aa_parents_idxs[valid_mask], + ] = fill + + return predictions diff --git a/netam/framework.py b/netam/framework.py index 1581e035..e0930f16 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -599,6 +599,9 @@ def process_data_loader(self, data_loader, train_mode=False, loss_reduction=None self.optimizer.zero_grad() scalar_loss.backward() + if torch.isnan(scalar_loss): + raise ValueError(f"NaN in loss: {scalar_loss.item()}") + nan_in_gradients = False for name, param in self.model.named_parameters(): if torch.isnan(param).any(): diff --git a/netam/molevol.py b/netam/molevol.py index 2aef1c10..8c03909b 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -339,14 +339,14 @@ def build_codon_mutsel( return codon_mutsel, sums_too_big -def neutral_aa_probs( +def neutral_codon_probs( parent_codon_idxs: Tensor, codon_mut_probs: Tensor, codon_csps: Tensor, multihit_model=None, ) -> Tensor: - """For every site, what is the probability that the amino acid will mutate to every - amino acid? + """For every site, what is the probability that the site will mutate to every + alternate codon? Args: parent_codon_idxs (torch.Tensor): The parent codons for each sequence. Shape: (codon_count, 3) @@ -354,8 +354,8 @@ def neutral_aa_probs( codon_csps (torch.Tensor): The substitution probabilities for each site in each codon. Shape: (codon_count, 3, 4) Returns: - torch.Tensor: The probability that each site will change to each amino acid. - Shape: (codon_count, 20) + torch.Tensor: The probability that each site will change to each codon. + Shape: (codon_count, 64) """ mut_matrices = build_mutation_matrices( @@ -366,8 +366,36 @@ def neutral_aa_probs( if multihit_model is not None: codon_probs = multihit_model(parent_codon_idxs, codon_probs) + return codon_probs.view(-1, 64) + + +def neutral_aa_probs( + parent_codon_idxs: Tensor, + codon_mut_probs: Tensor, + codon_csps: Tensor, + multihit_model=None, +) -> Tensor: + """For every site, what is the probability that the site will mutate to every + alternate amino acid? + + Args: + parent_codon_idxs (torch.Tensor): The parent codons for each sequence. Shape: (codon_count, 3) + codon_mut_probs (torch.Tensor): The mutation probabilities for each site in each codon. Shape: (codon_count, 3) + codon_csps (torch.Tensor): The substitution probabilities for each site in each codon. Shape: (codon_count, 3, 4) + + Returns: + torch.Tensor: The probability that each site will change to each codon. + Shape: (codon_count, 20) + """ + codon_probs = neutral_codon_probs( + parent_codon_idxs, + codon_mut_probs, + codon_csps, + multihit_model=multihit_model, + ) + # Get the probability of mutating to each amino acid. - aa_probs = codon_probs.view(-1, 64) @ CODON_AA_INDICATOR_MATRIX + aa_probs = codon_probs @ CODON_AA_INDICATOR_MATRIX return aa_probs diff --git a/netam/sequences.py b/netam/sequences.py index 1d574279..76d48219 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -30,6 +30,9 @@ MAX_AA_TOKEN_IDX = MAX_KNOWN_TOKEN_COUNT - 1 CODONS = ["".join(codon_list) for codon_list in itertools.product(BASES, repeat=3)] STOP_CODONS = ["TAA", "TAG", "TGA"] +AMBIGUOUS_CODON_IDX = len(CODONS) + + # Each token in RESERVED_TOKENS will appear once in aa strings, and three times # in nt strings. RESERVED_TOKEN_TRANSLATIONS = {token * 3: token for token in RESERVED_TOKENS} @@ -117,6 +120,14 @@ def dataset_inputs_of_pcp_df(pcp_df, known_token_count): ) +def build_stop_codon_indicator_tensor(): + """Return a tensor indicating the stop codons.""" + stop_codon_indicator = torch.zeros(len(CODONS)) + for stop_codon in STOP_CODONS: + stop_codon_indicator[CODONS.index(stop_codon)] = 1.0 + return stop_codon_indicator + + def nt_idx_array_of_str(nt_str): """Return the indices of the nucleotides in a string.""" try: @@ -153,6 +164,21 @@ def aa_idx_tensor_of_str(aa_str): raise +def idx_of_codon_allowing_ambiguous(codon): + if "N" in codon: + return AMBIGUOUS_CODON_IDX + else: + return CODONS.index(codon) + + +def codon_idx_tensor_of_str_ambig(nt_str): + """Return the indices of the codons in a string.""" + assert len(nt_str) % 3 == 0 + return torch.tensor( + [idx_of_codon_allowing_ambiguous(codon) for codon in iter_codons(nt_str)] + ) + + def aa_onehot_tensor_of_str(aa_str): aa_onehot = torch.zeros((len(aa_str), 20)) aa_indices_parent = aa_idx_array_of_str(aa_str) diff --git a/tests/test_dcsm.py b/tests/test_dcsm.py new file mode 100644 index 00000000..4a3abbec --- /dev/null +++ b/tests/test_dcsm.py @@ -0,0 +1,49 @@ +import torch +import pytest + +from netam.common import force_spawn +from netam.sequences import MAX_KNOWN_TOKEN_COUNT +from netam.models import TransformerBinarySelectionModelWiggleAct +from netam.dcsm import ( + DCSMBurrito, + DCSMDataset, +) + + +@pytest.fixture(scope="module") +def dcsm_burrito(pcp_df): + force_spawn() + """Fixture that returns the DNSM Burrito object.""" + pcp_df["in_train"] = True + pcp_df.loc[pcp_df.index[-15:], "in_train"] = False + train_dataset, val_dataset = DCSMDataset.train_val_datasets_of_pcp_df( + pcp_df, MAX_KNOWN_TOKEN_COUNT + ) + + model = TransformerBinarySelectionModelWiggleAct( + nhead=2, + d_model_per_head=4, + dim_feedforward=256, + layer_count=2, + output_dim=20, + ) + + burrito = DCSMBurrito( + train_dataset, + val_dataset, + model, + batch_size=32, + learning_rate=0.001, + min_learning_rate=0.0001, + ) + burrito.joint_train( + epochs=1, cycle_count=2, training_method="full", optimize_bl_first_cycle=False + ) + return burrito + + +def test_parallel_branch_length_optimization(dcsm_burrito): + dataset = dcsm_burrito.val_dataset + parallel_branch_lengths = dcsm_burrito.find_optimal_branch_lengths(dataset) + branch_lengths = dcsm_burrito.serial_find_optimal_branch_lengths(dataset) + assert torch.allclose(branch_lengths, parallel_branch_lengths) diff --git a/tests/test_sequences.py b/tests/test_sequences.py index 17ce83d1..a7df4dfd 100644 --- a/tests/test_sequences.py +++ b/tests/test_sequences.py @@ -13,7 +13,9 @@ CODON_AA_INDICATOR_MATRIX, MAX_KNOWN_TOKEN_COUNT, AA_AMBIG_IDX, + AMBIGUOUS_CODON_IDX, aa_onehot_tensor_of_str, + codon_idx_tensor_of_str_ambig, nt_idx_array_of_str, nt_subs_indicator_tensor_of, translate_sequences, @@ -94,6 +96,13 @@ def test_nucleotide_indices_of_codon(): assert nt_idx_array_of_str("GCG").tolist() == [2, 1, 2] +def test_codon_idx_tensor_of_str(): + nt_str = "AAAAACTTGTTTNTT" + expected_output = torch.tensor([0, 1, 62, 63, AMBIGUOUS_CODON_IDX]) + output = codon_idx_tensor_of_str_ambig(nt_str) + assert torch.equal(output, expected_output) + + def test_aa_onehot_tensor_of_str(): aa_str = "QY"