Skip to content

Commit

Permalink
Merge pull request #99 from DrChainsaw/Flux0.15
Browse files Browse the repository at this point in the history
Updates for Flux 0.15
  • Loading branch information
DrChainsaw authored Dec 12, 2024
2 parents 4a3f464 + 9034ab5 commit 61b026b
Show file tree
Hide file tree
Showing 14 changed files with 753 additions and 282 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "1"
Flux = "0.13, 0.14"
Functors = "0.4"
Flux = "0.15.2"
Functors = "0.4, 0.5"
JuMP = "0.21, 0.22, 0.23, 1"
NaiveNASflux = "2.0.10"
NaiveNASlib = "2.0.11"
Expand Down
7 changes: 4 additions & 3 deletions src/ONNXNaiveNASflux.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,14 @@ include("baseonnx/BaseOnnx.jl")
import .BaseOnnx: array
const ONNX = BaseOnnx
using Flux
using Flux: params
import Functors
using NaiveNASflux
using NaiveNASflux: weights, bias
using NaiveNASflux: indim, outdim, actdim, actrank, layertype, wrapped
using NaiveNASflux: FluxLayer, FluxParLayer, FluxNoParLayer, FluxDense, FluxConvolutional, FluxConv, FluxConvTranspose,
FluxBatchNorm, FluxInstanceNorm, FluxRecurrent, FluxRnn, FluxLstm, FluxGru, FluxTransparentLayer,
FluxPoolLayer, FluxDropOut, Flux2D, GenericFluxConvolutional, GenericFlux2D, GenericFluxRecurrent
FluxBatchNorm, FluxInstanceNorm, FluxRecurrent, FluxRecurrentCell, FluxRnn, FluxRnnCell, FluxLstm,
FluxLstmCell, FluxGru, FluxGruCell, FluxTransparentLayer, FluxPoolLayer, FluxDropOut, Flux2D,
GenericFluxConvolutional, GenericFlux2D, GenericFluxRecurrent
using Setfield
using Statistics
import Pkg
Expand All @@ -34,6 +34,7 @@ include("deserialize/graphbuilder.jl")
include("deserialize/combine.jl")
include("deserialize/deserialize.jl")

include("serialize/traceprobes.jl")
include("serialize/namingutil.jl")
include("serialize/serialize.jl")

Expand Down
1 change: 1 addition & 0 deletions src/deserialize/combine.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ check_squeeze(nsqueeze::OnnxNode, gb::CompGraphBuilder, innode::OnnxNode, ::Val{

function check_squeeze(nsqueeze::OnnxNode, gb::CompGraphBuilder, innode::OnnxNode, ::RecurrentLayer)
@debug "Remove squeeze after $innode"
innode.attribute[SQUEEZED_RECURRENT_KEY] = true
return retnode(innode, gb)
end

Expand Down
31 changes: 20 additions & 11 deletions src/deserialize/deserialize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,19 @@ load(m::ONNX.ModelProto, insizes...; kwargs...) = load(m.graph, insizes...; kwar
load(g::ONNX.GraphProto, insizes...; kwargs...) = CompGraph(g, insizes...; kwargs...)
NaiveNASlib.CompGraph(g::ONNX.GraphProto, insizes...; kwargs...) = CompGraph(CompGraphBuilder(g, insizes...); kwargs...)
function NaiveNASlib.CompGraph(gb::CompGraphBuilder; vfun = create_vertex_default, infer_shapes=true)
outputs::Vector{AbstractVertex} = vertex.(gb, node.(name.(gb.g.output), gb), vfun)
graph = CompGraph(gb.inputs, outputs)
if infer_shapes
try_infer_sizes!(graph, (get(gb.sizes, n, (missing,)) for n in name.(inputs(graph)))...)
end
return graph
# unique here is abit of a hack for LSTM testcase where an LSTM is the last layer
# Flux LSTM outputs a tuple which is translated to having two outputs in serialize
# However, the end result is that gb.g.output has one entry for each output and this means
# that we will put the same LSTM vertex twice as the output layer.
# This type of ambiguity (i.e do I want the output from vertex X twice, or does it actually
# output a tuple?) is why adding support for multi-output vertices seems quite painful
# at least with the current state of this package.
outputs::Vector{AbstractVertex} = unique(vertex.(gb, node.(name.(gb.g.output), gb), vfun))
graph = CompGraph(gb.inputs, outputs)
if infer_shapes
try_infer_sizes!(graph, (get(gb.sizes, n, (missing,)) for n in name.(inputs(graph)))...)
end
return graph
end

"""
Expand All @@ -39,12 +46,14 @@ Inputs to the returned vertex are created recursively based on state in `gb`.
function vertex(gb::CompGraphBuilder, n::OnnxNode, vfun = create_vertex_default)
return get!(gb.created, n) do
n_create, ins = check_combine(gb, n)
invertices = map(ni -> vertex(gb, ni, vfun), ins)
v = vfun(n_create, invertices)
if is_input(v)
push!(gb.inputs, v)
get!(gb.created, n_create) do
invertices = map(ni -> vertex(gb, ni, vfun), ins)
v = vfun(n_create, invertices)
if is_input(v)
push!(gb.inputs, v)
end
return v
end
return v
end
end

Expand Down
16 changes: 12 additions & 4 deletions src/deserialize/graphbuilder.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@

const ACTIVE_OUTPUTS_ATTRIBUTE_KEY = :ONNXNaiveNASflux_ACTIVE_OUTPUTS

"""
OnnxNode(proto::ONNX.NodeProto, params::Vector{ONNX.TensorProto}, attribs::Dict)
Expand All @@ -9,7 +11,11 @@ struct OnnxNode
params::Vector{ONNX.TensorProto}
attribute::Dict{Symbol, Any} # Must be Any or else we might overspecialize, preventing that stuff is added later
end
OnnxNode(proto, ps) = OnnxNode(proto, ps, Dict{Symbol, Any}(Dict(proto.attribute)))
function OnnxNode(proto, ps)
attribute = Dict{Symbol, Any}(Dict(proto.attribute))
attribute[ACTIVE_OUTPUTS_ATTRIBUTE_KEY] = findall(!isempty, output(proto))
OnnxNode(proto, ps, attribute)
end

"""
CompGraphBuilder(g::ONNX.Types.Graph, sizes::Dict{String, <:Tuple})
Expand Down Expand Up @@ -60,6 +66,8 @@ function output_to_node(nodes, initdict)
ps = params(nodeproto, initdict)
node = OnnxNode(nodeproto, ps)
for outname in output(node)
# TODO: Custom error type for this
@assert outname keys(allnodes) "Duplicate output name found: $(outname)!"
allnodes[outname] = node
end
end
Expand Down Expand Up @@ -186,6 +194,6 @@ end
optype(n::ONNX.NodeProto) = Symbol(n.op_type)
optype(n::OnnxNode) = optype(n.proto)

Flux.params(n::ONNX.NodeProto, initdict) = params(Val(optype(n)), n, initdict)
Flux.params(::Val, n::ONNX.NodeProto, initdict) = map(pname -> initdict[pname], setdiff(input(n), innames(n))) # Inputs which are not other vertices
Flux.params(n::OnnxNode) = n.params .|> array
params(n::ONNX.NodeProto, initdict) = params(Val(optype(n)), n, initdict)
params(::Val, n::ONNX.NodeProto, initdict) = map(pname -> initdict[pname], setdiff(input(n), innames(n))) # Inputs which are not other vertices
params(n::OnnxNode) = n.params .|> array
144 changes: 131 additions & 13 deletions src/deserialize/ops.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ const pseudotransparentops = Dict{Symbol, Any}()
const verts = Dict{Symbol, Any}()
const fluxlayertypes = Dict{Symbol, Any}()

layerfuns = Dict{Symbol, Any}()

# Rundown of the basic idea here:

# Aspect 1
Expand Down Expand Up @@ -38,13 +40,60 @@ const fluxlayertypes = Dict{Symbol, Any}()

# Functions which have dedicated vertex construction methods, such as Concat and Add end up in verts.


"""
OutputSelection(selection, wrapped)
Selects outputs from `wrapped` using `selection`.
Typically used when `wrapped` outputs a `Tuple` from which other nodes in the computation graph
only wants a subset.
Can also be used to transform Flux output to ONNX output. One example is recurrent layers where
Flux outputs all time steps of the hidden state while some ONNX outputs are only the last step.
Note that the more useful and generic InputSelection (which would allow a node to pick a subset)
of some other nodes output as its input is not yet implemented. OutputSelection only works when
1) all nodes which take input from `wrapped` want the exact same outputs and 2) on output nodes
(which is the reason why I bothered to implement it to begin with).
"""
struct OutputSelection{FS, L} <: NaiveNASflux.AbstractMutableComp
selection::FS
wrapped::L
end
NaiveNASflux.wrapped(o::OutputSelection) = o.wrapped
(o::OutputSelection)(x...) = _apply_selection(o.selection, wrapped(o)(x...))

_apply_selection(fs::Tuple, x) = map(f -> f(x), fs)
_apply_selection(f, x) = f(x)

# Use for Recurrent layers since ONNX specifies on extra dimension for the number of directions
# which Flux does not have
struct AddSingletonDim{L} <: NaiveNASflux.AbstractMutableComp
dim::Int
wrapped::L
end
NaiveNASflux.wrapped(a::AddSingletonDim) = a.wrapped
function (a::AddSingletonDim)(x)
y = wrapped(a)(x)
_apply_add_singleton_dim(y, a.dim)
end

_apply_add_singleton_dim(x, dim) = reshape(x, size(x)[1:dim-1]..., 1, size(x)[dim:end]...)
_apply_add_singleton_dim(xt::Tuple, dim) = map(x -> _apply_add_singleton_dim(x, dim), xt)

struct OpNotSupportedError <: Exception
msg::String
end
OpNotSupportedError(op_type::Symbol) = OpNotSupportedError(string("Operation type ", op_type, " not supported!"))
Base.showerror(io::IO, e::OpNotSupportedError) = print(io, "OpNotSupportedError: ", e.msg)

sources[:Constant] = params -> constant(Val.(keys(params))..., values(params)...)
sources[:Constant] = function(params)
params = if ACTIVE_OUTPUTS_ATTRIBUTE_KEY in keys(params)
delete!(copy(params), ACTIVE_OUTPUTS_ATTRIBUTE_KEY)
end
constant(Val.(keys(params))..., values(params)...)
end
constant(::Val{:value}, val::ONNX.TensorProto) = val |> array
constant(::Val{:value}, val) = val

Expand Down Expand Up @@ -149,21 +198,55 @@ actlayers[:InstanceNormalization] = function(params, γ, β)
end
fluxlayertypes[:InstanceNormalization] = (pars...) -> FluxInstanceNorm()

fluxrecurrentlayers[:RNN] = function(params, Wi_WBi, Wh_WBh, Wb_Rb=default_Wb_Rb(Wh_WBh), seqlen=[], h3d = default_init_h(Wb_Rb, 2))
@assert size(Wi_WBi, 3) == 1 "Num directions must be 1! Bidirectional (num directions = 2) not supported!" # TODO: Add...
const SQUEEZED_RECURRENT_KEY = :ONNXNaiveNASflux_SQUEEZED_RECURRENT_KEY

Wi,Wh,b,h = recurrent_arrays(FluxRnn(), Wi_WBi, Wh_WBh, Wb_Rb, h3d)
fluxrecurrentlayers[:RNN] = function(params, Wi_WBi, Wh_WBh, Wb_Rb=default_Wb_Rb(Wh_WBh), seqlen=[], h3d = nothing)
@assert size(Wi_WBi, 3) == 1 "Num directions must be 1! Bidirectional (num directions = 2) not supported!" # TODO: Add...
if !isnothing(h3d)
# We could probably create some wrapper struct for this if anyone ever needs it...
@warn "Got initial hidden state for RNN. This can't be stored in Flux > 0.15 and will be ignored."
end
Wi,Wh,b = recurrent_arrays(FluxRnnCell(), Wi_WBi, Wh_WBh, Wb_Rb)
act = rnnactfuns[Symbol(get(params, :activations, ["Tanh"])[])](1, params)
cell = Flux.RNNCell(act, Wi, Wh, b, fill!(similar(h), 0))
return Flux.Recur(cell, h)
cell = Flux.RNNCell(act, Wi, Wh, b)
return Flux.RNN(cell)
end
fluxlayertypes[:RNN] = (pars...) -> FluxRnn()

_onnx_rnn_output1(h) = h
# Select last timestep
_onnx_rnn_output2(h::AbstractArray) = selectdim(h, 2, lastindex(h, 2))

fluxrecurrentlayers[:LSTM] = function(params, Wi_WBi, Wh_WBh, Wb_Rb=default_Wb_Rb(Wh_WBh), seqlen=[1], h3d = default_init_h(Wb_Rb, 8), c3d=default_init_h(Wb_Rb,8), peep=nothing)
_rnn_output_selection(i) = i === 1 ? _onnx_rnn_output1 :
i === 2 ? _onnx_rnn_output2 :
throw(ArgumentError("Unsupported RNN output: $i"))

layerfuns[:RNN] = function(params, args...)
active_outputs = params[ACTIVE_OUTPUTS_ATTRIBUTE_KEY]
selection = if length(active_outputs) == 1
_rnn_output_selection(only(active_outputs))
else
ntuple(i -> _rnn_output_selection(active_outputs[i]), length(active_outputs))
end
paddims = haskey(params, SQUEEZED_RECURRENT_KEY) ? identity : l -> AddSingletonDim(3, l)
layer -> paddims(OutputSelection(selection, layer))
end


fluxrecurrentlayers[:LSTM] = function(params, Wi_WBi, Wh_WBh, Wb_Rb=default_Wb_Rb(Wh_WBh), seqlen=[1], h3d = nothing, c3d = nothing, peep=nothing)
@assert size(Wi_WBi, 3) == 1 "Num directions must be 1! Bidirectional (num directions = 2) not supported!" # TODO: Add...
@assert isnothing(peep) "Peepholes not supported!" # Or?
Wi,Wh,b,h,c = recurrent_arrays(FluxLstm(), Wi_WBi, Wh_WBh, Wb_Rb, h3d, c3d)
if !isnothing(h3d)
# We could probably create some wrapper struct for this if anyone ever needs it...
@warn "Got initial hidden state for LSTM. This can't be stored in Flux > 0.15 and will be ignored."
end

if !isnothing(c3d)
# We could probably create some wrapper struct for this if anyone ever needs it...
@warn "Got initial cell state for LSTM. This can't be stored in Flux > 0.15 and will be ignored."
end

Wi,Wh,b = recurrent_arrays(FluxLstmCell(), Wi_WBi, Wh_WBh, Wb_Rb)
# Flux only supports default activation functions
# We can only check that given values doesn't deviate
supported = [:Sigmoid, :Tanh, :Tanh]
Expand All @@ -172,13 +255,40 @@ fluxrecurrentlayers[:LSTM] = function(params, Wi_WBi, Wh_WBh, Wb_Rb=default_Wb_R
e == a
end "Got unsupported activation function: $acts"

# b, h and c must all be of the same type when creating a cell, but
# it is actually Recur which has the state
cell = Flux.LSTMCell(Wi, Wh, b, (fill!(similar(h), 0), fill!(similar(c), 0)))
return Flux.Recur(cell, (h, c))
# Should not be a problem when/if Flux adds this back as an optional output
@assert 3 params[ACTIVE_OUTPUTS_ATTRIBUTE_KEY] "LSTM output 3 (the cell state) not implemnented!"

cell = Flux.LSTMCell(Wi, Wh, b)
return Flux.LSTM(cell)
end
fluxlayertypes[:LSTM] = (pars...) -> FluxLstm()

_onnx_lstm_output1(h::AbstractArray) = h
_onnx_lstm_output2(h::AbstractArray) = selectdim(h, 2, lastindex(h, 2))
_onnx_lstm_output3(::AbstractArray) = throw(ArgumentError("LSTM output nr 3 (cell state) requires Flux.LSTM to output state. Please check you layer configuration!"))

_onnx_lstm_output1((h, c)::NTuple{2, AbstractArray}) = h
_onnx_lstm_output2((h, c)::NTuple{2, AbstractArray}) = selectdim(h, 2, lastindex(h, 2))
_onnx_lstm_output3((h, c)::NTuple{2, AbstractArray}) = selectdim(c, 2, lastindex(c, 2))

_lstm_output_selection(i) = i === 1 ? _onnx_lstm_output1 :
i === 2 ? _onnx_lstm_output2 :
i === 3 ? _onnx_lstm_output3 :
throw(ArgumentError("Unsupported LSTM output: $i"))

layerfuns[:LSTM] = function(params, args...)
active_outputs = params[ACTIVE_OUTPUTS_ATTRIBUTE_KEY]
selection = if length(active_outputs) == 1
# Can we be sure receiver does not want a single-element tuple here? No we can't :( :( :(
_lstm_output_selection(only(active_outputs))
else
ntuple(i -> _lstm_output_selection(active_outputs[i]), length(active_outputs))
end
paddims = haskey(params, SQUEEZED_RECURRENT_KEY) ? identity : l -> AddSingletonDim(3, l)
layer -> paddims(OutputSelection(selection, layer))
end


function recurrent_arrays(lt, Wi_WBi, Wh_WBh, Wb_Rb, h3ds...)
# ONNX weights are on the form [num_directions, hidden_size, input_size] (where num_directions is 2 for bidirectional else 1)
# Flux weights are of shape [hidden_size, input_size]
Expand Down Expand Up @@ -416,7 +526,15 @@ function refresh()
end

for (s, f) in fluxlayers
verts[s] = (name, inputs, args...;kwargs...) -> fluxvertex(name, f(args...), inputs...; kwargs...)
verts[s] = function(name, inputs, args...; kwargs...)
# This is typically to select outputs, e.g. from recurrent layers
kwargsnew = if s in keys(layerfuns)
mergewith(, Dict(:layerfun => layerfuns[s](args...)), Dict(kwargs))
else
kwargs
end
fluxvertex(name, f(args...), inputs...; kwargsnew...)
end
end

for (s, f) in invariantops
Expand Down
Loading

0 comments on commit 61b026b

Please sign in to comment.