From 1241bba9aae42e583d5cef887ffae768ab542464 Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Thu, 7 Nov 2024 20:10:53 -0500 Subject: [PATCH 1/4] add allowslow --- src/NNlib.jl | 9 +++++++++ src/batched/batchedmul.jl | 5 ++++- src/conv.jl | 4 ++++ test/batchedmul.jl | 11 +++++++++++ 4 files changed, 28 insertions(+), 1 deletion(-) diff --git a/src/NNlib.jl b/src/NNlib.jl index 687206fc..87a76197 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -18,6 +18,15 @@ using Statistics: mean const Numeric = Union{AbstractArray{<:T}, T} where {T<:Number} +""" + allowslow(::Bool) + +By default, NNlib will print warnings the first time various slow fallback paths are taken. +Calling `allowslow(false)` will instead make these into errors. +""" +allowslow(flag::Bool) = (SLOWERROR[] = !flag; nothing) +const SLOWERROR = Ref(true) + # Include APIs include("dim_helpers.jl") export ConvDims, DenseConvDims, PoolDims, DepthwiseConvDims diff --git a/src/batched/batchedmul.jl b/src/batched/batchedmul.jl index ccd9b0e8..c71ed8c3 100644 --- a/src/batched/batchedmul.jl +++ b/src/batched/batchedmul.jl @@ -274,7 +274,10 @@ for (TA, fA) in _BATCHED_LIST, (TB, fB) in _BATCHED_LIST size(A, 3) == size(C, 3) || size(A, 3) == 1 || throw(DimensionMismatch("batch size mismatch: A != C")) size(B, 3) == size(C, 3) || size(B, 3) == 1 || throw(DimensionMismatch("batch size mismatch: B != C")) - @debug "calling fallback method for batched_mul!" typeof(A) size(A) typeof(B) size(B) typeof(C) + @warn "calling fallback method for batched_mul!" typeof(A) size(A) typeof(B) size(B) typeof(C) maxlog=1 + if SLOWERROR[] + error("calling fallback method for batched_mul!") + end Abase, Bbase = _unbatch(A), _unbatch(B) sA, oA = size(A,3) == 1 ? (0,1) : (1,0) diff --git a/src/conv.jl b/src/conv.jl index fead2ee2..4d67f4a4 100644 --- a/src/conv.jl +++ b/src/conv.jl @@ -191,6 +191,7 @@ for (front_name, backend, signature) in ( if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 + SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name)))) end x_cs = Iterators.partition(1:size(in1, 4), @@ -232,6 +233,7 @@ for (front_name, backend, signature) in ( if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 + SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name)))) end @@ -275,6 +277,7 @@ for (front_name, backend, signature) in ( if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 + SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name)))) end dw_cs = Iterators.partition(1:size(out, 5), @@ -326,6 +329,7 @@ for (front_name, backend, signature) in ( if $(string(backend)) == "direct" && yT == Float64 # warn for Float32 + accidental Float64, but don't print warning for ForwardDiff.Dual @warn string("Slow fallback implementation invoked for ", $(string(front_name)), "! ", "You probably don't want this; check your datatypes.") yT T1 T2 maxlog=1 + SLOWERROR[] && error(string("calling slow fallback method for ", $(string(front_name)))) end $(Symbol("$(front_name)_$(backend)!"))(out, in1, in2, cdims; kwargs...) end diff --git a/test/batchedmul.jl b/test/batchedmul.jl index 1b8b08e1..b249112c 100644 --- a/test/batchedmul.jl +++ b/test/batchedmul.jl @@ -303,3 +303,14 @@ FiniteDifferences.to_vec(x::BatchedTranspose) = FiniteDifferences.to_vec(collect gradtest(batched_vec, randn(rng, M, P, B), randn(rng, P)) end + +@testset "warning / error" begin + prev = NNlib.SLOWERROR[] + NNlib.allowslow(true) + A = rand(1:99, 3,4,7) + B = rand(1:99, 4,5,7) + @test batched_mul(A, B) isa Array # no error! + NNlib.allowslow(false) + @test_throws Exception batched_mul(A, B) + NNlib.SLOWERROR[] = prev +end From b29d00d65617032bbc7bd6239ab90a0fc5690a4d Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 8 Nov 2024 16:46:52 -0500 Subject: [PATCH 2/4] wrong default! --- src/NNlib.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/NNlib.jl b/src/NNlib.jl index 87a76197..cb82b751 100644 --- a/src/NNlib.jl +++ b/src/NNlib.jl @@ -25,7 +25,7 @@ By default, NNlib will print warnings the first time various slow fallback paths Calling `allowslow(false)` will instead make these into errors. """ allowslow(flag::Bool) = (SLOWERROR[] = !flag; nothing) -const SLOWERROR = Ref(true) +const SLOWERROR = Ref(false) # Include APIs include("dim_helpers.jl") From d4c4b61b3a5f7033ae7672f6fe2a876b84fec2dc Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 8 Nov 2024 16:48:47 -0500 Subject: [PATCH 3/4] add to docs --- docs/src/reference.md | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/docs/src/reference.md b/docs/src/reference.md index 5edde719..8ed28b8a 100644 --- a/docs/src/reference.md +++ b/docs/src/reference.md @@ -164,3 +164,10 @@ NNlib.glu NNlib.within_gradient bias_act! ``` + +Finally, this switch changes warnings on various fallback paths into errors. +It's a bit like `CUDA.allowscalar(false)`. + +```@docs +allowslow +``` From 5bfbda76965c60f8ee3e3b199f942c2980d0bc9c Mon Sep 17 00:00:00 2001 From: Michael Abbott <32575566+mcabbott@users.noreply.github.com> Date: Fri, 8 Nov 2024 19:39:47 -0500 Subject: [PATCH 4/4] fix doctest --- src/fold.jl | 34 +++++++++++++++++----------------- 1 file changed, 17 insertions(+), 17 deletions(-) diff --git a/src/fold.jl b/src/fold.jl index f3c205e1..6594aeb0 100644 --- a/src/fold.jl +++ b/src/fold.jl @@ -16,35 +16,35 @@ and a potential inverse of `unfold`. The below example demonstrates that `unfold` uses the same sliding windows as `conv`. In general [`batched_mul`](@ref) + `unfold` should not be used to achieve convolution. ```jldoctest -julia> x = reshape([100 2 3 40 5 6 700], 7, 1, 1); # 1D data, 1 channel, batch of 1 +julia> x = reshape(Float32[100 2 3 40 5 6 700], 7, 1, 1); # 1D data, 1 channel, batch of 1 -julia> w = reshape([1 0 -1], 3, 1, 1); # 1D conv kernel of length 3 +julia> w = reshape(Float32[1 0 -1], 3, 1, 1); # 1D conv kernel of length 3 julia> kws = (pad=1, stride=2, flipped=true); # use same args for conv and unfold julia> z = NNlib.unfold(x, size(w); kws...) -4×3×1 Array{Int64, 3}: +4×3×1 Array{Float32, 3}: [:, :, 1] = - 0 100 2 - 2 3 40 - 40 5 6 - 6 700 0 + 0.0 100.0 2.0 + 2.0 3.0 40.0 + 40.0 5.0 6.0 + 6.0 700.0 0.0 julia> y1 = conv(x, w; kws...) -4×1×1 Array{Int64, 3}: +4×1×1 Array{Float32, 3}: [:, :, 1] = - -2 - -38 - 34 - 6 + -2.0 + -38.0 + 34.0 + 6.0 julia> y2 = z ⊠ w # ⊠ (\\boxtimes) is NNlib.batched_mul -4×1×1 Array{Int64, 3}: +4×1×1 Array{Float32, 3}: [:, :, 1] = - -2 - -38 - 34 - 6 + -2.0 + -38.0 + 34.0 + 6.0 ``` """ function unfold(x::AbstractArray{T, N}, kernel_size::NTuple{K}; stride = 1, pad = 0, dilation = 1, flipped = true) where {T, K, N}