Skip to content

Commit

Permalink
Fix FISTA and ISTA maths in docs (#2012)
Browse files Browse the repository at this point in the history
* Fix FISTA ISTA maths in docs

* Add maths around k>0

* Remove max_iteration from ista and fista examples

* Remove double title for FISTA

* Remove unneeded members from F/ISTA docs

---------

Signed-off-by: Laura Murgatroyd <[email protected]>
Co-authored-by: Margaret Duff <[email protected]>
  • Loading branch information
lauramurgatroyd and MargaretDuff authored Dec 17, 2024
1 parent b4d178a commit a47cb15
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 48 deletions.
76 changes: 30 additions & 46 deletions Wrappers/Python/cil/optimisation/algorithms/FISTA.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,23 +29,22 @@

class ISTA(Algorithm):

r"""Iterative Shrinkage-Thresholding Algorithm, see :cite:`BeckTeboulle_b`, :cite:`BeckTeboulle_a`.
r"""Iterative Shrinkage-Thresholding Algorithm (ISTA), see :cite:`BeckTeboulle_b`, :cite:`BeckTeboulle_a`, is used to solve:
Iterative Shrinkage-Thresholding Algorithm (ISTA)
.. math:: \min_{x} f(x) + g(x)
.. math:: x^{k+1} = \mathrm{prox}_{\alpha^{k} g}(x^{k} - \alpha^{k}\nabla f(x^{k}))
where :math:`f` is differentiable, and :math:`g` has a *simple* proximal operator.
is used to solve
In each update, the algorithm computes:
.. math:: \min_{x} f(x) + g(x)
.. math:: x_{k+1} = \mathrm{prox}_{\alpha g}(x_{k} - \alpha\nabla f(x_{k}))
where :math:`f` is differentiable, :math:`g` has a *simple* proximal operator and :math:`\alpha^{k}`
is the :code:`step_size` per iteration.
where :math:`\alpha` is the :code:`step_size`.
Note
----
For a constant step size, i.e., :math:`a^{k}=a` for :math:`k\geq1`, convergence of ISTA
For a constant step size, :math:`\alpha`, convergence of ISTA
is guaranteed if
.. math:: \alpha\in(0, \frac{2}{L}),
Expand All @@ -56,16 +55,16 @@ class ISTA(Algorithm):
----------
initial : DataContainer
Initial guess of ISTA.
Initial guess of ISTA. :math:`x_{0}`
f : Function
Differentiable function. If `None` is passed, the algorithm will use the ZeroFunction.
g : Function or `None`
Convex function with *simple* proximal operator. If `None` is passed, the algorithm will use the ZeroFunction.
step_size : positive :obj:`float` or child class of :meth:`cil.optimisation.utilities.StepSizeRule`', default = None
Step size for the gradient step of ISTA. If a float is passed, this is used as a constant step size. If a child class of :meth:`cil.optimisation.utilities.StepSizeRule`' is passed then it's method `get_step_size` is called for each update.
Step size for the gradient step of ISTA. If a float is passed, this is used as a constant step size. If a child class of :meth:`cil.optimisation.utilities.StepSizeRule` is passed then its method :meth:`get_step_size` is called for each update.
The default :code:`step_size` is a constant :math:`\frac{0.99*2}{L}` or 1 if `f=None`.
preconditioner: class with a `apply` method or a function that takes an initialised CIL function as an argument and modifies a provided `gradient`.
This could be a custom `preconditioner` or one provided in :meth:`~cil.optimisation.utilities.preconditoner`. If None is passed then `self.gradient_update` will remain unmodified.
preconditioner: class with an `apply` method or a function that takes an initialised CIL function as an argument and modifies a provided `gradient`.
This could be a custom `preconditioner` or one provided in :meth:`~cil.optimisation.utilities.preconditoner`. If None is passed then `self.gradient_update` will remain unmodified.
kwargs: Keyword arguments
Expand All @@ -85,7 +84,7 @@ class ISTA(Algorithm):
>>> f = LeastSquares(A, b=b, c=0.5)
>>> g = ZeroFunction()
>>> ig = Aop.domain
>>> ista = ISTA(initial = ig.allocate(), f = f, g = g, max_iteration=10)
>>> ista = ISTA(initial = ig.allocate(), f = f, g = g)
>>> ista.run()
Expand Down Expand Up @@ -229,25 +228,26 @@ def calculate_objective_function_at_point(self, x):

class FISTA(ISTA):

r"""Fast Iterative Shrinkage-Thresholding Algorithm, see :cite:`BeckTeboulle_b`, :cite:`BeckTeboulle_a`.
r"""Fast Iterative Shrinkage-Thresholding Algorithm (FISTA), see :cite:`BeckTeboulle_b`, :cite:`BeckTeboulle_a`, is used to solve:
.. math:: \min_{x} f(x) + g(x)
where :math:`f` is differentiable and :math:`g` has a *simple* proximal operator.
Fast Iterative Shrinkage-Thresholding Algorithm (FISTA)
In each update the algorithm completes the following steps:
.. math::
\begin{cases}
y_{k} = x_{k} - \alpha\nabla f(x_{k}) \\
x_{k+1} = \mathrm{prox}_{\alpha g}(y_{k})\\
t_{k+1} = \frac{1+\sqrt{1+ 4t_{k}^{2}}}{2}\\
y_{k+1} = x_{k} + \frac{t_{k}-1}{t_{k-1}}(x_{k} - x_{k-1})
x_{k} = \mathrm{prox}_{\alpha g}(y_{k} - \alpha\nabla f(y_{k}))\\
t_{k+1} = \frac{1+\sqrt{1+ 4t_{k}^{2}}}{2}\\
y_{k+1} = x_{k} + \frac{t_{k}-1}{t_{k+1}}(x_{k} - x_{k-1})
\end{cases}
is used to solve
.. math:: \min_{x} f(x) + g(x)
where :math:`\alpha` is the :code:`step_size`.
where :math:`f` is differentiable, :math:`g` has a *simple* proximal operator and :math:`\alpha^{k}`
is the :code:`step_size` per iteration.
Note that the above applies for :math:`k\geq 1`. For :math:`k=0`, :math:`x_{0}` and :math:`y_{0}` are initialised to `initial`, and :math:`t_{1}=1`.
Parameters
Expand All @@ -260,9 +260,9 @@ class FISTA(ISTA):
g : Function or `None`
Convex function with *simple* proximal operator. If `None` is passed, the algorithm will use the ZeroFunction.
step_size : positive :obj:`float` or child class of :meth:`cil.optimisation.utilities.StepSizeRule`', default = None
Step size for the gradient step of ISTA. If a float is passed, this is used as a constant step size. If a child class of :meth:`cil.optimisation.utilities.StepSizeRule`' is passed then it's method `get_step_size` is called for each update.
Step size for the gradient step of ISTA. If a float is passed, this is used as a constant step size. If a child class of :meth:`cil.optimisation.utilities.StepSizeRule` is passed then it's method :meth:`get_step_size` is called for each update.
The default :code:`step_size` is a constant :math:`\frac{1}{L}` or 1 if `f=None`.
preconditioner: class with a `apply` method or a function that takes an initialised CIL function as an argument and modifies a provided `gradient`.
preconditioner : class with an `apply` method or a function that takes an initialised CIL function as an argument and modifies a provided `gradient`.
This could be a custom `preconditioner` or one provided in :meth:`~cil.optimisation.utilities.preconditoner`. If None is passed then `self.gradient_update` will remain unmodified.
kwargs: Keyword arguments
Expand All @@ -283,7 +283,7 @@ class FISTA(ISTA):
>>> f = LeastSquares(A, b=b, c=0.5)
>>> g = ZeroFunction()
>>> ig = Aop.domain
>>> fista = FISTA(initial = ig.allocate(), f = f, g = g, max_iteration=10)
>>> fista = FISTA(initial = ig.allocate(), f = f, g = g)
>>> fista.run()
See also
Expand Down Expand Up @@ -319,14 +319,14 @@ def __init__(self, initial, f, g, step_size=None, preconditioner=None, **kwargs
step_size=step_size, preconditioner=preconditioner, **kwargs)

def update(self):
r"""Performs a single iteration of FISTA
r"""Performs a single iteration of FISTA. For :math:`k\geq 1`:
.. math::
\begin{cases}
x_{k+1} = \mathrm{prox}_{\alpha g}(y_{k} - \alpha\nabla f(y_{k}))\\
x_{k} = \mathrm{prox}_{\alpha g}(y_{k} - \alpha\nabla f(y_{k}))\\
t_{k+1} = \frac{1+\sqrt{1+ 4t_{k}^{2}}}{2}\\
y_{k+1} = x_{k} + \frac{t_{k}-1}{t_{k-1}}(x_{k} - x_{k-1})
y_{k+1} = x_{k} + \frac{t_{k}-1}{t_{k+1}}(x_{k} - x_{k-1})
\end{cases}
"""
Expand All @@ -349,19 +349,3 @@ def update(self):

self.x.subtract(self.x_old, out=self.y)
self.y.sapyb(((self.t_old-1)/self.t), self.x, 1.0, out=self.y)


if __name__ == "__main__":

from cil.optimisation.functions import L2NormSquared
from cil.optimisation.algorithms import GD
from cil.framework import ImageGeometry
f = L2NormSquared()
g = L2NormSquared()
ig = ImageGeometry(3, 4, 4)
initial = ig.allocate()
fista = FISTA(initial, f, g, step_size=1443432)
print(fista.is_provably_convergent())

gd = GD(initial=initial, objective=f, step_size=1023123)
print(gd.is_provably_convergent())
2 changes: 0 additions & 2 deletions docs/source/optimisation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -96,14 +96,12 @@ ISTA
----
.. autoclass:: cil.optimisation.algorithms.ISTA
:members:
:special-members:
:inherited-members: run, update_objective_interval, max_iteration

FISTA
-----
.. autoclass:: cil.optimisation.algorithms.FISTA
:members:
:special-members:
:inherited-members: run, update_objective_interval, max_iteration

PDHG
Expand Down

0 comments on commit a47cb15

Please sign in to comment.