Skip to content
This repository has been archived by the owner on Jul 19, 2023. It is now read-only.

Commit

Permalink
Merge pull request #128 from JuliaDiffEq/nnlib_convolutions
Browse files Browse the repository at this point in the history
mul! implemented for 2d and 3d multiplication with NNlib
  • Loading branch information
ChrisRackauckas authored Jul 2, 2019
2 parents 13be768 + ce1269e commit 7213994
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 1 deletion.
2 changes: 2 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
SuiteSparse = "4607b0f0-06f3-5cda-b6b1-a6196a1729e9"
BandedMatrices = "aae01518-5342-5314-be14-df237901396f"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"


[compat]
julia = "1"
Expand Down
2 changes: 1 addition & 1 deletion src/DiffEqOperators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import Base: +, -, *, /, \, size, getindex, setindex!, Matrix, convert
using DiffEqBase, StaticArrays, LinearAlgebra
import LinearAlgebra: mul!, ldiv!, lmul!, rmul!, axpy!, opnorm, factorize, I
import DiffEqBase: AbstractDiffEqLinearOperator, update_coefficients!, is_constant
using SparseArrays, ForwardDiff, BandedMatrices
using SparseArrays, ForwardDiff, BandedMatrices, NNlib

abstract type AbstractDerivativeOperator{T} <: AbstractDiffEqLinearOperator{T} end
abstract type AbstractDiffEqCompositeOperator{T} <: AbstractDiffEqLinearOperator{T} end
Expand Down
65 changes: 65 additions & 0 deletions src/derivative_operators/derivative_operator_functions.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,11 @@
function LinearAlgebra.mul!(x_temp::AbstractArray{T}, A::DerivativeOperator{T,N}, M::AbstractArray{T}) where {T<:Real,N}

# Check that x_temp has correct dimensions
v = zeros(ndims(x_temp))
v[N] = 2
@assert [size(x_temp)...]+v == [size(M)...]

# Check that axis of differentiation is in the dimensions of M and x_temp
ndimsM = ndims(M)
@assert N <= ndimsM

Expand All @@ -23,6 +25,69 @@ function LinearAlgebra.mul!(x_temp::AbstractArray{T}, A::DerivativeOperator{T,N}
end
end

for MT in [2,3]
@eval begin
function LinearAlgebra.mul!(x_temp::AbstractArray{T,$MT}, A::DerivativeOperator{T,N,Wind,T2,S1}, M::AbstractArray{T,$MT}) where {T<:Real,N,Wind,T2,SL,S1<:SArray{Tuple{SL},T,1,SL}}

# Check that x_temp has correct dimensions
v = zeros(ndims(x_temp))
v[N] = 2
@assert [size(x_temp)...]+v == [size(M)...]

# Check that axis of differentiation is in the dimensions of M and x_temp
ndimsM = ndims(M)
@assert N <= ndimsM

# Respahe x_temp for NNlib.conv!
new_size = Any[size(x_temp)...]
bpc = A.boundary_point_count
setindex!(new_size, new_size[N]- 2*bpc, N)
new_shape = []
for i in 1:ndimsM
if i != N
push!(new_shape,:)
else
push!(new_shape,bpc+1:new_size[N]+bpc)
end
end
_x_temp = reshape(view(x_temp, new_shape...), (new_size...,1,1))

# Reshape M for NNlib.conv!
_M = reshape(M, (size(M)...,1,1))
s = A.stencil_coefs
sl = A.stencil_length

# Setup W, the kernel for NNlib.conv!
Wdims = ones(Int64, ndims(_x_temp))
Wdims[N] = sl
W = zeros(Wdims...)
Widx = Any[Wdims...]
setindex!(Widx,:,N)
W[Widx...] = s ./ A.dx^A.derivative_order # this will change later
cv = DenseConvDims(_M, W)

conv!(_x_temp, _M, W, cv)

# Now deal with boundaries
dimsM = [axes(M)...]
alldims = [1:ndims(M);]
otherdims = setdiff(alldims, N)

idx = Any[first(ind) for ind in axes(M)]
itershape = tuple(dimsM[otherdims]...)
nidx = length(otherdims)
indices = Iterators.drop(CartesianIndices(itershape), 0)

setindex!(idx, :, N)
for I in indices
Base.replace_tuples!(nidx, idx, idx, otherdims, I)
convolve_BC_left!(view(x_temp, idx...), view(M, idx...), A)
convolve_BC_right!(view(x_temp, idx...), view(M, idx...), A)
end
end
end
end

function *(A::DerivativeOperator{T,N},M::AbstractArray{T}) where {T<:Real,N}
size_x_temp = [size(M)...]
size_x_temp[N] -= 2
Expand Down

0 comments on commit 7213994

Please sign in to comment.