Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Renaming to CSP where appropriate, and other related things #75

Merged
merged 11 commits into from
Oct 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 15 additions & 15 deletions netam/dasm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,38 +25,38 @@ class DASMDataset(dnsm.DNSMDataset):
def update_neutral_probs(self):
neutral_aa_probs_l = []

for nt_parent, mask, rates, branch_length, subs_probs in zip(
for nt_parent, mask, nt_rates, branch_length, nt_csps in zip(
self.nt_parents,
self.mask,
self.all_rates,
self.nt_ratess,
self._branch_lengths,
self.all_subs_probs,
self.nt_cspss,
):
mask = mask.to("cpu")
rates = rates.to("cpu")
subs_probs = subs_probs.to("cpu")
nt_rates = nt_rates.to("cpu")
nt_csps = nt_csps.to("cpu")
# Note we are replacing all Ns with As, which means that we need to be careful
# with masking out these positions later. We do this below.
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
parent_len = len(nt_parent)

mut_probs = 1.0 - torch.exp(-branch_length * rates[:parent_len])
normed_subs_probs = molevol.normalize_sub_probs(
parent_idxs, subs_probs[:parent_len, :]
mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len])
normed_nt_csps = molevol.normalize_sub_probs(
parent_idxs, nt_csps[:parent_len, :]
)

neutral_aa_probs = molevol.neutral_aa_probs(
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
normed_subs_probs.reshape(-1, 3, 4),
normed_nt_csps.reshape(-1, 3, 4),
)

if not torch.isfinite(neutral_aa_probs).all():
print(f"Found a non-finite neutral_aa_probs")
print(f"nt_parent: {nt_parent}")
print(f"mask: {mask}")
print(f"rates: {rates}")
print(f"subs_probs: {subs_probs}")
print(f"nt_rates: {nt_rates}")
print(f"nt_csps: {nt_csps}")
print(f"branch_length: {branch_length}")
raise ValueError(f"neutral_aa_probs is not finite: {neutral_aa_probs}")

Expand Down Expand Up @@ -84,8 +84,8 @@ def __getitem__(self, idx):
"subs_indicator": self.aa_subs_indicator_tensor[idx],
"mask": self.mask[idx],
"log_neutral_aa_probs": self.log_neutral_aa_probs[idx],
"rates": self.all_rates[idx],
"subs_probs": self.all_subs_probs[idx],
"nt_rates": self.nt_ratess[idx],
"nt_csps": self.nt_cspss[idx],
}

def to(self, device):
Expand All @@ -94,8 +94,8 @@ def to(self, device):
self.aa_subs_indicator_tensor = self.aa_subs_indicator_tensor.to(device)
self.mask = self.mask.to(device)
self.log_neutral_aa_probs = self.log_neutral_aa_probs.to(device)
self.all_rates = self.all_rates.to(device)
self.all_subs_probs = self.all_subs_probs.to(device)
self.nt_ratess = self.nt_ratess.to(device)
self.nt_cspss = self.nt_cspss.to(device)
if self.multihit_model is not None:
self.multihit_model = self.multihit_model.to(device)

