diff --git a/xtuner/tools/mmbench.py b/xtuner/tools/mmbench.py index 133355b73..93fde3f15 100644 --- a/xtuner/tools/mmbench.py +++ b/xtuner/tools/mmbench.py @@ -23,7 +23,8 @@ from torch.utils.data import Dataset from transformers import (AutoModel, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, CLIPImageProcessor, - CLIPVisionModel, GenerationConfig) + CLIPVisionModel, GenerationConfig, + SiglipImageProcessor, SiglipVisionModel) from xtuner.dataset.utils import decode_base64_to_image, expand2square from xtuner.model.utils import LoadWoInit, prepare_inputs_labels_for_multimodal @@ -330,10 +331,21 @@ def main(): 'Please specify the `--visual-encoder`!') visual_encoder_path = args.visual_encoder with LoadWoInit(): - visual_encoder = CLIPVisionModel.from_pretrained( - visual_encoder_path, torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype]) - image_processor = CLIPImageProcessor.from_pretrained( - visual_encoder_path) + if 'clip' in visual_encoder_path.lower(): + visual_encoder = CLIPVisionModel.from_pretrained( + visual_encoder_path, + torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype]) + image_processor = CLIPImageProcessor.from_pretrained( + visual_encoder_path) + elif 'siglip' in visual_encoder_path.lower(): + visual_encoder = SiglipVisionModel.from_pretrained( + visual_encoder_path, + torch_dtype=TORCH_DTYPE_MAP[args.torch_dtype]) + image_processor = SiglipImageProcessor.from_pretrained( + visual_encoder_path) + else: + raise f'Visual encoders not supported : {visual_encoder_path}' + master_print(f'Load visual_encoder from {visual_encoder_path}') # load adapter @@ -506,5 +518,4 @@ def main(): if __name__ == '__main__': - main()