-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathpredict.py
46 lines (34 loc) · 1.15 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
import flair
import json
from flair.data import Sentence
from flair.models import SequenceTagger
def parse_json(path):
lines = open(path, encoding='utf-8').read().splitlines()
model = SequenceTagger.load('final-model.pt')
for line in lines:
tweet = json.loads(line)
print("Input")
print(tweet)
sentence = Sentence(tweet["text"])
model.predict(sentence)
print("Model Prediction")
print(sentence.to_tagged_string())
loc_list = []
for entity in sentence.get_spans('ner'):
if entity.get_label("ner").value == 'LOC':
loc_dict = {"text": entity.text,
"start_offset": entity.start_position,
"end_offset": entity.end_position
}
loc_list.append(loc_dict)
d = {
"tweet_id": tweet["tweet_id"],
"location_mentions": loc_list
}
print("Output")
print(d)
j = json.dumps(d)
with open('output.jsonl', 'a') as f:
f.write(j)
f.write('\n')
parse_json("input.jsonl")