Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable adjoint method #3

Merged
merged 5 commits into from
Feb 29, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion tests/check_grad.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,8 @@ def get_analytical_jacobian(input, output):
reentrant = True
correct_grad_sizes = True

for i in range(_numel(flat_grad_output)):
N = tf.cast(_numel(flat_grad_output), dtype=tf.int32)
for i in range(N):
flat_grad_output *= 0.
add_one = [0] * (flat_grad_output.shape[0])
add_one[0] = 1
Expand Down
241 changes: 130 additions & 111 deletions tests/gradient_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
def max_abs(tensor):
return tf.reduce_max(tf.abs(tensor))


class TestGradient(unittest.TestCase):

def test_huen(self):
Expand Down Expand Up @@ -49,116 +48,136 @@ def test_adams(self):
func = lambda y0, t_points: tfdiffeq.odeint(f, y0, t_points, method='adams')
self.assertTrue(gradcheck(func, (y0, t_points)))

# def test_adjoint(self):
# """
# Test against dopri5
# """
# f, y0, t_points, _ = problems.construct_problem(TEST_DEVICE)
#
# func = lambda y0, t_points: tfdiffeq.odeint(f, y0, t_points, method='dopri5')
#
# tf.set_random_seed(0)
# with tf.GradientTape() as tape:
# tape.watch(t_points)
# ys = func(y0, t_points)
#
# # gradys = tf.random_uniform(ys.shape)
# # ys.backward(gradys)
#
# # reg_y0_grad = y0.grad
# reg_t_grad, reg_a_grad, reg_b_grad = tape.gradient(ys, [t_points, f.a, f.b])
# # reg_t_grad = t_points.grad
# # reg_a_grad = f.a.grad
# # reg_b_grad = f.b.grad
#
# f, y0, t_points, _ = problems.construct_problem(TEST_DEVICE)
#
# y0 = (y0,)
#
# func = lambda y0, t_points: tfdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
#
# with tf.GradientTape() as tape:
# tape.watch(t_points)
# ys = func(y0, t_points)
#
# grads = tape.gradient(ys, [t_points, f.a, f.b])
# adj_t_grad, adj_a_grad, adj_b_grad = grads
#
# # self.assertLess(max_abs(reg_y0_grad - adj_y0_grad), eps)
# self.assertLess(max_abs(reg_t_grad - adj_t_grad), eps)
# self.assertLess(max_abs(reg_a_grad - adj_a_grad), eps)
# self.assertLess(max_abs(reg_b_grad - adj_b_grad), eps)


# class TestCompareAdjointGradient(unittest.TestCase):
#
# def problem(self):
#
# class Odefunc(tf.keras.Model):
#
# def __init__(self):
# super(Odefunc, self).__init__()
# self.A = tf.Variable([[-0.1, 2.0], [-2.0, -0.1]], dtype=tf.float64)
# self.unused_module = tf.keras.layers.Dense(5)
#
# def call(self, t, y):
# y = tfdiffeq.cast_double(y)
# return tf.matmul(y ** 3, self.A)
#
# y0 = tf.convert_to_tensor([[2., 0.]])
# t_points = tf.linspace(0., 25., 10)
# func = Odefunc()
# return func, y0, t_points
#
# def test_dopri5_adjoint_against_dopri5(self):
# with tf.GradientTape() as tape:
# func, y0, t_points = self.problem()
# # tape.watch(t_points)
# tape.watch(y0)
# ys = tfdiffeq.odeint_adjoint(func, y0, t_points, method='dopri5')
#
# adj_y0_grad = tape.gradient(ys, y0) # y0.grad
# # adj_t_grad = tape.gradient(ys, t_points) # t_points.grad
# adj_A_grad = tape.gradient(ys, func.A) # func.A.grad
#
# print("reached here")
# # w_grad, b_grad = tape.gradient(ys, func.unused_module.variables)
# # self.assertEqual(max_abs(w_grad), 0)
# # self.assertEqual(max_abs(b_grad), 0)
#
# with tf.GradientTape() as tape:
# func, y0, t_points = self.problem()
# tape.watch(y0)
# # tape.watch(t_points)
# ys = tfdiffeq.odeint(func, y0, t_points, method='dopri5')
#
# y_grad = tape.gradient(ys, y0)
# # t_grad = tape.gradient(ys, t_points)
# a_grad = tape.gradient(ys, func.A)
#
# self.assertLess(max_abs(y_grad - adj_y0_grad), 3e-4)
# # self.assertLess(max_abs(t_grad - adj_t_grad), 1e-4)
# self.assertLess(max_abs(a_grad - adj_A_grad), 2e-3)

