Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updates for Flux 0.15 #99

Merged
merged 4 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1"
Flux = "0.13, 0.14"
Functors = "0.4"
Flux = "0.15.2"
Functors = "0.4, 0.5"
JuMP = "0.21, 0.22, 0.23, 1"
NaiveNASflux = "2.0.10"
NaiveNASlib = "2.0.11"
Expand Down
7 changes: 4 additions & 3 deletions src/ONNXNaiveNASflux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ include("baseonnx/BaseOnnx.jl")
import .BaseOnnx: array
const ONNX = BaseOnnx
using Flux
using Flux: params
import Functors
using NaiveNASflux
using NaiveNASflux: weights, bias
using NaiveNASflux: indim, outdim, actdim, actrank, layertype, wrapped
using NaiveNASflux: FluxLayer, FluxParLayer, FluxNoParLayer, FluxDense, FluxConvolutional, FluxConv, FluxConvTranspose,
FluxBatchNorm, FluxInstanceNorm, FluxRecurrent, FluxRnn, FluxLstm, FluxGru, FluxTransparentLayer,
FluxPoolLayer, FluxDropOut, Flux2D, GenericFluxConvolutional, GenericFlux2D, GenericFluxRecurrent
FluxBatchNorm, FluxInstanceNorm, FluxRecurrent, FluxRecurrentCell, FluxRnn, FluxRnnCell, FluxLstm,
FluxLstmCell, FluxGru, FluxGruCell, FluxTransparentLayer, FluxPoolLayer, FluxDropOut, Flux2D,
GenericFluxConvolutional, GenericFlux2D, GenericFluxRecurrent
using Setfield
using Statistics
import Pkg
Expand All @@ -34,6 +34,7 @@ include("deserialize/graphbuilder.jl")
include("deserialize/combine.jl")
include("deserialize/deserialize.jl")

include("serialize/traceprobes.jl")
include("serialize/namingutil.jl")
include("serialize/serialize.jl")

