-
Notifications
You must be signed in to change notification settings - Fork 152
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Training with half-precision doesn't work for the torch tile or CUDA bindings #623
Comments
I see only |
@jubueche no, neither work. The error is different for MWE:
As far as I am aware, we don't have any examples for half precision, so I'm not particularly surprised it doesn't work. |
I see that there are some compile options
@maljoras maybe you know how to enable that? |
@coreylammie in which GPUs did you try this one? |
A100_80GB. Once we do figure this out, it would be great to add an example for it. I intend on adding an example for MobileBERT/SQuAD anyway, so perhaps we can add a single example using half-precision support for this network/task. |
@coreylammie Note that your MWE was not even training in FP32. I have changed it to the below:
Ideally, this should train. The autocast is unfortunately only supported for CUDA. |
@coreylammie could you use the branch above and see if all tests pass on GPU and you can run the example above? Also, feel free to enter the autocast again. Just remove the |
@jubueche first, the MWE was not intended to train. It was indended to reproduce the errror, which is raised when
Are you sure this is not supported? |
I see. Maybe I forgot to set something. I will check soon. In the meantime, can you check if it runs on GPU in my PR? |
Need to document and add an example. |
Description
Training with half-precision doesn't work for the torch tile or CUDA bindings, e..g, when
with torch.autocast(device_type="cuda", dtype=torch.bfloat16):
is used in conjunction withrpu_config.runtime.data_type = RPUDataType.HALF
.How to reproduce
Convert a model with either
TorchInferenceRPUConfig()
orTorchInferenceRPUConfig()
, specifyrpu_config.runtime.data_type = RPUDataType.HALF
, and usewith torch.autocast(device_type="cuda", dtype=torch.bfloat16):
in the training loop.Expected behavior
https://aihwkit.readthedocs.io/en/latest/api/aihwkit.simulator.parameters.enums.html#aihwkit.simulator.parameters.enums.RPUDataType infers that this is supported.
The text was updated successfully, but these errors were encountered: