diff --git a/src/darsia/restoration/split_bregman_tvd.py b/src/darsia/restoration/split_bregman_tvd.py index 24286373..8560bc56 100644 --- a/src/darsia/restoration/split_bregman_tvd.py +++ b/src/darsia/restoration/split_bregman_tvd.py @@ -22,6 +22,7 @@ def split_bregman_tvd( isotropic: bool = False, verbose: Union[bool, int] = False, solver: da.Solver = da.Jacobi(), + adaptive=None, ) -> np.ndarray: """Split Bregman algorithm for TV denoising. @@ -43,6 +44,7 @@ def split_bregman_tvd( isotropic (bool): whether to use isotropic TV denoising verbose (bool, int): verbosity (frequency if int) solver (da.Solver): solver to use for the inner linear system + adaptive (lambda, Optional): adaptivity schedule Returns: array: denoised image @@ -78,8 +80,8 @@ def _shrink(x: np.ndarray, k: Union[float, np.ndarray]) -> np.ndarray: return np.maximum(np.abs(x) - k, 0) * np.sign(x) # Define right hand side function - def _rhs_function(dt: np.ndarray, bt: np.ndarray) -> np.ndarray: - return omega * img + ell * sum( + def _rhs_function(dt: np.ndarray, bt: np.ndarray, ellt) -> np.ndarray: + return omega * img + ellt * sum( [ da.forward_diff(img=bt[..., i] - dt[..., i], axis=i, dim=dim) for i in range(dim) @@ -103,8 +105,9 @@ def _rhs_function(dt: np.ndarray, bt: np.ndarray) -> np.ndarray: # Bregman iterations for iter in range(max_num_iter): + # First step - solve the stabilized diffusion system. - img_new = solver(x0=img_iter, rhs=_rhs_function(d, b)) + img_new = solver(x0=img_iter, rhs=_rhs_function(d, b, ell)) # Second step - shrinkage. if isotropic: @@ -121,6 +124,18 @@ def _rhs_function(dt: np.ndarray, bt: np.ndarray) -> np.ndarray: d[..., j] = _shrink(dub, mu / ell) b[..., j] = dub - d[..., j] + # Update ell + if adaptive is not None and adaptive(iter): + grad = np.zeros((*img.shape, dim), dtype=img.dtype) + for j in range(dim): + grad[..., j] = da.backward_diff(img=img_new, axis=j, dim=dim) + ell = 1.0 / np.maximum(np.linalg.norm(grad, ord=1, axis=-1), 1e-12) + solver.update_params( + mass_coeff=omega, + diffusion_coeff=ell, + dim=dim, + ) + # Monitor performance relative_increment = np.linalg.norm(img_new - img_iter) / img_nrm if verbose if isinstance(verbose, bool) else iter % verbose == 0: