Skip to content

Commit

Permalink
Swap inference to onnx for speed + remove warning
Browse files Browse the repository at this point in the history
  • Loading branch information
VikParuchuri committed Oct 7, 2024
1 parent 2557089 commit 660c010
Show file tree
Hide file tree
Showing 8 changed files with 757 additions and 547 deletions.
7 changes: 3 additions & 4 deletions benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ def pdfplumber_inference(pdf_path):
return pages


def pdftext_inference(pdf_path, model=None, workers=None):
return paginated_plain_text_output(pdf_path, model=model, workers=workers)
def pdftext_inference(pdf_path, workers=None):
return paginated_plain_text_output(pdf_path, workers=workers)


def compare_docs(doc1: str, doc2: str):
Expand All @@ -78,7 +78,6 @@ def main():
if args.pdftext_only:
times_tools = ["pymupdf", "pdftext"]
alignment_tools = ["pdftext"]
model = get_model()
for i in tqdm(range(len(dataset)), desc="Benchmarking"):
row = dataset[i]
pdf = row["pdf"]
Expand All @@ -88,7 +87,7 @@ def main():
f.seek(0)
pdf_path = f.name

pdftext_inference_model = partial(pdftext_inference, model=model, workers=args.pdftext_workers)
pdftext_inference_model = partial(pdftext_inference, workers=args.pdftext_workers)
inference_funcs = [pymupdf_inference, pdftext_inference_model, pdfplumber_inference]
for tool, inference_func in zip(times_tools, inference_funcs):
start = time.time()
Expand Down
Binary file added models/dt.onnx
Binary file not shown.
31 changes: 17 additions & 14 deletions pdftext/extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,20 @@
from pdftext.settings import settings


def _get_page_range(pdf_path, model, page_range):
pdf_doc = pdfium.PdfDocument(pdf_path)
def _get_page_range(page_range):
text_chars = get_pdfium_chars(pdf_doc, page_range)
pages = inference(text_chars, model)
return pages


def _get_pages(pdf_path, model=None, page_range=None, workers=None):
if model is None:
model = get_model()
def worker_init(pdf_path):
global model
global pdf_doc
pdf_doc = pdfium.PdfDocument(pdf_path)
model = get_model()


