diff --git a/notebooks/scratch.ipynb b/notebooks/scratch.ipynb index 4bf08b2..a8b7b1a 100644 --- a/notebooks/scratch.ipynb +++ b/notebooks/scratch.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -11,6 +11,28 @@ "a = [np.random.randint(0, 10, size=(3,2)) for _ in range(10)]\n", "b = [np.random.randint(0, 10, size=(3,2)) for _ in range(10)]" ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "1" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "a = np.array([1,2,3,4])\n", + "b = np.array([4,6,7,8])\n", + "np.intersect1d(a, b).size" + ] } ], "metadata": { diff --git a/test.py b/test.py index 5e98a48..33fb835 100644 --- a/test.py +++ b/test.py @@ -3,7 +3,8 @@ import glob import numpy as np -from test_utils.dataloader import get_dataloader, yield_data_from_tensorflow_dataloader, yield_data_from_pytorch_dataloader +from test_utils.dataloader import get_dataloader, yield_data_from_tensorflow_dataloader, yield_data_from_pytorch_dataloader, \ + s2c_row_inds #def test_model(): @@ -16,6 +17,9 @@ # Test for correct shape def test_correct_shape(): + """ + A function to test if the batch yielded by both Tensorflow and Pytorch has the same shape + """ experiment_config_path = "/users/ngun7t/Documents/cellbox-jun-6/configs_dev/Example.random_partition.json" tensorflow_dataloader_list = get_dataloader(experiment_config_path, tensorflow_code=True) pytorch_dataloader_list = get_dataloader(experiment_config_path, tensorflow_code=False) @@ -44,8 +48,43 @@ def test_correct_shape(): # Test for correct input +def test_single_to_combo(): + """ + A function to test if pytorch and tensorflow dataloaders yield the correct rows in the dataset for s2c experiment + """ + experiment_config_path = "/users/ngun7t/Documents/cellbox-jun-6/configs_dev/Example.single_to_combo.json" + loo_label_dir = "/users/ngun7t/Documents/cellbox-jun-6/data/loo_label.csv" + tensorflow_dataloader_list = get_dataloader(experiment_config_path, tensorflow_code=True) + pytorch_dataloader_list = get_dataloader(experiment_config_path, tensorflow_code=False) + # Get the row index that contains single drugs + rows_with_single_drugs, rows_with_multiple_drugs = s2c_row_inds(loo_label_dir) + # Code to extract the shape of each yield + for tf_dict, torch_dict in zip(tensorflow_dataloader_list, pytorch_dataloader_list): + tf_train_pert, tf_train_expr = yield_data_from_tensorflow_dataloader( + dataloader=tf_dict["iter_train"], + feed_dict=tf_dict["feed_dict"] + ) + torch_train_pert, torch_train_expr = yield_data_from_pytorch_dataloader( + dataloader=torch_dict["iter_train"] + ) + # Assert that the count of batches obtained is equal + assert len(tf_train_pert) == len(torch_train_pert), "Length of number of arrays yield for train pert not equal" + assert len(tf_train_expr) == len(torch_train_expr), "Length of number of arrays yield for train expr not equal" + + # Assert that the shape of each batch is equal, and also it contains the correct row index + for tf_arr, torch_arr in zip(tf_train_pert, torch_train_pert): + assert tf_arr.shape == np.array(torch_arr).shape, f"For pert batches, shape of tf batch = {tf_arr.shape} is not equal to shape of torch batch = {np.array(torch_arr).shape}" + assert np.intersect1d(tf_arr[:, -1], rows_with_multiple_drugs).size == 0, f"batches for tf train set contains data rows that has multiple drugs in s2c mode" + assert np.intersect1d(torch_arr[:, -1], rows_with_multiple_drugs).size == 0, f"batches for torch train set contains data rows that has multiple drugs in s2c mode" + + # Assert that the shape of each batch is equal + for tf_arr, torch_arr in zip(tf_train_expr, torch_train_expr): + assert tf_arr.shape == np.array(torch_arr).shape, f"For expr batches, shape of tf batch = {tf_arr.shape} is not equal to shape of torch batch = {np.array(torch_arr).shape}" + assert np.intersect1d(tf_arr[:, -1], rows_with_multiple_drugs).size == 0, f"batches for tf train set contains data rows that has multiple drugs in s2c mode" + assert np.intersect1d(torch_arr[:, -1], rows_with_multiple_drugs).size == 0, f"batches for torch train set contains data rows that has multiple drugs in s2c mode" + if __name__ == '__main__': diff --git a/test_utils/dataloader.py b/test_utils/dataloader.py index dfa70c9..6adc5a4 100644 --- a/test_utils/dataloader.py +++ b/test_utils/dataloader.py @@ -136,4 +136,15 @@ def yield_data_from_pytorch_dataloader(dataloader): for pert, expr in dataloader: items_pert.append(pert) items_expr.append(expr) - return items_pert, items_expr \ No newline at end of file + return items_pert, items_expr + + +def s2c_row_inds(loo_label_dir): + """ + Identify the rows of the dataset that only has one drug. + The information is stored in the loo_label file + """ + loo_label = pd.read_csv(loo_label_dir, header=None) + rows_with_single_drugs = loo_label.index[(loo_label[[0, 1]] == 0).any(axis=1)].tolist() + rows_with_multiple_drugs = list(set(list(range(loo_label.shape[0]))) - set(rows_with_single_drugs)) + return rows_with_single_drugs, rows_with_multiple_drugs \ No newline at end of file