Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Full DPO Distributed #2275

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open

Conversation

sam-pi
Copy link

@sam-pi sam-pi commented Jan 17, 2025

Context

Adapted from the great work in #1966

What is the purpose of this PR? Is it to

  • add a new feature

Please link to any issues this PR addresses: relates to #2082

Changelog

What are the changes made in this PR?

  • Adds full DPO distributed training configs and recipes, adapting from the lora DPO training
  • Includes integration tests
  • Includes configs for llama3.1 8B and 70B models

Test plan

Please make sure to do each of the following if applicable to your PR. If you're unsure about any one of these just ask and we will happily help. We also have a contributing page for some guidance on contributing.

  • run pre-commit hooks and linters (make sure you've first installed via pre-commit install)
  • add unit tests for any new functionality
  • update docstrings for any new or updated methods or classes
  • run unit tests via pytest tests
  • run recipe tests via pytest tests -m integration_test
  • manually run any new or modified recipes with sufficient proof of correctness
  • include relevant commands and any other artifacts in this summary (pastes of loss curves, eval results, etc.)

UX

If your function changed a public API, please add a dummy example of what the user experience will look like when calling it.
Here is a docstring example
and a tutorial example

  • I did not change any public API

Commands and Sample Outputs

Full DPO Config

output_dir: .../Meta-Llama-3.1-8B-Instruct/full_dpo
model:
  _component_: torchtune.models.llama3_1.llama3_1_8b
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: .../Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
  max_seq_len: 1024
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: .../Meta-Llama-3.1-8B-Instruct
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
resume_from_checkpoint: false
ref_checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: .../Meta-Llama-3.1-8B-Instruct
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
dataset:
  _component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: true
batch_size: 4
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  weight_decay: 0.05
  lr: 1.0e-06
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
loss:
  _component_: torchtune.rlhf.loss.DPOLoss
  beta: 0.05
  label_smoothing: 0
epochs: 1
max_steps_per_epoch: 2000
gradient_accumulation_steps: 4
compile: false
metric_logger:
  _component_: torchtune.training.metric_logging.WandBLogger
  log_dir: ${output_dir}/logs
  project: torchtune
  name: llama3.1-8B-dpo_3605
log_every_n_steps: 1
log_peak_memory_stats: true
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false

Lora DPO Config

output_dir: .../Meta-Llama-3.1-8B-Instruct/lora_dpo
model:
  _component_: torchtune.models.llama3_1.lora_llama3_1_8b
  lora_attn_modules:
  - q_proj
  - v_proj
  - output_proj
  apply_lora_to_mlp: true
  apply_lora_to_output: false
  lora_rank: 256
  lora_alpha: 256
  lora_dropout: 0.0
tokenizer:
  _component_: torchtune.models.llama3.llama3_tokenizer
  path: .../Meta-Llama-3.1-8B-Instruct/original/tokenizer.model
  max_seq_len: 1024
checkpointer:
  _component_: torchtune.training.FullModelHFCheckpointer
  checkpoint_dir: .../Meta-Llama-3.1-8B-Instruct
  checkpoint_files:
  - model-00001-of-00004.safetensors
  - model-00002-of-00004.safetensors
  - model-00003-of-00004.safetensors
  - model-00004-of-00004.safetensors
  recipe_checkpoint: null
  output_dir: ${output_dir}
  model_type: LLAMA3
resume_from_checkpoint: false
save_adapter_weights_only: false
dataset:
  _component_: torchtune.datasets.stack_exchange_paired_dataset
seed: null
shuffle: true
batch_size: 4
optimizer:
  _component_: torch.optim.AdamW
  fused: true
  weight_decay: 0.05
  lr: 1.0e-05
lr_scheduler:
  _component_: torchtune.training.lr_schedulers.get_cosine_schedule_with_warmup
  num_warmup_steps: 100
loss:
  _component_: torchtune.rlhf.loss.DPOLoss
  beta: 0.1
  label_smoothing: 0
epochs: 1
max_steps_per_epoch: 100
gradient_accumulation_steps: 4
compile: false
metric_logger:
  _component_: torchtune.training.metric_logging.WandBLogger
  log_dir: ${output_dir}/logs
  project: torchtune
  name: llama3.1-8Blora-dpo_3603
log_every_n_steps: 1
log_peak_memory_stats: true
device: cuda
dtype: bf16
enable_activation_checkpointing: true
enable_activation_offloading: false
Screenshot 2025-01-16 at 12 39 23 PM

Copy link

pytorch-bot bot commented Jan 17, 2025

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/torchtune/2275

Note: Links to docs will display an error until the docs builds have been completed.

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@facebook-github-bot facebook-github-bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Jan 17, 2025
@sam-pi
Copy link
Author

sam-pi commented Jan 17, 2025

@joecummings Please take a look and let me know if you have feedback!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants