-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapi.py
99 lines (77 loc) · 2.59 KB
/
api.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
from typing import List
import numpy as np
from fastapi import FastAPI, HTTPException
from fastapi.params import Body
from pydantic.dataclasses import dataclass
from sklearn_classifier import SklearnClassifier
from pytorch_classifier import PytorchClassifier, PYTORCH_MODEL_INPUT_DIM
SKLEARN_MODEL_PATH = "trained_models/sklearn.model"
PYTORCH_MODEL_PATH = "trained_models/pytorch.model"
sklearn_classifier = SklearnClassifier(SKLEARN_MODEL_PATH)
pytorch_classifier = PytorchClassifier(PYTORCH_MODEL_PATH)
classifiers = {
"sklearn": sklearn_classifier,
"pytorch": pytorch_classifier,
}
dimensions = {
"sklearn": sklearn_classifier.model.n_features_in_,
"pytorch": PYTORCH_MODEL_INPUT_DIM,
}
app = FastAPI()
@dataclass
class InputData:
sampleData: List[List[float]]
@dataclass
class InputDataUniversal(InputData):
model: str
def input_data_is_valid(model: str, input_values: List[List[float]]):
"""
Input data is valid if all the below is true:
- 2 dimensions
- vector dimension should match model input dimension
- values of type `float`
"""
if (
np.array(input_values).ndim == 2
and np.array(input_values).shape[1] == dimensions[model]
and np.array(input_values).dtype == float
):
return True
return False
def get_classification(model: str, input_data: InputData):
if input_data_is_valid(model, input_data.sampleData):
try:
prediction = classifiers[model].predict(input_data.sampleData)
return prediction
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
else:
raise HTTPException(status_code=422, detail="Input data is invalid")
@app.post("/sklearn")
def get_classification_sklearn(
input_data: InputData = Body(..., description="Input data"),
):
"""
Classify sample color using Scikit-Learn model.
"""
return get_classification("sklearn", input_data)
@app.post("/pytorch")
def get_classification_pytorch(
input_data: InputData = Body(..., description="Input data"),
):
"""
Classify sample color using Pytorch model.
"""
return get_classification("pytorch", input_data)
@app.post("/universal")
def get_classification_universal(
input_data: InputDataUniversal = Body(
..., description="Input data for universal endpoint"
),
):
"""
Classify sample color using Scikit-Learn or Pytorch model.
"""
if input_data.model not in classifiers.keys():
raise HTTPException(status_code=400, detail="Invalid model")
return get_classification(input_data.model, input_data)