Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Propagate kwargs through update_coefficients! #143

Merged
merged 24 commits into from
May 29, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
1ed74b6
Recursively propagate kwargs through update_coefficients!
gaurav-arya Jan 30, 2023
293d5eb
Rename accepted_kwarg_fields -> accepted_kwargs
gaurav-arya Mar 12, 2023
1712594
Allow accepted_kwargs=nothing to indicate no wrapping
gaurav-arya Mar 12, 2023
cd70503
Tweak keyword filtering logic
gaurav-arya Mar 12, 2023
17298fb
Test operator update (including kwarg update) in operator algebra test
gaurav-arya Mar 12, 2023
dc3e31d
Support kwargs in function operator
gaurav-arya Mar 12, 2023
5b47ef9
Propagate kwargs for out-of-place function operator update_coefficien…
gaurav-arya Mar 12, 2023
88d5050
Catch function operator error for empty kwargs
gaurav-arya Mar 12, 2023
c7fcd51
Address code review suggestions on diag op construction
gaurav-arya Mar 12, 2023
24aef00
Improve logic for normalizing kwargs
gaurav-arya Mar 12, 2023
44d66bb
Test operator application form in operator algebra test set
gaurav-arya Mar 12, 2023
9d8fdff
Support kwargs in function operator functionals
gaurav-arya Mar 12, 2023
f793c60
Add example
gaurav-arya Mar 12, 2023
3312967
Remove unncessary function call
gaurav-arya Mar 12, 2023
69f0ecc
Rename kwargs_for_op -> accepted_kwargs
gaurav-arya Mar 12, 2023
fef7618
Fix function operator out-of-place update coefficients
gaurav-arya Mar 12, 2023
ceca67c
Use NoKwargFilter() to bypass keyword filtering (rather than nothing)
gaurav-arya Mar 12, 2023
0efa2ce
Remove debug line
gaurav-arya Mar 12, 2023
2adb43f
Merge branch 'master' into ag-kwargs
vpuri3 May 27, 2023
e63e445
fix diagonaloperator update
vpuri3 May 27, 2023
e89d5a1
function op working
vpuri3 May 29, 2023
4818a66
moved kwargs to FunctionOp.traits
vpuri3 May 29, 2023
3854874
tests passing
vpuri3 May 29, 2023
3a665ca
Base.Pairs notdef in LTS
vpuri3 May 29, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions docs/src/interface.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,10 @@ the proof to affine operators, so then ``exp(A*t)*v`` operations via Krylov meth
affine as well, and all sorts of things. Thus affine operators have no matrix representation but they
are still compatible with essentially any Krylov method which would otherwise be compatible with
matrix-free representations, hence their support in the SciMLOperators interface.

## Note about keyword arguments to `update_coefficients!`

In rare cases, an operator may be used in a context where additional state is expected to be provided
to `update_coefficients!` beyond `u`, `p`, and `t`. In this case, the operator may accept this additional
state through arbitrary keyword arguments to `update_coefficients!`. When the caller provides these, they will be recursively propagated downwards through composed operators just like `u`, `p`, and `t`, and provided to the operator.
For the [premade SciMLOperators](premade_operators.md), one can specify the keyword arguments used by an operator with an `accepted_kwarg_fields` argument that defaults to an empty tuple.
22 changes: 12 additions & 10 deletions src/batch.jl
Original file line number Diff line number Diff line change
@@ -1,33 +1,35 @@
#
"""
BatchedDiagonalOperator(diag, [; update_func])
BatchedDiagonalOperator(diag; update_func=nothing, accepted_kwarg_fields=())

Represents a time-dependent elementwise scaling (diagonal-scaling) operation.
Acts on `AbstractArray`s of the same size as `diag`. The update function is called
by `update_coefficients!` and is assumed to have the following signature:

update_func(diag::AbstractVector,u,p,t) -> [modifies diag]
update_func(diag::AbstractVector,u,p,t; <accepted kwarg fields>) -> [modifies diag]
"""
struct BatchedDiagonalOperator{T,D,F} <: AbstractSciMLOperator{T}
diag::D
update_func::F

