From 969b2fcf25bcb4888310214cb23ecc40f7dbb436 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Wed, 30 Oct 2024 03:43:48 -0700 Subject: [PATCH 1/5] replacing a normalization with a check --- netam/dasm.py | 7 +++---- netam/dnsm.py | 7 +++---- netam/molevol.py | 40 ++++++++++++++-------------------------- tests/test_molevol.py | 17 ++++++++--------- 4 files changed, 28 insertions(+), 43 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index 8ea6f34b..18182fdc 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -41,14 +41,13 @@ def update_neutral_probs(self): parent_len = len(nt_parent) mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) - normed_nt_csps = molevol.normalize_sub_probs( - parent_idxs, nt_csps[:parent_len, :] - ) + nt_csps = nt_csps[:parent_len, :] + molevol.check_csps(parent_idxs, nt_csps) neutral_aa_probs = molevol.neutral_aa_probs( parent_idxs.reshape(-1, 3), mut_probs.reshape(-1, 3), - normed_nt_csps.reshape(-1, 3, 4), + nt_csps.reshape(-1, 3, 4), ) if not torch.isfinite(neutral_aa_probs).all(): diff --git a/netam/dnsm.py b/netam/dnsm.py index 29e90498..8ea3461a 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -260,16 +260,15 @@ def update_neutral_probs(self): # with masking out these positions later. We do this below. parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A")) parent_len = len(nt_parent) + molevol.check_csps(parent_idxs, nt_csps) mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) - normed_nt_csps = molevol.normalize_sub_probs( - parent_idxs, nt_csps[:parent_len, :] - ) + nt_csps = nt_csps[:parent_len, :] neutral_aa_mut_prob = molevol.neutral_aa_mut_probs( parent_idxs.reshape(-1, 3), mut_probs.reshape(-1, 3), - normed_nt_csps.reshape(-1, 3, 4), + nt_csps.reshape(-1, 3, 4), multihit_model=multihit_model, ) diff --git a/netam/molevol.py b/netam/molevol.py index a8994453..2656cf0b 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -1,12 +1,7 @@ """Free functions for molecular evolution computation. -We will follow terminology from Yaari et al 2013, where "mutability" refers to the -probability of a nucleotide mutating at a given site, while "substitution" refers to the -probability of a nucleotide mutating to another nucleotide at a given site conditional -on having a mutation. - -We assume that the mutation and substitution probabilities already take branch length -into account. +CSP means conditional substitution probability. CSPs are the probabilities of +alternate states conditioned on there being a substitution. """ import numpy as np @@ -19,41 +14,34 @@ import netam.sequences as sequences -def normalize_sub_probs(parent_idxs: Tensor, sub_probs: Tensor) -> Tensor: - """Normalize substitution probabilities. - - Given a parent DNA sequence and a 2D PyTorch tensor representing substitution - probabilities, this function sets the probability of the actual nucleotide - in the parent sequence to zero and then normalizes each row to form a valid - probability distribution. +def check_csps(parent_idxs: Tensor, csps: Tensor) -> Tensor: + """Make sure that the CSPs are valid, i.e. that they are a probability + distribution and the parent state is zero. Args: parent_idxs (torch.Tensor): The parent sequence indices. sub_probs (torch.Tensor): A 2D PyTorch tensor representing substitution probabilities. Rows correspond to sites, and columns correspond - to "ACGT" bases. - - Returns: - torch.Tensor: A 2D PyTorch tensor with normalized substitution probabilities. + to states (e.g. nucleotides). """ # Assert that sub_probs are within the range [0, 1] modulo rounding error assert torch.all( - sub_probs >= -1e-6 + csps >= -1e-6 ), "Substitution probabilities must be non-negative" assert torch.all( - sub_probs <= 1 + 1e-6 + csps <= 1 + 1e-6 ), "Substitution probabilities must be less than or equal to 1" # Create an array of row indices that matches the shape of `parent_idxs`. row_indices = torch.arange(len(parent_idxs)) - # Set the entries corresponding to the parent sequence to zero. - sub_probs[row_indices, parent_idxs] = 0.0 - - # Normalize the probabilities. - row_sums = torch.sum(sub_probs, dim=1, keepdim=True) - return sub_probs / row_sums + # Assert that the parent entry has a substitution probability of nearly 0. + assert torch.all( + csps[row_indices, parent_idxs] < 1e-6 + ), "Parent entry must have a substitution probability of nearly 0" + assert torch.allclose( + csps[:len(parent_idxs)].sum(dim=1), torch.ones(len(parent_idxs))) def build_mutation_matrices( diff --git a/tests/test_molevol.py b/tests/test_molevol.py index 0876be1e..d967622a 100644 --- a/tests/test_molevol.py +++ b/tests/test_molevol.py @@ -1,4 +1,6 @@ import torch +import pytest + import netam.molevol as molevol from netam import framework @@ -99,19 +101,16 @@ def test_neutral_aa_mut_probs(): assert torch.allclose(correct_tensor, computed_tensor) -def test_normalize_sub_probs(): +def test_check_csps(): parent_idxs = nt_idx_tensor_of_str("AC") - sub_probs = torch.tensor([[0.2, 0.3, 0.4, 0.1], [0.1, 0.2, 0.3, 0.4]]) - - expected_normalized = torch.tensor( + csp = torch.tensor( [[0.0, 0.375, 0.5, 0.125], [0.125, 0.0, 0.375, 0.5]] ) - normalized_sub_probs = molevol.normalize_sub_probs(parent_idxs, sub_probs) + molevol.check_csps(parent_idxs, csp) - assert normalized_sub_probs.shape == (2, 4), "Result has incorrect shape" - assert torch.allclose( - normalized_sub_probs, expected_normalized - ), "Unexpected normalized values" + not_csp = torch.tensor([[0.2, 0.3, 0.4, 0.1], [0.1, 0.2, 0.3, 0.4]]) + with pytest.raises(AssertionError): + molevol.check_csps(parent_idxs, not_csp) def iterative_aaprob_of_mut_and_sub(parent_codon, mut_probs, csps): From adef61a90f925e916a794d89bdee481726dc484f Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Wed, 30 Oct 2024 04:30:34 -0700 Subject: [PATCH 2/5] more renaming :eyeroll: --- netam/dasm.py | 28 ++++++++++++++-------------- netam/dnsm.py | 46 +++++++++++++++++++++++----------------------- 2 files changed, 37 insertions(+), 37 deletions(-) diff --git a/netam/dasm.py b/netam/dasm.py index 18182fdc..09d20dba 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -25,12 +25,12 @@ class DASMDataset(dnsm.DNSMDataset): def update_neutral_probs(self): neutral_aa_probs_l = [] - for nt_parent, mask, nt_rates, branch_length, nt_csps in zip( + for nt_parent, mask, nt_rates, nt_csps, branch_length in zip( self.nt_parents, - self.mask, + self.masks, self.nt_ratess, - self._branch_lengths, self.nt_cspss, + self._branch_lengths, ): mask = mask.to("cpu") nt_rates = nt_rates.to("cpu") @@ -74,25 +74,25 @@ def update_neutral_probs(self): # Note that our masked out positions will have a nan log probability, # which will require us to handle them correctly downstream. - self.log_neutral_aa_probs = torch.log(torch.stack(neutral_aa_probs_l)) + self.log_neutral_aa_probss = torch.log(torch.stack(neutral_aa_probs_l)) def __getitem__(self, idx): return { - "aa_parents_idxs": self.aa_parents_idxs[idx], - "aa_children_idxs": self.aa_children_idxs[idx], - "subs_indicator": self.aa_subs_indicator_tensor[idx], - "mask": self.mask[idx], - "log_neutral_aa_probs": self.log_neutral_aa_probs[idx], + "aa_parents_idxs": self.aa_parents_idxss[idx], + "aa_children_idxs": self.aa_children_idxss[idx], + "subs_indicator": self.aa_subs_indicators[idx], + "mask": self.masks[idx], + "log_neutral_aa_probs": self.log_neutral_aa_probss[idx], "nt_rates": self.nt_ratess[idx], "nt_csps": self.nt_cspss[idx], } def to(self, device): - self.aa_parents_idxs = self.aa_parents_idxs.to(device) - self.aa_children_idxs = self.aa_children_idxs.to(device) - self.aa_subs_indicator_tensor = self.aa_subs_indicator_tensor.to(device) - self.mask = self.mask.to(device) - self.log_neutral_aa_probs = self.log_neutral_aa_probs.to(device) + self.aa_parents_idxss = self.aa_parents_idxss.to(device) + self.aa_children_idxss = self.aa_children_idxss.to(device) + self.aa_subs_indicators = self.aa_subs_indicators.to(device) + self.masks = self.masks.to(device) + self.log_neutral_aa_probss = self.log_neutral_aa_probss.to(device) self.nt_ratess = self.nt_ratess.to(device) self.nt_cspss = self.nt_cspss.to(device) if self.multihit_model is not None: diff --git a/netam/dnsm.py b/netam/dnsm.py index 8ea3461a..a55985af 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -70,25 +70,25 @@ def __init__( # We have sequences of varying length, so we start with all tensors set # to the ambiguous amino acid, and then will fill in the actual values # below. - self.aa_parents_idxs = torch.full( + self.aa_parents_idxss = torch.full( (pcp_count, self.max_aa_seq_len), MAX_AMBIG_AA_IDX ) - self.aa_children_idxs = self.aa_parents_idxs.clone() - self.aa_subs_indicator_tensor = torch.zeros((pcp_count, self.max_aa_seq_len)) + self.aa_children_idxss = self.aa_parents_idxss.clone() + self.aa_subs_indicators = torch.zeros((pcp_count, self.max_aa_seq_len)) - self.mask = torch.ones((pcp_count, self.max_aa_seq_len), dtype=torch.bool) + self.masks = torch.ones((pcp_count, self.max_aa_seq_len), dtype=torch.bool) for i, (aa_parent, aa_child) in enumerate(zip(aa_parents, aa_children)): - self.mask[i, :] = aa_mask_tensor_of(aa_parent, self.max_aa_seq_len) + self.masks[i, :] = aa_mask_tensor_of(aa_parent, self.max_aa_seq_len) aa_seq_len = len(aa_parent) - self.aa_parents_idxs[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(aa_parent) - self.aa_children_idxs[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(aa_child) - self.aa_subs_indicator_tensor[i, :aa_seq_len] = aa_subs_indicator_tensor_of( + self.aa_parents_idxss[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(aa_parent) + self.aa_children_idxss[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(aa_child) + self.aa_subs_indicators[i, :aa_seq_len] = aa_subs_indicator_tensor_of( aa_parent, aa_child ) - assert torch.all(self.mask.sum(dim=1) > 0) - assert torch.max(self.aa_parents_idxs) <= MAX_AMBIG_AA_IDX + assert torch.all(self.masks.sum(dim=1) > 0) + assert torch.max(self.aa_parents_idxss) <= MAX_AMBIG_AA_IDX self._branch_lengths = branch_lengths self.update_neutral_probs() @@ -242,12 +242,12 @@ def update_neutral_probs(self): """ neutral_aa_mut_prob_l = [] - for nt_parent, mask, nt_rates, branch_length, nt_csps in zip( + for nt_parent, mask, nt_rates, nt_csps, branch_length in zip( self.nt_parents, - self.mask, + self.masks, self.nt_ratess, - self._branch_lengths, self.nt_cspss, + self._branch_lengths, ): mask = mask.to("cpu") nt_rates = nt_rates.to("cpu") @@ -298,26 +298,26 @@ def update_neutral_probs(self): # Note that our masked out positions will have a nan log probability, # which will require us to handle them correctly downstream. - self.log_neutral_aa_mut_probs = torch.log(torch.stack(neutral_aa_mut_prob_l)) + self.aa_neutral_aa_mut_probss = torch.log(torch.stack(neutral_aa_mut_prob_l)) def __len__(self): - return len(self.aa_parents_idxs) + return len(self.aa_parents_idxss) def __getitem__(self, idx): return { - "aa_parents_idxs": self.aa_parents_idxs[idx], - "subs_indicator": self.aa_subs_indicator_tensor[idx], - "mask": self.mask[idx], - "log_neutral_aa_mut_probs": self.log_neutral_aa_mut_probs[idx], + "aa_parents_idxs": self.aa_parents_idxss[idx], + "subs_indicator": self.aa_subs_indicators[idx], + "mask": self.masks[idx], + "log_neutral_aa_mut_probs": self.aa_neutral_aa_mut_probss[idx], "nt_rates": self.nt_ratess[idx], "nt_csps": self.nt_cspss[idx], } def to(self, device): - self.aa_parents_idxs = self.aa_parents_idxs.to(device) - self.aa_subs_indicator_tensor = self.aa_subs_indicator_tensor.to(device) - self.mask = self.mask.to(device) - self.log_neutral_aa_mut_probs = self.log_neutral_aa_mut_probs.to(device) + self.aa_parents_idxss = self.aa_parents_idxss.to(device) + self.aa_subs_indicators = self.aa_subs_indicators.to(device) + self.masks = self.masks.to(device) + self.aa_neutral_aa_mut_probss = self.aa_neutral_aa_mut_probss.to(device) self.nt_ratess = self.nt_ratess.to(device) self.nt_cspss = self.nt_cspss.to(device) if self.multihit_model is not None: From 41bdfce968cb618c92507f449cf2baa0175fd7d9 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Wed, 30 Oct 2024 04:32:08 -0700 Subject: [PATCH 3/5] make format --- netam/dnsm.py | 8 ++++++-- netam/molevol.py | 15 +++++++-------- tests/test_molevol.py | 4 +--- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index a55985af..c36082a4 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -81,8 +81,12 @@ def __init__( for i, (aa_parent, aa_child) in enumerate(zip(aa_parents, aa_children)): self.masks[i, :] = aa_mask_tensor_of(aa_parent, self.max_aa_seq_len) aa_seq_len = len(aa_parent) - self.aa_parents_idxss[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(aa_parent) - self.aa_children_idxss[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig(aa_child) + self.aa_parents_idxss[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig( + aa_parent + ) + self.aa_children_idxss[i, :aa_seq_len] = aa_idx_tensor_of_str_ambig( + aa_child + ) self.aa_subs_indicators[i, :aa_seq_len] = aa_subs_indicator_tensor_of( aa_parent, aa_child ) diff --git a/netam/molevol.py b/netam/molevol.py index 2656cf0b..2aef1c10 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -1,7 +1,7 @@ """Free functions for molecular evolution computation. -CSP means conditional substitution probability. CSPs are the probabilities of -alternate states conditioned on there being a substitution. +CSP means conditional substitution probability. CSPs are the probabilities of alternate +states conditioned on there being a substitution. """ import numpy as np @@ -15,8 +15,8 @@ def check_csps(parent_idxs: Tensor, csps: Tensor) -> Tensor: - """Make sure that the CSPs are valid, i.e. that they are a probability - distribution and the parent state is zero. + """Make sure that the CSPs are valid, i.e. that they are a probability distribution + and the parent state is zero. Args: parent_idxs (torch.Tensor): The parent sequence indices. @@ -26,9 +26,7 @@ def check_csps(parent_idxs: Tensor, csps: Tensor) -> Tensor: """ # Assert that sub_probs are within the range [0, 1] modulo rounding error - assert torch.all( - csps >= -1e-6 - ), "Substitution probabilities must be non-negative" + assert torch.all(csps >= -1e-6), "Substitution probabilities must be non-negative" assert torch.all( csps <= 1 + 1e-6 ), "Substitution probabilities must be less than or equal to 1" @@ -41,7 +39,8 @@ def check_csps(parent_idxs: Tensor, csps: Tensor) -> Tensor: csps[row_indices, parent_idxs] < 1e-6 ), "Parent entry must have a substitution probability of nearly 0" assert torch.allclose( - csps[:len(parent_idxs)].sum(dim=1), torch.ones(len(parent_idxs))) + csps[: len(parent_idxs)].sum(dim=1), torch.ones(len(parent_idxs)) + ) def build_mutation_matrices( diff --git a/tests/test_molevol.py b/tests/test_molevol.py index d967622a..62c21777 100644 --- a/tests/test_molevol.py +++ b/tests/test_molevol.py @@ -103,9 +103,7 @@ def test_neutral_aa_mut_probs(): def test_check_csps(): parent_idxs = nt_idx_tensor_of_str("AC") - csp = torch.tensor( - [[0.0, 0.375, 0.5, 0.125], [0.125, 0.0, 0.375, 0.5]] - ) + csp = torch.tensor([[0.0, 0.375, 0.5, 0.125], [0.125, 0.0, 0.375, 0.5]]) molevol.check_csps(parent_idxs, csp) not_csp = torch.tensor([[0.2, 0.3, 0.4, 0.1], [0.1, 0.2, 0.3, 0.4]]) From e322bf8b88c00495758fa14a873b8882417c0fc5 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Wed, 30 Oct 2024 12:09:16 -0700 Subject: [PATCH 4/5] fix --- netam/dnsm.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index c36082a4..448e255c 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -302,7 +302,7 @@ def update_neutral_probs(self): # Note that our masked out positions will have a nan log probability, # which will require us to handle them correctly downstream. - self.aa_neutral_aa_mut_probss = torch.log(torch.stack(neutral_aa_mut_prob_l)) + self.log_neutral_aa_mut_probss = torch.log(torch.stack(neutral_aa_mut_prob_l)) def __len__(self): return len(self.aa_parents_idxss) @@ -312,7 +312,7 @@ def __getitem__(self, idx): "aa_parents_idxs": self.aa_parents_idxss[idx], "subs_indicator": self.aa_subs_indicators[idx], "mask": self.masks[idx], - "log_neutral_aa_mut_probs": self.aa_neutral_aa_mut_probss[idx], + "log_neutral_aa_mut_probs": self.log_neutral_aa_mut_probss[idx], "nt_rates": self.nt_ratess[idx], "nt_csps": self.nt_cspss[idx], } @@ -321,7 +321,7 @@ def to(self, device): self.aa_parents_idxss = self.aa_parents_idxss.to(device) self.aa_subs_indicators = self.aa_subs_indicators.to(device) self.masks = self.masks.to(device) - self.aa_neutral_aa_mut_probss = self.aa_neutral_aa_mut_probss.to(device) + self.log_neutral_aa_mut_probss = self.log_neutral_aa_mut_probss.to(device) self.nt_ratess = self.nt_ratess.to(device) self.nt_cspss = self.nt_cspss.to(device) if self.multihit_model is not None: From 427f3a16cccc726290c7a90af879a11c986ac5a5 Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Wed, 30 Oct 2024 12:10:50 -0700 Subject: [PATCH 5/5] consistentize --- netam/dnsm.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/netam/dnsm.py b/netam/dnsm.py index 448e255c..aabdd466 100644 --- a/netam/dnsm.py +++ b/netam/dnsm.py @@ -310,7 +310,7 @@ def __len__(self): def __getitem__(self, idx): return { "aa_parents_idxs": self.aa_parents_idxss[idx], - "subs_indicator": self.aa_subs_indicators[idx], + "aa_subs_indicator": self.aa_subs_indicators[idx], "mask": self.masks[idx], "log_neutral_aa_mut_probs": self.log_neutral_aa_mut_probss[idx], "nt_rates": self.nt_ratess[idx], @@ -372,7 +372,7 @@ def predictions_of_batch(self, batch): return self.predictions_of_pair(log_neutral_aa_mut_probs, log_selection_factors) def loss_of_batch(self, batch): - aa_subs_indicator = batch["subs_indicator"].to(self.device) + aa_subs_indicator = batch["aa_subs_indicator"].to(self.device) mask = batch["mask"].to(self.device) aa_subs_indicator = aa_subs_indicator.masked_select(mask) predictions = self.predictions_of_batch(batch).masked_select(mask)