Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

inference with lora 时报错expected scalar type Half but found Float #110

Closed
3 tasks done
arceus-jia opened this issue Aug 10, 2023 · 5 comments
Closed
3 tasks done
Labels

Comments

@arceus-jia
Copy link

提交前必须检查以下项目

  • 请确保使用的是仓库最新代码(git pull),一些问题已被解决和修复。
  • 我已阅读项目文档FAQ章节并且已在Issue中对问题进行了搜索,没有找到相似问题和解决方案
  • 第三方插件问题:例如llama.cpptext-generation-webui等,同时建议到对应的项目中查找解决方案

问题类型

模型推理

基础模型

Alpaca-2-7B

操作系统

Linux

详细描述问题

python scripts/inference/inference_hf.py     --base_model   ziqingyang/chinese-alpaca-2-7b     --with_prompt     --interactive --lora_model ../output/test1/sft_lora_model/

使用sft指令微调的训练结果在inference时候会报错。 
expected scalar type Half but found Float
我看之前的issue说把lora和原始模型merge就 work了, 但显然动态加载lora是有需求的

依赖情况(代码类问题务必提供)

# 请在此处粘贴依赖情况

运行日志或截图

  File "/home/ubuntu/ml/llm/Chinese-LLaMA-Alpaca-2/scripts/inference/inference_hf.py", line 182, in <module>
    generation_output = model.generate(
  File "/home/ubuntu/ml/llm/peft/src/peft/peft_model.py", line 581, in generate
    outputs = self.base_model.generate(**kwargs)
  File "/home/ubuntu/anaconda3/envs/llm/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/generation/utils.py", line 1588, in generate
    return self.sample(
  File "/home/ubuntu/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/generation/utils.py", line 2642, in sample
    outputs = self(
  File "/home/ubuntu/anaconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/llm/lib/python3.10/site-packages/accelerate/hooks.py", line 165, in new_forward
    output = old_forward(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 806, in forward
    outputs = self.model(
  File "/home/ubuntu/anaconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 693, in forward
    layer_outputs = decoder_layer(
  File "/home/ubuntu/anaconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/llm/lib/python3.10/site-packages/transformers/models/llama/modeling_llama.py", line 408, in forward
    hidden_states, self_attn_weights, present_key_value = self.self_attn(
  File "/home/ubuntu/anaconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/ml/llm/Chinese-LLaMA-Alpaca-2/scripts/attn_and_long_ctx_patches.py", line 44, in xformers_forward
    query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
  File "/home/ubuntu/anaconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/ml/llm/peft/src/peft/tuners/lora.py", line 358, in forward
    result += self.lora_B(self.lora_A(self.lora_dropout(x))) * self.scaling
  File "/home/ubuntu/anaconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/ubuntu/anaconda3/envs/llm/lib/python3.10/site-packages/torch/nn/modules/linear.py", line 114, in forward
    return F.linear(input, self.weight, self.bias)
RuntimeError: expected scalar type Half but found Float
@arceus-jia arceus-jia changed the title inferences with lora 时报错expected scalar type Half but found Float inference with lora 时报错expected scalar type Half but found Float Aug 10, 2023
@arceus-jia
Copy link
Author

是peft 本身的问题, 新版我看 peft里lora的 forward已经加了result.to(previous_dtype), 升级可破。 但这样的话inference和训练就需要两个不同的环境了。有空可否针对peft0.4 升级一下训练代码。

@airaria
Copy link
Contributor

airaria commented Aug 10, 2023

新版的peft对显存的占用会比较大,所以我们一直使用旧版的peft。
不过感谢提醒,我们会关注peft的版本更新情况,适合时升级到更高版本

@nerylj
Copy link

nerylj commented Aug 11, 2023

希望能动态加载lora

@github-actions
Copy link

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you for your consideration.

@github-actions github-actions bot added the stale label Aug 21, 2023
@github-actions
Copy link

Closing the issue, since no updates observed. Feel free to re-open if you need any further assistance.

@github-actions github-actions bot closed this as not planned Won't fix, can't repro, duplicate, stale Aug 27, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

3 participants