Skip to content

Commit

Permalink
Merge pull request #55 from BIMSBbioinfo/survival
Browse files Browse the repository at this point in the history
Add support for survival modeling
  • Loading branch information
borauyar authored Feb 27, 2024
2 parents ac68079 + 4148f98 commit 0a01f18
Show file tree
Hide file tree
Showing 9 changed files with 236 additions and 37 deletions.
11 changes: 11 additions & 0 deletions .github/workflows/benchmarks.yml
Original file line number Diff line number Diff line change
Expand Up @@ -75,12 +75,23 @@ jobs:
curl -L -o dataset2.tgz https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis-benchmark-datasets/dataset2.tgz
tar -xzvf dataset2.tgz
- name: Download LGG_GBM_dataset
run: |
curl -L -o lgggbm_tcga_pub_processed.tgz https://bimsbstatic.mdc-berlin.de/akalin/buyar/flexynesis-benchmark-datasets/lgggbm_tcga_pub_processed.tgz
tar -xzvf lgggbm_tcga_pub_processed.tgz
- name: Run DirectPred
shell: bash -l {0}
run: |
conda activate my_env
flexynesis --data_path dataset1 --model_class DirectPred --target_variables Erlotinib --batch_variables Crizotinib --fusion_type early --hpo_iter 1 --features_min 50 --features_top_percentile 5 --log_transform False --data_types gex,cnv --outdir . --prefix erlotinib_direct --early_stop_patience 3 --use_loss_weighting False --evaluate_baseline_performance False
- name: Run DirectPred_TestSurvival
shell: bash -l {0}
run: |
conda activate my_env
flexynesis --data_path lgggbm_tcga_pub_processed --model_class DirectPred --target_variables STUDY --fusion_type intermediate --hpo_iter 1 --features_min 50 --features_top_percentile 5 --log_transform False --data_types mut,cna --outdir . --prefix lgg_surv --early_stop_patience 3 --use_loss_weighting False --evaluate_baseline_performance False --surv_event_var OS_STATUS --surv_time_var OS_MONTHS
- name: Run supervised_vae
shell: bash -l {0}
run: |
Expand Down
10 changes: 8 additions & 2 deletions flexynesis/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,12 @@ def main():
parser.add_argument("--data_path", help="(Required) Path to the folder with train/test data files", type=str, required = True)
parser.add_argument("--model_class", help="(Required) The kind of model class to instantiate", type=str, choices=["DirectPred", "DirectPredCNN", "DirectPredGCNN", "supervised_vae", "MultiTripletNetwork"], required = True)
parser.add_argument("--target_variables", help="(Required) Which variables in 'clin.csv' to use for predictions, comma-separated if multiple", type = str, required = True)
parser.add_argument('--config_path', type=str, default=None, help='Optional path to an external hyperparameter configuration file in YAML format.')
parser.add_argument("--batch_variables",
help="(Optional) Which variables in 'clin.csv' to use for data integration / batch correction, comma-separated if multiple",
type = str, default = None)
parser.add_argument("--surv_event_var", help="Which column in 'clin.csv' to use as event/status indicator for survival modeling", type = str, default = None)
parser.add_argument("--surv_time_var", help="Which column in 'clin.csv' to use as time/duration indicator for survival modeling", type = str, default = None)
parser.add_argument('--config_path', type=str, default=None, help='Optional path to an external hyperparameter configuration file in YAML format.')
parser.add_argument("--fusion_type", help="How to fuse the omics layers", type=str, choices=["early", "intermediate"], default = 'intermediate')
parser.add_argument("--hpo_iter", help="Number of iterations for hyperparameter optimisation", type=int, default = 5)
parser.add_argument("--correlation_threshold", help="Correlation threshold to drop highly redundant features (default: 0.8; set to 1 for no redundancy filtering)", type=float, default = 0.8)
Expand Down Expand Up @@ -100,6 +102,8 @@ class AvailableModels(NamedTuple):
model_class = model_class,
target_variables = args.target_variables,
batch_variables = args.batch_variables,
surv_event_var = args.surv_event_var,
surv_time_var = args.surv_time_var,
config_name = config_name,
config_path = args.config_path,
n_iter=int(args.hpo_iter),
Expand All @@ -111,7 +115,9 @@ class AvailableModels(NamedTuple):

