Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[API] update api for multi file and pdf support #1522

Merged
merged 11 commits into from
Apr 11, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 33 additions & 10 deletions api/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,21 @@ with this snippet:
import requests
with open('/path/to/your/img.jpg', 'rb') as f:
data = f.read()
print(requests.post("http://localhost:8080/detection", files={'file': data}).json())
print(requests.post("http://localhost:8080/detection", files={'files': [data]}).json())
```

should yield

```json
[{'box': [0.826171875, 0.185546875, 0.90234375, 0.201171875]},
{'box': [0.75390625, 0.185546875, 0.8173828125, 0.201171875]}]
[
{
"name": "invitation.png",
"boxes": [
[0.50390625, 0.712890625, 0.5185546875, 0.720703125],
[0.4716796875, 0.712890625, 0.48828125, 0.720703125]
]
},
]
```

#### Text recognition
Expand All @@ -58,13 +65,18 @@ with this snippet:
import requests
with open('/path/to/your/img.jpg', 'rb') as f:
data = f.read()
print(requests.post("http://localhost:8080/recognition", files={'file': data}).json())
print(requests.post("http://localhost:8080/recognition", files={'files': [data]}).json())
```

should yield

```json
{'value': 'invite'}
[
{
"name": "invitation.png",
"value": "invite"
},
]
```

