-
Notifications
You must be signed in to change notification settings - Fork 3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add batch prediction for pipelines #3432
Merged
Merged
Changes from all commits
Commits
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,9 +12,10 @@ | |
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
from typing import List, Optional, Union | ||
from typing import List, Optional, Union, Tuple, Iterator | ||
import logging | ||
from pathlib import Path | ||
from tqdm import tqdm | ||
|
||
import paddle | ||
from paddlenlp.transformers import ErnieCrossEncoder, AutoTokenizer | ||
|
@@ -44,6 +45,9 @@ def __init__( | |
model_name_or_path: Union[str, Path], | ||
top_k: int = 10, | ||
use_gpu: bool = True, | ||
max_seq_len: int = 256, | ||
progress_bar: bool = True, | ||
batch_size: int = 1000, | ||
): | ||
""" | ||
:param model_name_or_path: Directory of a saved model or the name of a public model e.g. | ||
|
@@ -66,26 +70,13 @@ def __init__( | |
self.transformer_model = ErnieCrossEncoder(model_name_or_path) | ||
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path) | ||
self.transformer_model.eval() | ||
self.progress_bar = progress_bar | ||
self.batch_size = batch_size | ||
self.max_seq_len = max_seq_len | ||
|
||
if len(self.devices) > 1: | ||
self.model = paddle.DataParallel(self.transformer_model) | ||
|
||
def predict_batch(self, | ||
query_doc_list: List[dict], | ||
top_k: int = None, | ||
batch_size: int = None): | ||
""" | ||
Use loaded Ranker model to, for a list of queries, rank each query's supplied list of Document. | ||
|
||
Returns list of dictionary of query and list of document sorted by (desc.) similarity with query | ||
|
||
:param query_doc_list: List of dictionaries containing queries with their retrieved documents | ||
:param top_k: The maximum number of answers to return for each query | ||
:param batch_size: Number of samples the model receives in one batch for inference | ||
:return: List of dictionaries containing query and ranked list of Document | ||
""" | ||
raise NotImplementedError | ||
|
||
def predict(self, | ||
query: str, | ||
documents: List[Document], | ||
|
@@ -105,7 +96,7 @@ def predict(self, | |
|
||
features = self.tokenizer([query for doc in documents], | ||
[doc.content for doc in documents], | ||
max_seq_len=256, | ||
max_seq_len=self.max_seq_len, | ||
pad_to_max_seq_len=True, | ||
truncation_strategy="longest_first") | ||
|
||
|
@@ -125,6 +116,146 @@ def predict(self, | |
reverse=True, | ||
) | ||
|
||
# rank documents according to scores | ||
# Rank documents according to scores | ||
sorted_documents = [doc for _, doc in sorted_scores_and_documents] | ||
return sorted_documents[:top_k] | ||
|
||
def predict_batch( | ||
self, | ||
queries: List[str], | ||
documents: Union[List[Document], List[List[Document]]], | ||
top_k: Optional[int] = None, | ||
batch_size: Optional[int] = None, | ||
) -> Union[List[Document], List[List[Document]]]: | ||
""" | ||
Use loaded ranker model to re-rank the supplied lists of Documents | ||
|
||
Returns lists of Documents sorted by (desc.) similarity with the corresponding queries. | ||
|
||
:param queries: Single query string or list of queries | ||
:param documents: Single list of Documents or list of lists of Documents to be reranked. | ||
:param top_k: The maximum number of documents to return per Document list. | ||
:param batch_size: Number of Documents to process at a time. | ||
""" | ||
if top_k is None: | ||
top_k = self.top_k | ||
|
||
if batch_size is None: | ||
batch_size = self.batch_size | ||
|
||
number_of_docs, all_queries, all_docs, single_list_of_docs = self._preprocess_batch_queries_and_docs( | ||
queries=queries, documents=documents) | ||
batches = self._get_batches(all_queries=all_queries, | ||
all_docs=all_docs, | ||
batch_size=batch_size) | ||
pb = tqdm(total=len(all_docs), | ||
disable=not self.progress_bar, | ||
desc="Ranking") | ||
|
||
preds = [] | ||
for cur_queries, cur_docs in batches: | ||
features = self.tokenizer(cur_queries, | ||
[doc.content for doc in cur_docs], | ||
max_seq_len=256, | ||
pad_to_max_seq_len=True, | ||
truncation_strategy="longest_first") | ||
|
||
tensors = {k: paddle.to_tensor(v) for (k, v) in features.items()} | ||
|
||
with paddle.no_grad(): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里的transformer_model是否设置了eval 状态了? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在 |
||
similarity_scores = self.transformer_model.matching( | ||
**tensors).numpy() | ||
preds.extend(similarity_scores) | ||
|
||
for doc, rank_score in zip(cur_docs, similarity_scores): | ||
doc.rank_score = rank_score | ||
doc.score = rank_score | ||
pb.update(len(cur_docs)) | ||
pb.close() | ||
if single_list_of_docs: | ||
sorted_scores_and_documents = sorted( | ||
zip(preds, documents), | ||
key=lambda similarity_document_tuple: similarity_document_tuple[ | ||
0], | ||
reverse=True, | ||
) | ||
sorted_documents = [doc for _, doc in sorted_scores_and_documents] | ||
return sorted_documents[:top_k] | ||
else: | ||
grouped_predictions = [] | ||
left_idx = 0 | ||
right_idx = 0 | ||
for number in number_of_docs: | ||
right_idx = left_idx + number | ||
grouped_predictions.append( | ||
similarity_scores[left_idx:right_idx]) | ||
left_idx = right_idx | ||
result = [] | ||
for pred_group, doc_group in zip(grouped_predictions, documents): | ||
sorted_scores_and_documents = sorted( | ||
zip(pred_group, doc_group), | ||
key=lambda similarity_document_tuple: | ||
similarity_document_tuple[0], | ||
reverse=True, | ||
) | ||
sorted_documents = [ | ||
doc for _, doc in sorted_scores_and_documents | ||
] | ||
result.append(sorted_documents[:top_k]) | ||
|
||
return result | ||
|
||
def _preprocess_batch_queries_and_docs( | ||
self, queries: List[str], documents: Union[List[Document], | ||
List[List[Document]]] | ||
) -> Tuple[List[int], List[str], List[Document], bool]: | ||
number_of_docs = [] | ||
all_queries = [] | ||
all_docs: List[Document] = [] | ||
single_list_of_docs = False | ||
|
||
# Docs case 1: single list of Documents -> rerank single list of Documents based on single query | ||
if len(documents) > 0 and isinstance(documents[0], Document): | ||
if len(queries) != 1: | ||
raise Exception( | ||
"Number of queries must be 1 if a single list of Documents is provided." | ||
) | ||
query = queries[0] | ||
number_of_docs = [len(documents)] | ||
all_queries = [query] * len(documents) | ||
all_docs = documents # type: ignore | ||
single_list_of_docs = True | ||
|
||
# Docs case 2: list of lists of Documents -> rerank each list of Documents based on corresponding query | ||
# If queries contains a single query, apply it to each list of Documents | ||
if len(documents) > 0 and isinstance(documents[0], list): | ||
if len(queries) == 1: | ||
queries = queries * len(documents) | ||
if len(queries) != len(documents): | ||
raise Exception( | ||
"Number of queries must be equal to number of provided Document lists." | ||
) | ||
for query, cur_docs in zip(queries, documents): | ||
if not isinstance(cur_docs, list): | ||
raise Exception( | ||
f"cur_docs was of type {type(cur_docs)}, but expected a list of Documents." | ||
) | ||
number_of_docs.append(len(cur_docs)) | ||
all_queries.extend([query] * len(cur_docs)) | ||
all_docs.extend(cur_docs) | ||
|
||
return number_of_docs, all_queries, all_docs, single_list_of_docs | ||
|
||
@staticmethod | ||
def _get_batches( | ||
all_queries: List[str], all_docs: List[Document], | ||
batch_size: Optional[int] | ||
) -> Iterator[Tuple[List[str], List[Document]]]: | ||
if batch_size is None: | ||
yield all_queries, all_docs | ||
return | ||
else: | ||
for index in range(0, len(all_queries), batch_size): | ||
yield all_queries[index:index + | ||
batch_size], all_docs[index:index + | ||
batch_size] |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里的max_seq_len是否可以通过参数传入了?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已修改