Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merge pull request #344 from LuxDL/ap/lux0.4 #346

Merged
merged 1 commit into from
Jul 4, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
67 changes: 39 additions & 28 deletions ext/LuxFluxTransformExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,14 @@ function Base.showerror(io::IO, e::FluxModelConversionError)
end
end

# Workaround https://github.com/cjdoris/PackageExtensionCompat.jl/issues/9
function __copy_anonymous_closure(x)
function copy_anonymous_closure(args...)
return x
end
return copy_anonymous_closure
end

"""
FluxLayer(layer)

Expand Down Expand Up @@ -119,10 +127,11 @@ end
function transform(l::Flux.Dense; preserve_ps_st::Bool=false, kwargs...)
out_dims, in_dims = size(l.weight)
if preserve_ps_st
bias = l.bias isa Bool ? nothing : reshape(copy(l.bias), out_dims, 1)
return Dense(in_dims => out_dims,
l.σ;
init_weight=(args...) -> copy(l.weight),
init_bias=(args...) -> reshape(copy(l.bias), out_dims, 1),
init_weight=__copy_anonymous_closure(copy(l.weight)),
init_bias=__copy_anonymous_closure(bias),
use_bias=!(l.bias isa Bool))
else
return Dense(in_dims => out_dims, l.σ; use_bias=!(l.bias isa Bool))
Expand All @@ -133,8 +142,8 @@ function transform(l::Flux.Scale; preserve_ps_st::Bool=false, kwargs...)
if preserve_ps_st
return Scale(size(l.scale),
l.σ;
init_weight=(args...) -> copy(l.scale),
init_bias=(args...) -> copy(l.bias),
init_weight=__copy_anonymous_closure(copy(l.scale)),
init_bias=__copy_anonymous_closure(copy(l.bias)),
use_bias=!(l.bias isa Bool))
else
return Scale(size(l.scale), l.σ; use_bias=!(l.bias isa Bool))
Expand All @@ -154,8 +163,8 @@ function transform(l::Flux.Bilinear; preserve_ps_st::Bool=false, kwargs...)
if preserve_ps_st
return Bilinear((in1, in2) => out,
l.σ;
init_weight=(args...) -> copy(l.weight),
init_bias=(args...) -> copy(l.bias),
init_weight=__copy_anonymous_closure(copy(l.weight)),
init_bias=__copy_anonymous_closure(copy(l.bias)),
use_bias=!(l.bias isa Bool))
else
return Bilinear((in1, in2) => out, l.σ; use_bias=!(l.bias isa Bool))
Expand All @@ -180,7 +189,8 @@ end
function transform(l::Flux.Embedding; preserve_ps_st::Bool=true, kwargs...)
out_dims, in_dims = size(l.weight)
if preserve_ps_st
return Embedding(in_dims => out_dims; init_weight=(args...) -> copy(l.weight))
return Embedding(in_dims => out_dims;
init_weight=__copy_anonymous_closure(copy(l.weight)))
else
return Embedding(in_dims => out_dims)
end
Expand All @@ -192,7 +202,8 @@ function transform(l::Flux.Conv; preserve_ps_st::Bool=false, kwargs...)
groups = l.groups
pad = l.pad isa Flux.SamePad ? SamePad() : l.pad
if preserve_ps_st
_bias = reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1)
_bias = l.bias isa Bool ? nothing :
reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1)
return Conv(k,
in_chs * groups => out_chs,
l.σ;
Expand All @@ -201,8 +212,8 @@ function transform(l::Flux.Conv; preserve_ps_st::Bool=false, kwargs...)
l.dilation,
groups,
use_bias=!(l.bias isa Bool),
init_weight=(args...) -> Lux._maybe_flip_conv_weight(l.weight),
init_bias=(args...) -> _bias)
init_weight=__copy_anonymous_closure(copy(l.weight)),
init_bias=__copy_anonymous_closure(_bias))
else
return Conv(k,
in_chs * groups => out_chs,
Expand All @@ -221,7 +232,8 @@ function transform(l::Flux.ConvTranspose; preserve_ps_st::Bool=false, kwargs...)
groups = l.groups
pad = l.pad isa Flux.SamePad ? SamePad() : l.pad
if preserve_ps_st
_bias = reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1)
_bias = l.bias isa Bool ? nothing :
reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1)
return ConvTranspose(k,
in_chs * groups => out_chs,
l.σ;
Expand All @@ -230,8 +242,8 @@ function transform(l::Flux.ConvTranspose; preserve_ps_st::Bool=false, kwargs...)
l.dilation,
groups,
use_bias=!(l.bias isa Bool),
init_weight=(args...) -> Lux._maybe_flip_conv_weight(l.weight),
init_bias=(args...) -> _bias)
init_weight=__copy_anonymous_closure(copy(l.weight)),
init_bias=__copy_anonymous_closure(_bias))
else
return ConvTranspose(k,
in_chs * groups => out_chs,
Expand All @@ -249,16 +261,17 @@ function transform(l::Flux.CrossCor; preserve_ps_st::Bool=false, kwargs...)
in_chs, out_chs = size(l.weight)[(end - 1):end]
pad = l.pad isa Flux.SamePad ? SamePad() : l.pad
if preserve_ps_st
_bias = reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1)
_bias = l.bias isa Bool ? nothing :
reshape(copy(l.bias), ntuple(_ -> 1, length(k))..., out_chs, 1)
return CrossCor(k,
in_chs => out_chs,
l.σ;
l.stride,
pad,
l.dilation,
use_bias=!(l.bias isa Bool),
init_weight=(args...) -> copy(l.weight),
init_bias=(args...) -> _bias)
init_weight=__copy_anonymous_closure(copy(l.weight)),
init_bias=__copy_anonymous_closure(_bias))
else
return CrossCor(k,
in_chs => out_chs,
Expand Down Expand Up @@ -305,8 +318,6 @@ function transform(l::Flux.Upsample{mode}; kwargs...) where {mode}
return Upsample{mode, typeof(l.scale), typeof(l.size)}(l.scale, l.size)
end

_const_return_anon_function(x) = (args...) -> x

function transform(l::Flux.RNNCell; preserve_ps_st::Bool=false, force_preserve::Bool=false)
out_dims, in_dims = size(l.Wi)
if preserve_ps_st
Expand All @@ -316,8 +327,8 @@ function transform(l::Flux.RNNCell; preserve_ps_st::Bool=false, force_preserve::
@warn "Preserving Parameters: `Wh` & `Wi` for `Flux.RNNCell` is ambiguous in Lux and hence not supported. Ignoring these parameters." maxlog=1
return RNNCell(in_dims => out_dims,
l.σ;
init_bias=(args...) -> copy(l.b),
init_state=(args...) -> copy(l.state0))
init_bias=__copy_anonymous_closure(copy(l.b)),
init_state=__copy_anonymous_closure(copy(l.state0)))
else
return RNNCell(in_dims => out_dims, l.σ)
end
Expand All @@ -334,9 +345,9 @@ function transform(l::Flux.LSTMCell; preserve_ps_st::Bool=false, force_preserve:
bs = Lux.multigate(l.b, Val(4))
_s, _m = copy.(l.state0)
return LSTMCell(in_dims => out_dims;
init_bias=_const_return_anon_function.(bs),
init_state=(args...) -> _s,
init_memory=(args...) -> _m)
init_bias=__copy_anonymous_closure.(bs),
init_state=__copy_anonymous_closure(_s),
init_memory=__copy_anonymous_closure(_m))
else
return LSTMCell(in_dims => out_dims)
end
Expand All @@ -353,7 +364,7 @@ function transform(l::Flux.GRUCell; preserve_ps_st::Bool=false, force_preserve::
bs = Lux.multigate(l.b, Val(3))
return GRUCell(in_dims => out_dims;
init_bias=_const_return_anon_function.(bs),
init_state=(args...) -> copy(l.state0))
init_state=__copy_anonymous_closure(copy(l.state0)))
else
return GRUCell(in_dims => out_dims)
end
Expand All @@ -374,8 +385,8 @@ function transform(l::Flux.BatchNorm;
l.track_stats,
epsilon=l.ϵ,
l.momentum,
init_bias=(args...) -> copy(l.β),
init_scale=(args...) -> copy(l.γ))
init_bias=__copy_anonymous_closure(copy(l.β)),
init_scale=__copy_anonymous_closure(copy(l.γ)))
else
return BatchNorm(l.chs, l.λ; l.affine, l.track_stats, epsilon=l.ϵ, l.momentum)
end
Expand All @@ -397,8 +408,8 @@ function transform(l::Flux.GroupNorm;
l.λ;
l.affine,
epsilon=l.ϵ,
init_bias=(args...) -> copy(l.β),
init_scale=(args...) -> copy(l.γ))
init_bias=__copy_anonymous_closure(copy(l.β)),
init_scale=__copy_anonymous_closure(copy(l.γ)))
else
return GroupNorm(l.chs, l.G, l.λ; l.affine, epsilon=l.ϵ)
end
Expand Down