From 6e6c0938eedb88d0d728edfe38f99a7653cba795 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Wed, 1 Jan 2025 19:36:57 -0500 Subject: [PATCH 1/3] feat: emit batch_norm ops from stablehlo --- ext/LuxReactantExt/LuxReactantExt.jl | 1 + lib/LuxLib/Project.toml | 19 ++++--- lib/LuxLib/ext/LuxLibReactantExt.jl | 85 ++++++++++++++++++++++++++++ 3 files changed, 97 insertions(+), 8 deletions(-) create mode 100644 lib/LuxLib/ext/LuxLibReactantExt.jl diff --git a/ext/LuxReactantExt/LuxReactantExt.jl b/ext/LuxReactantExt/LuxReactantExt.jl index 292132759e..14acc442ef 100644 --- a/ext/LuxReactantExt/LuxReactantExt.jl +++ b/ext/LuxReactantExt/LuxReactantExt.jl @@ -17,6 +17,7 @@ Utils.contiguous(x::AnyTracedRArray) = Reactant.TracedUtils.materialize_traced_a Utils.eltype(::Type{<:TracedRArray{T, N}}) where {T, N} = T Utils.eltype(::Type{<:TracedRNumber{T}}) where {T} = T +Utils.eltype(x::Reactant.AnyTracedRArray) = Reactant.unwrapped_eltype(x) function Utils.promote_to(::Type{T}, x::Number) where {T <: Number} x isa Reactant.TracedType && return x diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index a9ec915cb1..9851e5adfe 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -19,8 +19,8 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623" MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40" Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" -Preferences = "21216c6a-2e73-6563-6e65-726566657250" Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588" +Preferences = "21216c6a-2e73-6563-6e65-726566657250" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3" @@ -32,23 +32,29 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e" AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924" BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e" CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" -MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9" LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890" +MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2" Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4" +Reactant = "3c362404-f566-11ee-1572-e11a4b42c853" ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267" SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa" Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c" cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd" +[sources] +LuxCore = {path = "../LuxCore"} +MLDataDevices = {path = "../MLDataDevices"} + [extensions] LuxLibAppleAccelerateExt = "AppleAccelerate" LuxLibBLISBLASExt = "BLISBLAS" LuxLibCUDAExt = "CUDA" -LuxLibMKLExt = "MKL" LuxLibEnzymeExt = "Enzyme" LuxLibLoopVectorizationExt = "LoopVectorization" +LuxLibMKLExt = "MKL" LuxLibOctavianExt = ["Octavian", "LoopVectorization"] +LuxLibReactantExt = "Reactant" LuxLibReverseDiffExt = "ReverseDiff" LuxLibSLEEFPiratesExt = "SLEEFPirates" LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"] @@ -79,9 +85,10 @@ MLDataDevices = "1.6" Markdown = "1.10" NNlib = "0.9.26" Octavian = "0.3.28" -Preferences = "1.4.3" Polyester = "0.7.15" +Preferences = "1.4.3" Random = "1.10" +Reactant = "0.2.13" Reexport = "1" ReverseDiff = "1.15" SLEEFPirates = "0.6.43" @@ -91,7 +98,3 @@ Statistics = "1.10" Tracker = "0.2.36" cuDNN = "1.3" julia = "1.10" - -[sources] -LuxCore = { path = "../LuxCore" } -MLDataDevices = { path = "../MLDataDevices" } diff --git a/lib/LuxLib/ext/LuxLibReactantExt.jl b/lib/LuxLib/ext/LuxLibReactantExt.jl new file mode 100644 index 0000000000..b3ad6d0b55 --- /dev/null +++ b/lib/LuxLib/ext/LuxLibReactantExt.jl @@ -0,0 +1,85 @@ +module LuxLibReactantExt + +using Reactant: Reactant, MLIR, Ops, TracedUtils, TracedRArray, AnyTracedRArray, + AnyTracedRVector, TracedRNumber +using Static: StaticBool, True, False + +using LuxLib: LuxLib, Impl, Optional, Utils + +# Most of the NN code gen happens in Reactant.jl via an extension on NNlib, however, +# NNlib doesn't have certain ops implemented. In those cases we can emit more optimized +# StableHLO + +function Impl.batchnorm( + x::AnyTracedRArray{T}, + γ::Optional{<:AnyTracedRVector}, β::Optional{<:AnyTracedRVector}, + rμ::Optional{<:AnyTracedRVector}, rσ²::Optional{<:AnyTracedRVector}, + training::StaticBool, act::F, momentum, ϵ +) where {T, F} + x = TracedUtils.materialize_traced_array(x) + + γ = if γ === nothing + Ops.constant(fill(T(1), size(x, ndims(x) - 1))) + else + TracedUtils.materialize_traced_array(γ) + end + β = if β === nothing + Ops.constant(fill(T(0), size(x, ndims(x) - 1))) + else + TracedUtils.materialize_traced_array(β) + end + + if training isa True + op = MLIR.Dialects.stablehlo.batch_norm_training( + TracedUtils.get_mlir_data(x), + TracedUtils.get_mlir_data(γ), + TracedUtils.get_mlir_data(β); + epsilon=Float32(ϵ), + feature_index=Int64(ndims(x) - 2) + ) + + res = act.(TracedRArray{T, ndims(x)}((), MLIR.IR.result(op, 1), size(x))) + μ = TracedRArray{T, 1}((), MLIR.IR.result(op, 2), size(x, ndims(x) - 1)) + σ² = TracedRArray{T, 1}((), MLIR.IR.result(op, 3), size(x, ndims(x) - 1)) + + if rμ === nothing && rσ² === nothing + return res, nothing, nothing + else + @assert rμ !== nothing && rσ² !== nothing + m = T(Impl.accum_size(x, Impl.batchnorm_reduce_dims(x))) + rμ, rσ² = Impl.update_running_statistics( + rμ, rσ², μ, σ², momentum, momentum * m / (m - one(m)) + ) + return res, rμ, rσ² + end + else + if rμ === nothing && rσ² === nothing + μ, σ² = Impl.mean_var( + x; dims=Utils.unsafe_known(Impl.batchnorm_reduce_dims(x)), corrected=false + ) + μ = TracedUtils.materialize_traced_array(vec(μ)) + σ² = TracedUtils.materialize_traced_array(vec(σ²)) + else + @assert rμ !== nothing && rσ² !== nothing + μ = TracedUtils.materialize_traced_array(rμ) + σ² = TracedUtils.materialize_traced_array(rσ²) + end + + res = MLIR.IR.result( + MLIR.Dialects.stablehlo.batch_norm_inference( + TracedUtils.get_mlir_data(x), + TracedUtils.get_mlir_data(γ), + TracedUtils.get_mlir_data(β), + TracedUtils.get_mlir_data(μ), + TracedUtils.get_mlir_data(σ²); + epsilon=Float32(ϵ), + feature_index=Int64(ndims(x) - 2) + ), + 1 + ) + + return act.(TracedRArray{T, ndims(x)}((), res, size(x))), rμ, rσ² + end +end + +end From 5014ee43a674c0f7ac9da23c50fdc1d8954dcb01 Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Jan 2025 14:27:56 -0500 Subject: [PATCH 2/3] refactor: only implement inference path for now --- Project.toml | 2 +- lib/LuxLib/Project.toml | 2 +- lib/LuxLib/ext/LuxLibReactantExt.jl | 109 +++++++++++++++++----------- test/reactant/layer_tests.jl | 5 ++ 4 files changed, 72 insertions(+), 46 deletions(-) diff --git a/Project.toml b/Project.toml index e34cd555f9..bad32cec73 100644 --- a/Project.toml +++ b/Project.toml @@ -99,7 +99,7 @@ GPUArraysCore = "0.1.6, 0.2" LinearAlgebra = "1.10" LossFunctions = "0.11.1, 1" LuxCore = "1.2" -LuxLib = "1.3.7" +LuxLib = "1.5.0" MLDataDevices = "1.6.6" MLUtils = "0.4.4" MPI = "0.20.19" diff --git a/lib/LuxLib/Project.toml b/lib/LuxLib/Project.toml index 9851e5adfe..35fca7c12f 100644 --- a/lib/LuxLib/Project.toml +++ b/lib/LuxLib/Project.toml @@ -1,7 +1,7 @@ name = "LuxLib" uuid = "82251201-b29d-42c6-8e01-566dec8acb11" authors = ["Avik Pal and contributors"] -version = "1.4.1" +version = "1.5.0" [deps] ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9" diff --git a/lib/LuxLib/ext/LuxLibReactantExt.jl b/lib/LuxLib/ext/LuxLibReactantExt.jl index b3ad6d0b55..36b79cccd9 100644 --- a/lib/LuxLib/ext/LuxLibReactantExt.jl +++ b/lib/LuxLib/ext/LuxLibReactantExt.jl @@ -2,19 +2,18 @@ module LuxLibReactantExt using Reactant: Reactant, MLIR, Ops, TracedUtils, TracedRArray, AnyTracedRArray, AnyTracedRVector, TracedRNumber -using Static: StaticBool, True, False +using Static: False using LuxLib: LuxLib, Impl, Optional, Utils # Most of the NN code gen happens in Reactant.jl via an extension on NNlib, however, # NNlib doesn't have certain ops implemented. In those cases we can emit more optimized # StableHLO - function Impl.batchnorm( x::AnyTracedRArray{T}, γ::Optional{<:AnyTracedRVector}, β::Optional{<:AnyTracedRVector}, rμ::Optional{<:AnyTracedRVector}, rσ²::Optional{<:AnyTracedRVector}, - training::StaticBool, act::F, momentum, ϵ + ::False, act::F, momentum, ϵ ) where {T, F} x = TracedUtils.materialize_traced_array(x) @@ -29,57 +28,79 @@ function Impl.batchnorm( TracedUtils.materialize_traced_array(β) end - if training isa True - op = MLIR.Dialects.stablehlo.batch_norm_training( + if rμ === nothing && rσ² === nothing + μ, σ² = Impl.mean_var( + x; dims=Utils.unsafe_known(Impl.batchnorm_reduce_dims(x)), corrected=false + ) + μ = TracedUtils.materialize_traced_array(vec(μ)) + σ² = TracedUtils.materialize_traced_array(vec(σ²)) + else + @assert rμ !== nothing && rσ² !== nothing + μ = TracedUtils.materialize_traced_array(rμ) + σ² = TracedUtils.materialize_traced_array(rσ²) + end + + res = MLIR.IR.result( + MLIR.Dialects.stablehlo.batch_norm_inference( TracedUtils.get_mlir_data(x), TracedUtils.get_mlir_data(γ), - TracedUtils.get_mlir_data(β); + TracedUtils.get_mlir_data(β), + TracedUtils.get_mlir_data(μ), + TracedUtils.get_mlir_data(σ²); epsilon=Float32(ϵ), feature_index=Int64(ndims(x) - 2) - ) + ), + 1 + ) - res = act.(TracedRArray{T, ndims(x)}((), MLIR.IR.result(op, 1), size(x))) - μ = TracedRArray{T, 1}((), MLIR.IR.result(op, 2), size(x, ndims(x) - 1)) - σ² = TracedRArray{T, 1}((), MLIR.IR.result(op, 3), size(x, ndims(x) - 1)) + return act.(TracedRArray{T, ndims(x)}((), res, size(x))), rμ, rσ² +end - if rμ === nothing && rσ² === nothing - return res, nothing, nothing - else - @assert rμ !== nothing && rσ² !== nothing - m = T(Impl.accum_size(x, Impl.batchnorm_reduce_dims(x))) - rμ, rσ² = Impl.update_running_statistics( - rμ, rσ², μ, σ², momentum, momentum * m / (m - one(m)) - ) - return res, rμ, rσ² - end +# The following code is commented out since we don't have Batchnorm Op Adjoint registered +# for EnzymeJAX yet +#= +function Impl.batchnorm( + x::AnyTracedRArray{T}, + γ::Optional{<:AnyTracedRVector}, β::Optional{<:AnyTracedRVector}, + rμ::Optional{<:AnyTracedRVector}, rσ²::Optional{<:AnyTracedRVector}, + training::StaticBool, act::F, momentum, ϵ +) where {T, F} + x = TracedUtils.materialize_traced_array(x) + + γ = if γ === nothing + Ops.constant(fill(T(1), size(x, ndims(x) - 1))) + else + TracedUtils.materialize_traced_array(γ) + end + β = if β === nothing + Ops.constant(fill(T(0), size(x, ndims(x) - 1))) else - if rμ === nothing && rσ² === nothing - μ, σ² = Impl.mean_var( - x; dims=Utils.unsafe_known(Impl.batchnorm_reduce_dims(x)), corrected=false - ) - μ = TracedUtils.materialize_traced_array(vec(μ)) - σ² = TracedUtils.materialize_traced_array(vec(σ²)) - else - @assert rμ !== nothing && rσ² !== nothing - μ = TracedUtils.materialize_traced_array(rμ) - σ² = TracedUtils.materialize_traced_array(rσ²) - end + TracedUtils.materialize_traced_array(β) + end - res = MLIR.IR.result( - MLIR.Dialects.stablehlo.batch_norm_inference( - TracedUtils.get_mlir_data(x), - TracedUtils.get_mlir_data(γ), - TracedUtils.get_mlir_data(β), - TracedUtils.get_mlir_data(μ), - TracedUtils.get_mlir_data(σ²); - epsilon=Float32(ϵ), - feature_index=Int64(ndims(x) - 2) - ), - 1 - ) + op = MLIR.Dialects.stablehlo.batch_norm_training( + TracedUtils.get_mlir_data(x), + TracedUtils.get_mlir_data(γ), + TracedUtils.get_mlir_data(β); + epsilon=Float32(ϵ), + feature_index=Int64(ndims(x) - 2) + ) - return act.(TracedRArray{T, ndims(x)}((), res, size(x))), rμ, rσ² + res = act.(TracedRArray{T, ndims(x)}((), MLIR.IR.result(op, 1), size(x))) + μ = TracedRArray{T, 1}((), MLIR.IR.result(op, 2), size(x, ndims(x) - 1)) + σ² = TracedRArray{T, 1}((), MLIR.IR.result(op, 3), size(x, ndims(x) - 1)) + + if rμ === nothing && rσ² === nothing + return res, nothing, nothing + else + @assert rμ !== nothing && rσ² !== nothing + m = T(Impl.accum_size(x, Impl.batchnorm_reduce_dims(x))) + rμ, rσ² = Impl.update_running_statistics( + rμ, rσ², μ, σ², momentum, momentum * m / (m - one(m)) + ) + return res, rμ, rσ² end end +=# end diff --git a/test/reactant/layer_tests.jl b/test/reactant/layer_tests.jl index b2b5d8021b..744c875bda 100644 --- a/test/reactant/layer_tests.jl +++ b/test/reactant/layer_tests.jl @@ -99,3 +99,8 @@ end end end end + +@testitem "BatchNorm Layer" tags=[:reactant] setup=[SharedTestSetup] skip=:(Sys.iswindows()) begin + using Reactant, Lux, Random + +end From f074b760a7b73b443688f00c79d26c5b548d21df Mon Sep 17 00:00:00 2001 From: Avik Pal Date: Fri, 17 Jan 2025 16:17:48 -0500 Subject: [PATCH 3/3] test: batchnorm layers --- test/reactant/layer_tests.jl | 64 +++++++++++++++++++++++++++++++++++- 1 file changed, 63 insertions(+), 1 deletion(-) diff --git a/test/reactant/layer_tests.jl b/test/reactant/layer_tests.jl index 744c875bda..f277bb826d 100644 --- a/test/reactant/layer_tests.jl +++ b/test/reactant/layer_tests.jl @@ -100,7 +100,69 @@ end end end -@testitem "BatchNorm Layer" tags=[:reactant] setup=[SharedTestSetup] skip=:(Sys.iswindows()) begin +@testitem "BatchNorm Layer" tags=[:reactant] setup=[ + SharedTestSetup, SharedReactantLayersTestSetup] skip=:(Sys.iswindows()) begin using Reactant, Lux, Random + @testset "$(mode)" for (mode, atype, dev, ongpu) in MODES + if mode == "amdgpu" + @warn "Skipping AMDGPU tests for Reactant" + continue + end + + dev = reactant_device(; force=true) + + if ongpu + Reactant.set_default_backend("gpu") + else + Reactant.set_default_backend("cpu") + end + + @testset for track_stats in (true, false), affine in (true, false), + act in (identity, tanh) + + model = Chain( + Dense(2 => 3, tanh), + BatchNorm(3, act; track_stats, affine, init_bias=rand32, init_scale=rand32), + Dense(3 => 2) + ) + + x = rand(Float32, 2, 4) + ps, st = Lux.setup(Random.default_rng(), model) + + x_ra = x |> dev + ps_ra = ps |> dev + st_ra = st |> dev + + y, st2 = model(x, ps, st) + y_ra, st2_ra = @jit model(x_ra, ps_ra, st_ra) + + @test y≈y_ra rtol=1e-3 atol=1e-3 + if track_stats + @test st2.layer_2.running_mean≈st2_ra.layer_2.running_mean rtol=1e-3 atol=1e-3 + @test st2.layer_2.running_var≈st2_ra.layer_2.running_var rtol=1e-3 atol=1e-3 + end + + # TODO: Check for stablehlo.batch_norm_training once we emit it in LuxLib + + @testset "gradient" begin + ∂x, ∂ps = ∇sumabs2_zygote(model, x, ps, st) + ∂x_ra, ∂ps_ra = @jit ∇sumabs2_enzyme(model, x_ra, ps_ra, st_ra) + @test ∂x_ra≈∂x atol=1e-2 rtol=1e-2 + @test check_approx(∂ps_ra, ∂ps; atol=1e-2, rtol=1e-2) + end + + y2, st3 = model(x, ps, Lux.testmode(st2)) + y2_ra, st3_ra = @jit model(x_ra, ps_ra, Lux.testmode(st2_ra)) + + @test y2≈y2_ra rtol=1e-3 atol=1e-3 + if track_stats + @test st3.layer_2.running_mean≈st3_ra.layer_2.running_mean rtol=1e-3 atol=1e-3 + @test st3.layer_2.running_var≈st3_ra.layer_2.running_var rtol=1e-3 atol=1e-3 + end + + hlo = @code_hlo model(x_ra, ps_ra, Lux.testmode(st_ra)) + @test contains(repr(hlo), "stablehlo.batch_norm_inference") + end + end end