From dda468ce239233e27606a2bc1a1af415e9ae188c Mon Sep 17 00:00:00 2001 From: Erick Matsen Date: Tue, 10 Dec 2024 08:39:15 -0800 Subject: [PATCH] a sketch WIP WIP next: fix nonzero subs prob for ^ WIP WIP remove joined_mode, make everything be joined_mode always EOD commit: working on recovering OE plotting with changes the final solution? partial cleanup some cleanup more cleanup --- netam/common.py | 31 ++++++++++++----- netam/dasm.py | 9 +++-- netam/dxsm.py | 25 ++++++++------ netam/framework.py | 75 +++++++++++++++++++++++++++++++++++++---- netam/models.py | 4 +-- netam/molevol.py | 4 +-- netam/sequences.py | 72 ++++++++++++++++++++++++++++----------- tests/conftest.py | 1 + tests/test_dnsm.py | 5 +-- tests/test_sequences.py | 16 +++++++++ 10 files changed, 189 insertions(+), 53 deletions(-) diff --git a/netam/common.py b/netam/common.py index 7a1970df..a5799f79 100644 --- a/netam/common.py +++ b/netam/common.py @@ -13,15 +13,17 @@ from torch import nn, Tensor import multiprocessing as mp -from netam.sequences import iter_codons, apply_aa_mask_to_nt_sequence +from netam.sequences import ( + iter_codons, + apply_aa_mask_to_nt_sequence, + TOKEN_TRANSLATIONS, + BASES, + BASES_AND_N_TO_INDEX, + TOKEN_STR_SORTED +) BIG = 1e9 SMALL_PROB = 1e-6 -BASES = ["A", "C", "G", "T"] -BASES_AND_N_TO_INDEX = {"A": 0, "C": 1, "G": 2, "T": 3, "N": 4} -AA_STR_SORTED = "ACDEFGHIKLMNPQRSTVWY" -AA_STR_SORTED_AMBIG = AA_STR_SORTED + "X" -MAX_AMBIG_AA_IDX = len(AA_STR_SORTED_AMBIG) - 1 # I needed some sequence to use to normalize the rate of mutation in the SHM model. # So, I chose perhaps the most famous antibody sequence, VRC01: @@ -65,7 +67,7 @@ def aa_idx_tensor_of_str_ambig(aa_str): character.""" try: return torch.tensor( - [AA_STR_SORTED_AMBIG.index(aa) for aa in aa_str], dtype=torch.int + [TOKEN_STR_SORTED.index(aa) for aa in aa_str], dtype=torch.int ) except ValueError: print(f"Found an invalid amino acid in the string: {aa_str}") @@ -88,17 +90,28 @@ def generic_mask_tensor_of(ambig_symb, seq_str, length=None): return mask +def _consider_codon(codon): + """Return False if codon should be masked, True otherwise.""" + if "N" in codon: + return False + elif codon in TOKEN_TRANSLATIONS: + return False + else: + return True + + def codon_mask_tensor_of(nt_parent, *other_nt_seqs, aa_length=None): """Return a mask tensor indicating codons which contain at least one N. Codons beyond the length of the sequence are masked. If other_nt_seqs are provided, - the "and" mask will be computed for all sequences + the "and" mask will be computed for all sequences. + Codons containing marker tokens are also masked. """ if aa_length is None: aa_length = len(nt_parent) // 3 sequences = (nt_parent,) + other_nt_seqs mask = [ - all("N" not in codon for codon in codons) + all(_consider_codon(codon) for codon in codons) for codons in zip(*(iter_codons(sequence) for sequence in sequences)) ] if len(mask) < aa_length: diff --git a/netam/dasm.py b/netam/dasm.py index 5975ab14..e63748c3 100644 --- a/netam/dasm.py +++ b/netam/dasm.py @@ -35,13 +35,17 @@ def update_neutral_probs(self): multihit_model = None # Note we are replacing all Ns with As, which means that we need to be careful # with masking out these positions later. We do this below. + # TODO handle this some other way parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A")) parent_len = len(nt_parent) mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len]) nt_csps = nt_csps[:parent_len, :] nt_mask = mask.repeat_interleave(3)[: len(nt_parent)] - molevol.check_csps(parent_idxs[nt_mask], nt_csps[: len(nt_parent)][nt_mask]) + molevol.check_csps( + parent_idxs[nt_mask], + nt_csps[: len(nt_parent)][nt_mask] + ) neutral_aa_probs = molevol.neutral_aa_probs( parent_idxs.reshape(-1, 3), @@ -139,7 +143,8 @@ def prediction_pair_of_batch(self, batch): raise ValueError( f"log_neutral_aa_probs has non-finite values at relevant positions: {log_neutral_aa_probs[mask]}" ) - log_selection_factors = self.model(aa_parents_idxs, mask) + keep_token_mask = mask | sequences.token_mask_of_aa_idxs(aa_parents_idxs) + log_selection_factors = self.model(aa_parents_idxs, keep_token_mask) return log_neutral_aa_probs, log_selection_factors def predictions_of_pair(self, log_neutral_aa_probs, log_selection_factors): diff --git a/netam/dxsm.py b/netam/dxsm.py index c118b28b..90cb949d 100644 --- a/netam/dxsm.py +++ b/netam/dxsm.py @@ -15,7 +15,6 @@ from tqdm import tqdm from netam.common import ( - MAX_AMBIG_AA_IDX, aa_idx_tensor_of_str_ambig, stack_heterogeneous, codon_mask_tensor_of, @@ -28,6 +27,8 @@ translate_sequences, apply_aa_mask_to_nt_sequence, nt_mutation_frequency, + MAX_AA_TOKEN_IDX, + TOKEN_REGEX, ) @@ -43,8 +44,9 @@ def __init__( branch_lengths: torch.Tensor, multihit_model=None, ): - self.nt_parents = nt_parents - self.nt_children = nt_children + # TODO test this replacement + self.nt_parents = nt_parents.str.replace(TOKEN_REGEX, "N", regex=True) + self.nt_children = nt_children.str.replace(TOKEN_REGEX, "N", regex=True) self.nt_ratess = nt_ratess self.nt_cspss = nt_cspss self.multihit_model = copy.deepcopy(multihit_model) @@ -56,14 +58,16 @@ def __init__( assert len(self.nt_parents) == len(self.nt_children) pcp_count = len(self.nt_parents) - aa_parents = translate_sequences(self.nt_parents) - aa_children = translate_sequences(self.nt_children) + # Important to use the unmodified versions of nt_parents and + # nt_children so they still contain special tokens. + aa_parents = translate_sequences(nt_parents) + aa_children = translate_sequences(nt_children) self.max_aa_seq_len = max(len(seq) for seq in aa_parents) # We have sequences of varying length, so we start with all tensors set # to the ambiguous amino acid, and then will fill in the actual values # below. self.aa_parents_idxss = torch.full( - (pcp_count, self.max_aa_seq_len), MAX_AMBIG_AA_IDX + (pcp_count, self.max_aa_seq_len), MAX_AA_TOKEN_IDX ) self.aa_children_idxss = self.aa_parents_idxss.clone() self.aa_subs_indicators = torch.zeros((pcp_count, self.max_aa_seq_len)) @@ -90,7 +94,7 @@ def __init__( ) assert torch.all(self.masks.sum(dim=1) > 0) - assert torch.max(self.aa_parents_idxss) <= MAX_AMBIG_AA_IDX + assert torch.max(self.aa_parents_idxss) <= MAX_AA_TOKEN_IDX self._branch_lengths = branch_lengths self.update_neutral_probs() @@ -296,9 +300,10 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs): def find_optimal_branch_lengths(self, dataset, **optimization_kwargs): worker_count = min(mp.cpu_count() // 2, 10) - # # The following can be used when one wants a better traceback. - # burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) - # return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) + # TODO + # The following can be used when one wants a better traceback. + burrito = self.__class__(None, dataset, copy.deepcopy(self.model)) + return burrito.serial_find_optimal_branch_lengths(dataset, **optimization_kwargs) our_optimize_branch_length = partial( worker_optimize_branch_length, self.__class__, diff --git a/netam/framework.py b/netam/framework.py index 87d50571..1ea8607c 100644 --- a/netam/framework.py +++ b/netam/framework.py @@ -2,6 +2,7 @@ import copy import os from time import time +from warnings import warn import pandas as pd import numpy as np @@ -352,21 +353,75 @@ def trimmed_shm_model_outputs_of_crepe(crepe, parents): return trimmed_rates, trimmed_csps +def join_chains(pcp_df): + """Join the parent and child chains in the pcp_df. + + Make a parent column that is the parent_h + "^^^" + parent_l, and same for child. + + If parent_h and parent_l are not present, then we assume that the parent is the heavy chain. + If only one of parent_h or parent_l is present, then we place the ^^^ padding to the right of + heavy, or to the left of light. + """ + cols = pcp_df.columns + if "parent_h" in cols and "parent_l" in cols: + assert "child_h" in cols and "child_l" in cols, "child_h or child_l columns missing!" + pcp_df["parent"] = pcp_df["parent_h"] + "^^^" + pcp_df["parent_l"] + pcp_df["child"] = pcp_df["child_h"] + "^^^" + pcp_df["child_l"] + elif "parent_h" in cols and "parent_l" not in cols: + assert "child_h" in cols, "child_h column missing!" + pcp_df["parent"] = pcp_df["parent_h"] + "^^^" + pcp_df["child"] = pcp_df["child_h"] + "^^^" + elif "parent_h" not in cols and "parent_l" in cols: + if "parent" in cols: + warn("Both parent and parent_l columns found. Using only parent_l. " + "To use parent as heavy chain, rename to parent_h.") + assert "child_l" in cols, "child_l column missing!" + pcp_df["parent"] = "^^^" + pcp_df["parent_l"] + pcp_df["child"] = "^^^" + pcp_df["child_l"] + elif "parent" in cols: + assert "child" in cols, "child column missing!" + # We assume that this is the heavy chain. + pcp_df["parent"] += "^^^" + pcp_df["child"] += "^^^" + else: + raise ValueError("Could not find parent and child columns.") + pcp_df.drop(columns=["parent_h", "parent_l", "child_h", "child_l"], inplace=True, errors="ignore") + return pcp_df + + def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None): """Load a PCP dataframe from a gzipped CSV file. `orig_pcp_idx` is the index column from the original file, even if we subset by sampling or by choosing V families. + + If we will join the heavy and light chain sequences into a single + sequence starting with the heavy chain, using a `^^^` separator. If only heavy or light chain + sequence is present, this separator will be added to the appropriate side of the available sequence. """ pcp_df = ( pd.read_csv(pcp_df_path_gz, compression="gzip", index_col=0) .reset_index() .rename(columns={"index": "orig_pcp_idx"}) ) - pcp_df["v_family"] = pcp_df["v_gene"].str.split("-").str[0] - if chosen_v_families is not None: - chosen_v_families = set(chosen_v_families) - pcp_df = pcp_df[pcp_df["v_family"].isin(chosen_v_families)] + pcp_df = join_chains(pcp_df) + if not ("parent" in pcp_df.columns and "child" in pcp_df.columns): + if "parent_h" in pcp_df.columns and "parent_l" in pcp_df.columns: + pcp_df["parent"] = pcp_df["parent_h"] + pcp_df["child"] = pcp_df["child_h"] + pcp_df.drop(columns=["parent_h", "parent_l", "child_h", "child_l"], inplace=True, errors="ignore") + else: + raise ValueError( + "Could not find parent and child columns. " + ) + + # figure out what to do here: TODO this is only needed for oe plotting, but + # the way its set up will fail without a helpful message. + if "v_gene" in pcp_df.columns: + pcp_df["v_family"] = pcp_df["v_gene"].str.split("-").str[0] + if chosen_v_families is not None: + chosen_v_families = set(chosen_v_families) + pcp_df = pcp_df[pcp_df["v_family"].isin(chosen_v_families)] if sample_count is not None: pcp_df = pcp_df.sample(sample_count) pcp_df.reset_index(drop=True, inplace=True) @@ -374,9 +429,15 @@ def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None): def add_shm_model_outputs_to_pcp_df(pcp_df, crepe): - rates, csps = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"]) - pcp_df["nt_rates"] = rates - pcp_df["nt_csps"] = csps + # TODO what happens when one of these is empty? or if there's no split? + split_parents = pcp_df["parent"].copy().str.split(pat="^^^", expand=True, regex=False) + h_parents = split_parents[0] + "NNN" + l_parents = split_parents[1] + + h_rates, h_csps = trimmed_shm_model_outputs_of_crepe(crepe, h_parents) + l_rates, l_csps = trimmed_shm_model_outputs_of_crepe(crepe, l_parents) + pcp_df["nt_rates"] = [torch.cat([h_rate, l_rate], dim=0) for h_rate, l_rate in zip(h_rates, l_rates)] + pcp_df["nt_csps"] = [torch.cat([h_csp, l_csp], dim=0) for h_csp, l_csp in zip(h_csps, l_csps)] return pcp_df diff --git a/netam/models.py b/netam/models.py index 09ebd4d1..1edc8989 100644 --- a/netam/models.py +++ b/netam/models.py @@ -10,8 +10,8 @@ from torch import Tensor from netam.hit_class import apply_multihit_correction +from netam.sequences import MAX_AA_TOKEN_IDX from netam.common import ( - MAX_AMBIG_AA_IDX, aa_idx_tensor_of_str_ambig, PositionalEncoding, generate_kmers, @@ -622,7 +622,7 @@ def __init__( self.nhead = nhead self.dim_feedforward = dim_feedforward self.pos_encoder = PositionalEncoding(self.d_model, dropout_prob) - self.amino_acid_embedding = nn.Embedding(MAX_AMBIG_AA_IDX + 1, self.d_model) + self.amino_acid_embedding = nn.Embedding(MAX_AA_TOKEN_IDX + 1, self.d_model) self.encoder_layer = nn.TransformerEncoderLayer( d_model=self.d_model, nhead=nhead, diff --git a/netam/molevol.py b/netam/molevol.py index 2aef1c10..c089764d 100644 --- a/netam/molevol.py +++ b/netam/molevol.py @@ -9,7 +9,7 @@ import torch from torch import Tensor, optim -from netam.sequences import CODON_AA_INDICATOR_MATRIX +from netam.sequences import CODON_AA_INDICATOR_MATRIX, MAX_AA_TOKEN_IDX import netam.sequences as sequences @@ -444,7 +444,7 @@ def mutsel_log_pcp_probability_of( """ assert len(parent) % 3 == 0 - assert sel_matrix.shape == (len(parent) // 3, 20) + assert sel_matrix.shape == (len(parent) // 3, MAX_AA_TOKEN_IDX + 1) parent_idxs = sequences.nt_idx_tensor_of_str(parent) child_idxs = sequences.nt_idx_tensor_of_str(child) diff --git a/netam/sequences.py b/netam/sequences.py index feea5ad2..4a89d60f 100644 --- a/netam/sequences.py +++ b/netam/sequences.py @@ -1,6 +1,7 @@ """Code for handling sequences and sequence files.""" import itertools +import re import torch import numpy as np @@ -8,14 +9,30 @@ from Bio import SeqIO from Bio.Seq import Seq +BASES = ("A", "C", "G", "T") AA_STR_SORTED = "ACDEFGHIKLMNPQRSTVWY" -NT_STR_SORTED = "ACGT" +# Add additional tokens to this string: +RESERVED_TOKENS = "^" + + +NT_STR_SORTED = "".join(BASES) +BASES_AND_N_TO_INDEX = {base: idx for idx, base in enumerate(NT_STR_SORTED + "N")} +# ambiguous must remain last +TOKEN_STR_SORTED = AA_STR_SORTED + RESERVED_TOKENS + "X" +# TODO add tests for functions that use this: +RESERVED_TOKEN_AA_BOUNDS = (min(TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS), max(TOKEN_STR_SORTED.index(token) for token in RESERVED_TOKENS)) +MAX_AA_TOKEN_IDX = len(TOKEN_STR_SORTED) - 1 CODONS = [ "".join(codon_list) - for codon_list in itertools.product(["A", "C", "G", "T"], repeat=3) + for codon_list in itertools.product(BASES, repeat=3) ] STOP_CODONS = ["TAA", "TAG", "TGA"] +# Each token in RESERVED_TOKENS will appear once in aa strings, and three times +# in nt strings. +TOKEN_TRANSLATIONS = {token * 3: token for token in RESERVED_TOKENS} +# Create a regex pattern +TOKEN_REGEX = f"[{''.join(map(re.escape, list(RESERVED_TOKENS)))}]" def nt_idx_array_of_str(nt_str): """Return the indices of the nucleotides in a string.""" @@ -25,15 +42,21 @@ def nt_idx_array_of_str(nt_str): print(f"Found an invalid nucleotide in the string: {nt_str}") raise - def aa_idx_array_of_str(aa_str): """Return the indices of the amino acids in a string.""" try: - return np.array([AA_STR_SORTED.index(aa) for aa in aa_str]) + return np.array([TOKEN_STR_SORTED.index(aa) for aa in aa_str]) except ValueError: print(f"Found an invalid amino acid in the string: {aa_str}") raise +def aa_idx_array_of_str(aa_str): + """Return the indices of the amino acids in a string.""" + try: + return np.array([TOKEN_STR_SORTED.index(aa) for aa in aa_str]) + except ValueError: + print(f"Found an invalid amino acid in the string: {aa_str}") + raise def nt_idx_tensor_of_str(nt_str): """Return the indices of the nucleotides in a string.""" @@ -47,7 +70,7 @@ def nt_idx_tensor_of_str(nt_str): def aa_idx_tensor_of_str(aa_str): """Return the indices of the amino acids in a string.""" try: - return torch.tensor([AA_STR_SORTED.index(aa) for aa in aa_str]) + return torch.tensor([TOKEN_STR_SORTED.index(aa) for aa in aa_str]) except ValueError: print(f"Found an invalid amino acid in the string: {aa_str}") raise @@ -90,26 +113,31 @@ def read_fasta_sequences(file_path): return sequences -def translate_sequences(nt_sequences): - aa_sequences = [] - for seq in nt_sequences: - if len(seq) % 3 != 0: - raise ValueError(f"The sequence '{seq}' is not a multiple of 3.") - aa_seq = str(Seq(seq).translate()) - if "*" in aa_seq: - raise ValueError(f"The sequence '{seq}' contains a stop codon.") - aa_sequences.append(aa_seq) - return aa_sequences +def translate_codon(codon): + """Translate a codon to an amino acid.""" + if codon in TOKEN_TRANSLATIONS: + return TOKEN_TRANSLATIONS[codon] + else: + return str(Seq(codon).translate()) def translate_sequence(nt_sequence): - return translate_sequences([nt_sequence])[0] + if len(nt_sequence) % 3 != 0: + raise ValueError(f"The sequence '{nt_sequence}' is not a multiple of 3.") + aa_seq = "".join(translate_codon(nt_sequence[i: i + 3]) for i in range(0, len(nt_sequence), 3)) + if "*" in aa_seq: + raise ValueError(f"The sequence '{nt_sequence}' contains a stop codon.") + return aa_seq + + +def translate_sequences(nt_sequences): + return [translate_sequence(seq) for seq in nt_sequences] def aa_index_of_codon(codon): """Return the index of the amino acid encoded by a codon.""" aa = translate_sequence(codon) - return AA_STR_SORTED.index(aa) + return TOKEN_STR_SORTED.index(aa) def generic_mutation_frequency(ambig_symb, parent, child): @@ -159,12 +187,12 @@ def pcp_criteria_check(parent, child, max_mut_freq=0.3): def generate_codon_aa_indicator_matrix(): """Generate a matrix that maps codons (rows) to amino acids (columns).""" - matrix = np.zeros((len(CODONS), len(AA_STR_SORTED))) + matrix = np.zeros((len(CODONS), len(TOKEN_STR_SORTED))) for i, codon in enumerate(CODONS): try: aa = translate_sequences([codon])[0] - aa_idx = AA_STR_SORTED.index(aa) + aa_idx = TOKEN_STR_SORTED.index(aa) matrix[i, aa_idx] = 1 except ValueError: # Handle STOP codon pass @@ -206,3 +234,9 @@ def set_wt_to_nan(predictions: torch.Tensor, aa_sequence: str) -> torch.Tensor: wt_idxs = aa_idx_tensor_of_str(aa_sequence) predictions[torch.arange(len(aa_sequence)), wt_idxs] = float("nan") return predictions + + +def token_mask_of_aa_idxs(aa_idxs: torch.Tensor) -> torch.Tensor: + """Return a mask indicating which positions in an amino acid sequence contain special indicator tokens""" + min_idx, max_idx = RESERVED_TOKEN_AA_BOUNDS + return (aa_idxs <= max_idx) & (aa_idxs >= min_idx) diff --git a/tests/conftest.py b/tests/conftest.py index c88350cb..7ab6ff21 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -8,6 +8,7 @@ @pytest.fixture(scope="module") def pcp_df(): + # TODO add some checking related to joined mode df = load_pcp_df( "data/wyatt-10x-1p5m_pcp_2023-11-30_NI.first100.csv.gz", ) diff --git a/tests/test_dnsm.py b/tests/test_dnsm.py index e0bca099..18b449e7 100644 --- a/tests/test_dnsm.py +++ b/tests/test_dnsm.py @@ -7,14 +7,15 @@ crepe_exists, load_crepe, ) -from netam.common import aa_idx_tensor_of_str_ambig, MAX_AMBIG_AA_IDX, force_spawn +from netam.sequences import MAX_AA_TOKEN_IDX +from netam.common import aa_idx_tensor_of_str_ambig, force_spawn from netam.models import TransformerBinarySelectionModelWiggleAct from netam.dnsm import DNSMBurrito, DNSMDataset def test_aa_idx_tensor_of_str_ambig(): input_seq = "ACX" - expected_output = torch.tensor([0, 1, MAX_AMBIG_AA_IDX], dtype=torch.int) + expected_output = torch.tensor([0, 1, MAX_AA_TOKEN_IDX], dtype=torch.int) output = aa_idx_tensor_of_str_ambig(input_seq) assert torch.equal(output, expected_output) diff --git a/tests/test_sequences.py b/tests/test_sequences.py index 8866214e..42aab70c 100644 --- a/tests/test_sequences.py +++ b/tests/test_sequences.py @@ -5,6 +5,7 @@ from Bio.Data import CodonTable from netam.sequences import ( AA_STR_SORTED, + TOKEN_STR_SORTED, CODONS, CODON_AA_INDICATOR_MATRIX, aa_onehot_tensor_of_str, @@ -14,6 +15,21 @@ ) +def test_token_order(): + # If we always add additional tokens to the end, then converting to indices + # will not be affected when we have a proper aa string. + assert TOKEN_STR_SORTED[:len(AA_STR_SORTED)] == AA_STR_SORTED + + +# TODO implement these tests +def test_token_replace(): + assert False + + +def test_token_mask(): + assert False + + def test_nucleotide_indices_of_codon(): assert nt_idx_array_of_str("AAA").tolist() == [0, 0, 0] assert nt_idx_array_of_str("TAC").tolist() == [3, 0, 1]