-
Notifications
You must be signed in to change notification settings - Fork 160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
feat: lora fine-tuning in FHE + gpt2 use case example #823
Conversation
23ebfd9
to
f20697d
Compare
timings are from my computer, it hasn't been refreshed |
6c1caeb
to
e63afe9
Compare
e63afe9
to
1e92c43
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't reviewed everything yet.
Main issue is :
- can we make it work for other models than GPT2 (for example the simple MLP)
- make the LoraTraining use self.training to determine if it's doing inference or training
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to remove the transformer dependency:
- move the dep to the use case example, replace conv1d -> nn.Linear in the usecase
- parametrize the remote module finding function in way that avoids the dep
src/concrete/ml/torch/lora.py
Outdated
from typing import List | ||
|
||
import torch | ||
from transformers import Conv1D as TransformerConv1D |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
adding the dependency to transformers .. we had removed it earlier
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
alright done
return grad_input, None, None | ||
|
||
|
||
def get_remote_names(model: torch.nn.Module, include_embedding_layers: bool = False) -> List[str]: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this is too specific for a library function. e.g. for the lm_head
moniker.
I suggest you add two arguments remote_layer_types
and layer_reject_filter
.
You can then call with remote_layer_types = [nn.Linear, nn.Embedding, transformer.Conv1D]
. Thus CML does not need to know TransformerConv1D
. and layer_reject_filter = ['lm_head
]`
You can thus remove the include_embedding_layers
flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
but in the function we manually add CustomLinear
if the user asks for nn.Linear
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure here. This is very specific with the replace_layers_with_custom
. I would not allow the user to select layers himself. For now we do linear layers in FHE and we have a workaround for lm_head and embedding as long as there are not fixed.
09a96b5
to
0379e15
Compare
|
Coverage failed ❌Coverage details
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Excellent work here @RomanBredehoft and @jfrery !
currently running the notebook with 100 epochs
refs https://github.com/zama-ai/concrete-ml-internal/issues/4522