Skip to content

Commit

Permalink
Merge pull request #41 from sanderlab/tf2_refactor
Browse files Browse the repository at this point in the history
TF2 Refactor
  • Loading branch information
cannin authored Feb 14, 2023
2 parents 1dfa9fb + 216afbd commit a1f1d20
Show file tree
Hide file tree
Showing 12 changed files with 2,400 additions and 46 deletions.
8 changes: 4 additions & 4 deletions binder/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,16 @@ importlib-metadata==1.7.0
Keras-Applications==1.0.8
Keras-Preprocessing==1.1.2
Markdown==3.2.2
numpy==1.16.0
numpy==1.19.5
opt-einsum==3.2.1
pandas==0.24.2
protobuf==3.12.2
python-dateutil==2.8.1
pytz==2020.1
six==1.15.0
tensorboard==1.15.0
tensorflow==1.15.0
tensorflow-estimator==1.15.1
tensorboard==2.6.0
tensorflow==2.6.2
tensorflow-estimator==2.6.0
termcolor==1.1.0
Werkzeug==1.0.1
wrapt==1.12.1
Expand Down
3 changes: 2 additions & 1 deletion cellbox/cellbox/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,9 @@
import os
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.compat.v1 as tf
from scipy import sparse
tf.disable_v2_behavior()


def factory(cfg):
Expand Down
3 changes: 2 additions & 1 deletion cellbox/cellbox/kernel.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
degree of ODEs, and the envelope forms
"""

import tensorflow as tf
import tensorflow.compat.v1 as tf
tf.disable_v2_behavior()


def get_envelope(args):
Expand Down
3 changes: 2 additions & 1 deletion cellbox/cellbox/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
"""

import numpy as np
import tensorflow as tf
import tensorflow.compat.v1 as tf
import cellbox.kernel
from cellbox.utils import loss, optimize
# import tensorflow_probability as tfp
tf.disable_v2_behavior()


def factory(args):
Expand Down
15 changes: 8 additions & 7 deletions cellbox/cellbox/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
import time
import numpy as np
import pandas as pd
import tensorflow as tf
import tensorflow.compat.v1 as tf
from tensorflow.compat.v1.errors import OutOfRangeError
import cellbox
from cellbox.utils import TimeLogger
tf.disable_v2_behavior()


def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n_iter_buffer, n_iter_patience, args):
Expand Down Expand Up @@ -59,7 +60,7 @@ def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n
while True:
if idx_iter > n_iter or n_unchanged > n_iter_patience:
break
t0 = time.clock()
t0 = time.perf_counter()
try:
_, loss_train_i, loss_train_mse_i = sess.run(
(model.op_optimize, model.train_loss, model.train_mse_loss), feed_dict=args.feed_dicts['train_set'])
Expand All @@ -78,7 +79,7 @@ def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n
n_iter_patience))
append_record("record_eval.csv",
[idx_epoch, idx_iter, loss_train_i, loss_valid_i, loss_train_mse_i,
loss_valid_mse_i, None, time.clock() - t0])
loss_valid_mse_i, None, time.perf_counter() - t0])
# early stopping
idx_iter += 1
if new_loss < best_params.loss_min:
Expand All @@ -89,18 +90,18 @@ def train_substage(model, sess, lr_val, l1_lambda, l2_lambda, n_epoch, n_iter, n
n_unchanged += 1

# Evaluation on valid set
t0 = time.clock()
t0 = time.perf_counter()
sess.run(model.iter_eval.initializer, feed_dict=args.feed_dicts['valid_set'])
loss_valid_i, loss_valid_mse_i = eval_model(sess, model.iter_eval, (model.eval_loss, model.eval_mse_loss),
args.feed_dicts['valid_set'], n_batches_eval=args.n_batches_eval)
append_record("record_eval.csv", [-1, None, None, loss_valid_i, None, loss_valid_mse_i, None, time.clock() - t0])
append_record("record_eval.csv", [-1, None, None, loss_valid_i, None, loss_valid_mse_i, None, time.perf_counter() - t0])

# Evaluation on test set
t0 = time.clock()
t0 = time.perf_counter()
sess.run(model.iter_eval.initializer, feed_dict=args.feed_dicts['test_set'])
loss_test_mse = eval_model(sess, model.iter_eval, model.eval_mse_loss,
args.feed_dicts['test_set'], n_batches_eval=args.n_batches_eval)
append_record("record_eval.csv", [-1, None, None, None, None, None, loss_test_mse, time.clock() - t0])
append_record("record_eval.csv", [-1, None, None, None, None, None, loss_test_mse, time.perf_counter() - t0])

best_params.save()
args.logger.log("------------------ Substage {} finished!-------------------".format(substage_i))
Expand Down
4 changes: 2 additions & 2 deletions cellbox/cellbox/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,9 @@

import time
import hashlib
import tensorflow as tf
import tensorflow.compat.v1 as tf
import json

tf.disable_v2_behavior()

def loss(x_gold, x_hat, W, l1=0, l2=0, weight=1.):
"""evaluate loss"""
Expand Down
18 changes: 12 additions & 6 deletions cellbox/cellbox/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
This module defines the version of the package
"""

__version__ = '0.3.1'
__version__ = '0.3.2'
VERSION = __version__


Expand Down Expand Up @@ -104,14 +104,14 @@ def get_msg():

"""
version 0.2.3
-- June 8, 2020 --
-- Jun 8, 2020 --
* Add support to L2 loss (alone or together with L1, i.e. elastic net)
* Clean the example configs folder
""",

"""
version 0.3.0
-- June 8, 2020 --
-- Jun 8, 2020 --
Add support for alternative form of perturbation
* Previous: add u on activity nodes
* New: fix activity nodes directly
Expand All @@ -123,10 +123,16 @@ def get_msg():

"""
version 0.3.1
-- Sept 25, 2020 --
-- Sep 25, 2020 --
* Release version for publication
* Add documentation
* Rename package to 'cellbox'
""",

"""
version 0.3.2
-- Feb 10, 2023 --
* Modify CellBox to support TF2
"""
]
print(
Expand All @@ -138,12 +144,12 @@ def get_msg():
" | |___| __/ | | |_) | (_) > < \n"
" \_____\___|_|_|____/ \___/_/\_\ \n"
"Running CellBox scripts developed in Sander lab\n"
"Maintained by Bo Yuan, Judy Shen, and Augustin Luna"
"Maintained by Bo Yuan, Judy Shen, and Augustin Luna; contributions by Daniel Ritter"
)

print(changelog[-1])
print(
"Tutorials and documentations are available at https://github.com/dfci/CellBox\n"
"Tutorials and documentations are available at https://github.com/sanderlab/CellBox\n"
"If you want to discuss the usage or to report a bug, please use the 'Issues' function at GitHub.\n"
"If you find CellBox useful for your research, please consider citing the corresponding publication.\n"
"For more information, please email us at [email protected] and [email protected], "
Expand Down
2 changes: 1 addition & 1 deletion cellbox/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
url="https://github.com/dfci/CellBox",
packages=['cellbox'],
python_requires='>=3.6',
install_requires=['tensorflow==1.15.0', 'numpy==1.16.0', 'pandas==0.24.2', 'scipy==1.3.0'],
install_requires=['tensorflow==2.11.0', 'numpy==1.24.1', 'pandas==1.5.3', 'scipy==1.10.0'],
tests_require=['pytest', 'pandas', 'numpy', 'scipy'],
setup_requires=['pytest-runner', "pytest"],
zip_safe=True,
Expand Down
Loading

0 comments on commit a1f1d20

Please sign in to comment.