From 6004120cb6844c8152ec252b258810f79706abc7 Mon Sep 17 00:00:00 2001 From: Bo Yuan Date: Tue, 15 Feb 2022 12:27:58 -0800 Subject: [PATCH] Correct the generation of cfg.loo --- cellbox/cellbox/dataset.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/cellbox/cellbox/dataset.py b/cellbox/cellbox/dataset.py index 34c6759..bf676d0 100644 --- a/cellbox/cellbox/dataset.py +++ b/cellbox/cellbox/dataset.py @@ -24,7 +24,11 @@ def factory(cfg): cfg.expr_out = tf.compat.v1.placeholder(tf.float32, [None, cfg.n_x], name='expr_out') cfg.pert = pd.read_csv(os.path.join(cfg.root_dir, cfg.pert_file), header=None, dtype=np.float32) cfg.expr = pd.read_csv(os.path.join(cfg.root_dir, cfg.expr_file), header=None, dtype=np.float32) - cfg.loo = np.vstack(np.where(cfg.pert!=0)).T + 1 + group_df = pd.DataFrame(np.where(cfg.pert != 0), index=['row_id', 'pert_idx']).T.groupby('row_id') + max_combo_degree = group_df.pert_idx.count().max() + cfg.loo = pd.DataFrame(group_df.pert_idx.apply( + lambda x: pad_and_realign(x, max_combo_degree, cfg.n_activity_nodes - 1) + ).tolist()) # add noise if cfg.add_noise_level > 0: @@ -68,6 +72,12 @@ def factory(cfg): return cfg +def pad_and_realign(x, length, idx_shift=0): + x -= idx_shift + padded = np.pad(x, (0, length - len(x)), 'constant') + return padded + + def get_tensors(cfg): # prepare training placeholders cfg.l1_lambda_placeholder = tf.compat.v1.placeholder(tf.float32, name='l1_lambda')