This repository has been archived by the owner on Nov 1, 2021. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 142
/
Copy pathapp.py
90 lines (73 loc) · 3.33 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
__copyright__ = "Copyright (c) 2021 Jina AI Limited. All rights reserved."
__license__ = "Apache-2.0"
import os
import sys
import click
import random
from jina import Flow, Document, DocumentArray
from jina.logging.predefined import default_logger as logger
MAX_DOCS = int(os.environ.get('JINA_MAX_DOCS', 10000))
def config(dataset: str):
if dataset == 'toy':
os.environ['JINA_DATA_FILE'] = os.environ.get('JINA_DATA_FILE', 'data/toy-input.txt')
elif dataset == 'full':
os.environ['JINA_DATA_FILE'] = os.environ.get('JINA_DATA_FILE', 'data/input.txt')
os.environ['JINA_PORT'] = os.environ.get('JINA_PORT', str(45678))
cur_dir = os.path.dirname(os.path.abspath(__file__))
os.environ.setdefault('JINA_WORKSPACE', os.path.join(cur_dir, 'workspace'))
os.environ.setdefault('JINA_WORKSPACE_MOUNT',
f'{os.environ.get("JINA_WORKSPACE")}:/workspace/workspace')
def print_topk(resp, sentence):
for doc in resp.data.docs:
print(f"\n\n\nTa-Dah🔮, here's what we found for: {sentence}")
for idx, match in enumerate(doc.matches):
score = match.scores['cosine'].value
print(f'> {idx:>2d}({score:.2f}). {match.text}')
def input_generator(num_docs: int, file_path: str):
with open(file_path) as file:
lines = file.readlines()
num_lines = len(lines)
random.shuffle(lines)
for i in range(min(num_docs, num_lines)):
yield Document(text=lines[i])
def index(num_docs):
flow = Flow().load_config('flows/flow.yml')
data_path = os.path.join(os.path.dirname(__file__), os.environ.get('JINA_DATA_FILE', None))
with flow:
flow.post(on='/index', inputs=input_generator(num_docs, data_path),
show_progress=True)
def query(top_k):
flow = Flow().load_config('flows/flow.yml')
with flow:
text = input('Please type a sentence: ')
doc = Document(content=text)
result = flow.post(on='/search', inputs=DocumentArray([doc]),
parameters={'top_k': top_k},
line_format='text',
return_results=True,
)
print_topk(result[0], text)
@click.command()
@click.option(
'--task',
'-t',
type=click.Choice(['index', 'query'], case_sensitive=False),
)
@click.option('--num_docs', '-n', default=MAX_DOCS)
@click.option('--top_k', '-k', default=5)
@click.option('--dataset', '-d', type=click.Choice(['toy', 'full']), default='toy')
def main(task, num_docs, top_k, dataset):
config(dataset)
if task == 'index':
if os.path.exists(os.environ.get("JINA_WORKSPACE")):
logger.error(f'\n +---------------------------------------------------------------------------------+ \
\n | 🤖🤖🤖 | \
\n | The directory {os.environ.get("JINA_WORKSPACE")} already exists. Please remove it before indexing again. | \
\n | 🤖🤖🤖 | \
\n +---------------------------------------------------------------------------------+')
sys.exit(1)
index(num_docs)
elif task == 'query':
query(top_k)
if __name__ == '__main__':
main()