Skip to content

Commit

Permalink
Merge pull request #47 from sanderlab/add_function_docstrings
Browse files Browse the repository at this point in the history
Add function docstrings
  • Loading branch information
cannin authored Jun 1, 2023
2 parents ba6fe63 + 356f8a1 commit f1a477f
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 73 deletions.
1 change: 0 additions & 1 deletion cellbox/cellbox/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,3 @@
from cellbox.version import __version__, VERSION, get_msg

get_msg()
#
35 changes: 24 additions & 11 deletions cellbox/cellbox/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,23 @@
import numpy as np
import pandas as pd
import tensorflow.compat.v1 as tf
from typing import Mapping, Any
from scipy import sparse
tf.disable_v2_behavior()


def factory(cfg):
"""formulate training dataset"""
"""Formulates the training dataset.
This factory conducts the following three steps of data processing.
(1) Create variable placeholders for the perturbation and expression
vectors (input and output).
(2) [Optional] Add noise to the loaded data. This was used for
corruption analyses.
(3) Data partitioning given cellbox.config.Config.experiment_type.
The results are in the form of a dictionary.
(4) Creates a feeding dictionary for the session call.
"""
# Prepare data
if cfg.sparse_data:
cfg.pert_in = tf.compat.v1.sparse.placeholder(tf.float32, [None, cfg.n_x], name='pert_in')
Expand All @@ -31,7 +42,7 @@ def factory(cfg):
lambda x: pad_and_realign(x, max_combo_degree, cfg.n_activity_nodes - 1)
).tolist())

# add noise
# Add noise
if cfg.add_noise_level > 0:
np.random.seed(cfg.seed)
assert not cfg.sparse_data, "Adding noise to sparse data format is yet to be supported"
Expand Down Expand Up @@ -73,13 +84,15 @@ def factory(cfg):
return cfg


def pad_and_realign(x, length, idx_shift=0):
def pad_and_realign(x: tf.Tensor, length: int, idx_shift: int=0) -> tf.Tensor:
"""Add zeros to the given tensor of perturbation indices."""
x -= idx_shift
padded = np.pad(x, (0, length - len(x)), 'constant')
return padded


def get_tensors(cfg):
def get_tensors(cfg) -> None:
"""Gets the dataset iterators and regularization placeholders."""
# prepare training placeholders
cfg.l1_lambda_placeholder = tf.compat.v1.placeholder(tf.float32, name='l1_lambda')
cfg.l2_lambda_placeholder = tf.compat.v1.placeholder(tf.float32, name='l2_lambda')
Expand All @@ -95,8 +108,8 @@ def get_tensors(cfg):
return cfg


def s2c(cfg):
"""data parition for single-to-combo experiments"""
def s2c(cfg) -> Mapping[str, Any]:
"""Data parition for single-to-combo experiments"""
double_idx = cfg.loo.all(axis=1)
testidx = double_idx

Expand Down Expand Up @@ -135,7 +148,7 @@ def s2c(cfg):
return dataset


def loo(cfg, singles):
def loo(cfg, singles) -> Mapping[str, Any]:
"""data parition for leave-one-drug-out experiments"""
drug_index = int(cfg.drug_index)
double_idx = cfg.loo.all(axis=1)
Expand Down Expand Up @@ -181,7 +194,7 @@ def loo(cfg, singles):
return dataset


def random_partition(cfg):
def random_partition(cfg) -> Mapping[str, Any]:
"""random dataset partition"""
nexp, _ = cfg.pert.shape
nvalid = int(nexp * cfg.trainset_ratio)
Expand Down Expand Up @@ -222,7 +235,7 @@ def random_partition(cfg):
return dataset


def random_partition_with_replicates(cfg):
def random_partition_with_replicates(cfg) -> Mapping[str, Any]:
"""random dataset partition"""
nexp = len(np.unique(cfg.loo, axis=0))
nvalid = int(nexp * cfg.trainset_ratio)
Expand Down Expand Up @@ -268,9 +281,9 @@ def random_partition_with_replicates(cfg):
return dataset