# evaluate predictions
print("Computing model evaluation metrics")
metrics_df = flexynesis.evaluate_wrapper(model.predict(test_dataset), test_dataset)
metrics_df = flexynesis.evaluate_wrapper(model.predict(test_dataset), test_dataset,
surv_event_var=model.surv_event_var,
surv_time_var=model.surv_time_var)
metrics_df.to_csv(os.path.join(args.outdir, '.'.join([args.prefix, 'stats.csv'])), header=True, index=False)

# print known/predicted labels
Expand Down
13 changes: 10 additions & 3 deletions flexynesis/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,13 @@ class HyperparameterTuning:
load_and_convert_config(config_path): Loads and converts a configuration file.
"""
def __init__(self, dataset, model_class, config_name, target_variables,
batch_variables = None, n_iter = 10, config_path = None, plot_losses = False,
batch_variables = None, surv_event_var = None, surv_time_var = None, n_iter = 10, config_path = None, plot_losses = False,
val_size = 0.2, use_loss_weighting = True, early_stop_patience = -1):
self.dataset = dataset
self.model_class = model_class
self.target_variables = target_variables.strip().split(',')
self.surv_event_var = surv_event_var
self.surv_time_var = surv_time_var
self.batch_variables = batch_variables.strip().split(',') if batch_variables is not None else None
self.config_name = config_name
self.n_iter = n_iter
Expand Down Expand Up @@ -79,8 +81,13 @@ def __init__(self, dataset, model_class, config_name, target_variables,
raise ValueError(f"'{self.config_name}' not found in the default config.")

def objective(self, params, current_step, total_steps):
model = self.model_class(params, self.dataset, self.target_variables,
self.batch_variables, self.val_size, self.use_loss_weighting)
model = self.model_class(config = params, dataset = self.dataset,
target_variables = self.target_variables,
batch_variables = self.batch_variables,
surv_event_var = self.surv_event_var,
surv_time_var = self.surv_time_var,
val_size = self.val_size,
use_loss_weighting = self.use_loss_weighting)
print(params)

mycallbacks = [self.progress_bar]
Expand Down
37 changes: 27 additions & 10 deletions flexynesis/models/direct_pred.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,21 @@

from ..modules import *



class DirectPred(pl.LightningModule):
def __init__(self, config, dataset, target_variables, batch_variables = None, val_size = 0.2, use_loss_weighting = True):
def __init__(self, config, dataset, target_variables, batch_variables = None,
surv_event_var = None, surv_time_var = None, val_size = 0.2, use_loss_weighting = True):
super(DirectPred, self).__init__()
self.config = config
self.dataset = dataset
self.target_variables = target_variables
self.surv_event_var = surv_event_var
self.surv_time_var = surv_time_var
# both surv event and time variables are assumed to be numerical variables
# we create only one survival variable for the pair (surv_time_var and surv_event_var)
if self.surv_event_var is not None and self.surv_time_var is not None:
self.target_variables = self.target_variables + [self.surv_event_var]
self.batch_variables = batch_variables
self.variables = target_variables + batch_variables if batch_variables else target_variables
self.variables = self.target_variables + batch_variables if batch_variables else self.target_variables
self.val_size = val_size
self.dat_train, self.dat_val = self.prepare_data()
self.feature_importances = {}
Expand Down Expand Up @@ -137,9 +142,15 @@ def training_step(self, train_batch, batch_idx):
outputs = self.forward(x_list)
losses = {}
for var in self.variables:
y_hat = outputs[var]
y = y_dict[var]
loss = self.compute_loss(var, y, y_hat)
if var == self.surv_event_var:
durations = y_dict[self.surv_time_var]
events = y_dict[self.surv_event_var]
risk_scores = outputs[var] #output of MLP
loss = cox_ph_loss(risk_scores, durations, events)
else:
y_hat = outputs[var]
y = y_dict[var]
loss = self.compute_loss(var, y, y_hat)
losses[var] = loss

total_loss = self.compute_total_loss(losses)
Expand All @@ -165,9 +176,15 @@ def validation_step(self, val_batch, batch_idx):
outputs = self.forward(x_list)
losses = {}
for var in self.variables:
y_hat = outputs[var]
y = y_dict[var]
loss = self.compute_loss(var, y, y_hat)
if var == self.surv_event_var:
durations = y_dict[self.surv_time_var]
events = y_dict[self.surv_event_var]
risk_scores = outputs[var] #output of MLP
loss = cox_ph_loss(risk_scores, durations, events)
else:
y_hat = outputs[var]
y = y_dict[var]
loss = self.compute_loss(var, y, y_hat)
losses[var] = loss
total_loss = sum(losses.values())
losses['val_loss'] = total_loss
Expand Down
35 changes: 27 additions & 8 deletions flexynesis/models/supervised_vae.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,13 +45,20 @@ class supervised_vae(pl.LightningModule):
model = supervised_vae(num_layers=2, input_dims=[100, 200], hidden_dims=[64, 32], latent_dim=16, num_class=1)
"""
def __init__(self, config, dataset, target_variables, batch_variables = None, val_size = 0.2, use_loss_weighting = True):
def __init__(self, config, dataset, target_variables, batch_variables = None,
surv_event_var = None, surv_time_var = None, val_size = 0.2, use_loss_weighting = True):
super(supervised_vae, self).__init__()
self.config = config
self.dataset = dataset
self.target_variables = target_variables
self.surv_event_var = surv_event_var
self.surv_time_var = surv_time_var
# both surv event and time variables are assumed to be numerical variables
# we create only one survival variable for the pair (surv_time_var and surv_event_var)
if self.surv_event_var is not None and self.surv_time_var is not None:
self.target_variables = self.target_variables + [self.surv_event_var]
self.batch_variables = batch_variables
self.variables = target_variables + batch_variables if batch_variables else target_variables
self.variables = self.target_variables + batch_variables if batch_variables else self.target_variables
self.val_size = val_size

