Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: dataclass args for accelerated MoE tuning #390

Merged

Conversation

willmj
Copy link
Collaborator

@willmj willmj commented Nov 15, 2024

Description of the change

This PR adds one dataclass argument to enable accelerted moe for sft_trainer.py, via the new fms-acceleration accelerated-moe plugin and allows for accelerated MoE full-finetuning with the --fast_moe flag. --fast_moe enables a technique to train Mixture of Expert (MoE) models in parallel instead of sequentially.
With this flag, we expect major speedup in train time and decrease in memory usage on Mixture of Expert models.

Framework Config EP Degree (parameter) Model Train Runtime Speedup Memory Usage Memory Savings
none N/A granite 3b a800 1 2371.93 base 71199 base
Scatter MoE 1 granite 3b a800 1 742.739 3.19 71187 1.0
Scatter MoE + Padding Free 1 granite 3b a800 1 631.976 3.75 48401 0.68
Scatter MoE + Padding Free + foak 1 granite 3b a800 1 615.453 3.85 42651 0.6
none N/A mixtral 8x7b 8 4180.95 base 65607 base
Scatter MoE 8 mixtral 8b7x 8 1071.2 3.9 52004.8 0.79
Scatter MoE + Padding Free + foak 8 mixtral 8x7b 8 1043.67 4.01 51961.2 0.79

Related issue number

How to verify the PR

This PR is a work-in-progress and requires more testing, and the official release of fms-acceleration-moe

  • To verify, run a tuning job with fast_moe.
  • Run a tuning job with other plugins added on top of fast_moe
  • Ensure that incorrect parameters result in failures
  • Ensure that non-MoE models cannot be trained with this plugin set

Was the PR tested

  • I have added >=1 unit test(s) for every new method I have added.
  • I have ensured all unit tests pass

Copy link

Thanks for making a pull request! 😃
One of the maintainers will review and advise on the next steps.

@github-actions github-actions bot added the feat label Nov 15, 2024
@willmj
Copy link
Collaborator Author

willmj commented Nov 21, 2024

Tested using new flag on granite 3 3b MoE, inference up next

Regular MOE tuning

Tested this branch without fast_moe

      {
          "model_name_or_path": "/ibm_dmf_lakehouse/models/base_training/shared/granite-3.0-3b-a800m-base/r240924a",
          "training_data_path": "/testing/tuning/input/cc_tone_sft_format_1000_train.json",
          "output_dir": "/testing/tuning/output/granite-3b-moe/ft/20241120_1014-tone",
          "save_model_dir": "/testing/tuning/output/granite-3b-moe/ft/20241120_1014-tone/save_model",
          "num_train_epochs": 10.0,
          "per_device_train_batch_size": 2,
          "gradient_accumulation_steps": 1,
          "learning_rate": 1e-5,
          "response_template": "\n### Response:",
          "dataset_text_field": "output"
      }

Training logs:

{'loss': 0.8331, 'grad_norm': 364.0, 'learning_rate': 9e-06, 'epoch': 1.0}
{'loss': 0.4259, 'grad_norm': 0.10986328125, 'learning_rate': 8.000000000000001e-06, 'epoch': 2.0}
{'loss': 0.1667, 'grad_norm': 25.25, 'learning_rate': 7e-06, 'epoch': 3.0}
{'loss': 0.0304, 'grad_norm': 21.625, 'learning_rate': 6e-06, 'epoch': 4.0}
{'loss': 0.0023, 'grad_norm': 0.005828857421875, 'learning_rate': 5e-06, 'epoch': 5.0}
{'loss': 0.0004, 'grad_norm': 0.005157470703125, 'learning_rate': 4.000000000000001e-06, 'epoch': 6.0}
{'loss': 0.0001, 'grad_norm': 0.0038604736328125, 'learning_rate': 3e-06, 'epoch': 7.0}
{'loss': 0.0001, 'grad_norm': 0.000469207763671875, 'learning_rate': 2.0000000000000003e-06, 'epoch': 8.0}
{'loss': 0.0001, 'grad_norm': 0.004547119140625, 'learning_rate': 1.0000000000000002e-06, 'epoch': 9.0}
{'loss': 0.0001, 'grad_norm': 0.01324462890625, 'learning_rate': 0.0, 'epoch': 10.0}
{'train_runtime': 5311.528, 'train_samples_per_second': 1.883, 'train_steps_per_second': 0.941, 'train_loss': 0.1459229184500873, 'epoch': 10.0}

