-
Notifications
You must be signed in to change notification settings - Fork 6
/
Copy pathcomponent_0_culture_relevance_classifier.py
75 lines (61 loc) · 2.31 KB
/
component_0_culture_relevance_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
import os
import time
import pandas as pd
from transformers import pipeline
import logging
from pipeline.pipeline_component import PipelineComponent
logger = logging.getLogger(__name__)
class CultureRelevanceClassifier(PipelineComponent):
description = (
"Classify sentences/comments and pick out the culturally-relevant ones"
)
config_layer = "0_culture_relevance_classifier"
def __init__(self, config: dict):
super().__init__(config)
# Get local config
self._local_config = config[self.config_layer]
os.makedirs(
os.path.join(config["result_base_dir"], self.config_layer), exist_ok=True
)
# Get the labels
if "output_file" in self._local_config:
self.check_if_output_exists(self._local_config["output_file"])
# Get the classifier config
self._model_name = self._local_config["model_name"]
self._device = self._local_config["device"]
# prepare the classifier
self.classifier = pipeline(
"text-classification",
model=self._local_config["classifier_path"],
device=self._device,
)
def read_input(self):
df = pd.read_csv(self._local_config["input_file"])
if self._config["dry_run"] is not None:
df = df.head(self._config["dry_run"])
return df
def run(self):
id2label = {0: "No", 1: "Yes"}
label2id = {"No": 0, "Yes": 1}
df = self.read_input()
logger.info(f"total number of samples: {len(df)}")
test_pred = self.classifier(
[
str(r)
for r in df[self._local_config["field_name_with_comments"]].tolist()
],
batch_size=self._local_config["batch_size"],
truncation=True,
max_length=self._local_config["batch_size"],
)
df["pred_label"] = [id2label[int(pred["label"][-1])] for pred in test_pred]
df["pred_score"] = [round(pred["score"], 2) for pred in test_pred]
df = df[df["pred_label"] == "Yes"]
self.save_output(df)
logger.info("Prediction Done!")
def save_output(self, df):
logger.info(f"save to {self._local_config['output_file']}")
df.to_csv(
self._local_config["output_file"],
index=False,
)