Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ONNX for inference, speed boost, enable flattening PDFs #8

Merged
merged 8 commits into from
Oct 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ pdftext PDF_PATH --out_path output.txt
- `--keep_hyphens` will keep hyphens in the output (they will be stripped and words joined otherwise)
- `--pages` will specify pages (comma separated) to extract
- `--workers` specifies the number of parallel workers to use
- `--flatten_pdf` merges form fields into the PDF

## JSON

Expand All @@ -44,6 +45,7 @@ pdftext PDF_PATH --out_path output.txt --json
- `--pages` will specify pages (comma separated) to extract
- `--keep_chars` will keep individual characters in the json output
- `--workers` specifies the number of parallel workers to use
- `--flatten_pdf` merges form fields into the PDF

The output will be a json list, with each item in the list corresponding to a single page in the input pdf (in order). Each page will include the following keys:

Expand Down
7 changes: 3 additions & 4 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def pdfplumber_inference(pdf_path):
return pages


def pdftext_inference(pdf_path, model=None, workers=None):
return paginated_plain_text_output(pdf_path, model=model, workers=workers)
def pdftext_inference(pdf_path, workers=None):
return paginated_plain_text_output(pdf_path, workers=workers)


def compare_docs(doc1: str, doc2: str):
Expand All @@ -78,7 +78,6 @@ def main():
if args.pdftext_only:
times_tools = ["pymupdf", "pdftext"]
alignment_tools = ["pdftext"]
model = get_model()
for i in tqdm(range(len(dataset)), desc="Benchmarking"):
row = dataset[i]
pdf = row["pdf"]
Expand All @@ -88,7 +87,7 @@ def main():
f.seek(0)
pdf_path = f.name

pdftext_inference_model = partial(pdftext_inference, model=model, workers=args.pdftext_workers)
pdftext_inference_model = partial(pdftext_inference, workers=args.pdftext_workers)
inference_funcs = [pymupdf_inference, pdftext_inference_model, pdfplumber_inference]
for tool, inference_func in zip(times_tools, inference_funcs):
start = time.time()
Expand Down
5 changes: 3 additions & 2 deletions extract_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ def main():
parser.add_argument("--sort", action="store_true", help="Attempt to sort the text by reading order", default=False)
parser.add_argument("--keep_hyphens", action="store_true", help="Keep hyphens in words", default=False)
parser.add_argument("--pages", type=str, help="Comma separated pages to extract, like 1,2,3", default=None)
parser.add_argument("--flatten_pdf", action="store_true", help="Flatten form fields and annotations into page contents", default=False)
parser.add_argument("--keep_chars", action="store_true", help="Keep character level information", default=False)
parser.add_argument("--workers", type=int, help="Number of workers to use for parallel processing", default=None)
args = parser.parse_args()
Expand All @@ -24,10 +25,10 @@ def main():
assert all(p <= len(pdf_doc) for p in pages), "Invalid page number(s) provided"

if args.json:
text = dictionary_output(args.pdf_path, sort=args.sort, page_range=pages, keep_chars=args.keep_chars, workers=args.workers)
text = dictionary_output(args.pdf_path, sort=args.sort, page_range=pages, flatten_pdf=args.flatten_pdf, keep_chars=args.keep_chars, workers=args.workers)
text = json.dumps(text)
else:
text = plain_text_output(args.pdf_path, sort=args.sort, hyphens=args.keep_hyphens, page_range=pages, workers=args.workers)
text = plain_text_output(args.pdf_path, sort=args.sort, hyphens=args.keep_hyphens, page_range=pages, flatten_pdf=args.flatten_pdf, workers=args.workers)

if args.out_path is None:
print(text)
Expand Down
Binary file added models/dt.onnx
Binary file not shown.
53 changes: 35 additions & 18 deletions pdftext/extraction.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
from functools import partial
from typing import List
from concurrent.futures import ProcessPoolExecutor
import math
Expand All @@ -12,49 +11,67 @@
from pdftext.settings import settings


def _get_page_range(pdf_path, model, page_range):
pdf_doc = pdfium.PdfDocument(pdf_path)
text_chars = get_pdfium_chars(pdf_doc, page_range)
def _load_pdf(pdf, flatten_pdf):
if isinstance(pdf, str):
pdf = pdfium.PdfDocument(pdf)

if not isinstance(pdf, pdfium.PdfDocument):
raise TypeError("pdf must be a file path string or a PdfDocument object")

# Must be called on the parent pdf, before the page was retrieved
if flatten_pdf:
pdf.init_forms()

return pdf


def _get_page_range(page_range, flatten_pdf=False):
text_chars = get_pdfium_chars(pdf_doc, page_range, flatten_pdf)
pages = inference(text_chars, model)
return pages


