Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

New Dataset superclass #89

Merged
merged 6 commits into from
Nov 26, 2024
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,3 +368,15 @@ def linear_bump_lr(epoch, warmup_epochs, total_epochs, max_lr, min_lr):
epoch - warmup_epochs
)
return lr


def encode_sequences(sequences, encoder):
encoded_parents, wt_base_modifiers = zip(
*[encoder.encode_sequence(sequence) for sequence in sequences]
)
masks = [nt_mask_tensor_of(sequence, encoder.site_count) for sequence in sequences]
return (
torch.stack(encoded_parents),
torch.stack(masks),
torch.stack(wt_base_modifiers),
)
1 change: 1 addition & 0 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import netam.framework as framework
import netam.molevol as molevol
import netam.sequences as sequences
import copy


class DASMDataset(DXSMDataset):
Expand Down
16 changes: 1 addition & 15 deletions netam/dxsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from functools import partial

import torch
from torch.utils.data import Dataset

# Amazingly, using one thread makes things 50x faster for branch length
# optimization on our server.
Expand Down Expand Up @@ -32,7 +31,7 @@
)


class DXSMDataset(Dataset, ABC):
class DXSMDataset(framework.BranchLengthDataset, ABC):
prefix = "dxsm"

def __init__(
Expand Down Expand Up @@ -222,19 +221,6 @@ def branch_lengths(self, new_branch_lengths):
self._branch_lengths = new_branch_lengths
self.update_neutral_probs()

def __len__(self):
return len(self.aa_parents_idxss)

def export_branch_lengths(self, out_csv_path):
pd.DataFrame({"branch_length": self.branch_lengths}).to_csv(
out_csv_path, index=False
)

def load_branch_lengths(self, in_csv_path):
self.branch_lengths = torch.Tensor(
pd.read_csv(in_csv_path)["branch_length"].values
)

@abstractmethod
def update_neutral_probs(self):
pass
Expand Down
55 changes: 24 additions & 31 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
BASES_AND_N_TO_INDEX,
BIG,
VRC01_NT_SEQ,
encode_sequences,
)
from netam import models
import netam.molevol as molevol
Expand Down Expand Up @@ -132,7 +133,23 @@ def parameters(self):
return {}


class SHMoofDataset(Dataset):
class BranchLengthDataset(Dataset):
def __len__(self):
return len(self.branch_lengths)

def export_branch_lengths(self, out_csv_path):
pd.DataFrame({"branch_length": self.branch_lengths}).to_csv(
out_csv_path, index=False
)

def load_branch_lengths(self, in_csv_path):
self.branch_lengths = pd.read_csv(in_csv_path)["branch_length"].values

def __repr__(self):
return f"{self.__class__.__name__}(Size: {len(self)}) on {self.branch_lengths.device}"


class SHMoofDataset(BranchLengthDataset):
def __init__(self, dataframe, kmer_length, site_count):
super().__init__()
self.encoder = KmerSequenceEncoder(kmer_length, site_count)
Expand All @@ -146,9 +163,6 @@ def __init__(self, dataframe, kmer_length, site_count):
) = self.encode_pcps(dataframe)
assert self.encoded_parents.shape[0] == self.branch_lengths.shape[0]

def __len__(self):
return len(self.encoded_parents)

def __getitem__(self, idx):
return (
self.encoded_parents[idx],
Expand All @@ -159,9 +173,6 @@ def __getitem__(self, idx):
self.branch_lengths[idx],
)

def __repr__(self):
return f"{self.__class__.__name__}(Size: {len(self)}) on {self.encoded_parents.device}"

def to(self, device):
self.encoded_parents = self.encoded_parents.to(device)
self.masks = self.masks.to(device)
Expand Down Expand Up @@ -224,9 +235,6 @@ def export_branch_lengths(self, out_csv_path):
}
).to_csv(out_csv_path, index=False)

def load_branch_lengths(self, in_csv_path):
self.branch_lengths = pd.read_csv(in_csv_path)["branch_length"].values


class Crepe:
"""A lightweight wrapper around a model that can be used for prediction but not
Expand All @@ -243,6 +251,11 @@ def __init__(self, encoder, model, training_hyperparameters={}):
self.model = model
self.training_hyperparameters = training_hyperparameters

def __call__(self, sequences):
return self.model.selection_factors_of_sequences(
sequences, encoder=self.encoder
)

@property
def device(self):
return next(self.model.parameters()).device
Expand All @@ -251,27 +264,7 @@ def to(self, device):
self.model.to(device)

def encode_sequences(self, sequences):
encoded_parents, wt_base_modifiers = zip(
*[self.encoder.encode_sequence(sequence) for sequence in sequences]
)
masks = [
nt_mask_tensor_of(sequence, self.encoder.site_count)
for sequence in sequences
]
return (
torch.stack(encoded_parents),
torch.stack(masks),
torch.stack(wt_base_modifiers),
)

def __call__(self, sequences):
encoded_parents, masks, wt_base_modifiers = self.encode_sequences(sequences)
encoded_parents = encoded_parents.to(self.device)
masks = masks.to(self.device)
wt_base_modifiers = wt_base_modifiers.to(self.device)
with torch.no_grad():
outputs = self.model(encoded_parents, masks, wt_base_modifiers)
return tuple(t.detach().cpu() for t in outputs)
return encode_sequences(sequences, self.encoder)

def save(self, prefix):
torch.save(self.model.state_dict(), f"{prefix}.pth")
Expand Down
16 changes: 16 additions & 0 deletions netam/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
PositionalEncoding,
generate_kmers,
aa_mask_tensor_of,
encode_sequences,
)

warnings.filterwarnings(
Expand Down Expand Up @@ -59,6 +60,18 @@ def unfreeze(self):
for param in self.parameters():
param.requires_grad = True

def selection_factors_of_sequences(self, sequences, encoder=None):
if encoder is None:
raise ValueError("An encoder must be provided.")
device = next(self.parameters()).device
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's make this a @property .

encoded_parents, masks, wt_base_modifiers = encode_sequences(sequences, encoder)
encoded_parents = encoded_parents.to(device)
masks = masks.to(device)
wt_base_modifiers = wt_base_modifiers.to(device)
with torch.no_grad():
outputs = self(encoded_parents, masks, wt_base_modifiers)
return tuple(t.detach().cpu() for t in outputs)


class KmerModel(ModelBase):
def __init__(self, kmer_length):
Expand Down Expand Up @@ -536,6 +549,9 @@ class AbstractBinarySelectionModel(ABC, nn.Module):
def __init__(self):
super().__init__()

def selection_factors_of_sequences(self, sequences: list[str], **kwargs) -> Tensor:
return tuple(self.selection_factors_of_aa_str(seq) for seq in sequences)

def selection_factors_of_aa_str(self, aa_str: str) -> Tensor:
"""Do the forward method then exponentiation without gradients from an amino
acid string.
Expand Down