Skip to content

Commit

Permalink
ENH: Allow to adaptively update the regularization parameter in Bregman.
Browse files Browse the repository at this point in the history
  • Loading branch information
jwboth committed Dec 14, 2023
1 parent a426696 commit 4f1b494
Showing 1 changed file with 18 additions and 3 deletions.
21 changes: 18 additions & 3 deletions src/darsia/restoration/split_bregman_tvd.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ def split_bregman_tvd(
isotropic: bool = False,
verbose: Union[bool, int] = False,
solver: da.Solver = da.Jacobi(),
adaptive=None,
) -> np.ndarray:
"""Split Bregman algorithm for TV denoising.
Expand All @@ -43,6 +44,7 @@ def split_bregman_tvd(
isotropic (bool): whether to use isotropic TV denoising
verbose (bool, int): verbosity (frequency if int)
solver (da.Solver): solver to use for the inner linear system
adaptive (lambda, Optional): adaptivity schedule
Returns:
array: denoised image
Expand Down Expand Up @@ -78,8 +80,8 @@ def _shrink(x: np.ndarray, k: Union[float, np.ndarray]) -> np.ndarray:
return np.maximum(np.abs(x) - k, 0) * np.sign(x)

# Define right hand side function
def _rhs_function(dt: np.ndarray, bt: np.ndarray) -> np.ndarray:
return omega * img + ell * sum(
def _rhs_function(dt: np.ndarray, bt: np.ndarray, ellt) -> np.ndarray:
return omega * img + ellt * sum(
[
da.forward_diff(img=bt[..., i] - dt[..., i], axis=i, dim=dim)
for i in range(dim)
Expand All @@ -103,8 +105,9 @@ def _rhs_function(dt: np.ndarray, bt: np.ndarray) -> np.ndarray:

# Bregman iterations
for iter in range(max_num_iter):

# First step - solve the stabilized diffusion system.
img_new = solver(x0=img_iter, rhs=_rhs_function(d, b))
img_new = solver(x0=img_iter, rhs=_rhs_function(d, b, ell))

# Second step - shrinkage.
if isotropic:
Expand All @@ -121,6 +124,18 @@ def _rhs_function(dt: np.ndarray, bt: np.ndarray) -> np.ndarray:
d[..., j] = _shrink(dub, mu / ell)
b[..., j] = dub - d[..., j]

# Update ell
if adaptive is not None and adaptive(iter):
grad = np.zeros((*img.shape, dim), dtype=img.dtype)
for j in range(dim):
grad[..., j] = da.backward_diff(img=img_new, axis=j, dim=dim)
ell = 1.0 / np.maximum(np.linalg.norm(grad, ord=1, axis=-1), 1e-12)
solver.update_params(
mass_coeff=omega,
diffusion_coeff=ell,
dim=dim,
)

# Monitor performance
relative_increment = np.linalg.norm(img_new - img_iter) / img_nrm
if verbose if isinstance(verbose, bool) else iter % verbose == 0:
Expand Down

0 comments on commit 4f1b494

Please sign in to comment.