Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add batch prediction for pipelines #3432

Merged
merged 3 commits into from
Oct 12, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions pipelines/examples/semantic-search/semantic_search_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,22 @@ def semantic_search_tutorial():
})

print_documents(prediction)
# Batch prediction
predictions = pipe.run_batch(queries=["亚马逊河流的介绍", '期货交易手续费指的是什么?'],
params={
"Retriever": {
"top_k": 50
},
"Ranker": {
"top_k": 5
}
})
for i in range(len(predictions['queries'])):
result = {
'documents': predictions['documents'][i],
'query': predictions['queries'][i]
}
print_documents(result)


if __name__ == "__main__":
Expand Down
31 changes: 24 additions & 7 deletions pipelines/pipelines/nodes/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,16 +127,33 @@ def _dispatch_run(self, **kwargs) -> Tuple[Dict, str]:
- collate `_debug` information if present
- merge component output with the preceding output and pass it on to the subsequent Component in the Pipeline
"""
return self._dispatch_run_general(self.run, **kwargs)

def _dispatch_run_batch(self, **kwargs):
"""
The Pipelines call this method when run_batch() is executed. This method in turn executes the
_dispatch_run_general() method with the correct run method.
"""
return self._dispatch_run_general(self.run_batch, **kwargs)

def _dispatch_run_general(self, run_method: Callable, **kwargs):
"""
This method takes care of the following:
- inspect run_method's signature to validate if all necessary arguments are available
- pop `debug` and sets them on the instance to control debug output
- call run_method with the corresponding arguments and gather output
- collate `_debug` information if present
- merge component output with the preceding output and pass it on to the subsequent Component in the Pipeline
"""
arguments = deepcopy(kwargs)
params = arguments.get("params") or {}

run_signature_args = inspect.signature(self.run).parameters.keys()
run_signature_args = inspect.signature(run_method).parameters.keys()

run_params: Dict[str, Any] = {}
for key, value in params.items():
if key == self.name: # targeted params for this node
if isinstance(value, dict):

# Extract debug attributes
if "debug" in value.keys():
self.debug = value.pop("debug")
Expand All @@ -156,19 +173,19 @@ def _dispatch_run(self, **kwargs) -> Tuple[Dict, str]:
if key in run_signature_args:
run_inputs[key] = value

output, stream = self.run(**run_inputs, **run_params)
output, stream = run_method(**run_inputs, **run_params)

# Collect debug information
debug_info = {}
if getattr(self, "debug", None):
# Include input
debug_info["input"] = {**run_inputs, **run_params}
debug_info["input"]["debug"] = self.debug
# Include output
# Include output, exclude _debug to avoid recursion
filtered_output = {
key: value
for key, value in output.items() if key != "_debug"
} # Exclude _debug to avoid recursion
}
debug_info["output"] = filtered_output
# Include custom debug info
custom_debug = output.get("_debug", {})
Expand All @@ -182,9 +199,9 @@ def _dispatch_run(self, **kwargs) -> Tuple[Dict, str]:
if all_debug:
output["_debug"] = all_debug

# add "extra" args that were not used by the node
# add "extra" args that were not used by the node, but not the 'inputs' value
for k, v in arguments.items():
if k not in output.keys():
if k not in output.keys() and k != "inputs":
output[k] = v

output["params"] = params
Expand Down
26 changes: 24 additions & 2 deletions pipelines/pipelines/nodes/ranker/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional
from typing import List, Optional, Union

import logging
from abc import abstractmethod
Expand Down Expand Up @@ -48,7 +48,7 @@ def predict_batch(self,
def run(self,
query: str,
documents: List[Document],
top_k: Optional[int] = None): # type: ignore
top_k: Optional[int] = None):
self.query_count += 1
if documents:
predict = self.timing(self.predict, "query_time")
Expand All @@ -62,6 +62,28 @@ def run(self,

return output, "output_1"

def run_batch(
self,
queries: List[str],
documents: Union[List[Document], List[List[Document]]],
top_k: Optional[int] = None,
batch_size: Optional[int] = None,
):
self.query_count += len(queries)
predict_batch = self.timing(self.predict_batch, "query_time")
results = predict_batch(queries=queries,
documents=documents,
top_k=top_k,
batch_size=batch_size)

for doc_list in results:
document_ids = [doc.id for doc in doc_list]
logger.debug("Ranked documents with IDs: %s", document_ids)

output = {"documents": results}

return output, "output_1"

def timing(self, fn, attr_name):
"""Wrapper method used to time functions."""

Expand Down
169 changes: 150 additions & 19 deletions pipelines/pipelines/nodes/ranker/ernie_ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional, Union
from typing import List, Optional, Union, Tuple, Iterator
import logging
from pathlib import Path
from tqdm import tqdm

import paddle
from paddlenlp.transformers import ErnieCrossEncoder, AutoTokenizer
Expand Down Expand Up @@ -44,6 +45,9 @@ def __init__(
model_name_or_path: Union[str, Path],
top_k: int = 10,
use_gpu: bool = True,
max_seq_len: int = 256,
progress_bar: bool = True,
batch_size: int = 1000,
):
"""
:param model_name_or_path: Directory of a saved model or the name of a public model e.g.
Expand All @@ -66,26 +70,13 @@ def __init__(
self.transformer_model = ErnieCrossEncoder(model_name_or_path)
self.tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
self.transformer_model.eval()
self.progress_bar = progress_bar
self.batch_size = batch_size
self.max_seq_len = max_seq_len

if len(self.devices) > 1:
self.model = paddle.DataParallel(self.transformer_model)

def predict_batch(self,
query_doc_list: List[dict],
top_k: int = None,
batch_size: int = None):
"""
Use loaded Ranker model to, for a list of queries, rank each query's supplied list of Document.

Returns list of dictionary of query and list of document sorted by (desc.) similarity with query

:param query_doc_list: List of dictionaries containing queries with their retrieved documents
:param top_k: The maximum number of answers to return for each query
:param batch_size: Number of samples the model receives in one batch for inference
:return: List of dictionaries containing query and ranked list of Document
"""
raise NotImplementedError

def predict(self,
query: str,
documents: List[Document],
Expand All @@ -105,7 +96,7 @@ def predict(self,

features = self.tokenizer([query for doc in documents],
[doc.content for doc in documents],
max_seq_len=256,
max_seq_len=self.max_seq_len,
pad_to_max_seq_len=True,
truncation_strategy="longest_first")

Expand All @@ -125,6 +116,146 @@ def predict(self,
reverse=True,
)

# rank documents according to scores
# Rank documents according to scores
sorted_documents = [doc for _, doc in sorted_scores_and_documents]
return sorted_documents[:top_k]

def predict_batch(
self,
queries: List[str],
documents: Union[List[Document], List[List[Document]]],
top_k: Optional[int] = None,
batch_size: Optional[int] = None,
) -> Union[List[Document], List[List[Document]]]:
"""
Use loaded ranker model to re-rank the supplied lists of Documents

Returns lists of Documents sorted by (desc.) similarity with the corresponding queries.

:param queries: Single query string or list of queries
:param documents: Single list of Documents or list of lists of Documents to be reranked.
:param top_k: The maximum number of documents to return per Document list.
:param batch_size: Number of Documents to process at a time.
"""
if top_k is None:
top_k = self.top_k

if batch_size is None:
batch_size = self.batch_size

number_of_docs, all_queries, all_docs, single_list_of_docs = self._preprocess_batch_queries_and_docs(
queries=queries, documents=documents)
batches = self._get_batches(all_queries=all_queries,
all_docs=all_docs,
batch_size=batch_size)
pb = tqdm(total=len(all_docs),
disable=not self.progress_bar,
desc="Ranking")

preds = []
for cur_queries, cur_docs in batches:
features = self.tokenizer(cur_queries,
[doc.content for doc in cur_docs],
max_seq_len=256,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的max_seq_len是否可以通过参数传入了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改

pad_to_max_seq_len=True,
truncation_strategy="longest_first")

tensors = {k: paddle.to_tensor(v) for (k, v) in features.items()}

with paddle.no_grad():
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的transformer_model是否设置了eval 状态了?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ErnieRanker的init函数里面有设置成eval状态

similarity_scores = self.transformer_model.matching(
**tensors).numpy()
preds.extend(similarity_scores)

for doc, rank_score in zip(cur_docs, similarity_scores):
doc.rank_score = rank_score
doc.score = rank_score
pb.update(len(cur_docs))
pb.close()
if single_list_of_docs:
sorted_scores_and_documents = sorted(
zip(preds, documents),
key=lambda similarity_document_tuple: similarity_document_tuple[
0],
reverse=True,
)
sorted_documents = [doc for _, doc in sorted_scores_and_documents]
return sorted_documents[:top_k]
else:
grouped_predictions = []
left_idx = 0
right_idx = 0
for number in number_of_docs:
right_idx = left_idx + number
grouped_predictions.append(
similarity_scores[left_idx:right_idx])
left_idx = right_idx
result = []
for pred_group, doc_group in zip(grouped_predictions, documents):
sorted_scores_and_documents = sorted(
zip(pred_group, doc_group),
key=lambda similarity_document_tuple:
similarity_document_tuple[0],
reverse=True,
)
sorted_documents = [
doc for _, doc in sorted_scores_and_documents
]
result.append(sorted_documents[:top_k])

return result

def _preprocess_batch_queries_and_docs(
self, queries: List[str], documents: Union[List[Document],
List[List[Document]]]
) -> Tuple[List[int], List[str], List[Document], bool]:
number_of_docs = []
all_queries = []
all_docs: List[Document] = []
single_list_of_docs = False

# Docs case 1: single list of Documents -> rerank single list of Documents based on single query
if len(documents) > 0 and isinstance(documents[0], Document):
if len(queries) != 1:
raise Exception(
"Number of queries must be 1 if a single list of Documents is provided."
)
query = queries[0]
number_of_docs = [len(documents)]
all_queries = [query] * len(documents)
all_docs = documents # type: ignore
single_list_of_docs = True

# Docs case 2: list of lists of Documents -> rerank each list of Documents based on corresponding query
# If queries contains a single query, apply it to each list of Documents
if len(documents) > 0 and isinstance(documents[0], list):
if len(queries) == 1:
queries = queries * len(documents)
if len(queries) != len(documents):
raise Exception(
"Number of queries must be equal to number of provided Document lists."
)
for query, cur_docs in zip(queries, documents):
if not isinstance(cur_docs, list):
raise Exception(
f"cur_docs was of type {type(cur_docs)}, but expected a list of Documents."
)
number_of_docs.append(len(cur_docs))
all_queries.extend([query] * len(cur_docs))
all_docs.extend(cur_docs)

return number_of_docs, all_queries, all_docs, single_list_of_docs

@staticmethod
def _get_batches(
all_queries: List[str], all_docs: List[Document],
batch_size: Optional[int]
) -> Iterator[Tuple[List[str], List[Document]]]:
if batch_size is None:
yield all_queries, all_docs
return
else:
for index in range(0, len(all_queries), batch_size):
yield all_queries[index:index +
batch_size], all_docs[index:index +
batch_size]
Loading