Skip to content

Commit

Permalink
refactor detect script (#1060)
Browse files Browse the repository at this point in the history
  • Loading branch information
felixdittrich92 authored Sep 15, 2022
1 parent a95baaa commit 42195de
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 63 deletions.
46 changes: 46 additions & 0 deletions .github/workflows/scripts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,52 @@ jobs:
wget https://github.com/mindee/doctr/releases/download/v0.1.0/sample.pdf
python scripts/analyze.py sample.pdf --noblock
test-detect-text:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: [3.7, 3.8]
framework: [tensorflow, pytorch]
steps:
- if: matrix.os == 'macos-latest'
name: Install MacOS prerequisites
run: brew install cairo pango gdk-pixbuf libffi
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python }}
architecture: x64
- if: matrix.framework == 'tensorflow'
name: Cache python modules (TF)
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}
- if: matrix.framework == 'pytorch'
name: Cache python modules (PT)
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}
- if: matrix.framework == 'tensorflow'
name: Install package (TF)
run: |
python -m pip install --upgrade pip
pip install -e .[tf] --upgrade
- if: matrix.framework == 'pytorch'
name: Install package (PT)
run: |
python -m pip install --upgrade pip
pip install -e .[torch] --upgrade
- name: Run detection script
run: |
wget https://github.com/mindee/doctr/releases/download/v0.1.0/sample.pdf
python scripts/detect_text.py sample.pdf
test-evaluate:
runs-on: ${{ matrix.os }}
strategy:
Expand Down
102 changes: 39 additions & 63 deletions scripts/detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,95 +24,71 @@
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

IMAGE_FILE_EXTENSIONS = [".jpeg", ".jpg", ".png", ".tif", ".tiff", ".bmp"]
OTHER_EXTENSIONS = [".pdf"]

_OUTPUT_CHOICE_JSON = "json"
_OUTPUT_CHOICE_TEXT = "txt"

def _process_file(model, file_path: Path, out_format: str) -> None:

def _process_file(model, file_path: Path, out_format: str) -> str:
if str(file_path).lower().endswith(".pdf"):
if out_format not in ["txt", "json", "xml"]:
raise ValueError(f"Unsupported output format: {out_format}")

if os.path.splitext(file_path)[1] in IMAGE_FILE_EXTENSIONS:
doc = DocumentFile.from_images([file_path])
elif os.path.splitext(file_path)[1] in OTHER_EXTENSIONS:
doc = DocumentFile.from_pdf(file_path)
else:
doc = DocumentFile.from_images(file_path)
print(f"Skip unsupported file type: {file_path}")

out = model(doc)
export = out.export()

if out_format == _OUTPUT_CHOICE_JSON:
out_txt = json.dumps(export, indent=2)
elif out_format == _OUTPUT_CHOICE_TEXT:
out_txt = ""
for page in export["pages"]:
for block in page["blocks"]:
for line in block["lines"]:
for word in line["words"]:
out_txt += word["value"] + " "
out_txt += "\n"
out_txt += "\n\n"

if out_format == "json":
output = json.dumps(out.export(), indent=2)
elif out_format == "txt":
output = out.render()
elif out_format == "xml":
output = out.export_as_xml()

path = Path("output").joinpath(file_path.stem + "." + out_format)
if out_format == "xml":
for i, (xml_bytes, xml_tree) in enumerate(output):
path = Path("output").joinpath(file_path.stem + f"_{i}." + out_format)
xml_tree.write(path, encoding="utf-8", xml_declaration=True)
else:
out_txt = ""
return out_txt
with open(path, "w") as f:
f.write(output)


def main(args):
model = ocr_predictor(args.detection, args.recognition, pretrained=True)
path = Path(args.path)

os.makedirs(name="output", exist_ok=True)

if path.is_dir():
allowed = (".pdf", ".jpeg", ".jpg", ".png", ".tif", ".tiff", ".bmp")
to_process = [f for f in path.iterdir() if str(f).lower().endswith(allowed)]
for filename in tqdm(to_process):
out_path = path.joinpath(f"{filename}.{args.format}")
if out_path.exists():
continue
in_path = path.joinpath(filename)
# print(in_path)
out_str = _process_file(model, in_path, args.format)
with open(out_path, "w") as fh:
fh.write(out_str)
to_process = [
f for f in path.iterdir() if str(f).lower().endswith(tuple(IMAGE_FILE_EXTENSIONS + OTHER_EXTENSIONS))
]
for file_path in tqdm(to_process):
_process_file(model, file_path, args.format)
else:
out_str = _process_file(model, path, args.format)
print(out_str)
_process_file(model, path, args.format)


def parse_args():
parser = argparse.ArgumentParser(
description="DocTR text detection",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("path", type=str, help="Path to process: PDF, image, directory")
parser.add_argument("--detection", type=str, default="db_resnet50", help="Text detection model to use for analysis")
parser.add_argument(
"path",
type=str,
help="Path to process: PDF, image, directory",
)
parser.add_argument(
"--detection",
type=str,
default="db_resnet50",
help="Text detection model to use for analysis",
)
parser.add_argument(
"--recognition",
type=str,
default="crnn_vgg16_bn",
help="Text recognition model to use for analysis",
)
parser.add_argument(
"-f",
"--format",
choices=[
_OUTPUT_CHOICE_JSON,
_OUTPUT_CHOICE_TEXT,
],
default=_OUTPUT_CHOICE_TEXT,
help="Output format",
"--recognition", type=str, default="crnn_vgg16_bn", help="Text recognition model to use for analysis"
)
parser.add_argument("-f", "--format", choices=["txt", "json", "xml"], default="txt", help="Output format")
return parser.parse_args()


if __name__ == "__main__":
parsed_args = parse_args()
try:
main(parsed_args)
except KeyboardInterrupt:
print("Cancelled")
pass
main(parsed_args)

0 comments on commit 42195de

Please sign in to comment.