Skip to content

Commit

Permalink
Add customized tokenizer.
Browse files Browse the repository at this point in the history
  • Loading branch information
Mansterteddy committed Oct 17, 2021
1 parent 790080f commit d441673
Show file tree
Hide file tree
Showing 6 changed files with 450 additions and 1 deletion.
5 changes: 4 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,4 @@
*.onnx
*.onnx
*.txt
*.bin
*.dll
177 changes: 177 additions & 0 deletions onnxruntime/tokenizer/baseline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
import json
import ctypes
import numpy as np
from blingfire import *
from numpy.ctypeslib import ndpointer

blingfire_path = "./blingfiretokdll.dll"
blingfire_model = "./data/xlm_roberta_base.bin"
vocab_path = "./data/vocab.txt"
max_doc_count = 96
max_seq_length = 256
max_query_length = 16
max_title_length = 32
max_url_length = 32

h = load_model(blingfire_model)

print("Load Bling Fire Tokenizer")

dir_path = os.path.dirname(os.path.realpath(__file__))
os.environ["PATH"] = dir_path + ';' + os.environ["PATH"]
ranklm_lib = ctypes.CDLL("./RankLMTokenization.dll")

ranklm_init = ranklm_lib.RankLMTokenization_SentencePiece_FBV_Init
ranklm_init.argtypes = [ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_int, ctypes.c_char_p, ctypes.c_char_p]
ranklm_init.restype = None

ranklm_id_tokenize = ranklm_lib.RankLMTokenization_SentencePiece_FBV_ID_Tokenize
ranklm_id_tokenize.argtypes = [ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_int, ndpointer(ctypes.c_int32), ndpointer(ctypes.c_int32), ndpointer(ctypes.c_int32)]
ranklm_id_tokenize.restype = None

ranklm_token_tokenize = ranklm_lib.RankLMTokenization_SentencePiece_FBV_Token_Tokenize
ranklm_token_tokenize.argtypes = [ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_int, ndpointer(ctypes.c_int32), ndpointer(ctypes.c_int32), ndpointer(ctypes.c_int32)]
ranklm_token_tokenize.restype = None

ranklm_tokenize = ranklm_lib.RankLMTokenization_SentencePiece_FBV_Tokenize
ranklm_tokenize.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_int, ndpointer(ctypes.c_int32), ndpointer(ctypes.c_int32), ndpointer(ctypes.c_int32)]
ranklm_tokenize.restype = None

ranklm_fb_tokenize = ranklm_lib.RankLMTokenization_SentencePiece_FBV_FB_Tokenize
ranklm_fb_tokenize.argtypes = [ctypes.c_void_p, ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_char_p, ndpointer(ctypes.c_int),
ctypes.c_int, ndpointer(ctypes.c_int32), ndpointer(ctypes.c_int32), ndpointer(ctypes.c_int32)]
ranklm_fb_tokenize.restype = None

ranklm_init(max_doc_count, max_seq_length, max_query_length, max_title_length, max_url_length, ctypes.c_char_p(blingfire_path.encode("utf-8")), ctypes.c_char_p(vocab_path.encode("utf-8")))

def get_lang_dist_from_market(market):
lang_dist = market.split('-')
if len(lang_dist) >= 2:
language = "-".join(lang_dist[:-1])
district = lang_dist[-1]
else:
language = "un"
district = "un"

return language, district

def get_lang_dist(market, market_json):

if ("Language" in market_json) and ("Region" in market_json):
lang_dist_dict = json.loads(market_json)
language = lang_dist_dict["Language"].lower().strip()
district = lang_dist_dict["Region"].lower().strip()

if language == "" or district == "":
language, district = get_lang_dist_from_market(market)

else:
language, district = get_lang_dist_from_market(market)

return language, district


input_list = [["23314 454 7560 85 5 3958 32 188131 454 11627 1369", "153 115 13761 3245 30128 21393 6 3958 6 33957 2011 126707 13820 18 75813 121046 6957 1284 18 46667 225006 153 24 33416 6 78175 111202 20179 95 39884 13639 425 16684 23314 194602 78403 2011 124999 153 196423 31 9607 363 36398 96335 68828 9351 45 10763 6635 7026 8834 73395 1230 82678 74", "106 25037 92 6 2566 3114 64 9271 41793 92", "48498 100 71 77463 26249 36049 141496 159201 41 1294 22970 144", "fr-fr", ""], ["11493 5 337 67847", "305 13312 6650 20 351 1507 1202 337 67847 337 67847 11493 123 3177", "78600 30535 113543 81384 64 10248 64 864 910 2507 169 3742 6 7693", "337 67847 11493 123 3177 20 337 67847 35399", "en-id", ""], ["6 8709 71684 1128 56963 9594", "378 122369 268 6 8709 71684 1128 4035 9056 11541 64632 37106 46879 2490 9839 5873 5 1210 37151 153 28292 194546 56963 18617 143964 9594 15 6 192141 10134 2846 1388 6 167039 8709 71684 1128 106000 194546 240762 6995 1173 35645 684 109052 5873 15 6 20212 10134 2846 1388 6 71729 38", "82414 496 9365 65451", "6 8709 71684 1128 14455 9065 9 12865 68818 1764", "zh-tw", ""]]

query_list = b""
snippet_list = b""
url_list = b""
title_list = b""
lang_list = b""
dist_list = b""

query_lengths = []
snippet_lengths = []
url_lengths = []
title_lengths = []
lang_lengths = []
dist_lengths = []

for instance in input_list:

query = instance[0].strip()
snippet = instance[1].strip() + " " + instance[5].strip()
url = instance[2].strip()
title = instance[3].strip()
market = instance[4].lower()
language, district = get_lang_dist(market, instance[-1])

query_encode = query.encode("utf-8")
snippet_encode = snippet.encode("utf-8")
url_encode = url.encode("utf-8")
title_encode = title.encode("utf-8")
lang_encode = language.encode("utf-8")
dist_encode = district.encode("utf-8")

query_list += query_encode
snippet_list += snippet_encode
url_list += url_encode
title_list += title_encode
lang_list += lang_encode
dist_list += dist_encode

query_lengths.append(len(query_encode))
snippet_lengths.append(len(snippet_encode))
url_lengths.append(len(url_encode))
title_lengths.append(len(title_encode))
lang_lengths.append(len(lang_encode))
dist_lengths.append(len(dist_encode))

p_query_list = ctypes.c_char_p(query_list)
p_snippet_list = ctypes.c_char_p(snippet_list)
p_url_list = ctypes.c_char_p(url_list)
p_title_list = ctypes.c_char_p(title_list)
p_lang_list = ctypes.c_char_p(lang_list)
p_dist_list = ctypes.c_char_p(dist_list)

p_query_lengths = np.array(query_lengths, dtype="int32")
p_snippet_lengths = np.array(snippet_lengths, dtype="int32")
p_url_lengths = np.array(url_lengths, dtype="int32")
p_title_lengths = np.array(title_lengths, dtype="int32")
p_lang_lengths = np.array(lang_lengths, dtype="int32")
p_dist_lengths = np.array(dist_lengths, dtype="int32")

batch_size = len(query_lengths)

input_ids = np.zeros((batch_size, max_seq_length), dtype="int32")
segment_ids = np.zeros((batch_size, max_seq_length), dtype="int32")
input_mask = np.zeros((batch_size, max_seq_length), dtype="int32")

ranklm_id_tokenize(p_query_list, p_query_lengths,
p_snippet_list, p_snippet_lengths,
p_url_list, p_url_lengths,
p_title_list, p_title_lengths,
p_lang_list, p_lang_lengths,
p_dist_list, p_dist_lengths,
batch_size, input_ids, segment_ids, input_mask)


print("input_ids: ", input_ids)
print("segment_ids: ", segment_ids)
print("input_mask: ", input_mask)

free_model(h)
13 changes: 13 additions & 0 deletions onnxruntime/tokenizer/export_ranklmtokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
import onnx
import onnxruntime as ort
from onnx import helper, onnx_pb as onnx_proto
from onnxruntime_extensions import make_onnx_model

nodes = []
nodes.append(helper.make_node("RankLMTokenizer", ["input"], ["output"], domain='ai.onnx.contrib'))

input = helper.make_tensor_value_info("input", onnx_proto.TensorProto.STRING, [None, None])
output = helper.make_tensor_value_info("output", onnx_proto.TensorProto.INT64, [None, None])
graph = helper.make_graph(nodes, "RankLM", [input], [output])
model = make_onnx_model(graph)
onnx.save(model, "RankLMToken.onnx")
34 changes: 34 additions & 0 deletions onnxruntime/tokenizer/export_raw_tokenizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import onnx
import onnxruntime as ort
from onnx import helper, onnx_pb as onnx_proto
from onnxruntime_extensions import make_onnx_model

nodes = []
nodes.append(helper.make_node("QueryNormalize", ["query"], ["NormalizedQuery"], domain='ai.onnx.contrib'))
nodes.append(helper.make_node("TitleNormalize", ["title"], ["NormalizedTitle"], domain='ai.onnx.contrib'))
nodes.append(helper.make_node("SnippetNormalize", ["snippet"], ["NormalizedSnippet"], domain='ai.onnx.contrib'))
nodes.append(helper.make_node("UrlNormalize", ["url"], ["NormalizedUrl"], domain='ai.onnx.contrib'))
nodes.append(helper.make_node("MarketNormalize", ["market"], ["NormalizedMarket"], domain='ai.onnx.contrib'))

