From 2d6cd0abfe781752d4977503845d76c42d55cb46 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sat, 20 Apr 2024 20:14:12 +0200 Subject: [PATCH 01/15] ema test implementation --- apax/train/trainer.py | 31 ++++++++++++++++++++++++++++++- 1 file changed, 30 insertions(+), 1 deletion(-) diff --git a/apax/train/trainer.py b/apax/train/trainer.py index f3add601..c9cdce4b 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -17,6 +17,22 @@ log = logging.getLogger(__name__) +@jax.jit +def tree_ema(tree1, tree2, alpha): + ema = jax.tree_map(lambda a,b: alpha * a + (1-alpha) * b, tree1, tree2) + return ema + + + +class EMAParameters: + def __init__(self, params, alpha) -> None: + self.alpha = alpha + self.ema_params = params + + def update(self, opt_params): + self.ema_params = tree_ema(opt_params, self.ema_params, self.alpha) + + def fit( state, @@ -90,6 +106,11 @@ def fit( raise ValueError( f"n_epochs <= current epoch from checkpoint ({n_epochs} <= {start_epoch})" ) + + ema = False + if ema: + alpha = 0.9 + ema_handler = EMAParameters(state.params, alpha) devices = len(jax.devices()) if devices > 1 and data_parallel: @@ -152,6 +173,13 @@ def fit( for key, val in train_batch_metrics.compute().items() } + if ema: + ema_handler.update(state.params) + val_params = ema_handler.ema_params + else: + val_params = state.params + + if val_ds is not None: epoch_loss.update({"val_loss": 0.0}) val_batch_metrics = Metrics.empty() @@ -168,8 +196,9 @@ def fit( for batch_idx in range(val_steps_per_epoch): batch = next(batch_val_ds) + batch_loss, val_batch_metrics = val_step( - state.params, batch, val_batch_metrics + val_params, batch, val_batch_metrics ) epoch_loss["val_loss"] += batch_loss batch_pbar.update() From 1ee39a8d41f56d51bdf9263d20b747f57122b9e6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sat, 20 Apr 2024 21:50:53 +0200 Subject: [PATCH 02/15] removed deprecated fn, added sam contrib optimizer --- apax/optimizer/get_optimizer.py | 26 +++++++++++--------------- 1 file changed, 11 insertions(+), 15 deletions(-) diff --git a/apax/optimizer/get_optimizer.py b/apax/optimizer/get_optimizer.py index d079e9bf..5e351654 100644 --- a/apax/optimizer/get_optimizer.py +++ b/apax/optimizer/get_optimizer.py @@ -2,26 +2,18 @@ from typing import Any, Callable import optax +from optax import contrib from flax import traverse_util from flax.core.frozen_dict import freeze log = logging.getLogger(__name__) -def map_nested_fn(fn: Callable[[str, Any], dict]) -> Callable[[dict], dict]: - """ - Recursively apply `fn` to the key-value pairs of a nested dict - See - https://optax.readthedocs.io/en/latest/api.html?highlight=multitransform#multi-transform - """ - - def map_fn(nested_dict): - return { - k: map_fn(v) if isinstance(v, dict) else fn(k, v) - for k, v in nested_dict.items() - } - - return map_fn +def sam(lr=1e-3, b1=0.9, b2=0.999, rho=0.001, sync_period=2): + """A SAM optimizer using Adam for the outer optimizer.""" + opt = optax.adam(lr, b1=b1, b2=b2) + adv_opt = optax.chain(contrib.normalize(), optax.sgd(rho)) + return contrib.sam(opt, adv_opt, sync_period=sync_period) def get_schedule( @@ -66,7 +58,11 @@ def get_opt( Several `optax` optimizers are supported. """ log.info("Initializing Optimizer") - opt = getattr(optax, opt_name) + if opt_name == "sam": + opt = sam + else: + print("optname") + opt = getattr(optax, opt_name) nn_opt = make_optimizer(opt, nn_lr, transition_begin, transition_steps, opt_kwargs) emb_opt = make_optimizer(opt, emb_lr, transition_begin, transition_steps, opt_kwargs) From 11ab96f4f8eb42bf1eb8b29b1ae36e295a0c0a87 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Sat, 20 Apr 2024 21:53:09 +0200 Subject: [PATCH 03/15] removed custom sam optimizer --- apax/cli/templates/train_config_full.yaml | 1 - apax/config/train_config.py | 3 --- apax/train/eval.py | 2 +- apax/train/run.py | 1 - apax/train/trainer.py | 21 +++++---------------- 5 files changed, 6 insertions(+), 22 deletions(-) diff --git a/apax/cli/templates/train_config_full.yaml b/apax/cli/templates/train_config_full.yaml index a5fe36d5..1394bb1c 100644 --- a/apax/cli/templates/train_config_full.yaml +++ b/apax/cli/templates/train_config_full.yaml @@ -81,7 +81,6 @@ optimizer: shift_lr: 0.05 zbl_lr: 0.001 transition_begin: 0 - sam_rho: 0.0 callbacks: - name: csv diff --git a/apax/config/train_config.py b/apax/config/train_config.py index 388c0776..077983f9 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -221,8 +221,6 @@ class OptimizerConfig(BaseModel, frozen=True, extra="forbid"): learning rate schedule. opt_kwargs : dict, default = {} Optimizer keyword arguments. Passed to the `optax` optimizer. - sam_rho : NonNegativeFloat, default = 0.0 - Rho parameter for Sharpness-Aware Minimization. """ opt_name: str = "adam" @@ -233,7 +231,6 @@ class OptimizerConfig(BaseModel, frozen=True, extra="forbid"): zbl_lr: NonNegativeFloat = 0.001 transition_begin: int = 0 opt_kwargs: dict = {} - sam_rho: NonNegativeFloat = 0.0 class MetricsConfig(BaseModel, extra="forbid"): diff --git a/apax/train/eval.py b/apax/train/eval.py index 82abf6a4..5996150b 100644 --- a/apax/train/eval.py +++ b/apax/train/eval.py @@ -110,7 +110,7 @@ def predict(model, params, Metrics, loss_fn, test_ds, callbacks, is_ensemble=Fal callbacks.on_train_begin() _, test_step_fn = make_step_fns( - loss_fn, Metrics, model=model, sam_rho=0.0, is_ensemble=is_ensemble + loss_fn, Metrics, model=model, is_ensemble=is_ensemble ) batch_test_ds = test_ds.batch() diff --git a/apax/train/run.py b/apax/train/run.py index 088a4d62..4ddd9fa9 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -198,7 +198,6 @@ def run(user_config: Union[str, os.PathLike, dict], log_level="error"): ckpt_dir=config.data.model_version_path, ckpt_interval=config.checkpoints.ckpt_interval, val_ds=val_ds, - sam_rho=config.optimizer.sam_rho, patience=config.patience, disable_pbar=config.progress_bar.disable_epoch_pbar, disable_batch_pbar=config.progress_bar.disable_batch_pbar, diff --git a/apax/train/trainer.py b/apax/train/trainer.py index c9cdce4b..a57f4b4c 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -6,6 +6,7 @@ import jax import jax.numpy as jnp +from flax.training.train_state import TrainState import numpy as np from clu import metrics from jax.experimental import mesh_utils @@ -35,7 +36,7 @@ def update(self, opt_params): def fit( - state, + state: TrainState, train_ds: InMemoryDataset, loss_fn, Metrics: metrics.Collection, @@ -44,7 +45,6 @@ def fit( ckpt_dir, ckpt_interval: int = 1, val_ds: Optional[InMemoryDataset] = None, - sam_rho=0.0, patience: Optional[int] = None, disable_pbar: bool = False, disable_batch_pbar: bool = True, @@ -74,8 +74,6 @@ def fit( Interval for saving checkpoints. val_ds : InMemoryDataset, default = None Validation dataset. - sam_rho : float, default = 0.0 - Rho parameter for Sharpness-Aware Minimization. patience : int, default = None Patience for early stopping. disable_pbar : bool, default = False @@ -96,7 +94,7 @@ def fit( ckpt_manager = CheckpointManager() train_step, val_step = make_step_fns( - loss_fn, Metrics, model=state.apply_fn, sam_rho=sam_rho, is_ensemble=is_ensemble + loss_fn, Metrics, model=state.apply_fn, is_ensemble=is_ensemble ) if train_ds.n_jit_steps > 1: train_step = jax.jit(functools.partial(jax.lax.scan, train_step)) @@ -107,7 +105,7 @@ def fit( f"n_epochs <= current epoch from checkpoint ({n_epochs} <= {start_epoch})" ) - ema = False + ema = True if ema: alpha = 0.9 ema_handler = EMAParameters(state.params, alpha) @@ -302,21 +300,12 @@ def ensemble_eval_fn(state, inputs, labels): return ensemble_eval_fn -def make_step_fns(loss_fn, Metrics, model, sam_rho, is_ensemble): +def make_step_fns(loss_fn, Metrics, model, is_ensemble): loss_calculator = partial(calc_loss, loss_fn=loss_fn, model=model) grad_fn = jax.value_and_grad(loss_calculator, 0, has_aux=True) - rho = sam_rho def update_step(state, inputs, labels): (loss, predictions), grads = grad_fn(state.params, inputs, labels) - - if rho > 1e-6: - # SAM step - grad_norm = global_norm(grads) - eps = jax.tree_map(lambda g, n: g * rho / n, grads, grad_norm) - params_eps = jax.tree_map(lambda p, e: p + e, state.params, eps) - (loss, _), grads = grad_fn(params_eps, inputs, labels) # maybe get rid of SAM - state = state.apply_gradients(grads=grads) return loss, predictions, state From 99cf4a04c67e7b376c8602406453f8fa3261b01d Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 22 Apr 2024 08:33:41 +0000 Subject: [PATCH 04/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- apax/train/trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/train/trainer.py b/apax/train/trainer.py index a57f4b4c..1a511de5 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -104,7 +104,7 @@ def fit( raise ValueError( f"n_epochs <= current epoch from checkpoint ({n_epochs} <= {start_epoch})" ) - + ema = True if ema: alpha = 0.9 From 2448e5b4bbc7544f49dd2683b8c94990a41b439d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 22 Apr 2024 20:53:16 +0200 Subject: [PATCH 05/15] exposed EMA to config --- apax/config/train_config.py | 9 ++++++++ apax/train/parameters.py | 19 ++++++++++++++++ apax/train/run.py | 14 ++++++++---- apax/train/trainer.py | 44 ++++++------------------------------- 4 files changed, 45 insertions(+), 41 deletions(-) create mode 100644 apax/train/parameters.py diff --git a/apax/config/train_config.py b/apax/config/train_config.py index 077983f9..eb35b3a6 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -354,6 +354,14 @@ class CheckpointConfig(BaseModel, extra="forbid"): reset_layers: List[str] = [] +class WeightAverage(BaseModel, extra="forbid"): + """ + """ + ema_start: int + alpha: float = 0.9 + + + class Config(BaseModel, frozen=True, extra="forbid"): """ Main configuration of a apax training run. Parameter that are config classes will @@ -417,6 +425,7 @@ class Config(BaseModel, frozen=True, extra="forbid"): metrics: List[MetricsConfig] = [] loss: List[LossConfig] optimizer: OptimizerConfig = OptimizerConfig() + weight_average: Optional[WeightAverage] = None callbacks: List[CallBack] = [CSVCallback(name="csv")] progress_bar: TrainProgressbarConfig = TrainProgressbarConfig() checkpoints: CheckpointConfig = CheckpointConfig() diff --git a/apax/train/parameters.py b/apax/train/parameters.py new file mode 100644 index 00000000..8011bbcc --- /dev/null +++ b/apax/train/parameters.py @@ -0,0 +1,19 @@ +import jax + +@jax.jit +def tree_ema(tree1, tree2, alpha): + ema = jax.tree_map(lambda a,b: alpha * a + (1-alpha) * b, tree1, tree2) + return ema + + +class EMAParameters: + def __init__(self, ema_start, alpha) -> None: + self.alpha = alpha + self.ema_start = ema_start + self.ema_params = None + + def update(self, opt_params, epoch): + if epoch > self.ema_start: + self.ema_params = tree_ema(opt_params, self.ema_params, self.alpha) + else: + self.ema_params = opt_params diff --git a/apax/train/run.py b/apax/train/run.py index 4ddd9fa9..baae4e6d 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -16,6 +16,7 @@ from apax.train.loss import Loss, LossCollection from apax.train.metrics import initialize_metrics from apax.train.trainer import fit +from apax.train.parameters import EMAParameters from apax.transfer_learning import transfer_parameters from apax.utils.random import seed_py_np_tf @@ -173,11 +174,10 @@ def run(user_config: Union[str, os.PathLike, dict], log_level="error"): # TODO rework optimizer initialization and lr keywords steps_per_epoch = train_ds.steps_per_epoch() - n_epochs = config.n_epochs - transition_steps = steps_per_epoch * n_epochs - config.optimizer.transition_begin tx = get_opt( params, - transition_steps=transition_steps, + config.n_epochs, + steps_per_epoch, **config.optimizer.model_dump(), ) @@ -188,13 +188,18 @@ def run(user_config: Union[str, os.PathLike, dict], log_level="error"): if do_transfer_learning: state = transfer_parameters(state, config.checkpoints) + if config.weight_average: + ema_handler = EMAParameters(config.weight_average.ema_start, config.weight_average.alpha) + else: + ema_handler = None + fit( state, train_ds, loss_fn, Metrics, callbacks, - n_epochs, + config.n_epochs, ckpt_dir=config.data.model_version_path, ckpt_interval=config.checkpoints.ckpt_interval, val_ds=val_ds, @@ -203,5 +208,6 @@ def run(user_config: Union[str, os.PathLike, dict], log_level="error"): disable_batch_pbar=config.progress_bar.disable_batch_pbar, is_ensemble=config.n_models > 1, data_parallel=config.data_parallel, + ema_handler=ema_handler, ) log.info("Finished training") diff --git a/apax/train/trainer.py b/apax/train/trainer.py index a57f4b4c..e1a41bb7 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -15,25 +15,10 @@ from apax.data.input_pipeline import InMemoryDataset from apax.train.checkpoints import CheckpointManager, load_state +from apax.train.parameters import EMAParameters log = logging.getLogger(__name__) -@jax.jit -def tree_ema(tree1, tree2, alpha): - ema = jax.tree_map(lambda a,b: alpha * a + (1-alpha) * b, tree1, tree2) - return ema - - - -class EMAParameters: - def __init__(self, params, alpha) -> None: - self.alpha = alpha - self.ema_params = params - - def update(self, opt_params): - self.ema_params = tree_ema(opt_params, self.ema_params, self.alpha) - - def fit( state: TrainState, @@ -50,6 +35,7 @@ def fit( disable_batch_pbar: bool = True, is_ensemble=False, data_parallel=True, + ema_handler: Optional[EMAParameters]= None, ): """ Trains the model using the provided training dataset. @@ -104,11 +90,6 @@ def fit( raise ValueError( f"n_epochs <= current epoch from checkpoint ({n_epochs} <= {start_epoch})" ) - - ema = True - if ema: - alpha = 0.9 - ema_handler = EMAParameters(state.params, alpha) devices = len(jax.devices()) if devices > 1 and data_parallel: @@ -134,6 +115,9 @@ def fit( epoch_start_time = time.time() callbacks.on_epoch_begin(epoch=epoch + 1) + if ema_handler: + ema_handler.update(state.params, epoch) + epoch_loss.update({"train_loss": 0.0}) train_batch_metrics = Metrics.empty() @@ -171,13 +155,12 @@ def fit( for key, val in train_batch_metrics.compute().items() } - if ema: - ema_handler.update(state.params) + if ema_handler: + ema_handler.update(state.params, epoch) val_params = ema_handler.ema_params else: val_params = state.params - if val_ds is not None: epoch_loss.update({"val_loss": 0.0}) val_batch_metrics = Metrics.empty() @@ -194,7 +177,6 @@ def fit( for batch_idx in range(val_steps_per_epoch): batch = next(batch_val_ds) - batch_loss, val_batch_metrics = val_step( val_params, batch, val_batch_metrics ) @@ -246,18 +228,6 @@ def fit( val_ds.cleanup() -def global_norm(updates) -> jnp.ndarray: - """ - Returns the l2 norm of the input. - - Parameters - ---------- - updates: A pytree of ndarrays representing the gradient. - """ - norm = jax.tree_map(lambda u: jnp.sqrt(jnp.sum(jnp.square(u))), updates) - return norm - - def calc_loss(params, inputs, labels, loss_fn, model): R, Z, idx, box, offsets = ( inputs["positions"], From cfc192c42a1de580f48e24f9847a0b57978c492f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 22 Apr 2024 20:53:47 +0200 Subject: [PATCH 06/15] added cyclic cosine schedule --- apax/optimizer/get_optimizer.py | 69 +++++++++++++++++++++++++-------- 1 file changed, 52 insertions(+), 17 deletions(-) diff --git a/apax/optimizer/get_optimizer.py b/apax/optimizer/get_optimizer.py index 5e351654..e5f56372 100644 --- a/apax/optimizer/get_optimizer.py +++ b/apax/optimizer/get_optimizer.py @@ -1,7 +1,9 @@ import logging -from typing import Any, Callable +import jax.numpy as jnp +import numpy as np import optax +from optax._src import base from optax import contrib from flax import traverse_util from flax.core.frozen_dict import freeze @@ -16,34 +18,67 @@ def sam(lr=1e-3, b1=0.9, b2=0.999, rho=0.001, sync_period=2): return contrib.sam(opt, adv_opt, sync_period=sync_period) +def cyclic_cosine_decay_schedule( + init_value: float, + epochs, + steps_per_epoch, + period: int, + amplitude_factor: float = 0.9, + ) -> base.Schedule: + r"""Returns a function which implements cyclic cosine learning rate decay. + + Args: + init_value: An initial value for the learning rate. + + Returns: + schedule + A function that maps step counts to values. + """ + def schedule(count): + + cycle = count // (period * steps_per_epoch) + step_in_period = jnp.mod(count, period * steps_per_epoch) + lr = init_value/2 * (jnp.cos(np.pi * step_in_period / (period* steps_per_epoch)) + 1) + lr = lr * (amplitude_factor**cycle) + return lr + + return schedule + + + + def get_schedule( - lr: float, transition_begin: int, transition_steps: int + lr: float, n_epochs: int, steps_per_epoch: int, ) -> optax._src.base.Schedule: """ builds a linear learning rate schedule. """ - lr_schedule = optax.linear_schedule( - init_value=lr, - end_value=1e-6, - transition_begin=transition_begin, - transition_steps=transition_steps, + # lr_schedule = optax.linear_schedule( + # init_value=lr, + # end_value=1e-6, + # transition_begin=0, + # transition_steps=n_epochs *steps_per_epoch, + # ) + + lr_schedule = cyclic_cosine_decay_schedule( + lr, n_epochs, steps_per_epoch, 20, 0.95, ) return lr_schedule -def make_optimizer(opt, lr, transition_begin, transition_steps, opt_kwargs): +def make_optimizer(opt, lr, n_epochs, steps_per_epoch, opt_kwargs): if lr <= 1e-7: optimizer = optax.set_to_zero() else: - schedule = get_schedule(lr, transition_begin, transition_steps) + schedule = get_schedule(lr, n_epochs, steps_per_epoch) optimizer = opt(schedule, **opt_kwargs) return optimizer def get_opt( params, - transition_begin: int, - transition_steps: int, + n_epochs: int, + steps_per_epoch: int, emb_lr: float = 0.02, nn_lr: float = 0.03, scale_lr: float = 0.001, @@ -57,22 +92,22 @@ def get_opt( Builds an optimizer with different learning rates for each parameter group. Several `optax` optimizers are supported. """ + log.info("Initializing Optimizer") if opt_name == "sam": opt = sam else: - print("optname") opt = getattr(optax, opt_name) - nn_opt = make_optimizer(opt, nn_lr, transition_begin, transition_steps, opt_kwargs) - emb_opt = make_optimizer(opt, emb_lr, transition_begin, transition_steps, opt_kwargs) + nn_opt = make_optimizer(opt, nn_lr, n_epochs, steps_per_epoch, opt_kwargs) + emb_opt = make_optimizer(opt, emb_lr, n_epochs, steps_per_epoch, opt_kwargs) scale_opt = make_optimizer( - opt, scale_lr, transition_begin, transition_steps, opt_kwargs + opt, scale_lr, n_epochs, steps_per_epoch, opt_kwargs ) shift_opt = make_optimizer( - opt, shift_lr, transition_begin, transition_steps, opt_kwargs + opt, shift_lr, n_epochs, steps_per_epoch, opt_kwargs ) - zbl_opt = make_optimizer(opt, zbl_lr, transition_begin, transition_steps, opt_kwargs) + zbl_opt = make_optimizer(opt, zbl_lr, n_epochs, steps_per_epoch, opt_kwargs) partition_optimizers = { "w": nn_opt, From 8bd10cdd49f1e5efbf4da0d730ed8bd105535f36 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 22 Apr 2024 21:28:20 +0200 Subject: [PATCH 07/15] exposed lr schedules to config --- apax/config/lr_config.py | 45 ++++++++++++++++++++++ apax/config/train_config.py | 15 ++++---- apax/optimizer/get_optimizer.py | 68 +++++++++++++++++---------------- 3 files changed, 89 insertions(+), 39 deletions(-) create mode 100644 apax/config/lr_config.py diff --git a/apax/config/lr_config.py b/apax/config/lr_config.py new file mode 100644 index 00000000..b64f4f8c --- /dev/null +++ b/apax/config/lr_config.py @@ -0,0 +1,45 @@ +from typing import Literal + +from pydantic import BaseModel, NonNegativeFloat + + +class LRSchedule(BaseModel, frozen=True, extra="forbid"): + name: str + + +class LinearLR(LRSchedule, frozen=True, extra="forbid"): + """ + Configuration of the optimizer. + Learning rates of 0 will freeze the respective parameters. + + Parameters + ---------- + opt_name : str, default = "adam" + transition_begin: int = 0 + Number of steps after which to start decreasing + end_value: NonNegativeFloat = 1e-6 + Final LR at the end of training. + """ + + name: Literal["linear"] + transition_begin: int = 0 + end_value: NonNegativeFloat = 1e-6 + + +class CyclicCosineLR(LRSchedule, frozen=True, extra="forbid"): + """ + Configuration of the optimizer. + Learning rates of 0 will freeze the respective parameters. + + Parameters + ---------- + period: int = 20 + Length of a cycle. + decay_factor: NonNegativeFloat = 1.0 + Factor by which to decrease the LR after each cycle. + 1.0 means no decrease. + """ + + name: Literal["cyclic_cosine"] + period: int = 20 + decay_factor: NonNegativeFloat = 1.0 diff --git a/apax/config/train_config.py b/apax/config/train_config.py index eb35b3a6..073fbd49 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -16,6 +16,7 @@ ) from typing_extensions import Annotated +from apax.config.lr_config import CyclicCosineLR, LinearLR from apax.data.statistics import scale_method_list, shift_method_list log = logging.getLogger(__name__) @@ -216,9 +217,8 @@ class OptimizerConfig(BaseModel, frozen=True, extra="forbid"): Learning rate of the elemental output shifts. zbl_lr : NonNegativeFloat, default = 0.001 Learning rate of the ZBL correction parameters. - transition_begin : int, default = 0 - Number of training steps (not epochs) before the start of the linear - learning rate schedule. + schedule : LRSchedule = LinearLR + Learning rate schedule. opt_kwargs : dict, default = {} Optimizer keyword arguments. Passed to the `optax` optimizer. """ @@ -229,7 +229,9 @@ class OptimizerConfig(BaseModel, frozen=True, extra="forbid"): scale_lr: NonNegativeFloat = 0.001 shift_lr: NonNegativeFloat = 0.05 zbl_lr: NonNegativeFloat = 0.001 - transition_begin: int = 0 + schedule: Union[LinearLR, CyclicCosineLR] = Field( + LinearLR(name="linear"), discriminator="name" + ) opt_kwargs: dict = {} @@ -355,11 +357,10 @@ class CheckpointConfig(BaseModel, extra="forbid"): class WeightAverage(BaseModel, extra="forbid"): - """ - """ + """ """ + ema_start: int alpha: float = 0.9 - class Config(BaseModel, frozen=True, extra="forbid"): diff --git a/apax/optimizer/get_optimizer.py b/apax/optimizer/get_optimizer.py index e5f56372..d9ea4c22 100644 --- a/apax/optimizer/get_optimizer.py +++ b/apax/optimizer/get_optimizer.py @@ -3,10 +3,10 @@ import jax.numpy as jnp import numpy as np import optax -from optax._src import base -from optax import contrib from flax import traverse_util from flax.core.frozen_dict import freeze +from optax import contrib +from optax._src import base log = logging.getLogger(__name__) @@ -19,12 +19,11 @@ def sam(lr=1e-3, b1=0.9, b2=0.999, rho=0.001, sync_period=2): def cyclic_cosine_decay_schedule( - init_value: float, - epochs, - steps_per_epoch, - period: int, - amplitude_factor: float = 0.9, - ) -> base.Schedule: + init_value: float, + steps_per_epoch, + period: int, + decay_factor: float = 0.9, +) -> base.Schedule: r"""Returns a function which implements cyclic cosine learning rate decay. Args: @@ -34,43 +33,48 @@ def cyclic_cosine_decay_schedule( schedule A function that maps step counts to values. """ + def schedule(count): - cycle = count // (period * steps_per_epoch) step_in_period = jnp.mod(count, period * steps_per_epoch) - lr = init_value/2 * (jnp.cos(np.pi * step_in_period / (period* steps_per_epoch)) + 1) - lr = lr * (amplitude_factor**cycle) + lr = ( + init_value + / 2 + * (jnp.cos(np.pi * step_in_period / (period * steps_per_epoch)) + 1) + ) + lr = lr * (decay_factor**cycle) return lr return schedule - - def get_schedule( - lr: float, n_epochs: int, steps_per_epoch: int, + lr: float, + n_epochs: int, + steps_per_epoch: int, + schedule_kwargs: dict, ) -> optax._src.base.Schedule: """ builds a linear learning rate schedule. """ - # lr_schedule = optax.linear_schedule( - # init_value=lr, - # end_value=1e-6, - # transition_begin=0, - # transition_steps=n_epochs *steps_per_epoch, - # ) - - lr_schedule = cyclic_cosine_decay_schedule( - lr, n_epochs, steps_per_epoch, 20, 0.95, - ) + schedule_kwargs = schedule_kwargs.copy() + name = schedule_kwargs.pop("name") + if name == "linear": + lr_schedule = optax.linear_schedule( + init_value=lr, transition_steps=n_epochs * steps_per_epoch, **schedule_kwargs + ) + elif name == "cyclic_cosine": + lr_schedule = cyclic_cosine_decay_schedule(lr, steps_per_epoch, **schedule_kwargs) + else: + raise KeyError(f"unknown learning rate schedule: {name}") return lr_schedule -def make_optimizer(opt, lr, n_epochs, steps_per_epoch, opt_kwargs): +def make_optimizer(opt, lr, n_epochs, steps_per_epoch, opt_kwargs, schedule): if lr <= 1e-7: optimizer = optax.set_to_zero() else: - schedule = get_schedule(lr, n_epochs, steps_per_epoch) + schedule = get_schedule(lr, n_epochs, steps_per_epoch, schedule) optimizer = opt(schedule, **opt_kwargs) return optimizer @@ -86,7 +90,7 @@ def get_opt( zbl_lr: float = 0.001, opt_name: str = "adam", opt_kwargs: dict = {}, - **kwargs, + schedule: dict = {}, ) -> optax._src.base.GradientTransformation: """ Builds an optimizer with different learning rates for each parameter group. @@ -99,15 +103,15 @@ def get_opt( else: opt = getattr(optax, opt_name) - nn_opt = make_optimizer(opt, nn_lr, n_epochs, steps_per_epoch, opt_kwargs) - emb_opt = make_optimizer(opt, emb_lr, n_epochs, steps_per_epoch, opt_kwargs) + nn_opt = make_optimizer(opt, nn_lr, n_epochs, steps_per_epoch, opt_kwargs, schedule) + emb_opt = make_optimizer(opt, emb_lr, n_epochs, steps_per_epoch, opt_kwargs, schedule) scale_opt = make_optimizer( - opt, scale_lr, n_epochs, steps_per_epoch, opt_kwargs + opt, scale_lr, n_epochs, steps_per_epoch, opt_kwargs, schedule ) shift_opt = make_optimizer( - opt, shift_lr, n_epochs, steps_per_epoch, opt_kwargs + opt, shift_lr, n_epochs, steps_per_epoch, opt_kwargs, schedule ) - zbl_opt = make_optimizer(opt, zbl_lr, n_epochs, steps_per_epoch, opt_kwargs) + zbl_opt = make_optimizer(opt, zbl_lr, n_epochs, steps_per_epoch, opt_kwargs, schedule) partition_optimizers = { "w": nn_opt, From 8f4e2fcfae75a284c2c6eeba8ef1a3b04017a2ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 22 Apr 2024 21:28:32 +0200 Subject: [PATCH 08/15] linting --- apax/train/parameters.py | 3 ++- apax/train/run.py | 6 ++++-- apax/train/trainer.py | 4 ++-- 3 files changed, 8 insertions(+), 5 deletions(-) diff --git a/apax/train/parameters.py b/apax/train/parameters.py index 8011bbcc..1e7ea89e 100644 --- a/apax/train/parameters.py +++ b/apax/train/parameters.py @@ -1,8 +1,9 @@ import jax + @jax.jit def tree_ema(tree1, tree2, alpha): - ema = jax.tree_map(lambda a,b: alpha * a + (1-alpha) * b, tree1, tree2) + ema = jax.tree_map(lambda a, b: alpha * a + (1 - alpha) * b, tree1, tree2) return ema diff --git a/apax/train/run.py b/apax/train/run.py index baae4e6d..fec30fa7 100644 --- a/apax/train/run.py +++ b/apax/train/run.py @@ -15,8 +15,8 @@ from apax.train.checkpoints import create_params, create_train_state from apax.train.loss import Loss, LossCollection from apax.train.metrics import initialize_metrics -from apax.train.trainer import fit from apax.train.parameters import EMAParameters +from apax.train.trainer import fit from apax.transfer_learning import transfer_parameters from apax.utils.random import seed_py_np_tf @@ -189,7 +189,9 @@ def run(user_config: Union[str, os.PathLike, dict], log_level="error"): state = transfer_parameters(state, config.checkpoints) if config.weight_average: - ema_handler = EMAParameters(config.weight_average.ema_start, config.weight_average.alpha) + ema_handler = EMAParameters( + config.weight_average.ema_start, config.weight_average.alpha + ) else: ema_handler = None diff --git a/apax/train/trainer.py b/apax/train/trainer.py index e1a41bb7..2d3e3e78 100644 --- a/apax/train/trainer.py +++ b/apax/train/trainer.py @@ -6,9 +6,9 @@ import jax import jax.numpy as jnp -from flax.training.train_state import TrainState import numpy as np from clu import metrics +from flax.training.train_state import TrainState from jax.experimental import mesh_utils from jax.sharding import PositionalSharding from tqdm import trange @@ -35,7 +35,7 @@ def fit( disable_batch_pbar: bool = True, is_ensemble=False, data_parallel=True, - ema_handler: Optional[EMAParameters]= None, + ema_handler: Optional[EMAParameters] = None, ): """ Trains the model using the provided training dataset. From f726461ea5e6ca677fe63bc9a1461f5045b11cfd Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Mon, 22 Apr 2024 21:37:55 +0200 Subject: [PATCH 09/15] updated tests for lr schedule options --- apax/cli/templates/train_config_full.yaml | 1 - tests/unit_tests/optimizer/test_get_opt.py | 1 + 2 files changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/cli/templates/train_config_full.yaml b/apax/cli/templates/train_config_full.yaml index 1394bb1c..79b1119e 100644 --- a/apax/cli/templates/train_config_full.yaml +++ b/apax/cli/templates/train_config_full.yaml @@ -80,7 +80,6 @@ optimizer: scale_lr: 0.001 shift_lr: 0.05 zbl_lr: 0.001 - transition_begin: 0 callbacks: - name: csv diff --git a/tests/unit_tests/optimizer/test_get_opt.py b/tests/unit_tests/optimizer/test_get_opt.py index e02428d4..065c30fe 100644 --- a/tests/unit_tests/optimizer/test_get_opt.py +++ b/tests/unit_tests/optimizer/test_get_opt.py @@ -29,6 +29,7 @@ def test_get_opt(): nn_lr=0.01, scale_lr=0.001, shift_lr=0.1, + schedule={"name": "linear", "transition_begin": 0, "end_value": 1e-6}, ) opt_state = opt.init(params=params) From cbb18ed60d508f5370a821585eb2dfa9efd4c2d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 24 Apr 2024 14:31:14 +0200 Subject: [PATCH 10/15] initialize zbl strength to 0 --- apax/layers/empirical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/layers/empirical.py b/apax/layers/empirical.py index 5adea3af..7cfa1b43 100644 --- a/apax/layers/empirical.py +++ b/apax/layers/empirical.py @@ -37,7 +37,7 @@ def setup(self): a_num_isp = inverse_softplus(a_num) coeffs_isp = inverse_softplus(coeffs) exps_isp = inverse_softplus(exps) - rep_scale_isp = inverse_softplus(1.0 / self.ke) + rep_scale_isp = inverse_softplus(0.0) self.a_exp = self.param("a_exp", nn.initializers.constant(a_exp_isp), (1,)) self.a_num = self.param("a_num", nn.initializers.constant(a_num_isp), (1,)) From ded78c402465717dabb764ebace9d3cac8d088e4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 24 Apr 2024 14:44:41 +0200 Subject: [PATCH 11/15] removed unused KE parameter --- apax/layers/empirical.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/apax/layers/empirical.py b/apax/layers/empirical.py index 7cfa1b43..9e02b611 100644 --- a/apax/layers/empirical.py +++ b/apax/layers/empirical.py @@ -26,8 +26,6 @@ class ZBLRepulsion(EmpiricalEnergyTerm): def setup(self): self.distance = vmap(space.distance, 0, 0) - self.ke = 14.3996 - a_exp = 0.23 a_num = 0.46850 coeffs = jnp.array([0.18175, 0.50986, 0.28022, 0.02817])[:, None] From ffc44d361b2b829cd9a21617a1b52f76c888d498 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 24 Apr 2024 14:49:49 +0200 Subject: [PATCH 12/15] added schedule and weight averaging to full config --- apax/cli/templates/train_config_full.yaml | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/apax/cli/templates/train_config_full.yaml b/apax/cli/templates/train_config_full.yaml index 79b1119e..d8f287f7 100644 --- a/apax/cli/templates/train_config_full.yaml +++ b/apax/cli/templates/train_config_full.yaml @@ -4,6 +4,7 @@ patience: null n_models: 1 n_jitted_steps: 1 data_parallel: True +weight_average: null data: directory: models/ @@ -80,7 +81,10 @@ optimizer: scale_lr: 0.001 shift_lr: 0.05 zbl_lr: 0.001 - + schedule: + name: linear + transition_begin: 0 + end_value: 1e-6 callbacks: - name: csv From e339e571a671900efe7cc97651b90029a0ce1fda Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 24 Apr 2024 15:05:26 +0200 Subject: [PATCH 13/15] removed use of deprecated ke parameter --- apax/layers/empirical.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/layers/empirical.py b/apax/layers/empirical.py index 9e02b611..6b4c2cbc 100644 --- a/apax/layers/empirical.py +++ b/apax/layers/empirical.py @@ -84,5 +84,5 @@ def __call__(self, dr_vec, Z, idx): E_ij = Z_i * Z_j / dr * f * cos_cutoff if self.apply_mask: E_ij = mask_by_neighbor(E_ij, idx) - E = 0.5 * rep_scale * self.ke * fp64_sum(E_ij) + E = 0.5 * rep_scale * fp64_sum(E_ij) return fp64_sum(E) From 3b852249f490902f6f772edd8cb5c2b2e592ea54 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 24 Apr 2024 15:13:49 +0200 Subject: [PATCH 14/15] added docstrings --- apax/config/train_config.py | 14 ++++++++++++-- apax/train/parameters.py | 14 +++++++++++++- 2 files changed, 25 insertions(+), 3 deletions(-) diff --git a/apax/config/train_config.py b/apax/config/train_config.py index 073fbd49..b9357ec7 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -357,9 +357,17 @@ class CheckpointConfig(BaseModel, extra="forbid"): class WeightAverage(BaseModel, extra="forbid"): - """ """ + """Applies an exponential moving average to model parameters. - ema_start: int + Parameters + ---------- + ema_start : int, default = 1 + Epoch at which to start averaging models. + alpha : float, default = 0.9 + How much of the new model to use. 1.0 would mean no averaging, 0.0 no updates. + """ + + ema_start: int =0 alpha: float = 0.9 @@ -402,6 +410,8 @@ class Config(BaseModel, frozen=True, extra="forbid"): | Loss configuration. optimizer : :class:`.OptimizerConfig` | Loss optimizer configuration. + weight_average : :class:`.WeightAverage`, optional + | Options for averaging weights between epochs. callbacks : List of various CallBack classes | Possible callbacks are :class:`.CSVCallback`, | :class:`.TBCallback`, :class:`.MLFlowCallback` diff --git a/apax/train/parameters.py b/apax/train/parameters.py index 1e7ea89e..5adc1830 100644 --- a/apax/train/parameters.py +++ b/apax/train/parameters.py @@ -3,12 +3,24 @@ @jax.jit def tree_ema(tree1, tree2, alpha): + """Exponential moving average of two pytrees. + """ ema = jax.tree_map(lambda a, b: alpha * a + (1 - alpha) * b, tree1, tree2) return ema class EMAParameters: - def __init__(self, ema_start, alpha) -> None: + """Handler for tracking an exponential moving average of model parameters. + The EMA parameters are used in the valitaion loop. + + Parameters + ---------- + ema_start : int, default = 1 + Epoch at which to start averaging models. + alpha : float, default = 0.9 + How much of the new model to use. 1.0 would mean no averaging, 0.0 no updates. + """ + def __init__(self, ema_start: int , alpha: float = 0.9) -> None: self.alpha = alpha self.ema_start = ema_start self.ema_params = None From 0921b3a112f7c929e0561f1295e9ae6dcbfea9de Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Moritz=20R=2E=20Sch=C3=A4fer?= Date: Wed, 24 Apr 2024 15:14:40 +0200 Subject: [PATCH 15/15] linting --- apax/config/train_config.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/apax/config/train_config.py b/apax/config/train_config.py index b9357ec7..d7eb5da3 100644 --- a/apax/config/train_config.py +++ b/apax/config/train_config.py @@ -367,7 +367,7 @@ class WeightAverage(BaseModel, extra="forbid"): How much of the new model to use. 1.0 would mean no averaging, 0.0 no updates. """ - ema_start: int =0 + ema_start: int = 0 alpha: float = 0.9