Skip to content

Commit

Permalink
Merge pull request #8 from State-of-The-MLOps/feature/model_save
Browse files Browse the repository at this point in the history
Feature/model save
  • Loading branch information
chl8469 authored Sep 8, 2021
2 parents 5f1b8a4 + be510a1 commit a3105b5
Show file tree
Hide file tree
Showing 10 changed files with 53 additions and 45 deletions.
14 changes: 7 additions & 7 deletions app/api/router/predict.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
# -*- coding: utf-8 -*-
import os
import codecs
import pickle
import numpy as np
from fastapi.param_functions import Depends
from app.api.schemas import RegModelPrediction
from app.api.schemas import ModelCorePrediction
from fastapi import APIRouter, HTTPException
from sqlalchemy.orm import Session

Expand Down Expand Up @@ -31,7 +32,7 @@ def get_db():


@router.put("/insurance")
def predict_insurance(info: RegModelPrediction, db: Session = Depends(get_db)):
def predict_insurance(info: ModelCorePrediction, model_name: str, db: Session = Depends(get_db)):
"""
Get information and predict insurance fee
param:
Expand All @@ -47,13 +48,12 @@ def predict_insurance(info: RegModelPrediction, db: Session = Depends(get_db)):
return:
insurance_fee: float
"""
reg_model = crud.get_reg_model(db, model_name=info.model_name)
reg_model = crud.get_reg_model(db, model_name=model_name)

if reg_model:
loaded_model = pickle.load(open(
os.path.join(reg_model.path, f'{reg_model.model_name}.pkl'),
'rb')
)
loaded_model = pickle.loads(
codecs.decode(reg_model.model_file, 'base64'))

test_set = np.array([
info.age,
info.sex,
Expand Down
6 changes: 3 additions & 3 deletions app/api/schemas.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
from pydantic import BaseModel


class RegModelBase(BaseModel):
class ModelCoreBase(BaseModel):
model_name: str


class RegModelPrediction(RegModelBase):
class ModelCorePrediction(BaseModel):
age: int
sex: int
bmi: float
Expand All @@ -14,6 +14,6 @@ class RegModelPrediction(RegModelBase):
region: int


class RegModel(RegModelBase):
class ModelCore(ModelCoreBase):
class Config:
orm_mode = True
6 changes: 2 additions & 4 deletions app/crud.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,5 @@
from app.api import schemas


def get_reg_model(db: Session, model_name: schemas.RegModelBase):
return db.query(models.RegModel).filter(
models.RegModel.model_name == model_name
).first()
def get_reg_model(db: Session, model_name: schemas.ModelCoreBase):
return db.query(models.ModelCore).filter_by(model_name=model_name).first()
20 changes: 10 additions & 10 deletions app/models.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# -*- coding: utf-8 -*-
import datetime
from sqlalchemy import Column, String, FLOAT, DateTime, ForeignKey
from sqlalchemy import Column, Integer, String, FLOAT, DateTime, ForeignKey, LargeBinary
from sqlalchemy.sql.functions import now
from sqlalchemy.orm import relationship

Expand All @@ -9,22 +9,22 @@
KST = datetime.timezone(datetime.timedelta(hours=9))


class RegModel(Base):
__tablename__ = 'reg_model'
class ModelCore(Base):
__tablename__ = 'model_core'

model_name = Column(String, primary_key=True)
path = Column(String, nullable=False)
model_file = Column(LargeBinary, nullable=False)

model_metadata = relationship(
"RegModelMetadata", backref="reg_model.model_name")
model_metadata_relation = relationship(
"ModelMetadata", backref="model_core.model_name")


class RegModelMetadata(Base):
__tablename__ = 'reg_model_metadata'
class ModelMetadata(Base):
__tablename__ = 'model_metadata'

experiment_name = Column(String, primary_key=True)
reg_model_name = Column(String, ForeignKey(
'reg_model.model_name'), nullable=False)
model_core_name = Column(String, ForeignKey(
'model_core.model_name'), nullable=False)
experimenter = Column(String, nullable=False)
version = Column(FLOAT)
train_mae = Column(FLOAT, nullable=False)
Expand Down
File renamed without changes.
33 changes: 20 additions & 13 deletions module/query.py → experiments/insurance/query.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,18 @@
# INSERT
INSERT_REG_MODEL = """
INSERT INTO reg_model (
INSERT_MODEL_CORE = """
INSERT INTO model_core (
model_name,
path
model_file
) VALUES(
%s,
%s
'%s'
)
"""

INSERT_REG_MODEL_METADATA = """
INSERT INTO reg_model_metadata (
INSERT_MODEL_METADATA = """
INSERT INTO model_metadata (
experiment_name,
reg_model_name,
model_core_name,
experimenter,
version,
train_mae,
Expand All @@ -32,15 +32,22 @@
"""

# UPDATE
UPDATE_REG_MODEL_METADATA = """
UPDATE reg_model_metadata
UPDATE_MODEL_METADATA = """
UPDATE model_metadata
SET
train_mae = %s,
val_mae = %s,
train_mse = %s,
val_mse = %s
WHERE experiment_name = %s
"""
UPDATE_MODEL_CORE = """
UPDATE model_core
SET
model_file = '%s'
WHERE
model_name = %s
"""

# pd READ_SQL
SELECT_ALL_INSURANCE = """
Expand All @@ -50,12 +57,12 @@

SELECT_VAL_MAE = """
SELECT val_mae
FROM reg_model_metadata
WHERE reg_model_name = %s
FROM model_metadata
WHERE model_core_name = %s
"""

SELECT_REG_MODEL = """
SELECT_MODEL_CORE = """
SELECT *
FROM reg_model
FROM model_core
WHERE model_name = %s
"""
File renamed without changes.
19 changes: 11 additions & 8 deletions module/trial.py → experiments/insurance/trial.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import os
import codecs
import pickle

from dotenv import load_dotenv
Expand Down Expand Up @@ -108,14 +109,13 @@ def main(params, df, engine, experiment_info, connection):
tr_mse_mean = np.mean(tr_mse)
tr_mae_mean = np.mean(tr_mae)

best_model = pd.read_sql(SELECT_REG_MODEL % (model_name), engine)
best_model = pd.read_sql(SELECT_MODEL_CORE % (model_name), engine)

if len(best_model) == 0:

with open(f"{os.path.join(path, model_name)}.pkl".replace("'", ""), "wb") as f:
pickle.dump(model, f)
connection.execute(INSERT_REG_MODEL % (model_name, path))
connection.execute(INSERT_REG_MODEL_METADATA % (
pickled_model = codecs.encode(pickle.dumps(model), "base64").decode()
connection.execute(INSERT_MODEL_CORE % (model_name, pickled_model))
connection.execute(INSERT_MODEL_METADATA % (
experiment_name,
model_name,
experimenter,
Expand All @@ -130,10 +130,13 @@ def main(params, df, engine, experiment_info, connection):
best_model_metadata = pd.read_sql(
SELECT_VAL_MAE % (model_name), engine)
saved_score = best_model_metadata.values[0]

if saved_score > valid_mae:
with open(f"{os.path.join(path, model_name)}.pkl".replace("'", ""), "wb") as f:
pickle.dump(model, f)
connection.execute(UPDATE_REG_MODEL_METADATA % (
pickled_model = codecs.encode(
pickle.dumps(model), "base64").decode()

connection.execute(UPDATE_MODEL_CORE % (pickled_model, model_name))
connection.execute(UPDATE_MODEL_METADATA % (
tr_mae_mean,
cv_mae_mean,
tr_mse_mean,
Expand Down
Empty file removed module/__init__.py
Empty file.
Binary file removed test_mnist.pkl
Binary file not shown.

0 comments on commit a3105b5

Please sign in to comment.