Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: add flash attn to inference and eval scripts #132

Merged

Conversation

anhuong
Copy link
Collaborator

@anhuong anhuong commented Apr 23, 2024

Description of the change

Add use_flash_attn flag for the inference and evaluation scripts. When used, this will load the model using flash attention which can affect the inference results.

Related issue number

related to: #103

How to verify the PR

Ran the scripts loading llama-7b model and a Lora tuned llama-7b model that requires merging. Saw that when using flag, the model is loaded with flash attention as noted by the log output.

inference:

# without
$ python scripts/run_inference.py --model /llama-eval-pvc/LLaMa/models/hf/7B --out_file llama_7b --max_new_tokens 50 --text "Today is a good day. What do you want to do today?" 
No adapter config found! Loading as a merged model...
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████| 2/2 [00:25<00:00, 12.86s/it]
Inferred device: cuda
100%|█████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.29s/it]
Exported results to: llama_7b

# with flash attn
$ python scripts/run_inference.py --model /llama-eval-pvc/LLaMa/models/hf/7B --out_file llama_7b_flash --max_new_tokens 50 --text "Today is a good day. What do you want to do today?" --use_flash_attn
No adapter config found! Loading as a merged model...
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with model.to('cuda').
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.35s/it]
Inferred device: cuda
100%|█████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:01<00:00,  1.80s/it]
Exported results to: llama_7b_flash

# lora tuned model with flash attn
$ python scripts/run_inference.py --model /nfs-storage-pvc/llama2-7b-twitter-lora-default-checkpoint-15/ --base_model_name_or_path /llama-eval-pvc/LLaMa/models/hf/7B --out_file llama_7b_lora --max_new_tokens 50 --text "Today is a good day. What do you want to do today?" --use_flash_attn
You are attempting to use Flash Attention 2.0 with a model not initialized on GPU. Make sure to move the model to GPU after initializing it on CPU with `model.to('cuda')`.
Loading checkpoint shards: 100%|██████████████████████████████████████████████████████████████| 2/2 [00:02<00:00,  1.03s/it]
Inferred device: cuda
100%|█████████████████████████████████████████████████████████████████████████████████████████| 1/1 [00:02<00:00,  2.02s/it]
Exported results to: llama_7b_lora

Was the PR tested

  • I have added >=1 unit test(s) for every new method I have added.
  • I have ensured all unit tests pass

@anhuong anhuong force-pushed the flash-attn-inference branch from dc2a2ea to 4c5768c Compare April 23, 2024 17:02
@anhuong anhuong changed the title add flash attn to inference and eval scripts feat: add flash attn to inference and eval scripts Apr 24, 2024
Copy link
Collaborator

@alex-jw-brooks alex-jw-brooks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks Anh! One question

attn_implementation="flash_attention_2"
if use_flash_attn
else None,
torch_dtype=torch.float16 if use_flash_attn else None,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason that float16 is being hardcoded here since flash attention also supports bfloat32?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah I wasn't sure about this, I think I did it based on this comment and perhaps some examples I saw like here. But I think you're very right that this should be removed

Copy link
Collaborator Author

@anhuong anhuong Apr 25, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If it's removed, what torch_dtype will be used to load the model? If trying to load a model that uses float32 for example this would fail so should this be set by the user? and default to bfloat16 as we have in sft_trainer?

For example I get warning:

Flash Attention 2.0 only supports torch.float16 and torch.bfloat16 dtypes, but the current dype in GPTBigCodeForCausalLM is torch.float32. You should run training or inference using Automatic Mixed-Precision via the `with torch.autocast(device_type='torch_device'):` decorator, or load the model with the `torch_dtype` argument. Example: `model = AutoModel.from_pretrained("openai/whisper-tiny", attn_implementation="flash_attention_2", torch_dtype=torch.float16)`

and then error RuntimeError: FlashAttention only support fp16 and bf16 data type

Copy link
Collaborator

@alex-jw-brooks alex-jw-brooks May 1, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ahh sorry, I meant bfloat16, not bfloat32 :) The best choice probably depends on what type the model is tuned with, but the safer choice is probably bfloat16 since it has the a bigger exponent (same range as float32).

It would be nice to check the behavior for what happens if you try loading with bfloat16 on a device that doesn't support it, to see if it has a fallback (and if that actually modifies the loaded model dtype value, or just the way computations are handled) or throws. But the most reasonable choice is probably to do the same thing we are doing for tuning to select the default data type at tune time, or check if bf16 is supported and use it if we can. The easiest way to check if bf16 is available is:

torch.cuda.is_available() and torch.cuda.is_bf16_supported()
(the former throws if you try to call it in torch with no cuda installed, which is why it's needed)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We use bfloat16 as the default for tuning (ref). So bfloat16 is the right default here, and I assume since we don't explicitly handle it, from_pretrained has a sensible fallback

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Modified to load with bfloat16 and I tried testing on a v100 where torch.cuda.is_available() == True and torch.cuda.is_bf16_supported() == False. How do I check the dtype of the model? I ran...

>>> model = AutoModelForCausalLM.from_pretrained("bigscience/bloom-560m", torch_dtype=torch.bfloat16)
config.json: 100%|█████████████████████████████████████████████████████████████████████████| 693/693 [00:00<00:00, 4.49MB/s]
model.safetensors: 100%|███████████████████████████████████████████████████████████████| 1.12G/1.12G [00:14<00:00, 79.3MB/s]
>>> model.config.torch_dtype
torch.bfloat16

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The config.torch_dtype is set by the model at save time, I think! The one you should want is just the model .dtype, which reflects the current instance loaded on your device, I don't think conversions like .to (which moves the model to different devices/types etc) would update the config based attribute but I am not sure

Copy link
Collaborator

@alex-jw-brooks alex-jw-brooks left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@alex-jw-brooks alex-jw-brooks merged commit dd29d49 into foundation-model-stack:main May 2, 2024
6 checks passed
achew010 pushed a commit to achew010/fms-hf-tuning that referenced this pull request May 6, 2024
…stack#132)

* add flash attn to inference and eval scripts

Signed-off-by: Anh-Uong <[email protected]>

* load model with torch_dtype bfloat16

Signed-off-by: Anh-Uong <[email protected]>

---------

Signed-off-by: Anh-Uong <[email protected]>
Signed-off-by: aaron.chew1 <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants