Skip to content

Commit

Permalink
Shifted step layer inside scope of batch function. (#74)
Browse files Browse the repository at this point in the history
This avoids creating a new function object on each call of mLSTM1900_batch, which might be the reason why we have
memory leakage issues, as the new object doesn't hit lax.scan's function cache.

We have thus sacrificed unit test-ability of the step function, but at least we can still test it from the outside. Not that we'd want to modify the step function anyways.
  • Loading branch information
ericmjl authored Sep 18, 2020
1 parent e3d7560 commit a11e4b5
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 86 deletions.
135 changes: 68 additions & 67 deletions jax_unirep/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,77 +160,78 @@ def mLSTM1900_batch(
h_t = np.zeros(params["wmh"].shape[0])
c_t = np.zeros(params["wmh"].shape[0])

step_func = partial(mLSTM1900_step, params)
(h_final, c_final), outputs = lax.scan(
step_func, init=(h_t, c_t), xs=batch
)
return h_final, c_final, outputs
def mLSTM1900_step(
carry: Tuple[np.ndarray, np.ndarray],
x_t: np.ndarray,
) -> Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray]:
"""
Implementation of mLSTMCell from UniRep paper, with weight normalization.
Exact source code reference:
https://github.com/churchlab/UniRep/blob/master/unirep.py#L75
def mLSTM1900_step(
params: Dict[str, np.ndarray],
carry: Tuple[np.ndarray, np.ndarray],
x_t: np.ndarray,
) -> Tuple[Tuple[np.ndarray, np.ndarray], np.ndarray]:
"""
Implementation of mLSTMCell from UniRep paper, with weight normalization.
Shapes of parameters:
Exact source code reference:
https://github.com/churchlab/UniRep/blob/master/unirep.py#L75
- wmx: 10, 1900
- wmh: 1900, 1900
- wx: 10, 7600
- wh: 1900, 7600
- gmx: 1900
- gmh: 1900
- gx: 7600
- gh: 7600
- b: 7600
Shapes of parameters:
Shapes of inputs:
- wmx: 10, 1900
- wmh: 1900, 1900
- wx: 10, 7600
- wh: 1900, 7600
- gmx: 1900
- gmh: 1900
- gx: 7600
- gh: 7600
- b: 7600
- x_t: (1, 10)
- carry:
- h_t: (1, 1900)
- c_t: (1, 1900)
"""
h_t, c_t = carry

# Perform weight normalization first
# (Corresponds to Line 113).
# In the original implementation, this is toggled by a boolean flag,
# but here we are enabling it by default.
params["wx"] = l2_normalize(params["wx"], axis=0) * params["gx"]
params["wh"] = l2_normalize(params["wh"], axis=0) * params["gh"]
params["wmx"] = l2_normalize(params["wmx"], axis=0) * params["gmx"]
params["wmh"] = l2_normalize(params["wmh"], axis=0) * params["gmh"]

# Shape annotation
# (:, 10) @ (10, 1900) * (:, 1900) @ (1900, 1900) => (:, 1900)
m = np.matmul(x_t, params["wmx"]) * np.matmul(h_t, params["wmh"])

# (:, 10) @ (10, 7600) * (:, 1900) @ (1900, 7600) + (7600, ) => (:, 7600)
z = (
np.matmul(x_t, params["wx"])
+ np.matmul(m, params["wh"])
+ params["b"]
)

Shapes of inputs:
# Splitting along axis 1, four-ways, gets us (:, 1900) as the shape
# for each of i, f, o and u
i, f, o, u = np.split(z, 4, axis=-1) # input, forget, output, update

- x_t: (1, 10)
- carry:
- h_t: (1, 1900)
- c_t: (1, 1900)
"""
h_t, c_t = carry

# Perform weight normalization first
# (Corresponds to Line 113).
# In the original implementation, this is toggled by a boolean flag,
# but here we are enabling it by default.
params["wx"] = l2_normalize(params["wx"], axis=0) * params["gx"]
params["wh"] = l2_normalize(params["wh"], axis=0) * params["gh"]
params["wmx"] = l2_normalize(params["wmx"], axis=0) * params["gmx"]
params["wmh"] = l2_normalize(params["wmh"], axis=0) * params["gmh"]

# Shape annotation
# (:, 10) @ (10, 1900) * (:, 1900) @ (1900, 1900) => (:, 1900)
m = np.matmul(x_t, params["wmx"]) * np.matmul(h_t, params["wmh"])

# (:, 10) @ (10, 7600) * (:, 1900) @ (1900, 7600) + (7600, ) => (:, 7600)
z = np.matmul(x_t, params["wx"]) + np.matmul(m, params["wh"]) + params["b"]

# Splitting along axis 1, four-ways, gets us (:, 1900) as the shape
# for each of i, f, o and u
i, f, o, u = np.split(z, 4, axis=-1) # input, forget, output, update

# Elementwise transforms here.
# Shapes are are (:, 1900) for each of the four.
i = sigmoid(i, version="exp")
f = sigmoid(f, version="exp")
o = sigmoid(o, version="exp")
u = tanh(u)

# (:, 1900) * (:, 1900) + (:, 1900) * (:, 1900) => (:, 1900)
c_t = f * c_t + i * u

# (:, 1900) * (:, 1900) => (:, 1900)
h_t = o * tanh(c_t)

# h, c each have shape (:, 1900)
return (h_t, c_t), h_t # returned this way to match rest of fundl API.
# Elementwise transforms here.
# Shapes are are (:, 1900) for each of the four.
i = sigmoid(i, version="exp")
f = sigmoid(f, version="exp")
o = sigmoid(o, version="exp")
u = tanh(u)

# (:, 1900) * (:, 1900) + (:, 1900) * (:, 1900) => (:, 1900)
c_t = f * c_t + i * u

# (:, 1900) * (:, 1900) => (:, 1900)
h_t = o * tanh(c_t)

# h, c each have shape (:, 1900)
return (h_t, c_t), h_t

(h_final, c_final), outputs = lax.scan(
mLSTM1900_step, init=(h_t, c_t), xs=batch
)
return h_final, c_final, outputs
20 changes: 1 addition & 19 deletions tests/test_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
mLSTM1900_AvgHidden,
mLSTM1900_batch,
mLSTM1900_Fusion,
mLSTM1900_step,
# mLSTM1900_step,
)
from jax_unirep.utils import (
get_embedding,
Expand All @@ -23,24 +23,6 @@
rng = random.PRNGKey(0)


def test_mLSTM1900_step():
"""
Given fake data of the correct input shapes,
make sure that the output shapes are also correct.
"""
params = load_params_1900()

x_t = npr.normal(size=(1, 10))
h_t = np.zeros(shape=(1, 1900))
c_t = np.zeros(shape=(1, 1900))

carry = (h_t, c_t)

(h_t, c_t), _ = mLSTM1900_step(params, carry, x_t)
assert h_t.shape == (1, 1900)
assert c_t.shape == (1, 1900)


def test_mLSTM1900_batch():
"""
Given one fake embedded sequence,
Expand Down

0 comments on commit a11e4b5

Please sign in to comment.