-
-
Notifications
You must be signed in to change notification settings - Fork 16
/
Copy pathmain.py
72 lines (59 loc) · 2.25 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
from threading import Semaphore
import cv2
import numpy as np
import uvicorn
from fastapi import FastAPI, UploadFile, HTTPException
from loguru import logger
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings
from ultralytics import YOLO
class Conf(BaseSettings):
model_path: str = Field(default="yolo-doclaynet.pt", description="Model path")
max_connections: int = Field(
default=10, description="Maximum number of connections"
)
port: int = Field(default=8000, description="Port number")
conf = Conf()
app = FastAPI()
model = YOLO(conf.model_path)
semaphore = Semaphore(conf.max_connections)
class LabelBox(BaseModel):
label: str = Field(example="Text", description="Label of the object")
box: list[float] = Field(
example=[0.0, 0.0, 0.0, 0.0], description="Bounding box coordinates"
)
class DetectResponse(BaseModel):
label_boxes: list[LabelBox] = Field(
example=[{"label": "Text", "box": [0.0, 0.0, 0.0, 0.0]}],
description="List of detected objects",
)
speed: dict = Field(
example={"preprocess": 0.0, "inference": 0.0, "postprocess": 0.0},
description="Speed in milliseconds",
)
@app.post("/api/detect")
def detect(image: UploadFile) -> DetectResponse:
logger.info(f"Received image: {image.filename}, {image.size}")
with semaphore:
image = cv2.imdecode(
np.frombuffer(image.file.read(), np.uint8), cv2.IMREAD_COLOR
)
if image is None:
raise HTTPException(status_code=400, detail="Invalid image")
result = model.predict(image, verbose=False)[0]
height = result.orig_shape[0]
width = result.orig_shape[1]
label_boxes = []
for label, box in zip(result.boxes.cls.tolist(), result.boxes.xyxyn.tolist()):
label_boxes.append(
LabelBox(
label=result.names[int(label)],
box=[box[0] * width, box[1] * height, box[2] * width, box[3] * height],
)
)
logger.info(
f"Detected objects: {len(label_boxes)}, Image size: {width}x{height}, Speed: {result.speed}"
)
return DetectResponse(label_boxes=label_boxes, speed=result.speed)
if __name__ == "__main__":
uvicorn.run(app, host="0.0.0.0", port=conf.port)