Skip to content
This repository has been archived by the owner on Mar 12, 2021. It is now read-only.

Attempt at mixed precision for contraction #598

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 9 additions & 7 deletions src/tensor/wrappers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ function elementwiseTrinary!(
descC = CuTensorDescriptor(C; op = opC)
@assert size(C) == size(D) && strides(C) == strides(D)
descD = descC # must currently be identical
#typeCompute = cudaDataType(T)
typeCompute = cudaDataType(T)
modeA = collect(Cint, Ainds)
modeB = collect(Cint, Binds)
Expand Down Expand Up @@ -207,9 +206,10 @@ function contraction!(
alpha::Number, A::CuArray, Ainds::ModeType, opA::cutensorOperator_t,
B::CuArray, Binds::ModeType, opB::cutensorOperator_t,
beta::Number, C::CuArray, Cinds::ModeType, opC::cutensorOperator_t,
opOut::cutensorOperator_t,
opOut::cutensorOperator_t;
pref::cutensorWorksizePreference_t=CUTENSOR_WORKSPACE_RECOMMENDED,
algo::cutensorAlgo_t=CUTENSOR_ALGO_DEFAULT, stream::CuStream=CuDefaultStream())
algo::cutensorAlgo_t=CUTENSOR_ALGO_DEFAULT, stream::CuStream=CuDefaultStream(),
compute_type::Type=eltype(C))

!is_unary(opA) && throw(ArgumentError("opA must be a unary op!"))
!is_unary(opB) && throw(ArgumentError("opB must be a unary op!"))
Expand All @@ -219,8 +219,10 @@ function contraction!(
descB = CuTensorDescriptor(B; op = opB)
descC = CuTensorDescriptor(C; op = opC)
# for now, D must be identical to C (and thus, descD must be identical to descC)
T = eltype(C)
computeType = cutensorComputeType(T) #CUTENSOR_R_MIN_64F #TODO cudaDataType(T)
computeType = cutensorComputeType(compute_type)
# fix this and use a look-up table
#T = sizeof(compute_type) < sizeof(eltype(C)) ? eltype(C) : compute_type
T = sizeof(compute_type) < sizeof(eltype(C)) ? eltype(C) : compute_type
modeA = collect(Cint, Ainds)
modeB = collect(Cint, Binds)
modeC = collect(Cint, Cinds)
Expand Down Expand Up @@ -251,8 +253,8 @@ function contraction!(
cutensorInitContractionPlan(handle(), plan, desc, find, sizeof(workspace))

cutensorContraction(handle(), plan,
T[alpha], A, B,
T[beta], C, C,
T[convert(T, alpha)], A, B,
T[convert(T, beta)], C, C,
workspace, sizeof(workspace), stream)
end
return C
Expand Down
97 changes: 48 additions & 49 deletions test/tensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -389,23 +389,24 @@ end
end

@testset "Contraction" begin
eltypes = ( #(Float16, Float16, Float16), # works for some
# (Float16, Float16, Float32), # works for some but claims otherwise
(Float32, Float32, Float32),
eltypes = ( (Float16, Float16, Float16, Float32), # works for some but claims otherwise
(Float32, Float32, Float32, Float32),
(Float32, Float32, Float32, Float16),
#(Float32, ComplexF32, ComplexF32),
#(ComplexF32, Float32, ComplexF32),
# (Float32, Float32, Float64), # does not work
(Float64, Float64, Float64),
#(Float64, ComplexF64, ComplexF64),
#(ComplexF64, Float64, ComplexF64),
# (ComplexF16, ComplexF16, ComplexF16), # does not work
(ComplexF32, ComplexF32, ComplexF32), # works for some
(Float64, Float64, Float64, Float64),
(Float64, Float64, Float64, Float32),
(Float64, ComplexF64, ComplexF64, ComplexF64),
(ComplexF64, Float64, ComplexF64, ComplexF64),
(ComplexF32, ComplexF32, ComplexF32, ComplexF32), # works for some
# (ComplexF32, ComplexF32, ComplexF64), # does not work
(ComplexF64, ComplexF64, ComplexF64) # works for some
(ComplexF64, ComplexF64, ComplexF64, ComplexF64), # works for some
(ComplexF64, ComplexF64, ComplexF64, ComplexF32) # works for some
)

@testset for NoA=1:3, NoB=1:3, Nc=1:3
@testset for (eltyA, eltyB, eltyC) in eltypes
@testset for (eltyA, eltyB, eltyC, eltyCompute) in eltypes
# setup
dmax = 2^div(18, max(NoA+Nc, NoB+Nc, NoA+NoB))
dimsoA = rand(2:dmax, NoA)
Expand All @@ -424,7 +425,7 @@ end
ipB = invperm(pB)
pC = randperm(NoA + NoB)
ipC = invperm(pC)

compute_rtol = (real(eltyCompute) == Float16 || real(eltyC) == Float16) ? 1e-2 : (real(eltyCompute) == Float32 ? 1e-4 : 1e-6)
dimsA = [dimsoA; dimsc][pA]
indsA = [indsoA; indsc][pA]
dimsB = [dimsc; dimsoB][pB]
Expand All @@ -446,79 +447,77 @@ end
opB = CUTENSOR.CUTENSOR_OP_IDENTITY
opC = CUTENSOR.CUTENSOR_OP_IDENTITY
opOut = CUTENSOR.CUTENSOR_OP_IDENTITY
dC = CUTENSOR.contraction!(1, dA, indsA, opA, dB, indsB, opB,
0, dC, indsC, opC, opOut)
dC = @sync CUTENSOR.contraction!(one(eltyCompute), dA, indsA, opA, dB, indsB, opB, zero(eltyCompute), dC, indsC, opC, opOut, compute_type=eltyCompute)
C = collect(dC)
mC = reshape(permutedims(C, ipC), (loA, loB))
@test mC ≈ mA * mB
@test mC ≈ mA * mB rtol=compute_rtol

# with non-trivial α
α = rand(eltyC)
dC = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB,
0, dC, indsC, opC, opOut)
α = rand(eltyCompute)
dC = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB, zero(eltyCompute), dC, indsC, opC, opOut, compute_type=eltyCompute)
C = collect(dC)
mC = reshape(permutedims(C, ipC), (loA, loB))
@test mC ≈ α * mA * mB
@test mC ≈ α * mA * mB rtol=compute_rtol

# with non-trivial β
C = rand(eltyC, (dimsC...,))
dC = CuArray(C)
α = rand(eltyC)
β = rand(eltyC)
α = rand(eltyCompute)
β = rand(eltyCompute)
copyto!(dC, C)
dD = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB,
β, dC, indsC, opC, opOut)
dD = CUTENSOR.contraction!(α, dA, indsA, opA, dB, indsB, opB, β, dC, indsC, opC, opOut, compute_type=eltyCompute)
D = collect(dD)
mC = reshape(permutedims(C, ipC), (loA, loB))
mD = reshape(permutedims(D, ipC), (loA, loB))
@test mD ≈ α * mA * mB + β * mC
@test mD ≈ α * mA * mB + β * mC rtol=compute_rtol

