From 2ffff717b17d2247a8e147e1e3065ba7de2e04c8 Mon Sep 17 00:00:00 2001 From: Benjamin Huo Date: Sun, 30 Jul 2023 16:22:44 +0800 Subject: [PATCH] Add inference support for Macbook silicon chip Signed-off-by: Benjamin Huo --- inference/README.md | 2 ++ inference/serve/gorilla_cli.py | 2 ++ 2 files changed, 4 insertions(+) diff --git a/inference/README.md b/inference/README.md index 9edbe83fe..10fc72d1e 100644 --- a/inference/README.md +++ b/inference/README.md @@ -56,6 +56,8 @@ For the falcon-7b model, you can use the following command: python3 serve/gorilla_falcon_cli.py --model-path path/to/gorilla-falcon-7b-hf-v0 ``` +> Add "--device mps" if you are running on your Mac with Apple silicon (M1, M2, etc) + ### [Optional] Batch Inference on a Prompt File After downloading the model, you need to make a jsonl file containing all the question you want to inference through Gorilla. Here is [one example](https://github.com/ShishirPatil/gorilla/blob/main/inference/example_questions/example_questions.jsonl): diff --git a/inference/serve/gorilla_cli.py b/inference/serve/gorilla_cli.py index 4938dfacc..9a0635098 100644 --- a/inference/serve/gorilla_cli.py +++ b/inference/serve/gorilla_cli.py @@ -67,6 +67,8 @@ def load_model( } else: kwargs["max_memory"] = {i: max_gpu_memory for i in range(num_gpus)} + elif device == "mps": + kwargs = {"torch_dtype": torch.float16} else: raise ValueError(f"Invalid device: {device}")