From 12c605af45f0699a0aad2dc4fa1fe0a88fa0b955 Mon Sep 17 00:00:00 2001 From: kztao Date: Wed, 7 Sep 2022 14:44:14 +0800 Subject: [PATCH] Fix windows dtype bug of neural search (#3182) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Fix windows dtype bug of neural search * Fix windows dtype bug of neural search Co-authored-by: 吴高升 --- applications/neural_search/recall/simcse/inference.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/applications/neural_search/recall/simcse/inference.py b/applications/neural_search/recall/simcse/inference.py index 0e11c6ad65e4..097c348c736f 100644 --- a/applications/neural_search/recall/simcse/inference.py +++ b/applications/neural_search/recall/simcse/inference.py @@ -66,8 +66,10 @@ def convert_example(example, tokenizer, max_seq_length=512, do_evalute=False): max_seq_length=max_seq_length) batchify_fn = lambda samples, fn=Tuple( - Pad(axis=0, pad_val=tokenizer.pad_token_id), # text_input - Pad(axis=0, pad_val=tokenizer.pad_token_type_id), # text_segment + Pad(axis=0, pad_val=tokenizer.pad_token_id, dtype="int64" + ), # text_input + Pad(axis=0, pad_val=tokenizer.pad_token_type_id, dtype="int64" + ), # text_segment ): [data for data in fn(samples)] pretrained_model = AutoModel.from_pretrained(model_name_or_path)