Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Color basis functions by level #95

Merged
merged 5 commits into from
Jan 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
133 changes: 99 additions & 34 deletions docs/src/examples_linear_fitting.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,13 @@ In this section we demonstrate how a spline grid can be fitted. We will fit a sp

```@example tutorial
using FileIO
image = load(normpath(@__DIR__, "julia_logo.png"))
using CairoMakie
image = rotr90(load(normpath(@__DIR__, "julia_logo.png")))
fig = Figure()
ratio = size(image, 1) / size(image, 2)
ax = Axis(fig[1, 1], aspect = ratio)
CairoMakie.image!(ax, image)
fig
```

## Defining the spline grid
Expand All @@ -16,21 +22,25 @@ using Colors: Gray
using SplineGrids

degree = (2, 2)
image_array = Float32.(Gray.(image[end:-1:1, :]))'
image_array = Float32.(Gray.(image))
n_sample_points = size(image_array)
n_control_points = ntuple(i -> n_sample_points[i]÷10, 2)
n_control_points = ntuple(i -> n_sample_points[i] ÷ 25, 2)
extent = ntuple(i -> (0, n_sample_points[i]), 2)
dim_out = 1

spline_dimensions = SplineDimension.(n_control_points, degree, n_sample_points)
spline_dimensions = ntuple(
i -> SplineDimension(
n_control_points[i], degree[i], n_sample_points[i]; extent = extent[i]),
2)
spline_grid = SplineGrid(spline_dimensions, dim_out)
spline_grid
```

The basis of this spline geometry looks like this:

```@example tutorial
using CairoMakie

fig_total = Figure(size = (1600, 900)) # hide
plot_basis!(fig_total, spline_grid; i = 1, j = 1, title = "Unrefined spline basis") # hide
plot_basis(spline_grid; title = "Unrefined spline basis")
```

Expand All @@ -43,11 +53,30 @@ using LinearMaps
using IterativeSolvers
using Plots

spline_grid_map = LinearMap(spline_grid)
sol = lsqr(spline_grid_map, vec(image_array))
copyto!(spline_grid.control_points, sol)
evaluate!(spline_grid)
Plots.plot(spline_grid; title = "Least squares fit")
function fit!(spline_grid)
spline_grid_map = LinearMap(spline_grid)
sol = lsqr(spline_grid_map, vec(image_array))
copyto!(
spline_grid.control_points,
reshape(sol, get_n_control_points(spline_grid.control_points), dim_out)
)
evaluate!(spline_grid.control_points)
evaluate!(spline_grid)
end

function plot_fit(spline_grid)
Plots.plot(spline_grid; aspect_ratio = :equal, title = "Least squares fit",
clims = (-0.5, 1.5), cmap = :viridis)
end
# hide
function _plot_fit(spline_grid, j) # hide
ax_fit = Axis(fig_total[2, j], aspect = ratio; title = "Least squares fit") # hide
CairoMakie.heatmap!(ax_fit, spline_grid.eval[:, :, 1], colorrange = (-0.5, 1.5)) # hide
end # hide

fit!(spline_grid)
_plot_fit(spline_grid, 1) # hide
plot_fit(spline_grid)
```

## Matrix
Expand All @@ -57,6 +86,7 @@ The least-squares fitting procedure above is matrix free, but the linear mapping
```@example tutorial
using SparseArrays

# Make an analogous spline geometry with a smaller sample grid so the output dimensionality is not too large
n_control_points = (5, 5)
degree = (2, 2)
n_sample_points = (15, 15)
Expand All @@ -66,46 +96,81 @@ spline_dimensions = SplineDimension.(n_control_points, degree, n_sample_points)
spline_grid_ = SplineGrid(spline_dimensions, dim_out)
spline_grid_map = LinearMap(spline_grid_)
M = sparse(spline_grid_map)
Plots.heatmap(M[end:-1:1,:])
Plots.heatmap(M[end:-1:1, :])
```

