Skip to content

Commit

Permalink
One transformer set for each fold and for train_valid dataset
Browse files Browse the repository at this point in the history
  • Loading branch information
stewarthe6 committed Dec 9, 2024
1 parent eadd9d4 commit d44ee76
Show file tree
Hide file tree
Showing 4 changed files with 144 additions and 104 deletions.
2 changes: 1 addition & 1 deletion atomsci/ddm/pipeline/model_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ def load_featurize_data(self, params=None):
# is fitted to the training data only. The transformers are then applied to the training,
# validation and test sets separately.
if not params.split_only:
self.model_wrapper.create_transformers(self.data)
self.model_wrapper.create_transformers(trans.get_all_training_datasets(self.data))
else:
self.run_mode = ''

Expand Down
147 changes: 80 additions & 67 deletions atomsci/ddm/pipeline/model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,9 +283,9 @@ def __init__(self, params, featurizer, ds_client):
self.output_dir = self.params.output_dir
self.model_dir = os.path.join(self.output_dir, 'model')
os.makedirs(self.model_dir, exist_ok=True)
self.transformers = []
self.transformers_x = []
self.transformers_w = []
self.transformers = {}
self.transformers_x = {}
self.transformers_w = {}

# ****************************************************************************************

Expand Down Expand Up @@ -336,7 +336,7 @@ def _create_output_transformers(self, model_dataset):
"""
# TODO: Just a warning, we may have response transformers for classification datasets in the future
if self.params.prediction_type=='regression' and self.params.transformers is True:
self.transformers = [trans.NormalizationTransformerMissingData(transform_y=True, dataset=model_dataset.dataset)]
return [trans.NormalizationTransformerMissingData(transform_y=True, dataset=model_dataset.dataset)]

# ****************************************************************************************

Expand All @@ -351,15 +351,15 @@ def _create_feature_transformers(self, model_dataset):
transformers_x: A list of deepchem transformation objects on featurizers, only if conditions are met.
"""
# Set up transformers for features, if needed
self.transformers_x = trans.create_feature_transformers(self.params, model_dataset)
return trans.create_feature_transformers(self.params, model_dataset)

# ****************************************************************************************

def create_transformers(self, model_dataset):
def create_transformers(self, training_datasets):
"""Initialize transformers for responses, features and weights, and persist them for later.
Args:
model_dataset: The ModelDataset object that handles the current dataset
training_datasets: The ModelDataset object that handles the current dataset
Side effects:
Overwrites the attributes:
Expand All @@ -372,22 +372,23 @@ def create_transformers(self, model_dataset):
params.transformer_key: A string pointing to the dataset key containing the transformer in the datastore, or the path to the transformer
"""
self._create_output_transformers(model_dataset)
for k, td in training_datasets.items():
self.transformers[k] = self._create_output_transformers(td)

self._create_feature_transformers(model_dataset)
self.transformers_x[k] = self._create_feature_transformers(td)

# Set up transformers for weights, if needed
self.transformers_w = trans.create_weight_transformers(self.params, model_dataset)
# Set up transformers for weights, if needed
self.transformers_w[k] = trans.create_weight_transformers(self.params, td)

if len(self.transformers) + len(self.transformers_x) + len(self.transformers_w) > 0:
if len(self.transformers[k]) + len(self.transformers_x[k]) + len(self.transformers_w[k]) > 0:

# Transformers are no longer saved as separate datastore objects; they are included in the model tarball
self.params.transformer_key = os.path.join(self.output_dir, 'transformers.pkl')
with open(self.params.transformer_key, 'wb') as txfmrpkl:
pickle.dump((self.transformers, self.transformers_x, self.transformers_w), txfmrpkl)
self.log.info("Wrote transformers to %s" % self.params.transformer_key)
self.params.transformer_oid = ""
self.params.transformer_bucket = ""
# Transformers are no longer saved as separate datastore objects; they are included in the model tarball
self.params.transformer_key = os.path.join(self.output_dir, f'transformers_{k}.pkl')
with open(self.params.transformer_key, 'wb') as txfmrpkl:
pickle.dump((self.transformers[k], self.transformers_x[k], self.transformers_w[k]), txfmrpkl)
self.log.info("Wrote transformers to %s" % self.params.transformer_key)
self.params.transformer_oid = ""
self.params.transformer_bucket = ""