Expand Down
1 change: 1 addition & 0 deletions src/deserialize/combine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ check_squeeze(nsqueeze::OnnxNode, gb::CompGraphBuilder, innode::OnnxNode, ::Val{

function check_squeeze(nsqueeze::OnnxNode, gb::CompGraphBuilder, innode::OnnxNode, ::RecurrentLayer)
@debug "Remove squeeze after $innode"
innode.attribute[SQUEEZED_RECURRENT_KEY] = true
return retnode(innode, gb)
end

Expand Down
31 changes: 20 additions & 11 deletions src/deserialize/deserialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,19 @@ load(m::ONNX.ModelProto, insizes...; kwargs...) = load(m.graph, insizes...; kwar
load(g::ONNX.GraphProto, insizes...; kwargs...) = CompGraph(g, insizes...; kwargs...)
NaiveNASlib.CompGraph(g::ONNX.GraphProto, insizes...; kwargs...) = CompGraph(CompGraphBuilder(g, insizes...); kwargs...)
function NaiveNASlib.CompGraph(gb::CompGraphBuilder; vfun = create_vertex_default, infer_shapes=true)
outputs::Vector{AbstractVertex} = vertex.(gb, node.(name.(gb.g.output), gb), vfun)
graph = CompGraph(gb.inputs, outputs)
if infer_shapes
try_infer_sizes!(graph, (get(gb.sizes, n, (missing,)) for n in name.(inputs(graph)))...)
end
return graph
# unique here is abit of a hack for LSTM testcase where an LSTM is the last layer
# Flux LSTM outputs a tuple which is translated to having two outputs in serialize
# However, the end result is that gb.g.output has one entry for each output and this means
# that we will put the same LSTM vertex twice as the output layer.
# This type of ambiguity (i.e do I want the output from vertex X twice, or does it actually
# output a tuple?) is why adding support for multi-output vertices seems quite painful
# at least with the current state of this package.
outputs::Vector{AbstractVertex} = unique(vertex.(gb, node.(name.(gb.g.output), gb), vfun))
graph = CompGraph(gb.inputs, outputs)
if infer_shapes
try_infer_sizes!(graph, (get(gb.sizes, n, (missing,)) for n in name.(inputs(graph)))...)
end
return graph
end

"""
Expand All @@ -39,12 +46,14 @@ Inputs to the returned vertex are created recursively based on state in `gb`.
function vertex(gb::CompGraphBuilder, n::OnnxNode, vfun = create_vertex_default)
return get!(gb.created, n) do
n_create, ins = check_combine(gb, n)
invertices = map(ni -> vertex(gb, ni, vfun), ins)
v = vfun(n_create, invertices)
if is_input(v)
push!(gb.inputs, v)
get!(gb.created, n_create) do
invertices = map(ni -> vertex(gb, ni, vfun), ins)
v = vfun(n_create, invertices)
if is_input(v)
push!(gb.inputs, v)
end
return v
end
return v
end
end

Expand Down
16 changes: 12 additions & 4 deletions src/deserialize/graphbuilder.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

const ACTIVE_OUTPUTS_ATTRIBUTE_KEY = :ONNXNaiveNASflux_ACTIVE_OUTPUTS

"""
OnnxNode(proto::ONNX.NodeProto, params::Vector{ONNX.TensorProto}, attribs::Dict)

Expand All @@ -9,7 +11,11 @@ struct OnnxNode
params::Vector{ONNX.TensorProto}
attribute::Dict{Symbol, Any} # Must be Any or else we might overspecialize, preventing that stuff is added later
end
OnnxNode(proto, ps) = OnnxNode(proto, ps, Dict{Symbol, Any}(Dict(proto.attribute)))
function OnnxNode(proto, ps)
attribute = Dict{Symbol, Any}(Dict(proto.attribute))
attribute[ACTIVE_OUTPUTS_ATTRIBUTE_KEY] = findall(!isempty, output(proto))
OnnxNode(proto, ps, attribute)
end

"""
CompGraphBuilder(g::ONNX.Types.Graph, sizes::Dict{String, <:Tuple})
Expand Down Expand Up @@ -60,6 +66,8 @@ function output_to_node(nodes, initdict)
ps = params(nodeproto, initdict)
node = OnnxNode(nodeproto, ps)
for outname in output(node)
# TODO: Custom error type for this
@assert outname ∉ keys(allnodes) "Duplicate output name found: $(outname)!"
allnodes[outname] = node
end
end
Expand Down Expand Up @@ -186,6 +194,6 @@ end
optype(n::ONNX.NodeProto) = Symbol(n.op_type)
optype(n::OnnxNode) = optype(n.proto)

Flux.params(n::ONNX.NodeProto, initdict) = params(Val(optype(n)), n, initdict)
Flux.params(::Val, n::ONNX.NodeProto, initdict) = map(pname -> initdict[pname], setdiff(input(n), innames(n))) # Inputs which are not other vertices
Flux.params(n::OnnxNode) = n.params .|> array
params(n::ONNX.NodeProto, initdict) = params(Val(optype(n)), n, initdict)
params(::Val, n::ONNX.NodeProto, initdict) = map(pname -> initdict[pname], setdiff(input(n), innames(n))) # Inputs which are not other vertices
params(n::OnnxNode) = n.params .|> array
144 changes: 131 additions & 13 deletions src/deserialize/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
const verts = Dict{Symbol, Any}()
const fluxlayertypes = Dict{Symbol, Any}()

layerfuns = Dict{Symbol, Any}()

# Rundown of the basic idea here:

# Aspect 1
Expand Down Expand Up @@ -38,13 +40,60 @@

# Functions which have dedicated vertex construction methods, such as Concat and Add end up in verts.


"""
OutputSelection(selection, wrapped)

Selects outputs from `wrapped` using `selection`.

Typically used when `wrapped` outputs a `Tuple` from which other nodes in the computation graph
only wants a subset.

Can also be used to transform Flux output to ONNX output. One example is recurrent layers where
Flux outputs all time steps of the hidden state while some ONNX outputs are only the last step.

Note that the more useful and generic InputSelection (which would allow a node to pick a subset)
of some other nodes output as its input is not yet implemented. OutputSelection only works when
1) all nodes which take input from `wrapped` want the exact same outputs and 2) on output nodes
(which is the reason why I bothered to implement it to begin with).
"""
struct OutputSelection{FS, L} <: NaiveNASflux.AbstractMutableComp
selection::FS
wrapped::L
end
NaiveNASflux.wrapped(o::OutputSelection) = o.wrapped
(o::OutputSelection)(x...) = _apply_selection(o.selection, wrapped(o)(x...))

_apply_selection(fs::Tuple, x) = map(f -> f(x), fs)
_apply_selection(f, x) = f(x)

# Use for Recurrent layers since ONNX specifies on extra dimension for the number of directions
# which Flux does not have
struct AddSingletonDim{L} <: NaiveNASflux.AbstractMutableComp
dim::Int
wrapped::L
end
NaiveNASflux.wrapped(a::AddSingletonDim) = a.wrapped
function (a::AddSingletonDim)(x)
y = wrapped(a)(x)
_apply_add_singleton_dim(y, a.dim)
end

_apply_add_singleton_dim(x, dim) = reshape(x, size(x)[1:dim-1]..., 1, size(x)[dim:end]...)
_apply_add_singleton_dim(xt::Tuple, dim) = map(x -> _apply_add_singleton_dim(x, dim), xt)

struct OpNotSupportedError <: Exception
msg::String
end
OpNotSupportedError(op_type::Symbol) = OpNotSupportedError(string("Operation type ", op_type, " not supported!"))
Base.showerror(io::IO, e::OpNotSupportedError) = print(io, "OpNotSupportedError: ", e.msg)

sources[:Constant] = params -> constant(Val.(keys(params))..., values(params)...)
sources[:Constant] = function(params)
params = if ACTIVE_OUTPUTS_ATTRIBUTE_KEY in keys(params)
delete!(copy(params), ACTIVE_OUTPUTS_ATTRIBUTE_KEY)
end
constant(Val.(keys(params))..., values(params)...)
end
constant(::Val{:value}, val::ONNX.TensorProto) = val |> array
constant(::Val{:value}, val) = val

Expand Down Expand Up @@ -149,21 +198,55 @@
end
fluxlayertypes[:InstanceNormalization] = (pars...) -> FluxInstanceNorm()

fluxrecurrentlayers[:RNN] = function(params, Wi_WBi, Wh_WBh, Wb_Rb=default_Wb_Rb(Wh_WBh), seqlen=[], h3d = default_init_h(Wb_Rb, 2))
@assert size(Wi_WBi, 3) == 1 "Num directions must be 1! Bidirectional (num directions = 2) not supported!" # TODO: Add...
const SQUEEZED_RECURRENT_KEY = :ONNXNaiveNASflux_SQUEEZED_RECURRENT_KEY

Wi,Wh,b,h = recurrent_arrays(FluxRnn(), Wi_WBi, Wh_WBh, Wb_Rb, h3d)
fluxrecurrentlayers[:RNN] = function(params, Wi_WBi, Wh_WBh, Wb_Rb=default_Wb_Rb(Wh_WBh), seqlen=[], h3d = nothing)
@assert size(Wi_WBi, 3) == 1 "Num directions must be 1! Bidirectional (num directions = 2) not supported!" # TODO: Add...
if !isnothing(h3d)
# We could probably create some wrapper struct for this if anyone ever needs it...
@warn "Got initial hidden state for RNN. This can't be stored in Flux > 0.15 and will be ignored."

Check warning on line 207 in src/deserialize/ops.jl

View check run for this annotation

Codecov / codecov/patch

src/deserialize/ops.jl#L207

Added line #L207 was not covered by tests
end
Wi,Wh,b = recurrent_arrays(FluxRnnCell(), Wi_WBi, Wh_WBh, Wb_Rb)
act = rnnactfuns[Symbol(get(params, :activations, ["Tanh"])[])](1, params)
cell = Flux.RNNCell(act, Wi, Wh, b, fill!(similar(h), 0))
return Flux.Recur(cell, h)
cell = Flux.RNNCell(act, Wi, Wh, b)
return Flux.RNN(cell)
end
fluxlayertypes[:RNN] = (pars...) -> FluxRnn()

_onnx_rnn_output1(h) = h
# Select last timestep
_onnx_rnn_output2(h::AbstractArray) = selectdim(h, 2, lastindex(h, 2))

fluxrecurrentlayers[:LSTM] = function(params, Wi_WBi, Wh_WBh, Wb_Rb=default_Wb_Rb(Wh_WBh), seqlen=[1], h3d = default_init_h(Wb_Rb, 8), c3d=default_init_h(Wb_Rb,8), peep=nothing)
_rnn_output_selection(i) = i === 1 ? _onnx_rnn_output1 :
i === 2 ? _onnx_rnn_output2 :
throw(ArgumentError("Unsupported RNN output: $i"))

layerfuns[:RNN] = function(params, args...)
active_outputs = params[ACTIVE_OUTPUTS_ATTRIBUTE_KEY]
selection = if length(active_outputs) == 1
_rnn_output_selection(only(active_outputs))
else
ntuple(i -> _rnn_output_selection(active_outputs[i]), length(active_outputs))
end
paddims = haskey(params, SQUEEZED_RECURRENT_KEY) ? identity : l -> AddSingletonDim(3, l)
layer -> paddims(OutputSelection(selection, layer))
end


fluxrecurrentlayers[:LSTM] = function(params, Wi_WBi, Wh_WBh, Wb_Rb=default_Wb_Rb(Wh_WBh), seqlen=[1], h3d = nothing, c3d = nothing, peep=nothing)
@assert size(Wi_WBi, 3) == 1 "Num directions must be 1! Bidirectional (num directions = 2) not supported!" # TODO: Add...
@assert isnothing(peep) "Peepholes not supported!" # Or?
Wi,Wh,b,h,c = recurrent_arrays(FluxLstm(), Wi_WBi, Wh_WBh, Wb_Rb, h3d, c3d)
if !isnothing(h3d)
# We could probably create some wrapper struct for this if anyone ever needs it...
@warn "Got initial hidden state for LSTM. This can't be stored in Flux > 0.15 and will be ignored."

Check warning on line 241 in src/deserialize/ops.jl

View check run for this annotation

Codecov / codecov/patch

src/deserialize/ops.jl#L241

Added line #L241 was not covered by tests
end

if !isnothing(c3d)
# We could probably create some wrapper struct for this if anyone ever needs it...
@warn "Got initial cell state for LSTM. This can't be stored in Flux > 0.15 and will be ignored."

Check warning on line 246 in src/deserialize/ops.jl

View check run for this annotation

Codecov / codecov/patch

src/deserialize/ops.jl#L246

Added line #L246 was not covered by tests
end

Wi,Wh,b = recurrent_arrays(FluxLstmCell(), Wi_WBi, Wh_WBh, Wb_Rb)
# Flux only supports default activation functions
# We can only check that given values doesn't deviate
supported = [:Sigmoid, :Tanh, :Tanh]
Expand All @@ -172,13 +255,40 @@
e == a
end "Got unsupported activation function: $acts"

# b, h and c must all be of the same type when creating a cell, but
# it is actually Recur which has the state
cell = Flux.LSTMCell(Wi, Wh, b, (fill!(similar(h), 0), fill!(similar(c), 0)))
return Flux.Recur(cell, (h, c))
# Should not be a problem when/if Flux adds this back as an optional output
@assert 3 ∉ params[ACTIVE_OUTPUTS_ATTRIBUTE_KEY] "LSTM output 3 (the cell state) not implemnented!"

cell = Flux.LSTMCell(Wi, Wh, b)
return Flux.LSTM(cell)
end
fluxlayertypes[:LSTM] = (pars...) -> FluxLstm()

_onnx_lstm_output1(h::AbstractArray) = h
_onnx_lstm_output2(h::AbstractArray) = selectdim(h, 2, lastindex(h, 2))
_onnx_lstm_output3(::AbstractArray) = throw(ArgumentError("LSTM output nr 3 (cell state) requires Flux.LSTM to output state. Please check you layer configuration!"))

Check warning on line 268 in src/deserialize/ops.jl

View check run for this annotation

Codecov / codecov/patch

src/deserialize/ops.jl#L268

Added line #L268 was not covered by tests

_onnx_lstm_output1((h, c)::NTuple{2, AbstractArray}) = h
_onnx_lstm_output2((h, c)::NTuple{2, AbstractArray}) = selectdim(h, 2, lastindex(h, 2))
_onnx_lstm_output3((h, c)::NTuple{2, AbstractArray}) = selectdim(c, 2, lastindex(c, 2))

Check warning on line 272 in src/deserialize/ops.jl

View check run for this annotation

Codecov / codecov/patch

src/deserialize/ops.jl#L270-L272

Added lines #L270 - L272 were not covered by tests

_lstm_output_selection(i) = i === 1 ? _onnx_lstm_output1 :
i === 2 ? _onnx_lstm_output2 :
i === 3 ? _onnx_lstm_output3 :
throw(ArgumentError("Unsupported LSTM output: $i"))

layerfuns[:LSTM] = function(params, args...)
active_outputs = params[ACTIVE_OUTPUTS_ATTRIBUTE_KEY]
selection = if length(active_outputs) == 1
# Can we be sure receiver does not want a single-element tuple here? No we can't :( :( :(
_lstm_output_selection(only(active_outputs))
else
ntuple(i -> _lstm_output_selection(active_outputs[i]), length(active_outputs))
end
paddims = haskey(params, SQUEEZED_RECURRENT_KEY) ? identity : l -> AddSingletonDim(3, l)
layer -> paddims(OutputSelection(selection, layer))
end


function recurrent_arrays(lt, Wi_WBi, Wh_WBh, Wb_Rb, h3ds...)
# ONNX weights are on the form [num_directions, hidden_size, input_size] (where num_directions is 2 for bidirectional else 1)
# Flux weights are of shape [hidden_size, input_size]
Expand Down Expand Up @@ -416,7 +526,15 @@
end

for (s, f) in fluxlayers
verts[s] = (name, inputs, args...;kwargs...) -> fluxvertex(name, f(args...), inputs...; kwargs...)
verts[s] = function(name, inputs, args...; kwargs...)
# This is typically to select outputs, e.g. from recurrent layers
kwargsnew = if s in keys(layerfuns)
mergewith(∘, Dict(:layerfun => layerfuns[s](args...)), Dict(kwargs))
else
kwargs
end
fluxvertex(name, f(args...), inputs...; kwargsnew...)
end
end

for (s, f) in invariantops
Expand Down
Loading
Loading