## Local refinement informed by local error

Clearly the error of the fit is largest around the boundary of the text:
Clearly the error of the fit is largest where the text is:

```@example tutorial
err_unrefined = (spline_grid.eval - image_array).^2
Plots.heatmap(err_unrefined[:,:,1]', colormap = c=cgrad(:RdYlGn, rev=true))
title!("Squared error per pixel")
err_unrefined = (spline_grid.eval - image_array) .^ 2

function plot_error(error)
Plots.heatmap(error[:, :, 1]', colormap = cgrad(:RdYlGn, rev = true),
aspect_ratio = :equal, clims = (0, 1))
title!("Squared error per pixel")
end
# hide
function _plot_error(error, j) # hide
ax_err = Axis(fig_total[3, j], aspect = ratio; title = "Squared error per pixel") # hide
CairoMakie.heatmap!(
ax_err, error[:, :, 1], colormap = cgrad(:RdYlGn, rev = true); colorrange = (0, 1)) # hide
end # hide

_plot_error(err_unrefined, 1) # hide
plot_error(err_unrefined)
```

We can easily locally refine the spline basis by mapping this error back on to the control points.

```@example tutorial
spline_grid = add_default_local_refinement(spline_grid)
error_informed_local_refinement!(spline_grid, err_unrefined)
deactivate_overwritten_control_points!(spline_grid.control_points)
plot_basis(spline_grid; title = "Refined spline basis")
function refine(spline_grid)
spline_grid = add_default_local_refinement(spline_grid)
error_informed_local_refinement!(spline_grid, err_unrefined)
deactivate_overwritten_control_points!(spline_grid.control_points)
spline_grid
end

spline_grid = refine(spline_grid)
plot_basis!(fig_total, spline_grid; i = 1, j = 2, title = "Refined spline basis (level 1)") # hide
plot_basis(spline_grid; title = "Refined spline basis (level 1)")
```

We can now fit the image again with the refined basis.

```@example tutorial
spline_grid_map = LinearMap(spline_grid)
sol = lsqr(spline_grid_map, vec(image_array))
copyto!(
spline_grid.control_points,
reshape(sol, get_n_control_points(spline_grid.control_points), dim_out)
)
evaluate!(spline_grid.control_points)
evaluate!(spline_grid)
Plots.plot(spline_grid; title = "Least squares fit")
fit!(spline_grid)
_plot_fit(spline_grid, 2) # hide
plot_fit(spline_grid)
```

and the local error looks a bit better:

```@example tutorial
err_refined = (spline_grid.eval - image_array) .^ 2
_plot_error(err_refined, 2) # hide
plot_error(err_refined)
```

and the local error looks a lot better:
## Iterating local refinement

Let's iterate the local refinement and fitting procedure a few more times to get a nicer result!

