From 2e1c6ddd676227d1cbc4cff9771b20595259ba38 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Chlo=C3=A9=20Daems?= <73901882+chloedia@users.noreply.github.com> Date: Wed, 8 Jan 2025 10:35:14 +0100 Subject: [PATCH] fix: Add EngineConfig & StrategyHandler (#211) * fix: Add EngineConfig & StrategyHandler * fix: Configs settings --- libs/megaparse/src/megaparse/configs/auto.py | 47 ++++ .../src/megaparse/examples/parse_file.py | 2 +- libs/megaparse/src/megaparse/megaparse.py | 37 +-- .../src/megaparse/parser/doctr_parser.py | 39 +-- .../src/megaparse/parser/strategy.py | 235 +++++++++++------- .../predictor/doctr_layout_detector.py | 9 +- libs/megaparse/tests/pdf/test_detect_ocr.py | 14 +- .../tests/pdf/test_pdf_processing.py | 13 +- libs/megaparse_sdk/megaparse_sdk/config.py | 15 -- 9 files changed, 258 insertions(+), 153 deletions(-) create mode 100644 libs/megaparse/src/megaparse/configs/auto.py diff --git a/libs/megaparse/src/megaparse/configs/auto.py b/libs/megaparse/src/megaparse/configs/auto.py new file mode 100644 index 0000000..c0034c1 --- /dev/null +++ b/libs/megaparse/src/megaparse/configs/auto.py @@ -0,0 +1,47 @@ +from enum import Enum + +from pydantic import BaseModel +from pydantic_settings import BaseSettings, SettingsConfigDict + + +class TextDetConfig(BaseModel): + det_arch: str = "fast_base" + batch_size: int = 2 + assume_straight_pages: bool = True + preserve_aspect_ratio: bool = True + symmetric_pad: bool = True + load_in_8_bit: bool = False + + +class AutoStrategyConfig(BaseModel): + auto_page_threshold: float = 0.6 + auto_document_threshold: float = 0.2 + + +class TextRecoConfig(BaseModel): + reco_arch: str = "crnn_vgg16_bn" + batch_size: int = 512 + + +class DeviceEnum(str, Enum): + CPU = "cpu" + CUDA = "cuda" + COREML = "coreml" + + +class MegaParseConfig(BaseSettings): + """ + Configuration for Megaparse. + """ + + model_config = SettingsConfigDict( + env_prefix="MEGAPARSE_", + env_file=(".env.local", ".env"), + env_nested_delimiter="__", + extra="ignore", + use_enum_values=True, + ) + text_det_config: TextDetConfig = TextDetConfig() + text_reco_config: TextRecoConfig = TextRecoConfig() + auto_parse_config: AutoStrategyConfig = AutoStrategyConfig() + device: DeviceEnum = DeviceEnum.CPU diff --git a/libs/megaparse/src/megaparse/examples/parse_file.py b/libs/megaparse/src/megaparse/examples/parse_file.py index b6aa06e..f5cd8bc 100644 --- a/libs/megaparse/src/megaparse/examples/parse_file.py +++ b/libs/megaparse/src/megaparse/examples/parse_file.py @@ -7,7 +7,7 @@ def main(): parser = UnstructuredParser() megaparse = MegaParse(parser=parser) - file_path = "./tests/pdf/ocr/0168126.pdf" + file_path = "./tests/pdf/native/0168029.pdf" parsed_file = megaparse.load(file_path) print(f"\n----- File Response : {file_path} -----\n") diff --git a/libs/megaparse/src/megaparse/megaparse.py b/libs/megaparse/src/megaparse/megaparse.py index f7fa70c..b4580a0 100644 --- a/libs/megaparse/src/megaparse/megaparse.py +++ b/libs/megaparse/src/megaparse/megaparse.py @@ -4,7 +4,7 @@ from pathlib import Path from typing import IO, BinaryIO -from megaparse_sdk.config import MegaParseConfig +from megaparse.configs.auto import DeviceEnum, MegaParseConfig from megaparse_sdk.schema.extensions import FileExtension from megaparse_sdk.schema.parser_config import StrategyEnum @@ -12,28 +12,43 @@ from megaparse.exceptions.base import ParsingException from megaparse.parser.base import BaseParser from megaparse.parser.doctr_parser import DoctrParser -from megaparse.parser.strategy import determine_strategy +from megaparse.parser.strategy import StrategyHandler from megaparse.parser.unstructured_parser import UnstructuredParser logger = logging.getLogger("megaparse") class MegaParse: - config: MegaParseConfig = MegaParseConfig() + config = MegaParseConfig() def __init__( self, - parser: BaseParser = UnstructuredParser(strategy=StrategyEnum.FAST), - ocr_parser: BaseParser = DoctrParser(), + parser: BaseParser | None = None, + ocr_parser: BaseParser | None = None, strategy: StrategyEnum = StrategyEnum.AUTO, format_checker: FormatChecker | None = None, ) -> None: + if not parser: + parser = UnstructuredParser(strategy=StrategyEnum.FAST) + if not ocr_parser: + ocr_parser = DoctrParser( + text_det_config=self.config.text_det_config, + text_reco_config=self.config.text_reco_config, + device=self.config.device, + ) + self.strategy = strategy self.parser = parser self.ocr_parser = ocr_parser self.format_checker = format_checker self.last_parsed_document: str = "" + self.strategy_handler = StrategyHandler( + text_det_config=self.config.text_det_config, + auto_config=self.config.auto_parse_config, + device=self.config.device, + ) + def validate_input( self, file_path: Path | str | None = None, @@ -132,17 +147,11 @@ def _select_parser( if self.strategy != StrategyEnum.AUTO or file_extension != FileExtension.PDF: return self.parser if file: - local_strategy = determine_strategy( - file=file, - threshold_pages_ocr=self.config.auto_document_threshold, - threshold_per_page=self.config.auto_page_threshold, + local_strategy = self.strategy_handler.determine_strategy( + file=file, # type: ignore #FIXME: Careful here on removing BinaryIO (not handled by onnxtr) ) if file_path: - local_strategy = determine_strategy( - file=file_path, - threshold_pages_ocr=self.config.auto_document_threshold, - threshold_per_page=self.config.auto_page_threshold, - ) + local_strategy = self.strategy_handler.determine_strategy(file=file_path) if local_strategy == StrategyEnum.HI_RES: return self.ocr_parser diff --git a/libs/megaparse/src/megaparse/parser/doctr_parser.py b/libs/megaparse/src/megaparse/parser/doctr_parser.py index c009732..38efe08 100644 --- a/libs/megaparse/src/megaparse/parser/doctr_parser.py +++ b/libs/megaparse/src/megaparse/parser/doctr_parser.py @@ -3,6 +3,7 @@ from pathlib import Path from typing import IO, BinaryIO, List +from megaparse.configs.auto import DeviceEnum, TextRecoConfig, TextDetConfig import onnxruntime as rt from megaparse_sdk.schema.extensions import FileExtension from onnxtr.io import DocumentFile @@ -19,16 +20,13 @@ class DoctrParser(BaseParser): def __init__( self, - det_predictor_model: str = "db_resnet50", - reco_predictor_model: str = "crnn_vgg16_bn", - det_bs: int = 2, - reco_bs: int = 512, - assume_straight_pages: bool = True, + text_det_config: TextDetConfig = TextDetConfig(), + text_reco_config: TextRecoConfig = TextRecoConfig(), + device: DeviceEnum = DeviceEnum.CPU, straighten_pages: bool = False, - use_gpu: bool = False, **kwargs, ): - self.use_gpu = use_gpu + self.device = device general_options = rt.SessionOptions() providers = self._get_providers() engine_config = EngineConfig( @@ -37,11 +35,11 @@ def __init__( ) # TODO: set in config or pass as kwargs self.predictor = ocr_predictor( - det_arch=det_predictor_model, - reco_arch=reco_predictor_model, - det_bs=det_bs, - reco_bs=reco_bs, - assume_straight_pages=assume_straight_pages, + det_arch=text_det_config.det_arch, + reco_arch=text_reco_config.reco_arch, + det_bs=text_det_config.batch_size, + reco_bs=text_reco_config.batch_size, + assume_straight_pages=text_det_config.assume_straight_pages, straighten_pages=straighten_pages, # Preprocessing related parameters det_engine_cfg=engine_config, @@ -53,14 +51,27 @@ def __init__( def _get_providers(self) -> List[str]: prov = rt.get_available_providers() logger.info("Available providers:", prov) - if self.use_gpu: + if self.device == DeviceEnum.CUDA: # TODO: support openvino, directml etc if "CUDAExecutionProvider" not in prov: raise ValueError( "onnxruntime can't find CUDAExecutionProvider in list of available providers" ) - return ["CUDAExecutionProvider"] + return ["TensorrtExecutionProvider", "CUDAExecutionProvider"] + elif self.device == DeviceEnum.COREML: + if "CoreMLExecutionProvider" not in prov: + raise ValueError( + "onnxruntime can't find CoreMLExecutionProvider in list of available providers" + ) + return ["CoreMLExecutionProvider"] + elif self.device == DeviceEnum.CPU: + return ["CPUExecutionProvider"] else: + warnings.warn( + "Device not supported, using CPU", + UserWarning, + stacklevel=2, + ) return ["CPUExecutionProvider"] def convert( diff --git a/libs/megaparse/src/megaparse/parser/strategy.py b/libs/megaparse/src/megaparse/parser/strategy.py index 45e56e7..780c634 100644 --- a/libs/megaparse/src/megaparse/parser/strategy.py +++ b/libs/megaparse/src/megaparse/parser/strategy.py @@ -1,108 +1,167 @@ import logging import random +import warnings from pathlib import Path +from typing import Any, List import numpy as np +import onnxruntime as rt import pypdfium2 as pdfium from megaparse_sdk.schema.parser_config import StrategyEnum from onnxtr.io import DocumentFile from onnxtr.models import detection_predictor +from onnxtr.models.engine import EngineConfig from pypdfium2._helpers.page import PdfPage +from megaparse.configs.auto import AutoStrategyConfig, DeviceEnum, TextDetConfig from megaparse.predictor.doctr_layout_detector import LayoutPredictor from megaparse.predictor.models.base import PageLayout logger = logging.getLogger("megaparse") -def get_strategy_page( - pdfium_page: PdfPage, onnxtr_page: PageLayout, threshold: float -) -> StrategyEnum: - # assert ( - # p_width == onnxtr_page.dimensions[1] - # and p_height == onnxtr_page.dimensions[0] - # ), "Page dimensions do not match" - text_coords = [] - # Get all the images in the page - for obj in pdfium_page.get_objects(): - if obj.type == 1: - text_coords.append(obj.get_pos()) - - p_width, p_height = int(pdfium_page.get_width()), int(pdfium_page.get_height()) - - pdfium_canva = np.zeros((int(p_height), int(p_width))) - - for coords in text_coords: - # (left,bottom,right, top) - # 0---l--------------R-> y - # | - # B (x0,y0) - # | - # T (x1,y1) - # ^ - # x - x0, y0, x1, y1 = ( - p_height - coords[3], - coords[0], - p_height - coords[1], - coords[2], +class StrategyHandler: + def __init__( + self, + auto_config: AutoStrategyConfig = AutoStrategyConfig(), + text_det_config: TextDetConfig = TextDetConfig(), + device: DeviceEnum = DeviceEnum.CPU, + ) -> None: + self.config = auto_config + self.device = device + general_options = rt.SessionOptions() + providers = self._get_providers() + engine_config = EngineConfig( + session_options=general_options, + providers=providers, ) - x0 = max(0, min(p_height, int(x0))) - y0 = max(0, min(p_width, int(y0))) - x1 = max(0, min(p_height, int(x1))) - y1 = max(0, min(p_width, int(y1))) - pdfium_canva[x0:x1, y0:y1] = 1 - - onnxtr_canva = np.zeros((int(p_height), int(p_width))) - for block in onnxtr_page.bboxes: - x0, y0 = block.bbox[0] - x1, y1 = block.bbox[1] - x0 = max(0, min(int(x0 * p_width), int(p_width))) - y0 = max(0, min(int(y0 * p_height), int(p_height))) - x1 = max(0, min(int(x1 * p_width), int(p_width))) - y1 = max(0, min(int(y1 * p_height), int(p_height))) - onnxtr_canva[y0:y1, x0:x1] = 1 - - intersection = np.logical_and(pdfium_canva, onnxtr_canva) - union = np.logical_or(pdfium_canva, onnxtr_canva) - iou = np.sum(intersection) / np.sum(union) - if iou < threshold: - return StrategyEnum.HI_RES - return StrategyEnum.FAST - - -def determine_strategy( - file: str - | Path - | bytes, # FIXME : Careful here on removing BinaryIO (not handled by onnxtr) - threshold_pages_ocr: float, - threshold_per_page: float, -) -> StrategyEnum: - logger.info("Determining strategy...") - need_ocr = 0 - - onnxtr_document = DocumentFile.from_pdf(file) - det_predictor = detection_predictor() - layout_predictor = LayoutPredictor(det_predictor) - - pdfium_document = pdfium.PdfDocument(file) - - onnxtr_document_layout = layout_predictor(onnxtr_document) - - for pdfium_page, onnxtr_page in zip( - pdfium_document, onnxtr_document_layout, strict=True - ): - strategy = get_strategy_page( - pdfium_page, onnxtr_page, threshold=threshold_per_page - ) - need_ocr += strategy == StrategyEnum.HI_RES - doc_need_ocr = (need_ocr / len(pdfium_document)) > threshold_pages_ocr - if isinstance(pdfium_document, pdfium.PdfDocument): - pdfium_document.close() + self.det_predictor = detection_predictor( + arch=text_det_config.det_arch, + assume_straight_pages=text_det_config.assume_straight_pages, + preserve_aspect_ratio=text_det_config.preserve_aspect_ratio, + symmetric_pad=text_det_config.symmetric_pad, + batch_size=text_det_config.batch_size, + load_in_8_bit=text_det_config.load_in_8_bit, + engine_cfg=engine_config, + ) - if doc_need_ocr: - logger.info("Using HI_RES strategy") - return StrategyEnum.HI_RES - logger.info("Using FAST strategy") - return StrategyEnum.FAST + def _get_providers(self) -> List[str]: + prov = rt.get_available_providers() + logger.info("Available providers:", prov) + if self.device == DeviceEnum.CUDA: + # TODO: support openvino, directml etc + if "CUDAExecutionProvider" not in prov: + raise ValueError( + "onnxruntime can't find CUDAExecutionProvider in list of available providers" + ) + return ["TensorrtExecutionProvider", "CUDAExecutionProvider"] + elif self.device == DeviceEnum.COREML: + if "CoreMLExecutionProvider" not in prov: + raise ValueError( + "onnxruntime can't find CoreMLExecutionProvider in list of available providers" + ) + return ["CoreMLExecutionProvider"] + elif self.device == DeviceEnum.CPU: + return ["CPUExecutionProvider"] + else: + warnings.warn( + "Device not supported, using CPU", + UserWarning, + stacklevel=2, + ) + return ["CPUExecutionProvider"] + + def get_strategy_page( + self, pdfium_page: PdfPage, onnxtr_page: PageLayout + ) -> StrategyEnum: + # assert ( + # p_width == onnxtr_page.dimensions[1] + # and p_height == onnxtr_page.dimensions[0] + # ), "Page dimensions do not match" + text_coords = [] + # Get all the images in the page + for obj in pdfium_page.get_objects(): + if obj.type == 1: + text_coords.append(obj.get_pos()) + + p_width, p_height = int(pdfium_page.get_width()), int(pdfium_page.get_height()) + + pdfium_canva = np.zeros((int(p_height), int(p_width))) + + for coords in text_coords: + # (left,bottom,right, top) + # 0---l--------------R-> y + # | + # B (x0,y0) + # | + # T (x1,y1) + # ^ + # x + x0, y0, x1, y1 = ( + p_height - coords[3], + coords[0], + p_height - coords[1], + coords[2], + ) + x0 = max(0, min(p_height, int(x0))) + y0 = max(0, min(p_width, int(y0))) + x1 = max(0, min(p_height, int(x1))) + y1 = max(0, min(p_width, int(y1))) + pdfium_canva[x0:x1, y0:y1] = 1 + + onnxtr_canva = np.zeros((int(p_height), int(p_width))) + for block in onnxtr_page.bboxes: + x0, y0 = block.bbox[0] + x1, y1 = block.bbox[1] + x0 = max(0, min(int(x0 * p_width), int(p_width))) + y0 = max(0, min(int(y0 * p_height), int(p_height))) + x1 = max(0, min(int(x1 * p_width), int(p_width))) + y1 = max(0, min(int(y1 * p_height), int(p_height))) + onnxtr_canva[y0:y1, x0:x1] = 1 + + intersection = np.logical_and(pdfium_canva, onnxtr_canva) + union = np.logical_or(pdfium_canva, onnxtr_canva) + iou = np.sum(intersection) / np.sum(union) + if iou < self.config.auto_page_threshold: + return StrategyEnum.HI_RES + return StrategyEnum.FAST + + def determine_strategy( + self, + file: str + | Path + | bytes, # FIXME : Careful here on removing BinaryIO (not handled by onnxtr) + max_samples: int = 5, + ) -> StrategyEnum: + logger.info("Determining strategy...") + need_ocr = 0 + + onnxtr_document = DocumentFile.from_pdf(file) + layout_predictor = LayoutPredictor(self.det_predictor) + pdfium_document = pdfium.PdfDocument(file) + + if len(pdfium_document) > max_samples: + sample_pages_index = random.sample(range(len(onnxtr_document)), max_samples) + onnxtr_document = [onnxtr_document[i] for i in sample_pages_index] + pdfium_document = [pdfium_document[i] for i in sample_pages_index] + + onnxtr_document_layout = layout_predictor(onnxtr_document) + + for pdfium_page, onnxtr_page in zip( + pdfium_document, onnxtr_document_layout, strict=True + ): + strategy = self.get_strategy_page(pdfium_page, onnxtr_page) + need_ocr += strategy == StrategyEnum.HI_RES + + doc_need_ocr = ( + need_ocr / len(pdfium_document) + ) > self.config.auto_document_threshold + if isinstance(pdfium_document, pdfium.PdfDocument): + pdfium_document.close() + + if doc_need_ocr: + logger.info("Using HI_RES strategy") + return StrategyEnum.HI_RES + logger.info("Using FAST strategy") + return StrategyEnum.FAST diff --git a/libs/megaparse/src/megaparse/predictor/doctr_layout_detector.py b/libs/megaparse/src/megaparse/predictor/doctr_layout_detector.py index e7b2564..fc50e5a 100644 --- a/libs/megaparse/src/megaparse/predictor/doctr_layout_detector.py +++ b/libs/megaparse/src/megaparse/predictor/doctr_layout_detector.py @@ -1,12 +1,12 @@ +import logging from typing import Any, List import numpy as np from megaparse.predictor.models.base import ( - BlockLayout, - PageLayout, BBOX, - Point2D, + BlockLayout, BlockType, + PageLayout, ) from onnxtr.models.detection.predictor import DetectionPredictor from onnxtr.models.engine import EngineConfig @@ -14,6 +14,8 @@ from onnxtr.utils.geometry import detach_scores from onnxtr.utils.repr import NestedObject +logger = logging.getLogger("megaparse") + class LayoutPredictor(NestedObject, _OCRPredictor): """Implements an object able to localize and identify text elements in a set of documents @@ -42,6 +44,7 @@ def __init__( preserve_aspect_ratio: bool = True, symmetric_pad: bool = True, detect_orientation: bool = False, + use_gpu: bool = False, clf_engine_cfg: EngineConfig | None = None, **kwargs: Any, ): diff --git a/libs/megaparse/tests/pdf/test_detect_ocr.py b/libs/megaparse/tests/pdf/test_detect_ocr.py index c474108..6b6c57d 100644 --- a/libs/megaparse/tests/pdf/test_detect_ocr.py +++ b/libs/megaparse/tests/pdf/test_detect_ocr.py @@ -1,30 +1,26 @@ import os import pytest -from megaparse.parser.strategy import determine_strategy +from megaparse.parser.strategy import StrategyHandler from megaparse_sdk.schema.parser_config import StrategyEnum -from megaparse_sdk.config import MegaParseConfig ocr_pdfs = os.listdir("./tests/pdf/ocr") native_pdfs = os.listdir("./tests/pdf/native") -config = MegaParseConfig() + +strategy_handler = StrategyHandler() @pytest.mark.parametrize("hi_res_pdf", ocr_pdfs) def test_hi_res_strategy(hi_res_pdf): - strategy = determine_strategy( + strategy = strategy_handler.determine_strategy( f"./tests/pdf/ocr/{hi_res_pdf}", - threshold_per_page=config.auto_page_threshold, - threshold_pages_ocr=config.auto_document_threshold, ) assert strategy == StrategyEnum.HI_RES @pytest.mark.parametrize("native_pdf", native_pdfs) def test_fast_strategy(native_pdf): - strategy = determine_strategy( + strategy = strategy_handler.determine_strategy( f"./tests/pdf/native/{native_pdf}", - threshold_per_page=config.auto_page_threshold, - threshold_pages_ocr=config.auto_document_threshold, ) assert strategy == StrategyEnum.FAST diff --git a/libs/megaparse/tests/pdf/test_pdf_processing.py b/libs/megaparse/tests/pdf/test_pdf_processing.py index 0a1ba16..2b85d2c 100644 --- a/libs/megaparse/tests/pdf/test_pdf_processing.py +++ b/libs/megaparse/tests/pdf/test_pdf_processing.py @@ -2,13 +2,12 @@ import pytest from megaparse.megaparse import MegaParse -from megaparse.parser.strategy import determine_strategy +from megaparse.parser.strategy import StrategyHandler from megaparse.parser.unstructured_parser import UnstructuredParser -from megaparse_sdk.config import MegaParseConfig from megaparse_sdk.schema.extensions import FileExtension from megaparse_sdk.schema.parser_config import StrategyEnum -config = MegaParseConfig() +strategy_handler = StrategyHandler() @pytest.fixture @@ -56,16 +55,12 @@ async def test_megaparse_pdf_processor_file(pdf_name, request): def test_strategy(scanned_pdf, native_pdf): - strategy = determine_strategy( + strategy = strategy_handler.determine_strategy( scanned_pdf, - threshold_per_page=config.auto_page_threshold, - threshold_pages_ocr=config.auto_document_threshold, ) assert strategy == StrategyEnum.HI_RES - strategy = determine_strategy( + strategy = strategy_handler.determine_strategy( native_pdf, - threshold_per_page=config.auto_page_threshold, - threshold_pages_ocr=config.auto_document_threshold, ) assert strategy == StrategyEnum.FAST diff --git a/libs/megaparse_sdk/megaparse_sdk/config.py b/libs/megaparse_sdk/megaparse_sdk/config.py index 4456e66..2dafc86 100644 --- a/libs/megaparse_sdk/megaparse_sdk/config.py +++ b/libs/megaparse_sdk/megaparse_sdk/config.py @@ -14,21 +14,6 @@ class MegaParseSDKConfig(BaseSettings): max_retries: int = 3 -class MegaParseConfig(BaseSettings): - """ - Configuration for Megaparse. - """ - - model_config = SettingsConfigDict( - env_prefix="MEGAPARSE_", - env_file=(".env.local", ".env"), - env_nested_delimiter="__", - extra="ignore", - ) - auto_page_threshold: float = 0.6 - auto_document_threshold: float = 0.2 - - class SSLConfig(BaseModel): ssl_key_file: FilePath ssl_cert_file: FilePath