-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #26 from State-of-The-MLOps/develop
Develop
- Loading branch information
Showing
26 changed files
with
1,614 additions
and
37 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,6 @@ | ||
.env | ||
.env | ||
*.pkl | ||
__pycache__ | ||
tf_model/**/* | ||
log.txt | ||
experiments/**/temp/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,6 @@ | ||
repos: | ||
|
||
- repo: https://github.com/PyCQA/flake8 | ||
rev: 3.9.2 | ||
- repo: https://github.com/psf/black | ||
rev: 20.8b1 | ||
hooks: | ||
- id: flake8 | ||
|
||
|
||
|
||
|
||
- id: black | ||
language_version: python3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,6 @@ | ||
# MLOps | ||
👊 Build MLOps system step by step 👊 | ||
👊 Build MLOps system step by step 👊 | ||
|
||
## 문서 | ||
|
||
- [API DOCS](./docs/api-list.md) |
File renamed without changes.
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
# -*- coding: utf-8 -*- | ||
from typing import List | ||
|
||
|
||
import numpy as np | ||
from fastapi import APIRouter | ||
from starlette.concurrency import run_in_threadpool | ||
|
||
from app import models | ||
from app.api.schemas import ModelCorePrediction | ||
from app.database import engine | ||
from app.utils import ScikitLearnModel, my_model | ||
from logger import L | ||
|
||
|
||
models.Base.metadata.create_all(bind=engine) | ||
|
||
|
||
router = APIRouter( | ||
prefix="/predict", tags=["predict"], responses={404: {"description": "Not Found"}} | ||
) | ||
|
||
|
||
@router.put("/insurance") | ||
async def predict_insurance(info: ModelCorePrediction, model_name: str): | ||
""" | ||
정보를 입력받아 보험료를 예측하여 반환합니다. | ||
Args: | ||
info(dict): 다음의 값들을 입력받습니다. age(int), sex(int), bmi(float), children(int), smoker(int), region(int) | ||
Returns: | ||
insurance_fee(float): 보험료 예측값입니다. | ||
""" | ||
|
||
def sync_call(info, model_name): | ||
""" | ||
none sync 함수를 sync로 만들어 주기 위한 함수이며 입출력은 부모 함수와 같습니다. | ||
""" | ||
model = ScikitLearnModel(model_name) | ||
model.load_model() | ||
|
||
info = info.dict() | ||
test_set = np.array([*info.values()]).reshape(1, -1) | ||
|
||
pred = model.predict_target(test_set) | ||
return {"result": pred.tolist()[0]} | ||
|
||
try: | ||
result = await run_in_threadpool(sync_call, info, model_name) | ||
L.info( | ||
f"Predict Args info: {info}\n\tmodel_name: {model_name}\n\tPrediction Result: {result}" | ||
) | ||
return result | ||
|
||
except Exception as e: | ||
L.error(e) | ||
return {"result": "Can't predict", "error": str(e)} | ||
|
||
|
||
@router.put("/atmos") | ||
async def predict_temperature(time_series: List[float]): | ||
""" | ||
온도 1시간 간격 시계열을 입력받아 이후 24시간 동안의 온도를 1시간 간격의 시계열로 예측합니다. | ||
Args: | ||
time_series(List): 72시간 동안의 1시간 간격 온도 시계열 입니다. 72개의 원소를 가져야 합니다. | ||
Returns: | ||
List[float]: 입력받은 시간 이후 24시간 동안의 1시간 간격 온도 예측 시계열 입니다. | ||
""" | ||
if len(time_series) != 72: | ||
L.error(f"input time_series: {time_series} is not valid") | ||
return {"result": "time series must have 72 values", "error": None} | ||
|
||
def sync_pred_ts(time_series): | ||
""" | ||
none sync 함수를 sync로 만들어 주기 위한 함수이며 입출력은 부모 함수와 같습니다. | ||
""" | ||
time_series = np.array(time_series).reshape(1, -1, 1) | ||
result = my_model.predict_target(time_series) | ||
L.info( | ||
f"Predict Args info: {time_series.flatten().tolist()}\n\tmodel_name: {my_model.model_name}\n\tPrediction Result: {result.tolist()[0]}" | ||
) | ||
|
||
return {"result": result, "error": None} | ||
|
||
try: | ||
result = await run_in_threadpool(sync_pred_ts, time_series) | ||
return result.tolist() | ||
|
||
except Exception as e: | ||
L.error(e) | ||
return {"result": "Can't predict", "error": str(e)} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
import multiprocessing | ||
import os | ||
import re | ||
import subprocess | ||
|
||
|
||
from fastapi import APIRouter | ||
|
||
from app.utils import NniWatcher, ExperimentOwl, base_dir, get_free_port, write_yml | ||
from logger import L | ||
|
||
router = APIRouter( | ||
prefix="/train", tags=["train"], responses={404: {"description": "Not Found"}} | ||
) | ||
|
||
|
||
@router.put("/insurance") | ||
def train_insurance( | ||
experiment_name: str = "exp1", | ||
experimenter: str = "DongUk", | ||
model_name: str = "insurance_fee_model", | ||
version: float = 0.1, | ||
): | ||
""" | ||
insurance와 관련된 학습을 실행하기 위한 API입니다. | ||
Args: | ||
experiment_name (str): 실험이름. 기본 값: exp1 | ||
experimenter (str): 실험자의 이름. 기본 값: DongUk | ||
model_name (str): 모델의 이름. 기본 값: insurance_fee_model | ||
version (float): 실험의 버전. 기본 값: 0.1 | ||
Returns: | ||
msg: 실험 실행의 성공과 상관없이 포트번호를 포함한 NNI Dashboard의 주소를 반환합니다. | ||
Note: | ||
실험의 최종 결과를 반환하지 않습니다. | ||
""" | ||
PORT = get_free_port() | ||
L.info( | ||
f"Train Args info\n\texperiment_name: {experiment_name}\n\texperimenter: {experimenter}\n\tmodel_name: {model_name}\n\tversion: {version}" | ||
) | ||
path = "experiments/insurance/" | ||
try: | ||
write_yml(path, experiment_name, experimenter, model_name, version) | ||
nni_create_result = subprocess.getoutput( | ||
"nnictl create --port {} --config {}/{}.yml".format(PORT, path, model_name) | ||
) | ||
sucs_msg = "Successfully started experiment!" | ||
|
||
if sucs_msg in nni_create_result: | ||
p = re.compile(r"The experiment id is ([a-zA-Z0-9]+)\n") | ||
expr_id = p.findall(nni_create_result)[0] | ||
nni_watcher = NniWatcher(expr_id, experiment_name, experimenter, version) | ||
m_process = multiprocessing.Process(target=nni_watcher.excute) | ||
m_process.start() | ||
|
||
L.info(nni_create_result) | ||
return {"msg": nni_create_result, "error": None} | ||
|
||
except Exception as e: | ||
L.error(e) | ||
return {"msg": "Can't start experiment", "error": str(e)} | ||
|
||
|
||
@router.put("/atmos") | ||
def train_atmos(expr_name: str): | ||
""" | ||
온도 시계열과 관련된 학습을 실행하기 위한 API입니다. | ||
Args: | ||
expr_name(str): NNI가 실행할 실험의 이름 입니다. 이 파라미터를 기반으로 project_dir/experiments/[expr_name] 경로로 찾아가 config.yml을 이용하여 NNI를 실행합니다. | ||
Returns: | ||
str: NNI실험이 실행된 결과값을 반환하거나 실행과정에서 발생한 에러 메세지를 반환합니다. | ||
Note: | ||
실험의 최종 결과를 반환하지 않습니다. | ||
""" | ||
|
||
nni_port = get_free_port() | ||
expr_path = os.path.join(base_dir, "experiments", expr_name) | ||
|
||
try: | ||
nni_create_result = subprocess.getoutput( | ||
"nnictl create --port {} --config {}/config.yml".format(nni_port, expr_path) | ||
) | ||
sucs_msg = "Successfully started experiment!" | ||
|
||
if sucs_msg in nni_create_result: | ||
p = re.compile(r"The experiment id is ([a-zA-Z0-9]+)\n") | ||
expr_id = p.findall(nni_create_result)[0] | ||
check_expr = ExperimentOwl(expr_id, expr_name, expr_path) | ||
check_expr.add("update_tfmodeldb") | ||
check_expr.add("modelfile_cleaner") | ||
|
||
m_process = multiprocessing.Process(target=check_expr.execute) | ||
m_process.start() | ||
|
||
L.info(nni_create_result) | ||
return {"msg": nni_create_result, "error": None} | ||
|
||
else: | ||
L.error(nni_create_result) | ||
return {"msg": nni_create_result, "error": None} | ||
|
||
except Exception as e: | ||
L.error(e) | ||
return {"msg": "Can't start experiment", "error": str(e)} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
from pydantic import BaseModel | ||
|
||
|
||
class ModelCoreBase(BaseModel): | ||
model_name: str | ||
|
||
|
||
class ModelCorePrediction(BaseModel): | ||
""" | ||
predict_insurance API의 입력 값 검증을 위한 pydantic 클래스입니다. | ||
Attributes: | ||
age(int) | ||
sex(int) | ||
bmi(float) | ||
children(int) | ||
smoker(int) | ||
region(int) | ||
""" | ||
|
||
age: int | ||
sex: int | ||
bmi: float | ||
children: int | ||
smoker: int | ||
region: int | ||
|
||
|
||
class ModelCore(ModelCoreBase): | ||
class Config: | ||
orm_mode = True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,31 +1,50 @@ | ||
import os | ||
|
||
|
||
from dotenv import load_dotenv | ||
from sqlalchemy import create_engine | ||
from sqlalchemy.orm import sessionmaker | ||
from sqlalchemy.ext.declarative import declarative_base | ||
|
||
from dotenv import load_dotenv | ||
load_dotenv(verbose=True) | ||
|
||
|
||
def connect(db): | ||
""" | ||
database와의 연결을 위한 함수 입니다. | ||
Args: | ||
db(str): 사용할 데이터베이스의 이름을 전달받습니다. | ||
load_dotenv(verbose=True) | ||
Returns: | ||
created database engine: 데이터베이스에 연결된 객체를 반환합니다. | ||
Examples: | ||
>>> engine = connect("my_db") | ||
>>> query = "SHOW timezone;" | ||
>>> engine.execute(query).fetchall() | ||
[('Asia/Seoul',)] | ||
>>> print(engine) | ||
Engine(postgresql://postgres:***@127.0.0.1:5432/my_db) | ||
""" | ||
print(db) | ||
|
||
POSTGRES_USER = os.getenv("POSTGRES_USER") | ||
POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD") | ||
POSTGRES_PORT = os.getenv("POSTGRES_PORT") | ||
POSTGRES_SERVER = os.getenv("POSTGRES_SERVER") | ||
POSTGRES_DB = db | ||
SQLALCHEMY_DATABASE_URL = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}\ | ||
@{POSTGRES_SERVER}:{POSTGRES_PORT}/{POSTGRES_DB}" | ||
|
||
SQLALCHEMY_DATABASE_URL = ( | ||
f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@" | ||
+ f"{POSTGRES_SERVER}:{POSTGRES_PORT}/{db}" | ||
) | ||
|
||
connection = create_engine(SQLALCHEMY_DATABASE_URL) | ||
|
||
return connection | ||
|
||
|
||
POSTGRES_DB = os.getenv("POSTGRES_DB") | ||
|
||
engine = connect(POSTGRES_DB) | ||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) | ||
Base = declarative_base() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
# -*- coding: utf-8 -*- | ||
import datetime | ||
|
||
|
||
from sqlalchemy import Column, Integer, String, FLOAT, DateTime, ForeignKey, LargeBinary | ||
from sqlalchemy.sql.functions import now | ||
from sqlalchemy.orm import relationship | ||
|
||
from app.database import Base | ||
|
||
KST = datetime.timezone(datetime.timedelta(hours=9)) | ||
|
||
|
||
class ModelCore(Base): | ||
__tablename__ = "model_core" | ||
|
||
model_name = Column(String, primary_key=True) | ||
model_file = Column(LargeBinary, nullable=False) | ||
|
||
model_metadata_relation = relationship( | ||
"ModelMetadata", backref="model_core.model_name" | ||
) | ||
|
||
|
||
class ModelMetadata(Base): | ||
__tablename__ = "model_metadata" | ||
|
||
experiment_name = Column(String, primary_key=True) | ||
model_core_name = Column( | ||
String, ForeignKey("model_core.model_name"), nullable=False | ||
) | ||
experimenter = Column(String, nullable=False) | ||
version = Column(FLOAT) | ||
train_mae = Column(FLOAT, nullable=False) | ||
val_mae = Column(FLOAT, nullable=False) | ||
train_mse = Column(FLOAT, nullable=False) | ||
val_mse = Column(FLOAT, nullable=False) | ||
created_at = Column(DateTime(timezone=True), server_default=now()) | ||
|
||
|
||
class TempModelData(Base): | ||
__tablename__ = "temp_model_data" | ||
|
||
id = Column(Integer, primary_key=True, autoincrement=True) | ||
model_name = Column(String, nullable=False) | ||
model_file = Column(LargeBinary, nullable=False) | ||
experiment_name = Column(String, nullable=False) | ||
experimenter = Column(String, nullable=False) | ||
version = Column(FLOAT, nullable=False) | ||
train_mae = Column(FLOAT, nullable=False) | ||
val_mae = Column(FLOAT, nullable=False) | ||
train_mse = Column(FLOAT, nullable=False) | ||
val_mse = Column(FLOAT, nullable=False) |
Oops, something went wrong.