Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

EMA, SAM update, cyclic LR schedule #269

Merged
merged 16 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions apax/cli/templates/train_config_full.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ patience: null
n_models: 1
n_jitted_steps: 1
data_parallel: True
weight_average: null

data:
directory: models/
Expand Down Expand Up @@ -80,9 +81,10 @@ optimizer:
scale_lr: 0.001
shift_lr: 0.05
zbl_lr: 0.001
transition_begin: 0
sam_rho: 0.0

schedule:
name: linear
transition_begin: 0
end_value: 1e-6
callbacks:
- name: csv

Expand Down
45 changes: 45 additions & 0 deletions apax/config/lr_config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
from typing import Literal

from pydantic import BaseModel, NonNegativeFloat


class LRSchedule(BaseModel, frozen=True, extra="forbid"):
name: str


class LinearLR(LRSchedule, frozen=True, extra="forbid"):
"""
Configuration of the optimizer.
Learning rates of 0 will freeze the respective parameters.

Parameters
----------
opt_name : str, default = "adam"
transition_begin: int = 0
Number of steps after which to start decreasing
end_value: NonNegativeFloat = 1e-6
Final LR at the end of training.
"""

name: Literal["linear"]
transition_begin: int = 0
end_value: NonNegativeFloat = 1e-6


class CyclicCosineLR(LRSchedule, frozen=True, extra="forbid"):
"""
Configuration of the optimizer.
Learning rates of 0 will freeze the respective parameters.

Parameters
----------
period: int = 20
Length of a cycle.
decay_factor: NonNegativeFloat = 1.0
Factor by which to decrease the LR after each cycle.
1.0 means no decrease.
"""

name: Literal["cyclic_cosine"]
period: int = 20
decay_factor: NonNegativeFloat = 1.0
31 changes: 24 additions & 7 deletions apax/config/train_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
)
from typing_extensions import Annotated

from apax.config.lr_config import CyclicCosineLR, LinearLR
from apax.data.statistics import scale_method_list, shift_method_list

log = logging.getLogger(__name__)
Expand Down Expand Up @@ -216,13 +217,10 @@ class OptimizerConfig(BaseModel, frozen=True, extra="forbid"):
Learning rate of the elemental output shifts.
zbl_lr : NonNegativeFloat, default = 0.001
Learning rate of the ZBL correction parameters.
transition_begin : int, default = 0
Number of training steps (not epochs) before the start of the linear
learning rate schedule.
schedule : LRSchedule = LinearLR
Learning rate schedule.
opt_kwargs : dict, default = {}
Optimizer keyword arguments. Passed to the `optax` optimizer.
sam_rho : NonNegativeFloat, default = 0.0
Rho parameter for Sharpness-Aware Minimization.
"""

opt_name: str = "adam"
Expand All @@ -231,9 +229,10 @@ class OptimizerConfig(BaseModel, frozen=True, extra="forbid"):
scale_lr: NonNegativeFloat = 0.001
shift_lr: NonNegativeFloat = 0.05
zbl_lr: NonNegativeFloat = 0.001
transition_begin: int = 0
schedule: Union[LinearLR, CyclicCosineLR] = Field(
LinearLR(name="linear"), discriminator="name"
)
opt_kwargs: dict = {}
sam_rho: NonNegativeFloat = 0.0


class MetricsConfig(BaseModel, extra="forbid"):
Expand Down Expand Up @@ -357,6 +356,21 @@ class CheckpointConfig(BaseModel, extra="forbid"):
reset_layers: List[str] = []


class WeightAverage(BaseModel, extra="forbid"):
"""Applies an exponential moving average to model parameters.