# def test_adams_adjoint_against_dopri5(self):
# func, y0, t_points = self.problem()
# ys_ = torchdiffeq.odeint_adjoint(func, y0, t_points, method='adams')
# gradys = torch.rand_like(ys_) * 0.1
# ys_.backward(gradys)
#
# adj_y0_grad = y0.grad
# adj_t_grad = t_points.grad
# adj_A_grad = func.A.grad
# self.assertEqual(max_abs(func.unused_module.weight.grad), 0)
# self.assertEqual(max_abs(func.unused_module.bias.grad), 0)
#
# func, y0, t_points = self.problem()
# ys = torchdiffeq.odeint(func, y0, t_points, method='dopri5')
# ys.backward(gradys)
#
# self.assertLess(max_abs(y0.grad - adj_y0_grad), 5e-2)
# self.assertLess(max_abs(t_points.grad - adj_t_grad), 5e-4)
# self.assertLess(max_abs(func.A.grad - adj_A_grad), 2e-2)
def test_adjoint(self):
"""
Test against dopri5
"""
tf.compat.v1.set_random_seed(0)
f, y0, t_points, _ = problems.construct_problem(TEST_DEVICE)
y0 = tf.cast(y0, tf.float64)
t_points = tf.cast(t_points, tf.float64)

func = lambda y0, t_points: tfdiffeq.odeint(f, y0, t_points, method='dopri5')

with tf.GradientTape() as tape:
tape.watch(t_points)
ys = func(y0, t_points)

reg_t_grad, reg_a_grad, reg_b_grad = tape.gradient(ys, [t_points, f.a, f.b])

f, y0, t_points, _ = problems.construct_problem(TEST_DEVICE)
y0 = tf.cast(y0, tf.float64)
t_points = tf.cast(t_points, tf.float64)

y0 = (y0,)

func = lambda y0, t_points: tfdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')

with tf.GradientTape() as tape:
tape.watch(t_points)
ys = func(y0, t_points)

grads = tape.gradient(ys, [t_points, f.a, f.b])
adj_t_grad, adj_a_grad, adj_b_grad = grads

self.assertLess(max_abs(reg_t_grad - adj_t_grad), 1.2e-7)
self.assertLess(max_abs(reg_a_grad - adj_a_grad), 1.2e-7)
self.assertLess(max_abs(reg_b_grad - adj_b_grad), 1.2e-7)


class TestCompareAdjointGradient(unittest.TestCase):

def problem(self):
tf.keras.backend.set_floatx('float64')

class Odefunc(tf.keras.Model):

def __init__(self):
super(Odefunc, self).__init__()
self.A = tf.Variable([[-0.1, -2.0], [2.0, -0.1]], dtype=tf.float64)
self.unused_module = tf.keras.layers.Dense(5, dtype=tf.float64)
self.unused_module.build((5,))

def call(self, t, y):
y = tfdiffeq.cast_double(y)
return tf.linalg.matvec(self.A, y ** 3)

y0 = tf.convert_to_tensor([2., 0.], dtype=tf.float64)
t_points = tf.linspace(
tf.constant(0., dtype=tf.float64),
tf.constant(25., dtype=tf.float64),
10
)
func = Odefunc()
return func, y0, t_points

def test_dopri5_adjoint_against_dopri5(self):
tf.keras.backend.set_floatx('float64')
tf.compat.v1.set_random_seed(0)
with tf.GradientTape(persistent=True) as tape:
func, y0, t_points = self.problem()
tape.watch(t_points)
tape.watch(y0)
ys = tfdiffeq.odeint_adjoint(func, y0, t_points, method='dopri5')

gradys = 0.1 * tf.random.uniform(shape=ys.shape, dtype=tf.float64)
adj_y0_grad, adj_t_grad, adj_A_grad = tape.gradient(
ys,
[y0, t_points, func.A],
output_gradients=gradys
)

w_grad, b_grad = tape.gradient(ys, func.unused_module.variables)
self.assertIsNone(w_grad)
self.assertIsNone(b_grad)

with tf.GradientTape() as tape:
func, y0, t_points = self.problem()
tape.watch(y0)
tape.watch(t_points)
ys = tfdiffeq.odeint(func, y0, t_points, method='dopri5')

