From 591d128fe2dc7755dcb161011273cfdc4464b256 Mon Sep 17 00:00:00 2001
From: Sai Krishna Kishore Nori
 <68329270+kishore-nori@users.noreply.github.com>
Date: Thu, 23 Jan 2025 23:59:39 +1100
Subject: [PATCH 1/2] Fixing extended trace failure for Adam and AdaMax and
 generalising `alpha` parameter to accept callable object (scheduler) (#1115)

* adding alpha to state and generalizing alpha to accept a function (scheduler)

* removing unused variables

* adding alpha to AdaMax state and generalizing alpha to accept a function (scheduler)

* updating the docstring for Adam to add description of scheduled alpha constructors

* updating the docstring for AdaMax to add description of scheduled alpha constructors

* adding tests for scheduled Adam and AdaMax, which covers testing extended_trace=true case

* adding default constant alpha case tests for extended_trace=true for Adam and AdaMax
---
 src/multivariate/solvers/first_order/adam.jl  | 61 +++++++++++++++---
 .../solvers/first_order/adamax.jl             | 62 ++++++++++++++++---
 .../solvers/first_order/adam_adamax.jl        | 59 ++++++++++++++++++
 3 files changed, 164 insertions(+), 18 deletions(-)

diff --git a/src/multivariate/solvers/first_order/adam.jl b/src/multivariate/solvers/first_order/adam.jl
index 86519a3c..3057a731 100644
--- a/src/multivariate/solvers/first_order/adam.jl
+++ b/src/multivariate/solvers/first_order/adam.jl
@@ -1,17 +1,37 @@
 """
 # Adam
-## Constructor
+## Constant `alpha` case (default) constructor:
+
 ```julia
     Adam(; alpha=0.0001, beta_mean=0.9, beta_var=0.999, epsilon=1e-8)
 ```
+
+## Scheduled `alpha` case constructor:
+
+Alternative to the above (default) usage where `alpha` is a fixed constant for
+all the iterations, the following constructor provides flexibility for `alpha`
+to be a callable object (a scheduler) that maps the current iteration count to
+a value of `alpha` that is to-be used for the current optimization iteraion's
+update step. This helps us in scheduling `alpha` over the iterations as
+desired, using the following usage,
+
+```julia
+    # Let alpha_scheduler be iteration -> alpha value mapping callable object
+    Adam(; alpha=alpha_scheduler, other_kwargs...)
+```
+
 ## Description
-Adam is a gradient based optimizer that choses its search direction by building up estimates of the first two moments of the gradient vector. This makes it suitable for problems with a stochastic objective and thus gradient. The method is introduced in [1] where the related AdaMax method is also introduced, see `?AdaMax` for more information on that method.
+Adam is a gradient based optimizer that choses its search direction by building
+up estimates of the first two moments of the gradient vector. This makes it
+suitable for problems with a stochastic objective and thus gradient. The method
+is introduced in [1] where the related AdaMax method is also introduced, see
+`?AdaMax` for more information on that method.
 
 ## References
 [1] https://arxiv.org/abs/1412.6980
 """
-struct Adam{T, Tm} <: FirstOrderOptimizer
-    α::T
+struct Adam{Tα, T, Tm} <: FirstOrderOptimizer
+    α::Tα  
     β₁::T
     β₂::T
     ϵ::T
@@ -32,20 +52,29 @@ mutable struct AdamState{Tx, T, Tm, Tu, Ti} <: AbstractOptimizerState
     s::Tx
     m::Tm
     u::Tu
+    alpha::T
     iter::Ti
 end
 function reset!(method, state::AdamState, obj, x)
     value_gradient!!(obj, x)
 end
+
+function _get_init_params(method::Adam{T}) where T <: Real
+  method.α, method.β₁, method.β₂
+end 
+
+function _get_init_params(method::Adam)
+  method.α(1), method.β₁, method.β₂
+end 
+
 function initial_state(method::Adam, options, d, initial_x::AbstractArray{T}) where T
     initial_x = copy(initial_x)
 
     value_gradient!!(d, initial_x)
-    α, β₁, β₂ = method.α, method.β₁, method.β₂
+    α, β₁, β₂ = _get_init_params(method)
 
     m = copy(gradient(d))
     u = zero(m)
-    a = 1 - β₁
     iter = 0
 
     AdamState(initial_x, # Maintain current state in state.x
@@ -54,13 +83,29 @@ function initial_state(method::Adam, options, d, initial_x::AbstractArray{T}) wh
                          similar(initial_x), # Maintain current search direction in state.s
                          m,
                          u,
+                         α,
                          iter)
 end
 
+function _update_iter_alpha_in_state!(
+  state::AdamState, method::Adam{T}) where T <: Real
+
+  state.iter = state.iter+1
+end 
+
+function _update_iter_alpha_in_state!(
+  state::AdamState, method::Adam)
+
+  state.iter = state.iter+1
+  state.alpha = method.α(state.iter)
+end
+
 function update_state!(d, state::AdamState{T}, method::Adam) where T
-    state.iter = state.iter+1
+    
+    _update_iter_alpha_in_state!(state, method)
     value_gradient!(d, state.x)
-    α, β₁, β₂, ϵ = method.α, method.β₁, method.β₂, method.ϵ
+
+    α, β₁, β₂, ϵ = state.alpha, method.β₁, method.β₂, method.ϵ
     a = 1 - β₁
     b = 1 - β₂
 
diff --git a/src/multivariate/solvers/first_order/adamax.jl b/src/multivariate/solvers/first_order/adamax.jl
index e001d46f..b06963bc 100644
--- a/src/multivariate/solvers/first_order/adamax.jl
+++ b/src/multivariate/solvers/first_order/adamax.jl
@@ -1,18 +1,37 @@
 """
 # AdaMax
-## Constructor
+## Constant `alpha` case (default) constructor:
+
 ```julia
     AdaMax(; alpha=0.002, beta_mean=0.9, beta_var=0.999, epsilon=1e-8)
 ```
-## Description
-AdaMax is a gradient based optimizer that choses its search direction by building up estimates of the first two moments of the gradient vector. This makes it suitable for problems with a stochastic objective and thus gradient. The method is introduced in [1] where the related Adam method is also introduced, see `?Adam` for more information on that method.
 
+## Scheduled `alpha` case constructor:
+
+Alternative to the above (default) usage where `alpha` is a fixed constant for
+all the iterations, the following constructor provides flexibility for `alpha`
+to be a callable object (a scheduler) that maps the current iteration count to
+a value of `alpha` that is to-be used for the current optimization iteraion's
+update step. This helps us in scheduling `alpha` over the iterations as
+desired, using the following usage,
 
+```julia
+    # Let alpha_scheduler be iteration -> alpha value mapping callable object
+    AdaMax(; alpha=alpha_scheduler, other_kwargs...)
+```
+
+## Description
+AdaMax is a gradient based optimizer that choses its search direction by
+building up estimates of the first two moments of the gradient vector. This
+makes it suitable for problems with a stochastic objective and thus gradient.
+The method is introduced in [1] where the related Adam method is also
+introduced, see `?Adam` for more information on that method.
+
+## References
 [1] https://arxiv.org/abs/1412.6980
 """
-
-struct AdaMax{T,Tm} <: FirstOrderOptimizer
-    α::T
+struct AdaMax{Tα, T, Tm} <: FirstOrderOptimizer
+    α::Tα  
     β₁::T
     β₂::T
     ϵ::T
@@ -33,20 +52,29 @@ mutable struct AdaMaxState{Tx, T, Tm, Tu, Ti} <: AbstractOptimizerState
     s::Tx
     m::Tm
     u::Tu
+    alpha::T
     iter::Ti
 end
 function reset!(method, state::AdaMaxState, obj, x)
     value_gradient!!(obj, x)
 end
+
+function _get_init_params(method::AdaMax{T}) where T <: Real
+  method.α, method.β₁, method.β₂
+end 
+
+function _get_init_params(method::AdaMax)
+  method.α(1), method.β₁, method.β₂
+end 
+
 function initial_state(method::AdaMax, options, d, initial_x::AbstractArray{T}) where T
     initial_x = copy(initial_x)
 
     value_gradient!!(d, initial_x)
-    α, β₁, β₂ = method.α, method.β₁, method.β₂
+    α, β₁, β₂ = _get_init_params(method)
 
     m = copy(gradient(d))
     u = zero(m)
-    a = 1 - β₁
     iter = 0
 
     AdaMaxState(initial_x, # Maintain current state in state.x
@@ -55,13 +83,27 @@ function initial_state(method::AdaMax, options, d, initial_x::AbstractArray{T})
                          similar(initial_x), # Maintain current search direction in state.s
                          m,
                          u,
+                         α,
                          iter)
 end
 
+function _update_iter_alpha_in_state!(
+  state::AdaMaxState, method::AdaMax{T}) where T <: Real
+
+  state.iter = state.iter+1
+end 
+
+function _update_iter_alpha_in_state!(
+  state::AdaMaxState, method::AdaMax)
+
+  state.iter = state.iter+1
+  state.alpha = method.α(state.iter)
+end
+
 function update_state!(d, state::AdaMaxState{T}, method::AdaMax) where T
-    state.iter = state.iter+1
+    _update_iter_alpha_in_state!(state, method)
     value_gradient!(d, state.x)
-    α, β₁, β₂, ϵ = method.α, method.β₁, method.β₂, method.ϵ
+    α, β₁, β₂, ϵ = state.alpha, method.β₁, method.β₂, method.ϵ
     a = 1 - β₁
     m, u = state.m, state.u
 
diff --git a/test/multivariate/solvers/first_order/adam_adamax.jl b/test/multivariate/solvers/first_order/adam_adamax.jl
index 4ead005f..f08f9c6a 100644
--- a/test/multivariate/solvers/first_order/adam_adamax.jl
+++ b/test/multivariate/solvers/first_order/adam_adamax.jl
@@ -21,6 +21,7 @@
                     skip = skip,
                     show_name = debug_printing)
 end
