-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpytorch_classifier.py
36 lines (27 loc) · 998 Bytes
/
pytorch_classifier.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from typing import List
import torch
import torch.nn as nn
import torch.nn.functional as F
from classifier import Classifier
PYTORCH_MODEL_INPUT_DIM = 4
class Model(nn.Module):
def __init__(self, input_dim=PYTORCH_MODEL_INPUT_DIM):
super(Model, self).__init__()
self.layer1 = nn.Linear(input_dim, 50)
self.layer2 = nn.Linear(50, 50)
self.layer3 = nn.Linear(50, 3)
def forward(self, x):
x = F.relu(self.layer1(x))
x = F.relu(self.layer2(x))
x = F.softmax(self.layer3(x), dim=1)
return x
class PytorchClassifier(Classifier):
def __init__(self, model_path):
super().__init__(model_path)
self.model = Model()
self.model.load_state_dict(torch.load(self.model_path))
self.model.eval()
def predict(self, input_data: List[List[float]]):
with torch.no_grad():
probas = self.model(torch.Tensor(input_data)).tolist()
return self._prepare_response_dict(probas)