From 262122b00d753c3ea1ad14dde956c601b60cf9ee Mon Sep 17 00:00:00 2001 From: akashc1 <43617927+akashc1@users.noreply.github.com> Date: Fri, 10 Jan 2025 12:02:46 -0800 Subject: [PATCH] llama 3.1 has correct `max_seq_len` for all versions (#2203) --- torchtune/models/llama3_1/_model_builders.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/torchtune/models/llama3_1/_model_builders.py b/torchtune/models/llama3_1/_model_builders.py index b6439b2eb2..f48ce580f5 100644 --- a/torchtune/models/llama3_1/_model_builders.py +++ b/torchtune/models/llama3_1/_model_builders.py @@ -73,7 +73,7 @@ def llama3_1_405b() -> TransformerDecoder: num_heads=128, num_kv_heads=8, embed_dim=16384, - max_seq_len=8192, + max_seq_len=131072, intermediate_dim=53248, attn_dropout=0.0, norm_eps=1e-5, @@ -236,7 +236,7 @@ def lora_llama3_1_405b( num_heads=128, num_kv_heads=8, embed_dim=16384, - max_seq_len=8192, + max_seq_len=131072, intermediate_dim=53248, attn_dropout=0.0, norm_eps=1e-5,