self.dat_train, self.dat_val = self.prepare_data()
Expand Down Expand Up @@ -218,9 +225,15 @@ def training_step(self, train_batch, batch_idx):
losses = {'mmd_loss': mmd_loss}

for var in self.variables:
y_hat = outputs[var]
y = y_dict[var]
loss = self.compute_loss(var, y, y_hat)
if var == self.surv_event_var:
durations = y_dict[self.surv_time_var]
events = y_dict[self.surv_event_var]
risk_scores = outputs[var] #output of MLP
loss = cox_ph_loss(risk_scores, durations, events)
else:
y_hat = outputs[var]
y = y_dict[var]
loss = self.compute_loss(var, y, y_hat)
losses[var] = loss

total_loss = self.compute_total_loss(losses)
Expand All @@ -243,9 +256,15 @@ def validation_step(self, val_batch, batch_idx):
# compute loss values for the supervisor heads
losses = {'mmd_loss': mmd_loss}
for var in self.variables:
y_hat = outputs[var]
y = y_dict[var]
loss = self.compute_loss(var, y, y_hat)
if var == self.surv_event_var:
durations = y_dict[self.surv_time_var]
events = y_dict[self.surv_event_var]
risk_scores = outputs[var] #output of MLP
loss = cox_ph_loss(risk_scores, durations, events)
else:
y_hat = outputs[var]
y = y_dict[var]
loss = self.compute_loss(var, y, y_hat)
losses[var] = loss

