Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Deactivate control points whose contribution is completely overwritten #93

Merged
merged 11 commits into from
Jan 26, 2025
1 change: 1 addition & 0 deletions docs/src/examples_linear_fitting.md
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ using CairoMakie

spline_grid = add_default_local_refinement(spline_grid)
error_informed_local_refinement!(spline_grid, err)
deactivate_overwritten_control_points!(spline_grid.control_points)
plot_basis(spline_grid)
```

Expand Down
3 changes: 3 additions & 0 deletions docs/src/manual.md
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ activate_local_refinement!(
::AbstractMatrix{Ti}) where {Nin, Nout, Tv, Ti}
activate_local_refinement!(::SplineGrids.AbstractSplineGrid, args...)
activate_local_control_point_range!
error_informed_local_refinement!
deactivate_overwritten_control_points!(::LocallyRefinedControlPoints)
deactivate_overwritten_control_points!(::LocallyRefinedControlPoints, ::Integer)
```

# Structs
Expand Down
4 changes: 4 additions & 0 deletions docs/src/theory_local_refinement.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,12 @@ spline_grid = add_default_local_refinement(spline_grid)
activate_local_control_point_range!(spline_grid, 1:4, 1:6)
activate_local_control_point_range!(spline_grid, 1:6, 1:2)
activate_local_control_point_range!(spline_grid, 9:10, 7:10)
deactivate_overwritten_control_points!(spline_grid.control_points)
plot_basis(spline_grid)
```

Note that some of the original basis functions are completely gone, which means that their contribution to the final geometry is completely overwritten. The function `deactivate_overwritten_control_points!` weeds out the control points associated with these overwritten basis functions. This means that every active control point is guaranteed to influence the spline geometry (assuming there is at least one global sample point in the effective support of the basis function associated with that control point).

A nice property of this construction is that it can be iterated, creating a hierarchy. Let's refine the basis some more:

```@example tutorial
Expand All @@ -55,6 +58,7 @@ spline_grid = add_default_local_refinement(spline_grid)
```@example tutorial
activate_local_control_point_range!(spline_grid, 5:12, 1:4)
activate_local_control_point_range!(spline_grid, 7:8, 5:6)
deactivate_overwritten_control_points!(spline_grid.control_points)
plot_basis(spline_grid)
```

