From 42195dea8f5ab6929aa585f500f0ed9f23abe4ac Mon Sep 17 00:00:00 2001 From: Felix Dittrich Date: Thu, 15 Sep 2022 22:12:53 +0200 Subject: [PATCH] refactor detect script (#1060) --- .github/workflows/scripts.yml | 46 +++++++++++++++ scripts/detect_text.py | 102 +++++++++++++--------------------- 2 files changed, 85 insertions(+), 63 deletions(-) diff --git a/.github/workflows/scripts.yml b/.github/workflows/scripts.yml index 5adc9d7e73..b485b5e665 100644 --- a/.github/workflows/scripts.yml +++ b/.github/workflows/scripts.yml @@ -53,6 +53,52 @@ jobs: wget https://github.com/mindee/doctr/releases/download/v0.1.0/sample.pdf python scripts/analyze.py sample.pdf --noblock + test-detect-text: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: [3.7, 3.8] + framework: [tensorflow, pytorch] + steps: + - if: matrix.os == 'macos-latest' + name: Install MacOS prerequisites + run: brew install cairo pango gdk-pixbuf libffi + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - if: matrix.framework == 'tensorflow' + name: Cache python modules (TF) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }} + - if: matrix.framework == 'pytorch' + name: Cache python modules (PT) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }} + - if: matrix.framework == 'tensorflow' + name: Install package (TF) + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + - if: matrix.framework == 'pytorch' + name: Install package (PT) + run: | + python -m pip install --upgrade pip + pip install -e .[torch] --upgrade + + - name: Run detection script + run: | + wget https://github.com/mindee/doctr/releases/download/v0.1.0/sample.pdf + python scripts/detect_text.py sample.pdf + test-evaluate: runs-on: ${{ matrix.os }} strategy: diff --git a/scripts/detect_text.py b/scripts/detect_text.py index 9d967d3ded..f634fff82a 100644 --- a/scripts/detect_text.py +++ b/scripts/detect_text.py @@ -24,54 +24,55 @@ if any(gpu_devices): tf.config.experimental.set_memory_growth(gpu_devices[0], True) +IMAGE_FILE_EXTENSIONS = [".jpeg", ".jpg", ".png", ".tif", ".tiff", ".bmp"] +OTHER_EXTENSIONS = [".pdf"] -_OUTPUT_CHOICE_JSON = "json" -_OUTPUT_CHOICE_TEXT = "txt" +def _process_file(model, file_path: Path, out_format: str) -> None: -def _process_file(model, file_path: Path, out_format: str) -> str: - if str(file_path).lower().endswith(".pdf"): + if out_format not in ["txt", "json", "xml"]: + raise ValueError(f"Unsupported output format: {out_format}") + + if os.path.splitext(file_path)[1] in IMAGE_FILE_EXTENSIONS: + doc = DocumentFile.from_images([file_path]) + elif os.path.splitext(file_path)[1] in OTHER_EXTENSIONS: doc = DocumentFile.from_pdf(file_path) else: - doc = DocumentFile.from_images(file_path) + print(f"Skip unsupported file type: {file_path}") out = model(doc) - export = out.export() - - if out_format == _OUTPUT_CHOICE_JSON: - out_txt = json.dumps(export, indent=2) - elif out_format == _OUTPUT_CHOICE_TEXT: - out_txt = "" - for page in export["pages"]: - for block in page["blocks"]: - for line in block["lines"]: - for word in line["words"]: - out_txt += word["value"] + " " - out_txt += "\n" - out_txt += "\n\n" + + if out_format == "json": + output = json.dumps(out.export(), indent=2) + elif out_format == "txt": + output = out.render() + elif out_format == "xml": + output = out.export_as_xml() + + path = Path("output").joinpath(file_path.stem + "." + out_format) + if out_format == "xml": + for i, (xml_bytes, xml_tree) in enumerate(output): + path = Path("output").joinpath(file_path.stem + f"_{i}." + out_format) + xml_tree.write(path, encoding="utf-8", xml_declaration=True) else: - out_txt = "" - return out_txt + with open(path, "w") as f: + f.write(output) def main(args): model = ocr_predictor(args.detection, args.recognition, pretrained=True) path = Path(args.path) + + os.makedirs(name="output", exist_ok=True) + if path.is_dir(): - allowed = (".pdf", ".jpeg", ".jpg", ".png", ".tif", ".tiff", ".bmp") - to_process = [f for f in path.iterdir() if str(f).lower().endswith(allowed)] - for filename in tqdm(to_process): - out_path = path.joinpath(f"{filename}.{args.format}") - if out_path.exists(): - continue - in_path = path.joinpath(filename) - # print(in_path) - out_str = _process_file(model, in_path, args.format) - with open(out_path, "w") as fh: - fh.write(out_str) + to_process = [ + f for f in path.iterdir() if str(f).lower().endswith(tuple(IMAGE_FILE_EXTENSIONS + OTHER_EXTENSIONS)) + ] + for file_path in tqdm(to_process): + _process_file(model, file_path, args.format) else: - out_str = _process_file(model, path, args.format) - print(out_str) + _process_file(model, path, args.format) def parse_args(): @@ -79,40 +80,15 @@ def parse_args(): description="DocTR text detection", formatter_class=argparse.ArgumentDefaultsHelpFormatter, ) + parser.add_argument("path", type=str, help="Path to process: PDF, image, directory") + parser.add_argument("--detection", type=str, default="db_resnet50", help="Text detection model to use for analysis") parser.add_argument( - "path", - type=str, - help="Path to process: PDF, image, directory", - ) - parser.add_argument( - "--detection", - type=str, - default="db_resnet50", - help="Text detection model to use for analysis", - ) - parser.add_argument( - "--recognition", - type=str, - default="crnn_vgg16_bn", - help="Text recognition model to use for analysis", - ) - parser.add_argument( - "-f", - "--format", - choices=[ - _OUTPUT_CHOICE_JSON, - _OUTPUT_CHOICE_TEXT, - ], - default=_OUTPUT_CHOICE_TEXT, - help="Output format", + "--recognition", type=str, default="crnn_vgg16_bn", help="Text recognition model to use for analysis" ) + parser.add_argument("-f", "--format", choices=["txt", "json", "xml"], default="txt", help="Output format") return parser.parse_args() if __name__ == "__main__": parsed_args = parse_args() - try: - main(parsed_args) - except KeyboardInterrupt: - print("Cancelled") - pass + main(parsed_args)