# ****************************************************************************************

Expand All @@ -400,58 +401,69 @@ def reload_transformers(self):
# Try local path first to check for transformers unpacked from model tarball
if not trans.transformers_needed(self.params):
return
local_path = f"{self.output_dir}/transformers.pkl"
if os.path.exists(local_path):
self.log.info(f"Reloading transformers from model tarball {local_path}")
with open(local_path, 'rb') as txfmr:
transformers_tuple = pickle.load(txfmr)
else:
if self.params.transformer_key is not None:
if self.params.save_results:
self.log.info(f"Reloading transformers from datastore key {self.params.transformer_key}")
transformers_tuple = dsf.retrieve_dataset_by_datasetkey(
dataset_key = self.params.transformer_key,
bucket = self.params.transformer_bucket,
client = self.ds_client )
else:
self.log.info(f"Reloading transformers from file {self.params.transformer_key}")
with open(self.params.transformer_key, 'rb') as txfmr:
transformers_tuple = pickle.load(txfmr)

for i in trans.get_transformer_keys():
# for backwards compatibity if this file exists, all folds use the same transformers
local_path = f"{self.output_dir}/transformers.pkl"
if not os.path.exists(local_path):
local_path = f"{self.output_dir}/transformers_{i}.pkl"

if os.path.exists(local_path):
self.log.info(f"Reloading transformers from model tarball {local_path}")
with open(local_path, 'rb') as txfmr:
transformers_tuple = pickle.load(txfmr)
else:
# Shouldn't happen
raise Exception("Transformers needed to reload model, but no transformer_key specified.")
if self.params.transformer_key is not None:
if self.params.save_results:
self.log.info(f"Reloading transformers from datastore key {self.params.transformer_key}")
transformers_tuple = dsf.retrieve_dataset_by_datasetkey(
dataset_key = self.params.transformer_key,
bucket = self.params.transformer_bucket,
client = self.ds_client )
else:
self.log.info(f"Reloading transformers from file {self.params.transformer_key}")
with open(self.params.transformer_key, 'rb') as txfmr:
transformers_tuple = pickle.load(txfmr)
else:
# Shouldn't happen
raise Exception("Transformers needed to reload model, but no transformer_key specified.")


if len(transformers_tuple) == 3:
self.transformers, self.transformers_x, self.transformers_w = transformers_tuple
else:
self.transformers, self.transformers_x = transformers_tuple
self.transformers_w = []
if len(transformers_tuple) == 3:
ty, tx, tw = transformers_tuple
else:
ty, tx = transformers_tuple
tw = []

self.transformers[i] = ty
self.transformers_x[i] = tx
self.transformers_w[i] = tw

# ****************************************************************************************

def transform_dataset(self, dataset):
def transform_dataset(self, dataset, fold=0):
"""Transform the responses and/or features in the given DeepChem dataset using the current transformers.
Args:
dataset: The DeepChem DiskDataset that contains a dataset
fold (int): Which fold is being transformed.
Returns:
transformed_dataset: The transformed DeepChem DiskDataset
"""
transformed_dataset = dataset
if len(self.transformers) > 0:
if len(self.transformers[fold]) > 0:
self.log.info("Transforming response data")
for transformer in self.transformers:
for transformer in self.transformers[fold]:
transformed_dataset = transformer.transform(transformed_dataset)
if len(self.transformers_x) > 0:
if len(self.transformers_x[fold]) > 0:
self.log.info("Transforming feature data")
for transformer in self.transformers_x:
for transformer in self.transformers_x[fold]:
transformed_dataset = transformer.transform(transformed_dataset)
if len(self.transformers_w) > 0:
if len(self.transformers_w[fold]) > 0:
self.log.info("Transforming weights")
for transformer in self.transformers_w:
for transformer in self.transformers_w[fold]:
transformed_dataset = transformer.transform(transformed_dataset)

