Skip to content

Commit

Permalink
Merge pull request #10 from State-of-The-MLOps/feature/training_api
Browse files Browse the repository at this point in the history
Add training API
  • Loading branch information
chl8469 authored Sep 10, 2021
2 parents 213b340 + a28ee89 commit 4d0af71
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 14 deletions.
54 changes: 54 additions & 0 deletions app/api/router/train.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
# -*- coding: utf-8 -*-
from app.utils import write_yml
import subprocess

from fastapi import APIRouter


router = APIRouter(
prefix="/train",
tags=["train"],
responses={404: {"description": "Not Found"}}
)


@router.put("/")
def train_insurance(
PORT: int = 8080,
experiment_sec: int = 20,
experiment_name: str = 'exp1',
experimenter: str = 'DongUk',
model_name: str = 'insurance_fee_model',
version: float = 0.1
):
"""
Args:
PORT (int): PORT to run NNi. Defaults to 8080
experiment_sec (int): Express the experiment time in seconds Defaults to 20
experiment_name (str): experiment name Defaults to exp1
experimeter (str): experimenter (author) Defaults to DongUk
model_name (str): model name Defaults to insurance_fee_model
version (float): version of experiment Defaults to 0.1
Returns:
msg: Regardless of success or not, return address values including PORT.
"""
path = 'experiments/insurance/'
try:
write_yml(
path,
experiment_name,
experimenter,
model_name,
version
)
subprocess.Popen(
"nnictl create --port {} --config {}/{}.yml && timeout {} && nnictl stop --port {}".format(
PORT, path, model_name, experiment_sec, PORT),
shell=True,
)

except Exception as e:
print('error')
print(e)

return {"msg": f'Check out http://127.0.0.1:{PORT}'}
48 changes: 45 additions & 3 deletions app/utils.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,20 @@
from app.database import engine
import codecs
import pickle
import zipfile
import os
import yaml
import zipfile

import tensorflow as tf

from app.database import engine


base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

physical_devices = tf.config.list_physical_devices('GPU')
tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)


class MyModel:
def __init__(self):
self._my_model = None
Expand Down Expand Up @@ -37,7 +42,6 @@ def load_tf_model(self, model_name):

return tf_model


def load_model(self):
self._my_model = self.load_tf_model('test_model')

Expand All @@ -48,3 +52,41 @@ def my_model(self):

my_model = MyModel()
my_model.load_model()


def write_yml(
path,
experiment_name,
experimenter,
model_name,
version
):
print(type(version))
with open('{}/{}.yml'.format(path, model_name), 'w') as yml_config_file:
yaml.dump({
'authorName': f'{experimenter}',
'experimentName': f'{experiment_name}',
'trialConcurrency': 1,
'maxExecDuration': '1h',
'maxTrialNum': 10,
'trainingServicePlatform': 'local',
'searchSpacePath': 'search_space.json',
'useAnnotation': False,
'tuner': {
'builtinTunerName': 'Anneal',
'classArgs': {
'optimize_mode': 'minimize'
}},
'trial': {
'command': 'python -e %s -en %s -mn %s -v %f trial.py' % (
experimenter,
experiment_name,
model_name,
version
),
'codeDir': '.'
}}, yml_config_file, default_flow_style=False)

yml_config_file.close()

return
35 changes: 30 additions & 5 deletions experiments/insurance/trial.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import os
import codecs
import pickle
import sys
import getopt

from dotenv import load_dotenv
import pandas as pd
Expand Down Expand Up @@ -161,14 +163,37 @@ def main(params, df, engine, experiment_info, connection):
f'{POSTGRES_SERVER}:{POSTGRES_PORT}/{POSTGRES_DB}'
engine = create_engine(SQLALCHEMY_DATABASE_URL)

argv = sys.argv
experiment_info = {
'path': "'C:\\Users\\TFG5076XG\\Documents\\MLOps'",
'experimenter': "'DongUk'",
'experiment_name': "'insurance0903'",
'model_name': "'keep_update_model'",
'version': 0.1
"experimenter": '',
"experiment_name": '',
"model_name": '',
"version": 0.1
}

try:
opts, etc_args = getopt.getopt(
argv[1:],
"e:en:mn:v:",
[
"experimenter=",
"experiment_name=",
"model_name=",
"version"
])
for opt, arg in opts:
if opt in ('-e', "--experimenter"):
experiment_info['experimenter'] = arg
elif opt in ("-en", "--experiment_name"):
experiment_info['experiment_name'] = arg
elif opt in ("-mn", "--model_name"):
experiment_info['model_name'] = arg
elif opt in ("-v", "--version"):
experiment_info['version'] = arg

except getopt.GetoptError:
sys.exit(2)

df = pd.read_sql(SELECT_ALL_INSURANCE, engine)

with engine.connect() as connection:
Expand Down
14 changes: 8 additions & 6 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from fastapi.middleware.cors import CORSMiddleware
import uvicorn

from app.api.router import predict
from app.api.router import predict, train

app = FastAPI()

Expand All @@ -17,15 +17,17 @@
)

app.include_router(predict.router)
app.include_router(train.router)


@app.get("/")
def hello_world():
return {"message": "Hello World"}


if __name__ == "__main__":
uvicorn.run("main:app",
host="0.0.0.0",
port=8000,
reload=True,
reload_dirs=['app/'])
uvicorn.run("main:app",
host="0.0.0.0",
port=8000,
reload=True,
reload_dirs=['app/'])

0 comments on commit 4d0af71

Please sign in to comment.