Location: /testing/tuning/output/granite-3b-moe/ft/20241121_1314-tone/save_model

Fast MOE

And with fast_moe:

      {
          "model_name_or_path": "/ibm_dmf_lakehouse/models/base_training/shared/granite-3.0-3b-a800m-base/r240924a",
          "training_data_path": "/testing/tuning/input/cc_tone_sft_format_1000_train.json",
          "output_dir": "/testing/tuning/output/granite-3b-moe/ft/20241121_1014-tone-FAST",
          "save_model_dir": "/testing/tuning/output/granite-3b-moe/ft/20241121_1014-tone-FAST/save_model",
          "num_train_epochs": 10.0,
          "per_device_train_batch_size": 2,
          "gradient_accumulation_steps": 1,
          "learning_rate": 1e-5,
          "response_template": "\n### Response:",
          "dataset_text_field": "output",
          "fast_moe": 1
      }

Training logs

{'loss': 0.4279, 'grad_norm': 0.076171875, 'learning_rate': 8.000000000000001e-06, 'epoch': 2.0}
{'loss': 0.1377, 'grad_norm': 3.78125, 'learning_rate': 7e-06, 'epoch': 3.0}
{'loss': 0.0384, 'grad_norm': 0.81640625, 'learning_rate': 6e-06, 'epoch': 4.0}
{'loss': 0.0031, 'grad_norm': 0.003997802734375, 'learning_rate': 5e-06, 'epoch': 5.0}
{'loss': 0.0006, 'grad_norm': 0.002044677734375, 'learning_rate': 4.000000000000001e-06, 'epoch': 6.0}
{'loss': 0.0002, 'grad_norm': 0.0032196044921875, 'learning_rate': 3e-06, 'epoch': 7.0}
{'loss': 0.0001, 'grad_norm': 0.002288818359375, 'learning_rate': 2.0000000000000003e-06, 'epoch': 8.0}
{'loss': 0.0001, 'grad_norm': 0.0087890625, 'learning_rate': 1.0000000000000002e-06, 'epoch': 9.0}
{'loss': 0.0001, 'grad_norm': 0.0115966796875, 'learning_rate': 0.0, 'epoch': 10.0}
{'train_runtime': 2140.2943, 'train_samples_per_second': 4.672, 'train_steps_per_second': 2.336, 'train_loss': 0.14420232288464904, 'epoch': 10.0}

Location: /testing/tuning/output/granite-3b-moe/ft/20241121_1315-tone-FAST/save_model

Results

We see a 2.48x speedup

for f in fields(dataclass):
nested_type = type_hints[f.name]
values = getattr(dataclass, f.name)
if values is not None and not is_dataclass(values):
values = nested_type(*values)
if isinstance(values, Iterable) and not isinstance(values, (str, bytes)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willmj can explain this fix? in this PR we are not doing anything much different then before, so why is this needed?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since there's only one value in FastMoeConfig that's an int, if the previous implementation tried to unpack it, it caused this error:

Traceback (most recent call last):
  File "/home/tuning/.local/lib/python3.11/site-packages/tuning/sft_trainer.py", line 601, in main
    ) = parse_arguments(parser, job_config)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.11/site-packages/tuning/sft_trainer.py", line 535, in parse_arguments
    ) = parser.parse_dict(json_config, allow_extra_keys=True)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/home/tuning/.local/lib/python3.11/site-packages/transformers/hf_argparser.py", line 374, in parse_dict
    obj = dtype(**inputs)
          ^^^^^^^^^^^^^^^
  File "<string>", line 4, in __init__
  File "/home/tuning/.local/lib/python3.11/site-packages/tuning/config/acceleration_configs/fast_moe.py", line 36, in __post_init__
    ensure_nested_dataclasses_initialized(self)
  File "/home/tuning/.local/lib/python3.11/site-packages/tuning/config/acceleration_configs/utils.py", line 34, in ensure_nested_dataclasses_initialized
    values = nested_type(*values)
             ^^^^^^^^^^^^^^^^^^^^