def _get_pages(pdf_path, model=None, page_range=None, workers=None):
if model is None:
model = get_model()
def worker_init(pdf_path, flatten_pdf):
global model
global pdf_doc

pdf_doc = _load_pdf(pdf_path, flatten_pdf)
model = get_model()

pdf_doc = pdfium.PdfDocument(pdf_path)

def _get_pages(pdf_path, page_range=None, flatten_pdf=False, workers=None):
pdf_doc = _load_pdf(pdf_path, flatten_pdf)
if page_range is None:
page_range = range(len(pdf_doc))

if workers is not None:
workers = min(workers, len(page_range) // settings.WORKER_PAGE_THRESHOLD) # It's inefficient to have too many workers, since we batch in inference

if workers is None or workers <= 1:
text_chars = get_pdfium_chars(pdf_doc, page_range)
model = get_model()
text_chars = get_pdfium_chars(pdf_doc, page_range, flatten_pdf)
return inference(text_chars, model)

func = partial(_get_page_range, pdf_path, model)
page_range = list(page_range)

pages_per_worker = math.ceil(len(page_range) / workers)
page_range_chunks = [page_range[i * pages_per_worker:(i + 1) * pages_per_worker] for i in range(workers)]

with ProcessPoolExecutor(max_workers=workers) as executor:
pages = list(executor.map(func, page_range_chunks))
with ProcessPoolExecutor(max_workers=workers, initializer=worker_init, initargs=(pdf_path, flatten_pdf)) as executor:
pages = list(executor.map(_get_page_range, page_range_chunks))

ordered_pages = [page for sublist in pages for page in sublist]

return ordered_pages


def plain_text_output(pdf_path, sort=False, model=None, hyphens=False, page_range=None, workers=None) -> str:
text = paginated_plain_text_output(pdf_path, sort=sort, model=model, hyphens=hyphens, page_range=page_range, workers=workers)
def plain_text_output(pdf_path, sort=False, hyphens=False, page_range=None, flatten_pdf=False, workers=None) -> str:
text = paginated_plain_text_output(pdf_path, sort=sort, hyphens=hyphens, page_range=page_range, workers=workers, flatten_pdf=flatten_pdf)
return "\n".join(text)


def paginated_plain_text_output(pdf_path, sort=False, model=None, hyphens=False, page_range=None, workers=None) -> List[str]:
pages = _get_pages(pdf_path, model, page_range, workers=workers)
def paginated_plain_text_output(pdf_path, sort=False, hyphens=False, page_range=None, flatten_pdf=False, workers=None) -> List[str]:
pages = _get_pages(pdf_path, page_range, workers=workers, flatten_pdf=flatten_pdf)
text = []
for page in pages:
text.append(merge_text(page, sort=sort, hyphens=hyphens).strip())
Expand All @@ -71,8 +88,8 @@ def _process_span(span, page_width, page_height, keep_chars):
char["bbox"] = unnormalize_bbox(char["bbox"], page_width, page_height)


def dictionary_output(pdf_path, sort=False, model=None, page_range=None, keep_chars=False, workers=None):
pages = _get_pages(pdf_path, model, page_range, workers=workers)
def dictionary_output(pdf_path, sort=False, page_range=None, keep_chars=False, flatten_pdf=False, workers=None):
pages = _get_pages(pdf_path, page_range, workers=workers, flatten_pdf=flatten_pdf)
for page in pages:
page_width, page_height = page["width"], page["height"]
for block in page["blocks"]:
Expand Down
56 changes: 27 additions & 29 deletions pdftext/inference.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from itertools import chain

import sklearn
import numpy as np

from pdftext.pdf.utils import LINE_BREAKS, TABS, SPACES
from pdftext.settings import settings
Expand Down Expand Up @@ -47,27 +45,27 @@ def create_training_row(char_info, prev_char, currblock, currline):

is_space = char in SPACES or char in TABS

training_row = {
"is_newline": char in LINE_BREAKS,
"is_space": is_space,
"x_gap": x_gap,
"y_gap": y_gap,
"font_match": font_match,
"x_outer_gap": char_x2 - prev_x1,
"y_outer_gap": char_y2 - prev_y1,
"line_x_center_gap": char_center_x - currline["center_x"],
"line_y_center_gap": char_center_y - currline["center_y"],
"line_x_gap": char_x1 - currline_bbox[2],
"line_y_gap": char_y1 - currline_bbox[3],
"line_x_start_gap": char_x1 - currline_bbox[0],
"line_y_start_gap": char_y1 - currline_bbox[1],
"block_x_center_gap": char_center_x - currblock["center_x"],
"block_y_center_gap": char_center_y - currblock["center_y"],
"block_x_gap": char_x1 - currblock_bbox[2],
"block_y_gap": char_y1 - currblock_bbox[3],
"block_x_start_gap": char_x1 - currblock_bbox[0],
"block_y_start_gap": char_y1 - currblock_bbox[1]
}
return np.array([
char_center_x - currblock["center_x"],
char_x1 - currblock_bbox[2],
char_x1 - currblock_bbox[0],
char_center_y - currblock["center_y"],
char_y1 - currblock_bbox[3],
char_y1 - currblock_bbox[1],
font_match,
char in LINE_BREAKS,
is_space,
char_center_x - currline["center_x"],
char_x1 - currline_bbox[2],
char_x1 - currline_bbox[0],
char_center_y - currline["center_y"],
char_y1 - currline_bbox[3],
char_y1 - currline_bbox[1],
x_gap,
char_x2 - prev_x1,
y_gap,
char_y2 - prev_y1
], dtype=np.float32)

return training_row

Expand Down Expand Up @@ -135,8 +133,6 @@ def infer_single_page(text_chars, block_threshold=settings.BLOCK_THRESHOLD):
font_info = f"{font['name']}_{font['size']}_{font['weight']}_{font['flags']}_{char_info['rotation']}"
if prev_char:
training_row = create_training_row(char_info, prev_char, block, line)
sorted_keys = sorted(training_row.keys())
training_row = [training_row[key] for key in sorted_keys]

prediction_probs = yield training_row
# First item is probability of same line/block, second is probability of new line, third is probability of new block
Expand Down Expand Up @@ -175,6 +171,8 @@ def inference(text_chars, model):
# Create generators and get first training row from each
generators = [infer_single_page(text_page) for text_page in text_chars]
next_prediction = {}
input_name = model.get_inputs()[0].name
output_name = model.get_outputs()[1].name

page_blocks = {}
while len(page_blocks) < len(generators):
Expand All @@ -199,10 +197,10 @@ def inference(text_chars, model):

training_idxs = sorted(training_data.keys())
training_rows = [training_data[idx] for idx in training_idxs]
training_rows = np.stack(training_rows, axis=0)

# Disable nan, etc, validation for a small speedup
with sklearn.config_context(assume_finite=True):
predictions = model.predict_proba(training_rows)
# Run inference
predictions = model.run([output_name], {input_name: training_rows})[0]
for pred, page_idx in zip(predictions, training_idxs):
next_prediction[page_idx] = pred
sorted_keys = sorted(page_blocks.keys())
Expand Down
6 changes: 3 additions & 3 deletions pdftext/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import joblib
from pdftext.settings import settings
import onnxruntime as rt


def get_model(model_path=settings.MODEL_PATH):
model = joblib.load(model_path)
return model
sess = rt.InferenceSession(model_path)
return sess
18 changes: 17 additions & 1 deletion pdftext/pdf/chars.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from typing import Dict, List

import pypdfium2.raw as pdfium_c
from pypdfium2 import PdfiumError

from pdftext.pdf.utils import get_fontname, pdfium_page_bbox_to_device_bbox, page_bbox_to_device_bbox
from pdftext.settings import settings
Expand All @@ -19,11 +20,26 @@ def update_previous_fonts(char_infos: List, i: int, prev_fontname: str, prev_fon
char_infos[j]["font"]["flags"] = fontflags


def get_pdfium_chars(pdf, page_range, fontname_sample_freq=settings.FONTNAME_SAMPLE_FREQ):
def flatten(page, flag=pdfium_c.FLAT_NORMALDISPLAY):
rc = pdfium_c.FPDFPage_Flatten(page, flag)
if rc == pdfium_c.FLATTEN_FAIL:
raise PdfiumError("Failed to flatten annotations / form fields.")


def get_pdfium_chars(pdf, page_range, flatten_pdf, fontname_sample_freq=settings.FONTNAME_SAMPLE_FREQ):
blocks = []

for page_idx in page_range:
page = pdf.get_page(page_idx)

if flatten_pdf:
# Flatten form fields and annotations into page contents.
flatten(pdf, page)

# Flattening invalidates existing handles to the page.
# It is necessary to re-initialize the page handle after flattening.
page = pdf.get_page(page_idx)

text_page = page.get_textpage()
mediabox = page.get_mediabox()
page_rotation = page.get_rotation()
Expand Down
2 changes: 1 addition & 1 deletion pdftext/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class Settings(BaseSettings):
BASE_PATH: str = os.path.dirname(os.path.dirname(__file__))
MODEL_PATH: str = os.path.join(BASE_PATH, "models", "dt.joblib")
MODEL_PATH: str = os.path.join(BASE_PATH, "models", "dt.onnx")

# Fonts
FONTNAME_SAMPLE_FREQ: int = 4
Expand Down
Loading
Loading