Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: emit batch_norm ops from stablehlo #1142

Merged
merged 3 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ GPUArraysCore = "0.1.6, 0.2"
LinearAlgebra = "1.10"
LossFunctions = "0.11.1, 1"
LuxCore = "1.2"
LuxLib = "1.3.7"
LuxLib = "1.5.0"
MLDataDevices = "1.6.6"
MLUtils = "0.4.4"
MPI = "0.20.19"
Expand Down
1 change: 1 addition & 0 deletions ext/LuxReactantExt/LuxReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ Utils.contiguous(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_a

Utils.eltype(::Type{<:TracedRArray{T, N}}) where {T, N} = T
Utils.eltype(::Type{<:TracedRNumber{T}}) where {T} = T
Utils.eltype(x::Reactant.AnyTracedRArray) = Reactant.unwrapped_eltype(x)

function Utils.promote_to(::Type{T}, x::Number) where {T <: Number}
x isa Reactant.TracedType && return x
Expand Down
21 changes: 12 additions & 9 deletions lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.4.1"
version = "1.5.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand All @@ -19,8 +19,8 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Expand All @@ -32,23 +32,29 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924"
BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[sources]
LuxCore = {path = "../LuxCore"}
MLDataDevices = {path = "../MLDataDevices"}

[extensions]
LuxLibAppleAccelerateExt = "AppleAccelerate"
LuxLibBLISBLASExt = "BLISBLAS"
LuxLibCUDAExt = "CUDA"
LuxLibMKLExt = "MKL"
LuxLibEnzymeExt = "Enzyme"
LuxLibLoopVectorizationExt = "LoopVectorization"
LuxLibMKLExt = "MKL"
LuxLibOctavianExt = ["Octavian", "LoopVectorization"]
LuxLibReactantExt = "Reactant"
LuxLibReverseDiffExt = "ReverseDiff"
LuxLibSLEEFPiratesExt = "SLEEFPirates"
LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"]
Expand Down Expand Up @@ -79,9 +85,10 @@ MLDataDevices = "1.6"
Markdown = "1.10"
NNlib = "0.9.26"
Octavian = "0.3.28"
Preferences = "1.4.3"
Polyester = "0.7.15"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.13"
Reexport = "1"
ReverseDiff = "1.15"
SLEEFPirates = "0.6.43"
Expand All @@ -91,7 +98,3 @@ Statistics = "1.10"
Tracker = "0.2.36"
cuDNN = "1.3"
julia = "1.10"

[sources]
LuxCore = { path = "../LuxCore" }
MLDataDevices = { path = "../MLDataDevices" }
106 changes: 106 additions & 0 deletions lib/LuxLib/ext/LuxLibReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
module LuxLibReactantExt

using Reactant: Reactant, MLIR, Ops, TracedUtils, TracedRArray, AnyTracedRArray,
AnyTracedRVector, TracedRNumber
using Static: False

using LuxLib: LuxLib, Impl, Optional, Utils

# Most of the NN code gen happens in Reactant.jl via an extension on NNlib, however,
# NNlib doesn't have certain ops implemented. In those cases we can emit more optimized
# StableHLO
function Impl.batchnorm(
x::AnyTracedRArray{T},
γ::Optional{<:AnyTracedRVector}, β::Optional{<:AnyTracedRVector},
rμ::Optional{<:AnyTracedRVector}, rσ²::Optional{<:AnyTracedRVector},
::False, act::F, momentum, ϵ
) where {T, F}
x = TracedUtils.materialize_traced_array(x)

γ = if γ === nothing
Ops.constant(fill(T(1), size(x, ndims(x) - 1)))
else
TracedUtils.materialize_traced_array(γ)
end
β = if β === nothing
Ops.constant(fill(T(0), size(x, ndims(x) - 1)))
else
TracedUtils.materialize_traced_array(β)
end

if rμ === nothing && rσ² === nothing
μ, σ² = Impl.mean_var(
x; dims=Utils.unsafe_known(Impl.batchnorm_reduce_dims(x)), corrected=false
)
μ = TracedUtils.materialize_traced_array(vec(μ))
σ² = TracedUtils.materialize_traced_array(vec(σ²))
else
@assert rμ !== nothing && rσ² !== nothing
μ = TracedUtils.materialize_traced_array(rμ)
σ² = TracedUtils.materialize_traced_array(rσ²)
end

res = MLIR.IR.result(
MLIR.Dialects.stablehlo.batch_norm_inference(
TracedUtils.get_mlir_data(x),
TracedUtils.get_mlir_data(γ),
TracedUtils.get_mlir_data(β),
TracedUtils.get_mlir_data(μ),
TracedUtils.get_mlir_data(σ²);
epsilon=Float32(ϵ),
feature_index=Int64(ndims(x) - 2)
),
1
)

return act.(TracedRArray{T, ndims(x)}((), res, size(x))), rμ, rσ²
end

# The following code is commented out since we don't have Batchnorm Op Adjoint registered
# for EnzymeJAX yet
#=
function Impl.batchnorm(
x::AnyTracedRArray{T},
γ::Optional{<:AnyTracedRVector}, β::Optional{<:AnyTracedRVector},
rμ::Optional{<:AnyTracedRVector}, rσ²::Optional{<:AnyTracedRVector},
training::StaticBool, act::F, momentum, ϵ
) where {T, F}
x = TracedUtils.materialize_traced_array(x)

γ = if γ === nothing
Ops.constant(fill(T(1), size(x, ndims(x) - 1)))
else
TracedUtils.materialize_traced_array(γ)
end
β = if β === nothing
Ops.constant(fill(T(0), size(x, ndims(x) - 1)))
else
TracedUtils.materialize_traced_array(β)
end

op = MLIR.Dialects.stablehlo.batch_norm_training(
TracedUtils.get_mlir_data(x),
TracedUtils.get_mlir_data(γ),
TracedUtils.get_mlir_data(β);
epsilon=Float32(ϵ),
feature_index=Int64(ndims(x) - 2)
)

res = act.(TracedRArray{T, ndims(x)}((), MLIR.IR.result(op, 1), size(x)))
μ = TracedRArray{T, 1}((), MLIR.IR.result(op, 2), size(x, ndims(x) - 1))
σ² = TracedRArray{T, 1}((), MLIR.IR.result(op, 3), size(x, ndims(x) - 1))

if rμ === nothing && rσ² === nothing
return res, nothing, nothing
else
@assert rμ !== nothing && rσ² !== nothing
m = T(Impl.accum_size(x, Impl.batchnorm_reduce_dims(x)))
rμ, rσ² = Impl.update_running_statistics(
rμ, rσ², μ, σ², momentum, momentum * m / (m - one(m))
)
return res, rμ, rσ²
end
end
=#

end
67 changes: 67 additions & 0 deletions test/reactant/layer_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,3 +99,70 @@ end
end
end
end

@testitem "BatchNorm Layer" tags=[:reactant] setup=[
SharedTestSetup, SharedReactantLayersTestSetup] skip=:(Sys.iswindows()) begin
using Reactant, Lux, Random

avik-pal marked this conversation as resolved.
Show resolved Hide resolved
@testset "$(mode)" for (mode, atype, dev, ongpu) in MODES
if mode == "amdgpu"
@warn "Skipping AMDGPU tests for Reactant"
continue
end

dev = reactant_device(; force=true)

if ongpu
Reactant.set_default_backend("gpu")
else
Reactant.set_default_backend("cpu")
end

@testset for track_stats in (true, false), affine in (true, false),
act in (identity, tanh)

model = Chain(
Dense(2 => 3, tanh),
BatchNorm(3, act; track_stats, affine, init_bias=rand32, init_scale=rand32),
Dense(3 => 2)
)

x = rand(Float32, 2, 4)
ps, st = Lux.setup(Random.default_rng(), model)

x_ra = x |> dev
ps_ra = ps |> dev
st_ra = st |> dev

y, st2 = model(x, ps, st)
y_ra, st2_ra = @jit model(x_ra, ps_ra, st_ra)

@test y≈y_ra rtol=1e-3 atol=1e-3
if track_stats
@test st2.layer_2.running_mean≈st2_ra.layer_2.running_mean rtol=1e-3 atol=1e-3
@test st2.layer_2.running_var≈st2_ra.layer_2.running_var rtol=1e-3 atol=1e-3
end

# TODO: Check for stablehlo.batch_norm_training once we emit it in LuxLib

@testset "gradient" begin
∂x, ∂ps = ∇sumabs2_zygote(model, x, ps, st)
∂x_ra, ∂ps_ra = @jit ∇sumabs2_enzyme(model, x_ra, ps_ra, st_ra)
@test ∂x_ra≈∂x atol=1e-2 rtol=1e-2
@test check_approx(∂ps_ra, ∂ps; atol=1e-2, rtol=1e-2)
end

y2, st3 = model(x, ps, Lux.testmode(st2))
y2_ra, st3_ra = @jit model(x_ra, ps_ra, Lux.testmode(st2_ra))

@test y2≈y2_ra rtol=1e-3 atol=1e-3
if track_stats
@test st3.layer_2.running_mean≈st3_ra.layer_2.running_mean rtol=1e-3 atol=1e-3
@test st3.layer_2.running_var≈st3_ra.layer_2.running_var rtol=1e-3 atol=1e-3
end

hlo = @code_hlo model(x_ra, ps_ra, Lux.testmode(st_ra))
@test contains(repr(hlo), "stablehlo.batch_norm_inference")
end
end
end
Loading