Skip to content

Commit

Permalink
fix: Add EngineConfig & StrategyHandler (#211)
Browse files Browse the repository at this point in the history
* fix: Add EngineConfig & StrategyHandler

* fix: Configs settings
  • Loading branch information
chloedia authored Jan 8, 2025
1 parent 03c7ada commit 2e1c6dd
Show file tree
Hide file tree
Showing 9 changed files with 258 additions and 153 deletions.
47 changes: 47 additions & 0 deletions libs/megaparse/src/megaparse/configs/auto.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from enum import Enum

from pydantic import BaseModel
from pydantic_settings import BaseSettings, SettingsConfigDict


class TextDetConfig(BaseModel):
det_arch: str = "fast_base"
batch_size: int = 2
assume_straight_pages: bool = True
preserve_aspect_ratio: bool = True
symmetric_pad: bool = True
load_in_8_bit: bool = False


class AutoStrategyConfig(BaseModel):
auto_page_threshold: float = 0.6
auto_document_threshold: float = 0.2


class TextRecoConfig(BaseModel):
reco_arch: str = "crnn_vgg16_bn"
batch_size: int = 512


class DeviceEnum(str, Enum):
CPU = "cpu"
CUDA = "cuda"
COREML = "coreml"


class MegaParseConfig(BaseSettings):
"""
Configuration for Megaparse.
"""

model_config = SettingsConfigDict(
env_prefix="MEGAPARSE_",
env_file=(".env.local", ".env"),
env_nested_delimiter="__",
extra="ignore",
use_enum_values=True,
)
text_det_config: TextDetConfig = TextDetConfig()
text_reco_config: TextRecoConfig = TextRecoConfig()
auto_parse_config: AutoStrategyConfig = AutoStrategyConfig()
device: DeviceEnum = DeviceEnum.CPU
2 changes: 1 addition & 1 deletion libs/megaparse/src/megaparse/examples/parse_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ def main():
parser = UnstructuredParser()
megaparse = MegaParse(parser=parser)

file_path = "./tests/pdf/ocr/0168126.pdf"
file_path = "./tests/pdf/native/0168029.pdf"

parsed_file = megaparse.load(file_path)
print(f"\n----- File Response : {file_path} -----\n")
Expand Down
37 changes: 23 additions & 14 deletions libs/megaparse/src/megaparse/megaparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,36 +4,51 @@
from pathlib import Path
from typing import IO, BinaryIO

from megaparse_sdk.config import MegaParseConfig
from megaparse.configs.auto import DeviceEnum, MegaParseConfig
from megaparse_sdk.schema.extensions import FileExtension
from megaparse_sdk.schema.parser_config import StrategyEnum

from megaparse.checker.format_checker import FormatChecker
from megaparse.exceptions.base import ParsingException
from megaparse.parser.base import BaseParser
from megaparse.parser.doctr_parser import DoctrParser
from megaparse.parser.strategy import determine_strategy
from megaparse.parser.strategy import StrategyHandler
from megaparse.parser.unstructured_parser import UnstructuredParser

logger = logging.getLogger("megaparse")


class MegaParse:
config: MegaParseConfig = MegaParseConfig()
config = MegaParseConfig()

def __init__(
self,
parser: BaseParser = UnstructuredParser(strategy=StrategyEnum.FAST),
ocr_parser: BaseParser = DoctrParser(),
parser: BaseParser | None = None,
ocr_parser: BaseParser | None = None,
strategy: StrategyEnum = StrategyEnum.AUTO,
format_checker: FormatChecker | None = None,
) -> None:
if not parser:
parser = UnstructuredParser(strategy=StrategyEnum.FAST)
if not ocr_parser:
ocr_parser = DoctrParser(
text_det_config=self.config.text_det_config,
text_reco_config=self.config.text_reco_config,
device=self.config.device,
)

self.strategy = strategy
self.parser = parser
self.ocr_parser = ocr_parser
self.format_checker = format_checker
self.last_parsed_document: str = ""

self.strategy_handler = StrategyHandler(
text_det_config=self.config.text_det_config,
auto_config=self.config.auto_parse_config,
device=self.config.device,
)

def validate_input(
self,
file_path: Path | str | None = None,
Expand Down Expand Up @@ -132,17 +147,11 @@ def _select_parser(
if self.strategy != StrategyEnum.AUTO or file_extension != FileExtension.PDF:
return self.parser
if file:
local_strategy = determine_strategy(
file=file,
threshold_pages_ocr=self.config.auto_document_threshold,
threshold_per_page=self.config.auto_page_threshold,
local_strategy = self.strategy_handler.determine_strategy(
file=file, # type: ignore #FIXME: Careful here on removing BinaryIO (not handled by onnxtr)
)
if file_path:
local_strategy = determine_strategy(
file=file_path,
threshold_pages_ocr=self.config.auto_document_threshold,
threshold_per_page=self.config.auto_page_threshold,
)
local_strategy = self.strategy_handler.determine_strategy(file=file_path)

if local_strategy == StrategyEnum.HI_RES:
return self.ocr_parser
Expand Down
39 changes: 25 additions & 14 deletions libs/megaparse/src/megaparse/parser/doctr_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pathlib import Path
from typing import IO, BinaryIO, List

from megaparse.configs.auto import DeviceEnum, TextRecoConfig, TextDetConfig
import onnxruntime as rt
from megaparse_sdk.schema.extensions import FileExtension
from onnxtr.io import DocumentFile
Expand All @@ -19,16 +20,13 @@ class DoctrParser(BaseParser):

def __init__(
self,
det_predictor_model: str = "db_resnet50",
reco_predictor_model: str = "crnn_vgg16_bn",
det_bs: int = 2,
reco_bs: int = 512,
assume_straight_pages: bool = True,
text_det_config: TextDetConfig = TextDetConfig(),
text_reco_config: TextRecoConfig = TextRecoConfig(),
device: DeviceEnum = DeviceEnum.CPU,
straighten_pages: bool = False,
use_gpu: bool = False,
**kwargs,
):
self.use_gpu = use_gpu
self.device = device
general_options = rt.SessionOptions()
providers = self._get_providers()
engine_config = EngineConfig(
Expand All @@ -37,11 +35,11 @@ def __init__(
)
# TODO: set in config or pass as kwargs
self.predictor = ocr_predictor(
det_arch=det_predictor_model,
reco_arch=reco_predictor_model,
det_bs=det_bs,
reco_bs=reco_bs,
assume_straight_pages=assume_straight_pages,
det_arch=text_det_config.det_arch,
reco_arch=text_reco_config.reco_arch,
det_bs=text_det_config.batch_size,
reco_bs=text_reco_config.batch_size,
assume_straight_pages=text_det_config.assume_straight_pages,
straighten_pages=straighten_pages,
# Preprocessing related parameters
det_engine_cfg=engine_config,
Expand All @@ -53,14 +51,27 @@ def __init__(
def _get_providers(self) -> List[str]:
prov = rt.get_available_providers()
logger.info("Available providers:", prov)
if self.use_gpu:
if self.device == DeviceEnum.CUDA:
# TODO: support openvino, directml etc
if "CUDAExecutionProvider" not in prov:
raise ValueError(
"onnxruntime can't find CUDAExecutionProvider in list of available providers"
)
return ["CUDAExecutionProvider"]
return ["TensorrtExecutionProvider", "CUDAExecutionProvider"]
elif self.device == DeviceEnum.COREML:
if "CoreMLExecutionProvider" not in prov:
raise ValueError(
"onnxruntime can't find CoreMLExecutionProvider in list of available providers"
)
return ["CoreMLExecutionProvider"]
elif self.device == DeviceEnum.CPU:
return ["CPUExecutionProvider"]
else:
warnings.warn(
"Device not supported, using CPU",
UserWarning,
stacklevel=2,
)
return ["CPUExecutionProvider"]

def convert(
Expand Down
Loading

0 comments on commit 2e1c6dd

Please sign in to comment.