def _get_pages(pdf_path, page_range=None, workers=None):
pdf_doc = pdfium.PdfDocument(pdf_path)
if page_range is None:
page_range = range(len(pdf_doc))
Expand All @@ -31,30 +34,30 @@ def _get_pages(pdf_path, model=None, page_range=None, workers=None):
workers = min(workers, len(page_range) // settings.WORKER_PAGE_THRESHOLD) # It's inefficient to have too many workers, since we batch in inference

if workers is None or workers <= 1:
model = get_model()
text_chars = get_pdfium_chars(pdf_doc, page_range)
return inference(text_chars, model)

func = partial(_get_page_range, pdf_path, model)
page_range = list(page_range)

pages_per_worker = math.ceil(len(page_range) / workers)
page_range_chunks = [page_range[i * pages_per_worker:(i + 1) * pages_per_worker] for i in range(workers)]

with ProcessPoolExecutor(max_workers=workers) as executor:
pages = list(executor.map(func, page_range_chunks))
with ProcessPoolExecutor(max_workers=workers, initializer=worker_init, initargs=(pdf_path,)) as executor:
pages = list(executor.map(_get_page_range, page_range_chunks))

ordered_pages = [page for sublist in pages for page in sublist]

return ordered_pages


def plain_text_output(pdf_path, sort=False, model=None, hyphens=False, page_range=None, workers=None) -> str:
text = paginated_plain_text_output(pdf_path, sort=sort, model=model, hyphens=hyphens, page_range=page_range, workers=workers)
def plain_text_output(pdf_path, sort=False, hyphens=False, page_range=None, workers=None) -> str:
text = paginated_plain_text_output(pdf_path, sort=sort, hyphens=hyphens, page_range=page_range, workers=workers)
return "\n".join(text)


def paginated_plain_text_output(pdf_path, sort=False, model=None, hyphens=False, page_range=None, workers=None) -> List[str]:
pages = _get_pages(pdf_path, model, page_range, workers=workers)
def paginated_plain_text_output(pdf_path, sort=False, hyphens=False, page_range=None, workers=None) -> List[str]:
pages = _get_pages(pdf_path, page_range, workers=workers)
text = []
for page in pages:
text.append(merge_text(page, sort=sort, hyphens=hyphens).strip())
Expand All @@ -71,8 +74,8 @@ def _process_span(span, page_width, page_height, keep_chars):
char["bbox"] = unnormalize_bbox(char["bbox"], page_width, page_height)


def dictionary_output(pdf_path, sort=False, model=None, page_range=None, keep_chars=False, workers=None):
pages = _get_pages(pdf_path, model, page_range, workers=workers)
def dictionary_output(pdf_path, sort=False, page_range=None, keep_chars=False, workers=None):
pages = _get_pages(pdf_path, page_range, workers=workers)
for page in pages:
page_width, page_height = page["width"], page["height"]
for block in page["blocks"]:
Expand Down
53 changes: 27 additions & 26 deletions pdftext/inference.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from itertools import chain

import numpy as np
import sklearn

from pdftext.pdf.utils import LINE_BREAKS, TABS, SPACES
Expand Down Expand Up @@ -47,27 +48,27 @@ def create_training_row(char_info, prev_char, currblock, currline):

is_space = char in SPACES or char in TABS

training_row = {
"is_newline": char in LINE_BREAKS,
"is_space": is_space,
"x_gap": x_gap,
"y_gap": y_gap,
"font_match": font_match,
"x_outer_gap": char_x2 - prev_x1,
"y_outer_gap": char_y2 - prev_y1,
"line_x_center_gap": char_center_x - currline["center_x"],
"line_y_center_gap": char_center_y - currline["center_y"],
"line_x_gap": char_x1 - currline_bbox[2],
"line_y_gap": char_y1 - currline_bbox[3],
"line_x_start_gap": char_x1 - currline_bbox[0],
"line_y_start_gap": char_y1 - currline_bbox[1],
"block_x_center_gap": char_center_x - currblock["center_x"],
"block_y_center_gap": char_center_y - currblock["center_y"],
"block_x_gap": char_x1 - currblock_bbox[2],
"block_y_gap": char_y1 - currblock_bbox[3],
"block_x_start_gap": char_x1 - currblock_bbox[0],
"block_y_start_gap": char_y1 - currblock_bbox[1]
}
return np.array([
char_center_x - currblock["center_x"],
char_x1 - currblock_bbox[2],
char_x1 - currblock_bbox[0],
char_center_y - currblock["center_y"],
char_y1 - currblock_bbox[3],
char_y1 - currblock_bbox[1],
font_match,
char in LINE_BREAKS,
is_space,
char_center_x - currline["center_x"],
char_x1 - currline_bbox[2],
char_x1 - currline_bbox[0],
char_center_y - currline["center_y"],
char_y1 - currline_bbox[3],
char_y1 - currline_bbox[1],
x_gap,
char_x2 - prev_x1,
y_gap,
char_y2 - prev_y1
], dtype=np.float32)

return training_row

Expand Down Expand Up @@ -135,8 +136,6 @@ def infer_single_page(text_chars, block_threshold=settings.BLOCK_THRESHOLD):
font_info = f"{font['name']}_{font['size']}_{font['weight']}_{font['flags']}_{char_info['rotation']}"
if prev_char:
training_row = create_training_row(char_info, prev_char, block, line)
sorted_keys = sorted(training_row.keys())
training_row = [training_row[key] for key in sorted_keys]

prediction_probs = yield training_row
# First item is probability of same line/block, second is probability of new line, third is probability of new block
Expand Down Expand Up @@ -175,6 +174,8 @@ def inference(text_chars, model):
# Create generators and get first training row from each
generators = [infer_single_page(text_page) for text_page in text_chars]
next_prediction = {}
input_name = model.get_inputs()[0].name
output_name = model.get_outputs()[1].name

page_blocks = {}
while len(page_blocks) < len(generators):
Expand All @@ -199,10 +200,10 @@ def inference(text_chars, model):

training_idxs = sorted(training_data.keys())
training_rows = [training_data[idx] for idx in training_idxs]
training_rows = np.stack(training_rows, axis=0)

# Disable nan, etc, validation for a small speedup
with sklearn.config_context(assume_finite=True):
predictions = model.predict_proba(training_rows)
# Run inference
predictions = model.run([output_name], {input_name: training_rows})[0]
for pred, page_idx in zip(predictions, training_idxs):
next_prediction[page_idx] = pred
sorted_keys = sorted(page_blocks.keys())
Expand Down
6 changes: 3 additions & 3 deletions pdftext/model.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import joblib
from pdftext.settings import settings
import onnxruntime as rt


def get_model(model_path=settings.MODEL_PATH):
model = joblib.load(model_path)
return model
sess = rt.InferenceSession(model_path)
return sess
2 changes: 1 addition & 1 deletion pdftext/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

class Settings(BaseSettings):
BASE_PATH: str = os.path.dirname(os.path.dirname(__file__))
MODEL_PATH: str = os.path.join(BASE_PATH, "models", "dt.joblib")
MODEL_PATH: str = os.path.join(BASE_PATH, "models", "dt.onnx")

# Fonts
FONTNAME_SAMPLE_FREQ: int = 4
Expand Down
Loading

0 comments on commit 660c010

Please sign in to comment.