function BatchedDiagonalOperator(
diag::AbstractArray;
update_func=DEFAULT_UPDATE_FUNC
update_func=nothing,
accepted_kwarg_fields=()
)
_update_func = preprocess_update_func(update_func, accepted_kwarg_fields)
new{
eltype(diag),
typeof(diag),
typeof(update_func)
typeof(_update_func)
}(
diag, update_func,
diag, _update_func,
)
end
end

function DiagonalOperator(u::AbstractArray; update_func=DEFAULT_UPDATE_FUNC)
BatchedDiagonalOperator(u; update_func=update_func)
function DiagonalOperator(u::AbstractArray; update_func=nothing, accepted_kwarg_fields=())
BatchedDiagonalOperator(u; update_func, accepted_kwarg_fields)
end

# traits
Expand All @@ -40,7 +42,7 @@ function Base.conj(L::BatchedDiagonalOperator) # TODO - test this thoroughly
update_func = if isreal(L)
L.update_func
else
(L,u,p,t) -> conj(L.update_func(conj(L.diag),u,p,t))
(L,u,p,t; kwargs...) -> conj(L.update_func(conj(L.diag),u,p,t; kwargs...))
end
BatchedDiagonalOperator(diag; update_func=update_func)
end
Expand All @@ -57,15 +59,15 @@ function LinearAlgebra.ishermitian(L::BatchedDiagonalOperator)
end
LinearAlgebra.isposdef(L::BatchedDiagonalOperator) = isposdef(Diagonal(vec(L.diag)))

isconstant(L::BatchedDiagonalOperator) = L.update_func == DEFAULT_UPDATE_FUNC
isconstant(L::BatchedDiagonalOperator) = update_func_isconstant(L.update_func)
islinear(::BatchedDiagonalOperator) = true
has_adjoint(L::BatchedDiagonalOperator) = true
has_ldiv(L::BatchedDiagonalOperator) = all(x -> !iszero(x), L.diag)
has_ldiv!(L::BatchedDiagonalOperator) = has_ldiv(L)

getops(L::BatchedDiagonalOperator) = (L.diag,)

update_coefficients!(L::BatchedDiagonalOperator,u,p,t) = (L.update_func(L.diag,u,p,t); nothing)
update_coefficients!(L::BatchedDiagonalOperator,u,p,t; kwargs...) = (L.update_func(L.diag,u,p,t; kwargs...); nothing)

# operator application
Base.:*(L::BatchedDiagonalOperator, u::AbstractVecOrMat) = L.diag .* u
Expand Down
4 changes: 2 additions & 2 deletions src/func.jl
Original file line number Diff line number Diff line change
Expand Up @@ -192,10 +192,10 @@ function update_coefficients(L::FunctionOperator, u, p, t)
)
end

function update_coefficients!(L::FunctionOperator, u, p, t)
function update_coefficients!(L::FunctionOperator, u, p, t; kwargs...)
ops = getops(L)
for op in ops
update_coefficients!(op, u, p, t)
update_coefficients!(op, u, p, t; kwargs...)
end

L.p = p
Expand Down
24 changes: 18 additions & 6 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,31 @@ out-of-place form B = update_coefficients(A,u,p,t).
"""
function (::AbstractSciMLOperator) end

# Utilities for update functions
DEFAULT_UPDATE_FUNC(A,u,p,t) = A
function preprocess_update_func(update_func, accepted_kwarg_fields)
update_func = (update_func === nothing) ? DEFAULT_UPDATE_FUNC : update_func
return FilterKwargs(update_func, accepted_kwarg_fields)
end
function update_func_isconstant(update_func)
if update_func isa FilterKwargs
return update_func.f == DEFAULT_UPDATE_FUNC
else
return update_func == DEFAULT_UPDATE_FUNC
end
end

update_coefficients!(L,u,p,t) = nothing
update_coefficients(L,u,p,t) = L
function update_coefficients!(L::AbstractSciMLOperator, u, p, t)
update_coefficients!(L,u,p,t; kwargs...) = nothing
update_coefficients(L,u,p,t; kwargs...) = L
function update_coefficients!(L::AbstractSciMLOperator, u, p, t; kwargs...)
for op in getops(L)
update_coefficients!(op, u, p, t)
update_coefficients!(op, u, p, t; kwargs...)
end
nothing
end

(L::AbstractSciMLOperator)(u, p, t) = (update_coefficients!(L, u, p, t); L * u)
(L::AbstractSciMLOperator)(du, u, p, t) = (update_coefficients!(L, u, p, t); mul!(du, L, u))
(L::AbstractSciMLOperator)(u, p, t; kwargs...) = (update_coefficients!(L, u, p, t; kwargs...); L * u)
(L::AbstractSciMLOperator)(du, u, p, t; kwargs...) = (update_coefficients!(L, u, p, t; kwargs...); mul!(du, L, u))

###
# caching interface
Expand Down
63 changes: 37 additions & 26 deletions src/matrix.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,20 @@
#
"""
MatrixOperator(A[; update_func])
MatrixOperator(A; update_func=nothing, accepted_kwarg_fields=())

