Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Change modeling OOP #16

Merged
merged 1 commit into from
Sep 13, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
38 changes: 14 additions & 24 deletions app/api/router/predict.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,14 @@
# -*- coding: utf-8 -*-
import codecs
import numpy as np
import pickle
from typing import List

from fastapi import APIRouter, HTTPException
from fastapi import APIRouter
from starlette.concurrency import run_in_threadpool

from app import models
from app.api.schemas import ModelCorePrediction
from app.database import engine
from app.utils import my_model
from app.utils import ScikitLearnModel, my_model


models.Base.metadata.create_all(bind=engine)
Expand All @@ -23,7 +22,7 @@


@router.put("/insurance")
def predict_insurance(info: ModelCorePrediction, model_name: str):
async def predict_insurance(info: ModelCorePrediction, model_name: str):
"""
Get information and predict insurance fee
param:
Expand All @@ -39,30 +38,21 @@ def predict_insurance(info: ModelCorePrediction, model_name: str):
return:
insurance_fee: float
"""
query = """
SELECT model_file
FROM model_core
WHERE model_name='{}';
""".format(model_name)

reg_model = engine.execute(query).fetchone()
def sync_call(info, model_name):
model = ScikitLearnModel(model_name)
model.load_model()

if reg_model is None:
raise HTTPException(
status_code=404,
detail="Model Not Found",
headers={"X-Error": "Model Not Found"},
)
info = info.dict()
test_set = np.array([*info.values()]).reshape(1, -1)

loaded_model = pickle.loads(
codecs.decode(reg_model[0], 'base64'))
pred = model.predict_target(test_set)

info = info.dict()
test_set = np.array([*info.values()]).reshape(1, -1)
return {"result": pred.tolist()[0]}

pred = loaded_model.predict(test_set)
result = await run_in_threadpool(sync_call, info, model_name)

return {"result": pred.tolist()[0]}
return result


@router.put("/atmos")
Expand All @@ -71,7 +61,7 @@ async def predict_temperature(time_series: List[float]):
return "time series must have 72 values"

try:
tf_model = my_model.my_model
tf_model = my_model.model
time_series = np.array(time_series).reshape(1, -1, 1)
result = tf_model.predict(time_series)
return result.tolist()
Expand Down
68 changes: 41 additions & 27 deletions app/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,46 +11,60 @@

base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

# physical_devices = tf.config.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
physical_devices = tf.config.list_physical_devices('GPU')
if physical_devices:
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)


class MyModel:
def __init__(self):
self._my_model = None
class CoreModel:

def load_tf_model(self, model_name):
"""
* DB에 있는 텐서플로우 모델을 불러옵니다.
* 모델은 zip형식으로 압축되어 binary로 저장되어 있습니다.
* 모델의 이름을 받아 압축 해제 및 tf_model폴더 아래에 저장한 후 로드하여
텐서플로우 모델 객체를 반환합니다.
"""
def __init__(self, model_name):
self.model_name = model_name
self.model = None
self.query = """
SELECT model_file
FROM model_core
WHERE model_name='{}';
""".format(self.model_name)

query = f"""SELECT model_file
FROM model_core
WHERE model_name='{model_name}';"""
def load_model(self):
raise Exception

bin_data = engine.execute(query).fetchone()[0]
def predict_target(self, target_data):
return self.model.predict(target_data)

model_buffer = pickle.loads(codecs.decode(bin_data, "base64"))
model_path = os.path.join(base_dir, "tf_model", model_name)

with zipfile.ZipFile(model_buffer, "r") as bf:
bf.extractall(model_path)
tf_model = tf.keras.models.load_model(model_path)
class ScikitLearnModel(CoreModel):
def __init__(self, *args):
super().__init__(*args)

def load_model(self):
_model = engine.execute(self.query).fetchone()
if _model is None:
raise ValueError('Model Not Found!')

return tf_model
self.model = pickle.loads(
codecs.decode(_model[0], 'base64')
)


class TensorFlowModel(CoreModel):
def __init__(self, *args):
super().__init__(*args)

def load_model(self):
self._my_model = self.load_tf_model('test_model')
_model = engine.execute(self.query).fetchone()
if _model is None:
raise ValueError('Model Not Found!')
model_buffer = pickle.loads(codecs.decode(_model[0], "base64"))
model_path = os.path.join(base_dir, "tf_model", self.model_name)

@property
def my_model(self):
return self._my_model
with zipfile.ZipFile(model_buffer, "r") as bf:
bf.extractall(model_path)
self.model = tf.keras.models.load_model(model_path)


my_model = MyModel()
my_model = TensorFlowModel('test_model')
my_model.load_model()


Expand Down