Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for GlobalMaxPool and GlobalMeanPool #88

Merged
merged 10 commits into from
Jul 14, 2024
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -102,6 +102,7 @@ Sigmoid
Softmax
Squeeze
Tanh
Unsqueeze
```

## Adding Operations
Expand Down
14 changes: 14 additions & 0 deletions src/deserialize/constraints.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,13 @@ NaiveNASflux.layer(r::Reshape) = r
NaiveNASflux.actdim(r::Reshape) = r.adim
NaiveNASflux.actrank(r::Reshape) = length(r.dims)

function Base.show(io::IO, s::Reshape)
print(io, "Reshape(dims=")
ioc = IOContext(io, :prefix => "[", :suffix=>"]")
show(ioc, s.dims)
print(io, ")")
end

# If Reshape is wrapped in an AbstractMutableComp we will hit this method instead due to how NaiveNASflux unwraps things
NaiveNASlib.compconstraint!(case, s::NaiveNASlib.AbstractJuMPΔSizeStrategy, ::Type{<:Reshape}, data) = NaiveNASlib.compconstraint!(case, s, layer(data.vertex), data)

Expand Down Expand Up @@ -179,6 +186,13 @@ NaiveNASflux.actrank(f::Flatten) = 1

calc_outsize(::Flatten, v) = 0

function Base.show(io::IO, s::Flatten)
print(io, "Flatten(dim=")
ioc = IOContext(io, :prefix => "[", :suffix=>"]")
show(ioc, s.dim)
print(io, ")")
end

# If Flatten is wrapped in an AbstractMutableComp we will hit this method instead due to how NaiveNASflux unwraps things
NaiveNASlib.compconstraint!(case, s::NaiveNASlib.AbstractJuMPΔSizeStrategy, ::Type{<:Flatten}, data) = NaiveNASlib.compconstraint!(case, s, layer(data.vertex), data)

Expand Down
104 changes: 93 additions & 11 deletions src/deserialize/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -208,32 +208,114 @@
fluxlayers[:Dropout] = params -> Dropout(get(params, :ratio, 0.5))
fluxlayertypes[:Dropout] = (pars...) -> FluxDropOut()


invariantops[:GlobalAveragePool] = function(params)
wrap = get(params, :wrap, identity)
return x -> globalmeanpool(x, wrap)
return wrap ∘ GlobalMeanPool()
end
fluxlayertypes[:GlobalAveragePool] = (pars...) -> FluxPoolLayer()

function globalmeanpool(x::AbstractArray{T,N}, wrap) where T where N
wrap(MeanPool(size(x)[1:N-2])(x))
end

invariantops[:GlobalMaxPool] = function(params)
wrap = get(params, :wrap, identity)
return x -> globalmaxpool(x, wrap)
return wrap ∘ GlobalMaxPool()
end
fluxlayertypes[:GlobalMaxPool] = (pars...) -> FluxPoolLayer()

function globalmaxpool(x::AbstractArray{T,N}, wrap) where T where N
wrap(MaxPool(size(x)[1:N-2])(x))
"""
Squeeze(dims)

Callable struct which performs `dropdims` on input using the provided `dims` where `dims` is compliant with the ONNX OP Squeeze (meaning it can be missing or use numpy indexing).

Mainly exists for pretty printing reaons though as its task can be performed by partially applied functions.

Designed to only be used when deserializing the `Squeeze` operation.
"""
struct Squeeze{D}
dims::D
end
(s::Squeeze)(x) = dropdims(x; dims=s.dims)
(s::Squeeze{Missing})(x) = dropdims(x; dims=Tuple(findall(i -> i == 1, size(x))))
(s::Squeeze{<:NumPyAxes})(x) = dropdims(x; dims=Tuple(numpy2fluxdim(s.dims, ndims(x))))

Base.show(io::IO, ::Squeeze{Missing}) = print(io, "Squeeze")
function Base.show(io::IO, s::Squeeze)
print(io, "Squeeze(dims=")
ioc = IOContext(io, :prefix => "[", :suffix=>"]")
show(ioc, s.dims)
print(io, ")")
end


invariantops[:Squeeze] = function(params)
np_axes = get(params, :axes, missing)
dimfun = ismissing(np_axes) ? x -> Tuple(findall(i -> i == 1, size(x))) : x -> Tuple(numpy2fluxdim.(np_axes, ndims(x)))
return x -> dropdims(x, dims=dimfun(x))
dims = if !ismissing(np_axes)
NumPyAxes(Tuple(np_axes))
else
np_axes
end
return Squeeze(dims)
end

"""
Unsqueeze(dims)

Callable struct which performs `reshape` on input using the provided `dims` where `dims` is compliant with the ONNX OP `Unsqueeze` (meaning it can use numpy indexing).

Mainly exists for pretty printing reaons though as its task can be performed by partially applied functions.

Designed to only be used when deserializing the `Unsqueeze` operation.
"""
struct Unsqueeze{D}
dims::D
end

(u::Unsqueeze)(x) = unsqueeze_onnx(x, u.dims)

function Base.show(io::IO, s::Unsqueeze)
print(io, "Unsqueeze(dims=")
ioc = IOContext(io, :prefix => "[", :suffix=>"]")
show(ioc, s.dims)
print(io, ")")
end

invariantops[:Unsqueeze] = function(params)
haskey(params, :axes) || throw(ArgumentError("Must supply axes for Unsqueeze!"))
return Unsqueeze(NumPyAxes(params[:axes]))
end

unsqueeze_onnx(x, np_axes) = reshape(x, insdims(size(x), np_axes))

struct Sorted{T}
vals::T
function Sorted(x)
vals = issorted(x) ? x : sort(x)
new{typeof(vals)}(vals)
end
end
Base.getindex(s::Sorted, args...) = Base.getindex(s.vals, args...)
Base.length(s::Sorted) = length(s.vals)

# Probably premature optimization: Allow for users to avoid numpy2fluxdim and sorting if they really want to.

function insdims(orgsize, np_axes::NumPyAxes; ndimsout=length(orgsize) + length(np_axes), kwargs...)
insdims(orgsize, numpy2fluxdim(np_axes, ndimsout); ndimsout, kwargs...)
end

insdims(orgsize, dimstoadd; kwargs...) = insdims(orgsize, Sorted(dimstoadd); kwargs...)
insdims(orgsize, dims::Sorted; ndimsout=length(orgsize) + length(dims), inssize=Returns(1)) = let
currax = Ref(1)
dimoffs = Ref(0)
ntuple(ndimsout) do i
if currax[] <= length(dims) && dims[currax[]] == i
ins = inssize(currax[])
currax[] += 1
dimoffs[] += 1
ins

Check warning on line 312 in src/deserialize/ops.jl

View check run for this annotation

Codecov / codecov/patch

src/deserialize/ops.jl#L312

Added line #L312 was not covered by tests
else
orgsize[i - dimoffs[]]
end
end
end


invariantops[:ReduceMean] = function(params)
np_axes = get(params, :axes, missing)
Expand Down
14 changes: 12 additions & 2 deletions src/serialize/serialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -401,6 +401,7 @@
Flux.selu(pp::AbstractProbe) = attribfun(identity, "Selu", pp)
Flux.selu(pp::AbstractProbe, γ, α) = attribfun(identity, "Selu", pp; attributes = ONNX.AttributeProto.(["gamma", "alpha"], [γ, α]))
Flux.σ(pp::AbstractProbe) = attribfun(identity, "Sigmoid", pp)
Flux.sigmoid_fast(pp::AbstractProbe) = attribfun(identity, "Sigmoid", pp) # Flux-specific construct

Check warning on line 404 in src/serialize/serialize.jl

View check run for this annotation

Codecov / codecov/patch

src/serialize/serialize.jl#L404

Added line #L404 was not covered by tests

Base.tanh(pp::AbstractProbe) = attribfun(identity, "Tanh", pp)
Flux.softmax(pp::AbstractProbe; dims=1) = onnxsoftmax(pp, np_axis = flux2numpydim(dims[end], ndims(pp)))
Expand All @@ -410,6 +411,8 @@
(l::Flux.MeanPool)(pp::AbstractProbe) = attribfun(s -> outshape(l, s), "AveragePool", pp; attributes = attribs(l))
(l::Flux.Dropout)(pp::AbstractProbe) = attribfun(identity, "Dropout", pp; attributes = [ONNX.AttributeProto("ratio", l.p)])

(l::Flux.GlobalMaxPool)(pp::AbstractProbe) = globalmaxpool(pp, identity)
(l::Flux.GlobalMeanPool)(pp::AbstractProbe) = globalmeanpool(pp, identity)

globalmeanpool(pp::AbstractProbe, wrap) = globalpool(pp, wrap, "GlobalAveragePool")
globalmaxpool(pp::AbstractProbe, wrap) = globalpool(pp, wrap, "GlobalMaxPool")
Expand Down Expand Up @@ -498,8 +501,6 @@


function axisfun(fshape, optype, pps::AbstractProbe...; dims, axname="axes")
fname = recursename(lowercase(optype), nextname(pps[1]))

attrib = if isempty(dims)
ONNX.AttributeProto[]
else
Expand All @@ -508,6 +509,11 @@
np_axis = flux2numpydim.(dims, ndims(pok[1]))
[ONNX.AttributeProto(axname, np_axis)]
end
axisfun(fshape, optype, attrib, pps...)
end

function axisfun(fshape, optype, attrib::AbstractArray{<:ONNX.AttributeProto}, pps::AbstractProbe...)
fname = recursename(lowercase(optype), nextname(pps[1]))

add!(pps[1], ONNX.NodeProto(
input = collect(name.(pps)),
Expand Down Expand Up @@ -576,3 +582,7 @@
end
return newfrom(pp, fname, fshape)
end

Flux.unsqueeze(pp::AbstractProbe; dims) = axisfun(s -> insdims(s, dims), "Unsqueeze", pp; dims=scal2tup(dims))
unsqueeze_onnx(pp::AbstractProbe, npa::NumPyAxes) = axisfun(s -> insdims(s, npa), "Unsqueeze", [ONNX.AttributeProto("axes", npa.axes)], pp)

38 changes: 37 additions & 1 deletion src/shapes.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,44 @@ NaiveNASflux.actrank(::Shape1D) = 0
# TODO: Move to NaiveNASflux
NaiveNASflux.nin(sc::SkipConnection) = nin(sc.layers)

"""
NumPyAxes(axes)

Represents `axes` using numpy conventions, e.g. `0` means last dimension, `1` means second last and `-1` means first etc. due to row major vs column major indexing.

Mainly intended to be used when deserializing ONNX OPs which operate along a provided set of dimensions to mark that we
don't know yet how to translate the numpy axes to Julia axes.

Primarily exists for pretty printing reasons so we can show which dimensions certain ops will operate along in a Julia compliant manner (e.g. numpy index `1` is shown as `end-1`).

### Examples
```julia-repl
julia> ONNXNaiveNASflux.NumPyAxes([0, 1, 2, -1, -2])
NumPyAxes[end,end-1,end-2,1,2]
```
"""
struct NumPyAxes{AX}
axes::AX
end
# This is cute, but generally not needed since all ops which make use of this need to supply some dims argument
# which is pretty much always required to be a Tuple or an Int by some method quite high up in the call stack.
#Base.to_index(x::AbstractArray, npa::NumPyAxes) = numpy2fluxdim.(collect(npa.axes), ndims(x))
Base.length(npa::NumPyAxes) = length(npa.axes)

function Base.show(io::IO, npa::NumPyAxes)
print(io, get(io, :prefix, "NumPyAxes["))
indstr = map(npa.axes) do ax
ax === 0 && return "end"
ax > 0 && return string("end-", ax)
-ax
end
print(io, join(indstr, ','))
print(io, get(io, :suffix, "]"))
end

numpy2fluxdim(np_axis, v::AbstractVertex) = numpy2fluxdim(np_axis, 1 + NaiveNASflux.actrank(v)[1])
numpy2fluxdim(np_axis, ndims) = np_axis >= 0 ? ndims - np_axis : abs(np_axis)
numpy2fluxdim(npa::NumPyAxes, ndims::Integer) = numpy2fluxdim.(npa.axes, ndims)
numpy2fluxdim(np_axis, ndims) = np_axis >= 0 ? ndims - np_axis : -np_axis

flux2numpydim(dim, ndims) = ndims - dim

Expand Down
24 changes: 24 additions & 0 deletions test/deserialize/Artifacts.toml
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,27 @@ git-tree-sha1 = "458aa7b72548ef30e68311e2ba83fee5232fdcdd"

[test_squeeze_negative_axes]
git-tree-sha1 = "085e5374680b94e4ac1748a9a944ab67eb933ea5"

[test_unsqueeze_axis_0]
git-tree-sha1 = "e5b5cba31574fa5742cdbfbd76ef13f3ad49a91b"

[test_unsqueeze_axis_1]
git-tree-sha1 = "161893d3dda3c7ced6c7882e237b4488deaa7c41"

[test_unsqueeze_axis_2]
git-tree-sha1 = "38acbb013bfb691b4e75a79aeb49d7220666bd6a"

[test_unsqueeze_axis_3]
git-tree-sha1 = "db2c2acf07c3010fea726f104c4d5e9c922b92ac"

[test_unsqueeze_negative_axes]
git-tree-sha1 = "3b3de957f53bba7e4f14a8e7325e100d9cb66482"

[test_unsqueeze_three_axes]
git-tree-sha1 = "376f228da8a6d73d7d4f5248496b9074cd349a12"

[test_unsqueeze_two_axes]
git-tree-sha1 = "a607099a395a3a323ac3528b3a155034355d582f"

[test_unsqueeze_unsorted_axes]
git-tree-sha1 = "22deacbe190142e6abf4631b332e087acee00ac4"
46 changes: 46 additions & 0 deletions test/deserialize/deserialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,14 @@ end
(name="test_softmax_negative_axis", ninputs=1, noutputs=1, fd=invariantops),
(name="test_squeeze", ninputs=1, noutputs=1, fd=invariantops),
(name="test_squeeze_negative_axes", ninputs=1, noutputs=1, fd=invariantops),
(name="test_unsqueeze_axis_0", ninputs=1, noutputs=1, fd=invariantops),
(name="test_unsqueeze_axis_1", ninputs=1, noutputs=1, fd=invariantops),
(name="test_unsqueeze_axis_2", ninputs=1, noutputs=1, fd=invariantops),
(name="test_unsqueeze_axis_3", ninputs=1, noutputs=1, fd=invariantops),
(name="test_unsqueeze_negative_axes", ninputs=1, noutputs=1, fd=invariantops),
(name="test_unsqueeze_three_axes", ninputs=1, noutputs=1, fd=invariantops),
(name="test_unsqueeze_two_axes", ninputs=1, noutputs=1, fd=invariantops),
(name="test_unsqueeze_unsorted_axes", ninputs=1, noutputs=1, fd=invariantops),
)

model, gb, inputs, outputs = prepare_node_test(tc.name, tc.ninputs, tc.noutputs)
Expand Down Expand Up @@ -417,4 +425,42 @@ end
@test nout(g[end]) == 60
@test length(defaultutility(g[end])) == exputilsize
end

@testset "Squeeze" begin
import ONNXNaiveNASflux: Squeeze, NumPyAxes

# The top two are not hit through any other test
@test size(Squeeze((1,2))(reshape(ones(2,3), 1, 1, 2, 1, 3, 1))) == (2, 1, 3, 1)
@test size(Squeeze(missing)(reshape(ones(2,3), 1, 1, 2, 1, 3, 1))) == (2, 3)
@test size(Squeeze(NumPyAxes([-1, 0, 2]))(reshape(ones(2,3), 1, 1, 2, 1, 3, 1))) == (1, 2, 3)
end

@testset "Pretty printing" begin
import ONNXNaiveNASflux: NumPyAxes

@testset "NumPyAxes" begin
@test sprint(show, NumPyAxes([0, 1, 2, -1, 2])) == "NumPyAxes[end,end-1,end-2,1,end-2]"
end

@testset "Squeeze" begin
import ONNXNaiveNASflux: Squeeze
@test sprint(show, Squeeze(NumPyAxes([0, 1, -2]))) == "Squeeze(dims=[end,end-1,2])"
@test sprint(show, Squeeze(missing)) == "Squeeze"
end

@testset "Unsqueeze" begin
import ONNXNaiveNASflux: Unsqueeze
@test sprint(show, Unsqueeze(NumPyAxes([0, 1, -2]))) == "Unsqueeze(dims=[end,end-1,2])"
end

@testset "Reshape" begin
import ONNXNaiveNASflux: Reshape
@test sprint(show, Reshape((1, 2))) == "Reshape(dims=(1, 2))"
end

@testset "Flatten" begin
import ONNXNaiveNASflux: Flatten
@test sprint(show, Flatten(2)) == "Flatten(dim=2)"
end
end
end
Loading
Loading