-
-
Notifications
You must be signed in to change notification settings - Fork 83
Attempt at mixed precision for contraction #598
Conversation
src/tensor/wrappers.jl
Outdated
@@ -251,8 +251,8 @@ function contraction!( | |||
cutensorInitContractionPlan(handle(), plan, desc, find, sizeof(workspace)) | |||
|
|||
cutensorContraction(handle(), plan, | |||
T[alpha], A, B, | |||
T[beta], C, C, | |||
compute_type[alpha], A, B, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The behavior w.r.t. the scalars in cuTENSOR is, unfortunately, slightly more involved:
https://docs.nvidia.com/cuda/cutensor/user_guide.html#scalar-types
Thus, it's not a 1:1 mapping. Maybe this is causing an issue?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I thought that the alpha
and beta
constants are of type typeCompute
in contraction
? That's what the docs here implied to me.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah, I see - with a change things seem to be doing better, although mysteriously I'm getting a not supported
error for all Float32
tensors with a compute type of Float16
...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree, the documentation is somewhat confusing; we'll add a reference to https://docs.nvidia.com/cuda/cutensor/user_guide.html#scalar-types as part of the contraction docs.
Could you please run the contraction with CUTENSOR_DEBUG=1 ? Which GPU are you computing on; FP16 support only starts with Volta.
DO NOT MERGE THIS
Paul (@springer13) agreed to take a look and see what's up with this sad PR that likely indicates I'm doing something wrong!