Skip to content

Commit

Permalink
Classifier loads labels from remote
Browse files Browse the repository at this point in the history
  • Loading branch information
Joffreybvn committed Feb 6, 2021
1 parent 6b0d122 commit 0cb6fbf
Show file tree
Hide file tree
Showing 4 changed files with 38 additions and 27 deletions.
15 changes: 11 additions & 4 deletions config.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,16 @@ class Config:
MODEL_MATCHING = "TF-IDF" # PolyFuzz lightest model - Optimized for matching
MODEL_CLASSIFIER = "bert-base-uncased" # HuggingFace smallest BERT model - For tokenization and classifying

# Weights to fine-tune the classifier
MODEL_WEIGHT_URL = "https://static.joffreybvn.be/file/joffreybvn/resachatbot/resa_BERT_model.pt"
MODEL_WEIGHT_LOCAL_COPY = "./assets/model/resa_BERT_model.pt"
# Remote files
b2_base_url = "https://static.joffreybvn.be/file/joffreybvn/resachatbot"

# External files
weight_file = "resa_BERT_model.pt"
MODEL_WEIGHT_URL = f"{b2_base_url}/{weight_file}" # Fine-tuned weights for BERT model
MODEL_WEIGHT_LOCAL_COPY = f"./assets/model/{weight_file}"

classes_file = "labels.pickle"
MODEL_CLASSES_URL = f"{b2_base_url}/{classes_file}"
MODEL_CLASSES_LOCAL_COPY = f"./assets/model/{classes_file}"

# Filters
FILTERS_TOML = "./filters.toml"
9 changes: 4 additions & 5 deletions filters.toml
Original file line number Diff line number Diff line change
@@ -1,16 +1,15 @@
# TOML document to store filters

[longtalk_hotel_reserve]
[longtalk_make_reservation]

# Size of the room: How many people ?
[longtalk_hotel_reserve.people]
[longtalk_make_reservation.people]
words = ["pearson", "people"]
regex = '''(?P<people>\d)\W%s'''
threshold = 0.85

# Duration of the book: How long ?
[longtalk_hotel_reserve.duration]
words = ["day"]
[longtalk_make_reservation.duration]
words = ["day", "night"]
regex = '''(?P<duration>\d)\W%s'''
threshold = 0.85

40 changes: 23 additions & 17 deletions src/classifying/classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,35 +16,41 @@ class Classifier:

def __init__(self):

# Load the classes
# Load the classes and the model
self.labels = self._load_labels()

# Download the model and instantiate it
self._download_model()
self.model = self._load_model()

@staticmethod
def _load_labels() -> dict:
"""Load the dictionary labels from a pickle file and return it."""

with open('./assets/data/labels.pickle', 'rb') as handle:
return pickle.load(handle)

@staticmethod
def _download_model():
"""
Stream and download the model from a given url to the given path.
"""
def __load_remote_file(url: str, local: str):

# Open the URL and a local file
with requests.get(config.MODEL_WEIGHT_URL, stream=True) as response:
with open(config.MODEL_WEIGHT_LOCAL_COPY, 'wb') as handle:
with requests.get(url, stream=True) as response:
with open(local, 'wb') as handle:

# Stream the model to the local file
for chunk in response.iter_content(chunk_size=8192):
handle.write(chunk)

def _load_labels(self) -> dict:
"""
Load the dictionary labels from a remote pickle file and return it.
"""

# Download and save the pickle locally
self.__load_remote_file(config.MODEL_CLASSES_URL, config.MODEL_CLASSES_LOCAL_COPY)

# Load and return a dictionary
with open(config.MODEL_CLASSES_LOCAL_COPY, 'rb') as handle:
return pickle.load(handle)

def _load_model(self) -> BertForSequenceClassification:
"""
Load the weight of the model from a remote file (around 500 Mo),
instantiate and return the model.
"""

# Download and save the weights locally
self.__load_remote_file(config.MODEL_WEIGHT_URL, config.MODEL_WEIGHT_LOCAL_COPY)

# Instantiate the model
model = BertForSequenceClassification.from_pretrained(
Expand Down
1 change: 0 additions & 1 deletion src/matching/matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ def __load_filters() -> dict:
def get_keywords(self, text: str, intent: str) -> dict:

keywords = {}

if intent in self.filters:

# Split the text into a list of words
Expand Down

0 comments on commit 0cb6fbf

Please sign in to comment.