Represents a time-dependent linear operator given by an AbstractMatrix. The
update function is called by `update_coefficients!` and is assumed to have
the following signature:

update_func(A::AbstractMatrix,u,p,t) -> [modifies A]
update_func(A::AbstractMatrix,u,p,t; <accepted kwarg fields>) -> [modifies A]
"""
struct MatrixOperator{T,AType<:AbstractMatrix{T},F} <: AbstractSciMLOperator{T}
A::AType
update_func::F
MatrixOperator(A::AType; update_func=DEFAULT_UPDATE_FUNC) where{AType} =
new{eltype(A),AType,typeof(update_func)}(A, update_func)
function MatrixOperator(A::AType; update_func=nothing, accepted_kwarg_fields=()) where {AType}
_update_func = preprocess_update_func(update_func, accepted_kwarg_fields)
new{eltype(A),AType,typeof(_update_func)}(A, _update_func)
end
end

# constructors
Expand All @@ -39,21 +41,21 @@ for op in (
if isconstant(L)
MatrixOperator($op(L.A))
else
update_func = (A,u,p,t) -> $op(L.update_func($op(L.A),u,p,t))
update_func = (A,u,p,t; kwargs...) -> $op(L.update_func($op(L.A),u,p,t; kwargs...))
MatrixOperator($op(L.A); update_func = update_func)
end
end
end
Base.conj(L::MatrixOperator) = MatrixOperator(
conj(L.A);
update_func= (A,u,b,t) -> conj(L.update_func(conj(L.A),u,p,t))
update_func= (A,u,p,t; kwargs...) -> conj(L.update_func(conj(L.A),u,p,t; kwargs...))
)

has_adjoint(A::MatrixOperator) = has_adjoint(A.A)
update_coefficients!(L::MatrixOperator,u,p,t) = (L.update_func(L.A,u,p,t); nothing)
update_coefficients!(L::MatrixOperator,u,p,t; kwargs...) = (L.update_func(L.A,u,p,t; kwargs...); nothing)

getops(L::MatrixOperator) = (L.A)
isconstant(L::MatrixOperator) = L.update_func == DEFAULT_UPDATE_FUNC
isconstant(L::MatrixOperator) = update_func_isconstant(L.update_func)
Base.iszero(L::MatrixOperator) = iszero(L.A)

SparseArrays.sparse(L::MatrixOperator) = sparse(L.A)
Expand Down Expand Up @@ -88,13 +90,13 @@ LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::MatrixOperator, u::AbstractVecOrMat)
LinearAlgebra.ldiv!(L::MatrixOperator, u::AbstractVecOrMat) = ldiv!(L.A, u)

"""
DiagonalOperator(diag, [; update_func])
DiagonalOperator(diag; update_func=nothing, accepted_kwarg_fields=())

Represents a time-dependent elementwise scaling (diagonal-scaling) operation.
The update function is called by `update_coefficients!` and is assumed to have
the following signature:

update_func(diag::AbstractVector,u,p,t) -> [modifies diag]
update_func(diag::AbstractVector,u,p,t; <accepted kwarg fields>) -> [modifies diag]