def sparse_to_feedable_arrays(npz):
"""convert sparse matrix to arrays"""
"""Converts sparse matrices to full arrays."""
# Not currently in use.
coo = npz.tocoo()
indices = [[i, j] for i, j in zip(coo.row, coo.col)]
values = coo.data
Expand Down
34 changes: 24 additions & 10 deletions cellbox/cellbox/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,18 @@
"""

import tensorflow.compat.v1 as tf
from typing import Callable, Mapping, Any
tf.disable_v2_behavior()


def get_envelope(args):
"""get the envelope form based on the given argument"""
def get_envelope(args) -> Callable[[tf.Tensor], tf.Tensor]:
"""Gets the envelope form based on the given argument.
Returns:
A function that takes in a tensor and returns a tensor with the same shape.
This function should apply a specific transformation such as the Hill's
equation.
"""
if args.envelope_form == 'tanh':
args.envelope_fn = tf.tanh
elif args.envelope_form == 'polynomial':
Expand All @@ -31,8 +38,15 @@ def get_envelope(args):
return args.envelope_fn


def get_dxdt(args, params):
"""calculate the derivatives dx/dt in the ODEs"""
def get_dxdt(
args, params: Mapping[str, tf.Tensor]) -> Callable[[tf.Tensor], tf.Tensor]:
"""Calculates the derivatives dx/dt in the ODEs.
Returns:
A function that takes in a tensor and returns a tensor with the same shape.
This function should apply an envelope function with given params, i.e.,
f(params, x).
"""
if args.ode_degree == 1:
def weighted_sum(x):
return tf.matmul(params['W'], x)
Expand All @@ -55,8 +69,8 @@ def weighted_sum(x):
raise Exception("Illegal envelope type. Choose from [0,1,2].")


def get_ode_solver(args):
"""get the ODE solver based on the given argument"""
def get_ode_solver(args) -> Callable[Any, tf.Tensor]:
"""Gets the ODE solver based on the given argument."""
if args.ode_solver == 'heun':
return heun_solver
if args.ode_solver == 'euler':
Expand All @@ -68,7 +82,7 @@ def get_ode_solver(args):
raise Exception("Illegal ODE solver. Use [heun, euler, rk4, midpoint]")


def heun_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None):
def heun_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None) -> tf.Tensor:
"""Heun's ODE solver"""
xs = []
n_x = t_mu.shape[0]
Expand All @@ -83,7 +97,7 @@ def heun_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None):
return xs


def euler_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None):
def euler_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None) -> tf.Tensor:
"""Euler's method"""
xs = []
n_x = t_mu.shape[0]
Expand All @@ -97,7 +111,7 @@ def euler_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None):
return xs


def midpoint_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None):
def midpoint_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None) -> tf.Tensor:
"""Midpoint method"""
xs = []
n_x = t_mu.shape[0]
Expand All @@ -112,7 +126,7 @@ def midpoint_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None):
return xs


def rk4_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None):
def rk4_solver(x, t_mu, dT, n_T, _dXdt, n_activity_nodes=None) -> tf.Tensor:
"""Runge-Kutta method"""
xs = []
n_x = t_mu.shape[0]
Expand Down
11 changes: 5 additions & 6 deletions cellbox/cellbox/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,11 @@
import tensorflow.compat.v1 as tf
import cellbox.kernel
from cellbox.utils import loss, optimize
# import tensorflow_probability as tfp
tf.disable_v2_behavior()


def factory(args):
"""define model type based on configuration input"""
"""Defines the model given the model args."""
if args.model == 'CellBox':
return CellBox(args).build()
# Deprecated for now, use scikit-learn instead
Expand All @@ -31,8 +30,8 @@ def factory(args):


class PertBio:
"""define abstract perturbation model"""
def __init__(self, args):
"""Defines the abstract perturbation model."""
def __init__(self, args) -> None:
self.args = args
self.n_x = args.n_x
self.pert_in, self.expr_out = args.pert_in, args.expr_out
Expand All @@ -43,8 +42,8 @@ def __init__(self, args):
self.l1_lambda, self.l2_lambda = self.args.l1_lambda_placeholder, self.args.l2_lambda_placeholder
self.lr = self.args.lr

def get_ops(self):
"""get operators for tensorflow"""
def get_ops(self) -> None:
"""Gets operators for tensorflow."""
if self.args.weight_loss == 'expr':
self.train_loss, self.train_mse_loss = loss(self.train_y, self.train_yhat, self.params['W'],
self.l1_lambda, self.l2_lambda, weight=self.train_y)
Expand Down
Loading

0 comments on commit f1a477f

Please sign in to comment.