Skip to content

Commit

Permalink
Updated transformer test to correctly test the standard deviation and…
Browse files Browse the repository at this point in the history
… to use a new json file specific to that test
  • Loading branch information
stewarthe6 committed Dec 16, 2024
1 parent 74c1953 commit 58c454d
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 3 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
{
"dataset_key" : "replaced",
"datastore" : "False",
"uncertainty": "False",
"splitter": "scaffold",
"split_valid_frac": "0.20",
"split_test_frac": "0.20",
"split_strategy": "train_valid_test",
"prediction_type": "classification",
"model_choice_score_type": "roc_auc",
"response_cols" : "active",
"id_col": "compound_id",
"smiles_col" : "rdkit_smiles",
"result_dir": "replaced",
"system": "LC",
"transformers": "True",
"model_type": "NN",
"featurizer": "computed_descriptors",
"descriptor_type": "rdkit_raw",
"weight_transform_type": "balancing",
"learning_rate": ".0007",
"layer_sizes": "20,10",
"dropouts": "0.3,0.3",
"save_results": "False",
"max_epochs": "2",
"early_stopping_patience": "2",
"verbose": "False",
"seed":"0"
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,11 @@ def test_balancing_transformer():
def test_all_transformers():
res_dir = tempfile.mkdtemp()
dskey = os.path.join(res_dir, 'special_test_dset.csv')
params = params_w_balan(dskey, res_dir)
params = read_params(
make_relative_to_file('jsons/all_transforms.json'),
dskey,
res_dir
)
make_test_datasets.make_test_dataset_and_split(dskey, params['descriptor_type'])

params['previously_featurized'] = True
Expand Down Expand Up @@ -77,7 +81,7 @@ def test_all_transformers():
# untransformed mean is 10 expected transformed mean is (10 - 0) / 2
assert abs(np.mean(trans_valid_dset.X) - 5) < 1e-4
# untransformed std is 5 expected transformed std is 5/2
assert abs(np.std(trans_valid_dset.X) - (2.5))
assert abs(np.std(trans_valid_dset.X) - (2.5)) < 1e-4
# validation has a 50/50 split. Majority class * 4 should equal oversampled minority class
valid_weights = trans_valid_dset.w
(valid_weight1, valid_weight2), (valid_count1, valid_count2) = np.unique(valid_weights, return_counts=True)
Expand Down Expand Up @@ -132,5 +136,5 @@ def params_w_balan(dset_key, res_dir):
return params

if __name__ == '__main__':
#test_all_transformers()
test_all_transformers()
test_balancing_transformer()

0 comments on commit 58c454d

Please sign in to comment.