Skip to content

Commit

Permalink
fix various fma issues (#533)
Browse files Browse the repository at this point in the history
This makes `src/builtins.jl` in sync with `bin/generate_builtins.jl`
again. A lot of this code was also out-of-date, now that we only support
1.6, so this includes some cleanup.
  • Loading branch information
KristofferC authored Apr 6, 2022
1 parent f74effe commit 0649d0d
Show file tree
Hide file tree
Showing 4 changed files with 75 additions and 0 deletions.
26 changes: 26 additions & 0 deletions bin/generate_builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,32 @@ function maybe_evaluate_builtin(frame, call_expr, expand::Bool)
"""
if isa(f, Core.IntrinsicFunction)
cargs = getargs(args, frame)
@static if isdefined(Core.Intrinsics, :have_fma)
if f === Core.Intrinsics.have_fma && length(cargs) == 1
cargs1 = cargs[1]
if cargs1 == Float64
return Some{Any}(FMA_FLOAT64[])
elseif cargs1 == Float32
return Some{Any}(FMA_FLOAT32[])
elseif cargs1 == Float16
return Some{Any}(FMA_FLOAT16[])
end
end
end
if f === Core.Intrinsics.muladd_float && length(cargs) == 3
a, b, c = cargs
Ta, Tb, Tc = typeof(a), typeof(b), typeof(c)
if !(Ta == Tb == Tc)
error("muladd_float: types of a, b, and c must match")
end
if Ta == Float64 && FMA_FLOAT64[]
f = Core.Intrinsics.fma_float
elseif Ta == Float32 && FMA_FLOAT32[]
f = Core.Intrinsics.fma_float
elseif Ta == Float16 && FMA_FLOAT16[]
f = Core.Intrinsics.fma_float
end
end
return Some{Any}(ccall(:jl_f_intrinsic_call, Any, (Any, Ptr{Any}, UInt32), f, cargs, length(cargs)))
""")
print(io,
Expand Down
26 changes: 26 additions & 0 deletions src/builtins.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,32 @@ function maybe_evaluate_builtin(frame, call_expr, expand::Bool)
end
if isa(f, Core.IntrinsicFunction)
cargs = getargs(args, frame)
@static if isdefined(Core.Intrinsics, :have_fma)
if f === Core.Intrinsics.have_fma && length(cargs) == 1
cargs1 = cargs[1]
if cargs1 == Float64
return Some{Any}(FMA_FLOAT64[])
elseif cargs1 == Float32
return Some{Any}(FMA_FLOAT32[])
elseif cargs1 == Float16
return Some{Any}(FMA_FLOAT16[])
end
end
end
if f === Core.Intrinsics.muladd_float && length(cargs) == 3
a, b, c = cargs
Ta, Tb, Tc = typeof(a), typeof(b), typeof(c)
if !(Ta == Tb == Tc)
error("muladd_float: types of a, b, and c must match")
end
if Ta == Float64 && FMA_FLOAT64[]
f = Core.Intrinsics.fma_float
elseif Ta == Float32 && FMA_FLOAT32[]
f = Core.Intrinsics.fma_float
elseif Ta == Float16 && FMA_FLOAT16[]
f = Core.Intrinsics.fma_float
end
end
return Some{Any}(ccall(:jl_f_intrinsic_call, Any, (Any, Ptr{Any}, UInt32), f, cargs, length(cargs)))
end
if isa(f, typeof(kwinvoke))
Expand Down
12 changes: 12 additions & 0 deletions src/packagedef.jl
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,12 @@ function set_compiled_methods()
push!(compiled_modules, Base.Threads)
end

_have_fma_compiled(::Type{T}) where {T} = Core.Intrinsics.have_fma(T)

const FMA_FLOAT64 = Ref(false)
const FMA_FLOAT32 = Ref(false)
const FMA_FLOAT16 = Ref(false)

function __init__()
set_compiled_methods()
COVERAGE[] = Base.JLOptions().code_coverage
Expand Down Expand Up @@ -144,6 +150,12 @@ function __init__()
# compiled_calls[(qsym, RT, Core.svec(AT...), Core.Compiler)] = f
# precompile(f, AT)
# end

@static if isdefined(Base, :have_fma)
FMA_FLOAT64[] = _have_fma_compiled(Float64)
FMA_FLOAT32[] = _have_fma_compiled(Float32)
FMA_FLOAT16[] = _have_fma_compiled(Float16)
end
end

include("precompile.jl")
Expand Down
11 changes: 11 additions & 0 deletions test/interpret.jl
Original file line number Diff line number Diff line change
Expand Up @@ -895,3 +895,14 @@ end
iscallexpr(ex::Expr) = ex.head === :call
@test (@interpret iscallexpr(:(sin(3.14))))
end

if isdefined(Base, :have_fma)
f_fma() = Base.have_fma(Float64)
@testset "fma" begin
@test (@interpret f_fma()) == f_fma()
a, b, c = (1.0585073227945125, -0.00040303348596386557, 1.5051263504758005e-16)
@test (@interpret muladd(a, b, c)) === muladd(a,b,c)
a = 1.0883740903666346; b = 2/3
@test (@interpret a^b) === a^b
end
end

0 comments on commit 0649d0d

Please sign in to comment.