TypeError: tuning.config.acceleration_configs.utils.parsable_dataclass.<locals>.ParsableDataclass() argument after * must be an iterable, not int

So this seemed like the best way to avoid that. If you have another solution though let me know.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@willmj thanks for the explaination. this fix seems reasonable, but i would suggest putting comments to explain this in the code

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if it makes more sense to check if is int and then the else case can be the more common case instead. You could also future proof more by checking if is int/float/bool

@fabianlim
Copy link
Collaborator

fabianlim commented Nov 22, 2024

@willmj In the original PR we reported the benches where the batch sizes are different, but the numbers that you report here are around that ballpark.

c.f., the numbers in the bench in a table (from the original PR)

@willmj
Copy link
Collaborator Author

willmj commented Dec 9, 2024

After running checkpoint utils on the branch Fabian created for safetensors, vLLM inference ran as expected:

% grpcurl -plaintext -proto ./proto/generation.proto -d "{\"params\":{\"method\":\"GREEDY\", \"stopping\": {\"max_new_tokens\": 128}}, \"requests\": [{\"text\":\"### Text: @sho_help @showtime your arrive is terrible streaming is stop and start every couple mins. Get it together it's xmas\n\n### Label:\"}]}" localhost:8033 fmaas.GenerationService/Generate
{
  "responses": [
    {
      "generatedTokenCount": 128,
      "text": " sad, frustrated, anxious, anxious, frustrated, sad, anxious, anxious, frustrated, sad, frustrated, anxious, frustrated, sad, frustrated, anxious, sad, frustrated, anxious, sad, frustrated, anxious, sad, frustrated, anxious, sad, frustrated, anxious, sad, frustrated, anxious, sad, frustrated, anxious, sad, frustrated, anxious, sad,",
      "inputTokenCount": 38,
      "stopReason": "MAX_TOKENS"
    }
  ]
}

Post-processing completed with this script (thanks again Fabian!):

from fms_acceleration_moe.utils.checkpoint_utils import get_state_dict_from_safe_checkpoint, recover_original_state_dict_from_checkpoint, save_single_safetensor
from safetensors.torch import save_file
from transformers.utils import SAFE_WEIGHTS_NAME, CONFIG_NAME
import os, shutil, json

checkpoint_dir = '<scattermoe-checkpoing-dir>'

output_dir = '<output-dir>'
pretrained_model_name_or_path = '<original-model-dir>'

config_file = os.path.join(checkpoint_dir, CONFIG_NAME)
target_config_file = os.path.join(output_dir, CONFIG_NAME)
if os.path.exists(config_file):
    shutil.copyfile(config_file, target_config_file)

    if not pretrained_model_name_or_path:
        with open(target_config_file) as f:
            pretrained_model_name_or_path = json.load(f).get("_name_or_path")


sd = get_state_dict_from_safe_checkpoint(checkpoint_dir)

sd = recover_original_state_dict_from_checkpoint(
    sd, pretrained_model_name_or_path
)


save_single_safetensor(
    {k: v.contiguous() for k, v in sd.items()},
    output_dir,
    metadata={"format": "pt"},
)


from transformers import AutoModelForCausalLM

# test if we can load the converted state dict
model = AutoModelForCausalLM.from_pretrained(output_dir)

FastMOE model saved in: /testing/tuning/output/granite-3b-moe/ft/20241121_1315-tone-FAST/save_model
Reconstructed SD model saved in: /testing/tuning/output/granite-3b-moe/ft/20241121_1315-tone-FAST/standard-sd

