diff --git a/tests/compile/test_basic_correctness.py b/tests/compile/test_basic_correctness.py index 77c56d91d0a8b..4ce5b76bb0036 100644 --- a/tests/compile/test_basic_correctness.py +++ b/tests/compile/test_basic_correctness.py @@ -14,9 +14,8 @@ "model, model_args, pp_size, tp_size, attn_backend, method, fullgraph", [ ("meta-llama/Llama-3.2-1B", [], 2, 2, "FLASH_ATTN", "generate", True), - ("nm-testing/Meta-Llama-3-8B-Instruct-W8A8-Dyn-Per-Token-2048-Samples", - ["--quantization", "compressed-tensors" - ], 1, 1, "FLASH_ATTN", "generate", True), + ("THUDM/chatglm3-6b", + ["--trust-remote-code"], 1, 1, "FLASH_ATTN", "generate", True), ("google/gemma-2-2b-it", [], 1, 2, "FLASHINFER", "generate", True), # TODO: add multi-modality test for llava ("llava-hf/llava-1.5-7b-hf", [], 2, 1, "FLASHINFER", "generate", False) diff --git a/vllm/model_executor/models/chatglm.py b/vllm/model_executor/models/chatglm.py index 8283975b9d8e2..6536801630da2 100644 --- a/vllm/model_executor/models/chatglm.py +++ b/vllm/model_executor/models/chatglm.py @@ -12,6 +12,7 @@ from torch.nn import LayerNorm from vllm.attention import Attention, AttentionMetadata +from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, LoRAConfig, MultiModalConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.inputs import INPUT_REGISTRY, DecoderOnlyInputs, InputContext @@ -472,6 +473,7 @@ def forward( return hidden_states +@support_torch_compile class ChatGLMModel(nn.Module): def __init__(