diff --git a/app/api/router/predict.py b/app/api/router/predict.py index d82bff6..09c986c 100644 --- a/app/api/router/predict.py +++ b/app/api/router/predict.py @@ -1,7 +1,13 @@ # -*- coding: utf-8 -*- -from fastapi import APIRouter +import os +import pickle +import numpy as np +from fastapi.param_functions import Depends +from app.api.schemas import RegModelPrediction +from fastapi import APIRouter, HTTPException +from sqlalchemy.orm import Session -from app import models +from app import crud, models from app.database import engine from app.database import SessionLocal @@ -27,3 +33,47 @@ def get_db(): @router.get("/") def hello_world(): return {"message": "Hello predict"} + + +@router.put("/insurance") +def predict_insurance(info: RegModelPrediction, db: Session = Depends(get_db)): + """ + Get information and predict insurance fee + param: + info: + # 임시로 int형태를 받도록 제작 + # preprocess 단계를 거치도록 만들 예정 + age: int + sex: int + bmi: float + children: int + smoker: int + region: int + return: + insurance_fee: float + """ + reg_model = crud.get_reg_model(db, model_name=info.model_name) + + if reg_model: + loaded_model = pickle.load(open( + os.path.join(reg_model.path, f'{reg_model.model_name}.pkl'), + 'rb') + ) + test_set = np.array([ + info.age, + info.sex, + info.bmi, + info.children, + info.smoker, + info.region + ]).reshape(1, -1) + + pred = loaded_model.predict(test_set) + + return {"result": pred.tolist()[0]} + else: + raise HTTPException( + status_code=404, + detail="Model Not Found", + headers={"X-Error": "Model Not Found"}, + ) diff --git a/app/api/schemas.py b/app/api/schemas.py index 57d95dc..4506481 100644 --- a/app/api/schemas.py +++ b/app/api/schemas.py @@ -43,3 +43,21 @@ class ClfModelCreate(ClfModelBase): class ClfModel(ClfModelBase): class Config: orm_mode = True + + +class RegModelBase(BaseModel): + model_name: str + + +class RegModelPrediction(RegModelBase): + age: int + sex: int + bmi: float + children: int + smoker: int + region: int + + +class RegModel(RegModelBase): + class Config: + orm_mode = True diff --git a/app/crud.py b/app/crud.py index c9181b4..3d91ed6 100644 --- a/app/crud.py +++ b/app/crud.py @@ -31,3 +31,9 @@ def create_clf_model(db: Session, clf_model: schemas.ClfModelCreate): db.commit() db.refresh(db_cf_model) return db_cf_model + + +def get_reg_model(db: Session, model_name: schemas.RegModelBase): + return db.query(models.RegModel).filter( + models.RegModel.model_name == model_name + ).first() diff --git a/app/models.py b/app/models.py index 3c85211..d99b57b 100644 --- a/app/models.py +++ b/app/models.py @@ -42,7 +42,7 @@ class RegModel(Base): path = Column(String, nullable=False) model_metadata = relationship( - "reg_model_metadata", backref="reg_model.model_name") + "RegModelMetadata", backref="reg_model.model_name") class RegModelMetadata(Base):