Skip to content

Commit

Permalink
feat: emit batch_norm ops from stablehlo
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Dec 21, 2024
1 parent 879a599 commit de62ff2
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 9 deletions.
21 changes: 12 additions & 9 deletions lib/LuxLib/Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.3.10"
version = "1.4.0"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand All @@ -19,8 +19,8 @@ LuxCore = "bb33d45b-7691-41d6-9220-0943567d0623"
MLDataDevices = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
Markdown = "d6f4376e-aef5-505a-96c1-9c027394607a"
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
Static = "aedffcd0-7271-4cad-89d0-dc628f76c6d3"
Expand All @@ -32,23 +32,29 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924"
BLISBLAS = "6f275bd8-fec0-4d39-945b-7e95a765fa1e"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
SLEEFPirates = "476501e8-09a2-5ece-8869-fb82de89a1fa"
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

[sources]
LuxCore = {path = "../LuxCore"}
MLDataDevices = {path = "../MLDataDevices"}

[extensions]
LuxLibAppleAccelerateExt = "AppleAccelerate"
LuxLibBLISBLASExt = "BLISBLAS"
LuxLibCUDAExt = "CUDA"
LuxLibMKLExt = "MKL"
LuxLibEnzymeExt = "Enzyme"
LuxLibLoopVectorizationExt = "LoopVectorization"
LuxLibMKLExt = "MKL"
LuxLibOctavianExt = ["Octavian", "LoopVectorization"]
LuxLibReactantExt = "Reactant"
LuxLibReverseDiffExt = "ReverseDiff"
LuxLibSLEEFPiratesExt = "SLEEFPirates"
LuxLibTrackerAMDGPUExt = ["AMDGPU", "Tracker"]
Expand Down Expand Up @@ -79,9 +85,10 @@ MLDataDevices = "1.6"
Markdown = "1.10"
NNlib = "0.9.24"
Octavian = "0.3.28"
Preferences = "1.4.3"
Polyester = "0.7.15"
Preferences = "1.4.3"
Random = "1.10"
Reactant = "0.2.11"
Reexport = "1"
ReverseDiff = "1.15"
SLEEFPirates = "0.6.43"
Expand All @@ -91,7 +98,3 @@ Statistics = "1.10"
Tracker = "0.2.36"
cuDNN = "1.3"
julia = "1.10"

[sources]
LuxCore = { path = "../LuxCore" }
MLDataDevices = { path = "../MLDataDevices" }
85 changes: 85 additions & 0 deletions lib/LuxLib/ext/LuxLibReactantExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
module LuxLibReactantExt

using Reactant: Reactant, MLIR, Ops, TracedUtils, TracedRArray, AnyTracedRArray,
AnyTracedRVector
using Static: StaticBool, True, False

using LuxLib: LuxLib, Impl, Optional, Utils

# Most of the NN code gen happens in Reactant.jl via an extension on NNlib, however,
# NNlib doesn't have certain ops implemented. In those cases we can emit more optimized
# StableHLO

function Impl.batchnorm(
x::AnyTracedRArray{T},
γ::Optional{<:AnyTracedRVector}, β::Optional{<:AnyTracedRVector},
::Optional{<:AnyTracedRVector}, rσ²::Optional{<:AnyTracedRVector},
training::StaticBool, act::F, momentum::Real, ϵ::Real
) where {T, F}
x = TracedUtils.materialize_traced_array(x)

γ = if γ === nothing
Ops.constant(fill(T(1), size(x, ndims(x) - 1)))
else
TracedUtils.materialize_traced_array(γ)
end
β = if β === nothing
Ops.constant(fill(T(0), size(x, ndims(x) - 1)))
else
TracedUtils.materialize_traced_array(β)
end

if training isa True
op = MLIR.Dialects.stablehlo.batch_norm_training(
TracedUtils.get_mlir_data(x),
TracedUtils.get_mlir_data(γ),
TracedUtils.get_mlir_data(β);
epsilon=Float32(ϵ),
feature_index=Int64(ndims(x) - 2)
)

res = act.(TracedRArray{T, ndims(x)}((), MLIR.IR.result(op, 1), size(x)))
μ = TracedRArray{T, 1}((), MLIR.IR.result(op, 2), size(x, ndims(x) - 1))
σ² = TracedRArray{T, 1}((), MLIR.IR.result(op, 3), size(x, ndims(x) - 1))

if=== nothing && rσ² === nothing
return res, nothing, nothing
else
@assert!== nothing && rσ² !== nothing
m = T(Impl.accum_size(x, Impl.batchnorm_reduce_dims(x)))
rμ, rσ² = Impl.update_running_statistics(
rμ, rσ², μ, σ², momentum, momentum * m / (m - one(m))
)
return res, rμ, rσ²
end
else
if=== nothing && rσ² === nothing
μ, σ² = Impl.mean_var(
x; dims=Utils.unsafe_known(Impl.batchnorm_reduce_dims(x)), corrected=false
)
μ = TracedUtils.materialize_traced_array(vec(μ))
σ² = TracedUtils.materialize_traced_array(vec(σ²))
else
@assert!== nothing && rσ² !== nothing
μ = TracedUtils.materialize_traced_array(rμ)
σ² = TracedUtils.materialize_traced_array(rσ²)
end

res = MLIR.IR.result(
MLIR.Dialects.stablehlo.batch_norm_inference(
TracedUtils.get_mlir_data(x),
TracedUtils.get_mlir_data(γ),
TracedUtils.get_mlir_data(β),
TracedUtils.get_mlir_data(μ),
TracedUtils.get_mlir_data(σ²);
epsilon=Float32(ϵ),
feature_index=Int64(ndims(x) - 2)
),
1
)

return act.(TracedRArray{T, ndims(x)}((), res, size(x))), rμ, rσ²
end
end

end

0 comments on commit de62ff2

Please sign in to comment.