return transformed_dataset
Expand Down Expand Up @@ -486,7 +498,7 @@ def get_train_valid_pred_results(self, perf_data):
return perf_data.get_prediction_results()

# ****************************************************************************************
def get_test_perf_data(self, model_dir, model_dataset):
def get_test_perf_data(self, model_dir, model_dataset, fold):
"""Returns the predicted values and metrics for the current test dataset against
the version of the model stored in model_dir, as a PerfData object.
Expand All @@ -506,14 +518,15 @@ def get_test_perf_data(self, model_dir, model_dataset):
# We pass transformed=False to indicate that the preds and uncertainties we get from
# generate_predictions are already untransformed, so that perf_data.get_prediction_results()
# doesn't untransform them again.
if hasattr(self.transformers[0], "ishybrid"):
if hasattr(self.transformers[0][0], "ishybrid"):
# indicate that we are training a hybrid model
# ASDF need to know what to pass in as the y transform now that they are fold dependent.
perf_data = perf.create_perf_data("hybrid", model_dataset, self.transformers, 'test', is_ki=self.params.is_ki, ki_convert_ratio=self.params.ki_convert_ratio, transformed=False)
else:
perf_data = perf.create_perf_data(self.params.prediction_type, model_dataset, self.transformers, 'test', transformed=False)
test_dset = model_dataset.test_dset
test_preds, test_stds = self.generate_predictions(test_dset)
_ = perf_data.accumulate_preds(test_preds, test_dset.ids, test_stds)
_ = perf_data.accumulate_preds(test_preds, test_dset.ids, test_stds, fold=fold)
return perf_data

# ****************************************************************************************
Expand All @@ -532,7 +545,7 @@ def get_test_pred_results(self, model_dir, model_dataset):
return perf_data.get_prediction_results()

# ****************************************************************************************
def get_full_dataset_perf_data(self, model_dataset):
def get_full_dataset_perf_data(self, model_dataset, fold):
"""Returns the predicted values and metrics from the current model for the full current dataset,
as a PerfData object.
Expand All @@ -555,7 +568,7 @@ def get_full_dataset_perf_data(self, model_dataset):
else:
perf_data = perf.create_perf_data(self.params.prediction_type, model_dataset, self.transformers, 'full', transformed=False)
full_preds, full_stds = self.generate_predictions(model_dataset.dataset)
_ = perf_data.accumulate_preds(full_preds, model_dataset.dataset.ids, full_stds)
_ = perf_data.accumulate_preds(full_preds, model_dataset.dataset.ids, full_stds, fold)
return perf_data

# ****************************************************************************************
Expand Down Expand Up @@ -913,10 +926,10 @@ def train_kfold_cv(self, pipeline):
train_pred = self.model.predict(train_dset, [])
test_pred = self.model.predict(test_dset, [])

train_perf = train_perf_data.accumulate_preds(train_pred, train_dset.ids)
test_perf = test_perf_data.accumulate_preds(test_pred, test_dset.ids)
train_perf = train_perf_data.accumulate_preds(train_pred, train_dset.ids, fold=k)
test_perf = test_perf_data.accumulate_preds(test_pred, test_dset.ids, fold=k)

valid_perf = em.accumulate(ei, subset='valid', dset=valid_dset)
valid_perf = em.accumulate(ei, subset='valid', dset=valid_dset, fold=k)
self.log.info("Fold %d, epoch %d: training %s = %.3f, validation %s = %.3f, test %s = %.3f" % (
k, ei, pipeline.metric_type, train_perf, pipeline.metric_type, valid_perf,
pipeline.metric_type, test_perf))
Expand All @@ -939,7 +952,7 @@ def train_kfold_cv(self, pipeline):

for ei in range(self.best_epoch+1):
self.model.fit(fit_dataset, nb_epoch=1, checkpoint_interval=0, restore=False)
train_perf, test_perf = em.update_epoch(ei, train_dset=fit_dataset, test_dset=test_dset)
train_perf, test_perf = em.update_epoch(ei, train_dset=fit_dataset, test_dset=test_dset, fold='train_valid')