When `diag` is an `AbstractVector` of length N, `L=DiagonalOpeator(diag, ...)`
can be applied to `AbstractArray`s with `size(u, 1) == N`. Each column of the `u`
Expand All @@ -105,11 +107,12 @@ an operator of size `(N, N)` where `N = size(diag, 1)` is the leading length of
`L` then is the elementwise-scaling operation on arrays of `length(u) = length(diag)`
with leading length `size(u, 1) = N`.
"""
function DiagonalOperator(diag::AbstractVector; update_func = DEFAULT_UPDATE_FUNC)
diag_update_func = if update_func == DEFAULT_UPDATE_FUNC
DEFAULT_UPDATE_FUNC
function DiagonalOperator(diag::AbstractVector; update_func=nothing, accepted_kwarg_fields=())
_update_func = preprocess_update_func(update_func, accepted_kwarg_fields)
diag_update_func = if update_func_isconstant(_update_func)
_update_func
else
(A, u, p, t) -> (update_func(A.diag, u, p, t); A)
(A, u, p, t; kwargs...) -> (_update_func(A.diag, u, p, t; kwargs...); A)
end
MatrixOperator(Diagonal(diag); update_func=diag_update_func)
end
Expand Down Expand Up @@ -202,13 +205,13 @@ LinearAlgebra.ldiv!(v::AbstractVecOrMat, L::InvertibleOperator, u::AbstractVecOr
LinearAlgebra.ldiv!(L::InvertibleOperator, u::AbstractVecOrMat) = ldiv!(L.F, u)

"""
L = AffineOperator(A, B, b[; update_func])
L = AffineOperator(A, B, b; update_func=nothing, accepted_kwarg_fields=())
L(u) = A*u + B*b

Represents a time-dependent affine operator. The update function is called
by `update_coefficients!` and is assumed to have the following signature:

update_func(b::AbstractArray,u,p,t) -> [modifies b]
update_func(b::AbstractArray,u,p,t; <accepted kwarg fields>) -> [modifies b]
"""
struct AffineOperator{T,AType,BType,bType,cType,F} <: AbstractSciMLOperator{T}
A::AType
Expand Down Expand Up @@ -236,44 +239,52 @@ end
function AffineOperator(A::Union{AbstractMatrix,AbstractSciMLOperator},
B::Union{AbstractMatrix,AbstractSciMLOperator},
b::AbstractArray;
update_func = DEFAULT_UPDATE_FUNC,
update_func=nothing,
accepted_kwarg_fields=()
)
@assert size(A, 1) == size(B, 1) "Dimension mismatch: A, B don't output vectors
of same size"

_update_func = preprocess_update_func(update_func, accepted_kwarg_fields)

A = A isa AbstractMatrix ? MatrixOperator(A) : A
B = B isa AbstractMatrix ? MatrixOperator(B) : B
cache = B * b

AffineOperator(A, B, b, cache, update_func)
AffineOperator(A, B, b, cache, _update_func)
end

"""
L = AddVector(b[; update_func])
L = AddVector(b; update_func=nothing, accepted_kwarg_fields=())
L(u) = u + b
"""
function AddVector(b::AbstractVecOrMat; update_func = DEFAULT_UPDATE_FUNC)
function AddVector(b::AbstractVecOrMat; update_func=nothing, accepted_kwarg_fields=())
_update_func = preprocess_update_func(update_func, accepted_kwarg_fields)

N = size(b, 1)
Id = IdentityOperator(N)

AffineOperator(Id, Id, b; update_func=update_func)
AffineOperator(Id, Id, b; update_func=_update_func)
end

"""
L = AddVector(B, b[; update_func])
L = AddVector(B, b; update_func=nothing, accepted_kwarg_fields=())
L(u) = u + B*b
"""
function AddVector(B, b::AbstractVecOrMat; update_func = DEFAULT_UPDATE_FUNC)
function AddVector(B, b::AbstractVecOrMat; update_func=nothing, accepted_kwarg_fields=())
_update_func = preprocess_update_func(update_func, accepted_kwarg_fields)

N = size(B, 1)
Id = IdentityOperator(N)

AffineOperator(Id, B, b; update_func=update_func)
AffineOperator(Id, B, b; update_func=_update_func)
end

getops(L::AffineOperator) = (L.A, L.B, L.b)

update_coefficients!(L::AffineOperator,u,p,t) = (L.update_func(L.b,u,p,t); nothing)
isconstant(L::AffineOperator) = (L.update_func == DEFAULT_UPDATE_FUNC) & all(isconstant, (L.A, L.B))
update_coefficients!(L::AffineOperator,u,p,t; kwargs...) = (L.update_func(L.b,u,p,t; kwargs...); nothing)
isconstant(L::AffineOperator) = update_func_isconstant(L.update_func) & all(isconstant, (L.A, L.B))

islinear(::AffineOperator) = false

Base.size(L::AffineOperator) = size(L.A)
Expand Down
16 changes: 9 additions & 7 deletions src/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -90,22 +90,24 @@ end
Base.:+(α::AbstractSciMLScalarOperator) = α

"""
ScalarOperator(val[; update_func])
ScalarOperator(val; update_func=nothing, accepted_kwarg_fields=())