Parameters
----------
ema_start : int, default = 1
Epoch at which to start averaging models.
alpha : float, default = 0.9
How much of the new model to use. 1.0 would mean no averaging, 0.0 no updates.
"""

ema_start: int = 0
alpha: float = 0.9


class Config(BaseModel, frozen=True, extra="forbid"):
"""
Main configuration of a apax training run. Parameter that are config classes will
Expand Down Expand Up @@ -396,6 +410,8 @@ class Config(BaseModel, frozen=True, extra="forbid"):
| Loss configuration.
optimizer : :class:`.OptimizerConfig`
| Loss optimizer configuration.
weight_average : :class:`.WeightAverage`, optional
| Options for averaging weights between epochs.
callbacks : List of various CallBack classes
| Possible callbacks are :class:`.CSVCallback`,
| :class:`.TBCallback`, :class:`.MLFlowCallback`
Expand All @@ -420,6 +436,7 @@ class Config(BaseModel, frozen=True, extra="forbid"):
metrics: List[MetricsConfig] = []
loss: List[LossConfig]
optimizer: OptimizerConfig = OptimizerConfig()
weight_average: Optional[WeightAverage] = None
callbacks: List[CallBack] = [CSVCallback(name="csv")]
progress_bar: TrainProgressbarConfig = TrainProgressbarConfig()
checkpoints: CheckpointConfig = CheckpointConfig()
Expand Down
6 changes: 2 additions & 4 deletions apax/layers/empirical.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,6 @@ class ZBLRepulsion(EmpiricalEnergyTerm):
def setup(self):
self.distance = vmap(space.distance, 0, 0)

self.ke = 14.3996

a_exp = 0.23
a_num = 0.46850
coeffs = jnp.array([0.18175, 0.50986, 0.28022, 0.02817])[:, None]
Expand All @@ -37,7 +35,7 @@ def setup(self):
a_num_isp = inverse_softplus(a_num)
coeffs_isp = inverse_softplus(coeffs)
exps_isp = inverse_softplus(exps)
rep_scale_isp = inverse_softplus(1.0 / self.ke)
rep_scale_isp = inverse_softplus(0.0)

self.a_exp = self.param("a_exp", nn.initializers.constant(a_exp_isp), (1,))
self.a_num = self.param("a_num", nn.initializers.constant(a_num_isp), (1,))
Expand Down Expand Up @@ -86,5 +84,5 @@ def __call__(self, dr_vec, Z, idx):
E_ij = Z_i * Z_j / dr * f * cos_cutoff
if self.apply_mask:
E_ij = mask_by_neighbor(E_ij, idx)
E = 0.5 * rep_scale * self.ke * fp64_sum(E_ij)
E = 0.5 * rep_scale * fp64_sum(E_ij)
return fp64_sum(E)
95 changes: 65 additions & 30 deletions apax/optimizer/get_optimizer.py
Original file line number Diff line number Diff line change
@@ -1,82 +1,117 @@
import logging
from typing import Any, Callable

import jax.numpy as jnp
import numpy as np
import optax
from flax import traverse_util
from flax.core.frozen_dict import freeze
from optax import contrib
from optax._src import base

log = logging.getLogger(__name__)


def map_nested_fn(fn: Callable[[str, Any], dict]) -> Callable[[dict], dict]:
"""
Recursively apply `fn` to the key-value pairs of a nested dict
See
https://optax.readthedocs.io/en/latest/api.html?highlight=multitransform#multi-transform
def sam(lr=1e-3, b1=0.9, b2=0.999, rho=0.001, sync_period=2):
"""A SAM optimizer using Adam for the outer optimizer."""
opt = optax.adam(lr, b1=b1, b2=b2)
adv_opt = optax.chain(contrib.normalize(), optax.sgd(rho))
return contrib.sam(opt, adv_opt, sync_period=sync_period)


def cyclic_cosine_decay_schedule(
init_value: float,
steps_per_epoch,
period: int,
decay_factor: float = 0.9,
) -> base.Schedule:
r"""Returns a function which implements cyclic cosine learning rate decay.

Args:
init_value: An initial value for the learning rate.

