Skip to content

Commit

Permalink
Add test for s2c mode
Browse files Browse the repository at this point in the history
  • Loading branch information
Phuc Nguyen committed Jun 15, 2023
1 parent c4d2076 commit 45a9394
Show file tree
Hide file tree
Showing 3 changed files with 75 additions and 3 deletions.
24 changes: 23 additions & 1 deletion notebooks/scratch.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -11,6 +11,28 @@
"a = [np.random.randint(0, 10, size=(3,2)) for _ in range(10)]\n",
"b = [np.random.randint(0, 10, size=(3,2)) for _ in range(10)]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"1"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"a = np.array([1,2,3,4])\n",
"b = np.array([4,6,7,8])\n",
"np.intersect1d(a, b).size"
]
}
],
"metadata": {
Expand Down
41 changes: 40 additions & 1 deletion test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
import glob
import numpy as np

from test_utils.dataloader import get_dataloader, yield_data_from_tensorflow_dataloader, yield_data_from_pytorch_dataloader
from test_utils.dataloader import get_dataloader, yield_data_from_tensorflow_dataloader, yield_data_from_pytorch_dataloader, \
s2c_row_inds


#def test_model():
Expand All @@ -16,6 +17,9 @@

# Test for correct shape
def test_correct_shape():
"""
A function to test if the batch yielded by both Tensorflow and Pytorch has the same shape
"""
experiment_config_path = "/users/ngun7t/Documents/cellbox-jun-6/configs_dev/Example.random_partition.json"
tensorflow_dataloader_list = get_dataloader(experiment_config_path, tensorflow_code=True)
pytorch_dataloader_list = get_dataloader(experiment_config_path, tensorflow_code=False)
Expand Down Expand Up @@ -44,8 +48,43 @@ def test_correct_shape():


# Test for correct input
def test_single_to_combo():
"""
A function to test if pytorch and tensorflow dataloaders yield the correct rows in the dataset for s2c experiment
"""
experiment_config_path = "/users/ngun7t/Documents/cellbox-jun-6/configs_dev/Example.single_to_combo.json"
loo_label_dir = "/users/ngun7t/Documents/cellbox-jun-6/data/loo_label.csv"
tensorflow_dataloader_list = get_dataloader(experiment_config_path, tensorflow_code=True)
pytorch_dataloader_list = get_dataloader(experiment_config_path, tensorflow_code=False)

# Get the row index that contains single drugs
rows_with_single_drugs, rows_with_multiple_drugs = s2c_row_inds(loo_label_dir)

# Code to extract the shape of each yield
for tf_dict, torch_dict in zip(tensorflow_dataloader_list, pytorch_dataloader_list):
tf_train_pert, tf_train_expr = yield_data_from_tensorflow_dataloader(
dataloader=tf_dict["iter_train"],
feed_dict=tf_dict["feed_dict"]
)
torch_train_pert, torch_train_expr = yield_data_from_pytorch_dataloader(
dataloader=torch_dict["iter_train"]
)
# Assert that the count of batches obtained is equal
assert len(tf_train_pert) == len(torch_train_pert), "Length of number of arrays yield for train pert not equal"
assert len(tf_train_expr) == len(torch_train_expr), "Length of number of arrays yield for train expr not equal"

# Assert that the shape of each batch is equal, and also it contains the correct row index
for tf_arr, torch_arr in zip(tf_train_pert, torch_train_pert):
assert tf_arr.shape == np.array(torch_arr).shape, f"For pert batches, shape of tf batch = {tf_arr.shape} is not equal to shape of torch batch = {np.array(torch_arr).shape}"
assert np.intersect1d(tf_arr[:, -1], rows_with_multiple_drugs).size == 0, f"batches for tf train set contains data rows that has multiple drugs in s2c mode"
assert np.intersect1d(torch_arr[:, -1], rows_with_multiple_drugs).size == 0, f"batches for torch train set contains data rows that has multiple drugs in s2c mode"

# Assert that the shape of each batch is equal
for tf_arr, torch_arr in zip(tf_train_expr, torch_train_expr):
assert tf_arr.shape == np.array(torch_arr).shape, f"For expr batches, shape of tf batch = {tf_arr.shape} is not equal to shape of torch batch = {np.array(torch_arr).shape}"
assert np.intersect1d(tf_arr[:, -1], rows_with_multiple_drugs).size == 0, f"batches for tf train set contains data rows that has multiple drugs in s2c mode"
assert np.intersect1d(torch_arr[:, -1], rows_with_multiple_drugs).size == 0, f"batches for torch train set contains data rows that has multiple drugs in s2c mode"


if __name__ == '__main__':

Expand Down
13 changes: 12 additions & 1 deletion test_utils/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,4 +136,15 @@ def yield_data_from_pytorch_dataloader(dataloader):
for pert, expr in dataloader:
items_pert.append(pert)
items_expr.append(expr)
return items_pert, items_expr
return items_pert, items_expr


def s2c_row_inds(loo_label_dir):
"""
Identify the rows of the dataset that only has one drug.
The information is stored in the loo_label file
"""
loo_label = pd.read_csv(loo_label_dir, header=None)
rows_with_single_drugs = loo_label.index[(loo_label[[0, 1]] == 0).any(axis=1)].tolist()
rows_with_multiple_drugs = list(set(list(range(loo_label.shape[0]))) - set(rows_with_single_drugs))
return rows_with_single_drugs, rows_with_multiple_drugs

0 comments on commit 45a9394

Please sign in to comment.