#### End-to-end OCR
Expand All @@ -78,14 +90,25 @@ with this snippet:
import requests
with open('/path/to/your/img.jpg', 'rb') as f:
data = f.read()
print(requests.post("http://localhost:8080/ocr", files={'file': data}).json())
print(requests.post("http://localhost:8080/ocr", files={'files': [data]}).json())
```

should yield

```json
[{'box': [0.75390625, 0.185546875, 0.8173828125, 0.201171875],
'value': 'Hello'},
{'box': [0.826171875, 0.185546875, 0.90234375, 0.201171875],
'value': 'world!'}]
[
{
"name": "hello_world.jpg",
"items": [
{
"value": "Hello",
"box": [0.005859375, 0.003312938981562763, 0.0205078125, 0.0332854340430202]
},
{
"value": "world!",
"box": [0.005859375, 0.003312938981562763, 0.0205078125, 0.0332854340430202]
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
},
],
}
]
```
26 changes: 20 additions & 6 deletions api/app/routes/detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,33 @@

from typing import List

from fastapi import APIRouter, File, UploadFile, status
from fastapi import APIRouter, File, HTTPException, UploadFile, status

from app.schemas import DetectionOut
from app.vision import det_predictor
from doctr.file_utils import CLASS_NAME
from doctr.io import decode_img_as_tensor
from doctr.io import DocumentFile

router = APIRouter()


@router.post("/", response_model=List[DetectionOut], status_code=status.HTTP_200_OK, summary="Perform text detection")
async def text_detection(file: UploadFile = File(...)):
async def text_detection(files: List[UploadFile] = [File(...)]):

Check warning on line 19 in api/app/routes/detection.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/app/routes/detection.py#L19

as argument
"""Runs docTR text detection model to analyze the input image"""
img = decode_img_as_tensor(file.file.read())
boxes = det_predictor([img])[0]
return [DetectionOut(box=box.tolist()) for box in boxes[CLASS_NAME][:, :-1]]
boxes: List[DetectionOut] = []
for file in files:
mime_type = file.content_type
if mime_type in ["image/jpeg", "image/png"]:
content = DocumentFile.from_images([await file.read()])
elif mime_type == "application/pdf":
content = DocumentFile.from_pdf(await file.read())
else:
raise HTTPException(status_code=400, detail=f"Unsupported file format for detection endpoint: {mime_type}")

boxes.append(
DetectionOut(
name=file.filename or "", boxes=[box.tolist() for box in det_predictor(content)[0][CLASS_NAME][:, :-1]]
)
)

return boxes
52 changes: 36 additions & 16 deletions api/app/routes/kie.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,27 +3,47 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Dict, List
from typing import List

from fastapi import APIRouter, File, UploadFile, status
from fastapi import APIRouter, File, HTTPException, UploadFile, status

from app.schemas import OCROut
from app.schemas import KIEElement, KIEOut
from app.vision import kie_predictor
from doctr.io import decode_img_as_tensor
from doctr.io import DocumentFile

router = APIRouter()


@router.post("/", response_model=Dict[str, List[OCROut]], status_code=status.HTTP_200_OK, summary="Perform KIE")
async def perform_kie(file: UploadFile = File(...)):
@router.post("/", response_model=List[KIEOut], status_code=status.HTTP_200_OK, summary="Perform KIE")
async def perform_kie(files: List[UploadFile] = [File(...)]):

Check warning on line 18 in api/app/routes/kie.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/app/routes/kie.py#L18

as argument
"""Runs docTR KIE model to analyze the input image"""
img = decode_img_as_tensor(file.file.read())
out = kie_predictor([img])

return {
class_name: [
OCROut(box=(*prediction.geometry[0], *prediction.geometry[1]), value=prediction.value)
for prediction in out.pages[0].predictions[class_name]
]
for class_name in out.pages[0].predictions.keys()
}
results: List[KIEOut] = []
for file in files:
mime_type = file.content_type
if mime_type in ["image/jpeg", "image/png"]:
content = DocumentFile.from_images([await file.read()])
elif mime_type == "application/pdf":
content = DocumentFile.from_pdf(await file.read())
else:
raise HTTPException(status_code=400, detail=f"Unsupported file format for KIE endpoint: {mime_type}")
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved

out = kie_predictor(content)

for page in out.pages:
results.append(
KIEOut(
name=file.filename or "",
predictions=[
KIEElement(
class_name=class_name,
items=[
dict(value=prediction.value, box=(*prediction.geometry[0], *prediction.geometry[1]))
for prediction in page.predictions[class_name]
],
)
for class_name in page.predictions.keys()
],
)
)

return results
40 changes: 28 additions & 12 deletions api/app/routes/ocr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,24 +5,40 @@

from typing import List

from fastapi import APIRouter, File, UploadFile, status
from fastapi import APIRouter, File, HTTPException, UploadFile, status

from app.schemas import OCROut
from app.vision import predictor
from doctr.io import decode_img_as_tensor
from doctr.io import DocumentFile

router = APIRouter()


@router.post("/", response_model=List[OCROut], status_code=status.HTTP_200_OK, summary="Perform OCR")
async def perform_ocr(file: UploadFile = File(...)):
async def perform_ocr(files: List[UploadFile] = [File(...)]):

Check warning on line 18 in api/app/routes/ocr.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/app/routes/ocr.py#L18

as argument
"""Runs docTR OCR model to analyze the input image"""
img = decode_img_as_tensor(file.file.read())
out = predictor([img])

return [
OCROut(box=(*word.geometry[0], *word.geometry[1]), value=word.value)
for block in out.pages[0].blocks
for line in block.lines
for word in line.words
]
results: List[OCROut] = []
for file in files:
mime_type = file.content_type
if mime_type in ["image/jpeg", "image/png"]:
content = DocumentFile.from_images([await file.read()])
elif mime_type == "application/pdf":
content = DocumentFile.from_pdf(await file.read())
else:
raise HTTPException(status_code=400, detail=f"Unsupported file format for OCR endpoint: {mime_type}")

out = predictor(content)
for page in out.pages:
results.append(
OCROut(
name=file.filename or "",
items=[
dict(value=word.value, box=(*word.geometry[0], *word.geometry[1]))
for block in page.blocks
for line in block.lines
for word in line.words
],
)
)

return results
28 changes: 21 additions & 7 deletions api/app/routes/recognition.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,32 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from fastapi import APIRouter, File, UploadFile, status
from typing import List

from fastapi import APIRouter, File, HTTPException, UploadFile, status

from app.schemas import RecognitionOut
from app.vision import reco_predictor
from doctr.io import decode_img_as_tensor
from doctr.io import DocumentFile

router = APIRouter()


@router.post("/", response_model=RecognitionOut, status_code=status.HTTP_200_OK, summary="Perform text recognition")
async def text_recognition(file: UploadFile = File(...)):
@router.post(
"/", response_model=List[RecognitionOut], status_code=status.HTTP_200_OK, summary="Perform text recognition"
)
async def text_recognition(files: List[UploadFile] = [File(...)]):

Check warning on line 20 in api/app/routes/recognition.py

View check run for this annotation

Codacy Production / Codacy Static Code Analysis

api/app/routes/recognition.py#L20

as argument
"""Runs docTR text recognition model to analyze the input image"""
img = decode_img_as_tensor(file.file.read())
out = reco_predictor([img])
return RecognitionOut(value=out[0][0])
words: List[RecognitionOut] = []
for file in files:
mime_type = file.content_type
if mime_type in ["image/jpeg", "image/png"]:
content = DocumentFile.from_images([await file.read()])
else:
raise HTTPException(
status_code=400, detail=f"Unsupported file format for recognition endpoint: {mime_type}"
)

words.append(RecognitionOut(name=file.filename or "", value=reco_predictor(content)[0][0]))

return words
28 changes: 22 additions & 6 deletions api/app/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,35 @@
# This program is licensed under the Apache License 2.0.
# See LICENSE or go to <https://opensource.org/licenses/Apache-2.0> for full license details.

from typing import Tuple
from typing import Dict, List, Tuple, Union

from pydantic import BaseModel, Field


# Recognition output
class RecognitionOut(BaseModel):
value: str = Field(..., example="Hello")
name: str = Field(..., examples=["example.jpg"])
value: str = Field(..., examples=["Hello"])


class DetectionOut(BaseModel):
box: Tuple[float, float, float, float]
name: str = Field(..., examples=["example.jpg"])
boxes: List[Tuple[float, float, float, float]]


class OCROut(RecognitionOut, DetectionOut):
pass
class OCROut(BaseModel):
name: str = Field(..., examples=["example.jpg"])
items: List[Dict[str, Union[str, Tuple[float, float, float, float]]]] = Field(
..., examples=[{"value": "example", "box": [0.0, 0.0, 0.0, 0.0]}]
)


class KIEElement(BaseModel):
class_name: str = Field(..., examples=["example"])
items: List[Dict[str, Union[str, Tuple[float, float, float, float]]]] = Field(
..., examples=[{"value": "example", "box": [0.0, 0.0, 0.0, 0.0]}]
)


class KIEOut(BaseModel):
name: str = Field(..., examples=["example.jpg"])
predictions: List[KIEElement]
4 changes: 2 additions & 2 deletions api/app/vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from doctr.models import kie_predictor, ocr_predictor

predictor = ocr_predictor(pretrained=True)
predictor = ocr_predictor(pretrained=True, assume_straight_pages=True)
det_predictor = predictor.det_predictor
reco_predictor = predictor.reco_predictor
kie_predictor = kie_predictor(pretrained=True)
kie_predictor = kie_predictor(pretrained=True, assume_straight_pages=True)
2 changes: 1 addition & 1 deletion api/docker-compose.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
version: '3.7'
version: '3.8'

services:
web:
Expand Down
1 change: 0 additions & 1 deletion api/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@ license = "Apache-2.0"

[tool.poetry.dependencies]
python = ">=3.9,<3.12"
tensorflow = ">=2.11.0,<2.16.0" # cf. https://github.com/mindee/doctr/pull/1461
python-doctr = {git = "https://github.com/mindee/doctr.git", extras = ['tf'], branch = "main" }
# Fastapi: minimum version required to avoid pydantic error
# cf. https://github.com/tiangolo/fastapi/issues/4168
Expand Down
7 changes: 7 additions & 0 deletions api/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ def mock_detection_image(tmpdir_factory):
return requests.get(url).content


@pytest_asyncio.fixture(scope="session")
def mock_txt_file(tmpdir_factory):
txt_file = tmpdir_factory.mktemp("data").join("mock.txt")
txt_file.write("mock text")
return txt_file.read("rb")


@pytest_asyncio.fixture(scope="function")
async def test_app_asyncio():
# for httpx>=20, follow_redirects=True (cf. https://github.com/encode/httpx/releases/tag/0.20.0)
Expand Down
13 changes: 9 additions & 4 deletions api/tests/routes/test_detection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@


@pytest.mark.asyncio
async def test_text_detection(test_app_asyncio, mock_detection_image):
response = await test_app_asyncio.post("/detection", files={"file": mock_detection_image})
async def test_text_detection(test_app_asyncio, mock_detection_image, mock_txt_file):
response = await test_app_asyncio.post("/detection", files={"files": [mock_detection_image] * 2})
assert response.status_code == 200
json_response = response.json()

Expand All @@ -16,9 +16,14 @@ async def test_text_detection(test_app_asyncio, mock_detection_image):
gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] / 2339

# Check that IoU with GT if reasonable
assert isinstance(json_response, list) and len(json_response) == gt_boxes.shape[0]
pred_boxes = np.array([elt["box"] for elt in json_response])
assert isinstance(json_response, list) and len(json_response) == 2
first_pred = json_response[0]
assert isinstance(first_pred, dict) and len(first_pred["boxes"]) == gt_boxes.shape[0]
pred_boxes = np.array(first_pred["boxes"])
iou_mat = box_iou(gt_boxes, pred_boxes)
gt_idxs, pred_idxs = linear_sum_assignment(-iou_mat)
is_kept = iou_mat[gt_idxs, pred_idxs] >= 0.8
assert gt_idxs[is_kept].shape[0] == gt_boxes.shape[0]
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved

response = await test_app_asyncio.post("/detection", files={"files": [mock_txt_file]})
assert response.status_code == 400
Loading
Loading