y_grad, t_grad, a_grad = tape.gradient(
ys,
[y0, t_points, func.A],
output_gradients=gradys
)

self.assertLess(max_abs(y_grad - adj_y0_grad), 3e-4)
self.assertLess(max_abs(t_grad - adj_t_grad), 1e-4)
self.assertLess(max_abs(a_grad - adj_A_grad), 2e-3)

#def test_adams_adjoint_against_dopri5(self):
# tf.keras.backend.set_floatx('float64')
# tf.compat.v1.set_random_seed(0)
# with tf.GradientTape(persistent=True) as tape:
# func, y0, t_points = self.problem()
# tape.watch(t_points)
# tape.watch(y0)
# ys = tfdiffeq.odeint_adjoint(func, y0, t_points, method='adams')

# gradys = 0.1 * tf.random.uniform(shape=ys.shape, dtype=tf.float64)
# adj_y0_grad, adj_t_grad, adj_A_grad = tape.gradient(
# ys,
# [y0, t_points, func.A],
# output_gradients=gradys
# )

# with tf.GradientTape() as tape:
# func, y0, t_points = self.problem()
# tape.watch(y0)
# tape.watch(t_points)
# ys = tfdiffeq.odeint(func, y0, t_points, method='dopri5')

# y_grad, t_grad, a_grad = tape.gradient(
# ys,
# [y0, t_points, func.A],
# output_gradients=gradys
# )

# self.assertLess(max_abs(y_grad - adj_y0_grad), 3e-4)
# self.assertLess(max_abs(t_grad - adj_t_grad), 1e-4)
# self.assertLess(max_abs(a_grad - adj_A_grad), 2e-3)


if __name__ == '__main__':
Expand Down
34 changes: 20 additions & 14 deletions tests/odeint_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,13 +68,16 @@ def test_dopri5(self):
with self.subTest(ode=ode):
self.assertLess(rel_error(sol, y), error_tol)

# def test_adjoint(self):
# for ode in problems.PROBLEMS.keys():
# f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
#
# y = tfdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
# with self.subTest(ode=ode):
# self.assertLess(rel_error(sol, y), error_tol)
def test_adjoint(self):
for ode in problems.PROBLEMS.keys():
f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
y0 = tf.cast(y0, tf.float64)
t_points = tf.cast(t_points, tf.float64)
sol = tf.cast(sol, tf.float64)

y = tfdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
with self.subTest(ode=ode):
self.assertLess(rel_error(sol, y), error_tol)


class TestSolverBackwardsInTimeError(unittest.TestCase):
Expand Down Expand Up @@ -119,13 +122,16 @@ def test_dopri5(self):
with self.subTest(ode=ode):
self.assertLess(rel_error(sol, y), error_tol)

# def test_adjoint(self):
# for ode in problems.PROBLEMS.keys():
# f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
#
# y = tfdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
# with self.subTest(ode=ode):
# self.assertLess(rel_error(sol, y), error_tol)
def test_adjoint(self):
for ode in problems.PROBLEMS.keys():
f, y0, t_points, sol = problems.construct_problem(TEST_DEVICE, reverse=True)
y0 = tf.cast(y0, tf.float64)
t_points = tf.cast(t_points, tf.float64)
sol = tf.cast(sol, tf.float64)

y = tfdiffeq.odeint_adjoint(f, y0, t_points, method='dopri5')
with self.subTest(ode=ode):
self.assertLess(rel_error(sol, y), error_tol)


class TestNoIntegration(unittest.TestCase):
Expand Down
2 changes: 1 addition & 1 deletion tfdiffeq/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# Core imports
from tfdiffeq.odeint import odeint
# from tfdiffeq.adjoint import odeint_adjoint
from tfdiffeq.adjoint import odeint_adjoint

# Utility functions
from tfdiffeq.misc import cast_double, func_cast_double
Expand Down
2 changes: 1 addition & 1 deletion tfdiffeq/adams.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ def g_and_explicit_phi(prev_t, next_t, implicit_phi, k):
dt = next_t - prev_t[0]

with tf.device(prev_t[0].device):
g = tf.Variable(tf.zeros([k + 1]))
g = tf.Variable(tf.zeros([k + 1]), trainable=False)

explicit_phi = collections.deque(maxlen=k)
beta = move_to_device(tf.convert_to_tensor(1.), prev_t[0].device)
Expand Down
Loading