Skip to content

Commit

Permalink
Update onnx docs (#2561)
Browse files Browse the repository at this point in the history
  • Loading branch information
AndreSlavescu authored Aug 5, 2024
1 parent 54b2a95 commit b467d4a
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 4 deletions.
2 changes: 1 addition & 1 deletion docs/onnx-conversion.md
Original file line number Diff line number Diff line change
Expand Up @@ -328,7 +328,7 @@ First, we want to store the newly generated models in the ```~/.cache/anserini/e

```bash
cd src/main/python/onnx/models
cp splade-cocondenser-ensembledistil-optimized.onnx ~/.cache/anserini/encoders/
cp splade-cocondenser-ensembledistil-optimized.onnx splade-cocondenser-ensembledistil-vocab.txt ~/.cache/anserini/encoders/
```

Second, now run the end to end regression as seen in the previously mentioned documentation with the generated ONNX model.
11 changes: 9 additions & 2 deletions src/main/python/onnx/convert_hf_model_to_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def get_dynamic_axes(input_names, output_names):
dynamic_axes[name] = {0: 'batch_size', 1: 'sequence'}
return dynamic_axes

def convert_model_to_onnx(text, model, tokenizer, onnx_path, device):
def convert_model_to_onnx(text, model, tokenizer, onnx_path, vocab_path, device):
print(model) # this prints the model structure for better understanding (optional)
model.eval()

Expand Down Expand Up @@ -70,6 +70,12 @@ def convert_model_to_onnx(text, model, tokenizer, onnx_path, device):
onnx.checker.check_model(onnx_model)
print("ONNX model checked successfully")

vocab = tokenizer.get_vocab()
with open(vocab_path, 'w', encoding='utf-8') as f:
for token, index in sorted(vocab.items(), key=lambda x: x[1]):
f.write(f"{token}\n")
print(f"Vocabulary saved to {vocab_path}")

# small inference session for testing
ort_session = onnxruntime.InferenceSession(onnx_path)
ort_inputs = {k: v.cpu().numpy() for k, v in test_input.items()}
Expand All @@ -89,5 +95,6 @@ def convert_model_to_onnx(text, model, tokenizer, onnx_path, device):

os.makedirs("models", exist_ok=True)
onnx_path = f"models/{model_prefix}.onnx"
vocab_path = f"models/{model_prefix}-vocab.txt"

convert_model_to_onnx(args.text, model, tokenizer, onnx_path, device=device)
convert_model_to_onnx(args.text, model, tokenizer, onnx_path, vocab_path, device=device)
2 changes: 1 addition & 1 deletion src/main/python/onnx/optimize_onnx_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def optimize_onnx_model(model_path, print_stats=False):
# optimized_model.convert_float_to_float16()

# Save the optimized model
model_name = model_path.split(".")[0]
model_name = model_path.rsplit(".onnx", 1)[0]
optimized_model_path = f'{model_name}-optimized.onnx'
optimized_model.save_model_to_file(optimized_model_path)
print(f"ONNX model optimization successful. Saved to {optimized_model_path}")
Expand Down

0 comments on commit b467d4a

Please sign in to comment.