Skip to content

Commit

Permalink
Merge pull request #25 from State-of-The-MLOps/feature/atmos_train
Browse files Browse the repository at this point in the history
Feature/atmos train
  • Loading branch information
chl8469 authored Oct 6, 2021
2 parents 02ba16d + c794eff commit 0f3f4fa
Show file tree
Hide file tree
Showing 4 changed files with 127 additions and 118 deletions.
8 changes: 4 additions & 4 deletions app/api/router/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def sync_call(info, model_name):

except Exception as e:
L.error(e)
return {"error": str(e)}
return {"result": "Can't predict", "error": str(e)}


@router.put("/atmos")
Expand All @@ -71,7 +71,7 @@ async def predict_temperature(time_series: List[float]):
"""
if len(time_series) != 72:
L.error(f"input time_series: {time_series} is not valid")
return "time series must have 72 values"
return {"result": "time series must have 72 values", "error": None}

def sync_pred_ts(time_series):
"""
Expand All @@ -83,12 +83,12 @@ def sync_pred_ts(time_series):
f"Predict Args info: {time_series.flatten().tolist()}\n\tmodel_name: {my_model.model_name}\n\tPrediction Result: {result.tolist()[0]}"
)

return result
return {"result": result, "error": None}

try:
result = await run_in_threadpool(sync_pred_ts, time_series)
return result.tolist()

except Exception as e:
L.error(e)
return {"error": str(e)}
return {"result": "Can't predict", "error": str(e)}
20 changes: 9 additions & 11 deletions app/api/router/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from fastapi import APIRouter

from app.utils import NniWatcher, ExprimentOwl, base_dir, check_expr_over, get_free_port, write_yml
from app.utils import NniWatcher, ExperimentOwl, base_dir, get_free_port, write_yml
from logger import L

router = APIRouter(
Expand Down Expand Up @@ -56,11 +56,11 @@ def train_insurance(
m_process.start()

L.info(nni_create_result)
return nni_create_result
return {"msg": nni_create_result, "error": None}

except Exception as e:
L.error(e)
return {"error": str(e)}
return {"msg": "Can't start experiment", "error": str(e)}


@router.put("/atmos")
Expand Down Expand Up @@ -90,22 +90,20 @@ def train_atmos(expr_name: str):
if sucs_msg in nni_create_result:
p = re.compile(r"The experiment id is ([a-zA-Z0-9]+)\n")
expr_id = p.findall(nni_create_result)[0]
check_expr = ExprimentOwl(expr_id, expr_name, expr_path)
check_expr = ExperimentOwl(expr_id, expr_name, expr_path)
check_expr.add("update_tfmodeldb")
check_expr.add("modelfile_cleaner")

m_process = multiprocessing.Process(
target=check_expr.execute
)

m_process = multiprocessing.Process(target=check_expr.execute)
m_process.start()

L.info(nni_create_result)
return nni_create_result
return {"msg": nni_create_result, "error": None}

else:
L.error(nni_create_result)
return {"error": nni_create_result}
return {"msg": nni_create_result, "error": None}

except Exception as e:
L.error(e)
return {"error": str(e)}
return {"msg": "Can't start experiment", "error": str(e)}
Loading

0 comments on commit 0f3f4fa

Please sign in to comment.