+
 @testset "AdaMax" begin
     f(x) = x[1]^4
     function g!(storage, x)
@@ -45,3 +46,61 @@ end
                     show_name=debug_printing,
                     iteration_exceptions = (("Trigonometric", 1_000_000,),))
 end
+
+@testset "Adam-scheduler" begin
+  f(x) = x[1]^4
+  function g!(storage, x)
+      storage[1] = 4 * x[1]^3
+      return
+  end
+
+  initial_x = [1.0]
+  options = Optim.Options(show_trace = debug_printing, allow_f_increases=true, iterations=100_000)
+  alpha_scheduler(iter) = 0.0001*(1 + 0.99^iter)
+  results = Optim.optimize(f, g!, initial_x, Adam(alpha=alpha_scheduler), options)
+  @test norm(Optim.minimum(results)) < 1e-6
+  @test summary(results) == "Adam"
+
+  # verifying the alpha values over iterations and also testing extended_trace
+  # this way we test both alpha scheduler and the working of
+  # extended_trace=true option
+
+  options = Optim.Options(show_trace = debug_printing, allow_f_increases=true, iterations=1000, extended_trace=true, store_trace=true)
+  results = Optim.optimize(f, g!, initial_x, Adam(alpha=1e-5), options)
+
+  @test prod(map(iter -> results.trace[iter].metadata["Current step size"], 2:results.iterations+1) .== 1e-5)
+
+  options = Optim.Options(show_trace = debug_printing, allow_f_increases=true, iterations=1000, extended_trace=true, store_trace=true)
+  results = Optim.optimize(f, g!, initial_x, Adam(alpha=alpha_scheduler), options)
+
+  @test map(iter -> results.trace[iter].metadata["Current step size"], 2:results.iterations+1) == alpha_scheduler.(1:results.iterations)
+end
+
+@testset "AdaMax-scheduler" begin
+  f(x) = x[1]^4
+  function g!(storage, x)
+      storage[1] = 4 * x[1]^3
+      return
+  end
+
+  initial_x = [1.0]
+  options = Optim.Options(show_trace = debug_printing, allow_f_increases=true, iterations=100_000)
+  alpha_scheduler(iter) = 0.002*(1 + 0.99^iter)
+  results = Optim.optimize(f, g!, initial_x, AdaMax(alpha=alpha_scheduler), options)
+  @test norm(Optim.minimum(results)) < 1e-6
+  @test summary(results) == "AdaMax"
+
+  # verifying the alpha values over iterations and also testing extended_trace
+  # this way we test both alpha scheduler and the working of
+  # extended_trace=true option
+
+  options = Optim.Options(show_trace = debug_printing, allow_f_increases=true, iterations=1000, extended_trace=true, store_trace=true)
+  results = Optim.optimize(f, g!, initial_x, AdaMax(alpha=1e-4), options)
+
+  @test prod(map(iter -> results.trace[iter].metadata["Current step size"], 2:results.iterations+1) .== 1e-4)
+
+  options = Optim.Options(show_trace = debug_printing, allow_f_increases=true, iterations=1000, extended_trace=true, store_trace=true)
+  results = Optim.optimize(f, g!, initial_x, AdaMax(alpha=alpha_scheduler), options)
+
+  @test map(iter -> results.trace[iter].metadata["Current step size"], 2:results.iterations+1) == alpha_scheduler.(1:results.iterations)
+end

From b041bd63373cf2403180c6993eff235b5c010670 Mon Sep 17 00:00:00 2001
From: Patrick Kofod Mogensen <patrick.mogensen@gmail.com>
Date: Thu, 23 Jan 2025 14:27:54 +0100
Subject: [PATCH 2/2] Update Project.toml

---
 Project.toml | 2 +-
 1 file changed, 1 insertion(+), 1 deletion(-)

diff --git a/Project.toml b/Project.toml
index 69ad5487..6d9f4f69 100644
--- a/Project.toml
+++ b/Project.toml
@@ -1,6 +1,6 @@
 name = "Optim"
 uuid = "429524aa-4258-5aef-a3af-852621145aeb"
-version = "1.10.0"
+version = "1.11.0"
 
 [deps]
 Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"