Returns:
schedule
A function that maps step counts to values.
"""

def map_fn(nested_dict):
return {
k: map_fn(v) if isinstance(v, dict) else fn(k, v)
for k, v in nested_dict.items()
}
def schedule(count):
cycle = count // (period * steps_per_epoch)
step_in_period = jnp.mod(count, period * steps_per_epoch)
lr = (
init_value
/ 2
* (jnp.cos(np.pi * step_in_period / (period * steps_per_epoch)) + 1)
)
lr = lr * (decay_factor**cycle)
return lr

return map_fn
return schedule


def get_schedule(
lr: float, transition_begin: int, transition_steps: int
lr: float,
n_epochs: int,
steps_per_epoch: int,
schedule_kwargs: dict,
) -> optax._src.base.Schedule:
"""
builds a linear learning rate schedule.
"""
lr_schedule = optax.linear_schedule(
init_value=lr,
end_value=1e-6,
transition_begin=transition_begin,
transition_steps=transition_steps,
)
schedule_kwargs = schedule_kwargs.copy()
name = schedule_kwargs.pop("name")
if name == "linear":
lr_schedule = optax.linear_schedule(
init_value=lr, transition_steps=n_epochs * steps_per_epoch, **schedule_kwargs
)
elif name == "cyclic_cosine":
lr_schedule = cyclic_cosine_decay_schedule(lr, steps_per_epoch, **schedule_kwargs)
else:
raise KeyError(f"unknown learning rate schedule: {name}")
return lr_schedule


def make_optimizer(opt, lr, transition_begin, transition_steps, opt_kwargs):
def make_optimizer(opt, lr, n_epochs, steps_per_epoch, opt_kwargs, schedule):
if lr <= 1e-7:
optimizer = optax.set_to_zero()
else:
schedule = get_schedule(lr, transition_begin, transition_steps)
schedule = get_schedule(lr, n_epochs, steps_per_epoch, schedule)
optimizer = opt(schedule, **opt_kwargs)
return optimizer


def get_opt(
params,
transition_begin: int,
transition_steps: int,
n_epochs: int,
steps_per_epoch: int,
emb_lr: float = 0.02,
nn_lr: float = 0.03,
scale_lr: float = 0.001,
shift_lr: float = 0.05,
zbl_lr: float = 0.001,
opt_name: str = "adam",
opt_kwargs: dict = {},
**kwargs,
schedule: dict = {},
) -> optax._src.base.GradientTransformation:
"""
Builds an optimizer with different learning rates for each parameter group.
Several `optax` optimizers are supported.
"""

log.info("Initializing Optimizer")
opt = getattr(optax, opt_name)
if opt_name == "sam":
opt = sam
else:
opt = getattr(optax, opt_name)

nn_opt = make_optimizer(opt, nn_lr, transition_begin, transition_steps, opt_kwargs)
emb_opt = make_optimizer(opt, emb_lr, transition_begin, transition_steps, opt_kwargs)
nn_opt = make_optimizer(opt, nn_lr, n_epochs, steps_per_epoch, opt_kwargs, schedule)
emb_opt = make_optimizer(opt, emb_lr, n_epochs, steps_per_epoch, opt_kwargs, schedule)
scale_opt = make_optimizer(
opt, scale_lr, transition_begin, transition_steps, opt_kwargs
opt, scale_lr, n_epochs, steps_per_epoch, opt_kwargs, schedule
)
shift_opt = make_optimizer(
opt, shift_lr, transition_begin, transition_steps, opt_kwargs
opt, shift_lr, n_epochs, steps_per_epoch, opt_kwargs, schedule
)
zbl_opt = make_optimizer(opt, zbl_lr, transition_begin, transition_steps, opt_kwargs)
zbl_opt = make_optimizer(opt, zbl_lr, n_epochs, steps_per_epoch, opt_kwargs, schedule)

partition_optimizers = {
"w": nn_opt,
Expand Down
2 changes: 1 addition & 1 deletion apax/train/eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ def predict(model, params, Metrics, loss_fn, test_ds, callbacks, is_ensemble=Fal

callbacks.on_train_begin()
_, test_step_fn = make_step_fns(
loss_fn, Metrics, model=model, sam_rho=0.0, is_ensemble=is_ensemble
loss_fn, Metrics, model=model, is_ensemble=is_ensemble
)

batch_test_ds = test_ds.batch()
Expand Down
32 changes: 32 additions & 0 deletions apax/train/parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import jax


@jax.jit
def tree_ema(tree1, tree2, alpha):
"""Exponential moving average of two pytrees.
"""
ema = jax.tree_map(lambda a, b: alpha * a + (1 - alpha) * b, tree1, tree2)
return ema


class EMAParameters:
"""Handler for tracking an exponential moving average of model parameters.
The EMA parameters are used in the valitaion loop.

Parameters
----------
ema_start : int, default = 1
Epoch at which to start averaging models.
alpha : float, default = 0.9
How much of the new model to use. 1.0 would mean no averaging, 0.0 no updates.
"""
def __init__(self, ema_start: int , alpha: float = 0.9) -> None:
self.alpha = alpha
self.ema_start = ema_start
self.ema_params = None

def update(self, opt_params, epoch):
if epoch > self.ema_start:
self.ema_params = tree_ema(opt_params, self.ema_params, self.alpha)
else:
self.ema_params = opt_params
Loading
Loading