diff --git a/app/api/router/predict.py b/app/api/router/predict.py index 08f4a5f..e9aabe1 100644 --- a/app/api/router/predict.py +++ b/app/api/router/predict.py @@ -1,9 +1,10 @@ # -*- coding: utf-8 -*- import os +import codecs import pickle import numpy as np from fastapi.param_functions import Depends -from app.api.schemas import RegModelPrediction +from app.api.schemas import ModelCorePrediction from fastapi import APIRouter, HTTPException from sqlalchemy.orm import Session @@ -31,7 +32,7 @@ def get_db(): @router.put("/insurance") -def predict_insurance(info: RegModelPrediction, db: Session = Depends(get_db)): +def predict_insurance(info: ModelCorePrediction, model_name: str, db: Session = Depends(get_db)): """ Get information and predict insurance fee param: @@ -47,13 +48,12 @@ def predict_insurance(info: RegModelPrediction, db: Session = Depends(get_db)): return: insurance_fee: float """ - reg_model = crud.get_reg_model(db, model_name=info.model_name) + reg_model = crud.get_reg_model(db, model_name=model_name) if reg_model: - loaded_model = pickle.load(open( - os.path.join(reg_model.path, f'{reg_model.model_name}.pkl'), - 'rb') - ) + loaded_model = pickle.loads( + codecs.decode(reg_model.model_file, 'base64')) + test_set = np.array([ info.age, info.sex, diff --git a/app/api/schemas.py b/app/api/schemas.py index 0df4913..87698ac 100644 --- a/app/api/schemas.py +++ b/app/api/schemas.py @@ -1,11 +1,11 @@ from pydantic import BaseModel -class RegModelBase(BaseModel): +class ModelCoreBase(BaseModel): model_name: str -class RegModelPrediction(RegModelBase): +class ModelCorePrediction(BaseModel): age: int sex: int bmi: float @@ -14,6 +14,6 @@ class RegModelPrediction(RegModelBase): region: int -class RegModel(RegModelBase): +class ModelCore(ModelCoreBase): class Config: orm_mode = True diff --git a/app/crud.py b/app/crud.py index a00180a..640f9d1 100644 --- a/app/crud.py +++ b/app/crud.py @@ -4,7 +4,5 @@ from app.api import schemas -def get_reg_model(db: Session, model_name: schemas.RegModelBase): - return db.query(models.RegModel).filter( - models.RegModel.model_name == model_name - ).first() +def get_reg_model(db: Session, model_name: schemas.ModelCoreBase): + return db.query(models.ModelCore).filter_by(model_name=model_name).first() diff --git a/app/models.py b/app/models.py index 41ddaf2..65b58d1 100644 --- a/app/models.py +++ b/app/models.py @@ -1,6 +1,6 @@ # -*- coding: utf-8 -*- import datetime -from sqlalchemy import Column, String, FLOAT, DateTime, ForeignKey +from sqlalchemy import Column, Integer, String, FLOAT, DateTime, ForeignKey, LargeBinary from sqlalchemy.sql.functions import now from sqlalchemy.orm import relationship @@ -9,22 +9,22 @@ KST = datetime.timezone(datetime.timedelta(hours=9)) -class RegModel(Base): - __tablename__ = 'reg_model' +class ModelCore(Base): + __tablename__ = 'model_core' model_name = Column(String, primary_key=True) - path = Column(String, nullable=False) + model_file = Column(LargeBinary, nullable=False) - model_metadata = relationship( - "RegModelMetadata", backref="reg_model.model_name") + model_metadata_relation = relationship( + "ModelMetadata", backref="model_core.model_name") -class RegModelMetadata(Base): - __tablename__ = 'reg_model_metadata' +class ModelMetadata(Base): + __tablename__ = 'model_metadata' experiment_name = Column(String, primary_key=True) - reg_model_name = Column(String, ForeignKey( - 'reg_model.model_name'), nullable=False) + model_core_name = Column(String, ForeignKey( + 'model_core.model_name'), nullable=False) experimenter = Column(String, nullable=False) version = Column(FLOAT) train_mae = Column(FLOAT, nullable=False) diff --git a/module/config.yml b/experiments/insurance/config.yml similarity index 100% rename from module/config.yml rename to experiments/insurance/config.yml diff --git a/module/query.py b/experiments/insurance/query.py similarity index 63% rename from module/query.py rename to experiments/insurance/query.py index 992cde5..f1fb600 100644 --- a/module/query.py +++ b/experiments/insurance/query.py @@ -1,18 +1,18 @@ # INSERT -INSERT_REG_MODEL = """ - INSERT INTO reg_model ( +INSERT_MODEL_CORE = """ + INSERT INTO model_core ( model_name, - path + model_file ) VALUES( %s, - %s + '%s' ) """ -INSERT_REG_MODEL_METADATA = """ - INSERT INTO reg_model_metadata ( +INSERT_MODEL_METADATA = """ + INSERT INTO model_metadata ( experiment_name, - reg_model_name, + model_core_name, experimenter, version, train_mae, @@ -32,8 +32,8 @@ """ # UPDATE -UPDATE_REG_MODEL_METADATA = """ - UPDATE reg_model_metadata +UPDATE_MODEL_METADATA = """ + UPDATE model_metadata SET train_mae = %s, val_mae = %s, @@ -41,6 +41,13 @@ val_mse = %s WHERE experiment_name = %s """ +UPDATE_MODEL_CORE = """ + UPDATE model_core + SET + model_file = '%s' + WHERE + model_name = %s + """ # pd READ_SQL SELECT_ALL_INSURANCE = """ @@ -50,12 +57,12 @@ SELECT_VAL_MAE = """ SELECT val_mae - FROM reg_model_metadata - WHERE reg_model_name = %s + FROM model_metadata + WHERE model_core_name = %s """ -SELECT_REG_MODEL = """ +SELECT_MODEL_CORE = """ SELECT * - FROM reg_model + FROM model_core WHERE model_name = %s """ diff --git a/module/search_space.json b/experiments/insurance/search_space.json similarity index 100% rename from module/search_space.json rename to experiments/insurance/search_space.json diff --git a/module/trial.py b/experiments/insurance/trial.py similarity index 90% rename from module/trial.py rename to experiments/insurance/trial.py index edc2fee..81eeef4 100644 --- a/module/trial.py +++ b/experiments/insurance/trial.py @@ -1,4 +1,5 @@ import os +import codecs import pickle from dotenv import load_dotenv @@ -108,14 +109,13 @@ def main(params, df, engine, experiment_info, connection): tr_mse_mean = np.mean(tr_mse) tr_mae_mean = np.mean(tr_mae) - best_model = pd.read_sql(SELECT_REG_MODEL % (model_name), engine) + best_model = pd.read_sql(SELECT_MODEL_CORE % (model_name), engine) if len(best_model) == 0: - with open(f"{os.path.join(path, model_name)}.pkl".replace("'", ""), "wb") as f: - pickle.dump(model, f) - connection.execute(INSERT_REG_MODEL % (model_name, path)) - connection.execute(INSERT_REG_MODEL_METADATA % ( + pickled_model = codecs.encode(pickle.dumps(model), "base64").decode() + connection.execute(INSERT_MODEL_CORE % (model_name, pickled_model)) + connection.execute(INSERT_MODEL_METADATA % ( experiment_name, model_name, experimenter, @@ -130,10 +130,13 @@ def main(params, df, engine, experiment_info, connection): best_model_metadata = pd.read_sql( SELECT_VAL_MAE % (model_name), engine) saved_score = best_model_metadata.values[0] + if saved_score > valid_mae: - with open(f"{os.path.join(path, model_name)}.pkl".replace("'", ""), "wb") as f: - pickle.dump(model, f) - connection.execute(UPDATE_REG_MODEL_METADATA % ( + pickled_model = codecs.encode( + pickle.dumps(model), "base64").decode() + + connection.execute(UPDATE_MODEL_CORE % (pickled_model, model_name)) + connection.execute(UPDATE_MODEL_METADATA % ( tr_mae_mean, cv_mae_mean, tr_mse_mean, diff --git a/module/__init__.py b/module/__init__.py deleted file mode 100644 index e69de29..0000000 diff --git a/test_mnist.pkl b/test_mnist.pkl deleted file mode 100644 index 72b0f57..0000000 Binary files a/test_mnist.pkl and /dev/null differ