diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index e66011a19a7..42b2f686381 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -906,8 +906,8 @@ def __init__( # Validation is done in the model itself if num_kv_heads is None: # Order is important here. - for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]: - num_kv_heads = getattr(config, "num_attention_heads", None) + for attr in ["num_key_value_heads", "num_attention_heads", "n_head"]: + num_kv_heads = getattr(config, attr, None) if num_kv_heads is not None: break if num_kv_heads is None: