Skip to content

Commit

Permalink
update api
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 committed Mar 25, 2024
1 parent afb9358 commit 6086d74
Show file tree
Hide file tree
Showing 15 changed files with 207 additions and 81 deletions.
43 changes: 33 additions & 10 deletions api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,21 @@ with this snippet:
import requests
with open('/path/to/your/img.jpg', 'rb') as f:
data = f.read()
print(requests.post("http://localhost:8080/detection", files={'file': data}).json())
print(requests.post("http://localhost:8080/detection", files={'files': [data]}).json())
```

should yield

```json
[{'box': [0.826171875, 0.185546875, 0.90234375, 0.201171875]},
{'box': [0.75390625, 0.185546875, 0.8173828125, 0.201171875]}]
[
{
"name": "invitation.png",
"boxes": [
[0.50390625, 0.712890625, 0.5185546875, 0.720703125],
[0.4716796875, 0.712890625, 0.48828125, 0.720703125]
]
},
]
```

#### Text recognition
Expand All @@ -58,13 +65,18 @@ with this snippet:
import requests
with open('/path/to/your/img.jpg', 'rb') as f:
data = f.read()
print(requests.post("http://localhost:8080/recognition", files={'file': data}).json())
print(requests.post("http://localhost:8080/recognition", files={'files': [data]}).json())
```

should yield

```json
{'value': 'invite'}
[
{
"name": "invitation.png",
"value": "invite"
},
]
```

