Skip to content

Commit

Permalink
Fix ONNXRT example with upgraded optimum 1.14.0 (#1381)
Browse files Browse the repository at this point in the history
Signed-off-by: Mengni Wang <[email protected]>
Signed-off-by: yuwenzho <[email protected]>
  • Loading branch information
mengniwang95 authored Nov 13, 2023
1 parent f5167dc commit da3442d
Show file tree
Hide file tree
Showing 5 changed files with 97 additions and 18 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import argparse
import os
import subprocess
import optimum.version
from packaging.version import Version
OPTIMUM114_VERSION = Version("1.14.0")


def parse_arguments():
Expand All @@ -12,20 +15,37 @@ def parse_arguments():

def prepare_model(input_model, output_model):
print("\nexport model...")
subprocess.run(
[
"optimum-cli",
"export",
"onnx",
"--model",
f"{input_model}",
"--task",
"text-generation-with-past",
f"{output_model}",
],
stdout=subprocess.PIPE,
text=True,
)
if Version(optimum.version.__version__) >= OPTIMUM114_VERSION:
subprocess.run(
[
"optimum-cli",
"export",
"onnx",
"--model",
f"{input_model}",
"--task",
"text-generation-with-past",
"--legacy",
f"{output_model}",
],
stdout=subprocess.PIPE,
text=True,
)
else:
subprocess.run(
[
"optimum-cli",
"export",
"onnx",
"--model",
f"{input_model}",
"--task",
"text-generation-with-past",
f"{output_model}",
],
stdout=subprocess.PIPE,
text=True,
)

assert os.path.exists(output_model), f"{output_model} doesn't exist!"

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pip install -r requirements.txt
## 2. Prepare Model

```bash
optimum-cli export onnx --model decapoda-research/llama-7b-hf --task text-generation-with-past ./llama_7b
python prepare_model.py --input_model="decapoda-research/llama-7b-hf" --output_model="./llama_7b"
```

# Run
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import argparse
import os
import subprocess
import optimum.version
from packaging.version import Version
OPTIMUM114_VERSION = Version("1.14.0")


def parse_arguments():
parser = argparse.ArgumentParser()
parser.add_argument("--input_model", type=str, required=False, default="")
parser.add_argument("--output_model", type=str, required=True)
return parser.parse_args()


def prepare_model(input_model, output_model):
print("\nexport model...")
if Version(optimum.version.__version__) >= OPTIMUM114_VERSION:
subprocess.run(
[
"optimum-cli",
"export",
"onnx",
"--model",
f"{input_model}",
"--task",
"text-generation-with-past",
"--legacy",
f"{output_model}",
],
stdout=subprocess.PIPE,
text=True,
)
else:
subprocess.run(
[
"optimum-cli",
"export",
"onnx",
"--model",
f"{input_model}",
"--task",
"text-generation-with-past",
f"{output_model}",
],
stdout=subprocess.PIPE,
text=True,
)

assert os.path.exists(output_model), f"{output_model} doesn't exist!"


if __name__ == "__main__":
args = parse_arguments()
prepare_model(args.input_model, args.output_model)
4 changes: 3 additions & 1 deletion test/adaptor/onnxrt_adaptor/test_weight_only_adaptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ def __iter__(self):
class TestWeightOnlyAdaptor(unittest.TestCase):
@classmethod
def setUpClass(self):
cmd = "optimum-cli export onnx --model hf-internal-testing/tiny-random-gptj --task text-generation gptj/"
cmd = (
"optimum-cli export onnx --model hf-internal-testing/tiny-random-gptj --task text-generation --legacy gptj/"
)
p = subprocess.Popen(
cmd, preexec_fn=os.setsid, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True
) # nosec
Expand Down
6 changes: 4 additions & 2 deletions test/model/test_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,9 @@ def setUp(self):
model = onnx.helper.make_model(graph, **{"opset_imports": [onnx.helper.make_opsetid("", 14)]})
self.matmul_reshape_model = model

cmd = "optimum-cli export onnx --model hf-internal-testing/tiny-random-gptj --task text-generation gptj/"
cmd = (
"optimum-cli export onnx --model hf-internal-testing/tiny-random-gptj --task text-generation --legacy gptj/"
)
p = subprocess.Popen(
cmd, preexec_fn=os.setsid, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True
) # nosec
Expand All @@ -216,7 +218,7 @@ def test_hf_model(self):

config = AutoConfig.from_pretrained("hf_test")
sessions = ORTModelForCausalLM.load_model("hf_test/decoder_model.onnx")
model = ORTModelForCausalLM(sessions[0], config, "hf_test", use_cache=False, use_io_binding=False)
model = ORTModelForCausalLM(sessions, config, model_save_dir="hf_test", use_cache=False, use_io_binding=False)
self.assertNotEqual(model, None)

def test_nodes(self):
Expand Down

0 comments on commit da3442d

Please sign in to comment.