Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding per-base inference #9

Merged
merged 53 commits into from
Feb 14, 2024
Merged
Show file tree
Hide file tree
Changes from 50 commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
243d327
enter rsmodels
matsen Jan 31, 2024
a0bf606
Add mutation and base indicator creation function
matsen Jan 31, 2024
d8b2f78
Add new base index vectors to SHMoofDataset
matsen Jan 31, 2024
99f65a1
new base idx indicator
matsen Jan 31, 2024
4033851
can train rsmodel now
matsen Jan 31, 2024
2c1dc5f
make format
matsen Jan 31, 2024
e2c18d0
Remove cnn_med_orig
matsen Jan 31, 2024
6250a91
adding rscnn to experiment
matsen Jan 31, 2024
c8c4814
adding profiling script
matsen Jan 31, 2024
c3d97ee
yes to wt_base_multiplier and no to central_base_mapping
matsen Jan 31, 2024
eee5647
oop continuing prev commit
matsen Jan 31, 2024
9e7d386
cleanup and moving around
matsen Jan 31, 2024
3f73276
format
matsen Jan 31, 2024
df2a609
adding loss weights
matsen Jan 31, 2024
3ccf750
Add JoinedRSCNNModel and IndepRSCNNModel
matsen Feb 1, 2024
8895121
adding wt_base_multiplier to other models
matsen Feb 1, 2024
392dd44
returning multiple losses
matsen Feb 1, 2024
eafcd8b
partway through recording multiple losses
matsen Feb 1, 2024
f63d72c
Recording separate losses
matsen Feb 1, 2024
8f3f25b
wt_base_modifier
matsen Feb 1, 2024
1ed046e
make format
matsen Feb 1, 2024
0d3795c
returning logits
matsen Feb 2, 2024
516cb95
being able to set loss weights
matsen Feb 2, 2024
a9decb5
more complex loss reduction
matsen Feb 2, 2024
f9d7a11
moving out load_shmoof_dataframes
matsen Feb 2, 2024
e5baed0
excluding Ns from being considered as mutations
matsen Feb 3, 2024
d71a41a
RSFivemer
matsen Feb 3, 2024
902b2d0
format
matsen Feb 3, 2024
bd04d58
bring name back
matsen Feb 3, 2024
5e0035b
mask in kmer rates
matsen Feb 3, 2024
e9b13d1
joined and hybrid models
matsen Feb 3, 2024
e61c9be
RSSHMoof
matsen Feb 5, 2024
55676f5
make format
matsen Feb 5, 2024
f144e8c
things seem ready to train dnsms with netam-shm!
matsen Feb 7, 2024
ee0ea44
refactoring framework for generality
matsen Feb 7, 2024
ea3d113
Add informative_site_count function
matsen Feb 7, 2024
a921134
make format
matsen Feb 7, 2024
6c87fe1
Remove unnecessary comment in FivemerModel constructor
matsen Feb 9, 2024
2620baa
appropriate trimming of long sequences
matsen Feb 9, 2024
5e9498f
better names
matsen Feb 9, 2024
1bd098d
loss weight of 0.01
matsen Feb 12, 2024
eccf6e7
cleanup
matsen Feb 12, 2024
834f5b6
docstrings
matsen Feb 12, 2024
93e1fee
commit before moving to directly using log site rates
matsen Feb 13, 2024
f1e7602
using site rate weights directly
matsen Feb 13, 2024
34f5864
comment
matsen Feb 13, 2024
774ee1d
better name
matsen Feb 13, 2024
9fff477
comment
matsen Feb 13, 2024
0b39f5e
making rate masking consistent across models
matsen Feb 13, 2024
d9fb3c3
comment
matsen Feb 13, 2024
30c994f
csv.gz now the format for DNSM; fixing tests
matsen Feb 14, 2024
76ad69a
adding data dir
matsen Feb 14, 2024
2c64ca5
more _ignore dir making
matsen Feb 14, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions netam/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,10 @@
import torch.optim as optim
from torch import nn, Tensor

BIG = 1e9
SMALL_PROB = 1e-6
BASES = ["A", "C", "G", "T"]
BASES_AND_N_TO_INDEX = {"A": 0, "C": 1, "G": 2, "T": 3, "N": 4}
AA_STR_SORTED = "ACDEFGHIKLMNPQRSTVWY"
AA_STR_SORTED_AMBIG = AA_STR_SORTED + "X"
MAX_AMBIG_AA_IDX = len(AA_STR_SORTED_AMBIG) - 1
Expand Down Expand Up @@ -53,6 +55,10 @@ def mask_tensor_of(seq_str, length=None):
return mask


def informative_site_count(seq_str):
return sum(c != "N" for c in seq_str)


def clamp_probability(x: Tensor) -> Tensor:
return torch.clamp(x, min=SMALL_PROB, max=(1.0 - SMALL_PROB))

Expand Down
6 changes: 2 additions & 4 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,10 +239,8 @@ def train_test_datasets_of_pcp_df(pcp_df, train_frac=0.8, branch_length_multipli


class DNSMBurrito(framework.Burrito):
def __init__(self, *args, device=pick_device(), **kwargs):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.device = device
self.model.to(self.device)
self.wrapped_model = WrappedBinaryMutSel(self.model, weights_directory=None)

def load_branch_lengths(self, in_csv_prefix):
Expand Down Expand Up @@ -401,11 +399,11 @@ def burrito_of_model(
l2_regularization_coeff=1e-6,
verbose=False,
):
model.to(device)
burrito = DNSMBurrito(
self.train_dataset,
self.val_dataset,
model,
device=device,
batch_size=batch_size,
learning_rate=learning_rate,
min_learning_rate=min_learning_rate,
Expand Down
18 changes: 10 additions & 8 deletions netam/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,14 +38,14 @@ def build_model_instances(self, prename):
filter_count=14,
dropout_prob=0.1,
),
f"{prename}_cnn_med_orig": models.CNNModel(
f"{prename}_cnn_med": models.CNNModel(
kmer_length=3,
kernel_size=11,
embedding_dim=9,
filter_count=9,
dropout_prob=0.1,
kernel_size=9,
embedding_dim=7,
filter_count=16,
dropout_prob=0.2,
),
f"{prename}_cnn_med": models.CNNModel(
f"{prename}_ind_rscnn_med": models.IndepRSCNNModel(
kmer_length=3,
kernel_size=9,
embedding_dim=7,
Expand Down Expand Up @@ -135,7 +135,8 @@ def train_or_load(

our_burrito_params = deepcopy(self.burrito_params)
our_burrito_params.update(training_params)
burrito = framework.SHMBurrito(
burrito_class = framework.burrito_class_of_model(model)
burrito = burrito_class(
train_dataset, val_dataset, model, verbose=False, **our_burrito_params
)
train_history = burrito.multi_train(epochs=self.epochs)
Expand Down Expand Up @@ -211,7 +212,8 @@ def train_experiment_df(self, experiment_df, pretrained_dir="../pretrained"):

def calculate_loss(self, model, dataset):
model.eval()
burrito = framework.SHMBurrito(
burrito_class = framework.burrito_class_of_model(model)
burrito = burrito_class(
dataset, dataset, model, verbose=False, **self.burrito_params
)
loss = burrito.evaluate()
Expand Down
Loading