diff --git a/cellbox/cellbox/__init__.py b/cellbox/cellbox/__init__.py index c9da2f2..0a6da15 100644 --- a/cellbox/cellbox/__init__.py +++ b/cellbox/cellbox/__init__.py @@ -10,4 +10,3 @@ from cellbox.version import __version__, VERSION, get_msg get_msg() -# diff --git a/cellbox/cellbox/dataset.py b/cellbox/cellbox/dataset.py index 15854f7..21eae11 100644 --- a/cellbox/cellbox/dataset.py +++ b/cellbox/cellbox/dataset.py @@ -8,12 +8,23 @@ import numpy as np import pandas as pd import tensorflow.compat.v1 as tf +from typing import Mapping, Any from scipy import sparse tf.disable_v2_behavior() def factory(cfg): - """formulate training dataset""" + """Formulates the training dataset. + + This factory conducts the following three steps of data processing. + (1) Create variable placeholders for the perturbation and expression + vectors (input and output). + (2) [Optional] Add noise to the loaded data. This was used for + corruption analyses. + (3) Data partitioning given cellbox.config.Config.experiment_type. + The results are in the form of a dictionary. + (4) Creates a feeding dictionary for the session call. + """ # Prepare data if cfg.sparse_data: cfg.pert_in = tf.compat.v1.sparse.placeholder(tf.float32, [None, cfg.n_x], name='pert_in') @@ -31,7 +42,7 @@ def factory(cfg): lambda x: pad_and_realign(x, max_combo_degree, cfg.n_activity_nodes - 1) ).tolist()) - # add noise + # Add noise if cfg.add_noise_level > 0: np.random.seed(cfg.seed) assert not cfg.sparse_data, "Adding noise to sparse data format is yet to be supported" @@ -73,13 +84,15 @@ def factory(cfg): return cfg -def pad_and_realign(x, length, idx_shift=0): +def pad_and_realign(x: tf.Tensor, length: int, idx_shift: int=0) -> tf.Tensor: + """Add zeros to the given tensor of perturbation indices.""" x -= idx_shift padded = np.pad(x, (0, length - len(x)), 'constant') return padded -def get_tensors(cfg): +def get_tensors(cfg) -> None: + """Gets the dataset iterators and regularization placeholders.""" # prepare training placeholders cfg.l1_lambda_placeholder = tf.compat.v1.placeholder(tf.float32, name='l1_lambda') cfg.l2_lambda_placeholder = tf.compat.v1.placeholder(tf.float32, name='l2_lambda') @@ -95,8 +108,8 @@ def get_tensors(cfg): return cfg -def s2c(cfg): - """data parition for single-to-combo experiments""" +def s2c(cfg) -> Mapping[str, Any]: + """Data parition for single-to-combo experiments""" double_idx = cfg.loo.all(axis=1) testidx = double_idx @@ -135,7 +148,7 @@ def s2c(cfg): return dataset -def loo(cfg, singles): +def loo(cfg, singles) -> Mapping[str, Any]: """data parition for leave-one-drug-out experiments""" drug_index = int(cfg.drug_index) double_idx = cfg.loo.all(axis=1) @@ -181,7 +194,7 @@ def loo(cfg, singles): return dataset -def random_partition(cfg): +def random_partition(cfg) -> Mapping[str, Any]: """random dataset partition""" nexp, _ = cfg.pert.shape nvalid = int(nexp * cfg.trainset_ratio) @@ -222,7 +235,7 @@ def random_partition(cfg): return dataset -def random_partition_with_replicates(cfg): +def random_partition_with_replicates(cfg) -> Mapping[str, Any]: """random dataset partition""" nexp = len(np.unique(cfg.loo, axis=0)) nvalid = int(nexp * cfg.trainset_ratio) @@ -268,9 +281,9 @@ def random_partition_with_replicates(cfg): return dataset - def sparse_to_feedable_arrays(npz): - """convert sparse matrix to arrays""" + """Converts sparse matrices to full arrays.""" + # Not currently in use. coo = npz.tocoo() indices = [[i, j] for i, j in zip(coo.row, coo.col)] values = coo.data diff --git a/cellbox/cellbox/kernel.py b/cellbox/cellbox/kernel.py index 5d5af21..869da90 100644 --- a/cellbox/cellbox/kernel.py +++ b/cellbox/cellbox/kernel.py @@ -4,11 +4,18 @@ """ import tensorflow.compat.v1 as tf +from typing import Callable, Mapping, Any tf.disable_v2_behavior() -def get_envelope(args): - """get the envelope form based on the given argument""" +def get_envelope(args) -> Callable[[tf.Tensor], tf.Tensor]: + """Gets the envelope form based on the given argument. + + Returns: + A function that takes in a tensor and returns a tensor with the same shape. + This function should apply a specific transformation such as the Hill's + equation. + """ if args.envelope_form == 'tanh': args.envelope_fn = tf.tanh elif args.envelope_form == 'polynomial': @@ -31,8 +38,15 @@ def get_envelope(args): return args.envelope_fn -def get_dxdt(args, params): - """calculate the derivatives dx/dt in the ODEs""" +def get_dxdt( + args, params: Mapping[str, tf.Tensor]) -> Callable[[tf.Tensor], tf.Tensor]: + """Calculates the derivatives dx/dt in the ODEs. + + Returns: + A function that takes in a tensor and returns a tensor with the same shape. + This function should apply an envelope function with given params, i.e., + f(params, x). + """ if args.ode_degree == 1: def weighted_sum(x): return tf.matmul(params['W'], x) @@ -55,8 +69,8 @@ def weighted_sum(x): raise Exception("Illegal envelope type. Choose from [0,1,2].") -def get_ode_solver(args): - """get the ODE solver based on the given argument""" +def get_ode_solver(args) -> Callable[Any, tf.Tensor]: + """Gets the ODE solver based on the given argument.""" if args.ode_solver == 'heun': return heun_solver if args.ode_solver == 'euler': @@ -68,7 +82,7 @@ def get_ode_solver(args): raise Exception("Illegal ODE solver. Use [heun, euler, rk4, midpoint]") -def heun_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None): +def heun_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None) -> tf.Tensor: """Heun's ODE solver""" xs = [] n_x = t_mu.shape[0] @@ -83,7 +97,7 @@ def heun_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None): return xs -def euler_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None): +def euler_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None) -> tf.Tensor: """Euler's method""" xs = [] n_x = t_mu.shape[0] @@ -97,7 +111,7 @@ def euler_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None): return xs -def midpoint_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None): +def midpoint_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None) -> tf.Tensor: """Midpoint method""" xs = [] n_x = t_mu.shape[0] @@ -112,7 +126,7 @@ def midpoint_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None): return xs -def rk4_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None): +def rk4_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None) -> tf.Tensor: """Runge-Kutta method""" xs = [] n_x = t_mu.shape[0] diff --git a/cellbox/cellbox/model.py b/cellbox/cellbox/model.py index 7477d2d..20132cb 100644 --- a/cellbox/cellbox/model.py +++ b/cellbox/cellbox/model.py @@ -7,12 +7,11 @@ import tensorflow.compat.v1 as tf import cellbox.kernel from cellbox.utils import loss, optimize -# import tensorflow_probability as tfp tf.disable_v2_behavior() def factory(args): - """define model type based on configuration input""" + """Defines the model given the model args.""" if args.model == 'CellBox': return CellBox(args).build() # Deprecated for now, use scikit-learn instead @@ -31,8 +30,8 @@ def factory(args): class PertBio: - """define abstract perturbation model""" - def __init__(self, args): + """Defines the abstract perturbation model.""" + def __init__(self, args) -> None: self.args = args self.n_x = args.n_x self.pert_in, self.expr_out = args.pert_in, args.expr_out @@ -43,8 +42,8 @@ def __init__(self, args): self.l1_lambda, self.l2_lambda = self.args.l1_lambda_placeholder, self.args.l2_lambda_placeholder self.lr = self.args.lr - def get_ops(self): - """get operators for tensorflow""" + def get_ops(self) -> None: + """Gets operators for tensorflow.""" if self.args.weight_loss == 'expr': self.train_loss, self.train_mse_loss = loss(self.train_y, self.train_yhat, self.params['W'], self.l1_lambda, self.l2_lambda, weight=self.train_y) diff --git a/cellbox/cellbox/train.py b/cellbox/cellbox/train.py index 963468a..33e99dc 100644 --- a/cellbox/cellbox/train.py +++ b/cellbox/cellbox/train.py @@ -10,17 +10,20 @@ import tensorflow.compat.v1 as tf from tensorflow.compat.v1.errors import OutOfRangeError import cellbox +from typing import Sequence, Any from cellbox.utils import TimeLogger tf.disable_v2_behavior() -def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n_iter_buffer, n_iter_patience, args): +def train_substage( + model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, + n_iter, n_iter_buffer, n_iter_patience, args) -> None: """ - Training function that does one stage of training. The stage training can be repeated and modified to give better - training result. + Training function that does one stage of training. The stage training + can be repeated and modified to give better training result. Args: - model (CellBox): an CellBox instance + model (cellbox.model.PertBio): an CellBox instance sess (tf.Session): current session, need reinitialization for every nT lr_val (float): learning rate (read in from config file) l1_lambda (float): l1 regularization weight @@ -29,7 +32,7 @@ def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n n_iter (int): maximum number of iterations n_iter_buffer (int): training loss moving average window n_iter_patience (int): training loss tolerance - args: Args or configs + args (cellbox.config.Config): The model args. """ stages = glob.glob("*best*.csv") @@ -48,7 +51,8 @@ def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n model.l1_lambda: l1_lambda, model.l2_lambda: l2_lambda }) - args.logger.log("--------- lr: {}\tl1: {}\tl2: {}\t".format(lr_val, l1_lambda, l2_lambda)) + args.logger.log("--------- lr: {}\tl1: {}\tl2: {}\t".format( + lr_val, l1_lambda, l2_lambda)) sess.run(model.iter_monitor.initializer, feed_dict=args.feed_dicts['valid_set']) for idx_epoch in range(n_epoch): @@ -63,16 +67,18 @@ def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n t0 = time.perf_counter() try: _, loss_train_i, loss_train_mse_i = sess.run( - (model.op_optimize, model.train_loss, model.train_mse_loss), feed_dict=args.feed_dicts['train_set']) + (model.op_optimize, model.train_loss, model.train_mse_loss), + feed_dict=args.feed_dicts['train_set']) except OutOfRangeError: # for iter_train break # record training loss_valid_i, loss_valid_mse_i = sess.run( - (model.monitor_loss, model.monitor_mse_loss), feed_dict=args.feed_dicts['valid_set']) + (model.monitor_loss, model.monitor_mse_loss), + feed_dict=args.feed_dicts['valid_set']) new_loss = best_params.avg_n_iters_loss(loss_valid_i) if args.export_verbose > 0: - print(("Substage:{}\tEpoch:{}/{}\tIteration: {}/{}" + "\tloss (train):{:1.6f}\tloss (buffer on valid):" + print(("Substage:{}\tEpoch:{}/{}\tIteration: {}/{}" +"\tloss (train):{:1.6f}\tloss (buffer on valid):" "{:1.6f}" + "\tbest:{:1.6f}\tTolerance: {}/{}").format(substage_i, idx_epoch, n_epoch, idx_iter, n_iter, loss_train_i, new_loss, best_params.loss_min, n_unchanged, @@ -92,15 +98,17 @@ def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n # Evaluation on valid set t0 = time.perf_counter() sess.run(model.iter_eval.initializer, feed_dict=args.feed_dicts['valid_set']) - loss_valid_i, loss_valid_mse_i = eval_model(sess, model.iter_eval, (model.eval_loss, model.eval_mse_loss), - args.feed_dicts['valid_set'], n_batches_eval=args.n_batches_eval) + loss_valid_i, loss_valid_mse_i = eval_model( + sess, model.iter_eval, (model.eval_loss, model.eval_mse_loss), + args.feed_dicts['valid_set'], n_batches_eval=args.n_batches_eval) append_record("record_eval.csv", [-1, None, None, loss_valid_i, None, loss_valid_mse_i, None, time.perf_counter() - t0]) # Evaluation on test set t0 = time.perf_counter() sess.run(model.iter_eval.initializer, feed_dict=args.feed_dicts['test_set']) - loss_test_mse = eval_model(sess, model.iter_eval, model.eval_mse_loss, - args.feed_dicts['test_set'], n_batches_eval=args.n_batches_eval) + loss_test_mse = eval_model( + sess, model.iter_eval, model.eval_mse_loss, + args.feed_dicts['test_set'], n_batches_eval=args.n_batches_eval) append_record("record_eval.csv", [-1, None, None, None, None, None, loss_test_mse, time.perf_counter() - t0]) best_params.save() @@ -108,16 +116,31 @@ def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n save_model(args.saver, sess, './' + args.ckpt_name) -def append_record(filename, contents): - """define function for appending training record""" +def append_record(filename: str, contents: Sequence[Any]) -> None: + """Appends the contents to the log.""" with open(filename, 'a') as f: for content in contents: f.write('{},'.format(content)) f.write('\n') -def eval_model(sess, eval_iter, obj_fn, eval_dict, return_avg=True, n_batches_eval=None): - """simulate the model for prediction""" +def eval_model( + sess, eval_iter, obj_fn, eval_dict, return_avg=True, + n_batches_eval=None) -> np.ndarray: + """Uses a given model to make predictions. + + Args: + sess: The training session for tensorflow v1. + eval_iter: The data iterator used for evaluation. + obj_fn: The operator used to evaluate. + eval_dict: The feed_dict used for sess.run(). + return_avg: Whether to calculates an average of the evaluated tensor. + n_batches_eval: The max number of batches used for training. If None, + uses all the batches. + + Returns: + The evaluated objective function. + """ sess.run(eval_iter.initializer, feed_dict=eval_dict) counter = 0 eval_results = [] @@ -134,13 +157,14 @@ def eval_model(sess, eval_iter, obj_fn, eval_dict, return_avg=True, n_batches_ev return np.vstack(eval_results) -def train_model(model, args): - """Train the model""" +def train_model(model, args) -> None: + """Trains the model given the model instance and args.""" args.logger = TimeLogger(time_logger_step=1, hierachy=2) # Check if all variables in scope # TODO: put variables under appropriate scopes - for i in tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope='initialization'): + for i in tf.compat.v1.get_collection( + tf.compat.v1.GraphKeys.GLOBAL_VARIABLES, scope='initialization'): print(i) # Initialization @@ -175,16 +199,23 @@ def train_model(model, args): tf.compat.v1.reset_default_graph() -def save_model(saver, sess, path): - """save model""" +def save_model(saver, sess, path) -> None: + """Saves the model and session to a given filepath.""" # Save the variables to disk. tmp = saver.save(sess, path) print("Model saved in path: %s" % tmp) class Screenshot(dict): - """summarize the model""" + """The class that tracks the model metadata.""" + def __init__(self, args, n_iter_buffer): + """Creates the instance and initialize the variables. + + Args: + args (cellbox.config.Config): The model args. + n_iter_buffer (int): The moving average window for model losses. + """ # initialize loss_min super().__init__() self.loss_min = 1000 @@ -197,14 +228,23 @@ def __init__(self, args, n_iter_buffer): self.substage_i = [] self.export_verbose = args.export_verbose - def avg_n_iters_loss(self, new_loss): - """average the last few losses""" + def avg_n_iters_loss(self, new_loss: float) -> float: + """Averages the last few losses""" self.saved_losses = self.saved_losses + [new_loss] self.saved_losses = self.saved_losses[-self.n_iter_buffer:] return sum(self.saved_losses) / len(self.saved_losses) def screenshot(self, sess, model, substage_i, node_index, loss_min, args): - """evaluate models""" + """Evaluates the model performance and updates the summary files and best config. + + Args: + sess (tf.Session): The session for tensorflow v1. + model (cellbox.models.PertBio): The instance for the model. + substage_i (int): The index for the training substage. + node_index (pandas.DataFrame): A dataframe of node indices. + loss_min (int): The best loss so far. + args (cellbox.config.Config): The model args. + """ self.substage_i = substage_i self.loss_min = loss_min # Save the variables to disk. @@ -244,7 +284,7 @@ def screenshot(self, sess, model, substage_i, node_index, loss_min, args): pass def save(self): - """save model parameters""" + """Exports the best model metadata to a CSV file.""" for file in glob.glob(str(self.substage_i) + "_best.*.csv"): os.remove(file) for key in self: diff --git a/cellbox/cellbox/utils.py b/cellbox/cellbox/utils.py index 32647fc..3820f74 100644 --- a/cellbox/cellbox/utils.py +++ b/cellbox/cellbox/utils.py @@ -7,10 +7,29 @@ import hashlib import tensorflow.compat.v1 as tf import json +from typing import Union, Tuple, Any tf.disable_v2_behavior() -def loss(x_gold, x_hat, W, l1=0, l2=0, weight=1.): - """evaluate loss""" +def loss( + x_gold: Union[tf.SparseTensor, tf.Tensor], + x_hat: Union[tf.SparseTensor, tf.Tensor], + W: tf.Variable, + l1: float = 0.0, + l2: float = 0.0, + weight: float = 1.0) -> Tuple[tf.Tensor, tf.Tensor]: + """Evaluates the losses. + + Args: + x_gold: The ground truth tensor for values. Expected shape: [B, N]. + x_hat: The predicted tensor for values. Expected shape: [B, N]. + W: The variable for the the weight matrix, used for regularization. + l1: The lambda 1 for L1 loss. + l2: The lambda 2 for L2 loss. + weight: The balance weight for the MSE loss when calculating the total loss. + + Returns: + A tuple of the total loss (including regularization) and the raw MSE loss. + """ if isinstance(x_gold, tf.SparseTensor): x_gold = tf.sparse.to_dense(x_gold) @@ -22,18 +41,24 @@ def loss(x_gold, x_hat, W, l1=0, l2=0, weight=1.): return loss_full, loss_mse -def optimize(loss_in, lr, optimizer=tf.compat.v1.train.AdamOptimizer, var_list=None): +def optimize( + loss_in: tf.Tensor, + lr: tf.Variable, + optimizer: tf.compat.v1.train.Optimizer = tf.compat.v1.train.AdamOptimizer, + var_list: Any = None): """ - Optimize the training loss using Adam + Optimizes the training loss using by default the Adam optimizer. Args: loss_in (float): training loss, mean squared error + L1 regularization term lr (float): placeholder for learning rate optimizer: default tf.train.AdamOptimizer - var_list: list of vars to be optimized + var_list (list): a list of vars to be optimized + Returns: - opt_op (optimizer): op to optimize the training loss - loss (loss): training loss, including regularization if applicable + A tuple of two items: + opt_op (optimizer): op to optimize the training loss + loss (loss): training loss, including regularization if applicable """ if var_list is None: var_list = tf.compat.v1.get_collection(tf.compat.v1.GraphKeys.GLOBAL_VARIABLES) @@ -44,24 +69,29 @@ def optimize(loss_in, lr, optimizer=tf.compat.v1.train.AdamOptimizer, var_list=N class TimeLogger: - """calculate training time""" - def __init__(self, time_logger_step=1, hierachy=1): + """The class used to measure and log the training time.""" + def __init__(self, time_logger_step: int = 1, hierachy: int = 1) -> None: + """Creates the logger instance and initialize the variables. + + Args: + time_logger_step: The listening frequency, default=1. + hierachy: The header #s of the printed log string during logging. Used + when running `grep` from the log files. + """ self.time_logger_step = time_logger_step self.step_count = 0 self.hierachy = hierachy self.time = time.time() - def log(self, s): - """time log""" + def log(self, s: str) -> None: + """Track and log the time difference.""" if self.step_count % self.time_logger_step == 0: print("#" * 4 * self.hierachy, " ", s, " --time elapsed: %.2f" % (time.time() - self.time)) self.time = time.time() self.step_count += 1 -def md5(obj): - """ - returns a hashed with md5 string of the key - """ +def md5(obj: Any) -> str: + """Returns a hashed with md5 string of the key. Used as identifiers for file I/O.""" key = json.dumps(vars(obj), sort_keys=True) return hashlib.md5(key.encode()).hexdigest() diff --git a/cellbox/cellbox/version.py b/cellbox/cellbox/version.py index 90e5768..0edeeb8 100644 --- a/cellbox/cellbox/version.py +++ b/cellbox/cellbox/version.py @@ -4,11 +4,14 @@ __version__ = '0.3.2' VERSION = __version__ +# TODO(desmondyuan): update 0.3.3 +# * new function docstrings +# * remove the binder implementation +# * add basic tests +def get_msg() -> None: + """Print the version history.""" -def get_msg(): - """get version history""" - # for test: installation completed changelog = [ """ version 0.0.2 diff --git a/scripts/main.py b/scripts/main.py index 57177de..ab8f1dc 100644 --- a/scripts/main.py +++ b/scripts/main.py @@ -18,12 +18,14 @@ def set_seed(in_seed): + """Sets random seeds for numpy and tensorflow.""" int_seed = int(in_seed) tf.compat.v1.set_random_seed(int_seed) np.random.seed(int_seed) def prepare_workdir(in_cfg): + """Creates the working directory for each experiment and generates necessary files.""" # Read Data in_cfg.root_dir = os.getcwd() in_cfg.node_index = pd.read_csv(in_cfg.node_index_file, header=None, names=None) \