Skip to content

Commit

Permalink
Correct the generation of cfg.loo
Browse files Browse the repository at this point in the history
  • Loading branch information
DesmondYuan committed Feb 15, 2022
1 parent 93df2d0 commit 6004120
Showing 1 changed file with 11 additions and 1 deletion.
12 changes: 11 additions & 1 deletion cellbox/cellbox/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@ def factory(cfg):
cfg.expr_out = tf.compat.v1.placeholder(tf.float32, [None, cfg.n_x], name='expr_out')
cfg.pert = pd.read_csv(os.path.join(cfg.root_dir, cfg.pert_file), header=None, dtype=np.float32)
cfg.expr = pd.read_csv(os.path.join(cfg.root_dir, cfg.expr_file), header=None, dtype=np.float32)
cfg.loo = np.vstack(np.where(cfg.pert!=0)).T + 1
group_df = pd.DataFrame(np.where(cfg.pert != 0), index=['row_id', 'pert_idx']).T.groupby('row_id')
max_combo_degree = group_df.pert_idx.count().max()
cfg.loo = pd.DataFrame(group_df.pert_idx.apply(
lambda x: pad_and_realign(x, max_combo_degree, cfg.n_activity_nodes - 1)
).tolist())

# add noise
if cfg.add_noise_level > 0:
Expand Down Expand Up @@ -68,6 +72,12 @@ def factory(cfg):
return cfg


def pad_and_realign(x, length, idx_shift=0):
x -= idx_shift
padded = np.pad(x, (0, length - len(x)), 'constant')
return padded


def get_tensors(cfg):
# prepare training placeholders
cfg.l1_lambda_placeholder = tf.compat.v1.placeholder(tf.float32, name='l1_lambda')
Expand Down

0 comments on commit 6004120

Please sign in to comment.