Skip to content

Commit

Permalink
make compatible with latest version of optax
Browse files Browse the repository at this point in the history
  • Loading branch information
rowanz committed Jan 25, 2022
1 parent 8dc049a commit 008c3c1
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 3 deletions.
2 changes: 1 addition & 1 deletion finetune/optimization.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from pretrain.optimization import *


class DecayedWeightsDeltaState(OptState):
class DecayedWeightsDeltaState(NamedTuple):
"""Overall state of the gradient transformation."""
orig_params: chex.Array # Momentum

Expand Down
4 changes: 2 additions & 2 deletions pretrain/optimization.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import optax
from optax import OptState, GradientTransformation
from optax import GradientTransformation
from optax._src.base import NO_PARAMS_MSG
import jax
import chex
Expand All @@ -14,7 +14,7 @@
from typing import NamedTuple, Any


class ScaleByAdamState(OptState):
class ScaleByAdamState(NamedTuple):
"""State for the Adam algorithm."""
count: chex.Array
mu: optax.Updates
Expand Down

0 comments on commit 008c3c1

Please sign in to comment.