Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Dataset superclass #89

Merged
merged 6 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,15 @@ def linear_bump_lr(epoch, warmup_epochs, total_epochs, max_lr, min_lr):
epoch - warmup_epochs
)
return lr


def encode_sequences(sequences, encoder):
encoded_parents, wt_base_modifiers = zip(
*[encoder.encode_sequence(sequence) for sequence in sequences]
)
masks = [nt_mask_tensor_of(sequence, encoder.site_count) for sequence in sequences]
return (
torch.stack(encoded_parents),
torch.stack(masks),
torch.stack(wt_base_modifiers),
)
1 change: 1 addition & 0 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import netam.framework as framework
import netam.molevol as molevol
import netam.sequences as sequences
import copy


class DASMDataset(DXSMDataset):
Expand Down
16 changes: 1 addition & 15 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import partial

import torch
from torch.utils.data import Dataset

# Amazingly, using one thread makes things 50x faster for branch length
# optimization on our server.
Expand Down Expand Up @@ -32,7 +31,7 @@
)


class DXSMDataset(Dataset, ABC):
class DXSMDataset(framework.BranchLengthDataset, ABC):
prefix = "dxsm"

def __init__(
Expand Down Expand Up @@ -222,19 +221,6 @@ def branch_lengths(self, new_branch_lengths):
self._branch_lengths = new_branch_lengths
self.update_neutral_probs()

def __len__(self):
return len(self.aa_parents_idxss)

def export_branch_lengths(self, out_csv_path):
pd.DataFrame({"branch_length": self.branch_lengths}).to_csv(
out_csv_path, index=False
)

def load_branch_lengths(self, in_csv_path):
self.branch_lengths = torch.Tensor(
pd.read_csv(in_csv_path)["branch_length"].values
)

@abstractmethod
def update_neutral_probs(self):
pass
Expand Down
53 changes: 22 additions & 31 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
BASES_AND_N_TO_INDEX,
BIG,
VRC01_NT_SEQ,
encode_sequences,
)
from netam import models
import netam.molevol as molevol
Expand Down Expand Up @@ -132,7 +133,23 @@ def parameters(self):
return {}


class SHMoofDataset(Dataset):
class BranchLengthDataset(Dataset):
def __len__(self):
return len(self.branch_lengths)

def export_branch_lengths(self, out_csv_path):
pd.DataFrame({"branch_length": self.branch_lengths}).to_csv(
out_csv_path, index=False
)

def load_branch_lengths(self, in_csv_path):
self.branch_lengths = pd.read_csv(in_csv_path)["branch_length"].values

def __repr__(self):
return f"{self.__class__.__name__}(Size: {len(self)}) on {self.branch_lengths.device}"


class SHMoofDataset(BranchLengthDataset):
def __init__(self, dataframe, kmer_length, site_count):
super().__init__()
self.encoder = KmerSequenceEncoder(kmer_length, site_count)
Expand All @@ -146,9 +163,6 @@ def __init__(self, dataframe, kmer_length, site_count):
) = self.encode_pcps(dataframe)
assert self.encoded_parents.shape[0] == self.branch_lengths.shape[0]

def __len__(self):
return len(self.encoded_parents)

def __getitem__(self, idx):
return (
self.encoded_parents[idx],
Expand All @@ -159,9 +173,6 @@ def __getitem__(self, idx):
self.branch_lengths[idx],
)

def __repr__(self):
return f"{self.__class__.__name__}(Size: {len(self)}) on {self.encoded_parents.device}"

def to(self, device):
self.encoded_parents = self.encoded_parents.to(device)
self.masks = self.masks.to(device)
Expand Down Expand Up @@ -224,9 +235,6 @@ def export_branch_lengths(self, out_csv_path):
}
).to_csv(out_csv_path, index=False)

def load_branch_lengths(self, in_csv_path):
self.branch_lengths = pd.read_csv(in_csv_path)["branch_length"].values


class Crepe:
"""A lightweight wrapper around a model that can be used for prediction but not
Expand All @@ -243,6 +251,9 @@ def __init__(self, encoder, model, training_hyperparameters={}):
self.model = model
self.training_hyperparameters = training_hyperparameters

def __call__(self, sequences):
return self.model.evaluate_sequences(sequences, encoder=self.encoder)

@property
def device(self):
return next(self.model.parameters()).device
Expand All @@ -251,27 +262,7 @@ def to(self, device):
self.model.to(device)

def encode_sequences(self, sequences):
encoded_parents, wt_base_modifiers = zip(
*[self.encoder.encode_sequence(sequence) for sequence in sequences]
)
masks = [
nt_mask_tensor_of(sequence, self.encoder.site_count)
for sequence in sequences
]
return (
torch.stack(encoded_parents),
torch.stack(masks),
torch.stack(wt_base_modifiers),
)

def __call__(self, sequences):
encoded_parents, masks, wt_base_modifiers = self.encode_sequences(sequences)
encoded_parents = encoded_parents.to(self.device)
masks = masks.to(self.device)
wt_base_modifiers = wt_base_modifiers.to(self.device)
with torch.no_grad():
outputs = self.model(encoded_parents, masks, wt_base_modifiers)
return tuple(t.detach().cpu() for t in outputs)
return encode_sequences(sequences, self.encoder)

def save(self, prefix):
torch.save(self.model.state_dict(), f"{prefix}.pth")
Expand Down
29 changes: 25 additions & 4 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
PositionalEncoding,
generate_kmers,
aa_mask_tensor_of,
encode_sequences,
)

warnings.filterwarnings(
Expand Down Expand Up @@ -49,6 +50,10 @@ def reinitialize_weights(self):
else:
raise ValueError(f"Unrecognized layer type: {type(layer)}")

@property
def device(self):
return next(self.parameters()).device

def freeze(self):
"""Freeze all parameters in the model, disabling gradient computations."""
for param in self.parameters():
Expand All @@ -59,6 +64,17 @@ def unfreeze(self):
for param in self.parameters():
param.requires_grad = True

def evaluate_sequences(self, sequences, encoder=None):
if encoder is None:
raise ValueError("An encoder must be provided.")
encoded_parents, masks, wt_base_modifiers = encode_sequences(sequences, encoder)
encoded_parents = encoded_parents.to(self.device)
masks = masks.to(self.device)
wt_base_modifiers = wt_base_modifiers.to(self.device)
with torch.no_grad():
outputs = self(encoded_parents, masks, wt_base_modifiers)
return tuple(t.detach().cpu() for t in outputs)


class KmerModel(ModelBase):
def __init__(self, kmer_length):
Expand Down Expand Up @@ -536,6 +552,13 @@ class AbstractBinarySelectionModel(ABC, nn.Module):
def __init__(self):
super().__init__()

@property
def device(self):
return next(self.parameters()).device

def evaluate_sequences(self, sequences: list[str], **kwargs) -> Tensor:
return tuple(self.selection_factors_of_aa_str(seq) for seq in sequences)

def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
"""Do the forward method then exponentiation without gradients from an amino
acid string.
Expand All @@ -548,12 +571,10 @@ def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
the level of selection for each amino acid at each site.
"""

model_device = next(self.parameters()).device

aa_idxs = aa_idx_tensor_of_str_ambig(aa_str)
aa_idxs = aa_idxs.to(model_device)
aa_idxs = aa_idxs.to(self.device)
mask = aa_mask_tensor_of(aa_str)
mask = mask.to(model_device)
mask = mask.to(self.device)

with torch.no_grad():
model_out = self(aa_idxs.unsqueeze(0), mask.unsqueeze(0)).squeeze(0)
Expand Down