# with CuTensor objects
ctA = CuTensor(dA, indsA)
ctB = CuTensor(dB, indsB)
ctC = CuTensor(dC, indsC)
ctC = LinearAlgebra.mul!(ctC, ctA, ctB)
C2, C2inds = collect(ctC)
mC = reshape(permutedims(C2, ipC), (loA, loB))
@test mC ≈ mA * mB
ctC = ctA * ctB
C2, C2inds = collect(ctC)
pC2 = convert.(Int, indexin(convert.(Char, C2inds), [indsoA; indsoB]))
mC = reshape(permutedims(C2, invperm(pC2)), (loA, loB))
@test mC ≈ mA * mB

if eltyCompute != Float32 && eltyC != Float16
ctA = CuTensor(dA, indsA)
ctB = CuTensor(dB, indsB)
ctC = CuTensor(dC, indsC)
ctC = LinearAlgebra.mul!(ctC, ctA, ctB)
C2, C2inds = collect(ctC)
mC = reshape(permutedims(C2, ipC), (loA, loB))
@test mC ≈ mA * mB
ctC = ctA * ctB
C2, C2inds = collect(ctC)
pC2 = convert.(Int, indexin(convert.(Char, C2inds), [indsoA; indsoB]))
mC = reshape(permutedims(C2, invperm(pC2)), (loA, loB))
@test mC ≈ mA * mB
end
# with conjugation flag for complex arguments
if !((NoA, NoB, Nc) in ((1,1,3), (1,2,3), (3,1,2)))
# not supported for these specific cases for unknown reason
if eltyA <: Complex
opA = CUTENSOR.CUTENSOR_OP_CONJ
opB = CUTENSOR.CUTENSOR_OP_IDENTITY
opA = CUTENSOR.CUTENSOR_OP_CONJ
opB = CUTENSOR.CUTENSOR_OP_IDENTITY
opOut = CUTENSOR.CUTENSOR_OP_IDENTITY
dC = CUTENSOR.contraction!(1, dA, indsA, opA, dB, indsB, opB,
0, dC, indsC, opC, opOut)
C = collect(dC)
mC = reshape(permutedims(C, ipC), (loA, loB))
@test mC ≈ conj(mA) * mB
dC = CUTENSOR.contraction!(one(eltyCompute), dA, indsA, opA, dB, indsB, opB,
zero(eltyCompute), dC, indsC, opC, opOut, compute_type=eltyCompute)
C = collect(dC)
mC = reshape(permutedims(C, ipC), (loA, loB))
@test mC ≈ conj(mA) * mB rtol=compute_rtol
end
if eltyB <: Complex
opA = CUTENSOR.CUTENSOR_OP_IDENTITY
opB = CUTENSOR.CUTENSOR_OP_CONJ
opOut = CUTENSOR.CUTENSOR_OP_IDENTITY
dC = CUTENSOR.contraction!(1, dA, indsA, opA, dB, indsB, opB,
0, dC, indsC, opC, opOut)
dC = CUTENSOR.contraction!(one(eltyCompute), dA, indsA, opA, dB, indsB, opB,
zero(eltyCompute), dC, indsC, opC, opOut, compute_type=eltyCompute)
C = collect(dC)
mC = reshape(permutedims(C, ipC), (loA, loB))
@test mC ≈ mA*conj(mB)
@test mC ≈ mA*conj(mB) rtol=compute_rtol
end
if eltyA <: Complex && eltyB <: Complex
opA = CUTENSOR.CUTENSOR_OP_CONJ
opB = CUTENSOR.CUTENSOR_OP_CONJ
opOut = CUTENSOR.CUTENSOR_OP_IDENTITY
dC = CUTENSOR.contraction!(1, dA, indsA, opA, dB, indsB, opB,
0, dC, indsC, opC, opOut)
dC = CUTENSOR.contraction!(one(eltyCompute), dA, indsA, opA, dB, indsB, opB,
zero(eltyCompute), dC, indsC, opC, opOut, compute_type=eltyCompute)
C = collect(dC)
mC = reshape(permutedims(C, ipC), (loA, loB))
@test mC ≈ conj(mA)*conj(mB)
@test mC ≈ conj(mA)*conj(mB) rtol=compute_rtol
end
end
end
Expand Down