From a28ee89266aaabd54301b677701e88861a6d6a2a Mon Sep 17 00:00:00 2001
From: ehddnr301 <dy950328@gmail.com>
Date: Fri, 10 Sep 2021 13:17:33 +0900
Subject: [PATCH] Add training API
MIME-Version: 1.0
Content-Type: text/plain; charset=UTF-8
Content-Transfer-Encoding: 8bit

training 하는것을 API로 만들어서 요청을 받습니다.
요청받은 값으로 config.yml 을 만들고 그 파일로 training을 진행합니다.
---
 app/api/router/train.py        | 54 ++++++++++++++++++++++++++++++++++
 app/utils.py                   | 48 ++++++++++++++++++++++++++++--
 experiments/insurance/trial.py | 35 ++++++++++++++++++----
 main.py                        | 14 +++++----
 4 files changed, 137 insertions(+), 14 deletions(-)
 create mode 100644 app/api/router/train.py

diff --git a/app/api/router/train.py b/app/api/router/train.py
new file mode 100644
index 0000000..1be02b5
--- /dev/null
+++ b/app/api/router/train.py
@@ -0,0 +1,54 @@
+# -*- coding: utf-8 -*-
+from app.utils import write_yml
+import subprocess
+
+from fastapi import APIRouter
+
+
+router = APIRouter(
+    prefix="/train",
+    tags=["train"],
+    responses={404: {"description": "Not Found"}}
+)
+
+
+@router.put("/")
+def train_insurance(
+    PORT: int = 8080,
+    experiment_sec: int = 20,
+    experiment_name: str = 'exp1',
+    experimenter: str = 'DongUk',
+    model_name: str = 'insurance_fee_model',
+    version: float = 0.1
+):
+    """
+    Args:
+        PORT (int): PORT to run NNi. Defaults to 8080
+        experiment_sec (int): Express the experiment time in seconds Defaults to 20
+        experiment_name (str): experiment name Defaults to exp1
+        experimeter (str): experimenter (author) Defaults to DongUk
+        model_name (str): model name Defaults to insurance_fee_model
+        version (float): version of experiment Defaults to 0.1
+    Returns:
+        msg: Regardless of success or not, return address values including PORT.
+    """
+    path = 'experiments/insurance/'
+    try:
+        write_yml(
+            path,
+            experiment_name,
+            experimenter,
+            model_name,
+            version
+        )
+        subprocess.Popen(
+            "nnictl create --port {} --config {}/{}.yml && timeout {} && nnictl stop --port {}".format(
+                PORT, path, model_name, experiment_sec, PORT),
+            shell=True,
+        )
+
+    except Exception as e:
+        print('error')
+        print(e)
+
+    return {"msg": f'Check out http://127.0.0.1:{PORT}'}
diff --git a/app/utils.py b/app/utils.py
index 1da705e..460b24e 100644
--- a/app/utils.py
+++ b/app/utils.py
@@ -1,15 +1,20 @@
-from app.database import engine
 import codecs
 import pickle
-import zipfile
 import os
+import yaml
+import zipfile
+
 import tensorflow as tf
 
+from app.database import engine
+
+
 base_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
 
 physical_devices = tf.config.list_physical_devices('GPU')
 tf.config.experimental.set_memory_growth(physical_devices[0], enable=True)
 
+
 class MyModel:
     def __init__(self):
         self._my_model = None
@@ -37,7 +42,6 @@ def load_tf_model(self, model_name):
 
         return tf_model
 
-
     def load_model(self):
         self._my_model = self.load_tf_model('test_model')
 
@@ -48,3 +52,41 @@ def my_model(self):
 
 my_model = MyModel()
 my_model.load_model()
+
+
+def write_yml(
+    path,
+    experiment_name,
+    experimenter,
+    model_name,
+    version
+):
+    print(type(version))
+    with open('{}/{}.yml'.format(path, model_name), 'w') as yml_config_file:
+        yaml.dump({
+            'authorName': f'{experimenter}',
+            'experimentName': f'{experiment_name}',
+            'trialConcurrency': 1,
+            'maxExecDuration': '1h',
+            'maxTrialNum': 10,
+            'trainingServicePlatform': 'local',
+            'searchSpacePath': 'search_space.json',
+            'useAnnotation': False,
+            'tuner': {
+                'builtinTunerName': 'Anneal',
+                'classArgs': {
+                    'optimize_mode': 'minimize'
+                }},
+            'trial': {
+                'command': 'python -e %s -en %s -mn %s -v %f trial.py' % (
+                    experimenter,
+                    experiment_name,
+                    model_name,
+                    version
+                ),
+                'codeDir': '.'
+            }}, yml_config_file, default_flow_style=False)
+
+        yml_config_file.close()
+
+    return
diff --git a/experiments/insurance/trial.py b/experiments/insurance/trial.py
index 81eeef4..b027a3d 100644
--- a/experiments/insurance/trial.py
+++ b/experiments/insurance/trial.py
@@ -1,6 +1,8 @@
 import os
 import codecs
 import pickle
+import sys
+import getopt
 
 from dotenv import load_dotenv
 import pandas as pd
@@ -161,14 +163,37 @@ def main(params, df, engine, experiment_info, connection):
         f'{POSTGRES_SERVER}:{POSTGRES_PORT}/{POSTGRES_DB}'
     engine = create_engine(SQLALCHEMY_DATABASE_URL)
 
+    argv = sys.argv
     experiment_info = {
-        'path': "'C:\\Users\\TFG5076XG\\Documents\\MLOps'",
-        'experimenter': "'DongUk'",
-        'experiment_name': "'insurance0903'",
-        'model_name': "'keep_update_model'",
-        'version': 0.1
+        "experimenter": '',
+        "experiment_name": '',
+        "model_name": '',
+        "version": 0.1
     }
 
+    try:
+        opts, etc_args = getopt.getopt(
+            argv[1:],
+            "e:en:mn:v:",
+            [
+                "experimenter=",
+                "experiment_name=",
+                "model_name=",
+                "version"
+            ])
+        for opt, arg in opts:
+            if opt in ('-e', "--experimenter"):
+                experiment_info['experimenter'] = arg
+            elif opt in ("-en", "--experiment_name"):
+                experiment_info['experiment_name'] = arg
+            elif opt in ("-mn", "--model_name"):
+                experiment_info['model_name'] = arg
+            elif opt in ("-v", "--version"):
+                experiment_info['version'] = arg
+
+    except getopt.GetoptError:
+        sys.exit(2)
+
     df = pd.read_sql(SELECT_ALL_INSURANCE, engine)
 
     with engine.connect() as connection:
diff --git a/main.py b/main.py
index 0adecc5..68d6f21 100644
--- a/main.py
+++ b/main.py
@@ -2,7 +2,7 @@
 from fastapi.middleware.cors import CORSMiddleware
 import uvicorn
 
-from app.api.router import predict
+from app.api.router import predict, train
 
 app = FastAPI()
 
@@ -17,15 +17,17 @@
 )
 
 app.include_router(predict.router)
+app.include_router(train.router)
 
 
 @app.get("/")
 def hello_world():
     return {"message": "Hello World"}
 
+
 if __name__ == "__main__":
-    uvicorn.run("main:app", 
-                host="0.0.0.0", 
-                port=8000, 
-                reload=True, 
-                reload_dirs=['app/'])
\ No newline at end of file
+    uvicorn.run("main:app",
+                host="0.0.0.0",
+                port=8000,
+                reload=True,
+                reload_dirs=['app/'])