Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[paddle] add metrics for llama-7b #278

Merged
merged 16 commits into from
Oct 7, 2023
Empty file.
263 changes: 0 additions & 263 deletions training/benchmarks/llama1_7B/paddle/model/models/modeling_pp.py

This file was deleted.

38 changes: 16 additions & 22 deletions training/benchmarks/llama1_7B/paddle/run_pretraining.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
LinearAnnealingWithWarmupDecay,
LlamaConfig,
LlamaForCausalLM,
LlamaForCausalLMPipe,
register_sequence_parallel_allreduce_hooks,
)

Expand All @@ -39,7 +40,6 @@
from driver import Driver, Event, dist_paddle
from driver.config_manager import get_properties_from_config
from dataloaders.dataloader import create_pretrained_dataset, get_train_data_file
from model.models.modeling_pp import LlamaForCausalLMPipe
from train.trainer import PretrainingTrainer
from train.training_state import TrainingState

Expand Down Expand Up @@ -232,7 +232,7 @@ def main():
model_args, data_args, training_args = parser.parse_dict(
get_properties_from_config(config)
)
data_args.input_dir = gpt_driver.config.data_dir
data_args.input_dir = llama_driver.config.data_dir

if model_args.tokenizer_name_or_path is None:
model_args.tokenizer_name_or_path = model_args.model_name_or_path
Expand Down Expand Up @@ -331,24 +331,20 @@ def main():
model.recompute_enable()

# Create the learning_rate sheduler and optimizer
if training_args.decay_steps is None:
training_args.decay_steps = training_args.max_steps
warmup_steps = training_args.warmup_ratio * training_args.max_steps

lr_scheduler = None
if training_args.lr_scheduler_type.value == "cosine":
lr_scheduler = CosineAnnealingWithWarmupDecay(
max_lr=training_args.learning_rate,
min_lr=training_args.min_learning_rate,
warmup_step=warmup_steps,
warmup_step=training_args.warmup_steps,
decay_step=training_args.decay_steps,
last_epoch=0,
)
elif training_args.lr_scheduler_type.value == "linear":
lr_scheduler = LinearAnnealingWithWarmupDecay(
max_lr=training_args.learning_rate,
min_lr=training_args.min_learning_rate,
warmup_step=warmup_steps,
warmup_step=training_args.warmup_steps,
decay_step=training_args.decay_steps,
last_epoch=0,
)
Expand Down Expand Up @@ -403,22 +399,20 @@ def main():
trainer.log_metrics("train", metrics)
trainer.save_metrics("train", metrics)
trainer.save_state()
training_state.raw_train_time = train_metrics["train_runtime"]
training_state.training_sequences_per_second = train_metrics[
"train_samples_per_second"
]
training_state.loss = train_metrics["train_loss"]
training_state.effective_tokens_per_second = total_effective_tokens / train_metrics["train_runtime"]
training_state.raw_train_time = metrics["train_runtime"]
training_state.training_sequences_per_second = metrics["train_samples_per_second"]
training_state.loss = metrics["train_loss"]
training_state.effective_tokens_per_second = total_effective_tokens / metrics["train_runtime"]
except:
training_state.end_training = False

# End Evaluation
dist_paddle.barrier()
eval_metrics = trainer.evaluate()
training_state.eval_loss = eval_metrics["eval_loss"]
training_state.eval_ppl = eval_metrics["eval_ppl"]
if eval_metrics["eval_ppl"] < config.target_ppl:
training_state.converged_success()
# dist_paddle.barrier()
# eval_metrics = trainer.evaluate()
# training_state.eval_loss = eval_metrics["eval_loss"]
# training_state.eval_ppl = eval_metrics["eval_ppl"]
# if eval_metrics["eval_ppl"] < config.target_ppl:
# training_state.converged_success()

return training_args, training_state, llama_driver

Expand All @@ -438,8 +432,8 @@ def main():
"training_sequences_per_second": state.training_sequences_per_second,
"effective_tokens_per_second": state.effective_tokens_per_second,
"converged": state.converged,
"final_loss": state.eval_loss,
"final_ppl": state.eval_ppl,
# "final_loss": state.eval_loss,
# "final_ppl": state.eval_ppl,
"raw_train_time": state.raw_train_time,
"init_time": state.init_time,
}
Expand Down
20 changes: 13 additions & 7 deletions training/nvidia/llama1_7B-paddle/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,21 +31,27 @@ wget https://bj.bcebos.com/paddlenlp/models/transformers/llama/data/llama_openwe
| 指标名称 | 指标值 | 特殊说明 |
| -------------- | ------------------------------ | ------------------------------------------- |
| 任务类别 | 文本分类、文本生成 | |
| 模型 | llama1_7B | |
| 模型 | llama1 | |
| 数据集 | openwebtext | |
| 配置文件 | config | |
| 数据精度 | precision,见“性能指标” | 可选fp32/amp/fp16 |
| 超参修改 | fix_hp,见“性能指标” | 跑满硬件设备评测吞吐量所需特殊超参 |
| 并行策略 | parallel_strategy,见“性能指标” | DP, TP, PP, SP |
| 硬件设备简称 | nvidia A100 (80G *8) | |
| 硬件设备简称 | nvidia A100 (80G * 8) and (40G * 8) | |
| 硬件存储使用 | memory(actual/total),见“性能指标” | 通常称为“显存”,单位为GiB |
| 吞吐量 | throughput,见“性能指标” | 训练吞吐量 |

* 性能指标

