Skip to content

Commit

Permalink
Float intrinsics fixes & test improvements (#531)
Browse files Browse the repository at this point in the history
* Use `Base.tanpi` in intrinsics

`tanpi` is in Julia since 1.10 so allsupported versions have it

* Test more intrinsics and fix `min`/`max`

Also clean up the different tests

* [NFC] Add commented out intrinsics to test once added

* `clamp` & `sign`

* 2-arg atan

* 3-arg max and min
  • Loading branch information
christiangnrd authored Feb 4, 2025
1 parent ca092c8 commit b8ab3b6
Show file tree
Hide file tree
Showing 2 changed files with 265 additions and 53 deletions.
39 changes: 30 additions & 9 deletions src/device/intrinsics/math.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,14 @@ using Base.Math: throw_complex_domainerror
# - add support for vector types
# - consider emitting LLVM intrinsics and lowering those in the back-end

### Common Intrinsics
@device_function clamp_fast(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.fast_clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval)
@device_override Base.clamp(x::Float32, minval::Float32, maxval::Float32) = ccall("extern air.clamp.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, minval, maxval)
@device_override Base.clamp(x::Float16, minval::Float16, maxval::Float16) = ccall("extern air.clamp.f16", llvmcall, Float16, (Float16, Float16, Float16), x, minval, maxval)

@device_override Base.sign(x::Float32) = ccall("extern air.sign.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.sign(x::Float16) = ccall("extern air.sign.f16", llvmcall, Float16, (Float16,), x)

### Floating Point Intrinsics

## Metal only supports single and half-precision floating-point types (and their vector counterparts)
Expand All @@ -17,13 +25,21 @@ using Base.Math: throw_complex_domainerror
@device_override Base.abs(x::Float32) = ccall("extern air.fabs.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.abs(x::Float16) = ccall("extern air.fabs.f16", llvmcall, Float16, (Float16,), x)

@device_override FastMath.min_fast(x::Float32) = ccall("extern air.fast_fmin.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.min(x::Float32) = ccall("extern air.fmin.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.min(x::Float16) = ccall("extern air.fmin.f16", llvmcall, Float16, (Float16,), x)
@device_override FastMath.min_fast(x::Float32, y::Float32) = ccall("extern air.fast_fmin.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.min(x::Float32, y::Float32) = ccall("extern air.fmin.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.min(x::Float16, y::Float16) = ccall("extern air.fmin.f16", llvmcall, Float16, (Float16, Float16), x, y)

@device_override FastMath.min_fast(x::Float32, y::Float32, z::Float32) = ccall("extern air.fast_fmin3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_override Base.min(x::Float32, y::Float32, z::Float32) = ccall("extern air.fmin3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_override Base.min(x::Float16, y::Float16, z::Float16) = ccall("extern air.fmin3.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z)

@device_override FastMath.max_fast(x::Float32) = ccall("extern air.fast_fmax.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.max(x::Float32) = ccall("extern air.fmax.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.max(x::Float16) = ccall("extern air.fmax.f16", llvmcall, Float16, (Float16,), x)
@device_override FastMath.max_fast(x::Float32, y::Float32) = ccall("extern air.fast_fmax.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.max(x::Float32, y::Float32) = ccall("extern air.fmax.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.max(x::Float16, y::Float16) = ccall("extern air.fmax.f16", llvmcall, Float16, (Float16, Float16), x, y)

@device_override FastMath.max_fast(x::Float32, y::Float32, z::Float32) = ccall("extern air.fast_fmax3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_override Base.max(x::Float32, y::Float32, z::Float32) = ccall("extern air.fmax3.f32", llvmcall, Cfloat, (Cfloat, Cfloat, Cfloat), x, y, z)
@device_override Base.max(x::Float16, y::Float16, z::Float16) = ccall("extern air.fmax3.f16", llvmcall, Float16, (Float16, Float16, Float16), x, y, z)

@device_override FastMath.acos_fast(x::Float32) = ccall("extern air.fast_acos.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.acos(x::Float32) = ccall("extern air.acos.f32", llvmcall, Cfloat, (Cfloat,), x)
Expand All @@ -45,6 +61,10 @@ using Base.Math: throw_complex_domainerror
@device_override Base.atan(x::Float32) = ccall("extern air.atan.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.atan(x::Float16) = ccall("extern air.atan.f16", llvmcall, Float16, (Float16,), x)

@device_override FastMath.atan_fast(x::Float32, y::Float32) = ccall("extern air.fast_atan2.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.atan(x::Float32, y::Float32) = ccall("extern air.atan2.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y)
@device_override Base.atan(x::Float16, y::Float16) = ccall("extern air.atan2.f16", llvmcall, Float16, (Float16, Float16), x, y)

@device_override FastMath.atanh_fast(x::Float32) = ccall("extern air.fast_atanh.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.atanh(x::Float32) = ccall("extern air.atanh.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.atanh(x::Float16) = ccall("extern air.atanh.f16", llvmcall, Float16, (Float16,), x)
Expand Down Expand Up @@ -240,6 +260,7 @@ end
s = ccall("extern air.sincos.f32", llvmcall, Cfloat, (Cfloat, Ptr{Cfloat}), x, c)
(s, c[])
end
# XXX: Broken
@device_override function Base.sincos(x::Float16)
c = Ref{Float16}()
s = ccall("extern air.sincos.f16", llvmcall, Float16, (Float16, Ptr{Float16}), x, c)
Expand Down Expand Up @@ -267,8 +288,8 @@ end
@device_override Base.tanh(x::Float16) = ccall("extern air.tanh.f16", llvmcall, Float16, (Float16,), x)

@device_function tanpi_fast(x::Float32) = ccall("extern air.fast_tanpi.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_function tanpi(x::Float32) = ccall("extern air.tanpi.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_function tanpi(x::Float16) = ccall("extern air.tanpi.f16", llvmcall, Float16, (Float16,), x)
@device_override Base.tanpi(x::Float32) = ccall("extern air.tanpi.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.tanpi(x::Float16) = ccall("extern air.tanpi.f16", llvmcall, Float16, (Float16,), x)

@device_function trunc_fast(x::Float32) = ccall("extern air.fast_trunc.f32", llvmcall, Cfloat, (Cfloat,), x)
@device_override Base.trunc(x::Float32) = ccall("extern air.trunc.f32", llvmcall, Cfloat, (Cfloat,), x)
Expand Down Expand Up @@ -418,7 +439,7 @@ end
j = fma(1.442695f0, a, 12582912.0f0)
j = j - 12582912.0f0
i = unsafe_trunc(Int32, j)
f = fma(j, -6.93145752f-1, a) # log_2_hi
f = fma(j, -6.93145752f-1, a) # log_2_hi
f = fma(j, -1.42860677f-6, f) # log_2_lo

# approximate r = exp(f)-1 on interval [-log(2)/2, +log(2)/2]
Expand Down
279 changes: 235 additions & 44 deletions test/device/intrinsics.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using SpecialFunctions
using Metal: metal_support
using Random
using SpecialFunctions

@testset "arguments" begin
@on_device dispatch_quadgroups_per_threadgroup()
Expand Down Expand Up @@ -103,71 +104,261 @@ end

############################################################################################

MATH_INTR_FUNCS_1_ARG = [
# Common functions
# saturate, # T saturate(T x) Clamp between 0.0 and 1.0
sign, # T sign(T x) returns 0.0 if x is NaN

# float math
acos, # T acos(T x)
asin, # T asin(T x)
asinh, # T asinh(T x)
atan, # T atan(T x)
atanh, # T atanh(T x)
ceil, # T ceil(T x)
cos, # T cos(T x)
cosh, # T cosh(T x)
cospi, # T cospi(T x)
exp, # T exp(T x)
exp2, # T exp2(T x)
exp10, # T exp10(T x)
abs, #T [f]abs(T x)
floor, # T floor(T x)
Metal.fract, # T fract(T x)
# ilogb, # Ti ilogb(T x)
log, # T log(T x)
log2, # T log2(T x)
log10, # T log10(T x)
# Metal.rint, # T rint(T x) # TODO: Add test. Not sure what the behaviour actually is
round, # T round(T x)
Metal.rsqrt, # T rsqrt(T x)
sin, # T sin(T x)
sinh, # T sinh(T x)
sinpi, # T sinpi(T x)
sqrt, # sqrt(T x)
tan, # T tan(T x)
tanh, # T tanh(T x)
tanpi, # T tanpi(T x)
trunc, # T trunc(T x)
]
Metal.rsqrt(x::Float16) = 1 / sqrt(x)
Metal.rsqrt(x::Float32) = 1 / sqrt(x)
Metal.fract(x::Float16) = mod(x, 1)
Metal.fract(x::Float32) = mod(x, 1)

MATH_INTR_FUNCS_2_ARG = [
# Common function
# step, # T step(T edge, T x) Returns 0.0 if x < edge, otherwise it returns 1.0

# float math
atan, # T atan2(T x, T y) Compute arc tangent of y over x.
# fdim, # T fdim(T x, T y)
max, # T [f]max(T x, T y)
min, # T [f]min(T x, T y)
# fmod, # T fmod(T x, T y)
# frexp, # T frexp(T x, Ti &exponent)
# ldexp, # T ldexp(T x, Ti k)
# modf, # T modf(T x, T &intval)
# nextafter, # T nextafter(T x, T y) # Metal 3.1+
# sincos,
hypot, # NOT MSL but tested the same
]

MATH_INTR_FUNCS_3_ARG = [
# Common functions
# mix, # T mix(T x, T y, T a) # x+(y-x)*a
# smoothstep, # T smoothstep(T edge0, T edge1, T x)
fma, # T fma(T a, T b, T c)
max, # T max3(T x, T y, T z)
# median3, # T median3(T x, T y, T z)
min, # T min3(T x, T y, T z)
]

@testset "math" begin
a = ones(Float32,1)
a .* Float32(3.14)
bufferA = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a)
vecA = unsafe_wrap(Vector{Float32}, pointer(bufferA), 1)
# 1-arg functions
@testset "$(fun)()::$T" for fun in MATH_INTR_FUNCS_1_ARG, T in (Float32, Float16)
cpuarr = if fun in [log, log2, log10, Metal.rsqrt, sqrt]
rand(T, 4)
else
T[0.0, -0.0, rand(T), -rand(T)]
end

mtlarr = MtlArray(cpuarr)

mtlout = fill!(similar(mtlarr), 0)

function intr_test(arr)
function kernel(res, arr)
idx = thread_position_in_grid_1d()
arr[idx] = cos(arr[idx])
res[idx] = fun(arr[idx])
return nothing
end
@metal intr_test(bufferA)
synchronize()
@test vecA cos.(a)
Metal.@sync @metal threads = length(mtlout) kernel(mtlout, mtlarr)
@eval @test Array($mtlout) $fun.($cpuarr)
end
# 2-arg functions
@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_2_ARG
N = 4
arr1 = randn(T, N)
arr2 = randn(T, N)
mtlarr1 = MtlArray(arr1)
mtlarr2 = MtlArray(arr2)

mtlout = fill!(similar(mtlarr1), 0)

function intr_test2(arr)
function kernel(res, x, y)
idx = thread_position_in_grid_1d()
arr[idx] = Metal.rsqrt(arr[idx])
res[idx] = fun(x[idx], y[idx])
return nothing
end
@metal intr_test2(bufferA)
synchronize()
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
@eval @test Array($mtlout) $fun.($arr1, $arr2)
end
# 3-arg functions
@testset "$(fun)()::$T" for T in (Float32, Float16), fun in MATH_INTR_FUNCS_3_ARG
N = 4
arr1 = randn(T, N)
arr2 = randn(T, N)
arr3 = randn(T, N)

bufferB = MtlArray{eltype(a),length(size(a)),Metal.SharedStorage}(a)
vecB = unsafe_wrap(Vector{Float32}, pointer(bufferB), 1)
mtlarr1 = MtlArray(arr1)
mtlarr2 = MtlArray(arr2)
mtlarr3 = MtlArray(arr3)

function intr_test3(arr_sin, arr_cos)
mtlout = fill!(similar(mtlarr1), 0)

function kernel(res, x, y, z)
idx = thread_position_in_grid_1d()
s, c = sincos(arr_cos[idx])
arr_sin[idx] = s
arr_cos[idx] = c
res[idx] = fun(x[idx], y[idx], z[idx])
return nothing
end
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2, mtlarr3)
@eval @test Array($mtlout) $fun.($arr1, $arr2, $arr3)
end
end

@metal intr_test3(bufferA, bufferB)
synchronize()
@test vecA sin.(a)
@test vecB cos.(a)
@testset "unique math" begin
@testset "$T" for T in (Float32, Float16)
let # acosh
arr = T[0, rand(T, 3)...] .+ T(1)
buffer = MtlArray(arr)
vec = acosh.(buffer)
@test Array(vec) acosh.(arr)
end

b = collect(LinRange(nextfloat(-1f0), 10f0, 20))
bufferC = MtlArray(b)
vecC = Array(log1p.(bufferC))
@test vecC log1p.(b)
let # sincos
N = 4
arr = rand(T, N)
bufferA = MtlArray(arr)
bufferB = MtlArray(arr)
function intr_test3(arr_sin, arr_cos)
idx = thread_position_in_grid_1d()
sinres, cosres = sincos(arr_cos[idx])
arr_sin[idx] = sinres
arr_cos[idx] = cosres
return nothing
end
# Broken with Float16
if T == Float16
@test_broken Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
else
Metal.@sync @metal threads = N intr_test3(bufferA, bufferB)
@test Array(bufferA) sin.(arr)
@test Array(bufferB) cos.(arr)
end
end

let # clamp
N = 4
in = randn(T, N)
minval = fill(T(-0.6), N)
maxval = fill(T(0.6), N)

d = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
bufferD = MtlArray(d)
vecD = Array(SpecialFunctions.erf.(bufferD))
@test vecD SpecialFunctions.erf.(d)
mtlin = MtlArray(in)
mtlminval = MtlArray(minval)
mtlmaxval = MtlArray(maxval)

mtlout = fill!(similar(mtlin), 0)

function kernel(res, x, y, z)
idx = thread_position_in_grid_1d()
res[idx] = clamp(x[idx], y[idx], z[idx])
return nothing
end
Metal.@sync @metal threads = N kernel(mtlout, mtlin, mtlminval, mtlmaxval)
@test Array(mtlout) == clamp.(in, minval, maxval)
end

e = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
bufferE = MtlArray(e)
vecE = Array(SpecialFunctions.erfc.(bufferE))
@test vecE SpecialFunctions.erfc.(e)
let #pow
N = 4
arr1 = rand(T, N)
arr2 = rand(T, N)
mtlarr1 = MtlArray(arr1)
mtlarr2 = MtlArray(arr2)

f = collect(LinRange(-1f0, 1f0, 20))
bufferF = MtlArray(f)
vecF = Array(SpecialFunctions.erfinv.(bufferF))
@test vecF SpecialFunctions.erfinv.(f)
mtlout = fill!(similar(mtlarr1), 0)

f = collect(LinRange(nextfloat(-88f0), 88f0, 100))
bufferF = MtlArray(f)
vecF = Array(expm1.(bufferF))
@test vecF expm1.(f)
function kernel(res, x, y)
idx = thread_position_in_grid_1d()
res[idx] = x[idx]^y[idx]
return nothing
end
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
@test Array(mtlout) arr1 .^ arr2
end

let #powr
N = 4
arr1 = rand(T, N)
arr2 = rand(T, N)
mtlarr1 = MtlArray(arr1)
mtlarr2 = MtlArray(arr2)

mtlout = fill!(similar(mtlarr1), 0)

function kernel(res, x, y)
idx = thread_position_in_grid_1d()
res[idx] = Metal.powr(x[idx], y[idx])
return nothing
end
Metal.@sync @metal threads = N kernel(mtlout, mtlarr1, mtlarr2)
@test Array(mtlout) arr1 .^ arr2
end

let # log1p
arr = collect(LinRange(nextfloat(-1.0f0), 10.0f0, 20))
buffer = MtlArray(arr)
vec = Array(log1p.(buffer))
@test vec log1p.(arr)
end

let # erf
arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
buffer = MtlArray(arr)
vec = Array(SpecialFunctions.erf.(buffer))
@test vec SpecialFunctions.erf.(arr)
end

let # erfc
arr = collect(LinRange(nextfloat(-3.0f0), 3.0f0, 20))
buffer = MtlArray(arr)
vec = Array(SpecialFunctions.erfc.(buffer))
@test vec SpecialFunctions.erfc.(arr)
end

let # erfinv
arr = collect(LinRange(-1.0f0, 1.0f0, 20))
buffer = MtlArray(arr)
vec = Array(SpecialFunctions.erfinv.(buffer))
@test vec SpecialFunctions.erfinv.(arr)
end

let # expm1
arr = collect(LinRange(nextfloat(-88.0f0), 88.0f0, 100))
buffer = MtlArray(arr)
vec = Array(expm1.(buffer))
@test vec expm1.(arr)
end
end
end

############################################################################################
Expand Down

0 comments on commit b8ab3b6

Please sign in to comment.