-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathfastapi_server.py
82 lines (62 loc) · 2.48 KB
/
fastapi_server.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
'''
FastAPI server code for CosmoCuisine iOS app
This server code is singular in purpose - to take .wav files from the client and classify them as languages using the
ECAPA-TDNN model pretrained on the VoxLingua107 Data Set (https://huggingface.co/TalTechNLP/voxlingua107-epaca-tdnn)
'''
# FastAPI imports
from fastapi import FastAPI, HTTPException, UploadFile, File
# Machine Learning Imports
from speechbrain.inference import EncoderClassifier
# Define some elements in the API
async def custom_lifespan(app: FastAPI):
# Set up the ECAPA-TDNN classifier pre-trained on VoxLingua107
app.language_id = EncoderClassifier.from_hparams(
source="./Voxlingua107-ECAPA-TDNN",
savedir="./Voxlingua107-ECAPA-TDNN")
yield
# Create the FastAPI app
app = FastAPI(
title="CosmoCuisine",
summary="Find out nutrition facts about packaged goods using audio, vision, and ML!",
lifespan=custom_lifespan,
)
#========================================
# Data store objects from pydantic
#----------------------------------------
'''See pydantic_models.py
'''
#===========================================
# Machine Learning methods (Scikit-learn)
#-------------------------------------------
# These allow us to interact with the REST server with ML from Scikit-learn.
@app.post(
"/predict",
response_description="Predict Label from Datapoint",
)
async def predict_datapoint(
wav_file: UploadFile = File(...)
):
"""
Post a feature data and get the prediction back
"""
try:
try:
# Read .wav file uploaded as part of multipart/form-data request
audio_content = await wav_file.read()
except Exception as e:
raise HTTPException(status_code=400, detail=f"Error reading wav file: {e}")
try:
# Create a temporary file path
temp_filepath = "./recording.wav"
# Write the uploaded bytes to the temporary file
with open(temp_filepath, "wb") as f:
f.write(audio_content)
# Classify language using path to wave file per load_audio expectations
signal = app.language_id.load_audio(temp_filepath)
prediction = app.language_id.classify_batch(signal)
predicted_label = prediction[3][0]
return {"prediction": predicted_label}
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error predicting with ML model: {e}")
finally:
await wav_file.close()