Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: lora fine-tuning in FHE + gpt2 use case example #823

Merged
merged 32 commits into from
Sep 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5af1529
docs: add_hybrid_lora_fine_tuning
RomanBredehoft Aug 2, 2024
0a0c455
chore: add makefile to use case
RomanBredehoft Aug 5, 2024
def6ef2
chore: fix pcc
RomanBredehoft Aug 5, 2024
8f7ccaa
chore: add push_changes target to use case action
RomanBredehoft Aug 5, 2024
fa97cbd
chore: fix pcc
RomanBredehoft Aug 5, 2024
ad4a636
chore: refresh notebook
RomanBredehoft Aug 5, 2024
b8a7df4
chore: clean notebook
RomanBredehoft Aug 5, 2024
d96a8cc
chore: refresh notebook(s) for use case lora_finetune
RomanBredehoft Aug 5, 2024
e53d494
chore: improve refresh notebook and use case
RomanBredehoft Aug 5, 2024
6650622
chore: fix pcc
RomanBredehoft Aug 5, 2024
a16006a
chore: add disable adapters and print lora weights
RomanBredehoft Aug 5, 2024
b5a1ba8
chore: add simulation execution
RomanBredehoft Aug 5, 2024
e675d47
chore: clean notebook
RomanBredehoft Aug 5, 2024
72a4c7b
chore: add loss plot
RomanBredehoft Aug 6, 2024
b0eff66
chore: refresh notebook(s) for use case lora_finetune
RomanBredehoft Aug 6, 2024
4881b39
chore: add FHE embedding layers
RomanBredehoft Aug 6, 2024
cf32cb7
chore: update requirements
RomanBredehoft Aug 6, 2024
15a7254
chore: add lm_head
RomanBredehoft Aug 6, 2024
445ee0e
chore: fix remote embedding and lm_head
RomanBredehoft Aug 6, 2024
9053e1b
chore: temporarily remove embedding/lm_head from remote
RomanBredehoft Aug 7, 2024
039f08b
chore: add 16b training, without embedding layers
RomanBredehoft Aug 8, 2024
b5c659d
chore: seed text generation
RomanBredehoft Aug 16, 2024
2f20b0f
chore: rename use case and notebook + add readme + refacto + fix
jfrery Sep 23, 2024
4d51e35
chore: update licenses
jfrery Sep 23, 2024
9024a71
chore: fix forbidden words
jfrery Sep 23, 2024
9d4f956
chore: fix codeblock
jfrery Sep 23, 2024
2e7f15d
chore: update notebook executed
jfrery Sep 23, 2024
8e8409a
chore: lora more generic for the MLP
jfrery Sep 24, 2024
7dca588
chore: add LoraMLP notebook
jfrery Sep 24, 2024
f736d84
chore: pcc + test
jfrery Sep 25, 2024
bbc3266
chore: add docstring loratraining
jfrery Sep 26, 2024
0379e15
chore: make transformer lib optional
jfrery Sep 26, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 10 additions & 3 deletions .github/workflows/refresh-one-notebook.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ on:
- FullyConnectedNeuralNetwork \n
- FullyConnectedNeuralNetworkOnMNIST \n
- GLMComparison \n
- GPT2FineTuneHybrid \n
- HealthCarePrediction \n
- ImportingFromScikitLearn \n
- KaggleTitanic \n
Expand All @@ -29,6 +30,7 @@ on:
- LinearSVR \n
- LogisticRegression \n
- LogisticRegressionTraining \n
- LoraMLP \n
- PerrorImpactOnFMNIST \n
- PoissonRegression \n
- QGPT2Evaluate \n
Expand Down Expand Up @@ -67,6 +69,7 @@ env:
FullyConnectedNeuralNetwork: "docs/advanced_examples/FullyConnectedNeuralNetwork.ipynb"
FullyConnectedNeuralNetworkOnMNIST: "docs/advanced_examples/FullyConnectedNeuralNetworkOnMNIST.ipynb"
GLMComparison: "docs/advanced_examples/GLMComparison.ipynb"
GPT2FineTuneHybrid: "use_case_examples/lora_finetuning/GPT2FineTuneHybrid.ipynb"
HealthCarePrediction: "use_case_examples/disease_prediction/HealthCarePrediction.ipynb"
ImportingFromScikitLearn: "docs/advanced_examples/ImportingFromScikitLearn.ipynb"
KaggleTitanic: "use_case_examples/titanic/KaggleTitanic.ipynb"
Expand All @@ -75,6 +78,7 @@ env:
LinearSVR: "docs/advanced_examples/LinearSVR.ipynb"
LogisticRegression: "docs/advanced_examples/LogisticRegression.ipynb"
LogisticRegressionTraining: "docs/advanced_examples/LogisticRegressionTraining.ipynb"
LoraMLP: "docs/advanced_examples/LoraMLP.ipynb"
PerrorImpactOnFMNIST: "use_case_examples/cifar/cifar_brevitas_finetuning/PerrorImpactOnFMNIST.ipynb"
PoissonRegression: "docs/advanced_examples/PoissonRegression.ipynb"
QGPT2Evaluate: "use_case_examples/llm/QGPT2Evaluate.ipynb"
Expand Down Expand Up @@ -195,11 +199,11 @@ jobs:
with:
token: ${{ secrets.BOT_TOKEN }}
commit-message: "chore: refresh ${{ github.event.inputs.notebook }} notebook"
branch: "refresh-${{ github.event.inputs.notebook }}-notebook-for-${{ github.ref_name }}"
branch: "refresh-${{ github.event.inputs.notebook }}-notebook-for-branch-${{ github.ref_name }}"
base: "${{ github.ref_name }}"
title: "Refresh ${{ github.event.inputs.notebook }} notebook for ${{ github.ref_name }}"
title: "Refresh ${{ github.event.inputs.notebook }} notebook for branch ${{ github.ref_name }}"
body: "Automatic PR with notebook refresh of ${{ github.event.inputs.notebook }} \
for ${{ github.ref_name }}."
for branch ${{ github.ref_name }}."
add-paths: |
docs/**/*.ipynb
use_case_examples/**/*.ipynb
Expand All @@ -211,6 +215,9 @@ jobs:
with:
commit_message: "chore: refresh ${{ github.event.inputs.notebook }} notebook"
add_options: '-u'
file_pattern: |
docs/**/*.ipynb
use_case_examples/**/*.ipynb