(α::ScalarOperator)(a::Number) = α * a

Represents a time-dependent scalar/scaling operator. The update function
is called by `update_coefficients!` and is assumed to have the following
signature:

update_func(oldval,u,p,t) -> newval
update_func(oldval,u,p,t; <accepted kwarg fields>) -> newval
"""
mutable struct ScalarOperator{T<:Number,F} <: AbstractSciMLScalarOperator{T}
val::T
update_func::F

ScalarOperator(val::T; update_func=DEFAULT_UPDATE_FUNC) where{T} =
new{T,typeof(update_func)}(val, update_func)
function ScalarOperator(val::T; update_func=nothing, accepted_kwarg_fields=()) where {T}
_update_func = preprocess_update_func(update_func, accepted_kwarg_fields)
new{T,typeof(_update_func)}(val, _update_func)
end
end

# constructors
Expand All @@ -118,7 +120,7 @@ ScalarOperator(λ::UniformScaling) = ScalarOperator(λ.λ)
# traits
function Base.conj(α::ScalarOperator) # TODO - test
val = conj(α.val)
update_func = (oldval,u,p,t) -> α.update_func(oldval |> conj,u,p,t) |> conj
update_func = (oldval,u,p,t; kwargs...) -> α.update_func(oldval |> conj,u,p,t; kwargs...) |> conj
ScalarOperator(val; update_func=update_func)
end

Expand All @@ -132,11 +134,11 @@ Base.abs(α::ScalarOperator) = abs(α.val)
Base.iszero(α::ScalarOperator) = iszero(α.val)

getops(α::ScalarOperator) = (α.val,)
isconstant(α::ScalarOperator) = α.update_func == DEFAULT_UPDATE_FUNC
isconstant(α::ScalarOperator) = update_func_isconstant(α.update_func)
has_ldiv(α::ScalarOperator) = !iszero(α.val)
has_ldiv!(α::ScalarOperator) = has_ldiv(α)

update_coefficients!(L::ScalarOperator,u,p,t) = (L.val = L.update_func(L.val,u,p,t); nothing)
update_coefficients!(L::ScalarOperator,u,p,t; kwargs...) = (L.val = L.update_func(L.val,u,p,t; kwargs...); nothing)

"""
Lazy addition of Scalar Operators
Expand Down
10 changes: 10 additions & 0 deletions src/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,14 @@ end
dims(A) = length(size(A))
dims(::AbstractArray{<:Any,N}) where{N} = N
dims(::AbstractSciMLOperator) = 2

# Keyword argument filtering
struct FilterKwargs{F,K}
f::F
accepted_kwarg_fields::K
end
function (f_filter::FilterKwargs)(args...; kwargs...)
filtered_kwargs = (kwarg => kwargs[kwarg] for kwarg in f_filter.accepted_kwarg_fields)
f_filter.f(args...; filtered_kwargs...)
end
#
11 changes: 11 additions & 0 deletions test/scalar.jl
Original file line number Diff line number Diff line change
Expand Up @@ -82,5 +82,16 @@ end
@test num(v,u,p,t) ≈ val * u

@test convert(Number, num) ≈ val

# Test scalar operator which expects keyword argument to update, modeled in the style of a DiffEq W-operator.
γ = ScalarOperator(0.0; update_func=(args...; dtgamma) -> dtgamma, accepted_kwarg_fields=(:dtgamma,))

dtgamma = rand()
@test γ(u,p,t; dtgamma) ≈ dtgamma * u
@test γ(v,u,p,t; dtgamma) ≈ dtgamma * u

γ_added = γ + α
@test γ_added(u,p,t; dtgamma) ≈ (dtgamma + p) * u
@test γ_added(v,u,p,t; dtgamma) ≈ (dtgamma + p) * u
end
#