Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[refactor] detection script #1060

Merged
merged 1 commit into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions .github/workflows/scripts.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,52 @@ jobs:
wget https://github.com/mindee/doctr/releases/download/v0.1.0/sample.pdf
python scripts/analyze.py sample.pdf --noblock

test-detect-text:
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
python: [3.7, 3.8]
framework: [tensorflow, pytorch]
steps:
- if: matrix.os == 'macos-latest'
name: Install MacOS prerequisites
run: brew install cairo pango gdk-pixbuf libffi
- uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python }}
architecture: x64
- if: matrix.framework == 'tensorflow'
name: Cache python modules (TF)
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}
- if: matrix.framework == 'pytorch'
name: Cache python modules (PT)
uses: actions/cache@v2
with:
path: ~/.cache/pip
key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('pyproject.toml') }}
- if: matrix.framework == 'tensorflow'
name: Install package (TF)
run: |
python -m pip install --upgrade pip
pip install -e .[tf] --upgrade
- if: matrix.framework == 'pytorch'
name: Install package (PT)
run: |
python -m pip install --upgrade pip
pip install -e .[torch] --upgrade

- name: Run detection script
run: |
wget https://github.com/mindee/doctr/releases/download/v0.1.0/sample.pdf
python scripts/detect_text.py sample.pdf

test-evaluate:
runs-on: ${{ matrix.os }}
strategy:
Expand Down
102 changes: 39 additions & 63 deletions scripts/detect_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,95 +24,71 @@
if any(gpu_devices):
tf.config.experimental.set_memory_growth(gpu_devices[0], True)

IMAGE_FILE_EXTENSIONS = [".jpeg", ".jpg", ".png", ".tif", ".tiff", ".bmp"]
OTHER_EXTENSIONS = [".pdf"]

_OUTPUT_CHOICE_JSON = "json"
_OUTPUT_CHOICE_TEXT = "txt"

def _process_file(model, file_path: Path, out_format: str) -> None:

def _process_file(model, file_path: Path, out_format: str) -> str:
if str(file_path).lower().endswith(".pdf"):
if out_format not in ["txt", "json", "xml"]:
raise ValueError(f"Unsupported output format: {out_format}")

if os.path.splitext(file_path)[1] in IMAGE_FILE_EXTENSIONS:
doc = DocumentFile.from_images([file_path])
elif os.path.splitext(file_path)[1] in OTHER_EXTENSIONS:
doc = DocumentFile.from_pdf(file_path)
else:
doc = DocumentFile.from_images(file_path)
print(f"Skip unsupported file type: {file_path}")

out = model(doc)
export = out.export()

if out_format == _OUTPUT_CHOICE_JSON:
out_txt = json.dumps(export, indent=2)
elif out_format == _OUTPUT_CHOICE_TEXT:
out_txt = ""
for page in export["pages"]:
for block in page["blocks"]:
for line in block["lines"]:
for word in line["words"]:
out_txt += word["value"] + " "
out_txt += "\n"
out_txt += "\n\n"

if out_format == "json":
output = json.dumps(out.export(), indent=2)
elif out_format == "txt":
output = out.render()
elif out_format == "xml":
output = out.export_as_xml()

path = Path("output").joinpath(file_path.stem + "." + out_format)
if out_format == "xml":
for i, (xml_bytes, xml_tree) in enumerate(output):
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved
path = Path("output").joinpath(file_path.stem + f"_{i}." + out_format)
xml_tree.write(path, encoding="utf-8", xml_declaration=True)
else:
out_txt = ""
return out_txt
with open(path, "w") as f:
f.write(output)


def main(args):
model = ocr_predictor(args.detection, args.recognition, pretrained=True)
path = Path(args.path)

os.makedirs(name="output", exist_ok=True)

if path.is_dir():
allowed = (".pdf", ".jpeg", ".jpg", ".png", ".tif", ".tiff", ".bmp")
to_process = [f for f in path.iterdir() if str(f).lower().endswith(allowed)]
for filename in tqdm(to_process):
out_path = path.joinpath(f"{filename}.{args.format}")
if out_path.exists():
continue
in_path = path.joinpath(filename)
# print(in_path)
out_str = _process_file(model, in_path, args.format)
with open(out_path, "w") as fh:
fh.write(out_str)
to_process = [
f for f in path.iterdir() if str(f).lower().endswith(tuple(IMAGE_FILE_EXTENSIONS + OTHER_EXTENSIONS))
]
for file_path in tqdm(to_process):
_process_file(model, file_path, args.format)
else:
out_str = _process_file(model, path, args.format)
print(out_str)
_process_file(model, path, args.format)


def parse_args():
parser = argparse.ArgumentParser(
description="DocTR text detection",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
parser.add_argument("path", type=str, help="Path to process: PDF, image, directory")
parser.add_argument("--detection", type=str, default="db_resnet50", help="Text detection model to use for analysis")
parser.add_argument(
"path",
type=str,
help="Path to process: PDF, image, directory",
)
parser.add_argument(
"--detection",
type=str,
default="db_resnet50",
help="Text detection model to use for analysis",
)
parser.add_argument(
"--recognition",
type=str,
default="crnn_vgg16_bn",
help="Text recognition model to use for analysis",
)
parser.add_argument(
"-f",
"--format",
choices=[
_OUTPUT_CHOICE_JSON,
_OUTPUT_CHOICE_TEXT,
],
default=_OUTPUT_CHOICE_TEXT,
help="Output format",
"--recognition", type=str, default="crnn_vgg16_bn", help="Text recognition model to use for analysis"
)
parser.add_argument("-f", "--format", choices=["txt", "json", "xml"], default="txt", help="Output format")
return parser.parse_args()


if __name__ == "__main__":
parsed_args = parse_args()
try:
main(parsed_args)
except KeyboardInterrupt:
print("Cancelled")
pass
main(parsed_args)
felixdittrich92 marked this conversation as resolved.
Show resolved Hide resolved