nodes.append(helper.make_node("QueryTokenize", ["NormalizedQuery"], ["TokenizedQuery"], domain='ai.onnx.contrib'))
nodes.append(helper.make_node("TitleTokenize", ["NormalizedTitle"], ["TokenizedTitle"], domain='ai.onnx.contrib'))
nodes.append(helper.make_node("SnippetTokenize", ["NormalizedSnippet"], ["TokenizedSnippet"], domain='ai.onnx.contrib'))
nodes.append(helper.make_node("UrlTokenize", ["NormalizedUrl"], ["TokenizedUrl"], domain='ai.onnx.contrib'))
nodes.append(helper.make_node("MarketTokenize", ["NormalizedMarket"], ["TokenizedMarket"], domain='ai.onnx.contrib'))

nodes.append(helper.make_node("IdConcat", ["TokenizedQuery", "TokenizedTitle", "TokenizedSnippet", "TokenizedUrl", "TokenizedMarket"],
["input_ids", "segment_ids", "input_mask"], domain='ai.onnx.contrib'))

query = helper.make_tensor_value_info("query", onnx_proto.TensorProto.STRING, [None, None])
title = helper.make_tensor_value_info("title", onnx_proto.TensorProto.STRING, [None, None])
snippet = helper.make_tensor_value_info("snippet", onnx_proto.TensorProto.STRING, [None, None])
url = helper.make_tensor_value_info("url", onnx_proto.TensorProto.STRING, [None, None])
market = helper.make_tensor_value_info("market", onnx_proto.TensorProto.STRING, [None, None])

input_ids = helper.make_tensor_value_info("input_ids", onnx_proto.TensorProto.INT64, [None, None])
segment_ids = helper.make_tensor_value_info("segment_ids", onnx_proto.TensorProto.INT64, [None, None])
input_mask = helper.make_tensor_value_info("input_mask", onnx_proto.TensorProto.INT64, [None, None])

graph = helper.make_graph(nodes, "RankLM", [query, title, snippet, url, market], [input_ids, segment_ids, input_mask])
model = make_onnx_model(graph)
onnx.save(model, "ranklm_raw_tokenizer.onnx")
31 changes: 31 additions & 0 deletions onnxruntime/tokenizer/export_whole_graph.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
import onnx
import onnxruntime as ort
from onnx import helper, onnx_pb as onnx_proto
from onnxruntime_extensions import make_onnx_model

nodes = []
nodes.append(helper.make_node("RankLMTokenizer", ["query", "title", "snippet", "url", "market"], ["input_ids", "segment_ids", "input_mask"], domain='ai.onnx.contrib'))

query = helper.make_tensor_value_info("query", onnx_proto.TensorProto.STRING, [None, None])
title = helper.make_tensor_value_info("title", onnx_proto.TensorProto.STRING, [None, None])
snippet = helper.make_tensor_value_info("snippet", onnx_proto.TensorProto.STRING, [None, None])
url = helper.make_tensor_value_info("url", onnx_proto.TensorProto.STRING, [None, None])
market = helper.make_tensor_value_info("market", onnx_proto.TensorProto.STRING, [None, None])

input_ids = helper.make_tensor_value_info("input_ids", onnx_proto.TensorProto.INT64, [None, None])
segment_ids = helper.make_tensor_value_info("segment_ids", onnx_proto.TensorProto.INT64, [None, None])
input_mask = helper.make_tensor_value_info("input_mask", onnx_proto.TensorProto.INT64, [None, None])

graph = helper.make_graph(nodes, "RankLM", [query, title, snippet, url, market], [input_ids, segment_ids, input_mask])
model = make_onnx_model(graph)
onnx.save(model, "ranklm_optim_tokenizer.onnx")

score = helper.make_tensor_value_info("score", onnx_proto.TensorProto.FLOAT, [None, None])
logits = helper.make_tensor_value_info("logits", onnx_proto.TensorProto.FLOAT, [None, None])

prev_model = onnx.load("ranklm.onnx").graph.node
nodes.extend(prev_model)

graph = helper.make_graph(nodes, "RankLM", [query, snippet, url, title, market], [score, logits])
model = make_onnx_model(graph)
onnx.save(model, "ranklm_whole_onnx.onnx")
Loading

0 comments on commit d441673

Please sign in to comment.