-
Notifications
You must be signed in to change notification settings - Fork 49
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: add flash attn to inference and eval scripts #132
feat: add flash attn to inference and eval scripts #132
Conversation
Signed-off-by: Anh-Uong <[email protected]>
dc2a2ea
to
4c5768c
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks Anh! One question
scripts/run_inference.py
Outdated
attn_implementation="flash_attention_2" | ||
if use_flash_attn | ||
else None, | ||
torch_dtype=torch.float16 if use_flash_attn else None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason that float16 is being hardcoded here since flash attention also supports bfloat32?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah I wasn't sure about this, I think I did it based on this comment and perhaps some examples I saw like here. But I think you're very right that this should be removed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If it's removed, what torch_dtype will be used to load the model? If trying to load a model that uses float32 for example this would fail so should this be set by the user? and default to bfloat16 as we have in sft_trainer?
For example I get warning:
Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GPTBigCodeForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`
and then error RuntimeError: FlashAttention only support fp16 and bf16 data type
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ahh sorry, I meant bfloat16, not bfloat32 :) The best choice probably depends on what type the model is tuned with, but the safer choice is probably bfloat16
since it has the a bigger exponent (same range as float32).
It would be nice to check the behavior for what happens if you try loading with bfloat16 on a device that doesn't support it, to see if it has a fallback (and if that actually modifies the loaded model dtype
value, or just the way computations are handled) or throws. But the most reasonable choice is probably to do the same thing we are doing for tuning to select the default data type at tune time, or check if bf16 is supported and use it if we can. The easiest way to check if bf16 is available is:
torch.cuda.is_available() and torch.cuda.is_bf16_supported()
(the former throws if you try to call it in torch with no cuda installed, which is why it's needed)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We use bfloat16
as the default for tuning (ref). So bfloat16
is the right default here, and I assume since we don't explicitly handle it, from_pretrained
has a sensible fallback
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Modified to load with bfloat16 and I tried testing on a v100 where torch.cuda.is_available() == True
and torch.cuda.is_bf16_supported() == False
. How do I check the dtype of the model? I ran...
>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.bfloat16)
config.json: 100%|█████████████████████████████████████████████████████████████████████████| 693/693 [00:00<00:00, 4.49MB/s]
model.safetensors: 100%|███████████████████████████████████████████████████████████████| 1.12G/1.12G [00:14<00:00, 79.3MB/s]
>>> model.config.torch_dtype
torch.bfloat16
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The config.torch_dtype
is set by the model at save time, I think! The one you should want is just the model .dtype
, which reflects the current instance loaded on your device, I don't think conversions like .to
(which moves the model to different devices/types etc) would update the config based attribute but I am not sure
Signed-off-by: Anh-Uong <[email protected]>
Signed-off-by: Anh-Uong <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
…stack#132) * add flash attn to inference and eval scripts Signed-off-by: Anh-Uong <[email protected]> * load model with torch_dtype bfloat16 Signed-off-by: Anh-Uong <[email protected]> --------- Signed-off-by: Anh-Uong <[email protected]> Signed-off-by: aaron.chew1 <[email protected]>
Description of the change
Add
use_flash_attn
flag for the inference and evaluation scripts. When used, this will load the model using flash attention which can affect the inference results.Related issue number
related to: #103
How to verify the PR
Ran the scripts loading llama-7b model and a Lora tuned llama-7b model that requires merging. Saw that when using flag, the model is loaded with flash attention as noted by the log output.
inference:
Was the PR tested