diff --git a/.gitignore b/.gitignore index 2eea525..867b6d8 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,6 @@ -.env \ No newline at end of file +.env +*.pkl +__pycache__ +tf_model/**/* +log.txt +experiments/**/temp/ diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 84a6bde..3cb791c 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,10 +1,6 @@ repos: - - - repo: https://github.com/PyCQA/flake8 - rev: 3.9.2 + - repo: https://github.com/psf/black + rev: 20.8b1 hooks: - - id: flake8 - - - - + - id: black + language_version: python3 diff --git a/README.md b/README.md index 8ffc8f4..f70e538 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,6 @@ # MLOps -๐Ÿ‘Š Build MLOps system step by step ๐Ÿ‘Š \ No newline at end of file +๐Ÿ‘Š Build MLOps system step by step ๐Ÿ‘Š + +## ๋ฌธ์„œ + +- [API DOCS](./docs/api-list.md) \ No newline at end of file diff --git a/app/router/__init__.py b/app/api/__init__.py similarity index 100% rename from app/router/__init__.py rename to app/api/__init__.py diff --git a/app/api/router/__init__.py b/app/api/router/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/app/api/router/predict.py b/app/api/router/predict.py new file mode 100644 index 0000000..4c76874 --- /dev/null +++ b/app/api/router/predict.py @@ -0,0 +1,94 @@ +# -*- coding: utf-8 -*- +from typing import List + + +import numpy as np +from fastapi import APIRouter +from starlette.concurrency import run_in_threadpool + +from app import models +from app.api.schemas import ModelCorePrediction +from app.database import engine +from app.utils import ScikitLearnModel, my_model +from logger import L + + +models.Base.metadata.create_all(bind=engine) + + +router = APIRouter( + prefix="/predict", tags=["predict"], responses={404: {"description": "Not Found"}} +) + + +@router.put("/insurance") +async def predict_insurance(info: ModelCorePrediction, model_name: str): + """ + ์ •๋ณด๋ฅผ ์ž…๋ ฅ๋ฐ›์•„ ๋ณดํ—˜๋ฃŒ๋ฅผ ์˜ˆ์ธกํ•˜์—ฌ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. + + Args: + info(dict): ๋‹ค์Œ์˜ ๊ฐ’๋“ค์„ ์ž…๋ ฅ๋ฐ›์Šต๋‹ˆ๋‹ค. age(int), sex(int), bmi(float), children(int), smoker(int), region(int) + + Returns: + insurance_fee(float): ๋ณดํ—˜๋ฃŒ ์˜ˆ์ธก๊ฐ’์ž…๋‹ˆ๋‹ค. + """ + + def sync_call(info, model_name): + """ + none sync ํ•จ์ˆ˜๋ฅผ sync๋กœ ๋งŒ๋“ค์–ด ์ฃผ๊ธฐ ์œ„ํ•œ ํ•จ์ˆ˜์ด๋ฉฐ ์ž…์ถœ๋ ฅ์€ ๋ถ€๋ชจ ํ•จ์ˆ˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค. + """ + model = ScikitLearnModel(model_name) + model.load_model() + + info = info.dict() + test_set = np.array([*info.values()]).reshape(1, -1) + + pred = model.predict_target(test_set) + return {"result": pred.tolist()[0]} + + try: + result = await run_in_threadpool(sync_call, info, model_name) + L.info( + f"Predict Args info: {info}\n\tmodel_name: {model_name}\n\tPrediction Result: {result}" + ) + return result + + except Exception as e: + L.error(e) + return {"result": "Can't predict", "error": str(e)} + + +@router.put("/atmos") +async def predict_temperature(time_series: List[float]): + """ + ์˜จ๋„ 1์‹œ๊ฐ„ ๊ฐ„๊ฒฉ ์‹œ๊ณ„์—ด์„ ์ž…๋ ฅ๋ฐ›์•„ ์ดํ›„ 24์‹œ๊ฐ„ ๋™์•ˆ์˜ ์˜จ๋„๋ฅผ 1์‹œ๊ฐ„ ๊ฐ„๊ฒฉ์˜ ์‹œ๊ณ„์—ด๋กœ ์˜ˆ์ธกํ•ฉ๋‹ˆ๋‹ค. + + Args: + time_series(List): 72์‹œ๊ฐ„ ๋™์•ˆ์˜ 1์‹œ๊ฐ„ ๊ฐ„๊ฒฉ ์˜จ๋„ ์‹œ๊ณ„์—ด ์ž…๋‹ˆ๋‹ค. 72๊ฐœ์˜ ์›์†Œ๋ฅผ ๊ฐ€์ ธ์•ผ ํ•ฉ๋‹ˆ๋‹ค. + + Returns: + List[float]: ์ž…๋ ฅ๋ฐ›์€ ์‹œ๊ฐ„ ์ดํ›„ 24์‹œ๊ฐ„ ๋™์•ˆ์˜ 1์‹œ๊ฐ„ ๊ฐ„๊ฒฉ ์˜จ๋„ ์˜ˆ์ธก ์‹œ๊ณ„์—ด ์ž…๋‹ˆ๋‹ค. + """ + if len(time_series) != 72: + L.error(f"input time_series: {time_series} is not valid") + return {"result": "time series must have 72 values", "error": None} + + def sync_pred_ts(time_series): + """ + none sync ํ•จ์ˆ˜๋ฅผ sync๋กœ ๋งŒ๋“ค์–ด ์ฃผ๊ธฐ ์œ„ํ•œ ํ•จ์ˆ˜์ด๋ฉฐ ์ž…์ถœ๋ ฅ์€ ๋ถ€๋ชจ ํ•จ์ˆ˜์™€ ๊ฐ™์Šต๋‹ˆ๋‹ค. + """ + time_series = np.array(time_series).reshape(1, -1, 1) + result = my_model.predict_target(time_series) + L.info( + f"Predict Args info: {time_series.flatten().tolist()}\n\tmodel_name: {my_model.model_name}\n\tPrediction Result: {result.tolist()[0]}" + ) + + return {"result": result, "error": None} + + try: + result = await run_in_threadpool(sync_pred_ts, time_series) + return result.tolist() + + except Exception as e: + L.error(e) + return {"result": "Can't predict", "error": str(e)} diff --git a/app/api/router/train.py b/app/api/router/train.py new file mode 100644 index 0000000..5fda7b8 --- /dev/null +++ b/app/api/router/train.py @@ -0,0 +1,109 @@ +import multiprocessing +import os +import re +import subprocess + + +from fastapi import APIRouter + +from app.utils import NniWatcher, ExperimentOwl, base_dir, get_free_port, write_yml +from logger import L + +router = APIRouter( + prefix="/train", tags=["train"], responses={404: {"description": "Not Found"}} +) + + +@router.put("/insurance") +def train_insurance( + experiment_name: str = "exp1", + experimenter: str = "DongUk", + model_name: str = "insurance_fee_model", + version: float = 0.1, +): + """ + insurance์™€ ๊ด€๋ จ๋œ ํ•™์Šต์„ ์‹คํ–‰ํ•˜๊ธฐ ์œ„ํ•œ API์ž…๋‹ˆ๋‹ค. + + Args: + experiment_name (str): ์‹คํ—˜์ด๋ฆ„. ๊ธฐ๋ณธ ๊ฐ’: exp1 + experimenter (str): ์‹คํ—˜์ž์˜ ์ด๋ฆ„. ๊ธฐ๋ณธ ๊ฐ’: DongUk + model_name (str): ๋ชจ๋ธ์˜ ์ด๋ฆ„. ๊ธฐ๋ณธ ๊ฐ’: insurance_fee_model + version (float): ์‹คํ—˜์˜ ๋ฒ„์ „. ๊ธฐ๋ณธ ๊ฐ’: 0.1 + + Returns: + msg: ์‹คํ—˜ ์‹คํ–‰์˜ ์„ฑ๊ณต๊ณผ ์ƒ๊ด€์—†์ด ํฌํŠธ๋ฒˆํ˜ธ๋ฅผ ํฌํ•จํ•œ NNI Dashboard์˜ ์ฃผ์†Œ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. + + Note: + ์‹คํ—˜์˜ ์ตœ์ข… ๊ฒฐ๊ณผ๋ฅผ ๋ฐ˜ํ™˜ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. + """ + PORT = get_free_port() + L.info( + f"Train Args info\n\texperiment_name: {experiment_name}\n\texperimenter: {experimenter}\n\tmodel_name: {model_name}\n\tversion: {version}" + ) + path = "experiments/insurance/" + try: + write_yml(path, experiment_name, experimenter, model_name, version) + nni_create_result = subprocess.getoutput( + "nnictl create --port {} --config {}/{}.yml".format(PORT, path, model_name) + ) + sucs_msg = "Successfully started experiment!" + + if sucs_msg in nni_create_result: + p = re.compile(r"The experiment id is ([a-zA-Z0-9]+)\n") + expr_id = p.findall(nni_create_result)[0] + nni_watcher = NniWatcher(expr_id, experiment_name, experimenter, version) + m_process = multiprocessing.Process(target=nni_watcher.excute) + m_process.start() + + L.info(nni_create_result) + return {"msg": nni_create_result, "error": None} + + except Exception as e: + L.error(e) + return {"msg": "Can't start experiment", "error": str(e)} + + +@router.put("/atmos") +def train_atmos(expr_name: str): + """ + ์˜จ๋„ ์‹œ๊ณ„์—ด๊ณผ ๊ด€๋ จ๋œ ํ•™์Šต์„ ์‹คํ–‰ํ•˜๊ธฐ ์œ„ํ•œ API์ž…๋‹ˆ๋‹ค. + + Args: + expr_name(str): NNI๊ฐ€ ์‹คํ–‰ํ•  ์‹คํ—˜์˜ ์ด๋ฆ„ ์ž…๋‹ˆ๋‹ค. ์ด ํŒŒ๋ผ๋ฏธํ„ฐ๋ฅผ ๊ธฐ๋ฐ˜์œผ๋กœ project_dir/experiments/[expr_name] ๊ฒฝ๋กœ๋กœ ์ฐพ์•„๊ฐ€ config.yml์„ ์ด์šฉํ•˜์—ฌ NNI๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค. + + Returns: + str: NNI์‹คํ—˜์ด ์‹คํ–‰๋œ ๊ฒฐ๊ณผ๊ฐ’์„ ๋ฐ˜ํ™˜ํ•˜๊ฑฐ๋‚˜ ์‹คํ–‰๊ณผ์ •์—์„œ ๋ฐœ์ƒํ•œ ์—๋Ÿฌ ๋ฉ”์„ธ์ง€๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. + + Note: + ์‹คํ—˜์˜ ์ตœ์ข… ๊ฒฐ๊ณผ๋ฅผ ๋ฐ˜ํ™˜ํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. + """ + + nni_port = get_free_port() + expr_path = os.path.join(base_dir, "experiments", expr_name) + + try: + nni_create_result = subprocess.getoutput( + "nnictl create --port {} --config {}/config.yml".format(nni_port, expr_path) + ) + sucs_msg = "Successfully started experiment!" + + if sucs_msg in nni_create_result: + p = re.compile(r"The experiment id is ([a-zA-Z0-9]+)\n") + expr_id = p.findall(nni_create_result)[0] + check_expr = ExperimentOwl(expr_id, expr_name, expr_path) + check_expr.add("update_tfmodeldb") + check_expr.add("modelfile_cleaner") + + m_process = multiprocessing.Process(target=check_expr.execute) + m_process.start() + + L.info(nni_create_result) + return {"msg": nni_create_result, "error": None} + + else: + L.error(nni_create_result) + return {"msg": nni_create_result, "error": None} + + except Exception as e: + L.error(e) + return {"msg": "Can't start experiment", "error": str(e)} diff --git a/app/api/schemas.py b/app/api/schemas.py new file mode 100644 index 0000000..73118ea --- /dev/null +++ b/app/api/schemas.py @@ -0,0 +1,31 @@ +from pydantic import BaseModel + + +class ModelCoreBase(BaseModel): + model_name: str + + +class ModelCorePrediction(BaseModel): + """ + predict_insurance API์˜ ์ž…๋ ฅ ๊ฐ’ ๊ฒ€์ฆ์„ ์œ„ํ•œ pydantic ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค. + + Attributes: + age(int) + sex(int) + bmi(float) + children(int) + smoker(int) + region(int) + """ + + age: int + sex: int + bmi: float + children: int + smoker: int + region: int + + +class ModelCore(ModelCoreBase): + class Config: + orm_mode = True diff --git a/app/database.py b/app/database.py index 0c04917..cc2bcf7 100644 --- a/app/database.py +++ b/app/database.py @@ -1,23 +1,43 @@ import os + +from dotenv import load_dotenv from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from sqlalchemy.ext.declarative import declarative_base -from dotenv import load_dotenv +load_dotenv(verbose=True) def connect(db): + """ + database์™€์˜ ์—ฐ๊ฒฐ์„ ์œ„ํ•œ ํ•จ์ˆ˜ ์ž…๋‹ˆ๋‹ค. + + Args: + db(str): ์‚ฌ์šฉํ•  ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์˜ ์ด๋ฆ„์„ ์ „๋‹ฌ๋ฐ›์Šต๋‹ˆ๋‹ค. - load_dotenv(verbose=True) + Returns: + created database engine: ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ์—ฐ๊ฒฐ๋œ ๊ฐ์ฒด๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. + + Examples: + >>> engine = connect("my_db") + >>> query = "SHOW timezone;" + >>> engine.execute(query).fetchall() + [('Asia/Seoul',)] + >>> print(engine) + Engine(postgresql://postgres:***@127.0.0.1:5432/my_db) + """ + print(db) POSTGRES_USER = os.getenv("POSTGRES_USER") POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD") POSTGRES_PORT = os.getenv("POSTGRES_PORT") POSTGRES_SERVER = os.getenv("POSTGRES_SERVER") - POSTGRES_DB = db - SQLALCHEMY_DATABASE_URL = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}\ - @{POSTGRES_SERVER}:{POSTGRES_PORT}/{POSTGRES_DB}" + + SQLALCHEMY_DATABASE_URL = ( + f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@" + + f"{POSTGRES_SERVER}:{POSTGRES_PORT}/{db}" + ) connection = create_engine(SQLALCHEMY_DATABASE_URL) @@ -25,7 +45,6 @@ def connect(db): POSTGRES_DB = os.getenv("POSTGRES_DB") - engine = connect(POSTGRES_DB) SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine) Base = declarative_base() diff --git a/app/models.py b/app/models.py new file mode 100644 index 0000000..efd76e5 --- /dev/null +++ b/app/models.py @@ -0,0 +1,53 @@ +# -*- coding: utf-8 -*- +import datetime + + +from sqlalchemy import Column, Integer, String, FLOAT, DateTime, ForeignKey, LargeBinary +from sqlalchemy.sql.functions import now +from sqlalchemy.orm import relationship + +from app.database import Base + +KST = datetime.timezone(datetime.timedelta(hours=9)) + + +class ModelCore(Base): + __tablename__ = "model_core" + + model_name = Column(String, primary_key=True) + model_file = Column(LargeBinary, nullable=False) + + model_metadata_relation = relationship( + "ModelMetadata", backref="model_core.model_name" + ) + + +class ModelMetadata(Base): + __tablename__ = "model_metadata" + + experiment_name = Column(String, primary_key=True) + model_core_name = Column( + String, ForeignKey("model_core.model_name"), nullable=False + ) + experimenter = Column(String, nullable=False) + version = Column(FLOAT) + train_mae = Column(FLOAT, nullable=False) + val_mae = Column(FLOAT, nullable=False) + train_mse = Column(FLOAT, nullable=False) + val_mse = Column(FLOAT, nullable=False) + created_at = Column(DateTime(timezone=True), server_default=now()) + + +class TempModelData(Base): + __tablename__ = "temp_model_data" + + id = Column(Integer, primary_key=True, autoincrement=True) + model_name = Column(String, nullable=False) + model_file = Column(LargeBinary, nullable=False) + experiment_name = Column(String, nullable=False) + experimenter = Column(String, nullable=False) + version = Column(FLOAT, nullable=False) + train_mae = Column(FLOAT, nullable=False) + val_mae = Column(FLOAT, nullable=False) + train_mse = Column(FLOAT, nullable=False) + val_mse = Column(FLOAT, nullable=False) diff --git a/app/query.py b/app/query.py new file mode 100644 index 0000000..33f6e9b --- /dev/null +++ b/app/query.py @@ -0,0 +1,102 @@ +UPDATE_TEMP_MODEL_DATA = """ + DELETE FROM temp_model_data + WHERE id NOT IN ( + SELECT id + FROM temp_model_data + WHERE experiment_name = '{}' + ORDER BY {} + LIMIT {} + ) + """ + + +SELECT_TEMP_MODEL_BY_EXPR_NAME = """ + SELECT * + FROM temp_model_data + WHERE experiment_name = '{}' + ORDER BY {}; + """ + + +SELECT_MODEL_METADATA_BY_EXPR_NAME = """ + SELECT * + FROM model_metadata + WHERE experiment_name = '{}' + """ + +INSERT_MODEL_CORE = """ + INSERT INTO model_core ( + model_name, + model_file + ) VALUES( + '{}', + '{}' + ) + """ + +INSERT_MODEL_METADATA = """ + INSERT INTO model_metadata ( + experiment_name, + model_core_name, + experimenter, + version, + train_mae, + val_mae, + train_mse, + val_mse + ) VALUES ( + '{}', + '{}', + '{}', + '{}', + '{}', + '{}', + '{}', + '{}' + ) +""" + +UPDATE_MODEL_CORE = """ + UPDATE model_core + SET + model_file = '{}' + WHERE + model_name = '{}' + """ + +UPDATE_MODEL_METADATA = """ + UPDATE model_metadata + SET + train_mae = {}, + val_mae = {}, + train_mse = {}, + val_mse = {} + WHERE experiment_name = '{}' + """ + +DELETE_ALL_EXPERIMENTS_BY_EXPR_NAME = """ + DELETE FROM temp_model_data + WHERE experiment_name = '{}' +""" + +INSERT_OR_UPDATE_MODEL = """ +UPDATE model_core +SET model_name='{mn}', model_file='{mf}' +WHERE model_core.model_name='{mn}'; +INSERT INTO model_core (model_name, model_file) +SELECT '{mn}', '{mf}' +WHERE NOT EXISTS (SELECT 1 + FROM model_core as mc + WHERE mc.model_name = '{mn}'); +""" + +INSERT_OR_UPDATE_SCORE = """ +UPDATE atmos_model_metadata +SET mae='{score1}', mse='{score2}' +WHERE atmos_model_metadata.model_name='{mn}'; +INSERT INTO atmos_model_metadata (model_name, experiment_id, mae, mse) +SELECT '{mn}', '{expr_id}', '{score1}', '{score2}' +WHERE NOT EXISTS (SELECT 1 + FROM atmos_model_metadata as amm + WHERE amm.model_name = '{mn}'); +""" \ No newline at end of file diff --git a/app/router/predict.py b/app/router/predict.py deleted file mode 100644 index f60a49e..0000000 --- a/app/router/predict.py +++ /dev/null @@ -1,12 +0,0 @@ -from fastapi import APIRouter - -router = APIRouter( - prefix="/predict", - tags=["predict"], - responses={404: {"description": "Not Found"}} -) - - -@router.get("/") -def hello_world(): - return {"message": "Hello predict"} diff --git a/app/utils.py b/app/utils.py new file mode 100644 index 0000000..57fa070 --- /dev/null +++ b/app/utils.py @@ -0,0 +1,518 @@ +import codecs +import glob +import io +import multiprocessing +import os +import pickle +import re +import shutil +import socketserver +import subprocess +import time +import zipfile + + +import tensorflow as tf +import yaml + +from app.database import engine +from app.query import * +from logger import L + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3" + +base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) + +physical_devices = tf.config.list_physical_devices("GPU") +if physical_devices: + tf.config.experimental.set_memory_growth(physical_devices[0], enable=True) + + +class CoreModel: + """ + predict API ํ˜ธ์ถœ์„ ๋ฐ›์•˜์„ ๋•Œ ์‚ฌ์šฉ๋  ML ๋ชจ๋ธ์„ ๋กœ๋“œํ•˜๋Š” ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค. + + Attributes: + model_name(str): ์˜ˆ์ธก์„ ์‹คํ–‰ํ•  ๋ชจ๋ธ์˜ ์ด๋ฆ„ + model(obj): ๋ชจ๋ธ์ด ์ €์žฅ๋  ์ธ์Šคํ„ด์Šค ๋ณ€์ˆ˜ + query(str): ์ž…๋ ฅ๋ฐ›์€ ๋ชจ๋ธ์ด๋ฆ„์„ ๊ธฐ๋ฐ˜์œผ๋กœ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์—์„œ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ค๋Š” SQL query์ž…๋‹ˆ๋‹ค. + """ + + def __init__(self, model_name): + self.model_name = model_name + self.model = None + self.query = """ + SELECT model_file + FROM model_core + WHERE model_name='{}'; + """.format( + self.model_name + ) + + def load_model(self): + """ + ๋ณธ ํด๋ž˜์Šค๋ฅผ ์ƒ์†๋ฐ›์•˜์„ ๋•Œ ์ด ํ•จ์ˆ˜๋ฅผ ๊ตฌํ˜„ํ•˜์ง€ ์•Š์œผ๋ฉด ์˜ˆ์™ธ๋ฅผ ๋ฐœ์ƒ์‹œํ‚ต๋‹ˆ๋‹ค. + """ + raise Exception + + def predict_target(self, target_data): + """ + ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์—์„œ ๋ถˆ๋Ÿฌ์™€ ์ธ์Šคํ„ด์Šค ๋ณ€์ˆ˜์— ์ €์žฅ๋œ ๋ชจ๋ธ์„ ๊ธฐ๋ฐ˜์œผ๋กœ ์˜ˆ์ธก์„ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค. + + Args: + target_data: predict API ํ˜ธ์ถœ ์‹œ ์ž…๋ ฅ๋ฐ›์€ ๊ฐ’์ž…๋‹ˆ๋‹ค. ์ž๋ฃŒํ˜•์€ ๋ชจ๋ธ์— ๋”ฐ๋ผ ๋‹ค๋ฆ…๋‹ˆ๋‹ค. + + Returns: + ์˜ˆ์ธก๋œ ๊ฐ’์„ ๋ฐ˜ํ™˜ ํ•ฉ๋‹ˆ๋‹ค. ์ž๋ฃŒํ˜•์€ ๋ชจ๋ธ์— ๋”ฐ๋ผ ๋‹ค๋ฆ…๋‹ˆ๋‹ค. + """ + return self.model.predict(target_data) + + +class ScikitLearnModel(CoreModel): + """ + Scikit learn ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ๊ธฐ๋ฐ˜์˜ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ์œ„ํ•œ ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค. + Examples: + >>> sk_model = ScikitLearnModel("my_model") + >>> sk_model.load_model() + >>> sk_model.predict_target(target) + predict result + """ + + def __init__(self, *args): + super().__init__(*args) + + def load_model(self): + """ + ๋ชจ๋ธ์„ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์—์„œ ๋ถˆ๋Ÿฌ์™€ ์ธ์Šคํ„ด์Šค ๋ณ€์ˆ˜์— ์ €์žฅํ•˜๋Š” ํ•จ์ˆ˜ ์ž…๋‹ˆ๋‹ค. ์ƒ์†๋ฐ›์€ ๋ถ€๋ชจํด๋ž˜์Šค์˜ ์ธ์Šคํ„ด์Šค ๋ณ€์ˆ˜๋ฅผ ์ด์šฉํ•˜๋ฉฐ, ๋ฐ˜ํ™˜ ๊ฐ’์€ ์—†์Šต๋‹ˆ๋‹ค. + """ + _model = engine.execute(self.query).fetchone() + if _model is None: + raise ValueError("Model Not Found!") + + self.model = pickle.loads(codecs.decode(_model[0], "base64")) + + +class TensorFlowModel(CoreModel): + """ + Tensorflow ๋ผ์ด๋ธŒ๋Ÿฌ๋ฆฌ ๊ธฐ๋ฐ˜์˜ ๋ชจ๋ธ์„ ๋ถˆ๋Ÿฌ์˜ค๊ธฐ ์œ„ํ•œ ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค. + Examples: + >>> tf_model = TensorflowModel("my_model") + >>> tf_model.load_model() + >>> tf_model.predict_target(target) + predict result + """ + + def __init__(self, *args): + super().__init__(*args) + + def load_model(self): + """ + ๋ชจ๋ธ์„ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์—์„œ ๋ถˆ๋Ÿฌ์™€ ์ธ์Šคํ„ด์Šค ๋ณ€์ˆ˜์— ์ €์žฅํ•˜๋Š” ํ•จ์ˆ˜ ์ž…๋‹ˆ๋‹ค. ์ƒ์†๋ฐ›์€ ๋ถ€๋ชจํด๋ž˜์Šค์˜ ์ธ์Šคํ„ด์Šค ๋ณ€์ˆ˜๋ฅผ ์ด์šฉํ•˜๋ฉฐ, ๋ฐ˜ํ™˜ ๊ฐ’์€ ์—†์Šต๋‹ˆ๋‹ค. + """ + _model = engine.execute(self.query).fetchone() + if _model is None: + raise ValueError("Model Not Found!") + model_buffer = pickle.loads(codecs.decode(_model[0], "base64")) + model_path = os.path.join(base_dir, "tf_model", self.model_name) + + with zipfile.ZipFile(model_buffer, "r") as bf: + bf.extractall(model_path) + self.model = tf.keras.models.load_model(model_path) + + +my_model = TensorFlowModel("test_model") +my_model.load_model() + + +def write_yml(path, experiment_name, experimenter, model_name, version): + """ + NNI ์‹คํ—˜์„ ์‹œ์ž‘ํ•˜๊ธฐ ์œ„ํ•œ config.ymlํŒŒ์ผ์„ ์ž‘์„ฑํ•˜๋Š” ํ•จ์ˆ˜ ์ž…๋‹ˆ๋‹ค. + + Args: + path(str): ์‹คํ—˜์˜ ๊ฒฝ๋กœ + experiment_name(str): ์‹คํ—˜์˜ ์ด๋ฆ„ + experimenter(str): ์‹คํ—˜์ž์˜ ์ด๋ฆ„ + model_name(str): ๋ชจ๋ธ์˜ ์ด๋ฆ„ + version(float): ๋ฒ„์ „ + + Returns: + ๋ฐ˜ํ™˜ ๊ฐ’์€ ์—†์œผ๋ฉฐ ์ž…๋ ฅ๋ฐ›์€ ๊ฒฝ๋กœ๋กœ ymlํŒŒ์ผ์ด ์ž‘์„ฑ๋ฉ๋‹ˆ๋‹ค. + """ + with open("{}/{}.yml".format(path, model_name), "w") as yml_config_file: + yaml.dump( + { + "authorName": f"{experimenter}", + "experimentName": f"{experiment_name}", + "trialConcurrency": 1, + "maxExecDuration": "1h", + "maxTrialNum": 10, + "trainingServicePlatform": "local", + "searchSpacePath": "search_space.json", + "useAnnotation": False, + "tuner": { + "builtinTunerName": "Anneal", + "classArgs": {"optimize_mode": "minimize"}, + }, + "trial": { + "command": "python trial.py -e {} -n {} -m {} -v {}".format( + experimenter, experiment_name, model_name, version + ), + "codeDir": ".", + }, + }, + yml_config_file, + default_flow_style=False, + ) + + yml_config_file.close() + + return + + +class NniWatcher: + """ + experiment_id๋ฅผ ์ž…๋ ฅ๋ฐ›์•„ ํ•ด๋‹น id๋ฅผ ๊ฐ€์ง„ nni ์‹คํ—˜์„ ๋ชจ๋‹ˆํ„ฐ๋งํ•˜๊ณ  ๋ชจ๋ธ ํŒŒ์ผ์„ ๊ด€๋ฆฌํ•ด์ฃผ๋Š” ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค. + ์ƒ์„ฑ๋˜๋Š” scikit learn ๋ชจ๋ธ์„ DB์˜ ์ž„์‹œ ํ…Œ์ด๋ธ”์— ์ €์žฅํ•˜์—ฌ ์ฃผ๊ธฐ์ ์œผ๋กœ ์—…๋ฐ์ดํŠธ ํ•ฉ๋‹ˆ๋‹ค. + ์ดํ›„ ์‹คํ—˜์˜ ๋ชจ๋“  ํ”„๋กœ์„ธ์Šค๊ฐ€ ์ข…๋ฃŒ๋˜๋ฉด ๊ฐ€์žฅ ์„ฑ๋Šฅ์ด ์ข‹์€ ๋ชจ๋ธ๊ณผ ์ ์ˆ˜๋ฅผ ์—…๋ฐ์ดํŠธ ํ•ฉ๋‹ˆ๋‹ค. + + Attributes: + experiment_id(str): nni experiment๋ฅผ ์‹คํ–‰ํ•  ๋•Œ ์ƒ์„ฑ๋˜๋Š” id + experiment_name(str): ์‹คํ—˜์˜ ์ด๋ฆ„ + experimenter(str): ์‹คํ—˜์ž์˜ ์ด๋ฆ„ + version(str): ์‹คํ—˜์˜ ๋ฒ„์ „ + minute(int): ๊ฐ์‹œ ์ฃผ๊ธฐ + is_kill(bool, default=True): ์‹คํ—˜ ๊ฐ์‹œํ•˜๋ฉฐ ์‹คํ—˜์ด ๋๋‚˜๋ฉด ์ข…๋ฃŒํ• ์ง€ ๊ฒฐ์ •ํ•˜๋Š” ๋ณ€์ˆ˜ + top_cnt(int, default=3): ์ž„์‹œ๋กœ ์ตœ๋Œ€ ๋ช‡๊ฐœ์˜ ์‹คํ—˜์„ ์ €์žฅํ• ์ง€ ๊ฒฐ์ •ํ•˜๋Š” ๋ณ€์ˆ˜ + evaluation_criteria(str, default="val_mae"): ์–ด๋–ค ํ‰๊ฐ€๊ธฐ์ค€์œผ๋กœ ๋ชจ๋ธ์„ ์—…๋ฐ์ดํŠธ ํ• ์ง€ ๊ฒฐ์ •ํ•˜๋Š” ๋ณ€์ˆ˜ + + Examples: + >>> watcher = NniWatcher(expr_id, experiment_name, experimenter, version) + >>> watcher.execute() + """ + + def __init__( + self, + experiment_id, + experiment_name, + experimenter, + version, + minute=1, + is_kill=True, + top_cnt=3, + evaluation_criteria="val_mae", + ): + self.experiment_id = experiment_id + self.experiment_name = experiment_name + self.experimenter = experimenter + self.version = version + self.is_kill = is_kill + self.top_cnt = top_cnt + self.evaluation_criteria = evaluation_criteria + self._wait_minute = minute * 20 + self._experiment_list = None + self._running_experiment = None + + def excute(self): + """ + ๋ชจ๋“  ํ•จ์ˆ˜๋ฅผ ์‹คํ–‰ํ•ฉ๋‹ˆ๋‹ค. + """ + self.watch_process() + self.model_final_update() + + def get_running_experiment(self): + """ + ์‹คํ–‰์ค‘์ธ ์‹คํ—˜์˜ ๋ชฉ๋ก์„ ๊ฐ€์ ธ์™€ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค. + """ + self._experiment_list = subprocess.getoutput("nnictl experiment list") + self._running_experiment = [ + expr + for expr in self._experiment_list.split("\n") + if self.experiment_id in expr + ] + L.info(self._running_experiment) + + def watch_process(self): + """ + ์‚ฌ์šฉ์ž๊ฐ€ ์ง€์ •ํ•œ ์‹œ๊ฐ„์„ ์ฃผ๊ธฐ๋กœ ์‹คํ—˜ ํ”„๋กœ์„ธ์Šค๊ฐ€ ์ง„ํ–‰ ์ค‘์ธ์ง€ ๊ฐ์‹œํ•˜๊ณ  "DONE"์ƒํƒœ๋กœ ๋ณ€๊ฒฝ๋˜๋ฉด ์‹คํ—˜์„ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค. + ๋ชจ๋ธ์˜ score๋ฅผ DB์— ์ฃผ๊ธฐ์ ์œผ๋กœ ์—…๋ฐ์ดํŠธ ํ•ด์ค๋‹ˆ๋‹ค. + """ + if self.is_kill: + while True: + self.get_running_experiment() + if self._running_experiment and ("DONE" in self._running_experiment[0]): + _stop_expr = subprocess.getoutput( + "nnictl stop {}".format(self.experiment_id) + ) + L.info(_stop_expr) + break + + elif self.experiment_id not in self._experiment_list: + L.error("Experiment ID not in Current Experiment List") + L.info(self._experiment_list) + break + + else: + self.model_keep_update() + time.sleep(self._wait_minute) + + def model_keep_update(self): + """ + scikit learn ๋ชจ๋ธ์˜ ์„ฑ๋Šฅ์„ DB์— ์—…๋ฐ์ดํŠธ ํ•ฉ๋‹ˆ๋‹ค. + """ + engine.execute( + UPDATE_TEMP_MODEL_DATA.format( + self.experiment_name, self.evaluation_criteria, self.top_cnt + ) + ) + + def model_final_update(self): + """ + ์‹คํ—˜ ์ข…๋ฃŒ์‹œ ์‹คํ–‰๋˜๋Š” ํ•จ์ˆ˜๋กœ ๋ชจ๋ธ์˜ ์ตœ์ข… ์ ์ˆ˜์™€ ๋ชจ๋ธ ํŒŒ์ผ์„ DB์— ์—…๋ฐ์ดํŠธ ํ•ด์ค๋‹ˆ๋‹ค. + """ + final_result = engine.execute( + SELECT_TEMP_MODEL_BY_EXPR_NAME.format( + self.experiment_name, self.evaluation_criteria + ) + ).fetchone() + + saved_result = engine.execute( + SELECT_MODEL_METADATA_BY_EXPR_NAME.format(self.experiment_name) + ).fetchone() + + a = pickle.loads(codecs.decode(final_result.model_file, "base64")) + pickled_model = codecs.encode(pickle.dumps(a), "base64").decode() + + if saved_result is None: + engine.execute( + INSERT_MODEL_CORE.format(final_result.model_name, pickled_model) + ) + engine.execute( + INSERT_MODEL_METADATA.format( + self.experiment_name, + final_result.model_name, + self.experimenter, + self.version, + final_result.train_mae, + final_result.val_mae, + final_result.train_mse, + final_result.val_mse, + ) + ) + elif ( + saved_result[self.evaluation_criteria] + > final_result[self.evaluation_criteria] + ): + engine.execute( + UPDATE_MODEL_CORE.format(pickled_model, saved_result.model_name) + ) + engine.execute( + UPDATE_MODEL_METADATA.format( + final_result.train_mae, + final_result.val_mae, + final_result.train_mse, + final_result.val_mse, + self.experiment_name, + ) + ) + + engine.execute(DELETE_ALL_EXPERIMENTS_BY_EXPR_NAME.format(self.experiment_name)) + + +def zip_model(model_path): + """ + ์ž…๋ ฅ๋ฐ›์€ ๋ชจ๋ธ์˜ ๊ฒฝ๋กœ๋ฅผ ์ฐพ์•„๊ฐ€ ๋ชจ๋ธ์„ ์••์ถ•ํ•˜์—ฌ ๋ฉ”๋ชจ๋ฆฌ ๋ฒ„ํผ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. + + Args: + model_path(str): ๋ชจ๋ธ์ด ์žˆ๋Š” ๊ฒฝ๋กœ์ž…๋‹ˆ๋‹ค. + + Returns: + memory buffer: ๋ชจ๋ธ์„ ์••์ถ•ํ•œ ๋ฉ”๋ชจ๋ฆฌ ๋ฒ„ํผ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. + + Note: + ๋ชจ๋ธ์„ ๋””์Šคํฌ์— ํŒŒ์ผ๋กœ ์ €์žฅํ•˜์ง€ ์•Š์Šต๋‹ˆ๋‹ค. + """ + model_buffer = io.BytesIO() + + basedir = os.path.basename(model_path) + + with zipfile.ZipFile(model_buffer, "w") as zf: + for root, dirs, files in os.walk(model_path): + + def make_arcname(x): + return os.path.join(root.split(basedir)[-1], x) + + for dr in dirs: + dir_path = os.path.join(root, dr) + zf.write(filename=dir_path, arcname=make_arcname(dr)) + for file in files: + file_path = os.path.join(root, file) + zf.write(filename=file_path, arcname=make_arcname(file)) + + return model_buffer + + +def get_free_port(): + """ + ํ˜ธ์ถœ ์ฆ‰์‹œ ์‚ฌ์šฉ๊ฐ€๋Šฅํ•œ ํฌํŠธ๋ฒˆํ˜ธ๋ฅผ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค. + + Returns: + ํ˜„์žฌ ์‚ฌ์šฉ๊ฐ€๋Šฅํ•œ ํฌํŠธ๋ฒˆํ˜ธ + + Examples: + >>> avail_port = get_free_port() # ์‚ฌ์šฉ ๊ฐ€๋Šฅํ•œ ํฌํŠธ, ๊ทธ๋•Œ๊ทธ๋•Œ ๋‹ค๋ฆ„ + >>> print(avail_port) + 45675 + """ + with socketserver.TCPServer(("localhost", 0), None) as s: + free_port = s.server_address[1] + return free_port + + +class ExperimentOwl: + """ + experiment_id๋ฅผ ์ž…๋ ฅ๋ฐ›์•„ ํ•ด๋‹น id๋ฅผ ๊ฐ€์ง„ nni ์‹คํ—˜์„ ๋ชจ๋‹ˆํ„ฐ๋งํ•˜๊ณ  ๋ชจ๋ธ ํŒŒ์ผ์„ ๊ด€๋ฆฌํ•ด์ฃผ๋Š” ํด๋ž˜์Šค์ž…๋‹ˆ๋‹ค. + ํ•„์š”ํ•œ ๊ธฐ๋Šฅ์„ instance.add("method name") ๋ฉ”์„œ๋“œ๋กœ ์ถ”๊ฐ€ํ•˜์—ฌ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. + + ํ˜„์žฌ ๋ณด์œ ํ•œ ๊ธฐ๋Šฅ + 1. (๊ธฐ๋ณธ)nnictl experiment list(shell command)๋ฅผ ์ฃผ๊ธฐ์ ์œผ๋กœ ํ˜ธ์ถœํ•˜์—ฌ ์‹คํ—˜์ด ํ˜„์žฌ ์ง„ํ–‰์ค‘์ธ์ง€ ํŒŒ์•…ํ•ฉ๋‹ˆ๋‹ค. + ์‹คํ—˜์˜ ์ƒํƒœ๊ฐ€ DONE์œผ๋กœ ๋ณ€๊ฒฝ๋˜๋ฉด ์ตœ๊ณ ์ ์ˆ˜ ๋ชจ๋ธ์„ ๋ฐ์ดํ„ฐ๋ฒ ์ด์Šค์— ์ €์žฅํ•˜๊ณ  nnictl stop experiment_id๋ฅผ ์‹คํ–‰ํ•˜์—ฌ ์‹คํ—˜์„ ์ข…๋ฃŒํ•œ ํ›„ ํ”„๋กœ์„ธ์Šค๊ฐ€ ์ข…๋ฃŒ๋ฉ๋‹ˆ๋‹ค. + + 2. ํŒŒ์ผ๋กœ ์ƒ์„ฑ๋˜๋Š” ๋ชจ๋ธ์ด ๋„ˆ๋ฌด ๋งŽ์•„์ง€์ง€ ์•Š๋„๋ก ์œ ์ง€ํ•ฉ๋‹ˆ๋‹ค.(3๊ฐœ ์ด์ƒ ๋ชจ๋ธ์ด ์ƒ์„ฑ๋˜๋ฉด ์„ฑ๋Šฅ์ˆœ์œผ๋กœ 3์œ„ ๋ฏธ๋งŒ์€ ์‚ญ์ œ) instance ์ƒ์„ฑ ์‹œ + mfile_manage = False๋กœ ๊ธฐ๋Šฅ์„ ์‚ฌ์šฉํ•˜์ง€ ์•Š์„ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค.(default True) + + 3. (method) update_tfmodelbd + ํ…์„œํ”Œ๋กœ์šฐ๋ฅผ ์ด์šฉํ•œ ์‹คํ—˜ ์‹œ ์ƒ์„ฑ๋˜๋Š” ๋ชจ๋ธ์„ ์‹คํ—˜์ด ์ข…๋ฃŒ๋˜๋ฉด DB์— ์ €์žฅํ•˜๊ฑฐ๋‚˜ ์ ์ˆ˜๊ฐ€ ํ–ฅ์ƒ๋˜์—ˆ์„ ์‹œ ์—…๋ฐ์ดํŠธ ํ•ด์ค๋‹ˆ๋‹ค. + + 4. (method) modelfile_cleaner + ๋ชจ๋“  ์‹คํ—˜์ด ์ข…๋ฃŒ๋˜๊ณ  ๋ชจ๋ธ์ด ์ €์žฅ๋˜๋ฉด temp ํด๋”์— ์žˆ๋Š” ๋ชจ๋ธํŒŒ์ผ๋“ค์„ ๋ชจ๋‘ ์ง€์›Œ์ค๋‹ˆ๋‹ค. + + Attributes: + experiment_id(str): nni experiment๋ฅผ ์‹คํ–‰ํ•  ๋•Œ ์ƒ์„ฑ๋˜๋Š” id + experiment_name(str): ์‹คํ—˜์˜ ์ด๋ฆ„ + experiment_path(str): ์‹คํ—˜์ฝ”๋“œ๊ฐ€ ์žˆ๋Š” ๊ฒฝ๋กœ + mfile_manage(bool, default=True): ์ฃผ๊ธฐ์ ์œผ๋กœ ํŒŒ์ผ ์‚ญ์ œ ์—ฌ๋ถ€ + time(int or float, default=5): ๊ฐ์‹œ์ฃผ๊ธฐ(๋ถ„) + + Examples: + >>> owl = ExperimentOwl(id, name, path) + >>> owl.add("update_tfmodeldb") + >>> owl.add("modelfile_cleaner") + >>> owl.execute() + """ + + def __init__( + self, experiment_id, experiment_name, experiment_path, mfile_manage=True, time=5 + ): + self.__minute = 60 + self.time = time * self.__minute + self.experiment_id = experiment_id + self.experiment_name = experiment_name + self.experiment_path = experiment_path + self.mfile_manage = mfile_manage + self.__func_list = [self.main] + + def execute(self): + """ + instance.add("method name")์œผ๋กœ ์ €์žฅ๋œ ๋ฉ”์„œ๋“œ๋“ค์„ ์ˆœ์„œ๋Œ€๋กœ ๋ชจ๋‘ ์‹คํ–‰์‹œํ‚ต๋‹ˆ๋‹ค. + """ + for func in self.__func_list: + func() + + def add(self, func_name): + func = getattr(self, func_name) + self.__func_list.append(func) + + def main(self): + """ + ExperimentOwlํด๋ž˜์Šค๋กœ ์ธ์Šคํ„ด์Šค๋ฅผ ์ƒ์„ฑ ํ›„ ์‹คํ–‰์‹œ ๊ธฐ๋ณธ์ ์œผ๋กœ ์‹คํ–‰๋˜๋Š” ๊ธฐ๋Šฅ์ž…๋‹ˆ๋‹ค. + ์‚ฌ์šฉ์ž๊ฐ€ ์ง€์ •ํ•œ ์‹œ๊ฐ„์„ ์ฃผ๊ธฐ๋กœ ์‹คํ—˜ ํ”„๋กœ์„ธ์Šค๊ฐ€ ์ง„ํ–‰ ์ค‘์ธ์ง€ ๊ฐ์‹œํ•˜๊ณ  "DONE"์ƒํƒœ๋กœ ๋ณ€๊ฒฝ๋˜๋ฉด ์‹คํ—˜์„ ์ข…๋ฃŒํ•ฉ๋‹ˆ๋‹ค. + ์ธ์Šคํ„ด์Šค ์ƒ์„ฑ ์‹œ mfile_manage์˜ต์…˜์ด True์ด๋ฉด ๋ชจ๋ธ ํŒŒ์ผ์ด ๋„ˆ๋ฌด ๋งŽ์•„์ง€์ง€ ์•Š๊ฒŒ ์ ์ˆ˜ ์ˆœ์„œ๋กœ 3์œ„ ์ดํ•˜๋Š” ์‚ญ์ œํ•ฉ๋‹ˆ๋‹ค.(default True) + """ + while True: + time.sleep(self.__minute) + + expr_list = subprocess.getoutput("nnictl experiment list") + + running_expr = [ + expr for expr in expr_list.split("\n") if self.experiment_id in expr + ] + print(running_expr) + if running_expr and ("DONE" in running_expr[0]): + stop_expr = subprocess.getoutput( + "nnictl stop {}".format(self.experiment_id) + ) + L.info(stop_expr) + break + + elif self.experiment_id not in expr_list: + L.info(expr_list) + break + + else: + if self.mfile_manage: + model_path = os.path.join( + self.experiment_path, + "temp", + "*_{}*".format(self.experiment_name), + ) + exprs = glob.glob(model_path) + if len(exprs) > 3: + exprs.sort() + [shutil.rmtree(_) for _ in exprs[3:]] + + def update_tfmodeldb(self): + """ + ์‹คํ—˜์ด ์ข…๋ฃŒ๋˜๋ฉด ๋ชจ๋ธ์„ DB์— ์ €์žฅํ•˜๊ฑฐ๋‚˜ ์ด๋ฏธ ๊ฐ™์€ ์ด๋ฆ„์˜ ๋ชจ๋ธ์ด ์กด์žฌํ•  ์‹œ ์ ์ˆ˜๋ฅผ ๋น„๊ตํ•˜์—ฌ ์—…๋ฐ์ดํŠธ ํ•ฉ๋‹ˆ๋‹ค. + """ + model_path = os.path.join( + self.experiment_path, "temp", "*_{}*".format(self.experiment_name) + ) + exprs = glob.glob(model_path) + if not exprs: + return 0 + + exprs.sort() + exprs = exprs[0] + metrics = os.path.basename(exprs).split("_")[:2] + metrics = [float(metric) for metric in metrics] + + score_sql = """SELECT mae + FROM atmos_model_metadata + WHERE model_name = '{}' + ORDER BY mae;""".format( + self.experiment_name + ) + saved_score = engine.execute(score_sql).fetchone() + + if not saved_score or (metrics[0] < saved_score[0]): + winner_model = os.path.join( + os.path.join(self.experiment_path, "temp", self.experiment_name) + ) + if os.path.exists: + shutil.rmtree(winner_model) + os.rename(exprs, winner_model) + + m_buffer = zip_model(winner_model) + encode_model = codecs.encode(pickle.dumps(m_buffer), "base64").decode() + + engine.execute( + INSERT_OR_UPDATE_MODEL.format(mn=self.experiment_name, mf=encode_model) + ) + engine.execute( + INSERT_OR_UPDATE_SCORE.format( + mn=self.experiment_name, + expr_id=self.experiment_id, + score1=metrics[0], + score2=metrics[1], + ) + ) + L.info("saved model %s %s" % (self.experiment_id, self.experiment_name)) + + def modelfile_cleaner(self): + """ + temp ํด๋”์— ์žˆ๋Š” ๋ชจ๋“  ๋ชจ๋ธํŒŒ์ผ์„ ์‚ญ์ œํ•ฉ๋‹ˆ๋‹ค. + ๊ฐ€์žฅ ๋งˆ์ง€๋ง‰์— ์‹คํ–‰ํ•˜์—ฌ ์ €์žฅ๋˜๊ณ  ๋‚จ์€ ๋ชจ๋ธํŒŒ์ผ๋“ค์„ ์‚ญ์ œํ•˜๋Š” ์šฉ๋„๋กœ ์‚ฌ์šฉํ•  ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค. + """ + model_path = os.path.join(self.experiment_path, "temp", "*") + exprs = glob.glob(model_path) + [shutil.rmtree(_) for _ in exprs] diff --git a/docs/api-list.md b/docs/api-list.md new file mode 100644 index 0000000..bd022a6 --- /dev/null +++ b/docs/api-list.md @@ -0,0 +1,141 @@ +# API List + +API_URL : 127.0.0.1 + +- [API List](#api-list) + - [Train](#train) + - [Insurance](#insuranceํ›ˆ๋ จ) + - [Temperature](#Temperatureํ›ˆ๋ จ) + - [Prediction](#predict) + - [Insurance](#insurance์˜ˆ์ธก) + - [Temperature](#Temperature์˜ˆ์ธก) + + +## Train + +### Insuranceํ›ˆ๋ จ + +#### ์š”์ฒญ + +``` +PUT {{API_URL}}/train/insurance +``` + +| ํŒŒ๋ผ๋ฏธํ„ฐ | ํŒŒ๋ผ๋ฏธํ„ฐ ์œ ํ˜• | ๋ฐ์ดํ„ฐ ํƒ€์ž… | ํ•„์ˆ˜ ์—ฌ๋ถ€ | ์„ค๋ช… | +| --------------- | ------------- | ----------- | --------- | ----------------------- | +| `experiment_name` | `body` | `str` | `(default) exp1` | ํ•™์Šต์ด๋ฆ„ | +| `experimenter` | `body` | `str` | `(default) DongUk` | ์—ฐ๊ตฌ์ž ์ด๋ฆ„ | +| `model_name` | `body` | `str` | `(default) insurance_fee_model` | ํ•™์Šต ๋ชจ๋ธ ์ด๋ฆ„ | +| `version` | `body` | `float` | `(default) 0.1` | ๋ชจ๋ธ ๋ฒ„์ „ | + + +
+ +#### ์‘๋‹ต + +| ํ‚ค | ๋ฐ์ดํ„ฐ ํƒ€์ž… | ์„ค๋ช… | +| -------------- | ----------- | ------------- | +| `msg` | `string` | NNI Dashboard ์ •๋ณด | +| `error` | `string` | ์—๋Ÿฌ๋‚ด์šฉ | + + +```jsonc +{ + "msg": Info Message, + "error": "Error info" +} +``` + + +### Temperatureํ›ˆ๋ จ + +#### ์š”์ฒญ + +| ํŒŒ๋ผ๋ฏธํ„ฐ | ํŒŒ๋ผ๋ฏธํ„ฐ ์œ ํ˜• | ๋ฐ์ดํ„ฐ ํƒ€์ž… | ํ•„์ˆ˜ ์—ฌ๋ถ€ | ์„ค๋ช… | +| --------------- | ------------- | ----------- | --------- | ----------------------- | +| `expr_name` | `body` | `string` | โœ… | ํ•™์Šต์ด๋ฆ„ | + + +
+ +#### ์‘๋‹ต + +| ํ‚ค | ๋ฐ์ดํ„ฐ ํƒ€์ž… | ์„ค๋ช… | +| -------------- | ----------- | ------------- | +| `msg` | `string` | NNI Dashboard ์ •๋ณด | +| `error` | `string` | ์—๋Ÿฌ๋‚ด์šฉ | + + +```jsonc +{ + "msg": Info Message, + "error": "Error info" +} +``` + +## Predict + +### Insurance์˜ˆ์ธก + +#### ์š”์ฒญ + +``` +PUT {{API_URL}}/predict/insurance +``` + +| ํŒŒ๋ผ๋ฏธํ„ฐ | ํŒŒ๋ผ๋ฏธํ„ฐ ์œ ํ˜• | ๋ฐ์ดํ„ฐ ํƒ€์ž… | ํ•„์ˆ˜ ์—ฌ๋ถ€ | ์„ค๋ช… | +| --------------- | ------------- | ----------- | --------- | ----------------------- | +| `age` | `body` | `int` | โœ… | ๋‚˜์ด | +| `sex` | `body` | `int` | โœ… | ์„ฑ๋ณ„ | +| `bmi` | `body` | `float` | โœ… | bmi์ˆ˜์น˜ | +| `children` | `body` | `int` | โœ… | ์ž๋…€ ์ˆ˜ | +| `smoker` | `body` | `int` | โœ… | ํก์—ฐ์—ฌ๋ถ€ | +| `region` | `body` | `int` | โœ… | ๊ฑฐ์ฃผ์ง€์—ญ | + +
+ +#### ์‘๋‹ต + +| ํ‚ค | ๋ฐ์ดํ„ฐ ํƒ€์ž… | ์„ค๋ช… | +| -------------- | ----------- | ------------- | +| `result` | `float` | ์˜ˆ์ธก๋œ ๋ณดํ—˜๋ฃŒ ๊ฐ’ | +| `error` | `string` | ์—๋Ÿฌ๋‚ด์šฉ | + + +```jsonc +{ + "result": 3213.123, + "error": "Error info" +} +``` + + +### Temperature์˜ˆ์ธก + +#### ์š”์ฒญ + +``` +PUT {{API_URL}}/predict/atmos +``` + +| ํŒŒ๋ผ๋ฏธํ„ฐ | ํŒŒ๋ผ๋ฏธํ„ฐ ์œ ํ˜• | ๋ฐ์ดํ„ฐ ํƒ€์ž… | ํ•„์ˆ˜ ์—ฌ๋ถ€ | ์„ค๋ช… | +| --------------- | ------------- | ----------- | --------- | ----------------------- | +| `time_series` | `body` | `List[float]` | โœ… | 72์ผ๊ฐ„์˜ ์˜จ๋„๋ฐ์ดํ„ฐ | + + +
+ +#### ์‘๋‹ต + +| ํ‚ค | ๋ฐ์ดํ„ฐ ํƒ€์ž… | ์„ค๋ช… | +| -------------- | ----------- | ------------- | +| `result` | `List[float]` | ์˜ˆ์ธก๋œ ํ–ฅํ›„ 24์ผ๊ฐ„ ์˜จ๋„๊ฐ’ | +| `error` | `string` | ์—๋Ÿฌ๋‚ด์šฉ | + + +```jsonc +{ + "result": [32.32, 33.32, 34.11...], + "error": "Error info" +} +``` diff --git a/experiments/atmos_tmp_01/config.yml b/experiments/atmos_tmp_01/config.yml new file mode 100644 index 0000000..031af06 --- /dev/null +++ b/experiments/atmos_tmp_01/config.yml @@ -0,0 +1,19 @@ +experimentName: GRU +searchSpaceFile: search_space.json +trialCommand: python train.py +trialCodeDirectory: . +trialConcurrency: 1 +maxExperimentDuration: 2h +maxTrialNumber: 2 +tuner: + # choice: + # TPE, Anneal, Evolution, SMAC, BatchTuner, GridSearch, Hyperband + # NetworkMorphism, MetisTuner, BOHB, GPTuner, PBTTuner, DNGOTuner + # SMAC need to be installed (pip install nni[SMAC]) + # https://nni.readthedocs.io/en/stable/Tuner/BuiltinTuner.html#Evolution + name: Anneal + classArgs: + optimize_mode: minimize # maximize or minimize +trainingService: + platform: local + useActiveGpu: True \ No newline at end of file diff --git a/experiments/atmos_tmp_01/preprocessing.py b/experiments/atmos_tmp_01/preprocessing.py new file mode 100644 index 0000000..981f638 --- /dev/null +++ b/experiments/atmos_tmp_01/preprocessing.py @@ -0,0 +1,11 @@ +import pandas as pd + + +def preprocess(data): + + # missing data + data = data.fillna(method="ffill") + + # etc. + + return data diff --git a/experiments/atmos_tmp_01/search_space.json b/experiments/atmos_tmp_01/search_space.json new file mode 100644 index 0000000..40fdad8 --- /dev/null +++ b/experiments/atmos_tmp_01/search_space.json @@ -0,0 +1,4 @@ +{ + "layer_n": {"_type":"randint", "_value":[2, 3]}, + "cell": {"_type":"randint", "_value":[24, 30]} +} \ No newline at end of file diff --git a/experiments/atmos_tmp_01/train.py b/experiments/atmos_tmp_01/train.py new file mode 100644 index 0000000..4a541c0 --- /dev/null +++ b/experiments/atmos_tmp_01/train.py @@ -0,0 +1,147 @@ +import os +import sys +import time +from preprocessing import preprocess + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +import numpy as np +import nni +import pandas as pd +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras.models import Sequential +from tensorflow.keras.layers import Dense +from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint +from tensorflow.keras.layers import GRU +from sklearn.metrics import mean_absolute_error, mean_squared_error + + +from expr_db import connect + +physical_devices = tf.config.list_physical_devices("GPU") +if physical_devices: + tf.config.experimental.set_memory_growth(physical_devices[0], enable=True) + + +def make_dataset(data, label, window_size=365, predsize=None): + feature_list = [] + label_list = [] + + if isinstance(predsize, int): + for i in range(len(data) - (window_size + predsize)): + feature_list.append(np.array(data.iloc[i : i + window_size])) + label_list.append( + np.array(label.iloc[i + window_size : i + window_size + predsize]) + ) + else: + for i in range(len(data) - window_size): + feature_list.append(np.array(data.iloc[i : i + window_size])) + label_list.append(np.array(label.iloc[i + window_size])) + + return np.array(feature_list), np.array(label_list) + + +def split_data(data_length, ratio): + """ + return index based on ratio + -------------------------------------------------- + example + + >>> split_data(data_length = 20, ratio = [1,2,3]) + [3, 10] + -------------------------------------------------- + """ + ratio = np.cumsum(np.array(ratio) / np.sum(ratio)) + + idx = [] + for i in ratio[:-1]: + idx.append(round(data_length * i)) + + return idx + + +def main(params): + con = connect("postgres") + data = pd.read_sql("select tmp from atmos_stn108;", con) + + data = preprocess(data) + + train_feature, train_label = make_dataset(data, data, 72, 24) + + idx = split_data(train_feature.shape[0], [6, 3, 1]) + X_train, X_valid, X_test = ( + train_feature[: idx[0]], + train_feature[idx[0] : idx[1]], + train_feature[idx[1] :], + ) + y_train, y_valid, y_test = ( + train_label[: idx[0]], + train_label[idx[0] : idx[1]], + train_label[idx[1] :], + ) + + model = Sequential() + for layer in range(params["layer_n"]): + if layer == params["layer_n"] - 1: + model.add(GRU(params["cell"])) + else: + model.add( + GRU( + params["cell"], + return_sequences=True, + input_shape=[None, train_feature.shape[2]], + ) + ) + model.add(Dense(24)) + + base_dir = os.path.dirname(os.path.abspath(__file__)) + parent_dir = base_dir.split(os.path.sep)[-1] + + model_path = "./" + model.compile(loss="mae", optimizer=keras.optimizers.Adam(lr=0.001)) + early_stop = EarlyStopping(monitor="val_loss", patience=5) + expr_time = time.strftime("%y%m%d_%H%M%S") + model_path = os.path.join(model_path, f"./temp") + if not os.path.exists(model_path): + os.makedirs(model_path) + + # ์‹คํ—˜์‹œ์ž‘์‹œ๊ฐ„์€ ์—ฌ๋Ÿฌ ๋ชจ๋ธ๊ฐ„์˜ ๊ตฌ๋ถ„์„ ์œ„ํ•ด ์ž„์‹œ๋กœ ๋„ฃ์—ˆ์ง€๋งŒ + # ์—ฌ๋Ÿฌ ์›Œ์ปค๋ฅผ ๋™์‹œ์— ์‹คํ–‰์‹œํ‚ฌ ๊ฒฝ์šฐ ๊ฒน์น  ์ˆ˜ ์žˆ์Œ. ์ถ”ํ›„ ๋ณ€๊ฒฝ ํ•„์š”!! + filename = os.path.join(model_path, f"./{parent_dir}_{expr_time}") + print(filename) + checkpoint = ModelCheckpoint( + filename, + monitor="val_loss", + verbose=1, + save_best_only=True, + mode="auto", + ) + + model.fit( + X_train, + y_train, + epochs=2, + batch_size=128, + validation_data=(X_valid, y_valid), + callbacks=[early_stop, checkpoint], + ) + + y_true = y_test.reshape(y_test.shape[0], y_test.shape[1]) + y_hat = model.predict(X_test) + + mae = mean_absolute_error(y_true, y_hat) + mse = mean_squared_error(y_true, y_hat) + + src_f = os.path.join(model_path, f"./{parent_dir}_{expr_time}") + dst_f = os.path.join( + model_path, f"./{mae:.03f}_{mse:.03f}_{parent_dir}_{expr_time}" + ) + os.rename(src_f, dst_f) + + nni.report_final_result(mae) + + +if __name__ == "__main__": + params = nni.get_next_parameter() + main(params) diff --git a/experiments/expr_db.py b/experiments/expr_db.py new file mode 100644 index 0000000..fc6ae08 --- /dev/null +++ b/experiments/expr_db.py @@ -0,0 +1,21 @@ +import os +from dotenv import load_dotenv +import sqlalchemy + + +def connect(db="postgres"): + """Returns a connection and a metadata object""" + + load_dotenv(verbose=True) + + POSTGRES_USER = os.getenv("POSTGRES_USER") + POSTGRES_PASSWORD = os.getenv("POSTGRES_PASSWORD") + POSTGRES_SERVER = os.getenv("POSTGRES_SERVER") + POSTGRES_PORT = os.getenv("POSTGRES_PORT") + POSTGRES_DB = db + + url = f"postgresql://{POSTGRES_USER}:{POSTGRES_PASSWORD}@{POSTGRES_SERVER}:{POSTGRES_PORT}/{POSTGRES_DB}" + + connection = sqlalchemy.create_engine(url) + + return connection diff --git a/experiments/insurance/config.yml b/experiments/insurance/config.yml new file mode 100644 index 0000000..4a22ca2 --- /dev/null +++ b/experiments/insurance/config.yml @@ -0,0 +1,21 @@ +authorName: ehddnr +experimentName: Lab04 +trialConcurrency: 1 +maxExecDuration: 1h +maxTrialNum: 10 +#choice: local, remote, pai +trainingServicePlatform: local +#nniManagerIp: +#choice: true, false +searchSpacePath: search_space.json +useAnnotation: false +tuner: + #choice: TPE, Random, Anneal, Evolution, BatchTuner, MetisTuner + #SMAC (SMAC should be installed through nnictl) + builtinTunerName: Anneal + classArgs: + #choice: maximize, minimize + optimize_mode: minimize +trial: + command: python trial.py + codeDir: . \ No newline at end of file diff --git a/experiments/insurance/query.py b/experiments/insurance/query.py new file mode 100644 index 0000000..8476f9e --- /dev/null +++ b/experiments/insurance/query.py @@ -0,0 +1,95 @@ +# insert temp + +INSERT_TEMP_MODEL = """ + INSERT INTO temp_model_data ( + model_name, + model_file, + experiment_name, + experimenter, + version, + train_mae, + val_mae, + train_mse, + val_mse + ) VALUES ( + {}, + '{}', + {}, + {}, + {}, + {}, + {}, + {}, + {} + ) +""" + + +# INSERT +INSERT_MODEL_CORE = """ + INSERT INTO model_core ( + model_name, + model_file + ) VALUES( + %s, + '%s' + ) + """ + +INSERT_MODEL_METADATA = """ + INSERT INTO model_metadata ( + experiment_name, + model_core_name, + experimenter, + version, + train_mae, + val_mae, + train_mse, + val_mse + ) VALUES ( + %s, + %s, + %s, + %s, + %s, + %s, + %s, + %s + ) + """ + +# UPDATE +UPDATE_MODEL_METADATA = """ + UPDATE model_metadata + SET + train_mae = %s, + val_mae = %s, + train_mse = %s, + val_mse = %s + WHERE experiment_name = %s + """ +UPDATE_MODEL_CORE = """ + UPDATE model_core + SET + model_file = '%s' + WHERE + model_name = %s + """ + +# pd READ_SQL +SELECT_ALL_INSURANCE = """ + SELECT * + FROM insurance + """ + +SELECT_VAL_MAE = """ + SELECT val_mae + FROM model_metadata + WHERE model_core_name = %s + """ + +SELECT_MODEL_CORE = """ + SELECT * + FROM model_core + WHERE model_name = %s + """ diff --git a/experiments/insurance/search_space.json b/experiments/insurance/search_space.json new file mode 100644 index 0000000..774f778 --- /dev/null +++ b/experiments/insurance/search_space.json @@ -0,0 +1,6 @@ +{ + "max_depth": {"_type":"randint", "_value":[2, 10]}, + "colsample_bytree": {"_type":"uniform", "_value":[0.3, 1.0]}, + "learning_rate": {"_type":"uniform", "_value":[0.001, 0.1]}, + "n_estimators": {"_type":"randint", "_value":[100, 200]} +} \ No newline at end of file diff --git a/experiments/insurance/trial.py b/experiments/insurance/trial.py new file mode 100644 index 0000000..6c61747 --- /dev/null +++ b/experiments/insurance/trial.py @@ -0,0 +1,161 @@ +import codecs +import getopt +import os +import pickle +import sys + + +sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + + +import nni +import numpy as np +import pandas as pd +from dotenv import load_dotenv +from expr_db import connect +from query import * +from sklearn.metrics import mean_absolute_error, mean_squared_error +from sklearn.model_selection import KFold, train_test_split +from sklearn.preprocessing import LabelEncoder +from xgboost.sklearn import XGBRegressor + +load_dotenv(verbose=True) + + +def preprocess(x_train, x_valid, col_list): + """ + param: + x_train : train dataset dataframe + x_valid : validation dataset dataframe + col_list : columns that required for LabelEncoding + return: + tmp_x_train.values : numpy.ndarray + tmp_x_valid.values : numpy.ndarray + """ + tmp_x_train = x_train.copy() + tmp_x_valid = x_valid.copy() + + tmp_x_train.reset_index(drop=True, inplace=True) + tmp_x_valid.reset_index(drop=True, inplace=True) + + encoder = LabelEncoder() + + for col in col_list: + tmp_x_train.loc[:, col] = encoder.fit_transform(tmp_x_train.loc[:, col]) + tmp_x_valid.loc[:, col] = encoder.transform(tmp_x_valid.loc[:, col]) + + return tmp_x_train.values, tmp_x_valid.values + + +def main(params, engine, experiment_info, connection): + """ + param: + params: Parameters determined by NNi + engine: sqlalchemy engine + experiment_info: information of experiment [dict] + connection: connection used to communicate with DB + """ + + df = pd.read_sql(SELECT_ALL_INSURANCE, engine) + experimenter = experiment_info["experimenter"] + experiment_name = experiment_info["experiment_name"] + model_name = experiment_info["model_name"] + version = experiment_info["version"] + + label_col = ["sex", "smoker", "region"] + + y = df.charges.to_frame() + x = df.iloc[:, :-1] + + x_train, x_valid, y_train, y_valid = train_test_split( + x, y, test_size=0.2, random_state=42 + ) + + kf = KFold(n_splits=5, shuffle=True, random_state=42) + + cv_mse, cv_mae, tr_mse, tr_mae = [], [], [], [] + fold_mae, fold_model = 1e10, None + + for trn_idx, val_idx in kf.split(x, y): + x_train, y_train = x.iloc[trn_idx], y.iloc[trn_idx] + x_valid, y_valid = x.iloc[val_idx], y.iloc[val_idx] + + # ์ „์ฒ˜๋ฆฌ + x_tra, x_val = preprocess(x_train, x_valid, label_col) + + # ๋ชจ๋ธ ์ •์˜ ๋ฐ ํŒŒ๋ผ๋ฏธํ„ฐ ์ „๋‹ฌ + model = XGBRegressor(**params) + + # ๋ชจ๋ธ ํ•™์Šต ๋ฐ Early Stopping ์ ์šฉ + model.fit(x_tra, y_train, eval_set=[(x_val, y_valid)], early_stopping_rounds=10) + + y_train_pred = model.predict(x_tra) + y_valid_pred = model.predict(x_val) + # Loss ๊ณ„์‚ฐ + train_mse = mean_squared_error(y_train, y_train_pred) + valid_mse = mean_squared_error(y_valid, y_valid_pred) + train_mae = mean_absolute_error(y_train, y_train_pred) + valid_mae = mean_absolute_error(y_valid, y_valid_pred) + + cv_mse.append(valid_mse) + cv_mae.append(valid_mae) + tr_mse.append(train_mse) + tr_mae.append(train_mae) + + new_mae = min(fold_mae, valid_mae) + if new_mae != fold_mae: + fold_model = model + + cv_mse_mean = np.mean(cv_mse) + cv_mae_mean = np.mean(cv_mae) + tr_mse_mean = np.mean(tr_mse) + tr_mae_mean = np.mean(tr_mae) + pickled_model = codecs.encode(pickle.dumps(fold_model), "base64").decode() + + connection.execute( + INSERT_TEMP_MODEL.format( + model_name, + pickled_model, + experiment_name, + experimenter, + version, + tr_mae_mean, + cv_mae_mean, + tr_mse_mean, + cv_mse_mean, + ) + ) + + nni.report_final_result(cv_mae_mean) + print("Final result is %g", cv_mae_mean) + print("Send final result done.") + + +if __name__ == "__main__": + params = nni.get_next_parameter() + engine = connect() + argv = sys.argv + experiment_info = {} + + try: + opts, etc_args = getopt.getopt( + argv[1:], + "e:n:m:v:", + ["experimenter=", "experiment_name=", "model_name=", "version="], + ) + for opt, arg in opts: + if opt in ("-e", "--experimenter"): + experiment_info["experimenter"] = f"'{arg}'" + elif opt in ("-n", "--experiment_name"): + experiment_info["experiment_name"] = f"'{arg}'" + elif opt in ("-m", "--model_name"): + experiment_info["model_name"] = f"'{arg}'" + elif opt in ("-v", "--version"): + experiment_info["version"] = arg + + except getopt.GetoptError: + sys.exit(2) + + with engine.connect() as connection: + with connection.begin(): + main(params, engine, experiment_info, connection) diff --git a/logger.py b/logger.py new file mode 100644 index 0000000..b3f056d --- /dev/null +++ b/logger.py @@ -0,0 +1,21 @@ +import logging +import logging.handlers +from colorlog import ColoredFormatter + + +L = logging.getLogger("snowdeer_log") +L.setLevel(logging.DEBUG) + +formatter = ColoredFormatter( + fmt="%(log_color)s [%(levelname)s] %(reset)s %(asctime)s [%(filename)s:%(lineno)d - %(funcName)20s()]\n\t%(message)s", + datefmt="%y-%m-%d %H:%M:%S", +) + +fileHandler = logging.FileHandler("./log.txt") +streamHandler = logging.StreamHandler() + +fileHandler.setFormatter(formatter) +streamHandler.setFormatter(formatter) + +L.addHandler(fileHandler) +L.addHandler(streamHandler) diff --git a/main.py b/main.py index 454a233..9e40cab 100644 --- a/main.py +++ b/main.py @@ -1,21 +1,32 @@ from fastapi import FastAPI +from fastapi.middleware.cors import CORSMiddleware +import uvicorn -from app.router import predict -from app.database import SessionLocal + +from app.api.router import predict, train app = FastAPI() -app.include_router(predict.router) +origins = ["*"] +app.add_middleware( + CORSMiddleware, + allow_origins=origins, + allow_credentials=True, + allow_methods=["*"], + allow_headers=["*"], +) -def get_db(): - db = SessionLocal() - try: - yield db - finally: - db.close() +app.include_router(predict.router) +app.include_router(train.router) @app.get("/") def hello_world(): return {"message": "Hello World"} + + +if __name__ == "__main__": + uvicorn.run( + "main:app", host="0.0.0.0", port=8000, reload=True, reload_dirs=["app/"] + ) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000..ea0763c Binary files /dev/null and b/requirements.txt differ