-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpipeline.py
67 lines (52 loc) · 1.9 KB
/
pipeline.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import torch
from transformers import pipeline as hf_pipeline
from flair.models import TARSTagger
from flair.data import Sentence
class ZeroShotNERPipeline:
def __init__(
self,
model='tars-ner',
):
self.model = TARSTagger.load(model)
def __call__(self, query, labels):
self.model.add_and_switch_to_new_task('task', labels, label_type='ner', force_switch=True)
if not isinstance(query, list):
inputs = [query]
else:
inputs = query
inputs = [Sentence(s) for s in inputs]
self.model.predict(inputs)
outputs = [sentence.to_tagged_string("ner") for sentence in inputs]
return outputs
class SemanticSearchPipeline:
def __init__(
self,
model,
context=None,
device="cpu",
):
self.model = model
self.device = device
self.fx = hf_pipeline("feature-extraction", model=model)
self.context_emb = None
# Compute embeddings for context, if given
if context is not None:
self.context_emb = self.compute_embeddings(context)
def __call__(self, query, context=None, temperature=0.01, return_probs=True):
query_emb = self.compute_embeddings(query)
if context is not None:
self.context_emb = self.compute_embeddings(context)
elif self.context_emb is None:
raise Exception("No context was given.")
sim = torch.einsum("ij,kj->ik", query_emb, self.context_emb)
if return_probs:
sim_probs = torch.softmax(sim / temperature, dim=-1)
return sim_probs
else:
return sim
def compute_embeddings(self, input):
input_emb = torch.mean(torch.tensor(self.fx(input)).to(self.device), dim=1)
input_emb = input_emb / torch.linalg.norm(
input_emb, ord=2, axis=-1, keepdims=True
)
return input_emb