From a4d7b779a58f2f3abb1db28884df349aa6fdffdf Mon Sep 17 00:00:00 2001 From: sky <2953427313@qq.com> Date: Wed, 7 Aug 2024 13:54:56 +0800 Subject: [PATCH] fix v100 TORCH_TYPE --- web_demo_2.6.py | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) diff --git a/web_demo_2.6.py b/web_demo_2.6.py index 2bebc19d..ea58e9a2 100644 --- a/web_demo_2.6.py +++ b/web_demo_2.6.py @@ -33,6 +33,10 @@ device = args.device assert device in ['cuda', 'mps'] + +TORCH_TYPE = torch.bfloat16 if torch.cuda.is_available() and torch.cuda.get_device_capability()[ + 0] >= 8 else torch.float16 + # Load model model_path = 'openbmb/MiniCPM-V-2_6' if 'int4' in model_path: @@ -44,7 +48,7 @@ if args.multi_gpus: from accelerate import load_checkpoint_and_dispatch, init_empty_weights, infer_auto_device_map with init_empty_weights(): - model = AutoModel.from_pretrained(model_path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=torch.bfloat16) + model = AutoModel.from_pretrained(model_path, trust_remote_code=True, attn_implementation='sdpa', torch_dtype=TORCH_TYPE) device_map = infer_auto_device_map(model, max_memory={0: "10GB", 1: "10GB"}, no_split_module_classes=['SiglipVisionTransformer', 'Qwen2DecoderLayer']) device_id = device_map["llm.model.embed_tokens"] @@ -63,9 +67,9 @@ device_map["llm.model.layers.16"] = device_id2 #print(device_map) - model = load_checkpoint_and_dispatch(model, model_path, dtype=torch.bfloat16, device_map=device_map) + model = load_checkpoint_and_dispatch(model, model_path, dtype=TORCH_TYPE, device_map=device_map) else: - model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=torch.bfloat16) + model = AutoModel.from_pretrained(model_path, trust_remote_code=True, torch_dtype=TORCH_TYPE) model = model.to(device=device) tokenizer = AutoTokenizer.from_pretrained(model_path, trust_remote_code=True) model.eval() @@ -554,4 +558,3 @@ def select_chat_type(_tab, _app_cfg): # launch demo.launch(share=False, debug=True, show_api=False, server_port=8885, server_name="0.0.0.0") -