self.log.info(f"Combined folds: Epoch {ei}, training {pipeline.metric_type} = {train_perf:.3},"
+ f"test {pipeline.metric_type} = {test_perf:.3}")
Expand Down Expand Up @@ -999,7 +1012,7 @@ def train_with_early_stopping(self, pipeline):
# saved will be the one we created intentionally when we reached a new best validation score.
self.model.fit(train_dset, nb_epoch=1, checkpoint_interval=0)
train_perf, valid_perf, test_perf = em.update_epoch(ei,
train_dset=train_dset, valid_dset=valid_dset, test_dset=test_dset)
train_dset=train_dset, valid_dset=valid_dset, test_dset=test_dset, fold=0)

self.log.info("Epoch %d: training %s = %.3f, validation %s = %.3f, test %s = %.3f" % (
ei, pipeline.metric_type, train_perf, pipeline.metric_type, valid_perf,
Expand Down Expand Up @@ -1455,7 +1468,7 @@ def train(self, pipeline):
valid_loss_ep /= (valid_data.n_ki + valid_data.n_bind)

train_perf, valid_perf, test_perf = em.update_epoch(ei,
train_dset=train_dset, valid_dset=valid_dset, test_dset=test_dset)
train_dset=train_dset, valid_dset=valid_dset, test_dset=test_dset, fold=0)

self.log.info("Epoch %d: training %s = %.3f, training loss = %.3f, validation %s = %.3f, validation loss = %.3f, test %s = %.3f" % (
ei, pipeline.metric_type, train_perf, train_loss_ep, pipeline.metric_type, valid_perf, valid_loss_ep,
Expand Down Expand Up @@ -1650,13 +1663,13 @@ def train(self, pipeline):
self.model.fit(train_dset)

train_pred = self.model.predict(train_dset, [])
train_perf = self.train_perf_data.accumulate_preds(train_pred, train_dset.ids)
train_perf = self.train_perf_data.accumulate_preds(train_pred, train_dset.ids, fold=k)

valid_pred = self.model.predict(valid_dset, [])
valid_perf = self.valid_perf_data.accumulate_preds(valid_pred, valid_dset.ids)
valid_perf = self.valid_perf_data.accumulate_preds(valid_pred, valid_dset.ids, fold=k)

test_pred = self.model.predict(test_dset, [])
test_perf = self.test_perf_data.accumulate_preds(test_pred, test_dset.ids)
test_perf = self.test_perf_data.accumulate_preds(test_pred, test_dset.ids, fold=k)
self.log.info("Fold %d: training %s = %.3f, validation %s = %.3f, test %s = %.3f" % (
k, pipeline.metric_type, train_perf, pipeline.metric_type, valid_perf,
pipeline.metric_type, test_perf))
Expand Down Expand Up @@ -2069,13 +2082,13 @@ def train(self, pipeline):
self.model.fit(train_dset)

train_pred = self.model.predict(train_dset, [])
train_perf = self.train_perf_data.accumulate_preds(train_pred, train_dset.ids)
train_perf = self.train_perf_data.accumulate_preds(train_pred, train_dset.ids, fold=k)

valid_pred = self.model.predict(valid_dset, [])
valid_perf = self.valid_perf_data.accumulate_preds(valid_pred, valid_dset.ids)
valid_perf = self.valid_perf_data.accumulate_preds(valid_pred, valid_dset.ids, fold=k)

test_pred = self.model.predict(test_dset, [])
test_perf = self.test_perf_data.accumulate_preds(test_pred, test_dset.ids)
test_perf = self.test_perf_data.accumulate_preds(test_pred, test_dset.ids, fold=k)
self.log.info("Fold %d: training %s = %.3f, validation %s = %.3f, test %s = %.3f" % (
k, pipeline.metric_type, train_perf, pipeline.metric_type, valid_perf,
pipeline.metric_type, test_perf))
Expand Down
Loading

0 comments on commit d44ee76

Please sign in to comment.