```@example tutorial
err_refined = (spline_grid.eval - image_array).^2
Plots.heatmap(err_refined[:,:,1]', colormap = c=cgrad(:RdYlGn, rev=true), clims = (0, maximum(err_unrefined)))
title!("Squared error per pixel")
```
function _iteration(spline_grid, j) # hide
spline_grid = refine(spline_grid) # hide
plot_basis!(
fig_total, spline_grid; i = 1, j, title = "Refined spline basis (level $(j - 1))") # hide
fit!(spline_grid) # hide
_plot_fit(spline_grid, j) # hide
err_refined = (spline_grid.eval - image_array) .^ 2 # hide
_plot_error(err_refined, j) # hide
spline_grid # hide
end # hide
spline_grid = _iteration(spline_grid, 3) # hide
_iteration(spline_grid, 4) # hide
Colorbar(fig_total[2, 5], colorrange = (-0.5, 1.5)) # hide
Colorbar(fig_total[3, 5], colorrange = (0, 1), colormap = cgrad(:RdYlGn, rev = true)) # hide
fig_total # hide
```
81 changes: 67 additions & 14 deletions ext/SplineGridsMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,32 +7,43 @@ using SplineGrids

function SplineGrids.plot_basis(spline_grid::SplineGrid{2}; kwargs...)
fig = Figure()
ax = Axis3(
fig[1, 1], azimuth = -π / 2, elevation = π / 2, perspectiveness = 0.5; kwargs...)
plot_basis!(ax, spline_grid)
plot_basis!(fig, spline_grid; kwargs...)
fig
end

function SplineGrids.plot_basis!(
ax::Makie.Axis3,
spline_grid::SplineGrid{2, Nout, Tv}
)::Nothing where {Nout, Tv}
fig::Figure, spline_grid::SplineGrid{2}; i = 1, j = 1, kwargs...)
extents = ntuple(i -> spline_grid.spline_dimensions[i].knot_vector.extent, 2)
ratio = (extents[2][2] - extents[2][1]) / (extents[1][2] - extents[1][1])
ax = Axis3(
fig[i, j], azimuth = -π / 2, elevation = π / 2,
perspectiveness = 0.5, aspect = (1, ratio, 1); kwargs...)
plot_basis!(ax, spline_grid)
end

function SplineGrids.plot_basis!(ax::Makie.Axis3, spline_grid::SplineGrid{2})::Nothing
spline_grid = adapt(CPU(), spline_grid)
(; control_points, spline_dimensions) = spline_grid
(; control_points) = spline_grid
control_points_new = deepcopy(control_points)
spline_grid = setproperties(spline_grid; control_points = control_points_new)

plot_basis!(ax, spline_grid, control_points_new)
return nothing
end

function SplineGrids.plot_basis!(
ax::Makie.Axis3,
spline_grid::SplineGrid{2},
control_points_new::DefaultControlPoints
)
(; spline_dimensions) = spline_grid
eval_view = view(spline_grid.eval, :, :, 1)
basis_functions = zero(eval_view)
CI = CartesianIndices(size(control_points_new)[1:2])

for i in 1:get_n_control_points(spline_grid)
control_points_new .= 0
if control_points_new isa DefaultControlPoints
control_points_new[CI[i], 1] = 1
else
control_points_new[i, 1] = 1
end
control_points_new[CI[i], 1] = 1
evaluate!(control_points_new)
evaluate!(spline_grid)
broadcast!(max, basis_functions, basis_functions, eval_view)
Expand All @@ -42,10 +53,52 @@ function SplineGrids.plot_basis!(
spline_dimensions[1].sample_points,
spline_dimensions[2].sample_points,
basis_functions,
colormap = :coolwarm,
color = cbrt.(basis_functions),
colormap = [:black, first(Makie.wong_colors())],
colorrange = (0, 1)
)
return nothing
end

function SplineGrids.plot_basis!(
ax::Makie.Axis3,
spline_grid::SplineGrid{2, Nout, Tv},
control_points_new::LocallyRefinedControlPoints
) where {Nout, Tv}
(; local_refinements) = control_points_new
(; spline_dimensions) = spline_grid
eval_view = view(spline_grid.eval, :, :, 1)
control_points_new .= 0
basis_functions = zero(eval_view)
wong_colors = Makie.wong_colors()
colors = fill(Makie.RGBA{Tv}(0, 0, 0, 1), size(eval_view))

for (level, local_refinement) in enumerate(local_refinements)
(; refinement_values) = local_refinement
basis_functions_level = zero(eval_view)
for i in axes(refinement_values, 1)
refinement_values .= 0
refinement_values[i, 1] = 1
evaluate!(control_points_new)
evaluate!(spline_grid)
broadcast!(max, basis_functions_level, basis_functions_level, eval_view)
end
cmap_level = cgrad([:black, wong_colors[level]])
for I in eachindex(basis_functions)
value = basis_functions[I]
value_level = basis_functions_level[I]
if value_level > value
basis_functions[I] = value_level
colors[I] = get(cmap_level, ∛(value_level))
end
end
end
surface!(
ax,
spline_dimensions[1].sample_points,
spline_dimensions[2].sample_points,
basis_functions,
color = colors
)
end

end # module SplineGridsMakieExt
6 changes: 4 additions & 2 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ function mult_adjoint!(
) where {N}
backend = get_backend(B)
validate_mult_input(Y, As, B, dims_refinement)
B .= 0

n_refmat = length(dims_refinement)
refmat_index_all = ntuple(
Expand All @@ -152,7 +153,7 @@ end

@kernel function local_refinement_adjoint_kernel(
refinement_values,
@Const(control_points),
control_points,
@Const(refinement_indices)
)
i = @index(Global, Linear)
Expand All @@ -164,6 +165,7 @@ end

for dim_out in 1:Nout
refinement_values[i, dim_out] = control_points[indices..., dim_out]
control_points[indices..., dim_out] = 0
end
end

Expand Down Expand Up @@ -200,4 +202,4 @@ function evaluate_adjoint!(control_points::LocallyRefinedControlPoints)::Nothing
end
end
return nothing
end
end
14 changes: 5 additions & 9 deletions src/control_points.jl
Original file line number Diff line number Diff line change
Expand Up @@ -542,22 +542,20 @@ end
error_informed_local_refinement!(
spline_grid::AbstractSplineGrid{Nin, Nout, HasWeights, Tv, Ti},
error::AbstractArray;
threshold::Union{Number, Nothing} = nothing
threshold_factor::Number = 1.0
)::Nothing where {Nin, Nout, HasWeights, Tv, Ti}

Refine the last level of the locally refined spline grid informed by the `error` array which has the same
shape as `spline_grid.eval`. This is done by:

- mapping the error back onto the control points by using the adjoint of the refinement matrices multiplication
- summing over the output dimensions to obtain a single number per control point stored in `control_grid_error`
- activating each control point whose value is bigger than `threshold`

`threshold` can be explicitly provided but by default it is given by the mean of `control_grid_error`.
- activating each control point whose value is bigger than `threshold = threshold_factor * mean(control_grid_error)`
"""
function error_informed_local_refinement!(
spline_grid::AbstractSplineGrid{Nin, Nout, HasWeights, Tv, Ti},
error::AbstractArray;
threshold::Union{Number, Nothing} = nothing
threshold_factor::Number = 1.0
)::Nothing where {Nin, Nout, HasWeights, Tv, Ti}
@assert size(error)==size(spline_grid.eval) "The error array must have the same size as the eval array."

Expand All @@ -569,9 +567,7 @@ function error_informed_local_refinement!(
control_grid_error = dropdims(sum(control_points_error, dims = Nin + 1), dims = Nin + 1)

# Default threshold if not provided
if isnothing(threshold)
threshold = sum(control_grid_error) / length(control_grid_error)
end
threshold = threshold_factor * sum(control_grid_error) / length(control_grid_error)

# Deduce the refinement indices
refinement_indices_ci = findall(>(threshold), control_grid_error)
Expand Down Expand Up @@ -681,4 +677,4 @@ function deactivate_overwritten_control_points!(
refinement_values = refinement_values_new
)
return nothing
end
end
2 changes: 1 addition & 1 deletion src/util_kernels.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,4 +85,4 @@ end
for dim_in in 1:Nin
indices[i, dim_in] = t[dim_in]
end
end
end
3 changes: 2 additions & 1 deletion src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -242,5 +242,6 @@ end

Base.zero(::Type{Flag}) = Flag(false)
Base.one(::Type{Flag}) = Flag(true)
Base.convert(::Type{Flag}, x::Number) = Flag(Bool(x))

Base.:*(a::Flag, ::Number) = a
Base.:*(a::Flag, ::Number) = a
Loading