diff --git a/src/Metal.jl b/src/Metal.jl index 7eb057e4..28b38c55 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -27,12 +27,12 @@ include("device/utils.jl") include("device/pointer.jl") include("device/array.jl") include("device/runtime.jl") +include("device/intrinsics/version.jl") include("device/intrinsics/arguments.jl") include("device/intrinsics/math.jl") include("device/intrinsics/synchronization.jl") include("device/intrinsics/memory.jl") include("device/intrinsics/simd.jl") -include("device/intrinsics/version.jl") include("device/intrinsics/atomics.jl") include("device/quirks.jl") diff --git a/src/compiler/execution.jl b/src/compiler/execution.jl index 58b2c7f2..54a089a3 100644 --- a/src/compiler/execution.jl +++ b/src/compiler/execution.jl @@ -18,9 +18,9 @@ certain extent arguments will be converted and managed automatically using `mtlc Finally, a call to `mtlcall` is performed, creating a command buffer in the current global command queue then committing it. -There is one supported keyword argument that influences the behavior of `@metal`: +There are a few keyword arguments that influence the behavior of `@metal`: -- `launch`: whether to launch this kernel, defaults to `true`. If `false` the returned +- `launch`: whether to launch this kernel, defaults to `true`. If `false`, the returned kernel object should be launched by calling it and passing arguments again. - `name`: the name of the kernel in the generated code. Defaults to an automatically- generated name. diff --git a/src/device/intrinsics/math.jl b/src/device/intrinsics/math.jl index b3394239..19a49dfc 100644 --- a/src/device/intrinsics/math.jl +++ b/src/device/intrinsics/math.jl @@ -294,6 +294,21 @@ end @device_override Base.trunc(x::Float32) = ccall("extern air.trunc.f32", llvmcall, Cfloat, (Cfloat,), x) @device_override Base.trunc(x::Float16) = ccall("extern air.trunc.f16", llvmcall, Float16, (Float16,), x) +@device_function function nextafter(x::Float32, y::Float32) + if metal_version() >= sv"3.1" # macOS 14+ + ccall("extern air.nextafter.f32", llvmcall, Cfloat, (Cfloat, Cfloat), x, y) + else + nextfloat(x, unsafe_trunc(Int32, sign(y - x))) + end +end +@device_function function nextafter(x::Float16, y::Float16) + if metal_version() >= sv"3.1" # macOS 14+ + ccall("extern air.nextafter.f16", llvmcall, Float16, (Float16, Float16), x, y) + else + nextfloat(x, unsafe_trunc(Int16, sign(y - x))) + end +end + # hypot without use of double # # taken from Cosmopolitan Libc diff --git a/test/device/intrinsics.jl b/test/device/intrinsics.jl index b71f62b1..ecd5e44d 100644 --- a/test/device/intrinsics.jl +++ b/test/device/intrinsics.jl @@ -159,7 +159,6 @@ MATH_INTR_FUNCS_2_ARG = [ # frexp, # T frexp(T x, Ti &exponent) # ldexp, # T ldexp(T x, Ti k) # modf, # T modf(T x, T &intval) - # nextafter, # T nextafter(T x, T y) # Metal 3.1+ hypot, # NOT MSL but tested the same ] @@ -353,6 +352,47 @@ end vec = Array(expm1.(buffer)) @test vec ≈ expm1.(arr) end + + + let # nextafter + function nextafter_test(X, y) + idx = thread_position_in_grid_1d() + X[idx] = Metal.nextafter(X[idx], y) + return nothing + end + + # Check the code is generated as expected + outval = T(0) + function nextafter_out_test() + Metal.nextafter(outval, outval) + return + end + + N = 4 + arr = rand(T, N) + + # test the intrinsic (macOS >= v14) + if metal_support() >= v"3.1" + buffer1 = MtlArray(arr) + Metal.@sync @metal threads = N nextafter_test(buffer1, typemax(T)) + @test Array(buffer1) == nextfloat.(arr) + Metal.@sync @metal threads = N nextafter_test(buffer1, typemin(T)) + @test Array(buffer1) == arr + + ir = sprint(io->(@device_code_llvm io=io dump_module=true @metal nextafter_out_test())) + @test occursin(Regex("@air\\.nextafter\\.f$(8*sizeof(T))"), ir) + end + + # test for metal < 3.1 + buffer2 = MtlArray(arr) + Metal.@sync @metal threads = N metal = v"3.0" nextafter_test(buffer2, typemax(T)) + @test Array(buffer2) == nextfloat.(arr) + Metal.@sync @metal threads = N metal = v"3.0" nextafter_test(buffer2, typemin(T)) + @test Array(buffer2) == arr + + ir = sprint(io->(@device_code_llvm io=io dump_module=true @metal metal = v"3.0" nextafter_out_test())) + @test occursin(Regex("@air\\.sign\\.f$(8*sizeof(T))"), ir) + end end end