total_loss = sum(losses.values())
Expand Down
35 changes: 27 additions & 8 deletions flexynesis/models/triplet_encoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,8 @@ def forward(self, x):
class MultiTripletNetwork(pl.LightningModule):
"""
"""
def __init__(self, config, dataset, target_variables, batch_variables = None, val_size = 0.2, use_loss_weighting = True):
def __init__(self, config, dataset, target_variables, batch_variables = None,
surv_event_var = None, surv_time_var = None, val_size = 0.2, use_loss_weighting = True):
"""
Initialize the MultiTripletNetwork with the given parameters.
Expand All @@ -74,8 +75,14 @@ def __init__(self, config, dataset, target_variables, batch_variables = None, va

self.config = config
self.target_variables = target_variables
self.surv_event_var = surv_event_var
self.surv_time_var = surv_time_var
# both surv event and time variables are assumed to be numerical variables
# we create only one survival variable for the pair (surv_time_var and surv_event_var)
if self.surv_event_var is not None and self.surv_time_var is not None:
self.target_variables = self.target_variables + [self.surv_event_var]
self.batch_variables = batch_variables
self.variables = target_variables + batch_variables if batch_variables else target_variables
self.variables = self.target_variables + batch_variables if batch_variables else self.target_variables
self.val_size = val_size
self.dataset = dataset
self.ann = self.dataset.ann
Expand Down Expand Up @@ -212,9 +219,15 @@ def training_step(self, train_batch, batch_idx):
# compute loss values for the supervisor heads
losses = {'triplet_loss': triplet_loss}
for var in self.variables:
y_hat = outputs[var]
y = y_dict[var]
loss = self.compute_loss(var, y, y_hat)
if var == self.surv_event_var:
durations = y_dict[self.surv_time_var]
events = y_dict[self.surv_event_var]
risk_scores = outputs[var] #output of MLP
loss = cox_ph_loss(risk_scores, durations, events)
else:
y_hat = outputs[var]
y = y_dict[var]
loss = self.compute_loss(var, y, y_hat)
losses[var] = loss

total_loss = self.compute_total_loss(losses)
Expand All @@ -231,9 +244,15 @@ def validation_step(self, val_batch, batch_idx):
# compute loss values for the supervisor heads
losses = {'triplet_loss': triplet_loss}
for var in self.variables:
y_hat = outputs[var]
y = y_dict[var]
loss = self.compute_loss(var, y, y_hat)
if var == self.surv_event_var:
durations = y_dict[self.surv_time_var]
events = y_dict[self.surv_event_var]
risk_scores = outputs[var] #output of MLP
loss = cox_ph_loss(risk_scores, durations, events)
else:
y_hat = outputs[var]
y = y_dict[var]
loss = self.compute_loss(var, y, y_hat)
losses[var] = loss

total_loss = sum(losses.values())
Expand Down
41 changes: 40 additions & 1 deletion flexynesis/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch_geometric.nn as gnn


__all__ = ["Encoder", "Decoder", "MLP", "EmbeddingNetwork", "CNN", "GCNN"]
__all__ = ["Encoder", "Decoder", "MLP", "EmbeddingNetwork", "CNN", "GCNN", "cox_ph_loss"]


class Encoder(nn.Module):
Expand Down Expand Up @@ -267,3 +267,42 @@ def forward(self, x, edge_index, batch):
x = self.relu_2(x)
x = self.aggregation(x, batch)
return x


def cox_ph_loss(outputs, durations, events):
"""
Calculate the Cox proportional hazards loss.
Args:
outputs (torch.Tensor): The output log-risk scores from the MLP.
durations (torch.Tensor): The observed times (durations) for each sample.
events (torch.Tensor): The event indicators (1 if event occurred, 0 if censored) for each sample.
Returns:
torch.Tensor: The calculated CoxPH loss.
"""
valid_indices = ~torch.isnan(durations) & ~torch.isnan(events)
if valid_indices.sum() > 0:
outputs = outputs[valid_indices]
events = events[valid_indices]
durations = durations[valid_indices]

# Exponentiate the outputs to get the hazard ratios
hazards = torch.exp(outputs)
# Ensure hazards is at least 1D
if hazards.dim() == 0:
hazards = hazards.unsqueeze(0) # Make hazards 1D if it's a scalar
# Calculate the risk set sum
log_risk_set_sum = torch.log(torch.cumsum(hazards[torch.argsort(durations, descending=True)], dim=0))
# Get the indices that sort the durations in descending order
sorted_indices = torch.argsort(durations, descending=True)
events_sorted = events[sorted_indices]

# Calculate the loss
uncensored_loss = torch.sum(outputs[sorted_indices][events_sorted == 1]) - torch.sum(log_risk_set_sum[events_sorted == 1])
total_loss = -uncensored_loss / torch.sum(events)
else:
total_loss = torch.tensor(0.0, device=outputs.device, requires_grad=True)
if not torch.isfinite(total_loss):
return torch.tensor(0.0, device=outputs.device, requires_grad=True)
return total_loss
Loading

0 comments on commit 0a01f18

Please sign in to comment.