From a944f79b69c552eedca4d5161e93246f5e7ece7f Mon Sep 17 00:00:00 2001 From: DL Date: Mon, 19 Aug 2024 21:56:02 +0800 Subject: [PATCH] Add support for Gemini Gemini PRO image parser --- requirements.txt | 18 +- src/llmsearch/config.py | 11 +- src/llmsearch/parsers/images/gemini_parser.py | 13 +- src/llmsearch/parsers/images/generic.py | 139 ++++++---- src/llmsearch/parsers/pdf.py | 4 +- src/llmsearch/parsers/tables/generic.py | 148 ++++++----- src/llmsearch/parsers/tables/gmft_parser.py | 239 ++++++++++++------ tests/test_table_splitting.py | 105 ++++++++ 8 files changed, 469 insertions(+), 208 deletions(-) create mode 100644 tests/test_table_splitting.py diff --git a/requirements.txt b/requirements.txt index 3911a3d..87a1598 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,8 +1,8 @@ -llama-cpp-python==0.2.78 -chromadb~=0.5 -langchain~=0.2.4 -langchain-community~=0.2.4 -langchain-openai~=0.1.8 +llama-cpp-python==0.2.88 +chromadb~=0.5.5 +langchain~=0.2.14 +langchain-community~=0.2.12 +langchain-openai~=0.1.22 langchain-huggingface~=0.0.3 pydantic~=2.7 transformers~=4.41 @@ -16,14 +16,14 @@ python-dotenv accelerate~=0.33 protobuf==3.20.2 termcolor -openai~=1.34.0 +openai~=1.41 einops # required for Mosaic models click bitsandbytes==0.43.1 # auto-gptq==0.2.0 InstructorEmbedding==1.0.1 unstructured~=0.14.5 -pymupdf==1.22.5 +pymupdf==1.24.9 streamlit~=1.28 python-docx~=1.1 six==1.16.0 ; python_version >= "3.10" and python_version < "4.0" @@ -36,4 +36,6 @@ threadpoolctl==3.1.0 ; python_version >= "3.10" and python_version < "4.0" tiktoken==0.7.0 ; python_version >= "3.10" and python_version < "4.0" tokenizers==0.19.1; python_version >= "3.10" and python_version < "4.0" tqdm==4.65.0 ; python_version >= "3.10" and python_version < "4.0" -# transformers==4.29.2 ; python_version >= "3.10" and python_version < "4.0" \ No newline at end of file +# transformers==4.29.2 ; python_version >= "3.10" and python_version < "4.0" +gmft==0.2.1 +google-generativeai~=0.7 \ No newline at end of file diff --git a/src/llmsearch/config.py b/src/llmsearch/config.py index 33b75ec..ac40bbe 100644 --- a/src/llmsearch/config.py +++ b/src/llmsearch/config.py @@ -62,7 +62,16 @@ class PDFTableParser(str, Enum): class PDFImageParser(str, Enum): GEMINI_15_FLASH = "gemini-1.5-flash" + GEMINI_15_PRO= "gemini-1.5-pro" +class PDFImageParseSettings(BaseModel): + image_parser: PDFImageParser + system_instruction: str = """You are an research assistant. You analyze the image to extract detailed information. Response must be a Markdown string in the follwing format: +- First line is a heading with image caption, starting with '# ' +- Second line is empty +- From the third line on - detailed data points and related metadata, extracted from the image, in Markdown format. Don't use Markdown tables. +""" + user_instruction: str = """From the image, extract detailed quantitative and qualitative data points.""" class EmbeddingModelType(str, Enum): huggingface = "huggingface" @@ -92,7 +101,7 @@ class DocumentPathSettings(BaseModel): pdf_table_parser: Optional[PDFTableParser] = None """If enabled, will parse tables in pdf files using a specific of a parser.""" - pdf_image_parser: Optional[PDFImageParser] = None + pdf_image_parser: Optional[PDFImageParseSettings] = None """If enabled, will parse images in pdf files using a specific of a parser.""" additional_parser_settings: Dict[str, Any] = Field(default_factory=dict) diff --git a/src/llmsearch/parsers/images/gemini_parser.py b/src/llmsearch/parsers/images/gemini_parser.py index dce8de8..5033057 100644 --- a/src/llmsearch/parsers/images/gemini_parser.py +++ b/src/llmsearch/parsers/images/gemini_parser.py @@ -19,18 +19,15 @@ class GeminiImageAnalyzer: def __init__( self, model_name: str, - instruction: str = """From the image, extract detailed quantitative and qualitative data points.""", + system_instruction: str, + user_instruction: str ): self.model_name = model_name - self.instruction = instruction + self.instruction = user_instruction + print(system_instruction, user_instruction) self.model = genai.GenerativeModel( model_name, - system_instruction="""You are an research assistant. You analyze the image to extract detailed information. Response must be a Markdown string in the follwing format: - -- First line is a heading with image caption, starting with '# ' -- Second line is empty -- From the third line on - detailed data points and related metadata, extracted from the image, in Markdown format. Don't use Markdown tables. -""", + system_instruction = system_instruction, generation_config=genai.types.GenerationConfig( # Only one candidate for now. candidate_count=1, diff --git a/src/llmsearch/parsers/images/generic.py b/src/llmsearch/parsers/images/generic.py index f9c2943..3b59764 100644 --- a/src/llmsearch/parsers/images/generic.py +++ b/src/llmsearch/parsers/images/generic.py @@ -1,8 +1,9 @@ from collections import defaultdict +import importlib import io from multiprocessing.pool import ThreadPool from pathlib import Path -from typing import Dict, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Callable import PIL.Image import pymupdf @@ -10,9 +11,40 @@ from pydantic import BaseModel from tenacity import retry, stop_after_attempt, wait_random_exponential -from llmsearch.config import PDFImageParser +from llmsearch.config import PDFImageParseSettings, PDFImageParser from llmsearch.parsers.markdown import markdown_splitter +# Define a mapping of PDFImageParser to corresponding analyzer classes and config +ANALYZER_MAPPING: Dict[PDFImageParser, Any] = { + PDFImageParser.GEMINI_15_FLASH: { + "import_path": "llmsearch.parsers.images.gemini_parser", # Import path for lazy loading + "class_name": "GeminiImageAnalyzer", + "params": {"model_name": "gemini-1.5-flash"}, + }, + + PDFImageParser.GEMINI_15_PRO: { + "import_path": "llmsearch.parsers.images.gemini_parser", # Import path for lazy loading + "class_name": "GeminiImageAnalyzer", + "params": {"model_name": "gemini-1.5-pro"}, + }, +} + + +def create_analyzer(image_analyzer: PDFImageParser, **additional_params): + analyzer_info = ANALYZER_MAPPING.get(image_analyzer) + + if analyzer_info is None: + raise ValueError(f"Unsupported image analyzer type: {image_analyzer}") + + # Lazy load the module + module = importlib.import_module(analyzer_info["import_path"]) + analyzer_class = getattr(module, analyzer_info["class_name"]) + analyzer_params = analyzer_info["params"] + + params = {**analyzer_params, **additional_params} + + return analyzer_class(**params) + class PDFImage(BaseModel): image_fn: Path @@ -26,49 +58,41 @@ def __init__( self, pdf_fn: Path, temp_folder: Path, - image_analyzer, - save_output=True, + image_analyzer: Callable, + save_output: bool = True, max_base_width: int = 1280, min_width: int = 640, min_height: int = 200, ): self.pdf_fn = pdf_fn - self.max_base_width = max_base_width self.temp_folder = temp_folder - self.min_width = min_width - self.min_height = min_height self.image_analyzer = image_analyzer self.save_output = save_output + self.max_base_width = max_base_width + self.min_width = min_width + self.min_height = min_height def prepare_and_clean_folder(self): - # Check if the folder exists if not self.temp_folder.exists(): - # Create the folder if it doesn't exist self.temp_folder.mkdir(parents=True, exist_ok=True) logger.info(f"Created folder: {self.temp_folder}") else: for file in self.temp_folder.iterdir(): if file.is_file(): - file.unlink() # Delete the file - logger.info(f"Deleted file: {file}") + file.unlink() + logger.debug(f"Deleted file: {file}") def extract_images(self) -> List[PDFImage]: self.prepare_and_clean_folder() - doc = pymupdf.open(self.pdf_fn) out_images = [] for page in doc: - page_images = page.get_images() - for img in page_images: + for img in page.get_images(): xref = img[0] - data = doc.extract_image(xref=xref) - out_fn = self._resize_and_save_image( - data=data, - page_num=page.number, - xref_num=xref, - ) - if out_fn is not None: + data = doc.extract_image(xref) + out_fn = self._resize_and_save_image(data, page.number, xref) + if out_fn: out_images.append( PDFImage( image_fn=out_fn, @@ -76,7 +100,6 @@ def extract_images(self) -> List[PDFImage]: bbox=(img[1], img[2], img[3], img[4]), ) ) - return out_images def _resize_and_save_image( @@ -85,30 +108,32 @@ def _resize_and_save_image( page_num: int, xref_num: int, ) -> Optional[Path]: - - image = data.get("image", None) - if image is None: + image_data = data.get("image") + if not image_data: return - with PIL.Image.open(io.BytesIO(image)) as img: + with PIL.Image.open(io.BytesIO(image_data)) as img: if img.size[1] < self.min_height or img.size[0] < self.min_width: - logger.info( + logger.debug( f"Image on page {page_num}, xref {xref_num} is too small. Skipping extraction..." ) return None - wpercent = self.max_base_width / float(img.size[0]) - # Resize the image, if needed + wpercent = self.max_base_width / float(img.size[0]) if wpercent < 1: - hsize = int((float(img.size[1]) * float(wpercent))) + hsize = int(float(img.size[1]) * wpercent) img = img.resize( (self.max_base_width, hsize), PIL.Image.Resampling.LANCZOS ) - out_fn = self.temp_folder / (str(self.pdf_fn.stem) + f"_page_{page_num}_xref_{xref_num}.png") - logger.info(f"Saving file: {out_fn}") - img.save(out_fn, mode="wb") - return Path(out_fn) + out_fn = ( + self.temp_folder + / f"{self.pdf_fn.stem}_page_{page_num}_xref_{xref_num}.png" + ) + logger.debug(f"Saving file: {out_fn}") + img.save(out_fn) + + return out_fn def analyze_images_threaded( self, extracted_images: List[PDFImage], max_threads: int = 10 @@ -117,22 +142,25 @@ def analyze_images_threaded( results = pool.starmap( analyze_single_image, [ - (pdf_image, self.image_analyzer, i) - for i, pdf_image in enumerate(extracted_images) + (img, self.image_analyzer, i) + for i, img in enumerate(extracted_images) ], ) if self.save_output: - for r in results: - with open(str(r.image_fn)[:-3] + ".md", "w") as file: - file.write(r.markdown) + for result in results: + with open(str(result.image_fn).replace(".png", ".md"), "w") as file: + file.write(result.markdown) return results def log_attempt_number(retry_state): - """return the result of the last call attempt""" - logger.error(f"API call attempt failed. Retrying: {retry_state.attempt_number}...") + error_message = str(retry_state.outcome.exception()) + logger.error( + f"API call attempt {retry_state.attempt_number} failed with error: {error_message}. Retrying..." + ) + # logger.error(f"API call attempt failed. Retrying: {retry_state.attempt_number}...") @retry( @@ -140,26 +168,30 @@ def log_attempt_number(retry_state): stop=stop_after_attempt(6), after=log_attempt_number, ) -def analyze_single_image(pdf_image: PDFImage, image_analyzer, i: int) -> PDFImage: - fn = pdf_image.image_fn - pdf_image.markdown = image_analyzer.analyze(fn) +def analyze_single_image( + pdf_image: PDFImage, image_analyzer: Callable, i: int +) -> PDFImage: + pdf_image.markdown = image_analyzer.analyze(pdf_image.image_fn) return pdf_image def get_image_chunks( - path: Path, max_size: int, image_analyzer: PDFImageParser, cache_folder: Path + path: Path, + max_size: int, + image_parse_setting: PDFImageParseSettings, + cache_folder: Path, ) -> Tuple[List[dict], Dict[int, List[Tuple[float]]]]: - if image_analyzer is PDFImageParser.GEMINI_15_FLASH: - from llmsearch.parsers.images.gemini_parser import GeminiImageAnalyzer - analyzer = GeminiImageAnalyzer(model_name="gemini-1.5-flash") + analyzer = create_analyzer( + image_parse_setting.image_parser, + system_instruction=image_parse_setting.system_instruction, + user_instruction=image_parse_setting.user_instruction, + ) image_parser = GenericPDFImageParser( pdf_fn=path, temp_folder=cache_folder / "pdf_images_temp", image_analyzer=analyzer, - # image_analyzer=GeminiImageAnalyzer(model_name="gemini-1.5-pro-exp-0801") ) - extracted_images = image_parser.extract_images() parsed_images = image_parser.analyze_images_threaded(extracted_images) @@ -167,8 +199,9 @@ def get_image_chunks( img_bboxes = defaultdict(list) for img in parsed_images: - print(str(img.image_fn) + ".md") - out_blocks += markdown_splitter(path=str(img.image_fn)[:-3] + ".md", max_chunk_size=max_size) + out_blocks += markdown_splitter( + path=str(img.image_fn).replace(".png", ".md"), max_chunk_size=max_size + ) img_bboxes[img.page_num].append(img.bbox) return out_blocks, img_bboxes @@ -179,7 +212,7 @@ def get_image_chunks( res = get_image_chunks( path=Path("/home/snexus/Downloads/Graph_Example2.pdf"), max_size=1024, - image_analyzer=PDFImageParser.GEMINI_15_FLASH, + image_parse_setting=PDFImageParseSettings(image_parser= PDFImageParser.GEMINI_15_PRO), cache_folder=Path("./output_images"), ) diff --git a/src/llmsearch/parsers/pdf.py b/src/llmsearch/parsers/pdf.py index 419baf8..911405f 100644 --- a/src/llmsearch/parsers/pdf.py +++ b/src/llmsearch/parsers/pdf.py @@ -6,7 +6,7 @@ from loguru import logger from langchain_text_splitters import CharacterTextSplitter -from llmsearch.parsers.tables.generic import boxes_intersect +from llmsearch.parsers.tables.generic import do_boxes_intersect class PDFSplitter: @@ -149,7 +149,7 @@ def filter_blocks(blocks: List[Tuple[float, float, float, float, str]], skip_block = False for filter_bbox in page_table_bboxes: - if boxes_intersect(filter_bbox, block_bbox): + if do_boxes_intersect(filter_bbox, block_bbox): # We found an intersection, set the flag and break the inner loop skip_block = True # print(f"SKipping block: {block}") diff --git a/src/llmsearch/parsers/tables/generic.py b/src/llmsearch/parsers/tables/generic.py index fba65c5..fd1f62c 100644 --- a/src/llmsearch/parsers/tables/generic.py +++ b/src/llmsearch/parsers/tables/generic.py @@ -1,13 +1,37 @@ from collections import defaultdict +import importlib from pathlib import Path -from typing import Dict, List, Tuple +from typing import Dict, List, Tuple, Any import pandas as pd from loguru import logger - from abc import ABC, abstractmethod - from llmsearch.config import PDFTableParser +# Define a mapping of PDFImageParser to corresponding analyzer classes and config +PARSER_MAPPING: Dict[PDFTableParser, Any] = { + PDFTableParser.GMFT: { + "import_path": "llmsearch.parsers.tables.gmft_parser", # Import path for lazy loading + "class_name": "GMFTParser", + "params": {}, + }, + # Add more analyzers here as needed + # PDFImageParser.ANOTHER_TYPE: {'import_path': 'another.module.path', 'class_name': 'AnotherAnalyzer', 'params': {'param1': value1, 'param2': value2}}, +} + + +def create_table_parser(table_parser: PDFTableParser, filename: Path): + parser_info = PARSER_MAPPING.get(table_parser) + + if parser_info is None: + raise ValueError(f"Unsupported table parser type: {table_parser}") + + # Lazy load the module + module = importlib.import_module(parser_info["import_path"]) + parser_class = getattr(module, parser_info["class_name"]) + additional_parser_params = parser_info["params"] + + return parser_class(fn = filename, **additional_parser_params) + class GenericParsedTable(ABC): def __init__(self, page_number: int, bbox: Tuple[float, float, float, float]): @@ -17,42 +41,41 @@ def __init__(self, page_number: int, bbox: Tuple[float, float, float, float]): @property @abstractmethod def df(self) -> pd.DataFrame: - """Returns Pandas DF corresponding to a table""" + """Returns a Pandas DataFrame corresponding to a table.""" pass @property @abstractmethod def caption(self) -> str: - """Returns caption of the table""" + """Returns the caption of the table.""" pass @property @abstractmethod def xml(self) -> List[str]: - """Returns xml representation of the table""" + """Returns XML representation of the table.""" pass def pandas_df_to_xml(df: pd.DataFrame) -> List[str]: - """Converts Pandas df to a simplified xml representation digestible by LLMs + """Converts a Pandas DataFrame to a simplified XML representation. Args: - df (pd.DataFrame): Pandas df + df (pd.DataFrame): The DataFrame to convert. Returns: - str: List of xml row strings representing the dataframe + List[str]: List of XML row strings representing the DataFrame. """ def func(row): xml = [""] for field in row.index: - xml.append(' {1}'.format(field, row[field])) + xml.append(f' {row[field]}') xml.append("") return "\n".join(xml) items = df.apply(func, axis=1).tolist() return items - # return "\n".join(items) def pdf_table_splitter( @@ -60,15 +83,25 @@ def pdf_table_splitter( max_size: int, include_caption: bool = True, max_caption_size_ratio: int = 4, -): +) -> List[Dict[str, Any]]: + """Splits a parsed table into manageable chunks. + + Args: + parsed_table (GenericParsedTable): The parsed table instance. + max_size (int): Maximum size for each chunk. + include_caption (bool): Whether to include the table caption. + max_caption_size_ratio (int): Ratio to determine allowable caption size. + + Returns: + List[Dict[str, Any]]: List of text chunks with metadata. + """ xml_elements = parsed_table.xml caption = parsed_table.caption metadata = {"page": parsed_table.page_num, "source_chunk_type": "table"} - all_chunks = [] - # If caption is too long, trim it down, so there is some space for actual data + # Trim caption if it's too long if len(caption) > max_size / max_caption_size_ratio: logger.warning( "Caption is too large compared to max char size, trimming down..." @@ -79,25 +112,19 @@ def pdf_table_splitter( if include_caption and caption: header = f"Table below contains information about: {caption}\n" + header - footer = f"```" - + footer = "```" current_text = header for el in xml_elements: - - # If new element is too big, trim it (shouldn't happen) if len(el) > max_size: logger.warning( - "xml element is larger than allowed max char size. Flushing.." - ) - # el = el[:max_size-len(header)-3] - all_chunks.append( - {"text": current_text + el + footer, "metadata": metadata} + "XML element is larger than allowed max char size. Flushing.." ) + all_chunks.append({"text": current_text + footer, "metadata": metadata}) + all_chunks.append({"text": header + el + footer, "metadata": metadata}) current_text = header - - # if current text is already large and doesn't fit the new element, flush it elif len(current_text + el) >= max_size: - all_chunks.append({"text": current_text + footer, "metadata": metadata}) + if current_text != header: + all_chunks.append({"text": current_text + footer, "metadata": metadata}) current_text = header + el + "\n" else: current_text += el + "\n" @@ -106,66 +133,63 @@ def pdf_table_splitter( all_chunks.append({"text": current_text + footer, "metadata": metadata}) return all_chunks -def boxes_intersect(box1: Tuple[float, float, float, float], box2: Tuple[float, float, float, float]) -> bool: - """ - Check if two bounding boxes intersect. - Parameters: - box1: Tuple (x1_min, y1_min, x1_max, y1_max) - box2: Tuple (x2_min, y2_min, x2_max, y2_max) +def do_boxes_intersect( + box1: Tuple[float, float, float, float], box2: Tuple[float, float, float, float] +) -> bool: + """Check if two bounding boxes intersect. + + Args: + box1 (Tuple[float, float, float, float]): First bounding box. + box2 (Tuple[float, float, float, float]): Second bounding box. Returns: - True if the boxes intersect, False otherwise. + bool: True if the boxes intersect, False otherwise. """ - - # Unpack the box coordinates x1_min, y1_min, x1_max, y1_max = box1 x2_min, y2_min, x2_max, y2_max = box2 - # Check for non-intersection - if x1_max < x2_min or x2_max < x1_min: - return False - if y1_max < y2_min or y2_max < y1_min: - return False + return not ( + x1_max < x2_min or x2_max < x1_min or y1_max < y2_min or y2_max < y1_min + ) - # If none of the non-intersection conditions are met, they must intersect - return True def get_table_chunks( - path: Path, max_size: int, table_parser: PDFTableParser, format_extensions = ("pdf",) -) -> Tuple[List[dict], Dict[int, List[Tuple[float]]]]: - """Parses tables from the document using specified table_splitter + path: Path, + max_size: int, + table_parser: PDFTableParser, + format_extensions: Tuple[str, ...] = (".pdf",), +) -> Tuple[List[Dict[str, Any]], Dict[int, List[Tuple[float, float, float, float]]]]: + """Parses tables from a document and splits them into chunks. Args: - path (Path): document path - max_size (int): Maximum chunk size to split by - table_splitter (PDFTableParser): name of the table splitter - """ + path (Path): Document path. + max_size (int): Maximum chunk size to split by. + table_parser (PDFTableParser): Table parser to use. + format_extensions (Tuple[str, ...]): Supported file formats for parsing. + Returns: + Tuple[List[Dict[str, Any]], Dict[int, List[Tuple[float, float, float, float]]]]: + A tuple with the list of table chunks and a dictionary of bounding boxes. + """ table_chunks = [] - extension = str(path).strip("/")[-3:] - if extension not in format_extensions: + extension = path.suffix.lower() + if extension not in format_extensions: logger.info(f"Format {extension} doesn't support table parsing..Skipping..") - return list(), dict() + return [], {} - if table_parser is PDFTableParser.GMFT: - from llmsearch.parsers.tables.gmft_parser import GMFTParser - parser = GMFTParser(fn=path) - splitter = pdf_table_splitter - else: - raise TypeError(f"Unknown table parser: {table_parser}") + parser = create_table_parser(table_parser, filename=path) logger.info("Parsing tables..") - parsed_tables = parser.parsed_tables logger.info(f"Parsed {len(parsed_tables)} tables. Chunking...") for parsed_table in parsed_tables: - table_chunks += splitter(parsed_table, max_size=max_size) + table_chunks.extend(pdf_table_splitter(parsed_table, max_size=max_size)) - # Extract tables bounding boxes and store in a convenient data structure. + # Extract bounding boxes table_bboxes = defaultdict(list) for table in parsed_tables: table_bboxes[table.page_num].append(table.bbox) - return table_chunks, table_bboxes \ No newline at end of file + return table_chunks, table_bboxes diff --git a/src/llmsearch/parsers/tables/gmft_parser.py b/src/llmsearch/parsers/tables/gmft_parser.py index e1d29bb..2c11226 100644 --- a/src/llmsearch/parsers/tables/gmft_parser.py +++ b/src/llmsearch/parsers/tables/gmft_parser.py @@ -5,7 +5,6 @@ from gmft import ( CroppedTable, TableDetector, - AutoFormatConfig, AutoTableFormatter, ) from pathlib import Path @@ -15,132 +14,224 @@ from llmsearch.parsers.tables.generic import ( pandas_df_to_xml, GenericParsedTable, - pdf_table_splitter, ) -class TableFormatterSingleton: - """Singleton for table formatter""" +class XMLConverter: + """Converts Pandas DataFrames to XML format.""" + + @staticmethod + def convert(df: pd.DataFrame) -> List[str]: + """Converts a DataFrame to a list of XML strings. + + Args: + df (pd.DataFrame): The DataFrame to convert. + + Returns: + List[str]: A list of XML strings representing the DataFrame. + """ + return pandas_df_to_xml(df) + - _instance = None +class ExtractionError(Exception): + """Custom exception for extraction failures.""" + pass + + +@dataclass +class PageTables: + """Holds cropped tables extracted from a specific page of a document.""" + page_num: int + cropped_tables: List[CroppedTable] + + @property + def n_tables(self) -> int: + """Returns the number of cropped tables extracted from the page.""" + return len(self.cropped_tables) + + +class TableFormatterSingleton: + """Singleton class for managing a single instance of AutoTableFormatter.""" + + _instance: Optional['TableFormatterSingleton'] = None formatter = None def __new__(cls, *args, **kwargs): - if not cls._instance: + """Creates a new instance if one does not already exist.""" + if cls._instance is None: logger.info("Initializing AutoTableFormatter...") - cls._instance = super(TableFormatterSingleton, cls).__new__(cls) + cls._instance = super().__new__(cls) cls._instance.formatter = AutoTableFormatter() return cls._instance class GMFTParsedTable(GenericParsedTable): - def __init__(self, table: CroppedTable, page_num: int) -> None: - super().__init__( - page_number=page_num, bbox=table.bbox - ) # Initialize the field from the abstract class - self._table = table - self.failed = False - self.formatter: AutoTableFormatter = TableFormatterSingleton().formatter + """Represents a parsed table with its metadata and data extraction logic.""" - # Formatter is passed externally - # self.formatter = formatter + def __init__(self, table: CroppedTable, page_num: int, formatter: AutoTableFormatter) -> None: + """Initializes the parsed table with a cropped table, page number, and formatter. + + Args: + table (CroppedTable): The cropped table to parse. + page_num (int): The page number where the table is found. + formatter (AutoTableFormatter): The formatter to be used for extraction. + """ + super().__init__(page_number=page_num, bbox=table.bbox) + self._table = table # Store the cropped table + self.failed = False # Track extraction failures + self.formatter = formatter # Formatter for extracting data @cached_property def _captions(self) -> List[str]: - # return "" + """Caches and returns a list of non-empty captions from the table.""" return [c for c in self._table.captions() if c.strip()] @cached_property def caption(self) -> str: + """Returns a unique string of all captions, combined into one.""" return "\n".join(set(self._captions)) @property def df(self) -> Optional[pd.DataFrame]: - ft = self.formatter.extract(self._table) + """Attempts to extract a DataFrame from the cropped table. + + Returns: + Optional[pd.DataFrame]: The extracted DataFrame or None if extraction fails. + + Raises: + ExtractionError: If extraction fails, this error will be raised. + """ + ft = self.formatter.extract(self._table) # Use the formatter to extract the table try: - df = ft.df() + return ft.df() # Return the DataFrame except ValueError as ex: logger.error(f"Couldn't extract df on page {self.page_num}: {str(ex)}") self.failed = True return None - - # config = AutoFormatConfig() - # config.total_overlap_reject_threshold = 0.8 - # config.large_table_threshold = 0 - - # try: - # logger.info("\tTrying to reover") - # df = ft.df(config_overrides = config) - # except ValueError: - # logger.error(f"\tCouldn't recover, page {self.page_num}: {str(ex)}") - # return None - - return df + # raise ExtractionError(f"Extraction failed on page {self.page_num}") @property def xml(self) -> List[str]: + """Converts the extracted DataFrame to XML format. + + Returns: + List[str]: A list of XML strings. Returns an empty list if df extraction failed. + """ if self.df is None: - return list() - return pandas_df_to_xml(self.df) + return [] + return XMLConverter.convert(self.df) -@dataclass -class PageTables: - page_num: int - cropped_tables: List[CroppedTable] +class DocumentHandler: + """Handles loading a PDF document and providing access to its pages.""" - @property - def n_tables(self): - return len(self.cropped_tables) + def __init__(self, path: Path): + """Initializes the DocumentHandler with a path to a PDF. + Args: + path (Path): The file path to the PDF document. + """ + self.doc = PyPDFium2Document(path) # Load the document using PyPDFium2 -class GMFTParser: - def __init__(self, fn: Path) -> None: - self.fn = fn - self._doc = None - self._parsed_tables = None + def get_pages(self) -> Any: + """Returns an iterable of pages from the loaded document.""" + return self.doc - # logger.info("Initializing Table Formatter.") - # self.formatter = AutoTableFormatter() - def detect_page_tables(self) -> Tuple[List[PageTables], Any]: - """Detects tables in a document and returns list of page tables""" +class TableDetectorHelper: + """Facilitates detection of tables within document pages.""" - logger.info("Detecting tables...") - doc = PyPDFium2Document(self.fn) - detector = TableDetector() - pt = [] + def __init__(self): + """Initializes the TableDetector to find tables.""" + self.detector = TableDetector() - for page in doc: - pt.append( - PageTables( - page_num=page.page_number, cropped_tables=detector.extract(page) - ) - ) + def detect_tables(self, page: Any) -> List[CroppedTable]: + """Detects and returns cropped tables from a given page. - return pt, doc + Args: + page (Any): The page from which to detect tables. - @property - def parsed_tables(self) -> List[GenericParsedTable]: - if self._parsed_tables is None: - page_tables, self._doc = self.detect_page_tables() - logger.info("Parsing tables ...") + Returns: + List[CroppedTable]: A list of detected cropped tables. + """ + return self.detector.extract(page) + + +class TableParser: + """Parses cropped tables into GMFTParsedTable objects.""" + + def __init__(self, formatter: AutoTableFormatter): + """Initializes the TableParser with a formatter. + + Args: + formatter (AutoTableFormatter): Formatter used for parsing tables. + """ + self.formatter = formatter - out_tables = [] + def parse(self, cropped_table: CroppedTable, page_num: int) -> GMFTParsedTable: + """Parses a cropped table into a GMFTParsedTable instance. - for page_table in page_tables: - for cropped_table in page_table.cropped_tables: - out_tables.append( - GMFTParsedTable(cropped_table, page_table.page_num) - ) - self._parsed_tables = out_tables + Args: + cropped_table (CroppedTable): The cropped table to parse. + page_num (int): The page number where the table is found. + + Returns: + GMFTParsedTable: An instance of GMFTParsedTable containing the parsed data. + """ + return GMFTParsedTable(cropped_table, page_num, self.formatter) + + +class GMFTParser: + """Main class for handling the parsing of tables from a PDF document.""" + + def __init__(self, fn: Path) -> None: + """Initializes the parser with a PDF file path and prepares components. + + Args: + fn (Path): The file path to the PDF document. + """ + self.fn = fn + self.document_handler = DocumentHandler(fn) # Load the document + self.formatter = TableFormatterSingleton().formatter # Get the formatter + self.table_detector = TableDetectorHelper() # Initialize table detector + self.table_parser = TableParser(self.formatter) # Initialize table parser + self._parsed_tables: Optional[List[GMFTParsedTable]] = None # Cache for parsed tables + + def detect_and_parse_tables(self) -> List[GMFTParsedTable]: + """Detects and parses tables from the PDF document. + + Returns: + List[GMFTParsedTable]: A list of parsed tables. + """ + logger.info("Detecting and parsing tables...") + detected_tables = [] + + # Iterate through the pages in the document + for page in self.document_handler.get_pages(): + cropped_tables = self.table_detector.detect_tables(page) # Detect tables on the page + # Parse each cropped table found on the page + for cropped_table in cropped_tables: + parsed_table = self.table_parser.parse(cropped_table, page.page_number) + detected_tables.append(parsed_table) # Store the parsed table + + return detected_tables + + @property + def parsed_tables(self) -> List[GMFTParsedTable]: + """Lazy-loads the parsed tables when requested. + + Returns: + List[GMFTParsedTable]: A list of parsed tables from the document. + """ + if self._parsed_tables is None: + self._parsed_tables = self.detect_and_parse_tables() # Detect and parse tables if not done already return self._parsed_tables if __name__ == "__main__": # fn = Path("/home/snexus/Downloads/ws90.pdf") # fn = Path("/home/snexus/Downloads/SSRN-id2741701.pdf") - fn = Path("/home/snexus/Downloads/Table_Example1.pdf") + fn = Path("/home/snexus/Downloads/ws90.pdf") parser = GMFTParser(fn=fn) for p in parser.parsed_tables: diff --git a/tests/test_table_splitting.py b/tests/test_table_splitting.py new file mode 100644 index 0000000..b01e837 --- /dev/null +++ b/tests/test_table_splitting.py @@ -0,0 +1,105 @@ +import pytest +from unittest.mock import MagicMock +from llmsearch.parsers.tables.generic import pdf_table_splitter # Replace with the actual module name + +@pytest.fixture +def setup_parsed_table(): + """Fixture to create a mock parsed table for testing.""" + dummy_bbox = (0.0, 0.0, 100.0, 100.0) # Dummy bounding box + parsed_table = MagicMock() + parsed_table.page_num = 1 + parsed_table.bbox = dummy_bbox + parsed_table.caption = "" + return parsed_table + +def test_basic_functionality(setup_parsed_table): + parsed_table = setup_parsed_table + parsed_table.xml = [ + "1", + "2" + ] + expected_output = [ + { + "text": "```xml table:\n1\n2\n```", + "metadata": {"page": 1, "source_chunk_type": "table"} + } + ] + + result = pdf_table_splitter(parsed_table, max_size=100) # Adjust max size as needed + print(result) + assert result == expected_output + +def test_caption_inclusion(setup_parsed_table): + parsed_table = setup_parsed_table + parsed_table.xml = ["1"] + parsed_table.caption = "This is a test caption." + + expected_output = [ + { + "text": "Table below contains information about: This is a test caption.\n```xml table:\n1\n```", + "metadata": {"page": 1, "source_chunk_type": "table"} + } + ] + + result = pdf_table_splitter(parsed_table, max_size=100) + assert result == expected_output + +def test_caption_trimming(setup_parsed_table): + parsed_table = setup_parsed_table + parsed_table.xml = ["1"] + parsed_table.caption = "A very long caption that exceeds the size limit." + + expected_output = [ + { + "text": "Table below contains information about: A very long capt\n```xml table:\n1\n```", + "metadata": {"page": 1, "source_chunk_type": "table"} + } + ] + + result = pdf_table_splitter(parsed_table, max_size=50, max_caption_size_ratio=3) + print(result) + assert result == expected_output + +def test_element_larger_than_max_size(setup_parsed_table): + parsed_table = setup_parsed_table + parsed_table.xml = [ + "1", + "2" + ] + long_element = "" + "" + "X" * 200 + "" # Very long element + parsed_table.xml.append(long_element) + + result = pdf_table_splitter(parsed_table, max_size=100) + print(result) + # There should be one chunk for the first two elements and a separate chunk for the long element + assert len(result) == 3 + +def test_empty_input(setup_parsed_table): + parsed_table = setup_parsed_table + parsed_table.xml = [] + parsed_table.caption = "" + + result = pdf_table_splitter(parsed_table, max_size=100) + print(result) + assert result == [ + { + "text": "```xml table:\n```", + "metadata": {"page": 1, "source_chunk_type": "table"} + } + ] + +def test_single_element(setup_parsed_table): + parsed_table = setup_parsed_table + parsed_table.xml = ["1"] + + result = pdf_table_splitter(parsed_table, max_size=150) + assert len(result) == 1 + +def test_multiple_elements_within_limit(setup_parsed_table): + parsed_table = setup_parsed_table + parsed_table.xml = [ + "1", + "2" + ] + result = pdf_table_splitter(parsed_table, max_size=250) + assert len(result) == 1 \ No newline at end of file