Expand Down
78 changes: 40 additions & 38 deletions netam/dnsm.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,15 +40,15 @@ def __init__(
self,
nt_parents: pd.Series,
nt_children: pd.Series,
all_rates: torch.Tensor,
all_subs_probs: torch.Tensor,
nt_ratess: torch.Tensor,
nt_cspss: torch.Tensor,
branch_lengths: torch.Tensor,
multihit_model=None,
):
self.nt_parents = nt_parents
self.nt_children = nt_children
self.all_rates = all_rates
self.all_subs_probs = all_subs_probs
self.nt_ratess = nt_ratess
self.nt_cspss = nt_cspss
self.multihit_model = copy.deepcopy(multihit_model)
if multihit_model is not None:
# We want these parameters to act like fixed data. This is essential
Expand Down Expand Up @@ -98,8 +98,8 @@ def of_seriess(
cls,
nt_parents: pd.Series,
nt_children: pd.Series,
all_rates_series: pd.Series,
all_subs_probs_series: pd.Series,
nt_rates_series: pd.Series,
nt_csps_series: pd.Series,
branch_length_multiplier=5.0,
multihit_model=None,
):
Expand All @@ -119,8 +119,8 @@ def of_seriess(
return cls(
nt_parents.reset_index(drop=True),
nt_children.reset_index(drop=True),
stack_heterogeneous(all_rates_series.reset_index(drop=True)),
stack_heterogeneous(all_subs_probs_series.reset_index(drop=True)),
stack_heterogeneous(nt_rates_series.reset_index(drop=True)),
stack_heterogeneous(nt_csps_series.reset_index(drop=True)),
initial_branch_lengths,
multihit_model=multihit_model,
)
Expand All @@ -129,12 +129,14 @@ def of_seriess(
def of_pcp_df(cls, pcp_df, branch_length_multiplier=5.0, multihit_model=None):
"""Alternative constructor that takes in a pcp_df and calculates the initial
branch lengths."""
assert "rates" in pcp_df.columns, "pcp_df must have a neutral rates column"
assert (
"nt_rates" in pcp_df.columns
), "pcp_df must have a neutral nt_rates column"
return cls.of_seriess(
pcp_df["parent"],
pcp_df["child"],
pcp_df["rates"],
pcp_df["subs_probs"],
pcp_df["nt_rates"],
pcp_df["nt_csps"],
branch_length_multiplier=branch_length_multiplier,
multihit_model=multihit_model,
)
Expand Down Expand Up @@ -172,8 +174,8 @@ def clone(self):
new_dataset = self.__class__(
self.nt_parents,
self.nt_children,
self.all_rates.copy(),
self.all_subs_probs.copy(),
self.nt_ratess.copy(),
self.nt_cspss.copy(),
self._branch_lengths.copy(),
multihit_model=self.multihit_model,
)
Expand All @@ -189,8 +191,8 @@ def subset_via_indices(self, indices):
new_dataset = self.__class__(
self.nt_parents[indices].reset_index(drop=True),
self.nt_children[indices].reset_index(drop=True),
self.all_rates[indices],
self.all_subs_probs[indices],
self.nt_ratess[indices],
self.nt_cspss[indices],
self._branch_lengths[indices],
multihit_model=self.multihit_model,
)
Expand Down Expand Up @@ -240,16 +242,16 @@ def update_neutral_probs(self):
"""
neutral_aa_mut_prob_l = []

for nt_parent, mask, rates, branch_length, subs_probs in zip(
for nt_parent, mask, nt_rates, branch_length, nt_csps in zip(
self.nt_parents,
self.mask,
self.all_rates,
self.nt_ratess,
self._branch_lengths,
self.all_subs_probs,
self.nt_cspss,
):
mask = mask.to("cpu")
rates = rates.to("cpu")
subs_probs = subs_probs.to("cpu")
nt_rates = nt_rates.to("cpu")
nt_csps = nt_csps.to("cpu")
if self.multihit_model is not None:
multihit_model = copy.deepcopy(self.multihit_model).to("cpu")
else:
Expand All @@ -259,24 +261,24 @@ def update_neutral_probs(self):
parent_idxs = sequences.nt_idx_tensor_of_str(nt_parent.replace("N", "A"))
parent_len = len(nt_parent)

mut_probs = 1.0 - torch.exp(-branch_length * rates[:parent_len])
normed_subs_probs = molevol.normalize_sub_probs(
parent_idxs, subs_probs[:parent_len, :]
mut_probs = 1.0 - torch.exp(-branch_length * nt_rates[:parent_len])
normed_nt_csps = molevol.normalize_sub_probs(
parent_idxs, nt_csps[:parent_len, :]
)

neutral_aa_mut_prob = molevol.neutral_aa_mut_probs(
parent_idxs.reshape(-1, 3),
mut_probs.reshape(-1, 3),
normed_subs_probs.reshape(-1, 3, 4),
normed_nt_csps.reshape(-1, 3, 4),
multihit_model=multihit_model,
)

if not torch.isfinite(neutral_aa_mut_prob).all():
print(f"Found a non-finite neutral_aa_mut_prob")
print(f"nt_parent: {nt_parent}")
print(f"mask: {mask}")
print(f"rates: {rates}")
print(f"subs_probs: {subs_probs}")
print(f"nt_rates: {nt_rates}")
print(f"nt_csps: {nt_csps}")
print(f"branch_length: {branch_length}")
raise ValueError(
f"neutral_aa_mut_prob is not finite: {neutral_aa_mut_prob}"
Expand Down Expand Up @@ -308,17 +310,17 @@ def __getitem__(self, idx):
"subs_indicator": self.aa_subs_indicator_tensor[idx],
"mask": self.mask[idx],
"log_neutral_aa_mut_probs": self.log_neutral_aa_mut_probs[idx],
"rates": self.all_rates[idx],
"subs_probs": self.all_subs_probs[idx],
"nt_rates": self.nt_ratess[idx],
"nt_csps": self.nt_cspss[idx],
}

def to(self, device):
self.aa_parents_idxs = self.aa_parents_idxs.to(device)
self.aa_subs_indicator_tensor = self.aa_subs_indicator_tensor.to(device)
self.mask = self.mask.to(device)
self.log_neutral_aa_mut_probs = self.log_neutral_aa_mut_probs.to(device)
self.all_rates = self.all_rates.to(device)
self.all_subs_probs = self.all_subs_probs.to(device)
self.nt_ratess = self.nt_ratess.to(device)
self.nt_cspss = self.nt_cspss.to(device)
if self.multihit_model is not None:
self.multihit_model = self.multihit_model.to(device)

Expand Down Expand Up @@ -390,8 +392,8 @@ def _find_optimal_branch_length(
self,
parent,
child,
rates,
subs_probs,
nt_rates,
nt_csps,
starting_branch_length,
multihit_model,
**optimization_kwargs,
Expand All @@ -400,7 +402,7 @@ def _find_optimal_branch_length(
return 0.0
sel_matrix = self.build_selection_matrix_from_parent(parent)
log_pcp_probability = molevol.mutsel_log_pcp_probability_of(
sel_matrix, parent, child, rates, subs_probs, multihit_model
sel_matrix, parent, child, nt_rates, nt_csps, multihit_model
)
if isinstance(starting_branch_length, torch.Tensor):
starting_branch_length = starting_branch_length.detach().item()
Expand All @@ -412,12 +414,12 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
optimal_lengths = []
failed_count = 0

for parent, child, rates, subs_probs, starting_length in tqdm(
for parent, child, nt_rates, nt_csps, starting_length in tqdm(
zip(
dataset.nt_parents,
dataset.nt_children,
dataset.all_rates,
dataset.all_subs_probs,
dataset.nt_ratess,
dataset.nt_cspss,
dataset.branch_lengths,
),
total=len(dataset.nt_parents),
Expand All @@ -426,8 +428,8 @@ def serial_find_optimal_branch_lengths(self, dataset, **optimization_kwargs):
branch_length, failed_to_converge = self._find_optimal_branch_length(
parent,
child,
rates[: len(parent)],
subs_probs[: len(parent), :],
nt_rates[: len(parent)],
nt_csps[: len(parent), :],
starting_length,
dataset.multihit_model,
**optimization_kwargs,
Expand Down
4 changes: 2 additions & 2 deletions netam/framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,8 @@ def load_pcp_df(pcp_df_path_gz, sample_count=None, chosen_v_families=None):
def add_shm_model_outputs_to_pcp_df(pcp_df, crepe_prefix, device=None):
crepe = load_crepe(crepe_prefix, device=device)
rates, csps = trimmed_shm_model_outputs_of_crepe(crepe, pcp_df["parent"])
pcp_df["rates"] = rates
pcp_df["subs_probs"] = csps
pcp_df["nt_rates"] = rates
pcp_df["nt_csps"] = csps
return pcp_df


Expand Down
Loading