stop-runner-linux:
Expand Down
39 changes: 38 additions & 1 deletion .github/workflows/run_one_use_cases_example.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,16 @@ on:
- federated_learning
- hybrid_model
- llm
- lora_finetuning
- resnet
- sentiment_analysis_with_transformer
- titanic
# --- refresh_use_cases_list.py: refresh list of use cases currently available [END] ---
push_changes:
description: 'Push refreshed notebook(s)'
required: false
type: boolean
default: false

concurrency:
group: ${{ github.ref }}
Expand Down Expand Up @@ -102,6 +108,37 @@ jobs:
USE_CASE=${{ github.event.inputs.use_case }}
make run_one_use_case_example USE_CASE=$USE_CASE

# Pull the latest changes if there are some
- name: Pull latest changes
if: ${{ github.event.inputs.push_changes == 'true' }}
run: |
git pull -X theirs

# If the target branch is main or a release branch, a Pull Request is opened for everyone to
# review.
- name: Open PR
if: ${{ github.event.inputs.push_changes == 'true' && (github.ref_name == 'main' || startsWith(github.ref_name , 'release/')) }}
uses: peter-evans/create-pull-request@c5a7806660adbe173f04e3e038b0ccdcd758773c
with:
token: ${{ secrets.BOT_TOKEN }}
commit-message: "chore: refresh notebook(s) for use case ${{ github.event.inputs.use_case }}"
branch: "refresh-notebook(s)-for-use-case-${{ github.event.inputs.use_case }}-for-branch-${{ github.ref_name }}"
base: "${{ github.ref_name }}"
title: "Refresh notebook(s) for use case ${{ github.event.inputs.use_case }} for branch ${{ github.ref_name }}"
body: "Automatic PR with notebook(s) refresh of use case ${{ github.event.inputs.use_case }} \
for branch ${{ github.ref_name }}."
add-paths: |
use_case_examples/**/*.ipynb

# If the target branch is another branch, the current branch is automatically merged into it
- name: Push changes into the current branch
if: ${{ github.event.inputs.push_changes == 'true' && github.ref_name != 'main' && !(startsWith(github.ref_name , 'release/')) }}
uses: stefanzweifel/git-auto-commit-action@8621497c8c39c72f3e2a999a26b4ca1b5058a842 #v5.0.1
with:
commit_message: "chore: refresh notebook(s) for use case ${{ github.event.inputs.use_case }}"
add_options: '-u'
file_pattern: 'use_case_examples/**/*.ipynb'

stop-runner-linux:
name: Stop EC2 runner
needs: [run-use-case-examples, start-runner-linux]
Expand Down Expand Up @@ -162,4 +199,4 @@ jobs:
- run-use-case-examples: ${{ needs.run-use-case-examples.result || 'Did not run.' }}\n\n\
- stop-runner-linux: ${{ needs.stop-runner-linux.result || 'Did not run.'}}"
SLACK_USERNAME: ${{ secrets.BOT_USERNAME }}
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
SLACK_WEBHOOK: ${{ secrets.SLACK_WEBHOOK }}
2 changes: 1 addition & 1 deletion deps_licenses/licenses_linux_user.txt.md5
Original file line number Diff line number Diff line change
@@ -1 +1 @@
31249a607336424af0d790feb9db6252
af423f91bb5313f1f1670ea72892d364
2 changes: 1 addition & 1 deletion deps_licenses/licenses_mac_intel_user.txt.md5
Original file line number Diff line number Diff line change
@@ -1 +1 @@
31249a607336424af0d790feb9db6252
af423f91bb5313f1f1670ea72892d364
2 changes: 1 addition & 1 deletion deps_licenses/licenses_mac_silicon_user.txt.md5
Original file line number Diff line number Diff line change
@@ -1 +1 @@
31249a607336424af0d790feb9db6252
af423f91bb5313f1f1670ea72892d364
534 changes: 534 additions & 0 deletions docs/advanced_examples/LoraMLP.ipynb

Large diffs are not rendered by default.

64 changes: 62 additions & 2 deletions poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ torchvision = [
{version = "0.17.2", markers = "platform_system=='Darwin' and platform_machine!='arm64'" },
{version = "0.18.1", markers = "platform_system!='Darwin' or platform_machine=='arm64'" }
]
peft = "^0.12.0"

[build-system]
requires = ["poetry-core>=1.0.0"]
Expand Down
2 changes: 1 addition & 1 deletion script/make_utils/run_use_case_examples.sh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ install_requirements() {
local example_dir=$1
if [ -f "${example_dir}/requirements.txt" ]; then
pushd "$example_dir"
if pip install -r requirements.txt; then
if pip install -r requirements.txt --extra-index-url https://pypi.zama.ai/cpu; then
echo "Requirements installed successfully."
else
echo "Failed to install requirements."
Expand Down
15 changes: 12 additions & 3 deletions src/concrete/ml/torch/hybrid_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,13 +565,22 @@ def save_and_clear_private_info(self, path: Path, via_mlir=True):
"""
path = Path(path)
path.mkdir(parents=True, exist_ok=True)
for name in self.module_names:
module = self._get_module_by_name(self.model, name)
# Remove private information

# Save the complete model (including private info) for the developer
complete_model_path = path / "complete_model.pth"
torch.save(self.model.state_dict(), complete_model_path.resolve())

def clear_private_info(module):
for attr in ["private_module", "calibration_data", "private_q_module"]:
if hasattr(module, attr):
setattr(module, attr, None)

for child in module.children():
clear_private_info(child)

# Clear private info for the entire model
clear_private_info(self.model)

# Save the model with a specific filename
model_path = path / "model.pth"

Expand Down
Loading
Loading