| 配置 | config | precision | fix_hp | parallel_strategy | throughput | memory |
| ------- | ------- | --------- | ------ | ---------------- | ------------ | ------ |
| A100单机8卡(1x8) | config_TP1PP1SH2SP8A100x1x8 | fp16, level="O2" | per_device_bs=4, accumulate=32, (global bs = 2M tokens) | flash_attention=True, recompute=False, use_fused_rms_norm=True, sharding="stage2", sharding_degree=8 | 15.70715 * 2048 / 8 = 4021 tokens/s | 76.98 * 8 GB |
| A100单机8卡(1x8) | config_TP2PP1SH1SP4A100x1x8 | fp16, level="O2" | per_device_bs=4, accumulate=64, (global bs = 2M tokens) | flash_attention=True, recompute=False, use_fused_rms_norm=True, sharding="stage1", sharding_degree=4, tensor_parallel_degree=2 | 14.27326 * 2048 / 8 = 3653 tokens/s | 62.11 * 8 GB |
| A100单机8卡(1x8) | config_TP2PP1SH2SP4A100x1x8 | fp16, level="O2" | per_device_bs=4, accumulate=64, (global bs = 2M tokens) | flash_attention=True, recompute=False, use_fused_rms_norm=True, sharding="stage2", sharding_degree=4, tensor_parallel_degree=2 | 13.48227 * 2048 / 8 = 3451 tokens/s | 57.63 * 8 GB |
| A100单机8卡(1x8) | config_TP2PP4SH1SP1A100x1x8 | fp16, level="O2" | per_device_bs=4, accumulate=64, (global bs = 2M tokens) | flash_attention=True, recompute=False, use_fused_rms_norm=True, sharding="stage2", sharding_degree=4, tensor_parallel_degree=2 | 13.644565 * 2048 / 8 = 3493 tokens/s | 58.62\*2 + 53.51\*2 + 49.46\*2 + 47.95\*2 GB |
| ------- | ------- | --------- | ------ | ---------------- | ------------ | ------ |
| LLaMA-7B | ------- | --------- | ------ | ---------------- | ------------ | ------ |
| A100单机8卡(1x8*80G) | config_TP1PP1SH2SP8A10080Gx1x8 | fp16, level="O2" | per_device_bs=4, accumulate=64, (global bs = 4M tokens) | flash_attention=True, recompute=False, use_fused_rms_norm=True, sharding="stage2", sharding_degree=8 | 16.67 * 2048 / 8 = 4267 tokens/s | 70.09 * 8 GB |
| A100单机8卡(1x8*80G) | config_TP2PP1SH1SP4A10080Gx1x8 | fp16, level="O2" | per_device_bs=4, accumulate=128, (global bs = 4M tokens) | flash_attention=True, recompute=False, use_fused_rms_norm=True, sharding="stage1", sharding_degree=4, tensor_parallel_degree=2 | 15.19 * 2048 / 8 = 3888 tokens/s | 58.73 * 8 GB |
| A100单机8卡(1x8*80G) | config_TP2PP1SH2SP4A10080Gx1x8 | fp16, level="O2" | per_device_bs=4, accumulate=128, (global bs = 4M tokens) | flash_attention=True, recompute=False, use_fused_rms_norm=True, sharding="stage2", sharding_degree=4, tensor_parallel_degree=2 | 14.26 * 2048 / 8 = 3650 tokens/s | 54.01 * 8 GB |
| A100单机8卡(1x8*80G) | config_TP2PP4SH1SP1A10080Gx1x8 | fp16, level="O2" | per_device_bs=4, accumulate=512, (global bs = 4M tokens) | flash_attention=True, recompute=False, use_fused_rms_norm=True, sharding="stage1", tensor_parallel_degree=2, pipline_parallel_degree=4 | 14.54 * 2048 / 8 = 3722 tokens/s | 46.80\*2 + 38.93\*2 + 31.74\*2 + 26.92\*2 GB |
| LLaMA-7B | ------- | --------- | ------ | ---------------- | ------------ | ------ |
| A100单机8卡(1x8*40G) | config_TP1PP1SH2SP8A10040Gx1x8 | fp16, level="O2" | per_device_bs=2, accumulate=128, (global bs =4M tokens) | flash_attention=True, recompute=True, use_fused_rms_norm=False, sharding="stage2", sharding_degree=8 | 10.72 * 2048 / 8 = 2744 tokens/s | 33.55 * 8 GB |
| A100单机8卡(1x8*40G) | config_TP2PP1SH1SP4A10040Gx1x8 | fp16, level="O2" | per_device_bs=2, accumulate=256, (global bs = 4M tokens) | flash_attention=True, recompute=True, use_fused_rms_norm=False, sharding="stage1", sharding_degree=4, tensor_parallel_degree=2 | 8.45 * 2048 / 8 = 2163 tokens/s | 28.4 * 8 GB |
| A100单机8卡(1x8*40G) | config_TP2PP1SH2SP4A10040Gx1x8 | fp16, level="O2" | per_device_bs=2, accumulate=256, (global bs = 4M tokens) | flash_attention=True, recompute=True, use_fused_rms_norm=False, sharding="stage2", sharding_degree=4, tensor_parallel_degree=2 | 8.44 * 2048 / 8 = 2160 tokens/s | 25.8 * 8 GB |
| A100单机8卡(1x8*40G) | config_TP2PP4SH1SP1A10040Gx1x8 | fp16, level="O2" | per_device_bs=2, accumulate=1024, (global bs = 4M tokens) | flash_attention=True, recompute=True, use_fused_rms_norm=False, sharding="stage1", tensor_parallel_degree=2, pipline_parallel_degree=4 | 8.72 * 2048 / 8 = 2232 tokens/s | 20.41\*2 + 19.80\*2 + 19.41\*2 + 20.12\*2 GB |
Loading