diff --git a/src/chainrules.jl b/src/chainrules.jl index d6ae1c4..4f3ba62 100644 --- a/src/chainrules.jl +++ b/src/chainrules.jl @@ -27,7 +27,7 @@ function ChainRulesCore.rrule( ::typeof(*), bm::BlockDiagonal{T, V}, v::StridedVector{T} - ) where {T<:Union{Real, Complex}, V<:Matrix{T}} + ) where {T<:Union{Real, Complex}, V} y = bm * v diff --git a/test/chainrules.jl b/test/chainrules.jl index d4055ed..7f1217d 100644 --- a/test/chainrules.jl +++ b/test/chainrules.jl @@ -1,17 +1,19 @@ @testset "chainrules.jl" begin - @testset "BlockDiagonal" begin - x = [randn(1, 2), randn(2, 2)] - test_rrule(BlockDiagonal, x) - end + @testset for V in (Tuple, Vector) + @testset "BlockDiagonal" begin + x = V([randn(1, 2), randn(2, 2)]) + test_rrule(BlockDiagonal, x) + end - @testset "Matrix" begin - D = BlockDiagonal([randn(1, 2), randn(2, 2)]) - test_rrule(Matrix, D) - end + @testset "Matrix" begin + B = BlockDiagonal(V([randn(1, 2), randn(2, 2)])) + test_rrule(Matrix, B) + end - @testset "BlockDiagonal * Vector" begin - D = BlockDiagonal([rand(2, 3), rand(3, 3)]) - v = rand(6) - test_rrule(*, D, v) + @testset "BlockDiagonal * Vector" begin + B = BlockDiagonal(V([rand(2, 3), rand(3, 3)])) + v = rand(6) + test_rrule(*, B, v) + end end end