@willmj willmj marked this pull request as ready for review December 9, 2024 18:49
@willmj willmj requested a review from kmehant as a code owner December 9, 2024 18:49
@fabianlim fabianlim changed the title feat: [WIP] dataclass args for accelerated MoE tuning feat: dataclass args for accelerated MoE tuning Dec 10, 2024
@willmj willmj force-pushed the feat-dataclass-args-scattermoe branch from 443b6b5 to c746655 Compare January 2, 2025 19:47
Signed-off-by: Will Johnson <[email protected]>
@willmj willmj force-pushed the feat-dataclass-args-scattermoe branch from eda08da to bd7e2ad Compare January 3, 2025 16:52
Signed-off-by: Will Johnson <[email protected]>
Copy link
Collaborator

@anhuong anhuong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for adding the fast MoE plugin and flag Will! A few clarifying questions...

README.md Outdated Show resolved Hide resolved
@@ -760,6 +762,9 @@ Notes:
* Notes on Multipack
- works only for *multi-gpu*.
- currently only includes the version of *multipack* optimized for linear attention implementations like *flash-attn*.
* Notes of Fast MOE
- `--fast_moe` is an integer value that configures the amount of expert parallel sharding (ep_degree).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a recommended value to use for this? I see the default is 1, would 1 work to configure the amount of expert parallel sharding or does it need to be greater than 1?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

From my understanding 1 should work in every case and does work to configure the ep sharding. @fabianlim might be able to provide more insight.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for my understanding, so this flag can be used with any MoE model and only for fine tuning (as described in fms-acceleration docs)?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a recommended value to use for this? I see the default is 1, would 1 work to configure the amount of expert parallel sharding or does it need to be greater than 1?

If you use 1 there you will switch to the scattermoe kernels without sharding. If > 1 it will be the kernels + sharding. In both cases gains will be observed in here; the ep in the names refer to expert world size

Also for my understanding, so this flag can be used with any MoE model and only for fine tuning (as described in fms-acceleration docs)?

No sorry, our moe plugin reqiures a module swap, which means we require code to tell how to swap a the MoE modele for any particular model, currently 2 MoE models with representative supported, and for additional model support, we need to add a new spec, but hopefully its a copy-and-paste effort if the archs are not much different.

README.md Outdated Show resolved Hide resolved
build/Dockerfile Outdated Show resolved Hide resolved
tests/acceleration/test_acceleration_framework.py Outdated Show resolved Hide resolved
for f in fields(dataclass):
nested_type = type_hints[f.name]
values = getattr(dataclass, f.name)
if values is not None and not is_dataclass(values):
values = nested_type(*values)
if isinstance(values, Iterable) and not isinstance(values, (str, bytes)):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering if it makes more sense to check if is int and then the else case can be the more common case instead. You could also future proof more by checking if is int/float/bool

willmj added 2 commits January 7, 2025 10:14
Signed-off-by: Will Johnson <[email protected]>
anhuong
anhuong previously approved these changes Jan 7, 2025
Copy link
Collaborator

@anhuong anhuong left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Small spelling error but otherwise the changes look good 👍

README.md Outdated Show resolved Hide resolved
@@ -760,6 +762,9 @@ Notes:
* Notes on Multipack
- works only for *multi-gpu*.
- currently only includes the version of *multipack* optimized for linear attention implementations like *flash-attn*.
* Notes of Fast MOE
- `--fast_moe` is an integer value that configures the amount of expert parallel sharding (ep_degree).
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also for my understanding, so this flag can be used with any MoE model and only for fine tuning (as described in fms-acceleration docs)?

@willmj
Copy link
Collaborator Author

willmj commented Jan 7, 2025

Also for my understanding, so this flag can be used with any MoE model and only for fine tuning (as described in fms-acceleration docs)?

Yes that's correct! LoRA activation will be coming soon though.

@anhuong anhuong merged commit 8851227 into foundation-model-stack:main Jan 7, 2025
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants