Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pdf_reader bug fix #44

Merged
merged 4 commits into from
Jun 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 90 additions & 69 deletions src/pai_rag/integrations/readers/pai_pdf_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
from llama_index.core import Settings
from pai_rag.utils.constants import DEFAULT_EASYOCR_MODEL_DIR
import json
import sys
import unicodedata
import logging
import tempfile
Expand All @@ -33,9 +32,10 @@ class PageItem(TypedDict):

class PaiPDFReader(BaseReader):
"""Read PDF files including texts, tables, images.

Args:
enable_image_ocr (bool): whether load ocr model to process images
model_dir: (str): ocr model path
enable_image_ocr (bool): whether load ocr model to process images
model_dir: (str): ocr model path
"""

def __init__(
Expand All @@ -54,34 +54,48 @@ def __init__(
)
logger.info("finished loading ocr model")

"""剪切图片
def process_pdf_image(self, element: LTFigure, page_object: PageObject) -> str:
"""
Processes an image element from a PDF, crops it out, and performs OCR on the result.

Args:
element (LTFigure): An LTFigure object representing the image in the PDF, containing its coordinates.
page_object (PageObject): A PageObject representing the page in the PDF to be cropped.

def process_image(self, element: LTFigure, page_object: PageObject) -> str:
# 获取从PDF中裁剪图像的坐标
Returns:
str: The OCR-processed text from the cropped image.
"""
# Retrieve the image's coordinates
[image_left, image_top, image_right, image_bottom] = [
element.x0,
element.y0,
element.x1,
element.y1,
]
# 使用坐标(left, bottom, right, top)裁剪页面
# Adjust the page's media box to crop the image based on the coordinates
page_object.mediabox.lower_left = (image_left, image_bottom)
page_object.mediabox.upper_right = (image_right, image_top)
# 将裁剪后的页面保存为新的PDF
# Save the cropped page as a new PDF file and perform OCR
cropped_pdf_writer = PyPDF2.PdfWriter()
with tempfile.NamedTemporaryFile(
delete=True, suffix=".pdf"
) as cropped_pdf_file:
cropped_pdf_writer.add_page(page_object)
cropped_pdf_writer.write(cropped_pdf_file)
cropped_pdf_file.flush()
return self.convert_to_images(cropped_pdf_file.name)
# Return the OCR-processed text
return self.ocr_pdf(cropped_pdf_file.name)

"""创建一个将PDF内容转换为image的函数
"""
def ocr_pdf(self, input_file: str) -> str:
"""
Function to convert PDF content into an image and then perform OCR (Optical Character Recognition)

def convert_to_images(self, input_file: str) -> str:
Args:
input_file (str): input file path.

Returns:
str: text from ocr.
"""
images = convert_from_path(input_file)
image = images[0]
with tempfile.NamedTemporaryFile(
Expand All @@ -91,32 +105,34 @@ def convert_to_images(self, input_file: str) -> str:
output_image_file.flush()
return self.image_to_text(output_image_file.name)

"""创建从图片中提取文本的函数
"""

def image_to_text(self, image_path: str) -> str:
# 从图片中抽取文本
"""
Function to perform OCR to extract text from image

Args:
image_path (str): input image path.

Returns:
str: text from ocr.
"""
result = self.image_reader.readtext(image_path)
predictions = "".join([item[1] for item in result])
return predictions

"""从页面中提取表格内容
"""Function to extract content from table
"""

@staticmethod
def extract_table(pdf: pdfplumber.PDF, page_num: int, table_num: int) -> List[Any]:
# 查找已检查的页面
table_page = pdf.pages[page_num]
# 提取适当的表格
table = table_page.extract_tables()[table_num]
return table

"""合并分页表格
"""Function to merge paginated tables
"""

@staticmethod
def merge_page_tables(total_tables: List[PageItem]) -> List[PageItem]:
# 合并分页表格
i = len(total_tables) - 1
while i - 1 >= 0:
table = total_tables[i]
Expand All @@ -135,16 +151,14 @@ def merge_page_tables(total_tables: List[PageItem]) -> List[PageItem]:
i -= 1
return total_tables

"""将表格转换为适当的格式
"""Function to parse table
"""

@staticmethod
def parse_table(table: List[List]) -> str:
table_string = ""
# 遍历表格的每一行
for row_num in range(len(table)):
row = table[row_num]
# 从warp的文字删除线路断路器
cleaned_row = [
item.replace("\n", " ")
if item is not None and "\n" in item
Expand All @@ -153,32 +167,29 @@ def parse_table(table: List[List]) -> str:
else item
for item in row
]
# 将表格转换为字符串,注意'|'、'\n'
table_string += "|" + "|".join(cleaned_row) + "|" + "\n"
# 删除最后一个换行符
table_string = table_string.strip()
return table_string

"""为表格生成摘要
"""Function to summarize table
"""

@staticmethod
def tables_summarize(table: List[List]) -> str:
prompt_text = f"请为以下表格生成一个摘要: {table}"
response = Settings.llm.complete(
prompt_text,
max_tokens=200, # 调整为所需的摘要长度
n=1, # 生成摘要的数量
max_tokens=200,
n=1,
)
summarized_text = response
return summarized_text

"""表格数据转化为json数据
"""Function to convert table data to json
"""

@staticmethod
def table_to_json(table: List[List]) -> str:
# 提取表头
table_info = []
column_name = table[0]
for row in range(1, len(table)):
Expand All @@ -190,23 +201,32 @@ def table_to_json(table: List[List]) -> str:

return json.dumps(table_info, ensure_ascii=False)

"""创建一个文本提取函数
"""Function to process text in pdf
"""

@staticmethod
def text_extraction(elements: List[LTTextBoxHorizontal]) -> List[str]:
# 找到每一行的坐标
"""
Extracts text lines from a list of text boxes and handles line breaks under specific conditions.

Args:
elements: A list of LTTextBoxHorizontal objects representing text boxes on a page.

Returns:
A list containing the extracted text lines with line breaks removed as per defined conditions.
"""
boxes, texts = [], []
# 页面文字的开始和结束坐标
# Initialize the start and end coordinates of the page text
max_x1 = 0
min_x0 = sys.maxsize
min_x0 = float("inf")
for text_box_h in elements:
if isinstance(text_box_h, LTTextBoxHorizontal):
for text_box_h_l in text_box_h:
if isinstance(text_box_h_l, LTTextLineHorizontal):
# Process each text line's coordinates and content
x0, y0, x1, y1 = text_box_h_l.bbox
text = text_box_h_l.get_text()
# 判断这一行是否以标点符号结尾。以标点符号结尾的行的结束位置和正常文字的结束位置不同
# Check if the line ends with punctuation and requires special handling
if not (
text[-1] == "\n"
and len(text) >= 2
Expand All @@ -216,7 +236,7 @@ def text_extraction(elements: List[LTTextBoxHorizontal]) -> List[str]:
min_x0 = min(min_x0, x0)
texts.append(text)
boxes.append((x0, x1))
# 判断是否去除换行符的条件:该行的结尾坐标大于等于除标点符号结尾的行的坐标向下取整 且 下一行的开头坐标小于等于最小文字坐标取整+1
# Remove line breaks based on defined conditions
for cur in range(len(boxes) - 1):
if boxes[cur][1] >= int(max_x1) and boxes[cur + 1][0] <= int(min_x0) + 1:
texts[cur] = texts[cur].replace("\n", "")
Expand Down Expand Up @@ -259,83 +279,84 @@ def load(
# open PDF file

pdfFileObj = open(file_path, "rb")
# 创建一个PDF阅读器对象
# Create a PDF reader object
pdf_read = PyPDF2.PdfReader(pdfFileObj)

total_tables = []
page_items = []
# 打开pdf文件
# Open the PDF and extract pages
pdf = pdfplumber.open(file_path)
# 从PDF中提取页面
for pagenum, page in enumerate(extract_pages(file_path)):
# 初始化从页面中提取文本所需的变量
# Initialize variables for extracting text from the page
page_object = pdf_read.pages[pagenum]
text_elements = []
text_from_images = []
# 初始化检查表的数量
# Initialize table count
table_num = 0
first_element = True
# 查找已检查的页面
# Find the checked page
page_tables = pdf.pages[pagenum]
# 找出本页上的表格数目
# Find the number of tables on the page
tables = page_tables.find_tables()

# 找到所有的元素
# Find all elements on the page
page_elements = [(element.y1, element) for element in page._objs]
# 对页面中出现的所有元素进行排序
# Sort the elements on the page by their y1 coordinate
page_elements.sort(key=lambda a: a[0], reverse=True)

# 查找组成页面的元素
# Iterate through the page's elements
for i, component in enumerate(page_elements):
# 提取页面布局的元素
# Extract text elements
element = component[1]

# 检查该元素是否为文本元素
# Check if the element is a text box
if isinstance(element, LTTextBoxHorizontal):
text_elements.append(element)

# 检查元素中的图像
# Check for images and extract text from them if OCR is enabled
elif isinstance(element, LTFigure) and self.enable_image_ocr:
# 从PDF中提取文字
image_texts = self.process_image(element, page_object)
# Extract text from the PDF image
image_texts = self.process_pdf_image(element, page_object)
text_from_images.append(image_texts)

# 检查表的元素
# Check for table elements
elif isinstance(element, LTRect):
lower_side = sys.maxsize
lower_side = float("inf")
upper_side = 0

# 如果第一个矩形元素
# If it's the first rectangle element
if first_element is True and (table_num + 1) <= len(tables):
# 找到表格的边界框
# Find the bounding box of the table
lower_side = page.bbox[3] - tables[table_num].bbox[3]
upper_side = element.y1
# 从表中提取信息
tabel_text = PaiPDFReader.extract_table(pdf, pagenum, table_num)
# Extract the table data
table_text = PaiPDFReader.extract_table(pdf, pagenum, table_num)

item = PageItem(
page_number=pagenum,
index_id=i,
item_type="table",
element=element,
table_num=table_num,
text=tabel_text,
text=table_text,
)
total_tables.append(item)
# 让它成为另一个元素
# Move to the next element
first_element = False

# 检查我们是否已经从页面中提取了表
# Check if we've extracted a table from the page
if element.y0 >= lower_side and element.y1 <= upper_side:
pass
elif not isinstance(page_elements[i + 1][1], LTRect):
elif i + 1 < len(page_elements) and not isinstance(
page_elements[i + 1][1], LTRect
):
first_element = True
table_num += 1

# 文本处理
# Text extraction from text elements
text_from_texts = PaiPDFReader.text_extraction(text_elements)
page_plain_text = "".join(text_from_texts)
# 图片处理
# Image text extraction
page_image_text = "".join(text_from_images)

page_items.append(
Expand All @@ -351,19 +372,19 @@ def load(
)
)

# 合并分页表格
# Merge tables across pages
total_tables = PaiPDFReader.merge_page_tables(total_tables)

# 构造返回数据
# Construct the returned data
docs = []
for pagenum, item in enumerate(page_items):
page_tables_texts = []
page_tables_summaries = []
page_tables_json = []
for table in total_tables:
# 如果页面匹配
# If the page number matches
if pagenum == table["page_number"]:
# 将表信息转换为结构化字符串格式
# Convert the table data to a structured string
table_string = PaiPDFReader.parse_table(table["text"])
summarized_table_text = PaiPDFReader.tables_summarize(table["text"])
json_data = PaiPDFReader.table_to_json(table["text"])
Expand All @@ -376,7 +397,7 @@ def load(

page_info_text = item[0]["text"] + item[1]["text"] + page_table_text

# if extra_info is not None, check if it is a dictionary
# if `extra_info` is not None, check if it is a dictionary
if extra_info:
if not isinstance(extra_info, dict):
raise TypeError("extra_info must be a dictionary.")
Expand Down
Loading
Loading