Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ValueError: Target module Dropout(p=0.05, inplace=False) is not supported. Currently, only the following modules are supported: torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D. #2286

Closed
gyuilLim opened this issue Dec 17, 2024 · 6 comments

Comments

@gyuilLim
Copy link

gyuilLim commented Dec 17, 2024

System Info

Library version: PEFT==0.13.2, PyTorch==2.4.0, Transformers==4.46.3
Python version: 3.8.19
CUDA version: 12.6

I am trying to implement Low-Rank Adaptation (LoRA) in my model, but I encountered the following error when running the training script:

ValueError: Target module Dropout(p=0.05, inplace=False) is not supported. Currently, only the following modules are supported: torch.nn.Linear, torch.nn.Embedding, torch.nn.Conv2d, transformers.pytorch_utils.Conv1D.

[rank0]: Traceback (most recent call last):
[rank0]:   File "/home/vision/gyuil/lab/vga_finetuning/LLaVA/llava/train/train_mem.py", line 6, in <module>
[rank0]:     train(attn_implementation="flash_attention_2")
[rank0]:   File "/home/vision/gyuil/lab/vga_finetuning/LLaVA/llava/train/train.py", line 921, in train
[rank0]:     model = get_peft_model(model, lora_config)
[rank0]:   File "/home/vision/anaconda3/envs/torch/lib/python3.8/site-packages/peft/mapping.py", line 194, in get_peft_model
[rank0]:     return MODEL_TYPE_TO_PEFT_MODEL_MAPPING[peft_config.task_type](
[rank0]:   File "/home/vision/anaconda3/envs/torch/lib/python3.8/site-packages/peft/peft_model.py", line 1609, in __init__
[rank0]:     super().__init__(model, peft_config, adapter_name, **kwargs)
[rank0]:   File "/home/vision/anaconda3/envs/torch/lib/python3.8/site-packages/peft/peft_model.py", line 171, in __init__
[rank0]:     self.base_model = cls(model, {adapter_name: peft_config}, adapter_name)
[rank0]:   File "/home/vision/anaconda3/envs/torch/lib/python3.8/site-packages/peft/tuners/lora/model.py", line 141, in __init__
[rank0]:     super().__init__(model, config, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
[rank0]:   File "/home/vision/anaconda3/envs/torch/lib/python3.8/site-packages/peft/tuners/tuners_utils.py", line 184, in __init__
[rank0]:     self.inject_adapter(self.model, adapter_name, low_cpu_mem_usage=low_cpu_mem_usage)
[rank0]:   File "/home/vision/anaconda3/envs/torch/lib/python3.8/site-packages/peft/tuners/tuners_utils.py", line 496, in inject_adapter
[rank0]:     self._create_and_replace(peft_config, adapter_name, target, target_name, parent, current_key=key)
[rank0]:   File "/home/vision/anaconda3/envs/torch/lib/python3.8/site-packages/peft/tuners/lora/model.py", line 227, in _create_and_replace
[rank0]:     new_module = self._create_new_module(lora_config, adapter_name, target, **kwargs)
[rank0]:   File "/home/vision/anaconda3/envs/torch/lib/python3.8/site-packages/peft/tuners/lora/model.py", line 353, in _create_new_module
[rank0]:     raise ValueError(
[rank0]: ValueError: Target module Dropout(p=0.05, inplace=False) is not supported. Currently, only the following modules are supported: `torch.nn.Linear`, `torch.nn.Embedding`, `torch.nn.Conv2d`, `transformers.pytorch_utils.Conv1D`.

It seems that the LoRA implementation currently does not allow for Dropout layers to be included as target modules. Could you provide guidance on how to properly handle dropout with LoRA or whether it will be supported in future updates?

Thank you for your assistance!

@gyuilLim
Copy link
Author

gyuilLim commented Dec 17, 2024

I thought it was a Python version issue, so I upgraded to 3.9, but the same error still occurs.

So I guess that if a Dropout layer is not inside a Conv2D or Linear sequential layer and instead exists independently outside the sequential layer, dropout layer is not converted to lora_dropout

As a temporary workaround, I added the following code at line 225 in peft/tuners/lora/model.py:

else:
    if isinstance(target, nn.Dropout):
        return

Although this is not a fundamental solution, it resolves the issue for now.

@BenjaminBossan
Copy link
Member

Could you please show the code you use that results in the error? That way, we can better take a look at it. It appears like your target_modules is defined in a way that a dropout layer is included.

Regarding dropout: Yes, you cannot target dropout layers with LoRA. That wouldn't really make sense, since dropout layers don't have any learnable parameters. Note that lora_dropout something different, it is dropout applied to the input to the LoRA layer, it's not LoRA applied to a dropout layer.

so I upgraded to 3.9

This is generally a good idea, since Python 3.8 and below have reached their end of life and thus no longer receive security updates.

@gyuilLim
Copy link
Author

I'm fine-tuning the llava-v1.6-mistral-7b model in two stages. The first fine-tuning proceeds without any issues. However, I encounter the above error in the second fine-tuning, and I'm not sure about the cause.

This is main_train.py

image

The part where the error actually occurs is as follows.

image

@gyuilLim
Copy link
Author

Is it necessary to use a different method instead of get_peft_model to load a model trained with get_peft_model for further training?

I'm sorry to bother you.

@BenjaminBossan
Copy link
Member

Is it necessary to use a different method instead of get_peft_model to load a model trained with get_peft_model for further training?

Indeed, this is necessary. You should use

from peft import PeftModel

base_model = ...  # load llava model here
peft_model = PeftModel.from_pretrained(base_model, <path-to-saved-peft-model>)

Then, if you want to further fine-tune this model, you have two options: You can pass is_trainable=True to from_pretrained, then you can continue fine-tuning the loaded LoRA adapter. Or you can add a new LoRA adapter and fine-tune that one by calling peft_model.add_adapter(<adapter-name>, lora_config) (and peft_model.set_adapter(<adapter-name>) to activate it).

Please try the suggestion and see if it solves your issue.

This is main_train.py

What is the result of find_all_linear_names?

@gyuilLim
Copy link
Author

gyuilLim commented Dec 19, 2024

@BenjaminBossan

As you mentioned, loading the LoRA fine-tuning model with PeftModel.from_pretrained() and continuing the training solved the issue.

The error occurred because PeftModel.from_pretrained() was not used when loading the model trained with LoRA(peft).

Thank you !!!

Here is the result of find_all_linear_names.

def find_all_linear_names(model):
    cls = torch.nn.Linear
    lora_module_names = set()
    multimodal_keywords = ['mm_projector', 'vision_tower', 'vision_resampler']
    for name, module in model.named_modules():
        if any(mm_keyword in name for mm_keyword in multimodal_keywords):
            continue
        if isinstance(module, cls):
            names = name.split('.')
            lora_module_names.add(names[0] if len(names) == 1 else names[-1])

    if 'lm_head' in lora_module_names: # needed for 16-bit
        lora_module_names.remove('lm_head')
    return list(lora_module_names)
    
## result : ['default', 'base_layer']

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants