Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add mistral fine-tuning and examples #395

Merged
merged 11 commits into from
Mar 11, 2024
Merged

Add mistral fine-tuning and examples #395

merged 11 commits into from
Mar 11, 2024

Conversation

saum7800
Copy link
Collaborator

@saum7800 saum7800 commented Mar 9, 2024

Description

This PR contains 3 changes primarily:

  1. Add a QLora trainer that has been tested to qlora fine-tune a Mistral 7B model on created datasets, It is implemented as a class and a train_model function, similar to the SFTTrainer class in huggingface
  2. Added in context examples for data transformation prompts. This is a minor addition to the existing prompts to improve transformation. Majority of the changes brought about in the transformation process will be part of a PR by @ritugala in the coming days.
  3. Added an examples directory with 3 smaller examples for: (a) creating transfomed data (b) creating synthetic data (c) qlora fine-tuning a mistral model from an existing dataset. This sets us up well to add examples in the future relating to: model selection, hyperparam optimization, full fine-tuning, etc.

@saum7800 saum7800 requested review from neubig and viswavi March 9, 2024 01:03
Copy link
Collaborator

@neubig neubig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good but I had a few suggestions, mostly for clarity!

prompt2model/dataset_transformer/prompt_template.py Outdated Show resolved Hide resolved
prompt2model/dataset_transformer/prompt_template.py Outdated Show resolved Hide resolved
prompt2model/model_trainer/peft_trainer.py Outdated Show resolved Hide resolved
prompt2model/model_trainer/peft_trainer.py Outdated Show resolved Hide resolved
prompt2model/utils/dataset_utils.py Outdated Show resolved Hide resolved
prompt2model/utils/dataset_utils.py Outdated Show resolved Hide resolved
prompt2model/model_trainer/peft_trainer.py Outdated Show resolved Hide resolved
prompt2model/model_trainer/peft_trainer.py Outdated Show resolved Hide resolved
prompt2model/model_trainer/peft_trainer.py Outdated Show resolved Hide resolved
@saum7800 saum7800 requested a review from neubig March 10, 2024 01:58
@saum7800
Copy link
Collaborator Author

have resolved all the issues as you had mentioned. there was a linting issue which was only showing up on GitHub actions, and not locally. locally, all pre-commit hooks are passing. have instructed to ignore the two error codes BLK 100 (Black would have made changes) and W503 (linebreak occured before binary operator). The BLK one is definitely fine to be ignored, since we are running black separately anyways. W503 raised errors on a lot of files that have been unmodified.

.flake8 Outdated
@@ -1,3 +1,5 @@
[flake8]
max-line-length = 88
extend-ignore = E203,FI10,FI11,FI12,FI13,FI14,FI15,FI16,FI17,FI18
per-file-ignores = prompt2model/dataset_transformer/prompt_template.py:E501
ignore = BLK100, W503
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think these two can actually be ignored globally, not just for this file.

BLK100 can be ignored because we are already running black in a separate action, and W503 seems to be commonly ignored by developers, since it conflicts with another warning in flake8, W504.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah, makes sense. "ignore" would have ignored globally, but I could just add it to extend-ignore directly instead of creating a new line for it.

Copy link
Collaborator

@viswavi viswavi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At a high-level, this looks great! The core of this PR (the QLoRA trainer) looks awesome and I have no technical notes about that.

Stylistically, I think this PR needs some revision in terms of comment formatting and variable naming.

The formatting of comments in this PR doesn't match typical recommended Python style (which we should adhere to). Python's style guidelines suggest using inline comments sparingly. Use comments in their own line whenever possible (which probably applies to all of these comments). Also, comments should almost always be full, grammatical English sentences with proper capitalization and punctuation. Please ensure that comments follow these two guidelines before we can merge this.

examples/mistral_qlora_finetune_example.py Outdated Show resolved Hide resolved
examples/mistral_qlora_finetune_example.py Outdated Show resolved Hide resolved
examples/mistral_qlora_finetune_example.py Outdated Show resolved Hide resolved
Comment on lines 34 to 35
plan_prompt_fn: Callable[
[str, list[dict], str], str
[str, list[dict], str, int], str
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With a complicated function interface like this, it would be good to explain what this function is (and why this particular interface is required) in the "Args:" section of the docstring to this __init__ function.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sure, did this. also rearranged variables to make more logical sense

prompt2model/model_trainer/qlora_trainer.py Outdated Show resolved Hide resolved
self.model, os.path.join(output_dir, "qlora_model")
)
self.model = self.model.merge_and_unload()
self.model.save_pretrained(os.path.join(output_dir, "final_model"))
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if the hard-coding of the model name here would cause problems, e.g. overwriting models.

It makes sense to at least store this as a variable (e.g. PEFT_MODEL_DIRECTORY=os.path.join(output_dir, "final_model")) for better visibility

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the idea with passing save_folder_path to the function was to provide a path so that nothing in the current directory gets overwritten. but can also add the variable you mentioned for better visibility, and a comment in the demo

prompt2model/model_trainer/qlora_trainer.py Outdated Show resolved Hide resolved
prompt2model/utils/dataset_utils.py Outdated Show resolved Hide resolved
prompt2model/utils/dataset_utils.py Outdated Show resolved Hide resolved
prompt2model/utils/parse_responses.py Show resolved Hide resolved
@saum7800
Copy link
Collaborator Author

all the proposed changes make sense. have made the changes. please re-review when you get a chance, thanks!

@saum7800 saum7800 requested a review from viswavi March 11, 2024 00:18
Copy link
Collaborator

@neubig neubig left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM now, thanks for doing the revisions!

Copy link
Collaborator

@viswavi viswavi left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tiny nit - we are variously using "QLora" and "QLoRA" in these changes; might be good to stick to the latter for consistency. (the other places where you refer to this in lowercase, "qlora", are a different story and I think those cases are totally fine).

Otherwise, LGTM!

@saum7800
Copy link
Collaborator Author

made it QLoRA in all comments. thanks!

@saum7800 saum7800 merged commit 25e0a96 into main Mar 11, 2024
8 checks passed
@saum7800 saum7800 deleted the saumya_mistral_qlora branch March 11, 2024 03:38
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants