Skip to content

Commit

Permalink
Specify ASCEND NPU for inference.
Browse files Browse the repository at this point in the history
  • Loading branch information
as12138 committed Dec 3, 2024
1 parent 1cd4b74 commit 8fe76b9
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 0 deletions.
9 changes: 9 additions & 0 deletions fastchat/serve/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
- Type "!!save <filename>" to save the conversation history to a json file.
- Type "!!load <filename>" to load a conversation history from a json file.
"""

import argparse
import os
import re
Expand Down Expand Up @@ -197,6 +198,14 @@ def main(args):
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
os.environ["XPU_VISIBLE_DEVICES"] = args.gpus
if len(args.gpus.split(",")) == 1:
try:
import torch_npu

torch.npu.set_device(int(args.gpus))
print(f"NPU is available, now model is running on npu:{args.gpus}")
except ModuleNotFoundError:
pass
if args.enable_exllama:
exllama_config = ExllamaConfig(
max_seq_len=args.exllama_max_seq_len,
Expand Down
9 changes: 9 additions & 0 deletions fastchat/serve/model_worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
A model worker that executes the model.
"""

import argparse
import base64
import gc
Expand Down Expand Up @@ -351,6 +352,14 @@ def create_model_worker():
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
if len(args.gpus.split(",")) == 1:
try:
import torch_npu

torch.npu.set_device(int(args.gpus))
print(f"NPU is available, now model is running on npu:{args.gpus}")
except ModuleNotFoundError:
pass

gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
Expand Down
9 changes: 9 additions & 0 deletions fastchat/serve/multi_model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
We recommend using this with multiple Peft models (with `peft` in the name)
where all Peft models are trained on the exact same base model.
"""

import argparse
import asyncio
import dataclasses
Expand Down Expand Up @@ -206,6 +207,14 @@ def create_multi_model_worker():
f"Larger --num-gpus ({args.num_gpus}) than --gpus {args.gpus}!"
)
os.environ["CUDA_VISIBLE_DEVICES"] = args.gpus
if len(args.gpus.split(",")) == 1:
try:
import torch_npu

torch.npu.set_device(int(args.gpus))
print(f"NPU is available, now model is running on npu:{args.gpus}")
except ModuleNotFoundError:
pass

gptq_config = GptqConfig(
ckpt=args.gptq_ckpt or args.model_path,
Expand Down

0 comments on commit 8fe76b9

Please sign in to comment.