diff --git a/README.md b/README.md index 4ab14a2..df32dbc 100644 --- a/README.md +++ b/README.md @@ -102,6 +102,7 @@ Sigmoid Softmax Squeeze Tanh +Unsqueeze ``` ## Adding Operations diff --git a/src/deserialize/constraints.jl b/src/deserialize/constraints.jl index d821887..382a211 100644 --- a/src/deserialize/constraints.jl +++ b/src/deserialize/constraints.jl @@ -76,6 +76,13 @@ NaiveNASflux.layer(r::Reshape) = r NaiveNASflux.actdim(r::Reshape) = r.adim NaiveNASflux.actrank(r::Reshape) = length(r.dims) +function Base.show(io::IO, s::Reshape) + print(io, "Reshape(dims=") + ioc = IOContext(io, :prefix => "[", :suffix=>"]") + show(ioc, s.dims) + print(io, ")") +end + # If Reshape is wrapped in an AbstractMutableComp we will hit this method instead due to how NaiveNASflux unwraps things NaiveNASlib.compconstraint!(case, s::NaiveNASlib.AbstractJuMPΔSizeStrategy, ::Type{<:Reshape}, data) = NaiveNASlib.compconstraint!(case, s, layer(data.vertex), data) @@ -179,6 +186,13 @@ NaiveNASflux.actrank(f::Flatten) = 1 calc_outsize(::Flatten, v) = 0 +function Base.show(io::IO, s::Flatten) + print(io, "Flatten(dim=") + ioc = IOContext(io, :prefix => "[", :suffix=>"]") + show(ioc, s.dim) + print(io, ")") +end + # If Flatten is wrapped in an AbstractMutableComp we will hit this method instead due to how NaiveNASflux unwraps things NaiveNASlib.compconstraint!(case, s::NaiveNASlib.AbstractJuMPΔSizeStrategy, ::Type{<:Flatten}, data) = NaiveNASlib.compconstraint!(case, s, layer(data.vertex), data) diff --git a/src/deserialize/ops.jl b/src/deserialize/ops.jl index fb8e292..95bdfdc 100644 --- a/src/deserialize/ops.jl +++ b/src/deserialize/ops.jl @@ -208,32 +208,114 @@ fluxlayertypes[:AveragePool] = (pars...) -> FluxPoolLayer() fluxlayers[:Dropout] = params -> Dropout(get(params, :ratio, 0.5)) fluxlayertypes[:Dropout] = (pars...) -> FluxDropOut() - invariantops[:GlobalAveragePool] = function(params) wrap = get(params, :wrap, identity) - return x -> globalmeanpool(x, wrap) + return wrap ∘ GlobalMeanPool() end fluxlayertypes[:GlobalAveragePool] = (pars...) -> FluxPoolLayer() -function globalmeanpool(x::AbstractArray{T,N}, wrap) where T where N - wrap(MeanPool(size(x)[1:N-2])(x)) -end - invariantops[:GlobalMaxPool] = function(params) wrap = get(params, :wrap, identity) - return x -> globalmaxpool(x, wrap) + return wrap ∘ GlobalMaxPool() end fluxlayertypes[:GlobalMaxPool] = (pars...) -> FluxPoolLayer() -function globalmaxpool(x::AbstractArray{T,N}, wrap) where T where N - wrap(MaxPool(size(x)[1:N-2])(x)) +""" + Squeeze(dims) + +Callable struct which performs `dropdims` on input using the provided `dims` where `dims` is compliant with the ONNX OP Squeeze (meaning it can be missing or use numpy indexing). + +Mainly exists for pretty printing reaons though as its task can be performed by partially applied functions. + +Designed to only be used when deserializing the `Squeeze` operation. +""" +struct Squeeze{D} + dims::D +end +(s::Squeeze)(x) = dropdims(x; dims=s.dims) +(s::Squeeze{Missing})(x) = dropdims(x; dims=Tuple(findall(i -> i == 1, size(x)))) +(s::Squeeze{<:NumPyAxes})(x) = dropdims(x; dims=Tuple(numpy2fluxdim(s.dims, ndims(x)))) + +Base.show(io::IO, ::Squeeze{Missing}) = print(io, "Squeeze") +function Base.show(io::IO, s::Squeeze) + print(io, "Squeeze(dims=") + ioc = IOContext(io, :prefix => "[", :suffix=>"]") + show(ioc, s.dims) + print(io, ")") end + invariantops[:Squeeze] = function(params) np_axes = get(params, :axes, missing) - dimfun = ismissing(np_axes) ? x -> Tuple(findall(i -> i == 1, size(x))) : x -> Tuple(numpy2fluxdim.(np_axes, ndims(x))) - return x -> dropdims(x, dims=dimfun(x)) + dims = if !ismissing(np_axes) + NumPyAxes(Tuple(np_axes)) + else + np_axes + end + return Squeeze(dims) +end + +""" + Unsqueeze(dims) + +Callable struct which performs `reshape` on input using the provided `dims` where `dims` is compliant with the ONNX OP `Unsqueeze` (meaning it can use numpy indexing). + +Mainly exists for pretty printing reaons though as its task can be performed by partially applied functions. + +Designed to only be used when deserializing the `Unsqueeze` operation. +""" +struct Unsqueeze{D} + dims::D +end + +(u::Unsqueeze)(x) = unsqueeze_onnx(x, u.dims) + +function Base.show(io::IO, s::Unsqueeze) + print(io, "Unsqueeze(dims=") + ioc = IOContext(io, :prefix => "[", :suffix=>"]") + show(ioc, s.dims) + print(io, ")") +end + +invariantops[:Unsqueeze] = function(params) + haskey(params, :axes) || throw(ArgumentError("Must supply axes for Unsqueeze!")) + return Unsqueeze(NumPyAxes(params[:axes])) +end + +unsqueeze_onnx(x, np_axes) = reshape(x, insdims(size(x), np_axes)) + +struct Sorted{T} + vals::T + function Sorted(x) + vals = issorted(x) ? x : sort(x) + new{typeof(vals)}(vals) + end end +Base.getindex(s::Sorted, args...) = Base.getindex(s.vals, args...) +Base.length(s::Sorted) = length(s.vals) + +# Probably premature optimization: Allow for users to avoid numpy2fluxdim and sorting if they really want to. + +function insdims(orgsize, np_axes::NumPyAxes; ndimsout=length(orgsize) + length(np_axes), kwargs...) + insdims(orgsize, numpy2fluxdim(np_axes, ndimsout); ndimsout, kwargs...) +end + +insdims(orgsize, dimstoadd; kwargs...) = insdims(orgsize, Sorted(dimstoadd); kwargs...) +insdims(orgsize, dims::Sorted; ndimsout=length(orgsize) + length(dims), inssize=Returns(1)) = let + currax = Ref(1) + dimoffs = Ref(0) + ntuple(ndimsout) do i + if currax[] <= length(dims) && dims[currax[]] == i + ins = inssize(currax[]) + currax[] += 1 + dimoffs[] += 1 + ins + else + orgsize[i - dimoffs[]] + end + end +end + invariantops[:ReduceMean] = function(params) np_axes = get(params, :axes, missing) diff --git a/src/serialize/serialize.jl b/src/serialize/serialize.jl index cd42f48..24c57b2 100644 --- a/src/serialize/serialize.jl +++ b/src/serialize/serialize.jl @@ -401,6 +401,7 @@ Flux.elu(pp::AbstractProbe, α=1f0) = attribfun(identity, "Elu", pp; attributes Flux.selu(pp::AbstractProbe) = attribfun(identity, "Selu", pp) Flux.selu(pp::AbstractProbe, γ, α) = attribfun(identity, "Selu", pp; attributes = ONNX.AttributeProto.(["gamma", "alpha"], [γ, α])) Flux.σ(pp::AbstractProbe) = attribfun(identity, "Sigmoid", pp) +Flux.sigmoid_fast(pp::AbstractProbe) = attribfun(identity, "Sigmoid", pp) # Flux-specific construct Base.tanh(pp::AbstractProbe) = attribfun(identity, "Tanh", pp) Flux.softmax(pp::AbstractProbe; dims=1) = onnxsoftmax(pp, np_axis = flux2numpydim(dims[end], ndims(pp))) @@ -410,6 +411,8 @@ onnxsoftmax(pp::AbstractProbe; np_axis=1) = attribfun(identity, "Softmax", pp; (l::Flux.MeanPool)(pp::AbstractProbe) = attribfun(s -> outshape(l, s), "AveragePool", pp; attributes = attribs(l)) (l::Flux.Dropout)(pp::AbstractProbe) = attribfun(identity, "Dropout", pp; attributes = [ONNX.AttributeProto("ratio", l.p)]) +(l::Flux.GlobalMaxPool)(pp::AbstractProbe) = globalmaxpool(pp, identity) +(l::Flux.GlobalMeanPool)(pp::AbstractProbe) = globalmeanpool(pp, identity) globalmeanpool(pp::AbstractProbe, wrap) = globalpool(pp, wrap, "GlobalAveragePool") globalmaxpool(pp::AbstractProbe, wrap) = globalpool(pp, wrap, "GlobalMaxPool") @@ -498,8 +501,6 @@ end function axisfun(fshape, optype, pps::AbstractProbe...; dims, axname="axes") - fname = recursename(lowercase(optype), nextname(pps[1])) - attrib = if isempty(dims) ONNX.AttributeProto[] else @@ -508,6 +509,11 @@ function axisfun(fshape, optype, pps::AbstractProbe...; dims, axname="axes") np_axis = flux2numpydim.(dims, ndims(pok[1])) [ONNX.AttributeProto(axname, np_axis)] end + axisfun(fshape, optype, attrib, pps...) +end + +function axisfun(fshape, optype, attrib::AbstractArray{<:ONNX.AttributeProto}, pps::AbstractProbe...) + fname = recursename(lowercase(optype), nextname(pps[1])) add!(pps[1], ONNX.NodeProto( input = collect(name.(pps)), @@ -576,3 +582,7 @@ function flatten(pp::AbstractProbe, dim) end return newfrom(pp, fname, fshape) end + +Flux.unsqueeze(pp::AbstractProbe; dims) = axisfun(s -> insdims(s, dims), "Unsqueeze", pp; dims=scal2tup(dims)) +unsqueeze_onnx(pp::AbstractProbe, npa::NumPyAxes) = axisfun(s -> insdims(s, npa), "Unsqueeze", [ONNX.AttributeProto("axes", npa.axes)], pp) + diff --git a/src/shapes.jl b/src/shapes.jl index f1ddb69..a527592 100644 --- a/src/shapes.jl +++ b/src/shapes.jl @@ -9,8 +9,44 @@ NaiveNASflux.actrank(::Shape1D) = 0 # TODO: Move to NaiveNASflux NaiveNASflux.nin(sc::SkipConnection) = nin(sc.layers) +""" + NumPyAxes(axes) + +Represents `axes` using numpy conventions, e.g. `0` means last dimension, `1` means second last and `-1` means first etc. due to row major vs column major indexing. + +Mainly intended to be used when deserializing ONNX OPs which operate along a provided set of dimensions to mark that we +don't know yet how to translate the numpy axes to Julia axes. + +Primarily exists for pretty printing reasons so we can show which dimensions certain ops will operate along in a Julia compliant manner (e.g. numpy index `1` is shown as `end-1`). + +### Examples +```julia-repl +julia> ONNXNaiveNASflux.NumPyAxes([0, 1, 2, -1, -2]) +NumPyAxes[end,end-1,end-2,1,2] +``` +""" +struct NumPyAxes{AX} + axes::AX +end +# This is cute, but generally not needed since all ops which make use of this need to supply some dims argument +# which is pretty much always required to be a Tuple or an Int by some method quite high up in the call stack. +#Base.to_index(x::AbstractArray, npa::NumPyAxes) = numpy2fluxdim.(collect(npa.axes), ndims(x)) +Base.length(npa::NumPyAxes) = length(npa.axes) + +function Base.show(io::IO, npa::NumPyAxes) + print(io, get(io, :prefix, "NumPyAxes[")) + indstr = map(npa.axes) do ax + ax === 0 && return "end" + ax > 0 && return string("end-", ax) + -ax + end + print(io, join(indstr, ',')) + print(io, get(io, :suffix, "]")) +end + numpy2fluxdim(np_axis, v::AbstractVertex) = numpy2fluxdim(np_axis, 1 + NaiveNASflux.actrank(v)[1]) -numpy2fluxdim(np_axis, ndims) = np_axis >= 0 ? ndims - np_axis : abs(np_axis) +numpy2fluxdim(npa::NumPyAxes, ndims::Integer) = numpy2fluxdim.(npa.axes, ndims) +numpy2fluxdim(np_axis, ndims) = np_axis >= 0 ? ndims - np_axis : -np_axis flux2numpydim(dim, ndims) = ndims - dim diff --git a/test/deserialize/Artifacts.toml b/test/deserialize/Artifacts.toml index 6995971..4a174a9 100644 --- a/test/deserialize/Artifacts.toml +++ b/test/deserialize/Artifacts.toml @@ -342,3 +342,27 @@ git-tree-sha1 = "458aa7b72548ef30e68311e2ba83fee5232fdcdd" [test_squeeze_negative_axes] git-tree-sha1 = "085e5374680b94e4ac1748a9a944ab67eb933ea5" + +[test_unsqueeze_axis_0] +git-tree-sha1 = "e5b5cba31574fa5742cdbfbd76ef13f3ad49a91b" + +[test_unsqueeze_axis_1] +git-tree-sha1 = "161893d3dda3c7ced6c7882e237b4488deaa7c41" + +[test_unsqueeze_axis_2] +git-tree-sha1 = "38acbb013bfb691b4e75a79aeb49d7220666bd6a" + +[test_unsqueeze_axis_3] +git-tree-sha1 = "db2c2acf07c3010fea726f104c4d5e9c922b92ac" + +[test_unsqueeze_negative_axes] +git-tree-sha1 = "3b3de957f53bba7e4f14a8e7325e100d9cb66482" + +[test_unsqueeze_three_axes] +git-tree-sha1 = "376f228da8a6d73d7d4f5248496b9074cd349a12" + +[test_unsqueeze_two_axes] +git-tree-sha1 = "a607099a395a3a323ac3528b3a155034355d582f" + +[test_unsqueeze_unsorted_axes] +git-tree-sha1 = "22deacbe190142e6abf4631b332e087acee00ac4" diff --git a/test/deserialize/deserialize.jl b/test/deserialize/deserialize.jl index 5f47b3d..1041ec1 100644 --- a/test/deserialize/deserialize.jl +++ b/test/deserialize/deserialize.jl @@ -216,6 +216,14 @@ end (name="test_softmax_negative_axis", ninputs=1, noutputs=1, fd=invariantops), (name="test_squeeze", ninputs=1, noutputs=1, fd=invariantops), (name="test_squeeze_negative_axes", ninputs=1, noutputs=1, fd=invariantops), + (name="test_unsqueeze_axis_0", ninputs=1, noutputs=1, fd=invariantops), + (name="test_unsqueeze_axis_1", ninputs=1, noutputs=1, fd=invariantops), + (name="test_unsqueeze_axis_2", ninputs=1, noutputs=1, fd=invariantops), + (name="test_unsqueeze_axis_3", ninputs=1, noutputs=1, fd=invariantops), + (name="test_unsqueeze_negative_axes", ninputs=1, noutputs=1, fd=invariantops), + (name="test_unsqueeze_three_axes", ninputs=1, noutputs=1, fd=invariantops), + (name="test_unsqueeze_two_axes", ninputs=1, noutputs=1, fd=invariantops), + (name="test_unsqueeze_unsorted_axes", ninputs=1, noutputs=1, fd=invariantops), ) model, gb, inputs, outputs = prepare_node_test(tc.name, tc.ninputs, tc.noutputs) @@ -417,4 +425,42 @@ end @test nout(g[end]) == 60 @test length(defaultutility(g[end])) == exputilsize end + + @testset "Squeeze" begin + import ONNXNaiveNASflux: Squeeze, NumPyAxes + + # The top two are not hit through any other test + @test size(Squeeze((1,2))(reshape(ones(2,3), 1, 1, 2, 1, 3, 1))) == (2, 1, 3, 1) + @test size(Squeeze(missing)(reshape(ones(2,3), 1, 1, 2, 1, 3, 1))) == (2, 3) + @test size(Squeeze(NumPyAxes([-1, 0, 2]))(reshape(ones(2,3), 1, 1, 2, 1, 3, 1))) == (1, 2, 3) + end + + @testset "Pretty printing" begin + import ONNXNaiveNASflux: NumPyAxes + + @testset "NumPyAxes" begin + @test sprint(show, NumPyAxes([0, 1, 2, -1, 2])) == "NumPyAxes[end,end-1,end-2,1,end-2]" + end + + @testset "Squeeze" begin + import ONNXNaiveNASflux: Squeeze + @test sprint(show, Squeeze(NumPyAxes([0, 1, -2]))) == "Squeeze(dims=[end,end-1,2])" + @test sprint(show, Squeeze(missing)) == "Squeeze" + end + + @testset "Unsqueeze" begin + import ONNXNaiveNASflux: Unsqueeze + @test sprint(show, Unsqueeze(NumPyAxes([0, 1, -2]))) == "Unsqueeze(dims=[end,end-1,2])" + end + + @testset "Reshape" begin + import ONNXNaiveNASflux: Reshape + @test sprint(show, Reshape((1, 2))) == "Reshape(dims=(1, 2))" + end + + @testset "Flatten" begin + import ONNXNaiveNASflux: Flatten + @test sprint(show, Flatten(2)) == "Flatten(dim=2)" + end + end end \ No newline at end of file diff --git a/test/serialize/serialize.jl b/test/serialize/serialize.jl index 00e7477..5a14307 100644 --- a/test/serialize/serialize.jl +++ b/test/serialize/serialize.jl @@ -26,6 +26,7 @@ using ONNXNaiveNASflux.NaiveNASflux import ONNXNaiveNASflux.NaiveNASflux: weights, bias import ONNXNaiveNASflux: AbstractProbe, nextname, newfrom, add!, genname, shape, nextshape + import Flux: unsqueeze struct NodeProbe{F, S} <: AbstractProbe name::String namefun::F @@ -80,9 +81,10 @@ end @testset "Dims method $(tc.ot)" for tc in ( - (f=cat, dims=1, ndims=2, ot=:Concat, axname=:axis), - (f=mean, dims=(2, 3), ndims=4, ot=:ReduceMean, axname=:axes), - (f=dropdims, dims=(3,), ndims=3, ot=:Squeeze, axname=:axes) + (f=cat, dims=1, expdims=1, ndims=2, ot=:Concat, axname=:axis), + (f=mean, dims=(2, 3), expdims=[2, 3], ndims=4, ot=:ReduceMean, axname=:axes), + (f=dropdims, dims=(3,), expdims=[3], ndims=3, ot=:Squeeze, axname=:axes), + (f=unsqueeze, dims=3, expdims=[3], ndims=3, ot=:Unsqueeze, axname=:axes), ) inprobe = NodeProbe("input", f -> "output", Tuple(1:tc.ndims)) @@ -96,8 +98,14 @@ @test output(res) == [name(outprobe)] @test optype(res) == tc.ot @test name(res) == name(outprobe) - expdims = tc.dims isa Tuple ? collect(tc.dims) : tc.dims - @test ONNXNaiveNASflux.numpy2fluxdim.(res.attribute[tc.axname], tc.ndims) == expdims + @test ONNXNaiveNASflux.numpy2fluxdim.(res.attribute[tc.axname], tc.ndims) == tc.expdims + + x = ones(Float32, ntuple(Returns(1), tc.ndims)) + invertex = convinputvertex(name(inprobe), 1, tc.ndims-1) + @test ONNXNaiveNASflux.verts[tc.ot](name(res), [invertex], res.attribute)(x) == tc.f(x; dims=tc.dims) + + ortout, = onnxruntime_infer(x -> tc.f(x; dims=tc.dims), x) + @test ortout == tc.f(x; dims=tc.dims) end @testset "Reshape" begin @@ -360,7 +368,10 @@ bnvertex(name, inpt::AbstractVertex, actfun=identity) = fluxvertex(name, BatchNorm(nout(inpt), actfun), inpt) - mpvertex(name, inpt::AbstractVertex) = fluxvertex(name, MaxPool((2,2); pad=(1,0), stride=(1,2)), inpt) + maxpvertex(name, inpt::AbstractVertex) = fluxvertex(name, MaxPool((2,2); pad=(1,0), stride=(1,2)), inpt) + + # TODO: Make which OP types shall be merged into a single vertex configurable... + gmpvertex(name, inpt::AbstractVertex) = invariantvertex(name, x -> dropdims(GlobalMeanPool()(x); dims=(1,2)), inpt) fvertex(name, inpt::AbstractVertex, f) = invariantvertex(name, f, inpt) @@ -500,7 +511,7 @@ v0 = conv2dinputvertex("input", 3) v1 = convvertex("conv1", v0, 4, relu) v2 = convvertex("conv2", v1, 5, elu) - v3 = fvertex("globmeanpool", v2, x -> ONNXNaiveNASflux.globalmeanpool(x, y -> dropdims(y, dims=(1,2)))) + v3 = gmpvertex("globalmeanpool", v2) v4 = dense("output", v3, 2) test_named_graph(CompGraph(v0, v4), (2,3)) @@ -509,7 +520,7 @@ @testset "Linear Conv graph with global pooling without names" begin v0 = conv2dinputvertex("input", 3) v1 = convvertex("", v0, 4, relu) - v2 = invariantvertex(x -> ONNXNaiveNASflux.globalmeanpool(x, y -> dropdims(y, dims=(1,2))), v1) + v2 = gmpvertex("", v1) g_org = CompGraph(v0, v2) @@ -530,7 +541,7 @@ v0 = conv2dinputvertex("input", 3) v1 = convvertex("conv", v0, 4, relu) v2 = bnvertex("batchnorm", v1, elu) - v3 = fvertex("globmeanpool", v2, x -> ONNXNaiveNASflux.globalmeanpool(x, y -> dropdims(y, dims=(1,2)))) + v3 = gmpvertex("globalmeanpool", v2) v4 = dense("output", v3, 2, selu) test_named_graph(CompGraph(v0, v4), (4,6)) @@ -538,9 +549,9 @@ @testset "Linear Conv and MaxPool graph with global pooling" begin v0 = conv2dinputvertex("input", 3) - v1 = mpvertex("maxpool", v0) + v1 = maxpvertex("maxpool", v0) v2 = convvertex("conv", v1, 4, relu) - v3 = fvertex("globmeanpool", v2, x -> ONNXNaiveNASflux.globalmeanpool(x, y -> dropdims(y, dims=(1,2)))) + v3 = gmpvertex("globalmeanpool", v2) v4 = dense("output", v3, 2, selu) test_named_graph(CompGraph(v0, v4), (2,3)) @@ -685,7 +696,7 @@ v1 = convvertex("conv", v0, 2, elu) v2 = bnvertex("batchnorm", v0) v3 = concat("conc", v1, v2) - v4 = fvertex("globmeanpool", v3, x -> ONNXNaiveNASflux.globalmeanpool(x, y -> dropdims(y, dims=(1,2)))) + v4 = gmpvertex("globalmeanpool", v3) v5 = dense("output", v4, 2, relu) test_named_graph(CompGraph(v0, v5), (2,3)) @@ -828,7 +839,7 @@ v0 = conv2dinputvertex("input", 3) v1 = convvertex("v1", v0, 2) v2 = concat("v2", v1, v0) - v3 = fvertex("v3", v2, x -> ONNXNaiveNASflux.globalmeanpool(x, y -> dropdims(y, dims=(1,2)))) + v3 = gmpvertex("globalmeanpool", v2) v4 = dense("v4", v3, 4) g = remodel(CompGraph(v0, v4))