From c9a501ae5492df248778bf66f330cd36a0887f60 Mon Sep 17 00:00:00 2001 From: ehddnr301 Date: Mon, 13 Sep 2021 10:59:28 +0900 Subject: [PATCH] Change modeling OOP MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 모델을 load하는것과 predict하는 부분을 class로 만들어서 좀더 객체지향적 프로그래밍을 하고자 하였습니다. --- app/api/router/predict.py | 38 ++++++++-------------- app/utils.py | 68 +++++++++++++++++++++++---------------- 2 files changed, 55 insertions(+), 51 deletions(-) diff --git a/app/api/router/predict.py b/app/api/router/predict.py index 136c5a0..9c72462 100644 --- a/app/api/router/predict.py +++ b/app/api/router/predict.py @@ -1,15 +1,14 @@ # -*- coding: utf-8 -*- -import codecs import numpy as np -import pickle from typing import List -from fastapi import APIRouter, HTTPException +from fastapi import APIRouter +from starlette.concurrency import run_in_threadpool from app import models from app.api.schemas import ModelCorePrediction from app.database import engine -from app.utils import my_model +from app.utils import ScikitLearnModel, my_model models.Base.metadata.create_all(bind=engine) @@ -23,7 +22,7 @@ @router.put("/insurance") -def predict_insurance(info: ModelCorePrediction, model_name: str): +async def predict_insurance(info: ModelCorePrediction, model_name: str): """ Get information and predict insurance fee param: @@ -39,30 +38,21 @@ def predict_insurance(info: ModelCorePrediction, model_name: str): return: insurance_fee: float """ - query = """ - SELECT model_file - FROM model_core - WHERE model_name='{}'; - """.format(model_name) - reg_model = engine.execute(query).fetchone() + def sync_call(info, model_name): + model = ScikitLearnModel(model_name) + model.load_model() - if reg_model is None: - raise HTTPException( - status_code=404, - detail="Model Not Found", - headers={"X-Error": "Model Not Found"}, - ) + info = info.dict() + test_set = np.array([*info.values()]).reshape(1, -1) - loaded_model = pickle.loads( - codecs.decode(reg_model[0], 'base64')) + pred = model.predict_target(test_set) - info = info.dict() - test_set = np.array([*info.values()]).reshape(1, -1) + return {"result": pred.tolist()[0]} - pred = loaded_model.predict(test_set) + result = await run_in_threadpool(sync_call, info, model_name) - return {"result": pred.tolist()[0]} + return result @router.put("/atmos") @@ -71,7 +61,7 @@ async def predict_temperature(time_series: List[float]): return "time series must have 72 values" try: - tf_model = my_model.my_model + tf_model = my_model.model time_series = np.array(time_series).reshape(1, -1, 1) result = tf_model.predict(time_series) return result.tolist() diff --git a/app/utils.py b/app/utils.py index 45d2ced..1f9b2e7 100644 --- a/app/utils.py +++ b/app/utils.py @@ -11,46 +11,60 @@ base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) -# physical_devices = tf.config.list_physical_devices('GPU') -# tf.config.experimental.set_memory_growth(physical_devices[0], enable=True) +physical_devices = tf.config.list_physical_devices('GPU') +if physical_devices: + tf.config.experimental.set_memory_growth(physical_devices[0], enable=True) -class MyModel: - def __init__(self): - self._my_model = None +class CoreModel: - def load_tf_model(self, model_name): - """ - * DB에 있는 텐서플로우 모델을 불러옵니다. - * 모델은 zip형식으로 압축되어 binary로 저장되어 있습니다. - * 모델의 이름을 받아 압축 해제 및 tf_model폴더 아래에 저장한 후 로드하여 - 텐서플로우 모델 객체를 반환합니다. - """ + def __init__(self, model_name): + self.model_name = model_name + self.model = None + self.query = """ + SELECT model_file + FROM model_core + WHERE model_name='{}'; + """.format(self.model_name) - query = f"""SELECT model_file - FROM model_core - WHERE model_name='{model_name}';""" + def load_model(self): + raise Exception - bin_data = engine.execute(query).fetchone()[0] + def predict_target(self, target_data): + return self.model.predict(target_data) - model_buffer = pickle.loads(codecs.decode(bin_data, "base64")) - model_path = os.path.join(base_dir, "tf_model", model_name) - with zipfile.ZipFile(model_buffer, "r") as bf: - bf.extractall(model_path) - tf_model = tf.keras.models.load_model(model_path) +class ScikitLearnModel(CoreModel): + def __init__(self, *args): + super().__init__(*args) + + def load_model(self): + _model = engine.execute(self.query).fetchone() + if _model is None: + raise ValueError('Model Not Found!') - return tf_model + self.model = pickle.loads( + codecs.decode(_model[0], 'base64') + ) + + +class TensorFlowModel(CoreModel): + def __init__(self, *args): + super().__init__(*args) def load_model(self): - self._my_model = self.load_tf_model('test_model') + _model = engine.execute(self.query).fetchone() + if _model is None: + raise ValueError('Model Not Found!') + model_buffer = pickle.loads(codecs.decode(_model[0], "base64")) + model_path = os.path.join(base_dir, "tf_model", self.model_name) - @property - def my_model(self): - return self._my_model + with zipfile.ZipFile(model_buffer, "r") as bf: + bf.extractall(model_path) + self.model = tf.keras.models.load_model(model_path) -my_model = MyModel() +my_model = TensorFlowModel('test_model') my_model.load_model()