From 0cb6fbf2efff2c88b8d494897896014239282db4 Mon Sep 17 00:00:00 2001 From: Joffrey Bienvenu Date: Sat, 6 Feb 2021 01:43:21 +0100 Subject: [PATCH] Classifier loads labels from remote --- config.py | 15 +++++++++---- filters.toml | 9 ++++---- src/classifying/classifier.py | 40 ++++++++++++++++++++--------------- src/matching/matcher.py | 1 - 4 files changed, 38 insertions(+), 27 deletions(-) diff --git a/config.py b/config.py index 7eee429..c94e46a 100644 --- a/config.py +++ b/config.py @@ -17,9 +17,16 @@ class Config: MODEL_MATCHING = "TF-IDF" # PolyFuzz lightest model - Optimized for matching MODEL_CLASSIFIER = "bert-base-uncased" # HuggingFace smallest BERT model - For tokenization and classifying - # Weights to fine-tune the classifier - MODEL_WEIGHT_URL = "https://static.joffreybvn.be/file/joffreybvn/resachatbot/resa_BERT_model.pt" - MODEL_WEIGHT_LOCAL_COPY = "./assets/model/resa_BERT_model.pt" + # Remote files + b2_base_url = "https://static.joffreybvn.be/file/joffreybvn/resachatbot" - # External files + weight_file = "resa_BERT_model.pt" + MODEL_WEIGHT_URL = f"{b2_base_url}/{weight_file}" # Fine-tuned weights for BERT model + MODEL_WEIGHT_LOCAL_COPY = f"./assets/model/{weight_file}" + + classes_file = "labels.pickle" + MODEL_CLASSES_URL = f"{b2_base_url}/{classes_file}" + MODEL_CLASSES_LOCAL_COPY = f"./assets/model/{classes_file}" + + # Filters FILTERS_TOML = "./filters.toml" diff --git a/filters.toml b/filters.toml index cf6ae65..3c02fa0 100644 --- a/filters.toml +++ b/filters.toml @@ -1,16 +1,15 @@ # TOML document to store filters -[longtalk_hotel_reserve] +[longtalk_make_reservation] # Size of the room: How many people ? - [longtalk_hotel_reserve.people] + [longtalk_make_reservation.people] words = ["pearson", "people"] regex = '''(?P\d)\W%s''' threshold = 0.85 # Duration of the book: How long ? - [longtalk_hotel_reserve.duration] - words = ["day"] + [longtalk_make_reservation.duration] + words = ["day", "night"] regex = '''(?P\d)\W%s''' threshold = 0.85 - diff --git a/src/classifying/classifier.py b/src/classifying/classifier.py index 895ffbc..62f994f 100644 --- a/src/classifying/classifier.py +++ b/src/classifying/classifier.py @@ -16,35 +16,41 @@ class Classifier: def __init__(self): - # Load the classes + # Load the classes and the model self.labels = self._load_labels() - - # Download the model and instantiate it - self._download_model() self.model = self._load_model() @staticmethod - def _load_labels() -> dict: - """Load the dictionary labels from a pickle file and return it.""" - - with open('./assets/data/labels.pickle', 'rb') as handle: - return pickle.load(handle) - - @staticmethod - def _download_model(): - """ - Stream and download the model from a given url to the given path. - """ + def __load_remote_file(url: str, local: str): # Open the URL and a local file - with requests.get(config.MODEL_WEIGHT_URL, stream=True) as response: - with open(config.MODEL_WEIGHT_LOCAL_COPY, 'wb') as handle: + with requests.get(url, stream=True) as response: + with open(local, 'wb') as handle: # Stream the model to the local file for chunk in response.iter_content(chunk_size=8192): handle.write(chunk) + def _load_labels(self) -> dict: + """ + Load the dictionary labels from a remote pickle file and return it. + """ + + # Download and save the pickle locally + self.__load_remote_file(config.MODEL_CLASSES_URL, config.MODEL_CLASSES_LOCAL_COPY) + + # Load and return a dictionary + with open(config.MODEL_CLASSES_LOCAL_COPY, 'rb') as handle: + return pickle.load(handle) + def _load_model(self) -> BertForSequenceClassification: + """ + Load the weight of the model from a remote file (around 500 Mo), + instantiate and return the model. + """ + + # Download and save the weights locally + self.__load_remote_file(config.MODEL_WEIGHT_URL, config.MODEL_WEIGHT_LOCAL_COPY) # Instantiate the model model = BertForSequenceClassification.from_pretrained( diff --git a/src/matching/matcher.py b/src/matching/matcher.py index 4cba2f6..47063d4 100644 --- a/src/matching/matcher.py +++ b/src/matching/matcher.py @@ -58,7 +58,6 @@ def __load_filters() -> dict: def get_keywords(self, text: str, intent: str) -> dict: keywords = {} - if intent in self.filters: # Split the text into a list of words