#### End-to-end OCR
Expand All @@ -78,14 +90,25 @@ with this snippet:
import requests
with open('/path/to/your/img.jpg', 'rb') as f:
data = f.read()
print(requests.post("http://localhost:8080/ocr", files={'file': data}).json())
print(requests.post("http://localhost:8080/ocr", files={'files': [data]}).json())
```

should yield

```json
[{'box': [0.75390625, 0.185546875, 0.8173828125, 0.201171875],
'value': 'Hello'},
{'box': [0.826171875, 0.185546875, 0.90234375, 0.201171875],
'value': 'world!'}]
[
{
"name": "hello_world.jpg",
"items": [
{
"value": "Hello",
"box": [0.005859375, 0.003312938981562763, 0.0205078125, 0.0332854340430202]
},
{
"value": "world!",
"box": [0.005859375, 0.003312938981562763, 0.0205078125, 0.0332854340430202]
},
],
}
]
```
26 changes: 20 additions & 6 deletions api/app/routes/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,33 @@

from typing import List

from fastapi import APIRouter, File, UploadFile, status
from fastapi import APIRouter, File, HTTPException, UploadFile, status

from app.schemas import DetectionOut
from app.vision import det_predictor
from doctr.file_utils import CLASS_NAME
from doctr.io import decode_img_as_tensor
from doctr.io import DocumentFile

router = APIRouter()


@router.post("/", response_model=List[DetectionOut], status_code=status.HTTP_200_OK, summary="Perform text detection")
async def text_detection(file: UploadFile = File(...)):
async def text_detection(files: List[UploadFile] = [File(...)]):

Check warning on line 19 in api/app/routes/detection.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/app/routes/detection.py#L19

as argument
"""Runs docTR text detection model to analyze the input image"""
img = decode_img_as_tensor(file.file.read())
boxes = det_predictor([img])[0]
return [DetectionOut(box=box.tolist()) for box in boxes[CLASS_NAME][:, :-1]]
boxes: List[DetectionOut] = []
for file in files:
mime_type = file.content_type
if mime_type in ["image/jpeg", "image/png"]:
content = DocumentFile.from_images([await file.read()])
elif mime_type == "application/pdf":
content = DocumentFile.from_pdf(await file.read())
else:
raise HTTPException(status_code=400, detail=f"Unsupported file format for detection endpoint: {mime_type}")

boxes.append(
DetectionOut(
name=file.filename or "", boxes=[box.tolist() for box in det_predictor(content)[0][CLASS_NAME][:, :-1]]
)
)

return boxes
52 changes: 36 additions & 16 deletions api/app/routes/kie.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,47 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Dict, List
from typing import List

from fastapi import APIRouter, File, UploadFile, status
from fastapi import APIRouter, File, HTTPException, UploadFile, status

from app.schemas import OCROut
from app.schemas import KIEElement, KIEOut
from app.vision import kie_predictor
from doctr.io import decode_img_as_tensor
from doctr.io import DocumentFile

router = APIRouter()


@router.post("/", response_model=Dict[str, List[OCROut]], status_code=status.HTTP_200_OK, summary="Perform KIE")
async def perform_kie(file: UploadFile = File(...)):
@router.post("/", response_model=List[KIEOut], status_code=status.HTTP_200_OK, summary="Perform KIE")
async def perform_kie(files: List[UploadFile] = [File(...)]):

Check warning on line 18 in api/app/routes/kie.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/app/routes/kie.py#L18

as argument
"""Runs docTR KIE model to analyze the input image"""
img = decode_img_as_tensor(file.file.read())
out = kie_predictor([img])

return {
class_name: [
OCROut(box=(*prediction.geometry[0], *prediction.geometry[1]), value=prediction.value)
for prediction in out.pages[0].predictions[class_name]
]
for class_name in out.pages[0].predictions.keys()
}
results: List[KIEOut] = []
for file in files:
mime_type = file.content_type
if mime_type in ["image/jpeg", "image/png"]:
content = DocumentFile.from_images([await file.read()])
elif mime_type == "application/pdf":
content = DocumentFile.from_pdf(await file.read())
else:
raise HTTPException(status_code=400, detail=f"Unsupported file format for KIE endpoint: {mime_type}")

out = kie_predictor(content)

for page in out.pages:
results.append(
KIEOut(
name=file.filename or "",
predictions=[
KIEElement(
class_name=class_name,
items=[
dict(value=prediction.value, box=(*prediction.geometry[0], *prediction.geometry[1]))
for prediction in page.predictions[class_name]
],
)
for class_name in page.predictions.keys()
],
)
)

return results
40 changes: 28 additions & 12 deletions api/app/routes/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,40 @@

from typing import List

from fastapi import APIRouter, File, UploadFile, status
from fastapi import APIRouter, File, HTTPException, UploadFile, status

from app.schemas import OCROut
from app.vision import predictor
from doctr.io import decode_img_as_tensor
from doctr.io import DocumentFile

router = APIRouter()


@router.post("/", response_model=List[OCROut], status_code=status.HTTP_200_OK, summary="Perform OCR")
async def perform_ocr(file: UploadFile = File(...)):
async def perform_ocr(files: List[UploadFile] = [File(...)]):

Check warning on line 18 in api/app/routes/ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/app/routes/ocr.py#L18

as argument
"""Runs docTR OCR model to analyze the input image"""
img = decode_img_as_tensor(file.file.read())
out = predictor([img])

return [
OCROut(box=(*word.geometry[0], *word.geometry[1]), value=word.value)
for block in out.pages[0].blocks
for line in block.lines
for word in line.words
]
results: List[OCROut] = []
for file in files:
mime_type = file.content_type
if mime_type in ["image/jpeg", "image/png"]:
content = DocumentFile.from_images([await file.read()])
elif mime_type == "application/pdf":
content = DocumentFile.from_pdf(await file.read())
else:
raise HTTPException(status_code=400, detail=f"Unsupported file format for OCR endpoint: {mime_type}")

out = predictor(content)
for page in out.pages:
results.append(
OCROut(
name=file.filename or "",
items=[
dict(value=word.value, box=(*word.geometry[0], *word.geometry[1]))
for block in page.blocks
for line in block.lines
for word in line.words
],
)
)

return results
28 changes: 21 additions & 7 deletions api/app/routes/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,32 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from fastapi import APIRouter, File, UploadFile, status
from typing import List

from fastapi import APIRouter, File, HTTPException, UploadFile, status

from app.schemas import RecognitionOut
from app.vision import reco_predictor
from doctr.io import decode_img_as_tensor
from doctr.io import DocumentFile

router = APIRouter()


@router.post("/", response_model=RecognitionOut, status_code=status.HTTP_200_OK, summary="Perform text recognition")
async def text_recognition(file: UploadFile = File(...)):
@router.post(
"/", response_model=List[RecognitionOut], status_code=status.HTTP_200_OK, summary="Perform text recognition"
)
async def text_recognition(files: List[UploadFile] = [File(...)]):

Check warning on line 20 in api/app/routes/recognition.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/app/routes/recognition.py#L20

as argument
"""Runs docTR text recognition model to analyze the input image"""
img = decode_img_as_tensor(file.file.read())
out = reco_predictor([img])
return RecognitionOut(value=out[0][0])
words: List[RecognitionOut] = []
for file in files:
mime_type = file.content_type
if mime_type in ["image/jpeg", "image/png"]:
content = DocumentFile.from_images([await file.read()])
else:
raise HTTPException(
status_code=400, detail=f"Unsupported file format for recognition endpoint: {mime_type}"
)

words.append(RecognitionOut(name=file.filename or "", value=reco_predictor(content)[0][0]))

return words
28 changes: 22 additions & 6 deletions api/app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,35 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Tuple
from typing import Dict, List, Tuple, Union

from pydantic import BaseModel, Field


# Recognition output
class RecognitionOut(BaseModel):
value: str = Field(..., example="Hello")
name: str = Field(..., examples=["example.jpg"])
value: str = Field(..., examples=["Hello"])


class DetectionOut(BaseModel):
box: Tuple[float, float, float, float]
name: str = Field(..., examples=["example.jpg"])
boxes: List[Tuple[float, float, float, float]]


class OCROut(RecognitionOut, DetectionOut):
pass
class OCROut(BaseModel):
name: str = Field(..., examples=["example.jpg"])
items: List[Dict[str, Union[str, Tuple[float, float, float, float]]]] = Field(
..., examples=[{"value": "example", "box": [0.0, 0.0, 0.0, 0.0]}]
)


class KIEElement(BaseModel):
class_name: str = Field(..., examples=["example"])
items: List[Dict[str, Union[str, Tuple[float, float, float, float]]]] = Field(
..., examples=[{"value": "example", "box": [0.0, 0.0, 0.0, 0.0]}]
)


class KIEOut(BaseModel):
name: str = Field(..., examples=["example.jpg"])
predictions: List[KIEElement]
4 changes: 2 additions & 2 deletions api/app/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from doctr.models import kie_predictor, ocr_predictor

predictor = ocr_predictor(pretrained=True)
predictor = ocr_predictor(pretrained=True, assume_straight_pages=True)
det_predictor = predictor.det_predictor
reco_predictor = predictor.reco_predictor
kie_predictor = kie_predictor(pretrained=True)
kie_predictor = kie_predictor(pretrained=True, assume_straight_pages=True)
2 changes: 1 addition & 1 deletion api/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version: '3.7'
version: '3.8'

services:
web:
Expand Down
1 change: 0 additions & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ license = "Apache-2.0"

[tool.poetry.dependencies]
python = ">=3.8.2,<3.11" # pypdfium2 needs a python version above 3.8.2
tensorflow = ">=2.11.0,<2.16.0" # cf. https://github.com/mindee/doctr/pull/1461
python-doctr = {git = "https://github.com/mindee/doctr.git", extras = ['tf'], branch = "main" }
# Fastapi: minimum version required to avoid pydantic error
# cf. https://github.com/tiangolo/fastapi/issues/4168
Expand Down
7 changes: 7 additions & 0 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ def mock_detection_image(tmpdir_factory):
return requests.get(url).content


@pytest_asyncio.fixture(scope="session")
def mock_txt_file(tmpdir_factory):
txt_file = tmpdir_factory.mktemp("data").join("mock.txt")
txt_file.write("mock text")
return txt_file.read("rb")


@pytest_asyncio.fixture(scope="function")
async def test_app_asyncio():
# for httpx>=20, follow_redirects=True (cf. https://github.com/encode/httpx/releases/tag/0.20.0)
Expand Down
13 changes: 9 additions & 4 deletions api/tests/routes/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


@pytest.mark.asyncio
async def test_text_detection(test_app_asyncio, mock_detection_image):
response = await test_app_asyncio.post("/detection", files={"file": mock_detection_image})
async def test_text_detection(test_app_asyncio, mock_detection_image, mock_txt_file):
response = await test_app_asyncio.post("/detection", files={"files": [mock_detection_image] * 2})
assert response.status_code == 200
json_response = response.json()

Expand All @@ -16,9 +16,14 @@ async def test_text_detection(test_app_asyncio, mock_detection_image):
gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] / 2339

# Check that IoU with GT if reasonable
assert isinstance(json_response, list) and len(json_response) == gt_boxes.shape[0]
pred_boxes = np.array([elt["box"] for elt in json_response])
assert isinstance(json_response, list) and len(json_response) == 2
first_pred = json_response[0]
assert isinstance(first_pred, dict) and len(first_pred["boxes"]) == gt_boxes.shape[0]
pred_boxes = np.array(first_pred["boxes"])
iou_mat = box_iou(gt_boxes, pred_boxes)
gt_idxs, pred_idxs = linear_sum_assignment(-iou_mat)
is_kept = iou_mat[gt_idxs, pred_idxs] >= 0.8
assert gt_idxs[is_kept].shape[0] == gt_boxes.shape[0]

response = await test_app_asyncio.post("/detection", files={"files": [mock_txt_file]})
assert response.status_code == 400
Loading

0 comments on commit 6086d74

Please sign in to comment.