Expand Down
4 changes: 3 additions & 1 deletion ext/SplineGridsMakieExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,12 @@ function SplineGrids.plot_basis!(

eval_view = view(spline_grid.eval, :, :, 1)
basis_functions = zero(eval_view)
CI = CartesianIndices(size(control_points_new)[1:2])

for i in 1:get_n_control_points(spline_grid)
control_points_new .= 0
if control_points_new isa DefaultControlPoints
control_points_new[CartesianIndices(size(control_points_new)[1:2])[i], 1] = 1
control_points_new[CI[i], 1] = 1
else
control_points_new[i, 1] = 1
end
Expand Down
30 changes: 24 additions & 6 deletions src/SplineGrids.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,33 @@ include("refinement_matrix.jl")
include("refinement.jl")
include("control_points.jl")
include("spline_grid.jl")
include("adjoint.jl")
include("plot_rec.jl")
include("validation.jl")

export KnotVector, SplineDimension, SplineGrid, NURBSGrid, decompress, evaluate!,
evaluate_adjoint!, insert_knot, refine, RefinementMatrix, rmeye, mult!,
DefaultControlPoints, LocalRefinement, LocallyRefinedControlPoints,
add_default_local_refinement, activate_local_refinement!, get_n_control_points,
plot_basis, plot_basis!, activate_local_control_point_range!,
error_informed_local_refinement!
export activate_local_control_point_range!,
activate_local_refinement!,
add_default_local_refinement,
deactivate_overwritten_control_points!,
decompress,
DefaultControlPoints,
error_informed_local_refinement!,
evaluate_adjoint!,
evaluate!,
get_n_control_points,
insert_knot,
KnotVector,
LocallyRefinedControlPoints,
LocalRefinement,
mult!,
NURBSGrid,
plot_basis,
plot_basis!,
refine,
RefinementMatrix,
rmeye,
SplineDimension,
SplineGrid

# Define names for SplineGridsMakieExt
function plot_basis end
Expand Down
168 changes: 168 additions & 0 deletions src/adjoint.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
@kernel function spline_eval_adjoint_kernel(
control_points,
@Const(basis_function_eval_all),
@Const(sample_indices_all),
@Const(eval),
control_point_kernel_size,
derivative_order,
degrees
)

# Index of the global sample point
J = @index(Global, Cartesian)

# Input and output dimensionality
Nin = ndims(control_points) - 1
Nout = size(control_points)[end]

control_point_index_base = ntuple(
dim_in -> sample_indices_all[dim_in][J[dim_in]] - degrees[dim_in] - 1, Nin)

for I in CartesianIndices(control_point_kernel_size)
control_point_index = ntuple(
dim_in -> control_point_index_base[dim_in] + I[dim_in], Nin)

# Compute basis function product
basis_function_product = one(eltype(eval))

for dim_in in 1:Nin
basis_function_product *= basis_function_eval_all[dim_in][
J[dim_in], I[dim_in], derivative_order[dim_in] + 1]
end

# Add product of multi-dimensional basis function and control point to output
for dim_out in 1:Nout
Atomix.@atomic control_points[
control_point_index..., dim_out] += basis_function_product *
eval[J, dim_out]
end
end
end

"""
evaluate_adjoint!(spline_grid::AbstractSplineGrid{Nin, Nout, false, Tv};
derivative_order::NTuple{Nin, <:Integer} = ntuple(_ -> 0, Nin),
control_points::AbstractControlPointArray{Nin, Nout, Tv} = spline_grid.control_points,
eval::AbstractArray = spline_grid.eval)::Nothing where {Nin, Nout, Tv}

evaluate the adjoint of the linear mapping `control_points -> eval`. This is a computation of the form
`eval -> control_points`. If we write `evaluate!(spline_grid)` as a matrix vector multiplication `eval = M * control_points`,
Then the adjoint is given by `v -> M' * v`. This mapping is used in fitting algorithms.
"""
function evaluate_adjoint!(spline_grid::AbstractSplineGrid{Nin, Nout, false, Tv};
derivative_order::NTuple{Nin, <:Integer} = ntuple(_ -> 0, Nin),
control_points::AbstractControlPointArray{Nin, Nout, Tv} = spline_grid.control_points,
eval::AbstractArray = spline_grid.eval
)::Nothing where {Nin, Nout, Tv}
@assert !is_nurbs(spline_grid) "Adjoint evaluation not supported for NURBS."
(; spline_dimensions) = spline_grid
validate_partial_derivatives(spline_dimensions, derivative_order)
control_points = obtain(control_points)
control_points .= 0
@assert size(control_points) == size(spline_grid.control_points)
@assert size(eval) == size(spline_grid.eval)

basis_function_eval_all = ntuple(i -> spline_dimensions[i].eval, Nin)
sample_indices_all = ntuple(i -> spline_dimensions[i].sample_indices, Nin)
control_point_kernel_size = get_cp_kernel_size(spline_dimensions)
degrees = ntuple(i -> spline_dimensions[i].degree, Nin)

backend = get_backend(eval)
spline_eval_adjoint_kernel(backend)(
control_points,
basis_function_eval_all,
sample_indices_all,
eval,
control_point_kernel_size,
derivative_order,
degrees,
ndrange = size(eval)[1:(end - 1)]
)
synchronize(backend)
return nothing
end

@kernel function refinement_matrix_array_mul_adjoint_kernel(
B,
@Const(Y),
@Const(row_pointer_all),
@Const(column_start_all),
@Const(nzval_all),
refmat_index_all
)
# Y index
I = @index(Global, Cartesian)

Ndims = ndims(Y)

column_start, n_columns = get_row_extends(
I,
refmat_index_all,
row_pointer_all,
column_start_all,
nzval_all
)

for J_base in CartesianIndices(Tuple(n_columns))
# B index
J = ntuple(dim -> J_base[dim] + column_start[dim] - 1, Ndims)
contrib = Y[I]
for dim in 1:Ndims
refmat_index = refmat_index_all[dim]
if !iszero(refmat_index)
row_pointer = row_pointer_all[refmat_index][I[dim]]
contrib *= nzval_all[refmat_index][row_pointer + J_base[dim] - 1]
end
end
if contrib isa Flag
if contrib.flag
B[J...] = contrib
end
else
Atomix.@atomic B[J...] += contrib
end
end
end

function mult_adjoint!(
B::AbstractArray,
As::NTuple{N, <:RefinementMatrix},
Y::AbstractArray,
dims_refinement::NTuple{N, <:Integer}
) where {N}
backend = get_backend(B)
validate_mult_input(Y, As, B, dims_refinement)

n_refmat = length(dims_refinement)
refmat_index_all = ntuple(
dim -> (dim ∈ dims_refinement) ? findfirst(==(dim), dims_refinement) : 0, ndims(Y))

refinement_matrix_array_mul_adjoint_kernel(backend)(
B,
Y,
ntuple(i -> As[i].row_pointer, n_refmat),
ntuple(i -> As[i].column_start, n_refmat),
ntuple(i -> As[i].nzval, n_refmat),
refmat_index_all,
ndrange = size(Y)
)
synchronize(backend)
return nothing
end

@kernel function local_refinement_adjoint_kernel(
refinement_values,
@Const(control_points),
@Const(refinement_indices)
)
i = @index(Global, Linear)

Nin = ndims(control_points) - 1
Nout = size(control_points)[end]

indices = ntuple(dim_in -> refinement_indices[i, dim_in], Nin)

for dim_out in 1:Nout
refinement_values[i, dim_out] = control_points[indices..., dim_out]
end
end
Loading
Loading