From 81c313eb7b759bb3c4c6e33547408d2bf3de7771 Mon Sep 17 00:00:00 2001 From: felixdittrich92 Date: Tue, 11 Jan 2022 08:34:52 +0100 Subject: [PATCH 1/7] backup --- .coveragerc | 2 + .flake8 | 5 + .github/ISSUE_TEMPLATE/bug_report.yml | 63 ++ .github/ISSUE_TEMPLATE/config.yml | 5 + .github/ISSUE_TEMPLATE/feature_request.yml | 33 + .github/release.yml | 24 + .github/verify_pr_labels.py | 81 ++ .github/workflows/builds.yml | 51 ++ .github/workflows/demo.yml | 45 ++ .github/workflows/doc-status.yml | 22 + .github/workflows/docker.yml | 30 + .github/workflows/docs.yml | 53 ++ .github/workflows/main.yml | 127 ++++ .github/workflows/pr-labels.yml | 29 + .github/workflows/pull_requests.yml | 34 + .github/workflows/references.yml | 299 ++++++++ .github/workflows/release.yml | 66 ++ .github/workflows/scripts.yml | 114 +++ .github/workflows/style.yml | 94 +++ .gitignore | 140 ++++ .isort.cfg | 5 + .pydocstyle | 3 + CODE_OF_CONDUCT.md | 128 ++++ CONTRIBUTING.md | 122 +++ Dockerfile | 19 + Dockerfile-api | 34 + LICENSE | 201 +++++ Makefile | 33 + README.md | 283 +++++++ api/README.md | 92 +++ api/app/config.py | 13 + api/app/main.py | 46 ++ api/app/routes/detection.py | 22 + api/app/routes/ocr.py | 24 + api/app/routes/recognition.py | 20 + api/app/schemas.py | 21 + api/app/vision.py | 16 + api/requirements-dev.txt | 4 + api/requirements.txt | 4 + api/tests/conftest.py | 28 + api/tests/routes/test_detection.py | 25 + api/tests/routes/test_ocr.py | 29 + api/tests/routes/test_recognition.py | 10 + demo/app.py | 109 +++ demo/requirements.txt | 2 + docker-compose.yml | 12 + docs/Makefile | 20 + docs/build.sh | 44 ++ docs/requirements.txt | 7 + docs/source/_static/css/mindee.css | 43 ++ .../_static/images/Logo-docTR-white.png | Bin 0 -> 7568 bytes docs/source/_static/images/favicon.ico | Bin 0 -> 100942 bytes docs/source/_static/js/custom.js | 108 +++ docs/source/changelog.rst | 38 + docs/source/conf.py | 101 +++ docs/source/datasets.rst | 104 +++ docs/source/index.rst | 91 +++ docs/source/installing.rst | 66 ++ docs/source/io.rst | 94 +++ docs/source/models.rst | 62 ++ docs/source/notebooks.md | 1 + docs/source/transforms.rst | 38 + docs/source/using_model_export.rst | 71 ++ docs/source/using_models.rst | 329 +++++++++ docs/source/utils.rst | 46 ++ doctr/__init__.py | 3 + doctr/datasets/__init__.py | 22 + doctr/datasets/cord.py | 90 +++ doctr/datasets/datasets/__init__.py | 6 + doctr/datasets/datasets/base.py | 118 +++ doctr/datasets/datasets/pytorch.py | 37 + doctr/datasets/datasets/tensorflow.py | 37 + doctr/datasets/detection.py | 64 ++ doctr/datasets/doc_artefacts.py | 79 ++ doctr/datasets/funsd.py | 93 +++ doctr/datasets/generator/__init__.py | 6 + doctr/datasets/generator/base.py | 154 ++++ doctr/datasets/generator/pytorch.py | 54 ++ doctr/datasets/generator/tensorflow.py | 61 ++ doctr/datasets/ic03.py | 108 +++ doctr/datasets/ic13.py | 83 +++ doctr/datasets/iiit5k.py | 93 +++ doctr/datasets/imgur5k.py | 100 +++ doctr/datasets/loader.py | 101 +++ doctr/datasets/ocr.py | 65 ++ doctr/datasets/recognition.py | 55 ++ doctr/datasets/sroie.py | 79 ++ doctr/datasets/svhn.py | 114 +++ doctr/datasets/svt.py | 100 +++ doctr/datasets/synthtext.py | 88 +++ doctr/datasets/utils.py | 163 +++++ doctr/datasets/vocabs.py | 33 + doctr/file_utils.py | 85 +++ doctr/io/__init__.py | 5 + doctr/io/elements.py | 405 ++++++++++ doctr/io/html.py | 26 + doctr/io/image/__init__.py | 8 + doctr/io/image/base.py | 55 ++ doctr/io/image/pytorch.py | 104 +++ doctr/io/image/tensorflow.py | 109 +++ doctr/io/pdf.py | 184 +++++ doctr/io/reader.py | 73 ++ doctr/models/__init__.py | 5 + doctr/models/_utils.py | 222 ++++++ doctr/models/artefacts/__init__.py | 2 + doctr/models/artefacts/barcode.py | 77 ++ doctr/models/artefacts/face.py | 59 ++ doctr/models/builder.py | 313 ++++++++ doctr/models/classification/__init__.py | 5 + .../classification/magc_resnet/__init__.py | 6 + .../classification/magc_resnet/pytorch.py | 158 ++++ .../classification/magc_resnet/tensorflow.py | 194 +++++ .../classification/mobilenet/__init__.py | 6 + .../classification/mobilenet/pytorch.py | 204 ++++++ .../classification/mobilenet/tensorflow.py | 385 ++++++++++ .../classification/predictor/__init__.py | 6 + .../classification/predictor/pytorch.py | 55 ++ .../classification/predictor/tensorflow.py | 56 ++ .../models/classification/resnet/__init__.py | 6 + doctr/models/classification/resnet/pytorch.py | 200 +++++ .../classification/resnet/tensorflow.py | 263 +++++++ doctr/models/classification/vgg/__init__.py | 6 + doctr/models/classification/vgg/pytorch.py | 75 ++ doctr/models/classification/vgg/tensorflow.py | 115 +++ doctr/models/classification/zoo.py | 67 ++ doctr/models/core.py | 19 + doctr/models/detection/__init__.py | 3 + doctr/models/detection/_utils/__init__.py | 6 + doctr/models/detection/_utils/pytorch.py | 37 + doctr/models/detection/_utils/tensorflow.py | 34 + doctr/models/detection/core.py | 105 +++ .../differentiable_binarization/__init__.py | 6 + .../differentiable_binarization/base.py | 348 +++++++++ .../differentiable_binarization/pytorch.py | 400 ++++++++++ .../differentiable_binarization/tensorflow.py | 371 ++++++++++ doctr/models/detection/linknet/__init__.py | 6 + doctr/models/detection/linknet/base.py | 177 +++++ doctr/models/detection/linknet/pytorch.py | 236 ++++++ doctr/models/detection/linknet/tensorflow.py | 263 +++++++ doctr/models/detection/predictor/__init__.py | 6 + doctr/models/detection/predictor/pytorch.py | 51 ++ .../models/detection/predictor/tensorflow.py | 52 ++ doctr/models/detection/zoo.py | 76 ++ doctr/models/obj_detection/__init__.py | 1 + .../obj_detection/faster_rcnn/__init__.py | 4 + .../obj_detection/faster_rcnn/pytorch.py | 79 ++ doctr/models/predictor/__init__.py | 6 + doctr/models/predictor/base.py | 95 +++ doctr/models/predictor/pytorch.py | 105 +++ doctr/models/predictor/tensorflow.py | 96 +++ doctr/models/preprocessor/__init__.py | 6 + doctr/models/preprocessor/pytorch.py | 127 ++++ doctr/models/preprocessor/tensorflow.py | 127 ++++ doctr/models/recognition/__init__.py | 4 + doctr/models/recognition/core.py | 61 ++ doctr/models/recognition/crnn/__init__.py | 6 + doctr/models/recognition/crnn/pytorch.py | 308 ++++++++ doctr/models/recognition/crnn/tensorflow.py | 278 +++++++ doctr/models/recognition/master/__init__.py | 6 + doctr/models/recognition/master/base.py | 57 ++ doctr/models/recognition/master/pytorch.py | 295 ++++++++ doctr/models/recognition/master/tensorflow.py | 300 ++++++++ .../models/recognition/predictor/__init__.py | 6 + doctr/models/recognition/predictor/_utils.py | 89 +++ doctr/models/recognition/predictor/pytorch.py | 85 +++ .../recognition/predictor/tensorflow.py | 81 ++ doctr/models/recognition/sar/__init__.py | 6 + doctr/models/recognition/sar/pytorch.py | 323 ++++++++ doctr/models/recognition/sar/tensorflow.py | 361 +++++++++ .../recognition/transformer/__init__.py | 6 + .../models/recognition/transformer/pytorch.py | 91 +++ .../recognition/transformer/tensorflow.py | 265 +++++++ doctr/models/recognition/utils.py | 84 +++ doctr/models/recognition/zoo.py | 56 ++ doctr/models/utils/__init__.py | 6 + doctr/models/utils/pytorch.py | 84 +++ doctr/models/utils/tensorflow.py | 123 ++++ doctr/models/zoo.py | 87 +++ doctr/transforms/__init__.py | 1 + doctr/transforms/functional/__init__.py | 6 + doctr/transforms/functional/base.py | 44 ++ doctr/transforms/functional/pytorch.py | 103 +++ doctr/transforms/functional/tensorflow.py | 137 ++++ doctr/transforms/modules/__init__.py | 8 + doctr/transforms/modules/base.py | 191 +++++ doctr/transforms/modules/pytorch.py | 121 +++ doctr/transforms/modules/tensorflow.py | 419 +++++++++++ doctr/utils/__init__.py | 4 + doctr/utils/common_types.py | 18 + doctr/utils/data.py | 109 +++ doctr/utils/fonts.py | 38 + doctr/utils/geometry.py | 262 +++++++ doctr/utils/metrics.py | 692 ++++++++++++++++++ doctr/utils/multithreading.py | 39 + doctr/utils/repr.py | 58 ++ doctr/utils/visualization.py | 338 +++++++++ mypy.ini | 77 ++ notebooks/README.md | 9 + references/classification/README.md | 34 + references/classification/latency.csv | 31 + references/classification/latency_pytorch.py | 64 ++ .../classification/latency_tensorflow.py | 72 ++ references/classification/train_pytorch.py | 391 ++++++++++ references/classification/train_tensorflow.py | 351 +++++++++ references/classification/utils.py | 73 ++ references/detection/README.md | 67 ++ references/detection/evaluate_pytorch.py | 160 ++++ references/detection/evaluate_tensorflow.py | 138 ++++ references/detection/latency.csv | 15 + references/detection/latency_pytorch.py | 64 ++ references/detection/latency_tensorflow.py | 72 ++ references/detection/results.csv | 9 + references/detection/train_pytorch.py | 391 ++++++++++ references/detection/train_tensorflow.py | 340 +++++++++ references/detection/utils.py | 83 +++ references/obj_detection/latency.csv | 3 + references/obj_detection/latency_pytorch.py | 65 ++ references/obj_detection/train_pytorch.py | 364 +++++++++ references/obj_detection/utils.py | 77 ++ references/recognition/README.md | 63 ++ references/recognition/latency.csv | 21 + references/recognition/latency_pytorch.py | 65 ++ references/recognition/latency_tensorflow.py | 74 ++ references/recognition/train_pytorch.py | 381 ++++++++++ references/recognition/train_tensorflow.py | 332 +++++++++ references/recognition/utils.py | 73 ++ references/requirements.txt | 3 + requirements-pt.txt | 16 + requirements.txt | 17 + scripts/analyze.py | 58 ++ scripts/collect_env.py | 353 +++++++++ scripts/detect_artefacts.py | 86 +++ scripts/evaluate.py | 172 +++++ setup.cfg | 3 + setup.py | 202 +++++ tests/common/test_core.py | 13 + tests/common/test_datasets.py | 44 ++ tests/common/test_datasets_utils.py | 74 ++ tests/common/test_headers.py | 46 ++ tests/common/test_io.py | 143 ++++ tests/common/test_io_elements.py | 235 ++++++ tests/common/test_models.py | 105 +++ tests/common/test_models_artefacts.py | 20 + tests/common/test_models_builder.py | 92 +++ tests/common/test_models_detection.py | 75 ++ .../test_models_recognition_predictor.py | 39 + tests/common/test_models_recognition_utils.py | 29 + tests/common/test_requirements.py | 48 ++ tests/common/test_transforms.py | 28 + tests/common/test_utils_fonts.py | 11 + tests/common/test_utils_geometry.py | 111 +++ tests/common/test_utils_metrics.py | 306 ++++++++ tests/common/test_utils_multithreading.py | 20 + tests/common/test_utils_visualization.py | 40 + tests/conftest.py | 574 +++++++++++++++ tests/pytorch/test_datasets_pt.py | 436 +++++++++++ tests/pytorch/test_file_utils_pt.py | 5 + tests/pytorch/test_io_image_pt.py | 50 ++ .../pytorch/test_models_classification_pt.py | 91 +++ tests/pytorch/test_models_detection_pt.py | 93 +++ tests/pytorch/test_models_obj_detection_pt.py | 34 + tests/pytorch/test_models_preprocessor_pt.py | 48 ++ tests/pytorch/test_models_recognition_pt.py | 85 +++ tests/pytorch/test_models_utils_pt.py | 30 + tests/pytorch/test_models_zoo_pt.py | 77 ++ tests/pytorch/test_transforms_pt.py | 274 +++++++ tests/requirements.txt | 5 + tests/tensorflow/test_datasets_loader_tf.py | 79 ++ tests/tensorflow/test_datasets_tf.py | 424 +++++++++++ tests/tensorflow/test_file_utils_tf.py | 5 + tests/tensorflow/test_io_image_tf.py | 50 ++ .../test_models_classification_tf.py | 75 ++ tests/tensorflow/test_models_detection_tf.py | 156 ++++ .../tensorflow/test_models_preprocessor_tf.py | 45 ++ .../tensorflow/test_models_recognition_tf.py | 114 +++ tests/tensorflow/test_models_utils_tf.py | 45 ++ tests/tensorflow/test_models_zoo_tf.py | 79 ++ tests/tensorflow/test_transforms_tf.py | 431 +++++++++++ 278 files changed, 27581 insertions(+) create mode 100644 .coveragerc create mode 100644 .flake8 create mode 100644 .github/ISSUE_TEMPLATE/bug_report.yml create mode 100644 .github/ISSUE_TEMPLATE/config.yml create mode 100644 .github/ISSUE_TEMPLATE/feature_request.yml create mode 100644 .github/release.yml create mode 100644 .github/verify_pr_labels.py create mode 100644 .github/workflows/builds.yml create mode 100644 .github/workflows/demo.yml create mode 100644 .github/workflows/doc-status.yml create mode 100644 .github/workflows/docker.yml create mode 100644 .github/workflows/docs.yml create mode 100644 .github/workflows/main.yml create mode 100644 .github/workflows/pr-labels.yml create mode 100644 .github/workflows/pull_requests.yml create mode 100644 .github/workflows/references.yml create mode 100644 .github/workflows/release.yml create mode 100644 .github/workflows/scripts.yml create mode 100644 .github/workflows/style.yml create mode 100644 .gitignore create mode 100644 .isort.cfg create mode 100644 .pydocstyle create mode 100644 CODE_OF_CONDUCT.md create mode 100644 CONTRIBUTING.md create mode 100644 Dockerfile create mode 100644 Dockerfile-api create mode 100644 LICENSE create mode 100644 Makefile create mode 100644 README.md create mode 100644 api/README.md create mode 100644 api/app/config.py create mode 100644 api/app/main.py create mode 100644 api/app/routes/detection.py create mode 100644 api/app/routes/ocr.py create mode 100644 api/app/routes/recognition.py create mode 100644 api/app/schemas.py create mode 100644 api/app/vision.py create mode 100644 api/requirements-dev.txt create mode 100644 api/requirements.txt create mode 100644 api/tests/conftest.py create mode 100644 api/tests/routes/test_detection.py create mode 100644 api/tests/routes/test_ocr.py create mode 100644 api/tests/routes/test_recognition.py create mode 100644 demo/app.py create mode 100644 demo/requirements.txt create mode 100644 docker-compose.yml create mode 100644 docs/Makefile create mode 100644 docs/build.sh create mode 100644 docs/requirements.txt create mode 100644 docs/source/_static/css/mindee.css create mode 100644 docs/source/_static/images/Logo-docTR-white.png create mode 100644 docs/source/_static/images/favicon.ico create mode 100644 docs/source/_static/js/custom.js create mode 100644 docs/source/changelog.rst create mode 100644 docs/source/conf.py create mode 100644 docs/source/datasets.rst create mode 100644 docs/source/index.rst create mode 100644 docs/source/installing.rst create mode 100644 docs/source/io.rst create mode 100644 docs/source/models.rst create mode 120000 docs/source/notebooks.md create mode 100644 docs/source/transforms.rst create mode 100644 docs/source/using_model_export.rst create mode 100644 docs/source/using_models.rst create mode 100644 docs/source/utils.rst create mode 100644 doctr/__init__.py create mode 100644 doctr/datasets/__init__.py create mode 100644 doctr/datasets/cord.py create mode 100644 doctr/datasets/datasets/__init__.py create mode 100644 doctr/datasets/datasets/base.py create mode 100644 doctr/datasets/datasets/pytorch.py create mode 100644 doctr/datasets/datasets/tensorflow.py create mode 100644 doctr/datasets/detection.py create mode 100644 doctr/datasets/doc_artefacts.py create mode 100644 doctr/datasets/funsd.py create mode 100644 doctr/datasets/generator/__init__.py create mode 100644 doctr/datasets/generator/base.py create mode 100644 doctr/datasets/generator/pytorch.py create mode 100644 doctr/datasets/generator/tensorflow.py create mode 100644 doctr/datasets/ic03.py create mode 100644 doctr/datasets/ic13.py create mode 100644 doctr/datasets/iiit5k.py create mode 100644 doctr/datasets/imgur5k.py create mode 100644 doctr/datasets/loader.py create mode 100644 doctr/datasets/ocr.py create mode 100644 doctr/datasets/recognition.py create mode 100644 doctr/datasets/sroie.py create mode 100644 doctr/datasets/svhn.py create mode 100644 doctr/datasets/svt.py create mode 100644 doctr/datasets/synthtext.py create mode 100644 doctr/datasets/utils.py create mode 100644 doctr/datasets/vocabs.py create mode 100644 doctr/file_utils.py create mode 100644 doctr/io/__init__.py create mode 100644 doctr/io/elements.py create mode 100644 doctr/io/html.py create mode 100644 doctr/io/image/__init__.py create mode 100644 doctr/io/image/base.py create mode 100644 doctr/io/image/pytorch.py create mode 100644 doctr/io/image/tensorflow.py create mode 100644 doctr/io/pdf.py create mode 100644 doctr/io/reader.py create mode 100644 doctr/models/__init__.py create mode 100644 doctr/models/_utils.py create mode 100644 doctr/models/artefacts/__init__.py create mode 100644 doctr/models/artefacts/barcode.py create mode 100644 doctr/models/artefacts/face.py create mode 100644 doctr/models/builder.py create mode 100644 doctr/models/classification/__init__.py create mode 100644 doctr/models/classification/magc_resnet/__init__.py create mode 100644 doctr/models/classification/magc_resnet/pytorch.py create mode 100644 doctr/models/classification/magc_resnet/tensorflow.py create mode 100644 doctr/models/classification/mobilenet/__init__.py create mode 100644 doctr/models/classification/mobilenet/pytorch.py create mode 100644 doctr/models/classification/mobilenet/tensorflow.py create mode 100644 doctr/models/classification/predictor/__init__.py create mode 100644 doctr/models/classification/predictor/pytorch.py create mode 100644 doctr/models/classification/predictor/tensorflow.py create mode 100644 doctr/models/classification/resnet/__init__.py create mode 100644 doctr/models/classification/resnet/pytorch.py create mode 100644 doctr/models/classification/resnet/tensorflow.py create mode 100644 doctr/models/classification/vgg/__init__.py create mode 100644 doctr/models/classification/vgg/pytorch.py create mode 100644 doctr/models/classification/vgg/tensorflow.py create mode 100644 doctr/models/classification/zoo.py create mode 100644 doctr/models/core.py create mode 100644 doctr/models/detection/__init__.py create mode 100644 doctr/models/detection/_utils/__init__.py create mode 100644 doctr/models/detection/_utils/pytorch.py create mode 100644 doctr/models/detection/_utils/tensorflow.py create mode 100644 doctr/models/detection/core.py create mode 100644 doctr/models/detection/differentiable_binarization/__init__.py create mode 100644 doctr/models/detection/differentiable_binarization/base.py create mode 100644 doctr/models/detection/differentiable_binarization/pytorch.py create mode 100644 doctr/models/detection/differentiable_binarization/tensorflow.py create mode 100644 doctr/models/detection/linknet/__init__.py create mode 100644 doctr/models/detection/linknet/base.py create mode 100644 doctr/models/detection/linknet/pytorch.py create mode 100644 doctr/models/detection/linknet/tensorflow.py create mode 100644 doctr/models/detection/predictor/__init__.py create mode 100644 doctr/models/detection/predictor/pytorch.py create mode 100644 doctr/models/detection/predictor/tensorflow.py create mode 100644 doctr/models/detection/zoo.py create mode 100644 doctr/models/obj_detection/__init__.py create mode 100644 doctr/models/obj_detection/faster_rcnn/__init__.py create mode 100644 doctr/models/obj_detection/faster_rcnn/pytorch.py create mode 100644 doctr/models/predictor/__init__.py create mode 100644 doctr/models/predictor/base.py create mode 100644 doctr/models/predictor/pytorch.py create mode 100644 doctr/models/predictor/tensorflow.py create mode 100644 doctr/models/preprocessor/__init__.py create mode 100644 doctr/models/preprocessor/pytorch.py create mode 100644 doctr/models/preprocessor/tensorflow.py create mode 100644 doctr/models/recognition/__init__.py create mode 100644 doctr/models/recognition/core.py create mode 100644 doctr/models/recognition/crnn/__init__.py create mode 100644 doctr/models/recognition/crnn/pytorch.py create mode 100644 doctr/models/recognition/crnn/tensorflow.py create mode 100644 doctr/models/recognition/master/__init__.py create mode 100644 doctr/models/recognition/master/base.py create mode 100644 doctr/models/recognition/master/pytorch.py create mode 100644 doctr/models/recognition/master/tensorflow.py create mode 100644 doctr/models/recognition/predictor/__init__.py create mode 100644 doctr/models/recognition/predictor/_utils.py create mode 100644 doctr/models/recognition/predictor/pytorch.py create mode 100644 doctr/models/recognition/predictor/tensorflow.py create mode 100644 doctr/models/recognition/sar/__init__.py create mode 100644 doctr/models/recognition/sar/pytorch.py create mode 100644 doctr/models/recognition/sar/tensorflow.py create mode 100644 doctr/models/recognition/transformer/__init__.py create mode 100644 doctr/models/recognition/transformer/pytorch.py create mode 100644 doctr/models/recognition/transformer/tensorflow.py create mode 100644 doctr/models/recognition/utils.py create mode 100644 doctr/models/recognition/zoo.py create mode 100644 doctr/models/utils/__init__.py create mode 100644 doctr/models/utils/pytorch.py create mode 100644 doctr/models/utils/tensorflow.py create mode 100644 doctr/models/zoo.py create mode 100644 doctr/transforms/__init__.py create mode 100644 doctr/transforms/functional/__init__.py create mode 100644 doctr/transforms/functional/base.py create mode 100644 doctr/transforms/functional/pytorch.py create mode 100644 doctr/transforms/functional/tensorflow.py create mode 100644 doctr/transforms/modules/__init__.py create mode 100644 doctr/transforms/modules/base.py create mode 100644 doctr/transforms/modules/pytorch.py create mode 100644 doctr/transforms/modules/tensorflow.py create mode 100644 doctr/utils/__init__.py create mode 100644 doctr/utils/common_types.py create mode 100644 doctr/utils/data.py create mode 100644 doctr/utils/fonts.py create mode 100644 doctr/utils/geometry.py create mode 100644 doctr/utils/metrics.py create mode 100644 doctr/utils/multithreading.py create mode 100644 doctr/utils/repr.py create mode 100644 doctr/utils/visualization.py create mode 100644 mypy.ini create mode 100644 notebooks/README.md create mode 100644 references/classification/README.md create mode 100644 references/classification/latency.csv create mode 100644 references/classification/latency_pytorch.py create mode 100644 references/classification/latency_tensorflow.py create mode 100644 references/classification/train_pytorch.py create mode 100644 references/classification/train_tensorflow.py create mode 100644 references/classification/utils.py create mode 100644 references/detection/README.md create mode 100644 references/detection/evaluate_pytorch.py create mode 100644 references/detection/evaluate_tensorflow.py create mode 100644 references/detection/latency.csv create mode 100644 references/detection/latency_pytorch.py create mode 100644 references/detection/latency_tensorflow.py create mode 100644 references/detection/results.csv create mode 100644 references/detection/train_pytorch.py create mode 100644 references/detection/train_tensorflow.py create mode 100644 references/detection/utils.py create mode 100644 references/obj_detection/latency.csv create mode 100644 references/obj_detection/latency_pytorch.py create mode 100644 references/obj_detection/train_pytorch.py create mode 100644 references/obj_detection/utils.py create mode 100644 references/recognition/README.md create mode 100644 references/recognition/latency.csv create mode 100644 references/recognition/latency_pytorch.py create mode 100644 references/recognition/latency_tensorflow.py create mode 100644 references/recognition/train_pytorch.py create mode 100644 references/recognition/train_tensorflow.py create mode 100644 references/recognition/utils.py create mode 100644 references/requirements.txt create mode 100644 requirements-pt.txt create mode 100644 requirements.txt create mode 100644 scripts/analyze.py create mode 100644 scripts/collect_env.py create mode 100644 scripts/detect_artefacts.py create mode 100644 scripts/evaluate.py create mode 100644 setup.cfg create mode 100644 setup.py create mode 100644 tests/common/test_core.py create mode 100644 tests/common/test_datasets.py create mode 100644 tests/common/test_datasets_utils.py create mode 100644 tests/common/test_headers.py create mode 100644 tests/common/test_io.py create mode 100644 tests/common/test_io_elements.py create mode 100644 tests/common/test_models.py create mode 100644 tests/common/test_models_artefacts.py create mode 100644 tests/common/test_models_builder.py create mode 100644 tests/common/test_models_detection.py create mode 100644 tests/common/test_models_recognition_predictor.py create mode 100644 tests/common/test_models_recognition_utils.py create mode 100644 tests/common/test_requirements.py create mode 100644 tests/common/test_transforms.py create mode 100644 tests/common/test_utils_fonts.py create mode 100644 tests/common/test_utils_geometry.py create mode 100644 tests/common/test_utils_metrics.py create mode 100644 tests/common/test_utils_multithreading.py create mode 100644 tests/common/test_utils_visualization.py create mode 100644 tests/conftest.py create mode 100644 tests/pytorch/test_datasets_pt.py create mode 100644 tests/pytorch/test_file_utils_pt.py create mode 100644 tests/pytorch/test_io_image_pt.py create mode 100644 tests/pytorch/test_models_classification_pt.py create mode 100644 tests/pytorch/test_models_detection_pt.py create mode 100644 tests/pytorch/test_models_obj_detection_pt.py create mode 100644 tests/pytorch/test_models_preprocessor_pt.py create mode 100644 tests/pytorch/test_models_recognition_pt.py create mode 100644 tests/pytorch/test_models_utils_pt.py create mode 100644 tests/pytorch/test_models_zoo_pt.py create mode 100644 tests/pytorch/test_transforms_pt.py create mode 100644 tests/requirements.txt create mode 100644 tests/tensorflow/test_datasets_loader_tf.py create mode 100644 tests/tensorflow/test_datasets_tf.py create mode 100644 tests/tensorflow/test_file_utils_tf.py create mode 100644 tests/tensorflow/test_io_image_tf.py create mode 100644 tests/tensorflow/test_models_classification_tf.py create mode 100644 tests/tensorflow/test_models_detection_tf.py create mode 100644 tests/tensorflow/test_models_preprocessor_tf.py create mode 100644 tests/tensorflow/test_models_recognition_tf.py create mode 100644 tests/tensorflow/test_models_utils_tf.py create mode 100644 tests/tensorflow/test_models_zoo_tf.py create mode 100644 tests/tensorflow/test_transforms_tf.py diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000000..3850f9124d --- /dev/null +++ b/.coveragerc @@ -0,0 +1,2 @@ +[run] +source = doctr diff --git a/.flake8 b/.flake8 new file mode 100644 index 0000000000..6cd2696bf3 --- /dev/null +++ b/.flake8 @@ -0,0 +1,5 @@ +[flake8] +max-line-length = 120 +ignore = E402, E265, F403, W503, W504, E731 +exclude = .circleci, .git, venv*, docs, build +per-file-ignores = **/__init__.py:F401 diff --git a/.github/ISSUE_TEMPLATE/bug_report.yml b/.github/ISSUE_TEMPLATE/bug_report.yml new file mode 100644 index 0000000000..1a79a40fb1 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.yml @@ -0,0 +1,63 @@ +name: 🐛 Bug report +description: Create a report to help us improve the library +labels: bug + +body: +- type: markdown + attributes: + value: > + #### Before reporting a bug, please check that the issue hasn't already been addressed in [the existing and past issues](https://github.com/mindee/doctr/issues?q=is%3Aissue). +- type: textarea + attributes: + label: Bug description + description: | + A clear and concise description of what the bug is. + + Please explain the result you observed and the behavior you were expecting. + placeholder: | + A clear and concise description of what the bug is. + validations: + required: true + +- type: textarea + attributes: + label: Code snippet to reproduce the bug + description: | + Sample code to reproduce the problem. + + Please wrap your code snippet with ```` ```triple quotes blocks``` ```` for readability. + placeholder: | + ```python + Sample code to reproduce the problem + ``` + validations: + required: true +- type: textarea + attributes: + label: Error traceback + description: | + The error message you received running the code snippet, with the full traceback. + + Please wrap your error message with ```` ```triple quotes blocks``` ```` for readability. + placeholder: | + ``` + The error message you got, with the full traceback. + ``` + validations: + required: true +- type: textarea + attributes: + label: Environment + description: | + Please run the following command and paste the output below. + ```sh + wget https://raw.githubusercontent.com/mindee/doctr/main/scripts/collect_env.py + # For security purposes, please check the contents of collect_env.py before running it. + python collect_env.py + ``` + validations: + required: true +- type: markdown + attributes: + value: > + Thanks for helping us improve the library! \ No newline at end of file diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000..7670faa78d --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: true +contact_links: + - name: Usage questions + url: https://github.com/mindee/doctr/discussions + about: Ask questions and discuss with other docTR community members diff --git a/.github/ISSUE_TEMPLATE/feature_request.yml b/.github/ISSUE_TEMPLATE/feature_request.yml new file mode 100644 index 0000000000..63e35c4a53 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.yml @@ -0,0 +1,33 @@ +name: 🚀 Feature request +description: Submit a proposal/request for a new feature for docTR +labels: enhancement + +body: +- type: textarea + attributes: + label: 🚀 The feature + description: > + A clear and concise description of the feature proposal + validations: + required: true +- type: textarea + attributes: + label: Motivation, pitch + description: > + Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too. + validations: + required: true +- type: textarea + attributes: + label: Alternatives + description: > + A description of any alternative solutions or features you've considered, if any. +- type: textarea + attributes: + label: Additional context + description: > + Add any other context or screenshots about the feature request. +- type: markdown + attributes: + value: > + Thanks for contributing 🎉 \ No newline at end of file diff --git a/.github/release.yml b/.github/release.yml new file mode 100644 index 0000000000..2efdfdfcf6 --- /dev/null +++ b/.github/release.yml @@ -0,0 +1,24 @@ +changelog: + exclude: + labels: + - ignore-for-release + categories: + - title: Breaking Changes 🛠 + labels: + - "type: breaking change" + # NEW FEATURES + - title: New Features + labels: + - "type: new feature" + # BUG FIXES + - title: Bug Fixes + labels: + - "type: bug" + # IMPROVEMENTS + - title: Improvements + labels: + - "type: enhancement" + # MISC + - title: Miscellaneous + labels: + - "type: misc" diff --git a/.github/verify_pr_labels.py b/.github/verify_pr_labels.py new file mode 100644 index 0000000000..1d2bd96b68 --- /dev/null +++ b/.github/verify_pr_labels.py @@ -0,0 +1,81 @@ +""" +Borrowed & adapted from https://github.com/pytorch/vision/blob/main/.github/process_commit.py +This script finds the merger responsible for labeling a PR by a commit SHA. It is used by the workflow in +'.github/workflows/pr-labels.yml'. If there exists no PR associated with the commit or the PR is properly labeled, +this script is a no-op. +Note: we ping the merger only, not the reviewers, as the reviewers can sometimes be external to torchvision +with no labeling responsibility, so we don't want to bother them. +""" + +from typing import Any, Set, Tuple + +import requests + +# For a PR to be properly labeled it should have one primary label and one secondary label + +# Should specify the type of change +PRIMARY_LABELS = { + "type: new feature", + "type: bug", + "type: enhancement", + "type: misc", +} + +# Should specify what has been modified +SECONDARY_LABELS = { + "topic: documentation", + "module: datasets", + "module: io", + "module: models", + "module: transforms", + "module: utils", + "ext: api", + "ext: demo", + "ext: docs", + "ext: notebooks", + "ext: references", + "ext: scripts", + "ext: tests", + "topic: build", + "topic: ci", + "topic: docker", +} + +GH_ORG = 'mindee' +GH_REPO = 'doctr' + + +def query_repo(cmd: str, *, accept) -> Any: + response = requests.get(f"https://api.github.com/repos/{GH_ORG}/{GH_REPO}/{cmd}", headers=dict(Accept=accept)) + return response.json() + + +def get_pr_merger_and_labels(pr_number: int) -> Tuple[str, Set[str]]: + # See https://docs.github.com/en/rest/reference/pulls#get-a-pull-request + data = query_repo(f"pulls/{pr_number}", accept="application/vnd.github.v3+json") + merger = data.get("merged_by", {}).get("login") + labels = {label["name"] for label in data["labels"]} + return merger, labels + + +def main(args): + merger, labels = get_pr_merger_and_labels(args.pr) + is_properly_labeled = bool(PRIMARY_LABELS.intersection(labels) and SECONDARY_LABELS.intersection(labels)) + if isinstance(merger, str) and not is_properly_labeled: + print(f"@{merger}") + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='PR label checker', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('pr', type=int, help='PR number') + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/.github/workflows/builds.yml b/.github/workflows/builds.yml new file mode 100644 index 0000000000..de0ecce7e1 --- /dev/null +++ b/.github/workflows/builds.yml @@ -0,0 +1,51 @@ +name: builds + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + build: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest] + python: [3.7, 3.8] + framework: [tensorflow, pytorch] + steps: + - uses: actions/checkout@v2 + - if: matrix.os == 'macos-latest' + name: Install MacOS prerequisites + run: brew install cairo pango gdk-pixbuf libffi + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - if: matrix.framework == 'tensorflow' + name: Cache python modules (TF) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }} + - if: matrix.framework == 'pytorch' + name: Cache python modules (PT) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }} + - if: matrix.framework == 'tensorflow' + name: Install package (TF) + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + - if: matrix.framework == 'pytorch' + name: Install package (PT) + run: | + python -m pip install --upgrade pip + pip install -e .[torch] --upgrade + - name: Import package + run: python -c "import doctr; print(doctr.__version__)" diff --git a/.github/workflows/demo.yml b/.github/workflows/demo.yml new file mode 100644 index 0000000000..0b167e955c --- /dev/null +++ b/.github/workflows/demo.yml @@ -0,0 +1,45 @@ +name: demo + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + test-demo: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: [3.8] + steps: + - if: matrix.os == 'macos-latest' + name: Install MacOS prerequisites + run: brew install cairo pango gdk-pixbuf libffi + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('demo/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}- + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + pip install -r demo/requirements.txt + + - name: Run demo + run: | + streamlit --version + screen -dm streamlit run demo/app.py + sleep 10 + curl http://localhost:8501/docs diff --git a/.github/workflows/doc-status.yml b/.github/workflows/doc-status.yml new file mode 100644 index 0000000000..e6551824bf --- /dev/null +++ b/.github/workflows/doc-status.yml @@ -0,0 +1,22 @@ +name: doc-status +on: + page_build + +jobs: + see-page-build-payload: + runs-on: ubuntu-latest + steps: + - name: Set up Python 3.7 + uses: actions/setup-python@v1 + with: + python-version: 3.7 + architecture: x64 + - name: check status + run: | + import os + status, errormsg = os.getenv('STATUS'), os.getenv('ERROR') + if status != 'built': raise AssertionError(f"There was an error building the page on GitHub pages.\n\nStatus: {status}\n\nError messsage: {errormsg}") + shell: python + env: + STATUS: ${{ github.event.build.status }} + ERROR: ${{ github.event.build.error.message }} diff --git a/.github/workflows/docker.yml b/.github/workflows/docker.yml new file mode 100644 index 0000000000..1cdbd22c74 --- /dev/null +++ b/.github/workflows/docker.yml @@ -0,0 +1,30 @@ +name: docker + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + docker-package: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Build docker image + run: docker build . -t doctr-py3.8.1-tf2.4-slim + - name: Run docker container + run: docker run doctr-py3.8.1-tf2.4-slim python -c 'import doctr' + + pytest-api: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Build & run docker + run: PORT=8002 docker-compose up -d --build + - name: Install dependencies in docker + run: | + PORT=8002 docker-compose exec -T web python -m pip install --upgrade pip + PORT=8002 docker-compose exec -T web pip install -r requirements-dev.txt + - name: Run docker test + run: PORT=8002 docker-compose exec -T web pytest . diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 0000000000..a438a6ad69 --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,53 @@ +name: docs +on: + push: + branches: main + +jobs: + docs-deploy: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + with: + persist-credentials: false + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('docs/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}- + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[tf] + pip install -e .[docs] + + - name: Build documentation + run: cd docs && bash build.sh + + - name: Documentation sanity check + run: test -e docs/build/index.html || exit + + - name: Install SSH Client 🔑 + uses: webfactory/ssh-agent@v0.4.1 + with: + ssh-private-key: ${{ secrets.SSH_DEPLOY_KEY }} + + - name: Deploy to Github Pages + uses: JamesIves/github-pages-deploy-action@3.7.1 + with: + BRANCH: gh-pages + FOLDER: 'docs/build' + COMMIT_MESSAGE: '[skip ci] Documentation updates' + CLEAN: true + SSH: true diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml new file mode 100644 index 0000000000..3bfe2998a4 --- /dev/null +++ b/.github/workflows/main.yml @@ -0,0 +1,127 @@ +name: tests + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + pytest-common: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('tests/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}- + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + pip install -e .[testing] + - name: Run unittests + run: | + coverage run -m pytest tests/common/ + coverage xml -o coverage-common.xml + - uses: actions/upload-artifact@v2 + with: + name: coverage-common + path: ./coverage-common.xml + if-no-files-found: error + + pytest-tf: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('tests/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}- + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + pip install -e .[testing] + - name: Run unittests + run: | + coverage run -m pytest tests/tensorflow/ + coverage xml -o coverage-tf.xml + - uses: actions/upload-artifact@v2 + with: + name: coverage-tf + path: ./coverage-tf.xml + if-no-files-found: error + + pytest-torch: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('tests/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}- + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[torch] --upgrade + pip install -e .[testing] + + - name: Run unittests + run: | + coverage run -m pytest tests/pytorch/ + coverage xml -o coverage-pt.xml + + - uses: actions/upload-artifact@v2 + with: + name: coverage-pytorch + path: ./coverage-pt.xml + if-no-files-found: error + + codecov-upload: + runs-on: ubuntu-latest + needs: [ pytest-common, pytest-tf, pytest-torch ] + steps: + - uses: actions/checkout@v2 + - uses: actions/download-artifact@v2 + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v1 + with: + flags: unittests + fail_ci_if_error: true diff --git a/.github/workflows/pr-labels.yml b/.github/workflows/pr-labels.yml new file mode 100644 index 0000000000..202ecdec4c --- /dev/null +++ b/.github/workflows/pr-labels.yml @@ -0,0 +1,29 @@ +name: pr-labels + +on: + pull_request: + branches: main + types: closed + +jobs: + is-properly-labeled: + if: github.event.pull_request.merged == true + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@v2 + - name: Set up python + uses: actions/setup-python@v2 + - name: Install requests + run: pip install requests + - name: Process commit and find merger responsible for labeling + id: commit + run: echo "::set-output name=merger::$(python .github/verify_pr_labels.py ${{ github.event.pull_request.number }})" + - name: 'Comment PR' + uses: actions/github-script@0.3.0 + if: ${{ steps.commit.outputs.merger != '' }} + with: + github-token: ${{ secrets.GITHUB_TOKEN }} + script: | + const { issue: { number: issue_number }, repo: { owner, repo } } = context; + github.issues.createComment({ issue_number, owner, repo, body: 'Hey ${{ steps.commit.outputs.merger }} 👋\nYou merged this PR, but it is not correctly labeled. The list of valid labels is available at https://github.com/mindee/doctr/blob/main/.github/verify_pr_labels.py' }); diff --git a/.github/workflows/pull_requests.yml b/.github/workflows/pull_requests.yml new file mode 100644 index 0000000000..007be1bd1a --- /dev/null +++ b/.github/workflows/pull_requests.yml @@ -0,0 +1,34 @@ +name: pull_requests + +on: + pull_request: + branches: main + +jobs: + docs-build: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.7 + uses: actions/setup-python@v2 + with: + python-version: 3.7 + architecture: x64 + - name: Cache python modules + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('docs/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}- + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + pip install -e .[docs] + + - name: Build documentation + run: cd docs && bash build.sh + + - name: Documentation sanity check + run: test -e docs/build/index.html || exit diff --git a/.github/workflows/references.yml b/.github/workflows/references.yml new file mode 100644 index 0000000000..c0b8627d5d --- /dev/null +++ b/.github/workflows/references.yml @@ -0,0 +1,299 @@ +name: references + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + train-char-classification: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: [3.8] + framework: [tensorflow, pytorch] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - if: matrix.framework == 'tensorflow' + name: Cache python modules (TF) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('references/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}- + - if: matrix.framework == 'pytorch' + name: Cache python modules (PT) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('references/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}- + - if: matrix.framework == 'tensorflow' + name: Install dependencies (TF) + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + pip install -r references/requirements.txt + sudo apt-get update && sudo apt-get install fonts-freefont-ttf -y + - if: matrix.framework == 'pytorch' + name: Install dependencies (PT) + run: | + python -m pip install --upgrade pip + pip install -e .[torch] --upgrade + pip install -r references/requirements.txt + sudo apt-get update && sudo apt-get install fonts-freefont-ttf -y + - if: matrix.framework == 'tensorflow' + name: Train for a short epoch (TF) + run: python references/classification/train_tensorflow.py resnet18 -b 32 --val-samples 1 --train-samples 1 --epochs 1 + - if: matrix.framework == 'pytorch' + name: Train for a short epoch (PT) + run: python references/classification/train_pytorch.py mobilenet_v3_small -b 32 --val-samples 1 --train-samples 1 --epochs 1 + + train-text-recognition: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: [3.8] + framework: [tensorflow, pytorch] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - if: matrix.framework == 'tensorflow' + name: Cache python modules (TF) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('references/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}- + - if: matrix.framework == 'pytorch' + name: Cache python modules (PT) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('references/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}- + - if: matrix.framework == 'tensorflow' + name: Install dependencies (TF) + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + pip install -r references/requirements.txt + - if: matrix.framework == 'pytorch' + name: Install dependencies (PT) + run: | + python -m pip install --upgrade pip + pip install -e .[torch] --upgrade + pip install -r references/requirements.txt + - name: Download and extract toy set + run: | + wget https://github.com/mindee/doctr/releases/download/v0.3.1/toy_recogition_set-036a4d80.zip + sudo apt-get update && sudo apt-get install unzip -y + unzip toy_recogition_set-036a4d80.zip -d reco_set + - if: matrix.framework == 'tensorflow' + name: Train for a short epoch (TF) + run: python references/recognition/train_tensorflow.py ./reco_set ./reco_set crnn_vgg16_bn -b 4 --epochs 1 + - if: matrix.framework == 'pytorch' + name: Train for a short epoch (PT) + run: python references/recognition/train_pytorch.py ./reco_set ./reco_set crnn_mobilenet_v3_small -b 4 --epochs 1 + + latency-text-recognition: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: [3.8] + framework: [tensorflow, pytorch] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - if: matrix.framework == 'tensorflow' + name: Cache python modules (TF) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }} + - if: matrix.framework == 'pytorch' + name: Cache python modules (PT) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }} + - if: matrix.framework == 'tensorflow' + name: Install dependencies (TF) + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + - if: matrix.framework == 'pytorch' + name: Install dependencies (PT) + run: | + python -m pip install --upgrade pip + pip install -e .[torch] --upgrade + - if: matrix.framework == 'tensorflow' + name: Benchmark latency (TF) + run: python references/recognition/latency_tensorflow.py crnn_vgg16_bn --it 5 + - if: matrix.framework == 'pytorch' + name: Benchmark latency (PT) + run: python references/recognition/latency_pytorch.py crnn_mobilenet_v3_small --it 5 + + train-text-detection: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: [3.8] + framework: [tensorflow, pytorch] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - if: matrix.framework == 'tensorflow' + name: Cache python modules (TF) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}-${{ hashFiles('references/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }}- + - if: matrix.framework == 'pytorch' + name: Cache python modules (PT) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}-${{ hashFiles('references/requirements.txt') }} + restore-keys: | + ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }}- + - if: matrix.framework == 'tensorflow' + name: Install dependencies (TF) + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + pip install -r references/requirements.txt + - if: matrix.framework == 'pytorch' + name: Install dependencies (PT) + run: | + python -m pip install --upgrade pip + pip install -e .[torch] --upgrade + pip install -r references/requirements.txt + - name: Download and extract toy set + run: | + wget https://github.com/mindee/doctr/releases/download/v0.3.1/toy_detection_set-bbbb4243.zip + sudo apt-get update && sudo apt-get install unzip -y + unzip toy_detection_set-bbbb4243.zip -d det_set + - if: matrix.framework == 'tensorflow' + name: Train for a short epoch (TF) + run: python references/detection/train_tensorflow.py ./det_set ./det_set db_resnet50 -b 2 --epochs 1 + - if: matrix.framework == 'pytorch' + name: Train for a short epoch (PT) + run: python references/detection/train_pytorch.py ./det_set ./det_set db_mobilenet_v3_large -b 2 --epochs 1 + + latency-text-detection: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: [3.8] + framework: [tensorflow, pytorch] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - if: matrix.framework == 'tensorflow' + name: Cache python modules (TF) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }} + - if: matrix.framework == 'pytorch' + name: Cache python modules (PT) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }} + - if: matrix.framework == 'tensorflow' + name: Install dependencies (TF) + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + - if: matrix.framework == 'pytorch' + name: Install dependencies (PT) + run: | + python -m pip install --upgrade pip + pip install -e .[torch] --upgrade + - if: matrix.framework == 'tensorflow' + name: Benchmark latency (TF) + run: python references/detection/latency_tensorflow.py linknet_resnet18 --it 5 --size 512 + - if: matrix.framework == 'pytorch' + name: Benchmark latency (PT) + run: python references/detection/latency_pytorch.py linknet_resnet18 --it 5 --size 512 + + latency-object-detection: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: [3.8] + framework: [pytorch] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v1 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - if: matrix.framework == 'tensorflow' + name: Cache python modules (TF) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }} + - if: matrix.framework == 'pytorch' + name: Cache python modules (PT) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }} + - if: matrix.framework == 'tensorflow' + name: Install dependencies (TF) + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + - if: matrix.framework == 'pytorch' + name: Install dependencies (PT) + run: | + python -m pip install --upgrade pip + pip install -e .[torch] --upgrade + - if: matrix.framework == 'pytorch' + name: Benchmark latency (PT) + run: python references/obj_detection/latency_pytorch.py fasterrcnn_mobilenet_v3_large_fpn --it 5 --size 512 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 0000000000..419fcfb299 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,66 @@ +name: pypi-publish + +on: + release: + types: [published] + +jobs: + + pypi-publish: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install setuptools wheel twine --upgrade + - name: Get release tag + id: release_tag + run: | + echo ::set-output name=VERSION::${GITHUB_REF/refs\/tags\//} + - name: Build and publish + env: + TWINE_USERNAME: ${{ secrets.PYPI_USERNAME }} + TWINE_PASSWORD: ${{ secrets.PYPI_PASSWORD }} + VERSION: ${{ steps.release_tag.outputs.VERSION }} + run: | + BUILD_VERSION=$VERSION python setup.py sdist bdist_wheel + twine check dist/* + twine upload dist/* + + pypi-check: + if: "!github.event.release.prerelease" + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: [3.7] + needs: pypi-publish + steps: + - uses: actions/checkout@v2 + - name: Set up Python 3.7 + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Install package + run: | + python -m pip install --upgrade pip + pip install python-doctr + python -c "import doctr; print(doctr.__version__)" diff --git a/.github/workflows/scripts.yml b/.github/workflows/scripts.yml new file mode 100644 index 0000000000..04bbe4a5d8 --- /dev/null +++ b/.github/workflows/scripts.yml @@ -0,0 +1,114 @@ +name: scripts + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + test-analyze: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: [3.7, 3.8] + framework: [tensorflow, pytorch] + steps: + - if: matrix.os == 'macos-latest' + name: Install MacOS prerequisites + run: brew install cairo pango gdk-pixbuf libffi + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - if: matrix.framework == 'tensorflow' + name: Cache python modules (TF) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }} + - if: matrix.framework == 'pytorch' + name: Cache python modules (PT) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }} + - if: matrix.framework == 'tensorflow' + name: Install package (TF) + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + - if: matrix.framework == 'pytorch' + name: Install package (PT) + run: | + python -m pip install --upgrade pip + pip install -e .[torch] --upgrade + + - name: Run analysis script + run: | + wget https://github.com/mindee/doctr/releases/download/v0.1.0/sample.pdf + python scripts/analyze.py sample.pdf --noblock + + test-evaluate: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest] + python: [3.7, 3.8] + framework: [tensorflow, pytorch] + steps: + - if: matrix.os == 'macos-latest' + name: Install MacOS prerequisites + run: brew install cairo pango gdk-pixbuf libffi + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - if: matrix.framework == 'tensorflow' + name: Cache python modules (TF) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }} + - if: matrix.framework == 'pytorch' + name: Cache python modules (PT) + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements-pt.txt') }} + - if: matrix.framework == 'tensorflow' + name: Install package (TF) + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + - if: matrix.framework == 'pytorch' + name: Install package (PT) + run: | + python -m pip install --upgrade pip + pip install -e .[torch] --upgrade + - name: Run evaluation script + run: python scripts/evaluate.py db_resnet50 crnn_vgg16_bn --samples 10 + + test-collectenv: + runs-on: ${{ matrix.os }} + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python: [3.7, 3.8] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Run environment collection script + run: python scripts/collect_env.py diff --git a/.github/workflows/style.yml b/.github/workflows/style.yml new file mode 100644 index 0000000000..befb8e4f3b --- /dev/null +++ b/.github/workflows/style.yml @@ -0,0 +1,94 @@ +name: style + +on: + push: + branches: main + pull_request: + branches: main + +jobs: + flake8-py3: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Run flake8 + run: | + pip install flake8 + flake8 --version + flake8 ./ + + isort-py3: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Run isort + run: | + pip install isort + isort --version + isort . + if [ -n "$(git status --porcelain --untracked-files=no)" ]; then exit 1; else echo "All clear"; fi + + mypy-py3: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Cache python modules + uses: actions/cache@v2 + with: + path: ~/.cache/pip + key: ${{ runner.os }}-pkg-deps-${{ matrix.python }}-${{ hashFiles('requirements.txt') }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e .[tf] --upgrade + pip install mypy + - name: Run mypy + run: | + mypy --version + mypy --config-file mypy.ini doctr/ + + pydocstyle-py3: + runs-on: ${{ matrix.os }} + strategy: + matrix: + os: [ubuntu-latest] + python: [3.7] + steps: + - uses: actions/checkout@v2 + - name: Set up Python + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python }} + architecture: x64 + - name: Run pydocstyle + run: | + pip install pydocstyle + pydocstyle --version + pydocstyle doctr/ diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000..1c285ec6bb --- /dev/null +++ b/.gitignore @@ -0,0 +1,140 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +pip-wheel-metadata/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +.python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# Temp files +doctr/version.py +logs/ +wandb/ +.idea/ + +# Checkpoints +*.pt +*.pb +*.index diff --git a/.isort.cfg b/.isort.cfg new file mode 100644 index 0000000000..d98c63384d --- /dev/null +++ b/.isort.cfg @@ -0,0 +1,5 @@ +[settings] +line_length = 120 +src_paths = doctr,tests +skip_glob=**/__init__.py +known_third_party=tensorflow,torch,torchvision,wandb,fastprogress diff --git a/.pydocstyle b/.pydocstyle new file mode 100644 index 0000000000..f81d27efc9 --- /dev/null +++ b/.pydocstyle @@ -0,0 +1,3 @@ +[pydocstyle] +select = D300,D301,D417 +match = .*\.py diff --git a/CODE_OF_CONDUCT.md b/CODE_OF_CONDUCT.md new file mode 100644 index 0000000000..ee84f1d7db --- /dev/null +++ b/CODE_OF_CONDUCT.md @@ -0,0 +1,128 @@ +# Contributor Covenant Code of Conduct + +## Our Pledge + +We as members, contributors, and leaders pledge to make participation in our +community a harassment-free experience for everyone, regardless of age, body +size, visible or invisible disability, ethnicity, sex characteristics, gender +identity and expression, level of experience, education, socio-economic status, +nationality, personal appearance, race, religion, or sexual identity +and orientation. + +We pledge to act and interact in ways that contribute to an open, welcoming, +diverse, inclusive, and healthy community. + +## Our Standards + +Examples of behavior that contributes to a positive environment for our +community include: + +* Demonstrating empathy and kindness toward other people +* Being respectful of differing opinions, viewpoints, and experiences +* Giving and gracefully accepting constructive feedback +* Accepting responsibility and apologizing to those affected by our mistakes, + and learning from the experience +* Focusing on what is best not just for us as individuals, but for the + overall community + +Examples of unacceptable behavior include: + +* The use of sexualized language or imagery, and sexual attention or + advances of any kind +* Trolling, insulting or derogatory comments, and personal or political attacks +* Public or private harassment +* Publishing others' private information, such as a physical or email + address, without their explicit permission +* Other conduct which could reasonably be considered inappropriate in a + professional setting + +## Enforcement Responsibilities + +Community leaders are responsible for clarifying and enforcing our standards of +acceptable behavior and will take appropriate and fair corrective action in +response to any behavior that they deem inappropriate, threatening, offensive, +or harmful. + +Community leaders have the right and responsibility to remove, edit, or reject +comments, commits, code, wiki edits, issues, and other contributions that are +not aligned to this Code of Conduct, and will communicate reasons for moderation +decisions when appropriate. + +## Scope + +This Code of Conduct applies within all community spaces, and also applies when +an individual is officially representing the community in public spaces. +Examples of representing our community include using an official e-mail address, +posting via an official social media account, or acting as an appointed +representative at an online or offline event. + +## Enforcement + +Instances of abusive, harassing, or otherwise unacceptable behavior may be +reported to the community leaders responsible for enforcement at +contact@mindee.com. +All complaints will be reviewed and investigated promptly and fairly. + +All community leaders are obligated to respect the privacy and security of the +reporter of any incident. + +## Enforcement Guidelines + +Community leaders will follow these Community Impact Guidelines in determining +the consequences for any action they deem in violation of this Code of Conduct: + +### 1. Correction + +**Community Impact**: Use of inappropriate language or other behavior deemed +unprofessional or unwelcome in the community. + +**Consequence**: A private, written warning from community leaders, providing +clarity around the nature of the violation and an explanation of why the +behavior was inappropriate. A public apology may be requested. + +### 2. Warning + +**Community Impact**: A violation through a single incident or series +of actions. + +**Consequence**: A warning with consequences for continued behavior. No +interaction with the people involved, including unsolicited interaction with +those enforcing the Code of Conduct, for a specified period of time. This +includes avoiding interactions in community spaces as well as external channels +like social media. Violating these terms may lead to a temporary or +permanent ban. + +### 3. Temporary Ban + +**Community Impact**: A serious violation of community standards, including +sustained inappropriate behavior. + +**Consequence**: A temporary ban from any sort of interaction or public +communication with the community for a specified period of time. No public or +private interaction with the people involved, including unsolicited interaction +with those enforcing the Code of Conduct, is allowed during this period. +Violating these terms may lead to a permanent ban. + +### 4. Permanent Ban + +**Community Impact**: Demonstrating a pattern of violation of community +standards, including sustained inappropriate behavior, harassment of an +individual, or aggression toward or disparagement of classes of individuals. + +**Consequence**: A permanent ban from any sort of public interaction within +the community. + +## Attribution + +This Code of Conduct is adapted from the [Contributor Covenant][homepage], +version 2.0, available at +https://www.contributor-covenant.org/version/2/0/code_of_conduct.html. + +Community Impact Guidelines were inspired by [Mozilla's code of conduct +enforcement ladder](https://github.com/mozilla/diversity). + +[homepage]: https://www.contributor-covenant.org + +For answers to common questions about this code of conduct, see the FAQ at +https://www.contributor-covenant.org/faq. Translations are available at +https://www.contributor-covenant.org/translations. diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md new file mode 100644 index 0000000000..485e9c68d4 --- /dev/null +++ b/CONTRIBUTING.md @@ -0,0 +1,122 @@ +# Contributing to docTR + +Everything you need to know to contribute efficiently to the project. + + + +## Codebase structure + +- [doctr](https://github.com/mindee/doctr/blob/main/doctr) - The package codebase +- [tests](https://github.com/mindee/doctr/blob/main/tests) - Python unit tests +- [docs](https://github.com/mindee/doctr/blob/main/docs) - Library documentation building +- [scripts](https://github.com/mindee/doctr/blob/main/scripts) - Example scripts +- [references](https://github.com/mindee/doctr/blob/main/references) - Reference training scripts +- [demo](https://github.com/mindee/doctr/blob/main/demo) - Small demo app to showcase docTR capabilities +- [api](https://github.com/mindee/doctr/blob/main/api) - A minimal template to deploy a REST API with docTR + + +## Continuous Integration + +This project uses the following integrations to ensure proper codebase maintenance: + +- [Github Worklow](https://help.github.com/en/actions/configuring-and-managing-workflows/configuring-a-workflow) - run jobs for package build and coverage +- [Codecov](https://codecov.io/) - reports back coverage results + +As a contributor, you will only have to ensure coverage of your code by adding appropriate unit testing of your code. + + + +## Feedback + +### Feature requests & bug report + +Whether you encountered a problem, or you have a feature suggestion, your input has value and can be used by contributors to reference it in their developments. For this purpose, we advise you to use Github [issues](https://github.com/mindee/doctr/issues). + +First, check whether the topic wasn't already covered in an open / closed issue. If not, feel free to open a new one! When doing so, use issue templates whenever possible and provide enough information for other contributors to jump in. + +### Questions + +If you are wondering how to do something with docTR, or a more general question, you should consider checking out Github [discussions](https://github.com/mindee/doctr/discussions). See it as a Q&A forum, or the docTR-specific StackOverflow! + + +## Developing docTR + +### Developer mode installation + +Install all additional dependencies with the following command: + +```shell +pip install -e .[dev] +``` + +### Commits + +- **Code**: ensure to provide docstrings to your Python code. In doing so, please follow [Google-style](https://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html) so it can ease the process of documentation later. +- **Commit message**: please follow [Udacity guide](http://udacity.github.io/git-styleguide/) + + +### Unit tests + +In order to run the same unit tests as the CI workflows, you can run unittests locally: + +```shell +make test +``` + +### Code quality + +To run all quality checks together + +```shell +make quality +``` + +#### Lint verification + +To ensure that your incoming PR complies with the lint settings, you need to install [flake8](https://flake8.pycqa.org/en/latest/) and run the following command from the repository's root folder: + +```shell +flake8 ./ +``` +This will read the `.flake8` setting file and let you know whether your commits need some adjustments. + +#### Import order + +In order to ensure there is a common import order convention, run [isort](https://github.com/PyCQA/isort) as follows: + +```shell +isort **/*.py +``` +This will reorder the imports of your local files. + +#### Annotation typing + +Additionally, to catch type-related issues and have a cleaner codebase, annotation typing are expected. After installing [mypy](https://github.com/python/mypy), you can run the verifications as follows: + +```shell +mypy --config-file mypy.ini doctr/ +``` +The `mypy.ini` file will be read to check your typing. + +#### Docstring format + +To keep a sane docstring structure, if you install [pydocstyle](https://github.com/PyCQA/pydocstyle), you can verify your docstrings as follows: + +```shell +pydocstyle doctr/ +``` +The `.pydocstyle` file will be read to configure this operation. + + +### Modifying the documentation + +In order to check locally your modifications to the documentation: +```shell +make docs-single-version +``` +You can now open your local version of the documentation located at `docs/_build/index.html` in your browser + + +## Let's connect + +Should you wish to connect somewhere else than on GitHub, feel free to join us on [Slack](https://join.slack.com/t/mindee-community/shared_invite/zt-uzgmljfl-MotFVfH~IdEZxjp~0zldww), where you will find a `#doctr` channel! diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000000..994f49f244 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,19 @@ +FROM python:3.8.1-slim + +ENV PYTHONUNBUFFERED 1 +ENV PYTHONDONTWRITEBYTECODE 1 + + +COPY ./requirements.txt /tmp/requirements.txt +COPY ./README.md /tmp/README.md +COPY ./setup.py /tmp/setup.py +COPY ./doctr /tmp/doctr + +RUN apt-get update \ + && apt-get install --no-install-recommends ffmpeg libsm6 libxext6 -y \ + && pip install --upgrade pip setuptools wheel \ + && pip install -e /tmp/.[tf] \ + && pip cache purge \ + && apt-get autoremove -y \ + && rm -rf /var/lib/apt/lists/* \ + && rm -rf /root/.cache/pip diff --git a/Dockerfile-api b/Dockerfile-api new file mode 100644 index 0000000000..79c51b99ea --- /dev/null +++ b/Dockerfile-api @@ -0,0 +1,34 @@ +FROM tiangolo/uvicorn-gunicorn-fastapi:python3.8-slim + +WORKDIR /app + +# set environment variables +ENV PYTHONDONTWRITEBYTECODE 1 +ENV PYTHONUNBUFFERED 1 +ENV PYTHONPATH "${PYTHONPATH}:/app" + +# copy requirements file +COPY api/requirements.txt /app/api-requirements.txt +COPY ./requirements.txt /tmp/requirements.txt + +RUN apt-get update \ + && apt-get install --no-install-recommends ffmpeg libsm6 libxext6 -y \ + && pip install --upgrade pip setuptools wheel \ + && pip install -r /app/api-requirements.txt \ + && pip install -r /tmp/requirements.txt \ + && pip cache purge \ + && apt-get autoremove -y \ + && rm -rf /var/lib/apt/lists/* \ + && rm -rf /root/.cache/pip + +# install doctr +COPY ./README.md /tmp/README.md +COPY ./setup.py /tmp/setup.py +COPY ./doctr /tmp/doctr + +RUN pip install -e /tmp/.[tf] \ + && pip cache purge \ + && rm -rf /root/.cache/pip + +# copy project +COPY api /app diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000..261eeb9e9f --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/Makefile b/Makefile new file mode 100644 index 0000000000..4a8ba3bf9f --- /dev/null +++ b/Makefile @@ -0,0 +1,33 @@ +# this target runs checks on all files +quality: + isort . -c + flake8 ./ + mypy doctr/ + pydocstyle doctr/ + +# this target runs checks on all files and potentially modifies some of them +style: + isort . + +# Run tests for the library +test: + coverage run -m pytest tests/common/ + USE_TF='1' coverage run -m pytest tests/tensorflow/ + USE_TORCH='1' coverage run -m pytest tests/pytorch/ + +test-common: + coverage run -m pytest tests/common/ + +test-tf: + USE_TF='1' coverage run -m pytest tests/tensorflow/ + +test-torch: + USE_TORCH='1' coverage run -m pytest tests/pytorch/ + +# Check that docs can build +docs-single-version: + sphinx-build docs/source docs/_build -a + +# Check that docs can build +docs: + cd docs && bash build.sh diff --git a/README.md b/README.md new file mode 100644 index 0000000000..c82103bd1e --- /dev/null +++ b/README.md @@ -0,0 +1,283 @@ +

+ +

+ +[![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE) ![Build Status](https://github.com/mindee/doctr/workflows/builds/badge.svg) [![codecov](https://codecov.io/gh/mindee/doctr/branch/main/graph/badge.svg?token=577MO567NM)](https://codecov.io/gh/mindee/doctr) [![CodeFactor](https://www.codefactor.io/repository/github/mindee/doctr/badge?s=bae07db86bb079ce9d6542315b8c6e70fa708a7e)](https://www.codefactor.io/repository/github/mindee/doctr) [![Codacy Badge](https://api.codacy.com/project/badge/Grade/340a76749b634586a498e1c0ab998f08)](https://app.codacy.com/gh/mindee/doctr?utm_source=github.com&utm_medium=referral&utm_content=mindee/doctr&utm_campaign=Badge_Grade) [![Doc Status](https://github.com/mindee/doctr/workflows/doc-status/badge.svg)](https://mindee.github.io/doctr) [![Pypi](https://img.shields.io/badge/pypi-v0.5.0-blue.svg)](https://pypi.org/project/python-doctr/) [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/mindee/doctr) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/quicktour.ipynb) + + +**Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch** + + +What you can expect from this repository: +- efficient ways to parse textual information (localize and identify each word) from your documents +- guidance on how to integrate this in your current architecture + +![OCR_example](https://github.com/mindee/doctr/releases/download/v0.2.0/ocr.png) + +## Quick Tour + +### Getting your pretrained model + +End-to-End OCR is achieved in docTR using a two-stage approach: text detection (localizing words), then text recognition (identify all characters in the word). +As such, you can select the architecture used for [text detection](https://mindee.github.io/doctr/latest/models.html#doctr-models-detection), and the one for [text recognition](https://mindee.github.io/doctr/latest/models.html#doctr-models-recognition) from the list of available implementations. + +```python +from doctr.models import ocr_predictor + +model = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True) +``` + +### Reading files + +Documents can be interpreted from PDF or images: + +```python +from doctr.io import DocumentFile +# PDF +pdf_doc = DocumentFile.from_pdf("path/to/your/doc.pdf").as_images() +# Image +single_img_doc = DocumentFile.from_images("path/to/your/img.jpg") +# Webpage +webpage_doc = DocumentFile.from_url("https://www.yoursite.com").as_images() +# Multiple page images +multi_img_doc = DocumentFile.from_images(["path/to/page1.jpg", "path/to/page2.jpg"]) +``` + +### Putting it together +Let's use the default pretrained model for an example: +```python +from doctr.io import DocumentFile +from doctr.models import ocr_predictor + +model = ocr_predictor(pretrained=True) +# PDF +doc = DocumentFile.from_pdf("path/to/your/doc.pdf").as_images() +# Analyze +result = model(doc) +``` + +### Dealing with rotated documents +Should you use docTR on documents that include rotated pages, or pages with multiple box orientations, +you have multiple options to handle it: + +- If you only use straight document pages with straight words (horizontal, same reading direction), +consider passing `assume_straight_boxes=True` to the ocr_predictor. It will directly fit straight boxes +on your page and return straight boxes, which makes it the fastest option. + +- If you want the predictor to output straight boxes (no matter the orientation of your pages, the final localizations +will be converted to straight boxes), you need to pass `export_as_straight_boxes=True` in the predictor. Otherwise, if `assume_straight_pages=False`, it will return rotated bounding boxes (potentially with an angle of 0°). + +If both options are set to False, the predictor will always fit and return rotated boxes. + + +To interpret your model's predictions, you can visualize them interactively as follows: + +```python +result.show(doc) +``` + +![Visualization sample](https://github.com/mindee/doctr/releases/download/v0.1.1/doctr_example_script.gif) + +Or even rebuild the original document from its predictions: + +```python +import matplotlib.pyplot as plt + +synthetic_pages = result.synthesize() +plt.imshow(synthetic_pages[0]); plt.axis('off'); plt.show() +``` + +![Synthesis sample](https://github.com/mindee/doctr/releases/download/v0.3.1/synthesized_sample.png) + + +The `ocr_predictor` returns a `Document` object with a nested structure (with `Page`, `Block`, `Line`, `Word`, `Artefact`). +To get a better understanding of our document model, check our [documentation](https://mindee.github.io/doctr/io.html#document-structure): + +You can also export them as a nested dict, more appropriate for JSON format: + +```python +json_output = result.export() +``` +For examples & further details about the export format, please refer to [this section](https://mindee.github.io/doctr/models.html#export-model-output) of the documentation + +## Installation + +### Prerequisites + +Python 3.6 (or higher) and [pip](https://pip.pypa.io/en/stable/) are required to install docTR. + +Since we use [weasyprint](https://weasyprint.readthedocs.io/), you will need extra dependencies if you are not running Linux. + +For MacOS users, you can install them as follows: +```shell +brew install cairo pango gdk-pixbuf libffi +``` + +For Windows users, those dependencies are included in GTK. You can find the latest installer over [here](https://github.com/tschoonj/GTK-for-Windows-Runtime-Environment-Installer/releases). + +### Latest release + +You can then install the latest release of the package using [pypi](https://pypi.org/project/python-doctr/) as follows: + +```shell +pip install python-doctr +``` +> :warning: Please note that the basic installation is not standalone, as it does not provide a deep learning framework, which is required for the package to run. + +We try to keep framework-specific dependencies to a minimum. You can install framework-specific builds as follows: + +```shell +# for TensorFlow +pip install "python-doctr[tf]" +# for PyTorch +pip install "python-doctr[torch]" +``` + +### Developer mode +Alternatively, you can install it from source, which will require you to install [Git](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git). +First clone the project repository: + +```shell +git clone https://github.com/mindee/doctr.git +pip install -e doctr/. +``` + +Again, if you prefer to avoid the risk of missing dependencies, you can install the TensorFlow or the PyTorch build: +```shell +# for TensorFlow +pip install -e doctr/.[tf] +# for PyTorch +pip install -e doctr/.[torch] +``` + + +## Models architectures +Credits where it's due: this repository is implementing, among others, architectures from published research papers. + +### Text Detection +- DBNet: [Real-time Scene Text Detection with Differentiable Binarization](https://arxiv.org/pdf/1911.08947.pdf). +- LinkNet: [LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation](https://arxiv.org/pdf/1707.03718.pdf) + +### Text Recognition +- CRNN: [An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition](https://arxiv.org/pdf/1507.05717.pdf). +- SAR: [Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition](https://arxiv.org/pdf/1811.00751.pdf). +- MASTER: [MASTER: Multi-Aspect Non-local Network for Scene Text Recognition](https://arxiv.org/pdf/1910.02562.pdf). + + +## More goodies + +### Documentation + +The full package documentation is available [here](https://mindee.github.io/doctr/) for detailed specifications. + + +### Demo app + +A minimal demo app is provided for you to play with our end-to-end OCR models! + +![Demo app](https://github.com/mindee/doctr/releases/download/v0.3.0/demo_update.png) + +#### Live demo + +Courtesy of :hugs: [HuggingFace](https://huggingface.co/) :hugs:, docTR has now a fully deployed version available on [Spaces](https://huggingface.co/spaces)! +Check it out [![Hugging Face Spaces](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Spaces-blue)](https://huggingface.co/spaces/mindee/doctr) + +#### Running it locally + +If you prefer to use it locally, there is an extra dependency ([Streamlit](https://streamlit.io/)) that is required: +```shell +pip install -r demo/requirements.txt +``` +Then run your app in your default browser with: + +```shell +streamlit run demo/app.py +``` + +#### TensorFlow.js + +Instead of having your demo actually running Python, you would prefer to run everything in your web browser? +Check out our [TensorFlow.js demo](https://github.com/mindee/doctr-tfjs-demo) to get started! + +![TFJS demo](https://github.com/mindee/doctr-tfjs-demo/releases/download/v0.1-models/demo_illustration_mini.png) + + +### Docker container + +If you wish to deploy containerized environments, you can use the provided Dockerfile to build a docker image: + +```shell +docker build . -t +``` + +### Example script + +An example script is provided for a simple documentation analysis of a PDF or image file: + +```shell +python scripts/analyze.py path/to/your/doc.pdf +``` +All script arguments can be checked using `python scripts/analyze.py --help` + + +### Minimal API integration + +Looking to integrate docTR into your API? Here is a template to get you started with a fully working API using the wonderful [FastAPI](https://github.com/tiangolo/fastapi) framework. + +#### Deploy your API locally +Specific dependencies are required to run the API template, which you can install as follows: +```shell +pip install -r api/requirements.txt +``` +You can now run your API locally: + +```shell +uvicorn --reload --workers 1 --host 0.0.0.0 --port=8002 --app-dir api/ app.main:app +``` + +Alternatively, you can run the same server on a docker container if you prefer using: +```shell +PORT=8002 docker-compose up -d --build +``` + +#### What you have deployed + +Your API should now be running locally on your port 8002. Access your automatically-built documentation at [http://localhost:8002/redoc](http://localhost:8002/redoc) and enjoy your three functional routes ("/detection", "/recognition", "/ocr"). Here is an example with Python to send a request to the OCR route: + +```python +import requests +with open('/path/to/your/doc.jpg', 'rb') as f: + data = f.read() +response = requests.post("http://localhost:8002/ocr", files={'file': data}).json() +``` + +### Example notebooks +Looking for more illustrations of docTR features? You might want to check the [Jupyter notebooks](https://github.com/mindee/doctr/tree/main/notebooks) designed to give you a broader overview. + + +## Citation + +If you wish to cite this project, feel free to use this [BibTeX](http://www.bibtex.org/) reference: + +```bibtex +@misc{doctr2021, + title={docTR: Document Text Recognition}, + author={Mindee}, + year={2021}, + publisher = {GitHub}, + howpublished = {\url{https://github.com/mindee/doctr}} +} +``` + + +## Contributing + +If you scrolled down to this section, you most likely appreciate open source. Do you feel like extending the range of our supported characters? Or perhaps submitting a paper implementation? Or contributing in any other way? + +You're in luck, we compiled a short guide (cf. [`CONTRIBUTING`](CONTRIBUTING.md)) for you to easily do so! + + +## License + +Distributed under the Apache 2.0 License. See [`LICENSE`](LICENSE) for more information. + diff --git a/api/README.md b/api/README.md new file mode 100644 index 0000000000..dcbab9264e --- /dev/null +++ b/api/README.md @@ -0,0 +1,92 @@ +# Template for your OCR API using docTR + +## Installation + +You will only need to install [Git](https://git-scm.com/book/en/v2/Getting-Started-Installing-Git) and [Docker](https://docs.docker.com/get-docker/). The container environment will be self-sufficient and install the remaining dependencies on its own. + +## Usage + +### Starting your web server + +You will need to clone the repository first: +```shell +git clone https://github.com/mindee/doctr.git +``` +then from the repo root folder, you can start your container: + +```shell +PORT=8050 docker-compose up -d --build +``` +Once completed, your [FastAPI](https://fastapi.tiangolo.com/) server should be running on port 8050 (feel free to change this in the previous command). + +### Documentation and swagger + +FastAPI comes with many advantages including speed and OpenAPI features. For instance, once your server is running, you can access the automatically built documentation and swagger in your browser at: http://localhost:8050/docs + + +### Using the routes + +You will find detailed instructions in the live documentation when your server is up, but here are some examples to use your available API routes: + +#### Text detection + +Using the following image: + + +with this snippet: + +```python +import requests +with open('/path/to/your/img.jpg', 'rb') as f: + data = f.read() +print(requests.post("http://localhost:8050/detection", files={'file': data}).json()) +``` + +should yield +``` +[{'box': [0.826171875, 0.185546875, 0.90234375, 0.201171875]}, + {'box': [0.75390625, 0.185546875, 0.8173828125, 0.201171875]}] +``` + + +#### Text recognition + +Using the following image: +![recognition-sample](https://user-images.githubusercontent.com/76527547/117133599-c073fa00-ada4-11eb-831b-412de4d28341.jpeg) + +with this snippet: + +```python +import requests +with open('/path/to/your/img.jpg', 'rb') as f: + data = f.read() +print(requests.post("http://localhost:8050/recognition", files={'file': data}).json()) +``` + +should yield +``` +{'value': 'invite'} +``` + + +#### End-to-end OCR + +Using the following image: + + +with this snippet: + +```python +import requests +with open('/path/to/your/img.jpg', 'rb') as f: + data = f.read() +print(requests.post("http://localhost:8050/ocr", files={'file': data}).json()) +``` + +should yield +``` +[{'box': [0.75390625, 0.185546875, 0.8173828125, 0.201171875], + 'value': 'Hello'}, + {'box': [0.826171875, 0.185546875, 0.90234375, 0.201171875], + 'value': 'world!'}] +``` diff --git a/api/app/config.py b/api/app/config.py new file mode 100644 index 0000000000..dae28df4ba --- /dev/null +++ b/api/app/config.py @@ -0,0 +1,13 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os + +import doctr + +PROJECT_NAME: str = 'docTR API template' +PROJECT_DESCRIPTION: str = 'Template API for Optical Character Recognition' +VERSION: str = doctr.__version__ +DEBUG: bool = os.environ.get('DEBUG', '') != 'False' diff --git a/api/app/main.py b/api/app/main.py new file mode 100644 index 0000000000..ca97f7c311 --- /dev/null +++ b/api/app/main.py @@ -0,0 +1,46 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import time + +from app import config as cfg +from app.routes import detection, ocr, recognition +from fastapi import FastAPI, Request +from fastapi.openapi.utils import get_openapi + +app = FastAPI(title=cfg.PROJECT_NAME, description=cfg.PROJECT_DESCRIPTION, debug=cfg.DEBUG, version=cfg.VERSION) + + +# Routing +app.include_router(recognition.router, prefix="/recognition", tags=["recognition"]) +app.include_router(detection.router, prefix="/detection", tags=["detection"]) +app.include_router(ocr.router, prefix="/ocr", tags=["ocr"]) + + +# Middleware +@app.middleware("http") +async def add_process_time_header(request: Request, call_next): + start_time = time.time() + response = await call_next(request) + process_time = time.time() - start_time + response.headers["X-Process-Time"] = str(process_time) + return response + + +# Docs +def custom_openapi(): + if app.openapi_schema: + return app.openapi_schema + openapi_schema = get_openapi( + title=cfg.PROJECT_NAME, + version=cfg.VERSION, + description=cfg.PROJECT_DESCRIPTION, + routes=app.routes, + ) + app.openapi_schema = openapi_schema + return app.openapi_schema + + +app.openapi = custom_openapi diff --git a/api/app/routes/detection.py b/api/app/routes/detection.py new file mode 100644 index 0000000000..d41beff65c --- /dev/null +++ b/api/app/routes/detection.py @@ -0,0 +1,22 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import List + +from app.schemas import DetectionOut +from app.vision import det_predictor +from fastapi import APIRouter, File, UploadFile, status + +from doctr.io import decode_img_as_tensor + +router = APIRouter() + + +@router.post("/", response_model=List[DetectionOut], status_code=status.HTTP_200_OK, summary="Perform text detection") +async def text_detection(file: UploadFile = File(...)): + """Runs docTR text detection model to analyze the input image""" + img = decode_img_as_tensor(file.file.read()) + boxes = det_predictor([img])[0] + return [DetectionOut(box=box.tolist()) for box in boxes[:, :-1]] diff --git a/api/app/routes/ocr.py b/api/app/routes/ocr.py new file mode 100644 index 0000000000..20997aaeb7 --- /dev/null +++ b/api/app/routes/ocr.py @@ -0,0 +1,24 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import List + +from app.schemas import OCROut +from app.vision import predictor +from fastapi import APIRouter, File, UploadFile, status + +from doctr.io import decode_img_as_tensor + +router = APIRouter() + + +@router.post("/", response_model=List[OCROut], status_code=status.HTTP_200_OK, summary="Perform OCR") +async def perform_ocr(file: UploadFile = File(...)): + """Runs docTR OCR model to analyze the input image""" + img = decode_img_as_tensor(file.file.read()) + out = predictor([img]) + + return [OCROut(box=(*word.geometry[0], *word.geometry[1]), value=word.value) + for word in out.pages[0].blocks[0].lines[0].words] diff --git a/api/app/routes/recognition.py b/api/app/routes/recognition.py new file mode 100644 index 0000000000..c558854442 --- /dev/null +++ b/api/app/routes/recognition.py @@ -0,0 +1,20 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from app.schemas import RecognitionOut +from app.vision import reco_predictor +from fastapi import APIRouter, File, UploadFile, status + +from doctr.io import decode_img_as_tensor + +router = APIRouter() + + +@router.post("/", response_model=RecognitionOut, status_code=status.HTTP_200_OK, summary="Perform text recognition") +async def text_recognition(file: UploadFile = File(...)): + """Runs docTR text recognition model to analyze the input image""" + img = decode_img_as_tensor(file.file.read()) + out = reco_predictor([img]) + return RecognitionOut(value=out[0][0]) diff --git a/api/app/schemas.py b/api/app/schemas.py new file mode 100644 index 0000000000..11d813b5f8 --- /dev/null +++ b/api/app/schemas.py @@ -0,0 +1,21 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Tuple + +from pydantic import BaseModel, Field + + +# Recognition output +class RecognitionOut(BaseModel): + value: str = Field(..., example="Hello") + + +class DetectionOut(BaseModel): + box: Tuple[float, float, float, float] + + +class OCROut(RecognitionOut, DetectionOut): + pass diff --git a/api/app/vision.py b/api/app/vision.py new file mode 100644 index 0000000000..cec2f9e3fb --- /dev/null +++ b/api/app/vision.py @@ -0,0 +1,16 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import tensorflow as tf + +gpu_devices = tf.config.experimental.list_physical_devices('GPU') +if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) + +from doctr.models import ocr_predictor + +predictor = ocr_predictor(pretrained=True) +det_predictor = predictor.det_predictor +reco_predictor = predictor.reco_predictor diff --git a/api/requirements-dev.txt b/api/requirements-dev.txt new file mode 100644 index 0000000000..250adc4063 --- /dev/null +++ b/api/requirements-dev.txt @@ -0,0 +1,4 @@ +pytest>=5.3.2 +pytest-asyncio>=0.14.0 +asyncpg>=0.20.0 +httpx>=0.16.1,<0.20.0 diff --git a/api/requirements.txt b/api/requirements.txt new file mode 100644 index 0000000000..d522a6f682 --- /dev/null +++ b/api/requirements.txt @@ -0,0 +1,4 @@ +fastapi>=0.65.2 +uvicorn>=0.11.1 +python-multipart==0.0.5 +python-doctr>=0.2.0 diff --git a/api/tests/conftest.py b/api/tests/conftest.py new file mode 100644 index 0000000000..6c87422688 --- /dev/null +++ b/api/tests/conftest.py @@ -0,0 +1,28 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import pytest +import requests +from app.main import app +from httpx import AsyncClient + + +@pytest.fixture(scope="session") +def mock_recognition_image(tmpdir_factory): + url = 'https://user-images.githubusercontent.com/76527547/117133599-c073fa00-ada4-11eb-831b-412de4d28341.jpeg' + return requests.get(url).content + + +@pytest.fixture(scope="session") +def mock_detection_image(tmpdir_factory): + url = 'https://user-images.githubusercontent.com/76527547/117319856-fc35bf00-ae8b-11eb-9b51-ca5aba673466.jpg' + return requests.get(url).content + + +@pytest.fixture(scope="function") +async def test_app_asyncio(): + # for httpx>=20, follow_redirects=True (cf. https://github.com/encode/httpx/releases/tag/0.20.0) + async with AsyncClient(app=app, base_url="http://test") as ac: + yield ac # testing happens here diff --git a/api/tests/routes/test_detection.py b/api/tests/routes/test_detection.py new file mode 100644 index 0000000000..03641b4754 --- /dev/null +++ b/api/tests/routes/test_detection.py @@ -0,0 +1,25 @@ +import numpy as np +import pytest +from scipy.optimize import linear_sum_assignment + +from doctr.utils.metrics import box_iou + + +@pytest.mark.asyncio +async def test_text_detection(test_app_asyncio, mock_detection_image): + + response = await test_app_asyncio.post("/detection", files={'file': mock_detection_image}) + assert response.status_code == 200 + json_response = response.json() + + gt_boxes = np.array([[1240, 430, 1355, 470], [1360, 430, 1495, 470]], dtype=np.float32) + gt_boxes[:, [0, 2]] = gt_boxes[:, [0, 2]] / 1654 + gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] / 2339 + + # Check that IoU with GT if reasonable + assert isinstance(json_response, list) and len(json_response) == gt_boxes.shape[0] + pred_boxes = np.array([elt['box'] for elt in json_response]) + iou_mat = box_iou(gt_boxes, pred_boxes) + gt_idxs, pred_idxs = linear_sum_assignment(-iou_mat) + is_kept = iou_mat[gt_idxs, pred_idxs] >= 0.8 + assert gt_idxs[is_kept].shape[0] == gt_boxes.shape[0] diff --git a/api/tests/routes/test_ocr.py b/api/tests/routes/test_ocr.py new file mode 100644 index 0000000000..0bdb3449f0 --- /dev/null +++ b/api/tests/routes/test_ocr.py @@ -0,0 +1,29 @@ +import numpy as np +import pytest +from scipy.optimize import linear_sum_assignment + +from doctr.utils.metrics import box_iou + + +@pytest.mark.asyncio +async def test_perform_ocr(test_app_asyncio, mock_detection_image): + + response = await test_app_asyncio.post("/ocr", files={'file': mock_detection_image}) + assert response.status_code == 200 + json_response = response.json() + + gt_boxes = np.array([[1240, 430, 1355, 470], [1360, 430, 1495, 470]], dtype=np.float32) + gt_boxes[:, [0, 2]] = gt_boxes[:, [0, 2]] / 1654 + gt_boxes[:, [1, 3]] = gt_boxes[:, [1, 3]] / 2339 + gt_labels = ["Hello", "world!"] + + # Check that IoU with GT if reasonable + assert isinstance(json_response, list) and len(json_response) == gt_boxes.shape[0] + pred_boxes = np.array([elt['box'] for elt in json_response]) + pred_labels = np.array([elt['value'] for elt in json_response]) + iou_mat = box_iou(gt_boxes, pred_boxes) + gt_idxs, pred_idxs = linear_sum_assignment(-iou_mat) + is_kept = iou_mat[gt_idxs, pred_idxs] >= 0.8 + gt_idxs, pred_idxs = gt_idxs[is_kept], pred_idxs[is_kept] + assert gt_idxs.shape[0] == gt_boxes.shape[0] + assert all(gt_labels[gt_idx] == pred_labels[pred_idx] for gt_idx, pred_idx in zip(gt_idxs, pred_idxs)) diff --git a/api/tests/routes/test_recognition.py b/api/tests/routes/test_recognition.py new file mode 100644 index 0000000000..9848c786e9 --- /dev/null +++ b/api/tests/routes/test_recognition.py @@ -0,0 +1,10 @@ +import pytest + + +@pytest.mark.asyncio +async def test_text_recognition(test_app_asyncio, mock_recognition_image): + + response = await test_app_asyncio.post("/recognition", files={'file': mock_recognition_image}) + assert response.status_code == 200 + + assert response.json() == {"value": "invite"} diff --git a/demo/app.py b/demo/app.py new file mode 100644 index 0000000000..cf822ed3ac --- /dev/null +++ b/demo/app.py @@ -0,0 +1,109 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os + +import matplotlib.pyplot as plt +import streamlit as st + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +import cv2 +import tensorflow as tf + +gpu_devices = tf.config.experimental.list_physical_devices('GPU') +if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) + +from doctr.io import DocumentFile +from doctr.models import ocr_predictor +from doctr.utils.visualization import visualize_page + +DET_ARCHS = ["db_resnet50", "db_mobilenet_v3_large"] +RECO_ARCHS = ["crnn_vgg16_bn", "crnn_mobilenet_v3_small", "master", "sar_resnet31"] + + +def main(): + + # Wide mode + st.set_page_config(layout="wide") + + # Designing the interface + st.title("docTR: Document Text Recognition") + # For newline + st.write('\n') + # Instructions + st.markdown("*Hint: click on the top-right corner of an image to enlarge it!*") + # Set the columns + cols = st.columns((1, 1, 1, 1)) + cols[0].subheader("Input page") + cols[1].subheader("Segmentation heatmap") + cols[2].subheader("OCR output") + cols[3].subheader("Page reconstitution") + + # Sidebar + # File selection + st.sidebar.title("Document selection") + # Disabling warning + st.set_option('deprecation.showfileUploaderEncoding', False) + # Choose your own image + uploaded_file = st.sidebar.file_uploader("Upload files", type=['pdf', 'png', 'jpeg', 'jpg']) + if uploaded_file is not None: + if uploaded_file.name.endswith('.pdf'): + doc = DocumentFile.from_pdf(uploaded_file.read()).as_images() + else: + doc = DocumentFile.from_images(uploaded_file.read()) + page_idx = st.sidebar.selectbox("Page selection", [idx + 1 for idx in range(len(doc))]) - 1 + cols[0].image(doc[page_idx]) + + # Model selection + st.sidebar.title("Model selection") + det_arch = st.sidebar.selectbox("Text detection model", DET_ARCHS) + reco_arch = st.sidebar.selectbox("Text recognition model", RECO_ARCHS) + + # For newline + st.sidebar.write('\n') + + if st.sidebar.button("Analyze page"): + + if uploaded_file is None: + st.sidebar.write("Please upload a document") + + else: + with st.spinner('Loading model...'): + predictor = ocr_predictor(det_arch, reco_arch, pretrained=True) + + with st.spinner('Analyzing...'): + + # Forward the image to the model + processed_batches = predictor.det_predictor.pre_processor([doc[page_idx]]) + out = predictor.det_predictor.model(processed_batches[0], return_preds=True) + seg_map = out["out_map"] + seg_map = tf.squeeze(seg_map[0, ...], axis=[2]) + seg_map = cv2.resize(seg_map.numpy(), (doc[page_idx].shape[1], doc[page_idx].shape[0]), + interpolation=cv2.INTER_LINEAR) + # Plot the raw heatmap + fig, ax = plt.subplots() + ax.imshow(seg_map) + ax.axis('off') + cols[1].pyplot(fig) + + # Plot OCR output + out = predictor([doc[page_idx]]) + fig = visualize_page(out.pages[0].export(), doc[page_idx], interactive=False) + cols[2].pyplot(fig) + + # Page reconsitution under input page + page_export = out.pages[0].export() + img = out.pages[0].synthesize() + cols[3].image(img, clamp=True) + + # Display JSON + st.markdown("\nHere are your analysis results in JSON format:") + st.json(page_export) + + +if __name__ == '__main__': + main() diff --git a/demo/requirements.txt b/demo/requirements.txt new file mode 100644 index 0000000000..f16f3aa066 --- /dev/null +++ b/demo/requirements.txt @@ -0,0 +1,2 @@ +-e git+https://github.com/mindee/doctr.git#egg=python-doctr[tf] +streamlit>=1.0.0 diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000000..2902accb24 --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,12 @@ +version: '3.7' + +services: + web: + build: + context: . + dockerfile: Dockerfile-api + command: uvicorn app.main:app --reload --workers 1 --host 0.0.0.0 --port 8080 + volumes: + - ./api/:/usr/src/app/ + ports: + - ${PORT}:8080 diff --git a/docs/Makefile b/docs/Makefile new file mode 100644 index 0000000000..92dd33a1a4 --- /dev/null +++ b/docs/Makefile @@ -0,0 +1,20 @@ +# Minimal makefile for Sphinx documentation +# + +# You can set these variables from the command line, and also +# from the environment for the first two. +SPHINXOPTS ?= +SPHINXBUILD ?= sphinx-build +SOURCEDIR = source +BUILDDIR = _build + +# Put it first so that "make" without argument is like "make help". +help: + @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +.PHONY: help Makefile + +# Catch-all target: route all unknown targets to Sphinx using the new +# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). +%: Makefile + @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/build.sh b/docs/build.sh new file mode 100644 index 0000000000..2446b5f739 --- /dev/null +++ b/docs/build.sh @@ -0,0 +1,44 @@ +function deploy_doc(){ + if [ ! -z "$1" ] + then + git checkout $1 + fi + COMMIT=$(git rev-parse --short HEAD) + echo "Creating doc at commit" $COMMIT "and pushing to folder $2" + pip install -U .. + if [ ! -z "$2" ] + then + if [ "$2" == "latest" ]; then + echo "Pushing main" + sphinx-build source _build -a && mkdir build && mkdir build/$2 && cp -a _build/* build/$2/ + elif [ -d build/$2 ]; then + echo "Directory" $2 "already exists" + else + echo "Pushing version" $2 + cp -r _static source/ && cp _conf.py source/conf.py + sphinx-build source _build -a + mkdir build/$2 && cp -a _build/* build/$2/ && git checkout source/ && git clean -f source/ + fi + else + echo "Pushing stable" + cp -r _static source/ && cp _conf.py source/conf.py + sphinx-build source build -a && git checkout source/ && git clean -f source/ + fi +} + +# You can find the commit for each tag on https://github.com/mindee/doctr/tags +if [ -d build ]; then rm -Rf build; fi +cp -r source/_static . +cp source/conf.py _conf.py +git fetch --all --tags --unshallow +deploy_doc "" latest +deploy_doc "571af3dc" v0.1.0 +deploy_doc "6248b0bd" v0.1.1 +deploy_doc "650c4ad4" v0.2.0 +deploy_doc "1bbdb072" v0.2.1 +deploy_doc "3f051346" v0.3.0 +deploy_doc "369a787d" v0.3.1 +deploy_doc "51663ddf" v0.4.0 +deploy_doc "74ff9ffb" v0.4.1 +deploy_doc "b9d8feb1" # v0.5.0 Latest stable release +rm -rf _build _static _conf.py diff --git a/docs/requirements.txt b/docs/requirements.txt new file mode 100644 index 0000000000..310c98c486 --- /dev/null +++ b/docs/requirements.txt @@ -0,0 +1,7 @@ +sphinx<3.5.0 +sphinx-rtd-theme==0.4.3 +sphinxemoji>=0.1.8 +sphinx-copybutton>=0.3.1 +docutils<0.18 +recommonmark>=0.7.1 +sphinx-markdown-tables>=0.0.15 diff --git a/docs/source/_static/css/mindee.css b/docs/source/_static/css/mindee.css new file mode 100644 index 0000000000..a17612f34c --- /dev/null +++ b/docs/source/_static/css/mindee.css @@ -0,0 +1,43 @@ +/* Version control */ + +.version-button { + color: white; + border: none; + padding: 5px; + font-size: 15px; + cursor: pointer; +} + +.version-button:hover, .version-button:focus { + background-color: #5eb2e6; +} + +.version-dropdown { + display: none; + min-width: 160px; + overflow: auto; + font-size: 15px; +} + +.version-dropdown a { + color: white; + padding: 3px 4px; + text-decoration: none; + display: block; +} + +.version-dropdown a:hover { + background-color: #5eb2e6; +} + +.version-show { + display: block; +} + +.wy-nav-content { + max-width: 1000px !important; +} + +h1, h2, h3, h4, h5, .caption-text { + font-family: "Helvetica Neue", Arial, sans-serif; +} diff --git a/docs/source/_static/images/Logo-docTR-white.png b/docs/source/_static/images/Logo-docTR-white.png new file mode 100644 index 0000000000000000000000000000000000000000..988eb2cf7f86413d2b10e389d286b1c329ec0585 GIT binary patch literal 7568 zcmd^E_g7O*w+<16&>=Jd=?0_<(vjYK57JQq>AeetUS6uA^d6;mkrJv9x)h~JM-U0U z*9hDw@At#K>#q9`+_lchnaQ5aZnIBjo+nmEOPL6t79Rir5UHvt=m7v&i4Ksvv9NhqaSKFmLmstxxlDzmS|F{VFLw?zlBM zKd+{0q#5V^9|~ox=JY@ME8lzEW7{Z;F(A}^Q=ss$vFM?ai>Av%h&;_%#s<3lY6QLO zyXPG=<-B)=hDqJ{?zK%F-ALc8xA5MaJlQhdy8jr~^)@o@q&Yr^KS)DYQ?R;FWi|Kf?NRl|+YtT_xwu%l`<|J449bR_>1x?b&8wdLf_*`OerzhTW=3 zY>iHA4J#Rt=m9Y*%J-utrq_wR$-Im-CGnkh079&QkVpC0*?zcQ-7PbJdRD*_Qn!`L z=oUwzzj|nNfLg705I+{2{wfX)z!N=~>R$Mi`Bx7i2NwGx*22C^!(TIGBr#KM{n$0R z)pPf42zW2V#ZP4JuNnCan5jflKShp&9ODnP(+VtmR1X>3NKxVQH)ji=WUu3> zj_{W1ovaVz_Nvw^HVX*5lzZG(62ol0j?r%wDk98c{7RZh<9(k0jqaE-_FFpkd`5Xb zx%K22KIJ2=Dr8R$W>tLL7x_0vC?y#ijI=75} z%O5+u*&Ax1jZjpAL&iVwCBky{6~gEfzTHIZsA6TaC?{aPM9{of4iSGNKY z^08c0W&E^avdHKEQ!1)V->MiTUIgqu{As`CCU%?rT(%l~T^Qr!mvt5QBoXoiqdz5r zxnyjCDX5cO@|SB!I_18t+k9v+Nrb*oZGR1iYqidV&UXE4_6i;o@m#n@|5#O|3- z4A>QyobOM;h92VWTe=9&c!sk$T0K4TJeS}u9yLY^@)tXs+vYjAbJqupRz-Jf=cDs8 zKhU#)4;lsBuY>BrBtl7$m8E|RMm2=ZA2$GyFn!cSf-N^DxwzAEfmwt~Jwvyy+)V`x zoIz--efgzn@t&!Pydn|ug|FBI$>s}{yBX4c6hLqz1TVQSvBjy|^6pgBylg#~1_PqqwYn1D`4_S+6p7Q#ex; zs(>%q+dRZ~ng2#|bb)<6vUa3S&!IRf-s3DJM>hedjPorBNWOVs2mOi@2LY(ADIYSm z^DQ5>>^C1+s?(lF(|B9V-3#KogbJRrgrC4|+^OPSH(E@&)_?fYE_JBP!RAbUC4FXc zBOd;e<)^5^>O!(Plzh&xM+$_c*g__*$uH(MNsS|uRT3vKBfwU5tz5Z z6~d`W_uim_BRC*waCE-^8(v#;8==%-inh>zDtTWyq0skaa@-C#xS(IW2Gn?J?Zzp* z5ObS*sqG|f8R8SCl4E06m5z!Q-$j8&xJncJT`NdBm5E=I;N=LBToYJgZc~I&`|M{4u!hHVJhPhm=Ayc# zGc`iM*%qYvxly)o=tH#L%!88MDdEYWIjvQrTxyd^f+!y+30a+sXui)OX zyOFSrr@V@3Wn6gErVLmO<%#HIg1vQR<#nV)W(-W^4M5_i7W-7pOG6XO;mwGVV_6H_ z#1D0SDrA?kEAQBEd?qtVz-wsY&9UL~)Io70rz59O8p8YFxjKvtg{rl6T9j5w^If^G ziR_?L4oS&?9MeyRbPB#g_=NTaW4klVmN57cd{MvneUD$u<%;+Boh4V~hG9vryvrLW z^+#!Bd64JEm(iWb&LYu+yTE$?9&FKAXW4i9TvZLVlb>Jkg^gpEK`aHgpq<(=Cnt*G z$o1aC1c$o;U>gBhbGhu?gL!0S3SZIv!wjCfOLU=YC7-60pK{;k(jBFcl(Ycim+(dH zP(YIN`8i~kpb(vd*lYID{k*dE?i@i_dC35qf;dHg{%V|H*6xq}eceLIfY}o-xAQvp zD-WHBPwz2eHzOhO!_L`nkUmud58k=1HC-01(Q9Wy8T0xLRZ;G3q`8-y;pHlJ4dlc; zLfG9;3_p?qeoj`aZEcIv*Nyzf-4_jM)R2D{_)0dB>6hi0i{z>8hB@#hX+VF_tG0s* z*0g5!uPqV^mdYclTf>4neDXAYp^;HMql!i-S(<9dFMa$Nv`QfAAjg6%$8_|?oY0Px zSM`3mkXFN-=&?cAcb-Sd84KGvel5BObtWLYV)hV$m!#6+H?O`%f5ovIUrWJmvQd?l zid#Z|?I>g(!;uAn7<)4qeoIf8U) z6Z^0f7~V{0nN}#R^s@8$*Ld~KZ_ezJ;sqj)+lVGv1^ATXNZYJd)iz)8Mn67D|CA%A zC?9;;hf7=6PCJSAWP}vqAdAvIR3y2?wo~VXP%-2FKpSkL(yjC(j<8N^ADc-;xq0sE ztRFo)ppKpx?WDA%&S;Um$(NquaEr!Bf zo-MH%NdG#BL+DVfD{>za!njPoO06aQV&mk6Vs725D|WV5*mYmfF0h*%zQu_-gT8)u|D zor^PZCKMRc<8L?pey);uP%51BfGg3`hbzT>hPht#W6sBzH^C+JvEjq^>X;G3Vm6bD z(}l&prg`o&*Y$Hb+3{lmtxbs^`-A#o8~Ykv7^A8!bW7|Y8(CuhQg5fp8l-zpHBm48 zgD9-d+ym{-MXFKD((w%?#b-_y?BAv0E1w2RS0|YZG~HZ60`AUrecW(-?-cT zdUf&rNfLv?T3-})$j(0~XsDFaqA7M!?cQ%GPHXl>4gm`bu~4Y1U1R4zu;{DA?c*ch z6?at9*1-YvI8w*bc;{IG^I%ZjYyu~Y(D}Tp#HKlTML8BNiCSx>C!`CqK6hMcB!Q&H z@a#0{7J60SS&=@25W2M4Km!A|A3R-=`(GS5gLN$fH5WrjqZE2GRv)Nsvs{ z9K9@i<-s&Cf4*P2y*>%8M78FlWar#{gU;@{5(m9s4-jP_J2~|j`keaib+7At+bq!u z9+w&dbCRVpUjoeiws8o|{i!+6C${IWfH}LbyMiSKE0PP%3|@!F8iI@P$pY=0u7*G( zf}HtRKU+#5bvyA@;kOuhF+N74NW3^U@K%uUYVslH0+oP2at%f$@? z0g0@T{)v_sReKU zTqOhibyEK=5l`VdC~bH{q5HGG=d3NiL}-_qg1@zX!;OQSQ#bv!aS@lfrSi`MY-5zfad#2lJ6 zms%O#@%#7RcpdA-ODLUlZf*65CaPz_cuWalYH~ zUlB!8Tl4oDpGteq4ph8B&i%7D8zQ}8WR>;LMr&>YP6($L10@AP$SM9J#xIU}gS zz1%g)bi*4$sBQP0_*%P`hLu>q>Vao-yCQ(o+*1y!ErL-2yJ4}(Nm-a zR;dXrjSi4~kMsQA|I=Oz&#-+_AmWu7-RjhX0)K(hkm)m=(1Go(4EDFL1=HT2Ak6Dc zwD29Rd`ZjtHav!!F2F_B#8PDqwU3Lag+H{ZSZX@Ipy+!tgw9W3WAm3~ zUZqS!d8C*9l#1m3-8hd$8-KjGCI*E}I%r?srWZEDXIqJo!mETAt$ZVetG&d_nrP5g~3Bll;IklgR9nuMbW*Nyyt z@FNRAd65ai!Ug;`y)xQ{temp)J>MBsszqFEgPp+HelHIEY$CDc$oGm2kprPcTnh)3 z6W-b~_vx^|+X_w2^&SpN@<*NZRH&PL3f2piGPnHB(=E037-4Mr?6W+a+V;HiRZpJ{ z#}afRwvi|rHle~~UE&}2&k9u1kp3}@*lrYx)nKeCe_e`6`rZI-f6O7cNrqNpr4h5x zaIpf~0v0w@?{n4hd`}6|+%p{IHq2HKz_kevW!s+)fp)~-S)Rl?wbx`X=An7IdkR3S zk@K~B_eeiv8fg-~#_8~1x>KqAS{Og3+(=V$B0Fd(Cw3rv+cu8-t+XeRvBUxIa}=Kt zc#JMhZg;1xwTt3MMWPsWPm=l8#=mI* zi^0%-eM*N(cc;|{-~OC8cpj#ehjk_x9Gb}}Ze|2W-?1tMFgh2V2RtT_nr*1bWRxzX zy>gvt6##Yf{S?}@#U;q$8<98KYLHs1Gd>n@(d1OWw}9KuR%-SopvO<|%jHD5U1Jp4 zs6v{S^Mr*CyLDU9*f$yAwxc%bRB32ILfHlnzs7Rh(0@xM__6fnN8C3i-h_$ zZ?Y{3Dt#Xbsu2b0@3jKxu-Dx7l?j@s&1ATJ$mBzUcKD$%{vwr3!%WM<7ZB)^VM(eaqbY^u z=bMvUIMbXZ!4K}Am_-Hf@ws^fr7wxAKo#h?vI!a3p!fSNIXL=-4m_>Y*F2=I*Y`_3 zL!$<%%-`;}sa%LU`wJI$Yg-k+PG%V~0w1v6e-l~BhdG}Z%-$D&;llxirD`uzl%3Pi z3MgBs`U4z;)n z%qC%96cx8HqQfEv5S`@cngWQ*Xs9x+K2gucNwO0yl;e_dBBs$>I_#9QTPyC`4GYqo z72}4K@-o!w+IJDNuPyebbfvQW;QfC7U_5#?QeR=_WW|kmB?XX|>t196jgpMC> zE$$M$H=2CLr9|{NEHE>Di1q_2^7PCp-lYjs+Wi*v0OOngjeKS5_nIA_Il9{NROuv(g8ai>xFSDri2T61sK13 z2bolSBBE=nf86!H?m_i*Pcj7Vq^04TW;V@oHX)S7=ql4uYxnHLQV{J8dp<+y{jD+= zHY4N^W`OcB$6?H)zgKcQyB{!n{@5ddNOwWKTy6p$z;lu+>_z(xZzX_Y;Xb3Q?1@8P z*_sk(wgMq;VF7!|msl+gx{VEE88_w`BBdqM~m8ojzmTZ9ehxXjdU5?Vku6goRO3^vM0&Y%ZDvz4D zM8S%u=ghkY7wl1v97mR@V7r+%mr2ZFf@{;TdJcnZO>6}T*TfE4!M-G83W1sYdixn!vwkhf9rZ0$VHcmbjMbdyC>m9!gg5kfUdFwO070VVMj2;euF4!qMx|40#A(A7=>|Xxiin9pZ}dQ#8;JBrICYYI+hm5;|n|2NZ6ZZKW*>l5!I`3uZ663gRNCFYG4e2sVlXtazPkuD{*wZEk5N2COr`p%;#$|B?@5vKZ#!RDOAZ>h5Z*lX)lKW literal 0 HcmV?d00001 diff --git a/docs/source/_static/images/favicon.ico b/docs/source/_static/images/favicon.ico new file mode 100644 index 0000000000000000000000000000000000000000..d9bf77d4a068038a8c10ea3a5d11bef3e4ee666c GIT binary patch literal 100942 zcmeHQe{38_6`pHC>@B6t%KiKyIqUiA-r? zrzCL*ZG|c&Qb1Hze~3ST6gAjINN!P@;^0V;keif7CM{`VgVh!cw2rLQDRnsB^WEC( zyWN}F-RJd=?>l+CyR&a*-hAJCZ)W%Ac9lx0Wr{y3Rjo#rD%DE+U0vDZ)#WL5=f{+4 zY|I_sq15)3DOFpWJ$~*pO8xAcRI)mET&Yz3!zoosTdr>l)%VV$DfQ!TtzWnD+D}|d zRj+LRTGIwT%c-23zW002QJYJTG&kLTcgJ7IV&5~tb*|4H< zMx0fC0{zmmbe(~MU|NQV5+J5!o{-d>h8y-CSo%e=2Pj5c))FVH+`NNiH zCLbRjxa*-8Z@r=6H&yNbe7Rxj>F)DGw{|?+u(3k@WtA$wZvOv{ZN6<}d1i3%>e#AJ zPQI|a|GK?{gQL9P-dcaOs{O$yQZEfZb8ufz)rpCJuRrzv^J}+V=-hW@>za?No7In& z$MSdKXjT6NVRfi(Vsh+;O|Ot19}U%XRHZt)Th0ugAF6rwjgu=zmv+^r=Kb%5r@DWC z&uxGI)4*8j-06=7PIY%4-+k)QwZA*Q?~#Wtrnhf8^wRj&U0=G`Th;#d{cAqh(Yvvt zV>!wFp>AU5<%27B9XosJ?AWR&Dd$Z0n;TTst|OzzhHCmwFIV?p6U*P)cd!2K-LX}h z_D*xT?%oe?sq1B>Tc60RsTR;~KN<#k{{6Wj{(br!ce5;50(oUJA``*is-9_hi z_D${>-g#v70-MQR>gM|SH&gNived53+vf(q+*5VwXMkAVPUQdb@i*V7 zX?yL$;I_}Ix{1zKvXGI^r&Z-mvHX3Ieqi0*XNOz9zJnr{YX0i_rk*?Q`{8QZ@8S#! zYaI8oAC7TzKX=}8HM2<5`MEs}e?<74U!)ifAAkZ=(|cH{#!Xm zfd9k)U0xi17yb|b$Nb+q8$kR={0|@WT_q9!t@C|c|ChG!gMaXk@xQcf0Q`f0*DNml zF6Muj|Ka-IIvYUzNBj>T^j#$p|E=?VT>r)I`+LYwniG`jT`nAHm-YF4?WMuL?F>fKAN)td z&-T*b-*yJ0=@0&+;b(hk@ULf(Fsh`~CY0#uamo3Be?1q3(T$W$)~%Sx+^BCcJ_+!z z#~S>L*}jQ?uuP)Xe}RASPc~p`yE-|yz`uQ*NooxKCDF96F7R(3XObF&e@Qg$s|)H`1q zaz=QItRtd&u>MQGMl5mbzlb-JwNxfDSpOT7>k?;p{(}ZFhqp`9$%rA0h0O!1!-3cM_U| ze+d-rtpohq%bkSg;9mkod+PxI_HrkoIrx`A(cU`1zrEZ^Xb%1*P_(xW@NX}75}Jd5 z2^8(E1N_^|orLD#UjjvY>j3}uawnlV_?JM@-a5d)z1&G?4*n%jw6~6g^1mmi7eAXM z2}$I6L~JG0AAbK^68hxmz*q-an^TIiNqqlXDe|yTeFp!f9V5U$_!pVv(OApRNL;8m zhIYeY5C0E`+hVad@DKhG|4Y{f;Qyt|J&$4FAN(W!m#z)K|4Wy9#Q)Of9{huU#Q)N^ z0q_t0G5;@J8-V|pF87%K8FHTukf}_j%N;6l?2kt}GxNE^nBh2+nNq5OP?^C4$8J2v zv#wRa_^|Dat7-qL;C=q!coq-#CpTVM|AIZa?Wxe;vUW7JfBc$@&uThs zq|`+zrK;&aeHy3#&-oYg`lq%zje8uBK5a4U#B<8H68i7KxjPH|!~SFR$GH67X$SuA zc6N___&@wVf&Sk_I`Mm_qbBWp;;9XOzID`gxO@{OC+ z`W!aU>2Qed78F7M!M}m6IrgpF!u|_9EO-X|AMww?ClUXHyH-Jo0OKEj-=s?>_icD? z!{d9g#LxH9wZ$^U>fm$QITFbjy(g^zxR4pt9?$D^~&Ra3`AN<4q;r|*N zK>X80V8YL2{RiT|21nD=IlOgql0`fJ(fF$~pTYI7sX-NyBdmGu>R{@(#r01S3{6iF z&ba@nw>3W-%+H7an;KpbIru-~Ux~;6lK%dyNN5m0mBjH+ypu&LME{dF|0zP3_$lxY z{td?3{CTe5vB>;a+{lYm2>Tbce_l(>*OO6`$$g|Gu7CFPi=g>Z_kV-7n-qEUzevpD zQ{-bN_rH_EXHY%BKlm?-0PG+BPc{%VuGk`f$8JoI1KI1T)PKgWRoQsVb`1x&)To~J_fm^{xbOj}-M&_40% z{s`j#8r9`@@6E|Cc%B@UF?pUH*X_lpJI{BVaQ8RNey(5cw!HSyei(2K%1Gf)_0&_! zUvE{fQ$etOukX3E-;~q37TE_TY7FGX8m$lgdN;c@9Rt73%t1E$OGuM8=)UC~uAo%e(8K{rqmsj4o-kP3Q=i246 z?y->fId57I_c^Ce&X&$O>sh?;*RDqhTAt)1=}+Yy9_H{z`KI{!m5;=qgQwS)bn}`g wMzvq3AI{c`Kj)mu)~I}{i{e8rrRg5*r@oj?mE}rZI+RVk@8q+LBhTCa0|tTeyZ`_I literal 0 HcmV?d00001 diff --git a/docs/source/_static/js/custom.js b/docs/source/_static/js/custom.js new file mode 100644 index 0000000000..2a0a4ec4b0 --- /dev/null +++ b/docs/source/_static/js/custom.js @@ -0,0 +1,108 @@ +// Based on https://github.com/huggingface/transformers/blob/master/docs/source/_static/js/custom.js + + +// These two things need to be updated at each release for the version selector. +// Last stable version +const stableVersion = "v0.5.0" +// Dictionary doc folder to label. The last stable version should have an empty key. +const versionMapping = { + "latest": "latest", + "": "v0.5.0 (stable)", + "v0.4.1": "v0.4.1", + "v0.4.0": "v0.4.0", + "v0.3.1": "v0.3.1", + "v0.3.0": "v0.3.0", + "v0.2.1": "v0.2.1", + "v0.2.0": "v0.2.0", + "v0.1.1": "v0.1.1", + "v0.1.0": "v0.1.0", +} + +function addGithubButton() { + const div = ` + + `; + document.querySelector(".wy-side-nav-search .icon-home").insertAdjacentHTML('afterend', div); +} + +function addVersionControl() { + // To grab the version currently in view, we parse the url + const parts = location.toString().split('#')[0].split('/'); + let versionIndex = parts.length - 2; + // Index page may not have a last part with filename.html so we need to go up + if (parts[parts.length - 1] != "" && ! parts[parts.length - 1].match(/\.html$|^search.html?/)) { + versionIndex = parts.length - 1; + } + const version = parts[versionIndex]; + + // Menu with all the links, + const versionMenu = document.createElement("div"); + + const htmlLines = []; + for (const [key, value] of Object.entries(versionMapping)) { + let baseUrlIndex = (version == "doctr") ? versionIndex + 1: versionIndex; + var urlParts = parts.slice(0, baseUrlIndex); + if (key != "") { + urlParts = urlParts.concat([key]); + } + urlParts = urlParts.concat(parts.slice(versionIndex+1)); + htmlLines.push(`${value}`); + } + + versionMenu.classList.add("version-dropdown"); + versionMenu.innerHTML = htmlLines.join('\n'); + + // Button for version selection + const versionButton = document.createElement("div"); + versionButton.classList.add("version-button"); + let label = (version == "doctr") ? stableVersion : version + versionButton.innerText = label.concat(" ▼"); + + // Toggle the menu when we click on the button + versionButton.addEventListener("click", () => { + versionMenu.classList.toggle("version-show"); + }); + + // Hide the menu when we click elsewhere + window.addEventListener("click", (event) => { + if (event.target != versionButton){ + versionMenu.classList.remove('version-show'); + } + }); + + // Container + const div = document.createElement("div"); + div.appendChild(versionButton); + div.appendChild(versionMenu); + div.style.paddingTop = '5px'; + div.style.paddingBottom = '5px'; + div.style.display = 'block'; + div.style.textAlign = 'center'; + + const scrollDiv = document.querySelector(".wy-side-nav-search"); + scrollDiv.insertBefore(div, scrollDiv.children[1]); +} + +/*! + * github-buttons v2.2.10 + * (c) 2019 なつき + * @license BSD-2-Clause + */ +/** + * modified to run programmatically + */ +function parseGithubButtons (){"use strict";var e=window.document,t=e.location,o=window.encodeURIComponent,r=window.decodeURIComponent,n=window.Math,a=window.HTMLElement,i=window.XMLHttpRequest,l="https://unpkg.com/github-buttons@2.2.10/dist/buttons.html",c=i&&i.prototype&&"withCredentials"in i.prototype,d=c&&a&&a.prototype.attachShadow&&!a.prototype.attachShadow.prototype,s=function(e,t,o){e.addEventListener?e.addEventListener(t,o):e.attachEvent("on"+t,o)},u=function(e,t,o){e.removeEventListener?e.removeEventListener(t,o):e.detachEvent("on"+t,o)},h=function(e,t,o){var r=function(n){return u(e,t,r),o(n)};s(e,t,r)},f=function(e,t,o){var r=function(n){if(t.test(e.readyState))return u(e,"readystatechange",r),o(n)};s(e,"readystatechange",r)},p=function(e){return function(t,o,r){var n=e.createElement(t);if(o)for(var a in o){var i=o[a];null!=i&&(null!=n[a]?n[a]=i:n.setAttribute(a,i))}if(r)for(var l=0,c=r.length;l'},eye:{width:16,height:16,path:''},star:{width:14,height:16,path:''},"repo-forked":{width:10,height:16,path:''},"issue-opened":{width:14,height:16,path:''},"cloud-download":{width:16,height:16,path:''}},w={},x=function(e,t,o){var r=p(e.ownerDocument),n=e.appendChild(r("style",{type:"text/css"}));n.styleSheet?n.styleSheet.cssText=m:n.appendChild(e.ownerDocument.createTextNode(m));var a,l,d=r("a",{className:"btn",href:t.href,target:"_blank",innerHTML:(a=t["data-icon"],l=/^large$/i.test(t["data-size"])?16:14,a=(""+a).toLowerCase().replace(/^octicon-/,""),{}.hasOwnProperty.call(v,a)||(a="mark-github"),'"),"aria-label":t["aria-label"]||void 0},[" ",r("span",{},[t["data-text"]||""])]);/\.github\.com$/.test("."+d.hostname)?/^https?:\/\/((gist\.)?github\.com\/[^\/?#]+\/[^\/?#]+\/archive\/|github\.com\/[^\/?#]+\/[^\/?#]+\/releases\/download\/|codeload\.github\.com\/)/.test(d.href)&&(d.target="_top"):(d.href="#",d.target="_self");var u,h,g,x,y=e.appendChild(r("div",{className:"widget"+(/^large$/i.test(t["data-size"])?" lg":"")},[d]));/^(true|1)$/i.test(t["data-show-count"])&&"github.com"===d.hostname&&(u=d.pathname.replace(/^(?!\/)/,"/").match(/^\/([^\/?#]+)(?:\/([^\/?#]+)(?:\/(?:(subscription)|(fork)|(issues)|([^\/?#]+)))?)?(?:[\/?#]|$)/))&&!u[6]?(u[2]?(h="/repos/"+u[1]+"/"+u[2],u[3]?(x="subscribers_count",g="watchers"):u[4]?(x="forks_count",g="network"):u[5]?(x="open_issues_count",g="issues"):(x="stargazers_count",g="stargazers")):(h="/users/"+u[1],g=x="followers"),function(e,t){var o=w[e]||(w[e]=[]);if(!(o.push(t)>1)){var r=b(function(){for(delete w[e];t=o.shift();)t.apply(null,arguments)});if(c){var n=new i;s(n,"abort",r),s(n,"error",r),s(n,"load",function(){var e;try{e=JSON.parse(n.responseText)}catch(e){return void r(e)}r(200!==n.status,e)}),n.open("GET",e),n.send()}else{var a=this||window;a._=function(e){a._=null,r(200!==e.meta.status,e.data)};var l=p(a.document)("script",{async:!0,src:e+(/\?/.test(e)?"&":"?")+"callback=_"}),d=function(){a._&&a._({meta:{}})};s(l,"load",d),s(l,"error",d),l.readyState&&f(l,/de|m/,d),a.document.getElementsByTagName("head")[0].appendChild(l)}}}.call(this,"https://api.github.com"+h,function(e,t){if(!e){var n=t[x];y.appendChild(r("a",{className:"social-count",href:t.html_url+"/"+g,target:"_blank","aria-label":n+" "+x.replace(/_count$/,"").replace("_"," ").slice(0,n<2?-1:void 0)+" on GitHub"},[r("b"),r("i"),r("span",{},[(""+n).replace(/\B(?=(\d{3})+(?!\d))/g,",")])]))}o&&o(y)})):o&&o(y)},y=window.devicePixelRatio||1,C=function(e){return(y>1?n.ceil(n.round(e*y)/y*2)/2:n.ceil(e))||0},F=function(e,t){e.style.width=t[0]+"px",e.style.height=t[1]+"px"},k=function(t,r){if(null!=t&&null!=r)if(t.getAttribute&&(t=function(e){for(var t={href:e.href,title:e.title,"aria-label":e.getAttribute("aria-label")},o=["icon","text","size","show-count"],r=0,n=o.length;r`_ + +v0.4.1 (2021-11-22) +------------------- +Release note: `v0.4.1 `_ + +v0.4.0 (2021-10-01) +------------------- +Release note: `v0.4.0 `_ + +v0.3.1 (2021-08-27) +------------------- +Release note: `v0.3.1 `_ + +v0.3.0 (2021-07-02) +------------------- +Release note: `v0.3.0 `_ + +v0.2.1 (2021-05-28) +------------------- +Release note: `v0.2.1 `_ + +v0.2.0 (2021-05-11) +------------------- +Release note: `v0.2.0 `_ + +v0.1.1 (2021-03-18) +------------------- +Release note: `v0.1.1 `_ + +v0.1.0 (2021-03-05) +------------------- +Release note: `v0.1.0 `_ diff --git a/docs/source/conf.py b/docs/source/conf.py new file mode 100644 index 0000000000..c8110bc5da --- /dev/null +++ b/docs/source/conf.py @@ -0,0 +1,101 @@ +# Configuration file for the Sphinx documentation builder. +# +# This file only contains a selection of the most common options. For a full +# list see the documentation: +# https://www.sphinx-doc.org/en/master/usage/configuration.html + +# -- Path setup -------------------------------------------------------------- + +# If extensions (or modules to document with autodoc) are in another directory, +# add these directories to sys.path here. If the directory is relative to the +# documentation root, use os.path.abspath to make it absolute, like shown here. +# +import os +import sys +from datetime import datetime + +import sphinx_rtd_theme + +sys.path.insert(0, os.path.abspath('../..')) +import doctr + +# -- Project information ----------------------------------------------------- + +master_doc = 'index' +project = 'docTR' +_copyright_str = f"-{datetime.now().year}" if datetime.now().year > 2021 else "" +copyright = f"2021{_copyright_str}, Mindee" +author = 'François-Guillaume Fernandez, Charles Gaillard' + +# The full version, including alpha/beta/rc tags +version = doctr.__version__ +release = doctr.__version__ + '-git' + +# -- General configuration --------------------------------------------------- + +# Add any Sphinx extension module names here, as strings. They can be +# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom +# ones. +extensions = [ + 'sphinx.ext.autodoc', + 'sphinx.ext.napoleon', + 'sphinx.ext.viewcode', + 'sphinx.ext.coverage', + 'sphinx.ext.mathjax', + 'sphinx.ext.autosectionlabel', + 'sphinxemoji.sphinxemoji', # cf. https://sphinxemojicodes.readthedocs.io/en/stable/ + 'sphinx_copybutton', + 'recommonmark', + 'sphinx_markdown_tables', +] + +napoleon_use_ivar = True + +# Add any paths that contain templates here, relative to this directory. +templates_path = ['_templates'] + +# List of patterns, relative to source directory, that match files and +# directories to ignore when looking for source files. +# This pattern also affects html_static_path and html_extra_path. +exclude_patterns = [u'_build', 'Thumbs.db', '.DS_Store', 'notebooks/*.rst'] + + +# The name of the Pygments (syntax highlighting) style to use. +pygments_style = 'sphinx' +highlight_language = 'python3' + +# -- Options for HTML output ------------------------------------------------- + +# The theme to use for HTML and HTML Help pages. See the documentation for +# a list of builtin themes. +# +html_theme = 'sphinx_rtd_theme' +html_theme_path = [sphinx_rtd_theme.get_html_theme_path()] + +# Theme options are theme-specific and customize the look and feel of a theme +# further. For a list of options available for each theme, see the +# documentation. +# +html_theme_options = { + 'collapse_navigation': False, + 'display_version': False, + 'logo_only': False, + 'analytics_id': 'G-40DVRMX8T4', +} + +html_logo = '_static/images/Logo-docTR-white.png' +html_favicon = '_static/images/favicon.ico' + + + +# Add any paths that contain custom static files (such as style sheets) here, +# relative to this directory. They are copied after the builtin static files, +# so a file named "default.css" will overwrite the builtin "default.css". +html_static_path = ['_static'] + +# A list of files that should not be packed into the epub file. +epub_exclude_files = ['search.html'] + +def setup(app): + app.add_css_file('css/mindee.css') + app.add_js_file('js/custom.js') diff --git a/docs/source/datasets.rst b/docs/source/datasets.rst new file mode 100644 index 0000000000..e40b1c506a --- /dev/null +++ b/docs/source/datasets.rst @@ -0,0 +1,104 @@ +doctr.datasets +============== + +.. currentmodule:: doctr.datasets + +Whether it is for training or for evaluation, having predefined objects to access datasets in your prefered framework +can be a significant save of time. + + +.. _datasets: + +Available Datasets +------------------ +Here are all datasets that are available through docTR: + + +Public datasets +^^^^^^^^^^^^^^^ + +.. autoclass:: FUNSD +.. autoclass:: SROIE +.. autoclass:: CORD +.. autoclass:: IIIT5K +.. autoclass:: SVT +.. autoclass:: SVHN +.. autoclass:: SynthText +.. autoclass:: IC03 +.. autoclass:: IC13 +.. autoclass:: IMGUR5K + +docTR synthetic datasets +^^^^^^^^^^^^^^^^^^^^^^^^ + +.. autoclass:: DocArtefacts +.. autoclass:: CharacterGenerator +.. autoclass:: WordGenerator + +docTR private datasets +^^^^^^^^^^^^^^^^^^^^^^ + +Since many documents include sensitive / personal information, we are not able to share all the data that has been used for this project. However, we provide some guidance on how to format your own dataset into the same format so that you can use all docTR tools all the same. + +.. autoclass:: DetectionDataset +.. autoclass:: RecognitionDataset +.. autoclass:: OCRDataset + + +Data Loading +------------ +Each dataset has its specific way to load a sample, but handling batch aggregation and the underlying iterator is a task deferred to another object in docTR. + +.. autoclass:: doctr.datasets.loader.DataLoader + + +.. _vocabs: + +Supported Vocabs +---------------- + +Since textual content has to be encoded properly for models to interpret them efficiently, docTR supports multiple sets +of vocabs. + +.. list-table:: docTR Vocabs + :widths: 20 5 50 + :header-rows: 1 + + * - Name + - size + - characters + * - digits + - 10 + - 0123456789 + * - ascii_letters + - 52 + - abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ + * - punctuation + - 32 + - !"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~ + * - currency + - 5 + - £€¥¢฿ + * - latin + - 94 + - 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~ + * - english + - 100 + - 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~°£€¥¢฿ + * - legacy_french + - 123 + - 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&\'()*+,-./:;<=>?@[\\]^_`{|}~°àâéèêëîïôùûçÀÂÉÈËÎÏÔÙÛÇ£€¥¢฿ + * - french + - 126 + - 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~°£€¥¢฿àâéèêëîïôùûüçÀÂÉÈÊËÎÏÔÙÛÜÇ + * - portuguese + - 131 + - 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~°£€¥¢฿áàâãéêëíïóôõúüçÁÀÂÃÉËÍÏÓÔÕÚÜÇ¡¿ + * - spanish + - 116 + - 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~°£€¥¢฿áéíóúüñÁÉÍÓÚÜÑ¡¿ + * - german + - 108 + - 0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~°£€¥¢฿äöüßÄÖÜẞ + +.. autofunction:: encode_sequences diff --git a/docs/source/index.rst b/docs/source/index.rst new file mode 100644 index 0000000000..2be367403c --- /dev/null +++ b/docs/source/index.rst @@ -0,0 +1,91 @@ +docTR: Document Text Recognition +================================ + +State-of-the-art Optical Character Recognition made seamless & accessible to anyone, powered by TensorFlow 2 & PyTorch + +.. image:: https://github.com/mindee/doctr/releases/download/v0.2.0/ocr.png + :align: center + + +DocTR provides an easy and powerful way to extract valuable information from your documents: + +* |:receipt:| **for automation**: seemlessly process documents for Natural Language Understanding tasks: we provide OCR predictors to parse textual information (localize and identify each word) from your documents. +* |:woman_scientist:| **for research**: quickly compare your own architectures speed & performances with state-of-art models on public datasets. + + +Main Features +------------- + +* |:robot:| Robust 2-stage (detection + recognition) OCR predictors with pretrained parameters +* |:zap:| User-friendly, 3 lines of code to load a document and extract text with a predictor +* |:rocket:| State-of-the-art performances on public document datasets, comparable with GoogleVision/AWS Textract +* |:zap:| Optimized for inference speed on both CPU & GPU +* |:bird:| Light package, minimal dependencies +* |:tools:| Actively maintained by Mindee +* |:factory:| Easy integration (available templates for browser demo & API deployment) + + +.. toctree:: + :maxdepth: 2 + :caption: Getting started + :hidden: + + installing + notebooks + + +Model zoo +^^^^^^^^^ + +Text detection models +""""""""""""""""""""" + * DBNet from `"Real-time Scene Text Detection with Differentiable Binarization" `_ + * LinkNet from `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" `_ + +Text recognition models +""""""""""""""""""""""" + * SAR from `"Show, Attend and Read: A Simple and Strong Baseline for Irregular Text Recognition" `_ + * CRNN from `"An End-to-End Trainable Neural Network for Image-based Sequence Recognition and Its Application to Scene Text Recognition" `_ + * MASTER from `"MASTER: Multi-Aspect Non-local Network for Scene Text Recognition" `_ + + +Supported datasets +^^^^^^^^^^^^^^^^^^ + * FUNSD from `"FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents" `_. + * CORD from `"CORD: A Consolidated Receipt Dataset forPost-OCR Parsing" `_. + * SROIE from `ICDAR 2019 `_. + * IIIT-5k from `CVIT `_. + * Street View Text from `"End-to-End Scene Text Recognition" `_. + * SynthText from `Visual Geometry Group `_. + * SVHN from `"Reading Digits in Natural Images with Unsupervised Feature Learning" `_. + * IC03 from `ICDAR 2003 `_. + * IC13 from `ICDAR 2013 `_. + + +.. toctree:: + :maxdepth: 2 + :caption: Using docTR + :hidden: + + using_models + using_model_export + + +.. toctree:: + :maxdepth: 2 + :caption: Package Reference + :hidden: + + datasets + io + models + transforms + utils + + +.. toctree:: + :maxdepth: 2 + :caption: Notes + :hidden: + + changelog diff --git a/docs/source/installing.rst b/docs/source/installing.rst new file mode 100644 index 0000000000..8197df660d --- /dev/null +++ b/docs/source/installing.rst @@ -0,0 +1,66 @@ + +************ +Installation +************ + +This library requires `Python `_ 3.6 or higher. + + +Prerequisites +============= + +Whichever OS you are running, you will need to install at least TensorFlow or PyTorch. You can refer to their corresponding installation pages to do so: + +* `TensorFlow 2 `_ +* `PyTorch `_ + +If you are running another OS than Linux, you will need a few extra dependencies. + +For MacOS users, you can install them using `Homebrew `_ as follows: + +.. code:: shell + + brew install cairo pango gdk-pixbuf libffi + +For Windows users, those dependencies are included in GTK. You can find the latest installer over `here `_. + + +Via Python Package +================== + +Install the last stable release of the package using `pip `_: + +.. code:: bash + + pip install python-doctr + + +We strive towards reducing framework-specific dependencies to a minimum, but some necessary features are developed by third-parties for specific frameworks. To avoid missing some dependencies for a specific framework, you can install specific builds as follows: + +.. code:: bash + + # for TensorFlow + pip install "python-doctr[tf]" + # for PyTorch + pip install "python-doctr[torch]" + + +Via Git +======= + +Install the library in developper mode: + +.. code:: bash + + git clone https://github.com/mindee/doctr.git + pip install -e doctr/. + +Again, for framework-specific builds: + +.. code:: bash + + git clone https://github.com/mindee/doctr.git + # for TensorFlow + pip install -e doctr/.[tf] + # for PyTorch + pip install -e doctr/.[torch] diff --git a/docs/source/io.rst b/docs/source/io.rst new file mode 100644 index 0000000000..8fa887e9f9 --- /dev/null +++ b/docs/source/io.rst @@ -0,0 +1,94 @@ +doctr.io +======== + + +.. currentmodule:: doctr.io + +The io module enables users to easily access content from documents and export analysis +results to structured formats. + +.. _document_structure: + +Document structure +------------------ + +Structural organization of the documents. + +Word +^^^^ +A Word is an uninterrupted sequence of characters. + +.. autoclass:: Word + +Line +^^^^ +A Line is a collection of Words aligned spatially and meant to be read together (on a two-column page, on the same horizontal, we will consider that there are two Lines). + +.. autoclass:: Line + +Artefact +^^^^^^^^ + +An Artefact is a non-textual element (e.g. QR code, picture, chart, signature, logo, etc.). + +.. autoclass:: Artefact + +Block +^^^^^ +A Block is a collection of Lines (e.g. an address written on several lines) and Artefacts (e.g. a graph with its title underneath). + +.. autoclass:: Block + +Page +^^^^ + +A Page is a collection of Blocks that were on the same physical page. + +.. autoclass:: Page + + .. automethod:: show + + +Document +^^^^^^^^ + +A Document is a collection of Pages. + +.. autoclass:: Document + + .. automethod:: show + + +File reading +------------ + +High-performance file reading and conversion to processable structured data. + +.. autofunction:: read_pdf + +.. autofunction:: read_img_as_numpy + +.. autofunction:: read_img_as_tensor + +.. autofunction:: decode_img_as_tensor + +.. autofunction:: read_html + + +.. autoclass:: DocumentFile + + .. automethod:: from_pdf + + .. automethod:: from_url + + .. automethod:: from_images + +.. autoclass:: PDF + + .. automethod:: as_images + + .. automethod:: get_words + + .. automethod:: get_lines + + .. automethod:: get_artefacts diff --git a/docs/source/models.rst b/docs/source/models.rst new file mode 100644 index 0000000000..d4f36df9bb --- /dev/null +++ b/docs/source/models.rst @@ -0,0 +1,62 @@ +doctr.models +============ + +.. currentmodule:: doctr.models + + +doctr.models.classification +---------------------- + +.. autofunction:: doctr.models.classification.vgg16_bn_r + +.. autofunction:: doctr.models.classification.resnet18 + +.. autofunction:: doctr.models.classification.resnet31 + +.. autofunction:: doctr.models.classification.mobilenet_v3_small + +.. autofunction:: doctr.models.classification.mobilenet_v3_large + +.. autofunction:: doctr.models.classification.mobilenet_v3_small_r + +.. autofunction:: doctr.models.classification.mobilenet_v3_large_r + +.. autofunction:: doctr.models.classification.mobilenet_v3_small_orientation + +.. autofunction:: doctr.models.classification.magc_resnet31 + +.. autofunction:: doctr.models.classification.crop_orientation_predictor + + +doctr.models.detection +---------------------- + +.. autofunction:: doctr.models.detection.linknet_resnet18 + +.. autofunction:: doctr.models.detection.db_resnet50 + +.. autofunction:: doctr.models.detection.db_mobilenet_v3_large + +.. autofunction:: doctr.models.detection.detection_predictor + + +doctr.models.recognition +------------------------ + +.. autofunction:: doctr.models.recognition.crnn_vgg16_bn + +.. autofunction:: doctr.models.recognition.crnn_mobilenet_v3_small + +.. autofunction:: doctr.models.recognition.crnn_mobilenet_v3_large + +.. autofunction:: doctr.models.recognition.sar_resnet31 + +.. autofunction:: doctr.models.recognition.master + +.. autofunction:: doctr.models.recognition.recognition_predictor + + +doctr.models.zoo +---------------- + +.. autofunction:: doctr.models.ocr_predictor diff --git a/docs/source/notebooks.md b/docs/source/notebooks.md new file mode 120000 index 0000000000..1ffa21de25 --- /dev/null +++ b/docs/source/notebooks.md @@ -0,0 +1 @@ +../../notebooks/README.md \ No newline at end of file diff --git a/docs/source/transforms.rst b/docs/source/transforms.rst new file mode 100644 index 0000000000..1e5b1d8b93 --- /dev/null +++ b/docs/source/transforms.rst @@ -0,0 +1,38 @@ +doctr.transforms +================ + +.. currentmodule:: doctr.transforms + +Data transformations are part of both training and inference procedure. Drawing inspiration from the design of `torchvision `_, we express transformations as composable modules. + + +Supported transformations +------------------------- +Here are all transformations that are available through docTR: + +.. autoclass:: Resize +.. autoclass:: Normalize +.. autoclass:: LambdaTransformation +.. autoclass:: ToGray +.. autoclass:: ColorInversion +.. autoclass:: RandomBrightness +.. autoclass:: RandomContrast +.. autoclass:: RandomSaturation +.. autoclass:: RandomHue +.. autoclass:: RandomGamma +.. autoclass:: RandomJpegQuality +.. autoclass:: RandomRotate +.. autoclass:: RandomCrop +.. autoclass:: GaussianBlur +.. autoclass:: ChannelShuffle +.. autoclass:: GaussianNoise +.. autoclass:: RandomHorizontalFlip + + +Composing transformations +--------------------------------------------- +It is common to require several transformations to be performed consecutively. + +.. autoclass:: Compose +.. autoclass:: OneOf +.. autoclass:: RandomApply diff --git a/docs/source/using_model_export.rst b/docs/source/using_model_export.rst new file mode 100644 index 0000000000..992f4e9866 --- /dev/null +++ b/docs/source/using_model_export.rst @@ -0,0 +1,71 @@ +Preparing your model for inference +================================== + +A well-trained model is a good achievement but you might want to tune a few things to make it production-ready! + +.. currentmodule:: doctr.models.export + + +Model compression +----------------- + +This section is meant to help you perform inference with compressed versions of your model. + + +TensorFlow Lite +^^^^^^^^^^^^^^^ + +TensorFlow provides utilities packaged as TensorFlow Lite to take resource constraints into account. You can easily convert any Keras model into a serialized TFLite version as follows: + + >>> import tensorflow as tf + >>> from tensorflow.keras import Sequential + >>> from doctr.models import conv_sequence + >>> model = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=(224, 224, 3))) + >>> converter = tf.lite.TFLiteConverter.from_keras_model(tf_model) + >>> serialized_model = converter.convert() + +Half-precision +^^^^^^^^^^^^^^ + +If you want to convert it to half-precision using your TFLite converter + + >>> converter.optimizations = [tf.lite.Optimize.DEFAULT] + >>> converter.target_spec.supported_types = [tf.float16] + >>> serialized_model = converter.convert() + + +Post-training quantization +^^^^^^^^^^^^^^^^^^^^^^^^^^ + +Finally if you wish to quantize the model with your TFLite converter + + >>> converter.optimizations = [tf.lite.Optimize.DEFAULT] + >>> # Float fallback for operators that do not have an integer implementation + >>> def representative_dataset(): + >>> for _ in range(100): yield [np.random.rand(1, *input_shape).astype(np.float32)] + >>> converter.representative_dataset = representative_dataset + >>> converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8] + >>> converter.inference_input_type = tf.int8 + >>> converter.inference_output_type = tf.int8 + >>> serialized_model = converter.convert() + + +Using SavedModel +---------------- + +Additionally, models in docTR inherit TensorFlow 2 model properties and can be exported to +`SavedModel `_ format as follows: + + + >>> import tensorflow as tf + >>> from doctr.models import db_resnet50 + >>> model = db_resnet50(pretrained=True) + >>> input_t = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32) + >>> _ = model(input_t, training=False) + >>> tf.saved_model.save(model, 'path/to/your/folder/db_resnet50/') + +And loaded just as easily: + + + >>> import tensorflow as tf + >>> model = tf.saved_model.load('path/to/your/folder/db_resnet50/') diff --git a/docs/source/using_models.rst b/docs/source/using_models.rst new file mode 100644 index 0000000000..1c0752463f --- /dev/null +++ b/docs/source/using_models.rst @@ -0,0 +1,329 @@ +Choosing the right model +======================== + +The full Optical Character Recognition task can be seen as two consecutive tasks: text detection and text recognition. +Either performed at once or separately, to each task corresponds a type of deep learning architecture. + +.. currentmodule:: doctr.models + +For a given task, docTR provides a Predictor, which is composed of 2 components: + +* PreProcessor: a module in charge of making inputs directly usable by the deep learning model. +* Model: a deep learning model, implemented with all supported deep learning backends (TensorFlow & PyTorch) along with its specific post-processor to make outputs structured and reusable. + + +Text Detection +-------------- + +The task consists of localizing textual elements in a given image. +While those text elements can represent many things, in docTR, we will consider uninterrupted character sequences (words). Additionally, the localization can take several forms: from straight bounding boxes (delimited by the 2D coordinates of the top-left and bottom-right corner), to polygons, or binary segmentation (flagging which pixels belong to this element, and which don't). + +Available architectures +^^^^^^^^^^^^^^^^^^^^^^^ + +The following architectures are currently supported: + +* `linknet_resnet18 `_ +* `db_resnet50 `_ +* `db_mobilenet_v3_large `_ + +For a comprehensive comparison, we have compiled a detailed benchmark on publicly available datasets: + + ++------------------------------------------------------------------+----------------------------+----------------------------+---------+ +| | FUNSD | CORD | | ++=================================+=================+==============+============+===============+============+===============+=========+ +| **Architecture** | **Input shape** | **# params** | **Recall** | **Precision** | **Recall** | **Precision** | **FPS** | ++---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+---------+ +| db_resnet50 | (1024, 1024, 3) | 25.2 M | 82.14 | 87.64 | 92.49 | 89.66 | 2.1 | ++---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+---------+ +| db_mobilenet_v3_large | (1024, 1024, 3) | 4.2 M | 79.35 | 84.03 | 81.14 | 66.85 | | ++---------------------------------+-----------------+--------------+------------+---------------+------------+---------------+---------+ + + +All text detection models above have been evaluated using both the training and evaluation sets of FUNSD and CORD (cf. :ref:`datasets`). +Explanations about the metrics being used are available in :ref:`metrics`. + +*Disclaimer: both FUNSD subsets combined have 199 pages which might not be representative enough of the model capabilities* + +FPS (Frames per second) is computed after a warmup phase of 100 tensors (where the batch size is 1), by measuring the average number of processed tensors per second over 1000 samples. Those results were obtained on a `c5.x12large `_ AWS instance (CPU Xeon Platinum 8275L). + + +Detection predictors +^^^^^^^^^^^^^^^^^^^^ + +`detection_predictor `_ wraps your detection model to make it easily useable with your favorite deep learning framework seamlessly. + + >>> import numpy as np + >>> from doctr.models import detection_predictor + >>> predictor = detection_predictor('db_resnet50') + >>> dummy_img = (255 * np.random.rand(800, 600, 3)).astype(np.uint8) + >>> out = model([dummy_img]) + + +Text Recognition +---------------- + +The task consists of transcribing the character sequence in a given image. + + +Available architectures +^^^^^^^^^^^^^^^^^^^^^^^ + +The following architectures are currently supported: + +* `crnn_vgg16_bn `_ +* `crnn_mobilenet_v3_small `_ +* `crnn_mobilenet_v3_large `_ +* `sar_resnet31 `_ +* `master `_ + + +For a comprehensive comparison, we have compiled a detailed benchmark on publicly available datasets: + + +.. list-table:: Text recognition model zoo + :header-rows: 1 + + * - Architecture + - Input shape + - # params + - FUNSD + - CORD + - FPS + * - crnn_vgg16_bn + - (32, 128, 3) + - 15.8M + - 87.18 + - 92.93 + - 12.8 + * - crnn_mobilenet_v3_small + - (32, 128, 3) + - 2.1M + - 86.21 + - 90.56 + - + * - crnn_mobilenet_v3_large + - (32, 128, 3) + - 4.5M + - 86.95 + - 92.03 + - + * - sar_resnet31 + - (32, 128, 3) + - 56.2M + - **87.70** + - **93.41** + - 2.7 + * - master + - (32, 128, 3) + - 67.7M + - 87.62 + - 93.27 + - + +All text recognition models above have been evaluated using both the training and evaluation sets of FUNSD and CORD (cf. :ref:`datasets`). +Explanations about the metric being used (exact match) are available in :ref:`metrics`. + +While most of our recognition models were trained on our french vocab (cf. :ref:`vocabs`), you can easily access the vocab of any model as follows: + + >>> from doctr.models import recognition_predictor + >>> predictor = recognition_predictor('crnn_vgg16_bn') + >>> print(predictor.model.cfg['vocab']) + + +*Disclaimer: both FUNSD subsets combine have 30595 word-level crops which might not be representative enough of the model capabilities* + +FPS (Frames per second) is computed after a warmup phase of 100 tensors (where the batch size is 1), by measuring the average number of processed tensors per second over 1000 samples. Those results were obtained on a `c5.x12large `_ AWS instance (CPU Xeon Platinum 8275L). + + +Recognition predictors +^^^^^^^^^^^^^^^^^^^^^^ +`recognition_predictor `_ wraps your recognition model to make it easily useable with your favorite deep learning framework seamlessly. + + >>> import numpy as np + >>> from doctr.models import recognition_predictor + >>> predictor = recognition_predictor('crnn_vgg16_bn') + >>> dummy_img = (255 * np.random.rand(50, 150, 3)).astype(np.uint8) + >>> out = model([dummy_img]) + + +End-to-End OCR +-------------- + +The task consists of both localizing and transcribing textual elements in a given image. + +Available architectures +^^^^^^^^^^^^^^^^^^^^^^^ + +You can use any combination of detection and recognition models supporte by docTR. + +For a comprehensive comparison, we have compiled a detailed benchmark on publicly available datasets: + ++----------------------------------------+--------------------------------------+--------------------------------------+ +| | FUNSD | CORD | ++========================================+============+===============+=========+============+===============+=========+ +| **Architecture** | **Recall** | **Precision** | **FPS** | **Recall** | **Precision** | **FPS** | ++----------------------------------------+------------+---------------+---------+------------+---------------+---------+ +| db_resnet50 + crnn_vgg16_bn | 71.25 | 76.02 | 0.85 | 84.00 | 81.42 | 1.6 | ++----------------------------------------+------------+---------------+---------+------------+---------------+---------+ +| db_resnet50 + master | 71.03 | 76.06 | | 84.49 | 81.94 | | ++----------------------------------------+------------+---------------+---------+------------+---------------+---------+ +| db_resnet50 + sar_resnet31 | 71.25 | 76.29 | 0.27 | 84.50 | **81.96** | 0.83 | ++----------------------------------------+------------+---------------+---------+------------+---------------+---------+ +| db_resnet50 + crnn_mobilenet_v3_small | 69.85 | 74.80 | | 80.85 | 78.42 | 0.83 | ++----------------------------------------+------------+---------------+---------+------------+---------------+---------+ +| db_resnet50 + crnn_mobilenet_v3_large | 70.57 | 75.57 | | 82.57 | 80.08 | 0.83 | ++----------------------------------------+------------+---------------+---------+------------+---------------+---------+ +| db_mobilenet_v3_large + crnn_vgg16_bn | 67.73 | 71.73 | | 71.65 | 59.03 | | ++----------------------------------------+------------+---------------+---------+------------+---------------+---------+ +| Gvision text detection | 59.50 | 62.50 | | 75.30 | 70.00 | | ++----------------------------------------+------------+---------------+---------+------------+---------------+---------+ +| Gvision doc. text detection | 64.00 | 53.30 | | 68.90 | 61.10 | | ++----------------------------------------+------------+---------------+---------+------------+---------------+---------+ +| AWS textract | **78.10** | **83.00** | | **87.50** | 66.00 | | ++----------------------------------------+------------+---------------+---------+------------+---------------+---------+ + +All OCR models above have been evaluated using both the training and evaluation sets of FUNSD and CORD (cf. :ref:`datasets`). +Explanations about the metrics being used are available in :ref:`metrics`. + +*Disclaimer: both FUNSD subsets combine have 199 pages which might not be representative enough of the model capabilities* + +FPS (Frames per second) is computed after a warmup phase of 100 tensors (where the batch size is 1), by measuring the average number of processed frames per second over 1000 samples. Those results were obtained on a `c5.x12large `_ AWS instance (CPU Xeon Platinum 8275L). + +Since you may be looking for specific use cases, we also performed this benchmark on private datasets with various document types below. Unfortunately, we are not able to share those at the moment since they contain sensitive information. + + ++----------------------------------------------+----------------------------+----------------------------+----------------------------+----------------------------+----------------------------+----------------------------+ +| | Receipts | Invoices | IDs | US Tax Forms | Resumes | Road Fines | ++==============================================+============+===============+============+===============+============+===============+============+===============+============+===============+============+===============+ +| **Architecture** | **Recall** | **Precision** | **Recall** | **Precision** | **Recall** | **Precision** | **Recall** | **Precision** | **Recall** | **Precision** | **Recall** | **Precision** | ++----------------------------------------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+ +| db_resnet50 + crnn_vgg16_bn (ours) | 78.70 | 81.12 | 65.80 | 70.70 | 50.25 | 51.78 | 79.08 | 92.83 | | | | | ++----------------------------------------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+ +| db_resnet50 + master (ours) | **79.00** | **81.42** | 65.57 | 69.86 | 51.34 | 52.90 | 78.86 | 92.57 | | | | | ++----------------------------------------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+ +| db_resnet50 + sar_resnet31 (ours) | 78.94 | 81.37 | 65.89 | **70.79** | **51.78** | **53.35** | 79.04 | 92.78 | | | | | ++----------------------------------------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+ +| db_resnet50 + crnn_mobilenet_v3_small (ours) | 76.81 | 79.15 | 64.89 | 69.61 | 45.03 | 46.38 | 78.96 | 92.11 | 85.91 | 87.20 | 84.85 | 85.86 | ++----------------------------------------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+ +| db_resnet50 + crnn_mobilenet_v3_large (ours) | 78.01 | 80.39 | 65.36 | 70.11 | 48.00 | 49.43 | 79.39 | 92.62 | 87.68 | 89.00 | 85.65 | 86.67 | ++----------------------------------------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+ +| db_mobilenet_v3_large + crnn_vgg16_bn (ours) | 78.36 | 74.93 | 63.04 | 68.41 | 39.36 | 41.75 | 72.14 | 89.97 | | | | | ++----------------------------------------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+ +| Gvision doc. text detection | 68.91 | 59.89 | 63.20 | 52.85 | 43.70 | 29.21 | 69.79 | 65.68 | | | | | ++----------------------------------------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+ +| AWS textract | 75.77 | 77.70 | **70.47** | 69.13 | 46.39 | 43.32 | **84.31** | **98.11** | | | | | ++----------------------------------------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+------------+---------------+ + + +Two-stage approaches +^^^^^^^^^^^^^^^^^^^^ +Those architectures involve one stage of text detection, and one stage of text recognition. The text detection will be used to produces cropped images that will be passed into the text recognition block. Everything is wrapped up with `ocr_predictor `_. + + >>> import numpy as np + >>> from doctr.models import ocr_predictor + >>> model = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True) + >>> input_page = (255 * np.random.rand(800, 600, 3)).astype(np.uint8) + >>> out = model([input_page]) + + +What should I do with the output? +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +The ocr_predictor returns a `Document` object with a nested structure (with `Page`, `Block`, `Line`, `Word`, `Artefact`). +To get a better understanding of our document model, check our :ref:`document_structure` section + +Here is a typical `Document` layout:: + + Document( + (pages): [Page( + dimensions=(340, 600) + (blocks): [Block( + (lines): [Line( + (words): [ + Word(value='No.', confidence=0.91), + Word(value='RECEIPT', confidence=0.99), + Word(value='DATE', confidence=0.96), + ] + )] + (artefacts): [] + )] + )] + ) + +You can also export them as a nested dict, more appropriate for JSON format:: + + json_output = result.export() + +For reference, here is the JSON export for the same `Document` as above:: + + { + 'pages': [ + { + 'page_idx': 0, + 'dimensions': (340, 600), + 'orientation': {'value': None, 'confidence': None}, + 'language': {'value': None, 'confidence': None}, + 'blocks': [ + { + 'geometry': ((0.1357421875, 0.0361328125), (0.8564453125, 0.8603515625)), + 'lines': [ + { + 'geometry': ((0.1357421875, 0.0361328125), (0.8564453125, 0.8603515625)), + 'words': [ + { + 'value': 'No.', + 'confidence': 0.914085328578949, + 'geometry': ((0.5478515625, 0.06640625), (0.5810546875, 0.0966796875)) + }, + { + 'value': 'RECEIPT', + 'confidence': 0.9949972033500671, + 'geometry': ((0.1357421875, 0.0361328125), (0.51171875, 0.1630859375)) + }, + { + 'value': 'DATE', + 'confidence': 0.9578408598899841, + 'geometry': ((0.1396484375, 0.3232421875), (0.185546875, 0.3515625)) + } + ] + } + ], + 'artefacts': [] + } + ] + } + ] + } + +To export the outpout as XML (hocr-format) you can use the `export_as_xml` method:: + + xml_output = result.export_as_xml() + for output in xml_output: + xml_bytes_string = output[0] + xml_element = output[1] + +For reference, here is a sample XML byte string output:: + + + + + docTR - hOCR + + + + + +
+
+

+ + Hello + XML + World + +

+
+ + \ No newline at end of file diff --git a/docs/source/utils.rst b/docs/source/utils.rst new file mode 100644 index 0000000000..ac0b13d9df --- /dev/null +++ b/docs/source/utils.rst @@ -0,0 +1,46 @@ +doctr.utils +=========== + +This module regroups non-core features that are complementary to the rest of the package. + +.. currentmodule:: doctr.utils + + +Visualization +------------- +Easy-to-use functions to make sense of your model's predictions. + +.. currentmodule:: doctr.utils.visualization + +.. autofunction:: visualize_page + +.. autofunction:: synthesize_page + + +.. _metrics: + +Task evaluation +--------------- +Implementations of task-specific metrics to easily assess your model performances. + +.. currentmodule:: doctr.utils.metrics + +.. autoclass:: TextMatch + + .. automethod:: update + .. automethod:: summary + +.. autoclass:: LocalizationConfusion + + .. automethod:: update + .. automethod:: summary + +.. autoclass:: OCRMetric + + .. automethod:: update + .. automethod:: summary + +.. autoclass:: DetectionMetric + + .. automethod:: update + .. automethod:: summary diff --git a/doctr/__init__.py b/doctr/__init__.py new file mode 100644 index 0000000000..14390c4cd1 --- /dev/null +++ b/doctr/__init__.py @@ -0,0 +1,3 @@ +from . import datasets, io, models, transforms, utils +from .file_utils import is_tf_available, is_torch_available +from .version import __version__ # noqa: F401 diff --git a/doctr/datasets/__init__.py b/doctr/datasets/__init__.py new file mode 100644 index 0000000000..cd187271b1 --- /dev/null +++ b/doctr/datasets/__init__.py @@ -0,0 +1,22 @@ +from doctr.file_utils import is_tf_available + +from .generator import * +from .cord import * +from .detection import * +from .doc_artefacts import * +from .funsd import * +from .ic03 import * +from .ic13 import * +from .iiit5k import * +from .imgur5k import * +from .ocr import * +from .recognition import * +from .sroie import * +from .svhn import * +from .svt import * +from .synthtext import * +from .utils import * +from .vocabs import * + +if is_tf_available(): + from .loader import * diff --git a/doctr/datasets/cord.py b/doctr/datasets/cord.py new file mode 100644 index 0000000000..6740913d63 --- /dev/null +++ b/doctr/datasets/cord.py @@ -0,0 +1,90 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import numpy as np + +from .datasets import VisionDataset +from .utils import convert_target_to_relative + +__all__ = ['CORD'] + + +class CORD(VisionDataset): + """CORD dataset from `"CORD: A Consolidated Receipt Dataset forPost-OCR Parsing" + `_. + + Example:: + >>> from doctr.datasets import CORD + >>> train_set = CORD(train=True, download=True) + >>> img, target = train_set[0] + + Args: + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `VisionDataset`. + """ + TRAIN = ('https://github.com/mindee/doctr/releases/download/v0.1.1/cord_train.zip', + '45f9dc77f126490f3e52d7cb4f70ef3c57e649ea86d19d862a2757c9c455d7f8') + + TEST = ('https://github.com/mindee/doctr/releases/download/v0.1.1/cord_test.zip', + '8c895e3d6f7e1161c5b7245e3723ce15c04d84be89eaa6093949b75a66fb3c58') + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + + url, sha256 = self.TRAIN if train else self.TEST + super().__init__(url, None, sha256, True, pre_transforms=convert_target_to_relative, **kwargs) + + # # List images + tmp_root = os.path.join(self.root, 'image') + self.data: List[Tuple[str, Dict[str, Any]]] = [] + self.train = train + np_dtype = np.float32 + for img_path in os.listdir(tmp_root): + # File existence check + if not os.path.exists(os.path.join(tmp_root, img_path)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}") + + stem = Path(img_path).stem + _targets = [] + with open(os.path.join(self.root, 'json', f"{stem}.json"), 'rb') as f: + label = json.load(f) + for line in label["valid_line"]: + for word in line["words"]: + if len(word["text"]) > 0: + x = word["quad"]["x1"], word["quad"]["x2"], word["quad"]["x3"], word["quad"]["x4"] + y = word["quad"]["y1"], word["quad"]["y2"], word["quad"]["y3"], word["quad"]["y4"] + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + box = np.array([ + [x[0], y[0]], + [x[1], y[1]], + [x[2], y[2]], + [x[3], y[3]], + ], dtype=np_dtype) + else: + # Reduce 8 coords to 4 -> xmin, ymin, xmax, ymax + box = [min(x), min(y), max(x), max(y)] + _targets.append((word['text'], box)) + + text_targets, box_targets = zip(*_targets) + + self.data.append(( + img_path, + dict(boxes=np.asarray(box_targets, dtype=int).clip(min=0), labels=list(text_targets)) + )) + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/datasets/__init__.py b/doctr/datasets/datasets/__init__.py new file mode 100644 index 0000000000..059f261e82 --- /dev/null +++ b/doctr/datasets/datasets/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/datasets/datasets/base.py b/doctr/datasets/datasets/base.py new file mode 100644 index 0000000000..e2704af523 --- /dev/null +++ b/doctr/datasets/datasets/base.py @@ -0,0 +1,118 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os +import shutil +from pathlib import Path +from typing import Any, Callable, List, Optional, Tuple, Union + +from doctr.io.image import get_img_shape +from doctr.utils.data import download_from_url + +__all__ = ['_AbstractDataset', '_VisionDataset'] + + +class _AbstractDataset: + + data: List[Any] = [] + _pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None + + def __init__( + self, + root: Union[str, Path], + img_transforms: Optional[Callable[[Any], Any]] = None, + sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None, + pre_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None, + ) -> None: + + if not Path(root).is_dir(): + raise ValueError(f'expected a path to a reachable folder: {root}') + + self.root = root + self.img_transforms = img_transforms + self.sample_transforms = sample_transforms + self._pre_transforms = pre_transforms + self._get_img_shape = get_img_shape + + def __len__(self) -> int: + return len(self.data) + + def _read_sample(self, index: int) -> Tuple[Any, Any]: + raise NotImplementedError + + def __getitem__( + self, + index: int + ) -> Tuple[Any, Any]: + + # Read image + img, target = self._read_sample(index) + # Pre-transforms (format conversion at run-time etc.) + if self._pre_transforms is not None: + img, target = self._pre_transforms(img, target) + + if self.img_transforms is not None: + # typing issue cf. https://github.com/python/mypy/issues/5485 + img = self.img_transforms(img) # type: ignore[call-arg] + + if self.sample_transforms is not None: + img, target = self.sample_transforms(img, target) + + return img, target + + def extra_repr(self) -> str: + return "" + + def __repr__(self) -> str: + return f"{self.__class__.__name__}({self.extra_repr()})" + + +class _VisionDataset(_AbstractDataset): + """Implements an abstract dataset + + Args: + url: URL of the dataset + file_name: name of the file once downloaded + file_hash: expected SHA256 of the file + extract_archive: whether the downloaded file is an archive to be extracted + download: whether the dataset should be downloaded if not present on disk + overwrite: whether the archive should be re-extracted + cache_dir: cache directory + cache_subdir: subfolder to use in the cache + """ + + def __init__( + self, + url: str, + file_name: Optional[str] = None, + file_hash: Optional[str] = None, + extract_archive: bool = False, + download: bool = False, + overwrite: bool = False, + cache_dir: Optional[str] = None, + cache_subdir: Optional[str] = None, + **kwargs: Any, + ) -> None: + + cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'doctr') if cache_dir is None else cache_dir + cache_subdir = 'datasets' if cache_subdir is None else cache_subdir + + file_name = file_name if isinstance(file_name, str) else os.path.basename(url) + # Download the file if not present + archive_path: Union[str, Path] = os.path.join(cache_dir, cache_subdir, file_name) + + if not os.path.exists(archive_path) and not download: + raise ValueError("the dataset needs to be downloaded first with download=True") + + archive_path = download_from_url(url, file_name, file_hash, cache_dir=cache_dir, cache_subdir=cache_subdir) + + # Extract the archive + if extract_archive: + archive_path = Path(archive_path) + dataset_path = archive_path.parent.joinpath(archive_path.stem) + if not dataset_path.is_dir() or overwrite: + shutil.unpack_archive(archive_path, dataset_path) + + super().__init__(dataset_path if extract_archive else archive_path, **kwargs) diff --git a/doctr/datasets/datasets/pytorch.py b/doctr/datasets/datasets/pytorch.py new file mode 100644 index 0000000000..6dd43a5161 --- /dev/null +++ b/doctr/datasets/datasets/pytorch.py @@ -0,0 +1,37 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os +from typing import Any, List, Tuple + +import torch + +from doctr.io import read_img_as_tensor + +from .base import _AbstractDataset, _VisionDataset + +__all__ = ['AbstractDataset', 'VisionDataset'] + + +class AbstractDataset(_AbstractDataset): + + def _read_sample(self, index: int) -> Tuple[torch.Tensor, Any]: + img_name, target = self.data[index] + # Read image + img = read_img_as_tensor(os.path.join(self.root, img_name), dtype=torch.float32) + + return img, target + + @staticmethod + def collate_fn(samples: List[Tuple[torch.Tensor, Any]]) -> Tuple[torch.Tensor, List[Any]]: + + images, targets = zip(*samples) + images = torch.stack(images, dim=0) + + return images, list(targets) + + +class VisionDataset(AbstractDataset, _VisionDataset): + pass diff --git a/doctr/datasets/datasets/tensorflow.py b/doctr/datasets/datasets/tensorflow.py new file mode 100644 index 0000000000..96821c8cc3 --- /dev/null +++ b/doctr/datasets/datasets/tensorflow.py @@ -0,0 +1,37 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os +from typing import Any, List, Tuple + +import tensorflow as tf + +from doctr.io import read_img_as_tensor + +from .base import _AbstractDataset, _VisionDataset + +__all__ = ['AbstractDataset', 'VisionDataset'] + + +class AbstractDataset(_AbstractDataset): + + def _read_sample(self, index: int) -> Tuple[tf.Tensor, Any]: + img_name, target = self.data[index] + # Read image + img = read_img_as_tensor(os.path.join(self.root, img_name), dtype=tf.float32) + + return img, target + + @staticmethod + def collate_fn(samples: List[Tuple[tf.Tensor, Any]]) -> Tuple[tf.Tensor, List[Any]]: + + images, targets = zip(*samples) + images = tf.stack(images, axis=0) + + return images, list(targets) + + +class VisionDataset(AbstractDataset, _VisionDataset): + pass diff --git a/doctr/datasets/detection.py b/doctr/datasets/detection.py new file mode 100644 index 0000000000..80f64c63e6 --- /dev/null +++ b/doctr/datasets/detection.py @@ -0,0 +1,64 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import json +import os +from typing import Any, List, Tuple + +import numpy as np + +from doctr.io.image import get_img_shape +from doctr.utils.geometry import convert_to_relative_coords + +from .datasets import AbstractDataset + +__all__ = ["DetectionDataset"] + + +class DetectionDataset(AbstractDataset): + """Implements a text detection dataset + + Example:: + >>> from doctr.datasets import DetectionDataset + >>> train_set = DetectionDataset(img_folder="/path/to/images", label_path="/path/to/labels.json") + >>> img, target = train_set[0] + + Args: + img_folder: folder with all the images of the dataset + label_path: path to the annotations of each image + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `AbstractDataset`. + """ + + def __init__( + self, + img_folder: str, + label_path: str, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + super().__init__( + img_folder, + pre_transforms=lambda img, boxes: (img, convert_to_relative_coords(boxes, get_img_shape(img))), + **kwargs + ) + + # File existence check + if not os.path.exists(label_path): + raise FileNotFoundError(f"unable to locate {label_path}") + with open(label_path, 'rb') as f: + labels = json.load(f) + + self.data: List[Tuple[str, np.ndarray]] = [] + np_dtype = np.float32 + for img_name, label in labels.items(): + # File existence check + if not os.path.exists(os.path.join(self.root, img_name)): + raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}") + + polygons = np.asarray(label['polygons'], dtype=np_dtype) + geoms = polygons if use_polygons else np.concatenate((polygons.min(axis=1), polygons.max(axis=1)), axis=1) + + self.data.append((img_name, np.asarray(geoms, dtype=np_dtype))) diff --git a/doctr/datasets/doc_artefacts.py b/doctr/datasets/doc_artefacts.py new file mode 100644 index 0000000000..044af0b93a --- /dev/null +++ b/doctr/datasets/doc_artefacts.py @@ -0,0 +1,79 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import json +import os +from typing import Any, Dict, List, Tuple + +import numpy as np + +from .datasets import VisionDataset + +__all__ = ['DocArtefacts'] + + +class DocArtefacts(VisionDataset): + """Object detection dataset for non-textual elements in documents. + The dataset includes a variety of synthetic document pages with non-textual elements. + + Example:: + >>> from doctr.datasets import DocArtefacts + >>> train_set = DocArtefacts(download=True) + >>> img, target = train_set[0] + + Args: + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `VisionDataset`. + """ + + URL = 'https://github.com/mindee/doctr/releases/download/v0.4.0/artefact_detection-13fab8ce.zip' + SHA256 = '13fab8ced7f84583d9dccd0c634f046c3417e62a11fe1dea6efbbaba5052471b' + CLASSES = ["background", "qr_code", "bar_code", "logo", "photo"] + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + + super().__init__(self.URL, None, self.SHA256, True, **kwargs) + self.train = train + + # Update root + self.root = os.path.join(self.root, "train" if train else "val") + # List images + tmp_root = os.path.join(self.root, 'images') + with open(os.path.join(self.root, "labels.json"), "rb") as f: + labels = json.load(f) + self.data: List[Tuple[str, Dict[str, Any]]] = [] + img_list = os.listdir(tmp_root) + if len(labels) != len(img_list): + raise AssertionError('the number of images and labels do not match') + np_dtype = np.float32 + for img_name, label in labels.items(): + # File existence check + if not os.path.exists(os.path.join(tmp_root, img_name)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}") + + # xmin, ymin, xmax, ymax + boxes = np.asarray([obj['geometry'] for obj in label], dtype=np_dtype) + classes = np.asarray([self.CLASSES.index(obj['label']) for obj in label], dtype=np.int64) + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + boxes = np.stack( + [ + np.stack([boxes[:, 0], boxes[:, 1]], axis=-1), + np.stack([boxes[:, 2], boxes[:, 1]], axis=-1), + np.stack([boxes[:, 2], boxes[:, 3]], axis=-1), + np.stack([boxes[:, 0], boxes[:, 3]], axis=-1), + ], axis=1 + ) + self.data.append((img_name, dict(boxes=boxes, labels=classes))) + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/funsd.py b/doctr/datasets/funsd.py new file mode 100644 index 0000000000..b0de69a7d0 --- /dev/null +++ b/doctr/datasets/funsd.py @@ -0,0 +1,93 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import numpy as np + +from .datasets import VisionDataset +from .utils import convert_target_to_relative + +__all__ = ['FUNSD'] + + +class FUNSD(VisionDataset): + """FUNSD dataset from `"FUNSD: A Dataset for Form Understanding in Noisy Scanned Documents" + `_. + + Example:: + >>> from doctr.datasets import FUNSD + >>> train_set = FUNSD(train=True, download=True) + >>> img, target = train_set[0] + + Args: + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `VisionDataset`. + """ + + URL = 'https://guillaumejaume.github.io/FUNSD/dataset.zip' + SHA256 = 'c31735649e4f441bcbb4fd0f379574f7520b42286e80b01d80b445649d54761f' + FILE_NAME = 'funsd.zip' + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + + super().__init__( + self.URL, + self.FILE_NAME, + self.SHA256, + True, + pre_transforms=convert_target_to_relative, + **kwargs + ) + self.train = train + np_dtype = np.float32 + + # Use the subset + subfolder = os.path.join('dataset', 'training_data' if train else 'testing_data') + + # # List images + tmp_root = os.path.join(self.root, subfolder, 'images') + self.data: List[Tuple[str, Dict[str, Any]]] = [] + for img_path in os.listdir(tmp_root): + # File existence check + if not os.path.exists(os.path.join(tmp_root, img_path)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}") + + stem = Path(img_path).stem + with open(os.path.join(self.root, subfolder, 'annotations', f"{stem}.json"), 'rb') as f: + data = json.load(f) + + _targets = [(word['text'], word['box']) for block in data['form'] + for word in block['words'] if len(word['text']) > 0] + text_targets, box_targets = zip(*_targets) + if use_polygons: + # xmin, ymin, xmax, ymax -> (x, y) coordinates of top left, top right, bottom right, bottom left corners + box_targets = [ + [ + [box[0], box[1]], + [box[2], box[1]], + [box[2], box[3]], + [box[0], box[3]], + ] for box in box_targets + ] + + self.data.append(( + img_path, + dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=list(text_targets)), + )) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/generator/__init__.py b/doctr/datasets/generator/__init__.py new file mode 100644 index 0000000000..059f261e82 --- /dev/null +++ b/doctr/datasets/generator/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/datasets/generator/base.py b/doctr/datasets/generator/base.py new file mode 100644 index 0000000000..4259962fb9 --- /dev/null +++ b/doctr/datasets/generator/base.py @@ -0,0 +1,154 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import random +from typing import Any, Callable, List, Optional, Tuple, Union + +from PIL import Image, ImageDraw + +from doctr.io.image import tensor_from_pil +from doctr.utils.fonts import get_font + +from ..datasets import AbstractDataset + + +def synthesize_text_img( + text: str, + font_size: int = 32, + font_family: Optional[str] = None, + background_color: Optional[Tuple[int, int, int]] = None, + text_color: Optional[Tuple[int, int, int]] = None, +) -> Image: + """Generate a synthetic text image + + Args: + text: the text to render as an image + font_size: the size of the font + font_family: the font family (has to be installed on your system) + background_color: background color of the final image + text_color: text color on the final image + + Returns: + PIL image of the text + """ + + background_color = (0, 0, 0) if background_color is None else background_color + text_color = (255, 255, 255) if text_color is None else text_color + + font = get_font(font_family, font_size) + text_w, text_h = font.getsize(text) + h, w = int(round(1.3 * text_h)), int(round(1.1 * text_w)) + # If single letter, make the image square, otherwise expand to meet the text size + img_size = (h, w) if len(text) > 1 else (max(h, w), max(h, w)) + + img = Image.new('RGB', img_size[::-1], color=background_color) + d = ImageDraw.Draw(img) + + # Offset so that the text is centered + text_pos = (int(round((img_size[1] - text_w) / 2)), int(round((img_size[0] - text_h) / 2))) + # Draw the text + d.text(text_pos, text, font=font, fill=text_color) + return img + + +class _CharacterGenerator(AbstractDataset): + + def __init__( + self, + vocab: str, + num_samples: int, + cache_samples: bool = False, + font_family: Optional[Union[str, List[str]]] = None, + img_transforms: Optional[Callable[[Any], Any]] = None, + sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None, + ) -> None: + self.vocab = vocab + self._num_samples = num_samples + self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item] + # Validate fonts + if isinstance(font_family, list): + for font in self.font_family: + try: + _ = get_font(font, 10) + except OSError: + raise ValueError(f"unable to locate font: {font}") + self.img_transforms = img_transforms + self.sample_transforms = sample_transforms + + self._data: List[Image.Image] = [] + if cache_samples: + self._data = [ + (synthesize_text_img(char, font_family=font), idx) + for idx, char in enumerate(self.vocab) for font in self.font_family + ] + + def __len__(self) -> int: + return self._num_samples + + def _read_sample(self, index: int) -> Tuple[Any, int]: + # Samples are already cached + if len(self._data) > 0: + idx = index % len(self._data) + pil_img, target = self._data[idx] + else: + target = index % len(self.vocab) + pil_img = synthesize_text_img(self.vocab[target], font_family=random.choice(self.font_family)) + img = tensor_from_pil(pil_img) + + return img, target + + +class _WordGenerator(AbstractDataset): + + def __init__( + self, + vocab: str, + min_chars: int, + max_chars: int, + num_samples: int, + cache_samples: bool = False, + font_family: Optional[Union[str, List[str]]] = None, + img_transforms: Optional[Callable[[Any], Any]] = None, + sample_transforms: Optional[Callable[[Any, Any], Tuple[Any, Any]]] = None, + ) -> None: + self.vocab = vocab + self.wordlen_range = (min_chars, max_chars) + self._num_samples = num_samples + self.font_family = font_family if isinstance(font_family, list) else [font_family] # type: ignore[list-item] + # Validate fonts + if isinstance(font_family, list): + for font in self.font_family: + try: + _ = get_font(font, 10) + except OSError: + raise ValueError(f"unable to locate font: {font}") + self.img_transforms = img_transforms + self.sample_transforms = sample_transforms + + self._data: List[Image.Image] = [] + if cache_samples: + _words = [self._generate_string(*self.wordlen_range) for _ in range(num_samples)] + self._data = [ + (synthesize_text_img(text, font_family=random.choice(self.font_family)), text) + for text in _words + ] + + def _generate_string(self, min_chars: int, max_chars: int) -> str: + num_chars = random.randint(min_chars, max_chars) + return "".join(random.choice(self.vocab) for _ in range(num_chars)) + + def __len__(self) -> int: + return self._num_samples + + def _read_sample(self, index: int) -> Tuple[Any, str]: + # Samples are already cached + if len(self._data) > 0: + pil_img, target = self._data[index] + else: + target = self._generate_string(*self.wordlen_range) + pil_img = synthesize_text_img(target, font_family=random.choice(self.font_family)) + img = tensor_from_pil(pil_img) + + return img, target diff --git a/doctr/datasets/generator/pytorch.py b/doctr/datasets/generator/pytorch.py new file mode 100644 index 0000000000..f0e4141d5d --- /dev/null +++ b/doctr/datasets/generator/pytorch.py @@ -0,0 +1,54 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from torch.utils.data._utils.collate import default_collate + +from .base import _CharacterGenerator, _WordGenerator + +__all__ = ['CharacterGenerator', 'WordGenerator'] + + +class CharacterGenerator(_CharacterGenerator): + """Implements a character image generation dataset + + Example:: + >>> from doctr.datasets import CharacterGenerator + >>> ds = CharacterGenerator(vocab='abdef') + >>> img, target = ds[0] + + Args: + vocab: vocabulary to take the character from + num_samples: number of samples that will be generated iterating over the dataset + cache_samples: whether generated images should be cached firsthand + font_family: font to use to generate the text images + img_transforms: composable transformations that will be applied to each image + sample_transforms: composable transformations that will be applied to both the image and the target + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + setattr(self, 'collate_fn', default_collate) + + +class WordGenerator(_WordGenerator): + """Implements a character image generation dataset + + Example:: + >>> from doctr.datasets import WordGenerator + >>> ds = WordGenerator(vocab='abdef') + >>> img, target = ds[0] + + Args: + vocab: vocabulary to take the character from + min_chars: minimum number of characters in a word + max_chars: maximum number of characters in a word + num_samples: number of samples that will be generated iterating over the dataset + cache_samples: whether generated images should be cached firsthand + font_family: font to use to generate the text images + img_transforms: composable transformations that will be applied to each image + sample_transforms: composable transformations that will be applied to both the image and the target + """ + + pass diff --git a/doctr/datasets/generator/tensorflow.py b/doctr/datasets/generator/tensorflow.py new file mode 100644 index 0000000000..bb6d09c081 --- /dev/null +++ b/doctr/datasets/generator/tensorflow.py @@ -0,0 +1,61 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import tensorflow as tf + +from .base import _CharacterGenerator, _WordGenerator + +__all__ = ['CharacterGenerator', 'WordGenerator'] + + +class CharacterGenerator(_CharacterGenerator): + """Implements a character image generation dataset + + Example:: + >>> from doctr.datasets import CharacterGenerator + >>> ds = CharacterGenerator(vocab='abdef') + >>> img, target = ds[0] + + Args: + vocab: vocabulary to take the character from + num_samples: number of samples that will be generated iterating over the dataset + cache_samples: whether generated images should be cached firsthand + font_family: font to use to generate the text images + img_transforms: composable transformations that will be applied to each image + sample_transforms: composable transformations that will be applied to both the image and the target + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, **kwargs) + + @staticmethod + def collate_fn(samples): + + images, targets = zip(*samples) + images = tf.stack(images, axis=0) + + return images, tf.convert_to_tensor(targets) + + +class WordGenerator(_WordGenerator): + """Implements a character image generation dataset + + Example:: + >>> from doctr.datasets import WordGenerator + >>> ds = WordGenerator(vocab='abdef') + >>> img, target = ds[0] + + Args: + vocab: vocabulary to take the character from + min_chars: minimum number of characters in a word + max_chars: maximum number of characters in a word + num_samples: number of samples that will be generated iterating over the dataset + cache_samples: whether generated images should be cached firsthand + font_family: font to use to generate the text images + img_transforms: composable transformations that will be applied to each image + sample_transforms: composable transformations that will be applied to both the image and the target + """ + + pass diff --git a/doctr/datasets/ic03.py b/doctr/datasets/ic03.py new file mode 100644 index 0000000000..2073604e21 --- /dev/null +++ b/doctr/datasets/ic03.py @@ -0,0 +1,108 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os +from typing import Any, Dict, List, Tuple + +import defusedxml.ElementTree as ET +import numpy as np + +from .datasets import VisionDataset + +__all__ = ['IC03'] + + +class IC03(VisionDataset): + """IC03 dataset from `"ICDAR 2003 Robust Reading Competitions: Entries, Results and Future Directions" + `_. + + Example:: + >>> from doctr.datasets import IC03 + >>> train_set = IC03(train=True, download=True) + >>> img, target = train_set[0] + + Args: + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `VisionDataset`. + """ + + TRAIN = ('http://www.iapr-tc11.org/dataset/ICDAR2003_RobustReading/TrialTrain/scene.zip', + '9d86df514eb09dd693fb0b8c671ef54a0cfe02e803b1bbef9fc676061502eb94', + 'ic03_train.zip') + TEST = ('http://www.iapr-tc11.org/dataset/ICDAR2003_RobustReading/TrialTest/scene.zip', + 'dbc4b5fd5d04616b8464a1b42ea22db351ee22c2546dd15ac35611857ea111f8', + 'ic03_test.zip') + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + + url, sha256, file_name = self.TRAIN if train else self.TEST + super().__init__(url, file_name, sha256, True, **kwargs) + self.train = train + self.data: List[Tuple[str, Dict[str, Any]]] = [] + np_dtype = np.float32 + + # Load xml data + tmp_root = os.path.join( + self.root, 'SceneTrialTrain' if self.train else 'SceneTrialTest') if sha256 else self.root + xml_tree = ET.parse(os.path.join(tmp_root, 'words.xml')) + xml_root = xml_tree.getroot() + + for image in xml_root: + name, resolution, rectangles = image + + # File existence check + if not os.path.exists(os.path.join(tmp_root, name.text)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, name.text)}") + + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + _boxes = [ + [ + [float(rect.attrib['x']), float(rect.attrib['y'])], + [float(rect.attrib['x']) + float(rect.attrib['width']), float(rect.attrib['y'])], + [ + float(rect.attrib['x']) + float(rect.attrib['width']), + float(rect.attrib['y']) + float(rect.attrib['height']) + ], + [float(rect.attrib['x']), float(rect.attrib['y']) + float(rect.attrib['height'])], + ] + for rect in rectangles + ] + else: + # x_min, y_min, x_max, y_max + _boxes = [ + [float(rect.attrib['x']), float(rect.attrib['y']), # type: ignore[list-item] + float(rect.attrib['x']) + float(rect.attrib['width']), # type: ignore[list-item] + float(rect.attrib['y']) + float(rect.attrib['height'])] # type: ignore[list-item] + for rect in rectangles + ] + + # filter images without boxes + if len(_boxes) > 0: + # Convert them to relative + w, h = int(resolution.attrib['x']), int(resolution.attrib['y']) + boxes = np.asarray(_boxes, dtype=np_dtype) + if use_polygons: + boxes[:, :, 0] /= w + boxes[:, :, 1] /= h + else: + boxes[:, [0, 2]] /= w + boxes[:, [1, 3]] /= h + + # Get the labels + labels = [lab.text for rect in rectangles for lab in rect if lab.text] + + self.data.append((name.text, dict(boxes=boxes, labels=labels))) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/ic13.py b/doctr/datasets/ic13.py new file mode 100644 index 0000000000..404dfaee17 --- /dev/null +++ b/doctr/datasets/ic13.py @@ -0,0 +1,83 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import csv +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import numpy as np + +from .datasets import AbstractDataset +from .utils import convert_target_to_relative + +__all__ = ["IC13"] + + +class IC13(AbstractDataset): + """IC13 dataset from `"ICDAR 2013 Robust Reading Competition" `_. + + Example:: + >>> # NOTE: You need to download both image and label parts from Focused Scene Text challenge Task2.1 2013-2015. + >>> from doctr.datasets import IC13 + >>> train_set = IC13(img_folder="/path/to/Challenge2_Training_Task12_Images", + >>> label_folder="/path/to/Challenge2_Training_Task1_GT") + >>> img, target = train_set[0] + >>> test_set = IC13(img_folder="/path/to/Challenge2_Test_Task12_Images", + >>> label_folder="/path/to/Challenge2_Test_Task1_GT") + >>> img, target = test_set[0] + + Args: + img_folder: folder with all the images of the dataset + label_folder: folder with all annotation files for the images + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `AbstractDataset`. + """ + + def __init__( + self, + img_folder: str, + label_folder: str, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(img_folder, pre_transforms=convert_target_to_relative, **kwargs) + + # File existence check + if not os.path.exists(label_folder) or not os.path.exists(img_folder): + raise FileNotFoundError( + f"unable to locate {label_folder if not os.path.exists(label_folder) else img_folder}") + + self.data: List[Tuple[Path, Dict[str, Any]]] = [] + np_dtype = np.float32 + + img_names = os.listdir(img_folder) + + for img_name in img_names: + + img_path = Path(img_folder, img_name) + label_path = Path(label_folder, "gt_" + Path(img_name).stem + ".txt") + + with open(label_path, newline='\n') as f: + _lines = [ + [val[:-1] if val.endswith(",") else val for val in row] + for row in csv.reader(f, delimiter=' ', quotechar="'") + ] + labels = [line[-1] for line in _lines] + # xmin, ymin, xmax, ymax + box_targets = np.array([list(map(int, line[:4])) for line in _lines], dtype=np_dtype) + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + box_targets = np.array( + [ + [ + [coords[0], coords[1]], + [coords[2], coords[1]], + [coords[2], coords[3]], + [coords[0], coords[3]], + ] for coords in box_targets + ], dtype=np_dtype + ) + self.data.append((img_path, dict(boxes=box_targets, labels=labels))) diff --git a/doctr/datasets/iiit5k.py b/doctr/datasets/iiit5k.py new file mode 100644 index 0000000000..55cb1e91fc --- /dev/null +++ b/doctr/datasets/iiit5k.py @@ -0,0 +1,93 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import numpy as np +import scipy.io as sio + +from .datasets import VisionDataset +from .utils import convert_target_to_relative + +__all__ = ['IIIT5K'] + + +class IIIT5K(VisionDataset): + """IIIT-5K character-level localization dataset from + `"BMVC 2012 Scene Text Recognition using Higher Order Language Priors" + `_. + + Example:: + >>> # NOTE: this dataset is for character-level localization + >>> from doctr.datasets import IIIT5K + >>> train_set = IIIT5K(train=True, download=True) + >>> img, target = train_set[0] + + Args: + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `VisionDataset`. + """ + + URL = 'https://cvit.iiit.ac.in/images/Projects/SceneTextUnderstanding/IIIT5K-Word_V3.0.tar.gz' + SHA256 = '7872c9efbec457eb23f3368855e7738f72ce10927f52a382deb4966ca0ffa38e' + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + + super().__init__( + self.URL, + None, + file_hash=self.SHA256, + extract_archive=True, + pre_transforms=convert_target_to_relative, + **kwargs + ) + self.train = train + + # Load mat data + tmp_root = os.path.join(self.root, 'IIIT5K') if self.SHA256 else self.root + mat_file = 'trainCharBound' if self.train else 'testCharBound' + mat_data = sio.loadmat(os.path.join(tmp_root, f'{mat_file}.mat'))[mat_file][0] + + self.data: List[Tuple[Path, Dict[str, Any]]] = [] + np_dtype = np.float32 + + for img_path, label, box_targets in mat_data: + _raw_path = img_path[0] + _raw_label = label[0] + + # File existence check + if not os.path.exists(os.path.join(tmp_root, _raw_path)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, _raw_path)}") + + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + box_targets = [ + [ + [box[0], box[1]], + [box[0] + box[2], box[1]], + [box[0] + box[2], box[1] + box[3]], + [box[0], box[1] + box[3]], + ] for box in box_targets + ] + else: + # xmin, ymin, xmax, ymax + box_targets = [[box[0], box[1], box[0] + box[2], box[1] + box[3]] for box in box_targets] + + # label are casted to list where each char corresponds to the character's bounding box + self.data.append((_raw_path, dict(boxes=np.asarray( + box_targets, dtype=np_dtype), labels=list(_raw_label)))) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/imgur5k.py b/doctr/datasets/imgur5k.py new file mode 100644 index 0000000000..c75d83d408 --- /dev/null +++ b/doctr/datasets/imgur5k.py @@ -0,0 +1,100 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import cv2 +import numpy as np + +from .datasets import AbstractDataset +from .utils import convert_target_to_relative + +__all__ = ["IMGUR5K"] + + +class IMGUR5K(AbstractDataset): + """IMGUR5K dataset from `"TextStyleBrush: Transfer of Text Aesthetics from a Single Example" + `_ | + `"repository" `_. + + Example:: + >>> # NOTE: You need to download/generate the dataset from the repository. + >>> from doctr.datasets import IMGUR5K + >>> train_set = IMGUR5K(train=True, img_folder="/path/to/IMGUR5K-Handwriting-Dataset/images", + >>> label_path="/path/to/IMGUR5K-Handwriting-Dataset/dataset_info/imgur5k_annotations.json") + >>> img, target = train_set[0] + >>> test_set = IMGUR5K(train=False, img_folder="/path/to/IMGUR5K-Handwriting-Dataset/images", + >>> label_path="/path/to/IMGUR5K-Handwriting-Dataset/dataset_info/imgur5k_annotations.json") + >>> img, target = test_set[0] + + Args: + img_folder: folder with all the images of the dataset + label_path: path to the annotations file of the dataset + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `AbstractDataset`. + """ + + def __init__( + self, + img_folder: str, + label_path: str, + train: bool = True, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(img_folder, pre_transforms=convert_target_to_relative, **kwargs) + + # File existence check + if not os.path.exists(label_path) or not os.path.exists(img_folder): + raise FileNotFoundError( + f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}") + + self.data: List[Tuple[Path, Dict[str, Any]]] = [] + self.train = train + np_dtype = np.float32 + + img_names = os.listdir(img_folder) + train_samples = int(len(img_names) * 0.9) + set_slice = slice(train_samples) if self.train else slice(train_samples, None) + + with open(label_path) as f: + annotation_file = json.load(f) + + for img_name in img_names[set_slice]: + img_path = Path(img_folder, img_name) + img_id = img_name.split(".")[0] + + # File existence check + if not os.path.exists(os.path.join(self.root, img_name)): + raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}") + + # some files have no annotations which are marked with only a dot in the 'word' key + # ref: https://github.com/facebookresearch/IMGUR5K-Handwriting-Dataset/blob/main/README.md + if img_id not in annotation_file['index_to_ann_map'].keys(): + continue + ann_ids = annotation_file['index_to_ann_map'][img_id] + annotations = [annotation_file['ann_id'][a_id] for a_id in ann_ids] + + labels = [ann['word'] for ann in annotations if ann['word'] != '.'] + # x_center, y_center, width, height, angle + _boxes = [list(map(float, ann['bounding_box'].strip('[ ]').split(', '))) + for ann in annotations if ann['word'] != '.'] + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + box_targets = [cv2.boxPoints(((box[0], box[1]), (box[2], box[3]), box[4])) for box in _boxes] + + if not use_polygons: + # xmin, ymin, xmax, ymax + box_targets = [np.concatenate((points.min(0), points.max(0)), axis=-1) for points in box_targets] + + # filter images without boxes + if len(box_targets) > 0: + self.data.append((img_path, dict(boxes=np.asarray(box_targets, dtype=np_dtype), labels=labels))) + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/loader.py b/doctr/datasets/loader.py new file mode 100644 index 0000000000..a435bc345f --- /dev/null +++ b/doctr/datasets/loader.py @@ -0,0 +1,101 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import math +from typing import Callable, Optional + +import numpy as np +import tensorflow as tf + +from doctr.utils.multithreading import multithread_exec + +__all__ = ["DataLoader"] + + +def default_collate(samples): + """Collate multiple elements into batches + + Args: + samples: list of N tuples containing M elements + + Returns: + Tuple of M sequences contianing N elements each + """ + + batch_data = zip(*samples) + + tf_data = tuple(tf.stack(elt, axis=0) for elt in batch_data) + + return tf_data + + +class DataLoader: + """Implements a dataset wrapper for fast data loading + + Example:: + >>> from doctr.datasets import FUNSD, DataLoader + >>> train_set = CORD(train=True, download=True) + >>> train_loader = DataLoader(train_set, batch_size=32) + >>> train_iter = iter(train_loader) + >>> images, targets = next(train_iter) + + Args: + dataset: the dataset + shuffle: whether the samples should be shuffled before passing it to the iterator + batch_size: number of elements in each batch + drop_last: if `True`, drops the last batch if it isn't full + num_workers: number of workers to use for data loading + collate_fn: function to merge samples into a batch + """ + + def __init__( + self, + dataset, + shuffle: bool = True, + batch_size: int = 1, + drop_last: bool = False, + num_workers: Optional[int] = None, + collate_fn: Optional[Callable] = None, + ) -> None: + self.dataset = dataset + self.shuffle = shuffle + self.batch_size = batch_size + nb = len(self.dataset) / batch_size + self.num_batches = math.floor(nb) if drop_last else math.ceil(nb) + if collate_fn is None: + self.collate_fn = self.dataset.collate_fn if hasattr(self.dataset, 'collate_fn') else default_collate + else: + self.collate_fn = collate_fn + self.num_workers = num_workers + self.reset() + + def __len__(self) -> int: + return self.num_batches + + def reset(self) -> None: + # Updates indices after each epoch + self._num_yielded = 0 + self.indices = np.arange(len(self.dataset)) + if self.shuffle is True: + np.random.shuffle(self.indices) + + def __iter__(self): + self.reset() + return self + + def __next__(self): + if self._num_yielded < self.num_batches: + # Get next indices + idx = self._num_yielded * self.batch_size + indices = self.indices[idx: min(len(self.dataset), idx + self.batch_size)] + + samples = multithread_exec(self.dataset.__getitem__, indices, threads=self.num_workers) + + batch_data = self.collate_fn(samples) + + self._num_yielded += 1 + return batch_data + else: + raise StopIteration diff --git a/doctr/datasets/ocr.py b/doctr/datasets/ocr.py new file mode 100644 index 0000000000..0e9b78960c --- /dev/null +++ b/doctr/datasets/ocr.py @@ -0,0 +1,65 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import json +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import numpy as np + +from .datasets import AbstractDataset + +__all__ = ['OCRDataset'] + + +class OCRDataset(AbstractDataset): + """Implements an OCR dataset + + Args: + img_folder: local path to image folder (all jpg at the root) + label_file: local path to the label file + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `AbstractDataset`. + """ + + def __init__( + self, + img_folder: str, + label_file: str, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + super().__init__(img_folder, **kwargs) + + # List images + self.data: List[Tuple[str, Dict[str, Any]]] = [] + np_dtype = np.float32 + with open(label_file, 'rb') as f: + data = json.load(f) + + for img_name, annotations in data.items(): + # Get image path + img_name = Path(img_name) + # File existence check + if not os.path.exists(os.path.join(self.root, img_name)): + raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}") + + # handle empty images + if len(annotations["typed_words"]) == 0: + self.data.append((img_name, dict(boxes=np.zeros((0, 4), dtype=np_dtype), labels=[]))) + continue + # Unpack the straight boxes (xmin, ymin, xmax, ymax) + geoms = [list(map(float, obj['geometry'][:4])) for obj in annotations['typed_words']] + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + geoms = [ + [geom[:2], [geom[2], geom[1]], geom[2:], [geom[0], geom[3]]] # type: ignore[list-item] + for geom in geoms + ] + + text_targets = [obj['value'] for obj in annotations['typed_words']] + + self.data.append((img_name, dict(boxes=np.asarray(geoms, dtype=np_dtype), labels=text_targets))) diff --git a/doctr/datasets/recognition.py b/doctr/datasets/recognition.py new file mode 100644 index 0000000000..5c9be584d5 --- /dev/null +++ b/doctr/datasets/recognition.py @@ -0,0 +1,55 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import json +import os +from pathlib import Path +from typing import Any, List, Tuple + +from .datasets import AbstractDataset + +__all__ = ["RecognitionDataset"] + + +class RecognitionDataset(AbstractDataset): + """Dataset implementation for text recognition tasks + + Example:: + >>> from doctr.datasets import RecognitionDataset + >>> train_set = RecognitionDataset(img_folder="/path/to/images", labels_path="/path/to/labels.json") + >>> img, target = train_set[0] + + Args: + img_folder: path to the images folder + labels_path: pathe to the json file containing all labels (character sequences) + **kwargs: keyword arguments from `AbstractDataset`. + """ + + def __init__( + self, + img_folder: str, + labels_path: str, + **kwargs: Any, + ) -> None: + super().__init__(img_folder, **kwargs) + + self.data: List[Tuple[str, str]] = [] + with open(labels_path) as f: + labels = json.load(f) + + for img_name, label in labels.items(): + if not os.path.exists(os.path.join(self.root, img_name)): + raise FileNotFoundError(f"unable to locate {os.path.join(self.root, img_name)}") + + self.data.append((img_name, label)) + + def merge_dataset(self, ds: AbstractDataset) -> None: + # Update data with new root for self + self.data = [(str(Path(self.root).joinpath(img_path)), label) for img_path, label in self.data] + # Define new root + self.root = Path("/") + # Merge with ds data + for img_path, label in ds.data: + self.data.append((str(Path(ds.root).joinpath(img_path)), label)) diff --git a/doctr/datasets/sroie.py b/doctr/datasets/sroie.py new file mode 100644 index 0000000000..95e94e69b5 --- /dev/null +++ b/doctr/datasets/sroie.py @@ -0,0 +1,79 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import csv +import os +from pathlib import Path +from typing import Any, Dict, List, Tuple + +import numpy as np + +from .datasets import VisionDataset +from .utils import convert_target_to_relative + +__all__ = ['SROIE'] + + +class SROIE(VisionDataset): + """SROIE dataset from `"ICDAR2019 Competition on Scanned Receipt OCR and Information Extraction" + `_. + + Example:: + >>> from doctr.datasets import SROIE + >>> train_set = SROIE(train=True, download=True) + >>> img, target = train_set[0] + + Args: + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `VisionDataset`. + """ + + TRAIN = ('https://github.com/mindee/doctr/releases/download/v0.1.1/sroie2019_train_task1.zip', + 'd4fa9e60abb03500d83299c845b9c87fd9c9430d1aeac96b83c5d0bb0ab27f6f') + TEST = ('https://github.com/mindee/doctr/releases/download/v0.1.1/sroie2019_test.zip', + '41b3c746a20226fddc80d86d4b2a903d43b5be4f521dd1bbe759dbf8844745e2') + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + + url, sha256 = self.TRAIN if train else self.TEST + super().__init__(url, None, sha256, True, pre_transforms=convert_target_to_relative, **kwargs) + self.train = train + + tmp_root = os.path.join(self.root, 'images') + self.data: List[Tuple[str, Dict[str, Any]]] = [] + np_dtype = np.float32 + + for img_path in os.listdir(tmp_root): + + # File existence check + if not os.path.exists(os.path.join(tmp_root, img_path)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path)}") + + stem = Path(img_path).stem + with open(os.path.join(self.root, 'annotations', f"{stem}.txt"), encoding='latin') as f: + _rows = [row for row in list(csv.reader(f, delimiter=',')) if len(row) > 0] + + labels = [",".join(row[8:]) for row in _rows] + # reorder coordinates (8 -> (4,2) -> + # (x, y) coordinates of top left, top right, bottom right, bottom left corners) and filter empty lines + coords = np.stack([np.array(list(map(int, row[:8])), dtype=np_dtype).reshape((4, 2)) + for row in _rows], axis=0) + + if not use_polygons: + # xmin, ymin, xmax, ymax + coords = np.concatenate((coords.min(axis=1), coords.max(axis=1)), axis=1) + + self.data.append((img_path, dict(boxes=coords, labels=labels))) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/svhn.py b/doctr/datasets/svhn.py new file mode 100644 index 0000000000..3a0096f652 --- /dev/null +++ b/doctr/datasets/svhn.py @@ -0,0 +1,114 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os +from typing import Any, Dict, List, Tuple + +import h5py +import numpy as np +from tqdm import tqdm + +from .datasets import VisionDataset +from .utils import convert_target_to_relative + +__all__ = ['SVHN'] + + +class SVHN(VisionDataset): + """SVHN dataset from `"The Street View House Numbers (SVHN) Dataset" + `_. + + Example:: + >>> from doctr.datasets import SVHN + >>> train_set = SVHN(train=True, download=True) + >>> img, target = train_set[0] + + Args: + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `VisionDataset`. + """ + TRAIN = ('http://ufldl.stanford.edu/housenumbers/train.tar.gz', + '4b17bb33b6cd8f963493168f80143da956f28ec406cc12f8e5745a9f91a51898', + 'svhn_train.tar') + + TEST = ('http://ufldl.stanford.edu/housenumbers/test.tar.gz', + '57ac9ceb530e4aa85b55d991be8fc49c695b3d71c6f6a88afea86549efde7fb5', + 'svhn_test.tar') + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + + url, sha256, name = self.TRAIN if train else self.TEST + super().__init__( + url, + file_name=name, + file_hash=sha256, + extract_archive=True, + pre_transforms=convert_target_to_relative, + **kwargs + ) + self.train = train + self.data: List[Tuple[str, Dict[str, Any]]] = [] + np_dtype = np.float32 + + tmp_root = os.path.join(self.root, 'train' if train else 'test') + + # Load mat data (matlab v7.3 - can not be loaded with scipy) + with h5py.File(os.path.join(tmp_root, 'digitStruct.mat'), 'r') as f: + img_refs = f['digitStruct/name'] + box_refs = f['digitStruct/bbox'] + for img_ref, box_ref in tqdm(iterable=zip(img_refs, box_refs), desc='Unpacking SVHN', total=len(img_refs)): + # convert ascii matrix to string + img_name = "".join(map(chr, f[img_ref[0]][()].flatten())) + + # File existence check + if not os.path.exists(os.path.join(tmp_root, img_name)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_name)}") + + # Unpack the information + box = f[box_ref[0]] + if box['left'].shape[0] == 1: + box_dict = {k: [int(vals[0][0])] for k, vals in box.items()} + else: + box_dict = {k: [int(f[v[0]][()].item()) for v in vals] for k, vals in box.items()} + + # Convert it to the right format + coords = np.array([ + box_dict['left'], + box_dict['top'], + box_dict['width'], + box_dict['height'] + ], dtype=np_dtype).transpose() + label_targets = list(map(str, box_dict['label'])) + + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + box_targets = np.stack( + [ + np.stack([coords[:, 0], coords[:, 1]], axis=-1), + np.stack([coords[:, 0] + coords[:, 2], coords[:, 1]], axis=-1), + np.stack([coords[:, 0] + coords[:, 2], coords[:, 1] + coords[:, 3]], axis=-1), + np.stack([coords[:, 0], coords[:, 1] + coords[:, 3]], axis=-1), + ], axis=1 + ) + else: + # x, y, width, height -> xmin, ymin, xmax, ymax + box_targets = np.stack([ + coords[:, 0], + coords[:, 1], + coords[:, 0] + coords[:, 2], + coords[:, 1] + coords[:, 3], + ], axis=-1) + self.data.append((img_name, dict(boxes=box_targets, labels=label_targets))) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/svt.py b/doctr/datasets/svt.py new file mode 100644 index 0000000000..65d5455723 --- /dev/null +++ b/doctr/datasets/svt.py @@ -0,0 +1,100 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os +from typing import Any, Dict, List, Tuple + +import defusedxml.ElementTree as ET +import numpy as np + +from .datasets import VisionDataset + +__all__ = ['SVT'] + + +class SVT(VisionDataset): + """SVT dataset from `"The Street View Text Dataset - UCSD Computer Vision" + `_. + + Example:: + >>> from doctr.datasets import SVT + >>> train_set = SVT(train=True, download=True) + >>> img, target = train_set[0] + + Args: + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `VisionDataset`. + """ + + URL = 'http://vision.ucsd.edu/~kai/svt/svt.zip' + SHA256 = '63b3d55e6b6d1e036e2a844a20c034fe3af3c32e4d914d6e0c4a3cd43df3bebf' + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + + super().__init__(self.URL, None, self.SHA256, True, **kwargs) + self.train = train + self.data: List[Tuple[str, Dict[str, Any]]] = [] + np_dtype = np.float32 + + # Load xml data + tmp_root = os.path.join(self.root, 'svt1') if self.SHA256 else self.root + xml_tree = ET.parse(os.path.join(tmp_root, 'train.xml')) if self.train else ET.parse( + os.path.join(tmp_root, 'test.xml')) + xml_root = xml_tree.getroot() + + for image in xml_root: + name, _, _, resolution, rectangles = image + + # File existence check + if not os.path.exists(os.path.join(tmp_root, name.text)): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, name.text)}") + + if use_polygons: + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + _boxes = [ + [ + [float(rect.attrib['x']), float(rect.attrib['y'])], + [float(rect.attrib['x']) + float(rect.attrib['width']), float(rect.attrib['y'])], + [ + float(rect.attrib['x']) + float(rect.attrib['width']), + float(rect.attrib['y']) + float(rect.attrib['height']) + ], + [float(rect.attrib['x']), float(rect.attrib['y']) + float(rect.attrib['height'])], + ] + for rect in rectangles + ] + else: + # x_min, y_min, x_max, y_max + _boxes = [ + [float(rect.attrib['x']), float(rect.attrib['y']), # type: ignore[list-item] + float(rect.attrib['x']) + float(rect.attrib['width']), # type: ignore[list-item] + float(rect.attrib['y']) + float(rect.attrib['height'])] # type: ignore[list-item] + for rect in rectangles + ] + # Convert them to relative + w, h = int(resolution.attrib['x']), int(resolution.attrib['y']) + boxes = np.asarray(_boxes, dtype=np_dtype) + if use_polygons: + boxes[:, :, 0] /= w + boxes[:, :, 1] /= h + else: + boxes[:, [0, 2]] /= w + boxes[:, [1, 3]] /= h + + # Get the labels + labels = [lab.text for rect in rectangles for lab in rect] + + self.data.append((name.text, dict(boxes=boxes, labels=labels))) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/synthtext.py b/doctr/datasets/synthtext.py new file mode 100644 index 0000000000..f83f75a743 --- /dev/null +++ b/doctr/datasets/synthtext.py @@ -0,0 +1,88 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os +from typing import Any, Dict, List, Tuple + +import numpy as np +from scipy import io as sio +from tqdm import tqdm + +from .datasets import VisionDataset +from .utils import convert_target_to_relative + +__all__ = ['SynthText'] + + +class SynthText(VisionDataset): + """SynthText dataset from `"Synthetic Data for Text Localisation in Natural Images" + `_ | `"repository" `_ | + `"website" `_. + + Example:: + >>> from doctr.datasets import SynthText + >>> train_set = SynthText(train=True, download=True) + >>> img, target = train_set[0] + + Args: + train: whether the subset should be the training one + use_polygons: whether polygons should be considered as rotated bounding box (instead of straight ones) + **kwargs: keyword arguments from `VisionDataset`. + """ + + URL = 'https://thor.robots.ox.ac.uk/~vgg/data/scenetext/SynthText.zip' + SHA256 = '28ab030485ec8df3ed612c568dd71fb2793b9afbfa3a9d9c6e792aef33265bf1' + + def __init__( + self, + train: bool = True, + use_polygons: bool = False, + **kwargs: Any, + ) -> None: + + super().__init__( + self.URL, + None, + file_hash=None, + extract_archive=True, + pre_transforms=convert_target_to_relative, + **kwargs + ) + self.train = train + + # Load mat data + tmp_root = os.path.join(self.root, 'SynthText') if self.SHA256 else self.root + mat_data = sio.loadmat(os.path.join(tmp_root, 'gt.mat')) + train_samples = int(len(mat_data['imnames'][0]) * 0.9) + set_slice = slice(train_samples) if self.train else slice(train_samples, None) + paths = mat_data['imnames'][0][set_slice] + boxes = mat_data['wordBB'][0][set_slice] + labels = mat_data['txt'][0][set_slice] + del mat_data + + self.data: List[Tuple[str, Dict[str, Any]]] = [] + np_dtype = np.float32 + + for img_path, word_boxes, txt in tqdm(iterable=zip(paths, boxes, labels), + desc='Unpacking SynthText', total=len(paths)): + # File existence check + if not os.path.exists(os.path.join(tmp_root, img_path[0])): + raise FileNotFoundError(f"unable to locate {os.path.join(tmp_root, img_path[0])}") + + labels = [elt for word in txt.tolist() for elt in word.split()] + # (x, y) coordinates of top left, top right, bottom right, bottom left corners + word_boxes = word_boxes.transpose(2, 1, 0) if word_boxes.ndim == 3 else np.expand_dims( + word_boxes.transpose(1, 0), axis=0) + + if not use_polygons: + # xmin, ymin, xmax, ymax + word_boxes = np.concatenate((word_boxes.min(axis=1), word_boxes.max(axis=1)), axis=1) + + self.data.append((img_path[0], dict(boxes=np.asarray(word_boxes, dtype=np_dtype), labels=labels))) + + self.root = tmp_root + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/doctr/datasets/utils.py b/doctr/datasets/utils.py new file mode 100644 index 0000000000..e32709c2cf --- /dev/null +++ b/doctr/datasets/utils.py @@ -0,0 +1,163 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import string +import unicodedata +from collections.abc import Sequence +from functools import partial +from typing import Any, Dict, List, Optional +from typing import Sequence as SequenceType +from typing import Tuple, TypeVar, Union + +import numpy as np + +from doctr.io.image import get_img_shape +from doctr.utils.geometry import convert_to_relative_coords + +from .vocabs import VOCABS + +__all__ = ['translate', 'encode_string', 'decode_sequence', 'encode_sequences'] + +ImageTensor = TypeVar('ImageTensor') + + +def translate( + input_string: str, + vocab_name: str, + unknown_char: str = '■', +) -> str: + """Translate a string input in a given vocabulary + + Args: + input_string: input string to translate + vocab_name: vocabulary to use (french, latin, ...) + unknown_char: unknown character for non-translatable characters + + Returns: + A string translated in a given vocab""" + + if VOCABS.get(vocab_name) is None: + raise KeyError("output vocabulary must be in vocabs dictionnary") + + translated = '' + for char in input_string: + if char not in VOCABS[vocab_name]: + # we need to translate char into a vocab char + if char in string.whitespace: + # remove whitespaces + continue + # normalize character if it is not in vocab + char = unicodedata.normalize('NFD', char).encode('ascii', 'ignore').decode('ascii') + if char == '' or char not in VOCABS[vocab_name]: + # if normalization fails or char still not in vocab, return unknown character) + char = unknown_char + translated += char + return translated + + +def encode_string( + input_string: str, + vocab: str, +) -> List[int]: + """Given a predefined mapping, encode the string to a sequence of numbers + + Args: + input_string: string to encode + vocab: vocabulary (string), the encoding is given by the indexing of the character sequence + + Returns: + A list encoding the input_string""" + + return list(map(vocab.index, input_string)) # type: ignore[arg-type] + + +def decode_sequence( + input_seq: Union[np.array, SequenceType[int]], + mapping: str, +) -> str: + """Given a predefined mapping, decode the sequence of numbers to a string + + Args: + input_seq: array to decode + mapping: vocabulary (string), the encoding is given by the indexing of the character sequence + + Returns: + A string, decoded from input_seq + """ + + if not isinstance(input_seq, (Sequence, np.ndarray)): + raise TypeError("Invalid sequence type") + if isinstance(input_seq, np.ndarray) and (input_seq.dtype != np.int_ or input_seq.max() >= len(mapping)): + raise AssertionError("Input must be an array of int, with max less than mapping size") + + return ''.join(map(mapping.__getitem__, input_seq)) + + +def encode_sequences( + sequences: List[str], + vocab: str, + target_size: Optional[int] = None, + eos: int = -1, + sos: Optional[int] = None, + pad: Optional[int] = None, + dynamic_seq_length: bool = False, + **kwargs: Any, +) -> np.ndarray: + """Encode character sequences using a given vocab as mapping + + Args: + sequences: the list of character sequences of size N + vocab: the ordered vocab to use for encoding + target_size: maximum length of the encoded data + eos: encoding of End Of String + sos: optional encoding of Start Of String + pad: optional encoding for padding. In case of padding, all sequences are followed by 1 EOS then PAD + dynamic_seq_length: if `target_size` is specified, uses it as upper bound and enables dynamic sequence size + + Returns: + the padded encoded data as a tensor + """ + + if 0 <= eos < len(vocab): + raise ValueError("argument 'eos' needs to be outside of vocab possible indices") + + if not isinstance(target_size, int) or dynamic_seq_length: + # Maximum string length + EOS + max_length = max(len(w) for w in sequences) + 1 + if isinstance(sos, int): + max_length += 1 + if isinstance(pad, int): + max_length += 1 + target_size = max_length if not isinstance(target_size, int) else min(max_length, target_size) + + # Pad all sequences + if isinstance(pad, int): # pad with padding symbol + if 0 <= pad < len(vocab): + raise ValueError("argument 'pad' needs to be outside of vocab possible indices") + # In that case, add EOS at the end of the word before padding + default_symbol = pad + else: # pad with eos symbol + default_symbol = eos + encoded_data = np.full([len(sequences), target_size], default_symbol, dtype=np.int32) + + # Encode the strings + for idx, seq in enumerate(map(partial(encode_string, vocab=vocab), sequences)): + if isinstance(pad, int): # add eos at the end of the sequence + seq.append(eos) + encoded_data[idx, :min(len(seq), target_size)] = seq[:min(len(seq), target_size)] + + if isinstance(sos, int): # place sos symbol at the beginning of each sequence + if 0 <= sos < len(vocab): + raise ValueError("argument 'sos' needs to be outside of vocab possible indices") + encoded_data = np.roll(encoded_data, 1) + encoded_data[:, 0] = sos + + return encoded_data + + +def convert_target_to_relative(img: ImageTensor, target: Dict[str, Any]) -> Tuple[ImageTensor, Dict[str, Any]]: + + target['boxes'] = convert_to_relative_coords(target['boxes'], get_img_shape(img)) + return img, target diff --git a/doctr/datasets/vocabs.py b/doctr/datasets/vocabs.py new file mode 100644 index 0000000000..9d0186dacb --- /dev/null +++ b/doctr/datasets/vocabs.py @@ -0,0 +1,33 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import string +from typing import Dict + +__all__ = ['VOCABS'] + + +VOCABS: Dict[str, str] = { + 'digits': string.digits, + 'ascii_letters': string.ascii_letters, + 'punctuation': string.punctuation, + 'currency': '£€¥¢฿', + 'ancient_greek': 'αβγδεζηθικλμνξοπρστυφχψωΑΒΓΔΕΖΗΘΙΚΛΜΝΞΟΠΡΣΤΥΦΧΨΩ', + 'arabic_letters': 'ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىي', + 'persian_letters': 'پچڢڤگ', + 'hindi_digits': '٠١٢٣٤٥٦٧٨٩', + 'arabic_diacritics': 'ًٌٍَُِّْ', + 'arabic_punctuation': '؟؛«»—' +} + +VOCABS['latin'] = VOCABS['digits'] + VOCABS['ascii_letters'] + VOCABS['punctuation'] +VOCABS['english'] = VOCABS['latin'] + '°' + VOCABS['currency'] +VOCABS['legacy_french'] = VOCABS['latin'] + '°' + 'àâéèêëîïôùûçÀÂÉÈËÎÏÔÙÛÇ' + VOCABS['currency'] +VOCABS['french'] = VOCABS['english'] + 'àâéèêëîïôùûüçÀÂÉÈÊËÎÏÔÙÛÜÇ' +VOCABS['portuguese'] = VOCABS['english'] + 'áàâãéêíïóôõúüçÁÀÂÃÉÊÍÏÓÔÕÚÜÇ' +VOCABS['spanish'] = VOCABS['english'] + 'áéíóúüñÁÉÍÓÚÜÑ' + '¡¿' +VOCABS['german'] = VOCABS['english'] + 'äöüßÄÖÜẞ' +VOCABS['arabic'] = (VOCABS['digits'] + VOCABS['hindi_digits'] + VOCABS['arabic_letters'] + VOCABS['persian_letters'] + + VOCABS['arabic_diacritics'] + VOCABS['arabic_punctuation'] + VOCABS['punctuation']) diff --git a/doctr/file_utils.py b/doctr/file_utils.py new file mode 100644 index 0000000000..98213c7391 --- /dev/null +++ b/doctr/file_utils.py @@ -0,0 +1,85 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +# Adapted from https://github.com/huggingface/transformers/blob/master/src/transformers/file_utils.py + +import importlib.util +import logging +import os +import sys + +if sys.version_info < (3, 8): + import importlib_metadata +else: + import importlib.metadata as importlib_metadata + + +__all__ = ['is_tf_available', 'is_torch_available'] + +ENV_VARS_TRUE_VALUES = {"1", "ON", "YES", "TRUE"} +ENV_VARS_TRUE_AND_AUTO_VALUES = ENV_VARS_TRUE_VALUES.union({"AUTO"}) + +USE_TF = os.environ.get("USE_TF", "AUTO").upper() +USE_TORCH = os.environ.get("USE_TORCH", "AUTO").upper() + + +if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VALUES: + _torch_available = importlib.util.find_spec("torch") is not None + if _torch_available: + try: + _torch_version = importlib_metadata.version("torch") + logging.info(f"PyTorch version {_torch_version} available.") + except importlib_metadata.PackageNotFoundError: + _torch_available = False +else: + logging.info("Disabling PyTorch because USE_TF is set") + _torch_available = False + + +if USE_TF in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TORCH not in ENV_VARS_TRUE_VALUES: + _tf_available = importlib.util.find_spec("tensorflow") is not None + if _tf_available: + candidates = ( + "tensorflow", + "tensorflow-cpu", + "tensorflow-gpu", + "tf-nightly", + "tf-nightly-cpu", + "tf-nightly-gpu", + "intel-tensorflow", + "tensorflow-rocm", + "tensorflow-macos", + ) + _tf_version = None + # For the metadata, we have to look for both tensorflow and tensorflow-cpu + for pkg in candidates: + try: + _tf_version = importlib_metadata.version(pkg) + break + except importlib_metadata.PackageNotFoundError: + pass + _tf_available = _tf_version is not None + if _tf_available: + if int(_tf_version.split('.')[0]) < 2: # type: ignore[union-attr] + logging.info(f"TensorFlow found but with version {_tf_version}. DocTR requires version 2 minimum.") + _tf_available = False + else: + logging.info(f"TensorFlow version {_tf_version} available.") +else: + logging.info("Disabling Tensorflow because USE_TORCH is set") + _tf_available = False + + +if not _torch_available and not _tf_available: + raise ModuleNotFoundError("DocTR requires either TensorFlow or PyTorch to be installed. Please ensure one of them" + " is installed and that either USE_TF or USE_TORCH is enabled.") + + +def is_torch_available(): + return _torch_available + + +def is_tf_available(): + return _tf_available diff --git a/doctr/io/__init__.py b/doctr/io/__init__.py new file mode 100644 index 0000000000..6eab8c2406 --- /dev/null +++ b/doctr/io/__init__.py @@ -0,0 +1,5 @@ +from .elements import * +from .html import * +from .image import * +from .pdf import * +from .reader import * diff --git a/doctr/io/elements.py b/doctr/io/elements.py new file mode 100644 index 0000000000..c93d42ffce --- /dev/null +++ b/doctr/io/elements.py @@ -0,0 +1,405 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List, Optional, Tuple, Union + +from defusedxml import defuse_stdlib + +defuse_stdlib() +from xml.etree import ElementTree as ET +from xml.etree.ElementTree import Element as ETElement +from xml.etree.ElementTree import SubElement + +import matplotlib.pyplot as plt +import numpy as np + +import doctr +from doctr.utils.common_types import BoundingBox +from doctr.utils.geometry import resolve_enclosing_bbox, resolve_enclosing_rbbox +from doctr.utils.repr import NestedObject +from doctr.utils.visualization import synthesize_page, visualize_page + +__all__ = ['Element', 'Word', 'Artefact', 'Line', 'Block', 'Page', 'Document'] + + +class Element(NestedObject): + """Implements an abstract document element with exporting and text rendering capabilities""" + + _children_names: List[str] = [] + _exported_keys: List[str] = [] + + def __init__(self, **kwargs: Any) -> None: + for k, v in kwargs.items(): + if k in self._children_names: + setattr(self, k, v) + else: + raise KeyError(f"{self.__class__.__name__} object does not have any attribute named '{k}'") + + def export(self) -> Dict[str, Any]: + """Exports the object into a nested dict format""" + + export_dict = {k: getattr(self, k) for k in self._exported_keys} + for children_name in self._children_names: + export_dict[children_name] = [c.export() for c in getattr(self, children_name)] + + return export_dict + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + raise NotImplementedError + + def render(self) -> str: + raise NotImplementedError + + +class Word(Element): + """Implements a word element + + Args: + value: the text string of the word + confidence: the confidence associated with the text prediction + geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to + the page's size + """ + + _exported_keys: List[str] = ["value", "confidence", "geometry"] + _children_names: List[str] = [] + + def __init__(self, value: str, confidence: float, geometry: Union[BoundingBox, np.ndarray]) -> None: + super().__init__() + self.value = value + self.confidence = confidence + self.geometry = geometry + + def render(self) -> str: + """Renders the full text of the element""" + return self.value + + def extra_repr(self) -> str: + return f"value='{self.value}', confidence={self.confidence:.2}" + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + return cls(**kwargs) + + +class Artefact(Element): + """Implements a non-textual element + + Args: + artefact_type: the type of artefact + confidence: the confidence of the type prediction + geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to + the page's size. + """ + + _exported_keys: List[str] = ["geometry", "type", "confidence"] + _children_names: List[str] = [] + + def __init__(self, artefact_type: str, confidence: float, geometry: BoundingBox) -> None: + super().__init__() + self.geometry = geometry + self.type = artefact_type + self.confidence = confidence + + def render(self) -> str: + """Renders the full text of the element""" + return f"[{self.type.upper()}]" + + def extra_repr(self) -> str: + return f"type='{self.type}', confidence={self.confidence:.2}" + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + return cls(**kwargs) + + +class Line(Element): + """Implements a line element as a collection of words + + Args: + words: list of word elements + geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to + the page's size. If not specified, it will be resolved by default to the smallest bounding box enclosing + all words in it. + """ + + _exported_keys: List[str] = ["geometry"] + _children_names: List[str] = ['words'] + words: List[Word] = [] + + def __init__( + self, + words: List[Word], + geometry: Optional[Union[BoundingBox, np.ndarray]] = None, + ) -> None: + # Resolve the geometry using the smallest enclosing bounding box + if geometry is None: + # Check whether this is a rotated or straight box + box_resolution_fn = resolve_enclosing_rbbox if len(words[0].geometry) == 4 else resolve_enclosing_bbox + geometry = box_resolution_fn([w.geometry for w in words]) # type: ignore[operator, misc] + + super().__init__(words=words) + self.geometry = geometry + + def render(self) -> str: + """Renders the full text of the element""" + return " ".join(w.render() for w in self.words) + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + kwargs.update({ + 'words': [Word.from_dict(_dict) for _dict in save_dict['words']], + }) + return cls(**kwargs) + + +class Block(Element): + """Implements a block element as a collection of lines and artefacts + + Args: + lines: list of line elements + artefacts: list of artefacts + geometry: bounding box of the word in format ((xmin, ymin), (xmax, ymax)) where coordinates are relative to + the page's size. If not specified, it will be resolved by default to the smallest bounding box enclosing + all lines and artefacts in it. + """ + + _exported_keys: List[str] = ["geometry"] + _children_names: List[str] = ['lines', 'artefacts'] + lines: List[Line] = [] + artefacts: List[Artefact] = [] + + def __init__( + self, + lines: List[Line] = [], + artefacts: List[Artefact] = [], + geometry: Optional[Union[BoundingBox, np.ndarray]] = None, + ) -> None: + # Resolve the geometry using the smallest enclosing bounding box + if geometry is None: + line_boxes = [word.geometry for line in lines for word in line.words] + artefact_boxes = [artefact.geometry for artefact in artefacts] + box_resolution_fn = resolve_enclosing_rbbox if isinstance( + lines[0].geometry, np.ndarray + ) else resolve_enclosing_bbox + geometry = box_resolution_fn(line_boxes + artefact_boxes) # type: ignore[operator, arg-type] + + super().__init__(lines=lines, artefacts=artefacts) + self.geometry = geometry + + def render(self, line_break: str = '\n') -> str: + """Renders the full text of the element""" + return line_break.join(line.render() for line in self.lines) + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + kwargs.update({ + 'lines': [Line.from_dict(_dict) for _dict in save_dict['lines']], + 'artefacts': [Artefact.from_dict(_dict) for _dict in save_dict['artefacts']], + }) + return cls(**kwargs) + + +class Page(Element): + """Implements a page element as a collection of blocks + + Args: + blocks: list of block elements + page_idx: the index of the page in the input raw document + dimensions: the page size in pixels in format (height, width) + orientation: a dictionary with the value of the rotation angle in degress and confidence of the prediction + language: a dictionary with the language value and confidence of the prediction + """ + + _exported_keys: List[str] = ["page_idx", "dimensions", "orientation", "language"] + _children_names: List[str] = ['blocks'] + blocks: List[Block] = [] + + def __init__( + self, + blocks: List[Block], + page_idx: int, + dimensions: Tuple[int, int], + orientation: Optional[Dict[str, Any]] = None, + language: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(blocks=blocks) + self.page_idx = page_idx + self.dimensions = dimensions + self.orientation = orientation if isinstance(orientation, dict) else dict(value=None, confidence=None) + self.language = language if isinstance(language, dict) else dict(value=None, confidence=None) + + def render(self, block_break: str = '\n\n') -> str: + """Renders the full text of the element""" + return block_break.join(b.render() for b in self.blocks) + + def extra_repr(self) -> str: + return f"dimensions={self.dimensions}" + + def show( + self, page: np.ndarray, interactive: bool = True, preserve_aspect_ratio: bool = False, **kwargs + ) -> None: + """Overlay the result on a given image + + Args: + page: image encoded as a numpy array in uint8 + interactive: whether the display should be interactive + preserve_aspect_ratio: pass True if you passed True to the predictor + """ + visualize_page(self.export(), page, interactive=interactive, preserve_aspect_ratio=preserve_aspect_ratio) + plt.show(**kwargs) + + def synthesize(self, **kwargs) -> np.ndarray: + """Synthesize the page from the predictions + + Returns: + synthesized page + """ + + return synthesize_page(self.export(), **kwargs) + + def export_as_xml(self, file_title: str = 'docTR - XML export (hOCR)') -> Tuple[bytes, ET.ElementTree]: + """Export the page as XML (hOCR-format) + convention: https://github.com/kba/hocr-spec/blob/master/1.2/spec.md + + Args: + file_title: the title of the XML file + + Returns: + a tuple of the XML byte string, and its ElementTree + """ + p_idx = self.page_idx + block_count: int = 1 + line_count: int = 1 + word_count: int = 1 + height, width = self.dimensions + language = self.language if 'language' in self.language.keys() else 'en' + # Create the XML root element + page_hocr = ETElement('html', attrib={'xmlns': 'http://www.w3.org/1999/xhtml', 'xml:lang': str(language)}) + # Create the header / SubElements of the root element + head = SubElement(page_hocr, 'head') + SubElement(head, 'title').text = file_title + SubElement(head, 'meta', attrib={'http-equiv': 'Content-Type', 'content': 'text/html; charset=utf-8'}) + SubElement(head, 'meta', attrib={'name': 'ocr-system', 'content': f"python-doctr {doctr.__version__}"}) + SubElement(head, 'meta', attrib={'name': 'ocr-capabilities', + 'content': 'ocr_page ocr_carea ocr_par ocr_line ocrx_word'}) + # Create the body + body = SubElement(page_hocr, 'body') + SubElement(body, 'div', attrib={ + 'class': 'ocr_page', + 'id': f'page_{p_idx + 1}', + 'title': f'image; bbox 0 0 {width} {height}; ppageno 0' + }) + # iterate over the blocks / lines / words and create the XML elements in body line by line with the attributes + for block in self.blocks: + if len(block.geometry) != 2: + raise TypeError("XML export is only available for straight bounding boxes for now.") + (xmin, ymin), (xmax, ymax) = block.geometry # type: ignore[misc] + block_div = SubElement(body, 'div', attrib={ + 'class': 'ocr_carea', + 'id': f'block_{block_count}', + 'title': f'bbox {int(round(xmin * width))} {int(round(ymin * height))} \ + {int(round(xmax * width))} {int(round(ymax * height))}' + }) + paragraph = SubElement(block_div, 'p', attrib={ + 'class': 'ocr_par', + 'id': f'par_{block_count}', + 'title': f'bbox {int(round(xmin * width))} {int(round(ymin * height))} \ + {int(round(xmax * width))} {int(round(ymax * height))}' + }) + block_count += 1 + for line in block.lines: + (xmin, ymin), (xmax, ymax) = line.geometry # type: ignore[misc] + # NOTE: baseline, x_size, x_descenders, x_ascenders is currently initalized to 0 + line_span = SubElement(paragraph, 'span', attrib={ + 'class': 'ocr_line', + 'id': f'line_{line_count}', + 'title': f'bbox {int(round(xmin * width))} {int(round(ymin * height))} \ + {int(round(xmax * width))} {int(round(ymax * height))}; \ + baseline 0 0; x_size 0; x_descenders 0; x_ascenders 0' + }) + line_count += 1 + for word in line.words: + (xmin, ymin), (xmax, ymax) = word.geometry # type: ignore[misc] + conf = word.confidence + word_div = SubElement(line_span, 'span', attrib={ + 'class': 'ocrx_word', + 'id': f'word_{word_count}', + 'title': f'bbox {int(round(xmin * width))} {int(round(ymin * height))} \ + {int(round(xmax * width))} {int(round(ymax * height))}; \ + x_wconf {int(round(conf * 100))}' + }) + # set the text + word_div.text = word.value + word_count += 1 + + return (ET.tostring(page_hocr, encoding='utf-8', method='xml'), ET.ElementTree(page_hocr)) + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + kwargs.update({'blocks': [Block.from_dict(block_dict) for block_dict in save_dict['blocks']]}) + return cls(**kwargs) + + +class Document(Element): + """Implements a document element as a collection of pages + + Args: + pages: list of page elements + """ + + _children_names: List[str] = ['pages'] + pages: List[Page] = [] + + def __init__( + self, + pages: List[Page], + ) -> None: + super().__init__(pages=pages) + + def render(self, page_break: str = '\n\n\n\n') -> str: + """Renders the full text of the element""" + return page_break.join(p.render() for p in self.pages) + + def show(self, pages: List[np.ndarray], **kwargs) -> None: + """Overlay the result on a given image + + Args: + pages: list of images encoded as numpy arrays in uint8 + """ + for img, result in zip(pages, self.pages): + result.show(img, **kwargs) + + def synthesize(self, **kwargs) -> List[np.ndarray]: + """Synthesize all pages from their predictions + + Returns: + list of synthesized pages + """ + + return [page.synthesize() for page in self.pages] + + def export_as_xml(self, **kwargs) -> List[Tuple[bytes, ET.ElementTree]]: + """Export the document as XML (hOCR-format) + + Args: + **kwargs: additional keyword arguments passed to the Page.export_as_xml method + + Returns: + list of tuple of (bytes, ElementTree) + """ + return [page.export_as_xml(**kwargs) for page in self.pages] + + @classmethod + def from_dict(cls, save_dict: Dict[str, Any], **kwargs): + kwargs = {k: save_dict[k] for k in cls._exported_keys} + kwargs.update({'pages': [Page.from_dict(page_dict) for page_dict in save_dict['pages']]}) + return cls(**kwargs) diff --git a/doctr/io/html.py b/doctr/io/html.py new file mode 100644 index 0000000000..0ae81888e9 --- /dev/null +++ b/doctr/io/html.py @@ -0,0 +1,26 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any + +from weasyprint import HTML + +__all__ = ['read_html'] + + +def read_html(url: str, **kwargs: Any) -> bytes: + """Read a PDF file and convert it into an image in numpy format + + Example:: + >>> from doctr.documents import read_html + >>> doc = read_html("https://www.yoursite.com") + + Args: + url: URL of the target web page + Returns: + decoded PDF file as a bytes stream + """ + + return HTML(url, **kwargs).write_pdf() diff --git a/doctr/io/image/__init__.py b/doctr/io/image/__init__.py new file mode 100644 index 0000000000..1950176a6d --- /dev/null +++ b/doctr/io/image/__init__.py @@ -0,0 +1,8 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +from .base import * + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/io/image/base.py b/doctr/io/image/base.py new file mode 100644 index 0000000000..14a8856f73 --- /dev/null +++ b/doctr/io/image/base.py @@ -0,0 +1,55 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from pathlib import Path +from typing import Optional, Tuple + +import cv2 +import numpy as np + +from doctr.utils.common_types import AbstractFile + +__all__ = ['read_img_as_numpy'] + + +def read_img_as_numpy( + file: AbstractFile, + output_size: Optional[Tuple[int, int]] = None, + rgb_output: bool = True, +) -> np.ndarray: + """Read an image file into numpy format + + Example:: + >>> from doctr.documents import read_img + >>> page = read_img("path/to/your/doc.jpg") + + Args: + file: the path to the image file + output_size: the expected output size of each page in format H x W + rgb_output: whether the output ndarray channel order should be RGB instead of BGR. + Returns: + the page decoded as numpy ndarray of shape H x W x 3 + """ + + if isinstance(file, (str, Path)): + if not Path(file).is_file(): + raise FileNotFoundError(f"unable to access {file}") + img = cv2.imread(str(file), cv2.IMREAD_COLOR) + elif isinstance(file, bytes): + file = np.frombuffer(file, np.uint8) + img = cv2.imdecode(file, cv2.IMREAD_COLOR) + else: + raise TypeError("unsupported object type for argument 'file'") + + # Validity check + if img is None: + raise ValueError("unable to read file.") + # Resizing + if isinstance(output_size, tuple): + img = cv2.resize(img, output_size[::-1], interpolation=cv2.INTER_LINEAR) + # Switch the channel order + if rgb_output: + img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) + return img diff --git a/doctr/io/image/pytorch.py b/doctr/io/image/pytorch.py new file mode 100644 index 0000000000..483d08aac8 --- /dev/null +++ b/doctr/io/image/pytorch.py @@ -0,0 +1,104 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from io import BytesIO +from typing import Tuple + +import numpy as np +import torch +from PIL import Image +from torchvision.transforms.functional import to_tensor + +from doctr.utils.common_types import AbstractPath + +__all__ = ['tensor_from_pil', 'read_img_as_tensor', 'decode_img_as_tensor', 'tensor_from_numpy', 'get_img_shape'] + + +def tensor_from_pil(pil_img: Image, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Convert a PIL Image to a PyTorch tensor + + Args: + pil_img: a PIL image + dtype: the output tensor data type + + Returns: + decoded image as tensor + """ + + if dtype == torch.float32: + img = to_tensor(pil_img) + else: + img = tensor_from_numpy(np.array(pil_img, np.uint8, copy=True), dtype) + + return img + + +def read_img_as_tensor(img_path: AbstractPath, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Read an image file as a PyTorch tensor + + Args: + img_path: location of the image file + dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. + + Returns: + decoded image as a tensor + """ + + if dtype not in (torch.uint8, torch.float16, torch.float32): + raise ValueError("insupported value for dtype") + + pil_img = Image.open(img_path, mode='r').convert('RGB') + + return tensor_from_pil(pil_img, dtype) + + +def decode_img_as_tensor(img_content: bytes, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Read a byte stream as a PyTorch tensor + + Args: + img_content: bytes of a decoded image + dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. + + Returns: + decoded image as a tensor + """ + + if dtype not in (torch.uint8, torch.float16, torch.float32): + raise ValueError("insupported value for dtype") + + pil_img = Image.open(BytesIO(img_content), mode='r').convert('RGB') + + return tensor_from_pil(pil_img, dtype) + + +def tensor_from_numpy(npy_img: np.ndarray, dtype: torch.dtype = torch.float32) -> torch.Tensor: + """Read an image file as a PyTorch tensor + + Args: + img: image encoded as a numpy array of shape (H, W, C) in np.uint8 + dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. + + Returns: + same image as a tensor of shape (C, H, W) + """ + + if dtype not in (torch.uint8, torch.float16, torch.float32): + raise ValueError("insupported value for dtype") + + if dtype == torch.float32: + img = to_tensor(npy_img) + else: + img = torch.from_numpy(npy_img) + # put it from HWC to CHW format + img = img.permute((2, 0, 1)).contiguous() + if dtype == torch.float16: + # Switch to FP16 + img = img.to(dtype=torch.float16).div(255) + + return img + + +def get_img_shape(img: torch.Tensor) -> Tuple[int, int]: + return img.shape[-2:] # type: ignore[return-value] diff --git a/doctr/io/image/tensorflow.py b/doctr/io/image/tensorflow.py new file mode 100644 index 0000000000..d8352bfe74 --- /dev/null +++ b/doctr/io/image/tensorflow.py @@ -0,0 +1,109 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Tuple + +import numpy as np +import tensorflow as tf +from PIL import Image + +if tf.__version__ >= '2.6.0': + from tensorflow.keras.utils import img_to_array +else: + from tensorflow.keras.preprocessing.image import img_to_array + +from doctr.utils.common_types import AbstractPath + +__all__ = ['tensor_from_pil', 'read_img_as_tensor', 'decode_img_as_tensor', 'tensor_from_numpy', 'get_img_shape'] + + +def tensor_from_pil(pil_img: Image, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: + """Convert a PIL Image to a TensorFlow tensor + + Args: + pil_img: a PIL image + dtype: the output tensor data type + + Returns: + decoded image as tensor + """ + + npy_img = img_to_array(pil_img) + + return tensor_from_numpy(npy_img, dtype) + + +def read_img_as_tensor(img_path: AbstractPath, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: + """Read an image file as a TensorFlow tensor + + Args: + img_path: location of the image file + dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. + + Returns: + decoded image as a tensor + """ + + if dtype not in (tf.uint8, tf.float16, tf.float32): + raise ValueError("insupported value for dtype") + + img = tf.io.read_file(img_path) + img = tf.image.decode_jpeg(img, channels=3) + + if dtype != tf.uint8: + img = tf.image.convert_image_dtype(img, dtype=dtype) + img = tf.clip_by_value(img, 0, 1) + + return img + + +def decode_img_as_tensor(img_content: bytes, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: + """Read a byte stream as a TensorFlow tensor + + Args: + img_content: bytes of a decoded image + dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. + + Returns: + decoded image as a tensor + """ + + if dtype not in (tf.uint8, tf.float16, tf.float32): + raise ValueError("insupported value for dtype") + + img = tf.io.decode_image(img_content, channels=3) + + if dtype != tf.uint8: + img = tf.image.convert_image_dtype(img, dtype=dtype) + img = tf.clip_by_value(img, 0, 1) + + return img + + +def tensor_from_numpy(npy_img: np.ndarray, dtype: tf.dtypes.DType = tf.float32) -> tf.Tensor: + """Read an image file as a TensorFlow tensor + + Args: + img: image encoded as a numpy array of shape (H, W, C) in np.uint8 + dtype: the desired data type of the output tensor. If it is float-related, values will be divided by 255. + + Returns: + same image as a tensor of shape (H, W, C) + """ + + if dtype not in (tf.uint8, tf.float16, tf.float32): + raise ValueError("insupported value for dtype") + + if dtype == tf.uint8: + img = tf.convert_to_tensor(npy_img, dtype=dtype) + else: + img = tf.image.convert_image_dtype(npy_img, dtype=dtype) + img = tf.clip_by_value(img, 0, 1) + + return img + + +def get_img_shape(img: tf.Tensor) -> Tuple[int, int]: + return img.shape[:2] diff --git a/doctr/io/pdf.py b/doctr/io/pdf.py new file mode 100644 index 0000000000..0b1f221a67 --- /dev/null +++ b/doctr/io/pdf.py @@ -0,0 +1,184 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +import cv2 +import fitz +import numpy as np + +from doctr.utils.common_types import AbstractFile, Bbox + +__all__ = ['read_pdf', 'PDF'] + + +def read_pdf(file: AbstractFile, **kwargs: Any) -> fitz.Document: + """Read a PDF file and convert it into an image in numpy format + + Example:: + >>> from doctr.documents import read_pdf + >>> doc = read_pdf("path/to/your/doc.pdf") + + Args: + file: the path to the PDF file + Returns: + the list of pages decoded as numpy ndarray of shape H x W x 3 + """ + + if isinstance(file, (str, Path)) and not Path(file).is_file(): + raise FileNotFoundError(f"unable to access {file}") + + fitz_args: Dict[str, AbstractFile] = {} + + if isinstance(file, (str, Path)): + fitz_args['filename'] = file + elif isinstance(file, bytes): + fitz_args['stream'] = file + else: + raise TypeError("unsupported object type for argument 'file'") + + # Read pages with fitz and convert them to numpy ndarrays + return fitz.open(**fitz_args, filetype="pdf", **kwargs) + + +def convert_page_to_numpy( + page: fitz.fitz.Page, + output_size: Optional[Tuple[int, int]] = None, + bgr_output: bool = False, + default_scales: Tuple[float, float] = (2, 2), +) -> np.ndarray: + """Convert a fitz page to a numpy-formatted image + + Args: + page: the page of a file read with PyMuPDF + output_size: the expected output size of each page in format H x W. Default goes to 840 x 595 for A4 pdf, + if you want to increase the resolution while preserving the original A4 aspect ratio can pass (1024, 726) + rgb_output: whether the output ndarray channel order should be RGB instead of BGR. + default_scales: spatial scaling to be applied when output_size is not specified where (1, 1) + corresponds to 72 dpi rendering. + + Returns: + the rendered image in numpy format + """ + + # If no output size is specified, keep the origin one + if output_size is not None: + scales = (output_size[1] / page.MediaBox[2], output_size[0] / page.MediaBox[3]) + else: + # Default 72 DPI (scales of (1, 1)) is unnecessarily low + scales = default_scales + + transform_matrix = fitz.Matrix(*scales) + + # Generate the pixel map using the transformation matrix + pixmap = page.get_pixmap(matrix=transform_matrix) + # Decode it into a numpy + img = np.frombuffer(pixmap.samples, dtype=np.uint8).reshape(pixmap.height, pixmap.width, 3) + + # Switch the channel order + if bgr_output: + img = cv2.cvtColor(img, cv2.COLOR_RGB2BGR) + + return img + + +class PDF: + """PDF document template + + Args: + doc: input PDF document + """ + def __init__(self, doc: fitz.Document) -> None: + self.doc = doc + + def as_images(self, **kwargs) -> List[np.ndarray]: + """Convert all document pages to images + + Example:: + >>> from doctr.documents import DocumentFile + >>> pages = DocumentFile.from_pdf("path/to/your/doc.pdf").as_images() + + Args: + kwargs: keyword arguments of `convert_page_to_numpy` + Returns: + the list of pages decoded as numpy ndarray of shape H x W x 3 + """ + return [convert_page_to_numpy(page, **kwargs) for page in self.doc] + + def get_page_lines(self, idx, **kwargs) -> List[Tuple[Bbox, str]]: + """Get the annotations for all lines of a given page""" + lines: List[Tuple[Bbox, str]] = [] + prev_block, prev_line = -1, -1 + current_line = [] + xmin, ymin, xmax, ymax = 0, 0, 0, 0 + # xmin, ymin, xmax, ymax, value, block_idx, line_idx, word_idx + for info in self.doc[idx].get_text_words(**kwargs): + if prev_block == info[-3] and prev_line == info[-2]: + current_line.append(info[4]) + xmin, ymin = min(xmin, info[0]), min(ymin, info[1]) + xmax, ymax = max(xmax, info[2]), max(ymax, info[3]) + else: + if len(current_line) > 0: + lines.append(((xmin, ymin, xmax, ymax), " ".join(current_line))) + current_line = [info[4]] + prev_block, prev_line = info[-3], info[-2] + xmin, ymin, xmax, ymax = info[:4] + + if len(current_line) > 0: + lines.append(((xmin, ymin, xmax, ymax), " ".join(current_line))) + + return lines + + def get_lines(self, **kwargs) -> List[List[Tuple[Bbox, str]]]: + """Get the annotations for all lines in the document + + Example:: + >>> from doctr.documents import DocumentFile + >>> lines = DocumentFile.from_pdf("path/to/your/doc.pdf").get_lines() + + Args: + kwargs: keyword arguments of `fitz.Page.get_text_words` + Returns: + the list of pages annotations, represented as a list of tuple (bounding box, value) + """ + return [self.get_page_lines(idx, **kwargs) for idx in range(len(self.doc))] + + def get_page_words(self, idx, **kwargs) -> List[Tuple[Bbox, str]]: + """Get the annotations for all words of a given page""" + + # xmin, ymin, xmax, ymax, value, block_idx, line_idx, word_idx + return [(info[:4], info[4]) for info in self.doc[idx].get_text_words(**kwargs)] + + def get_words(self, **kwargs) -> List[List[Tuple[Bbox, str]]]: + """Get the annotations for all words in the document + + Example:: + >>> from doctr.documents import DocumentFile + >>> words = DocumentFile.from_pdf("path/to/your/doc.pdf").get_words() + + Args: + kwargs: keyword arguments of `fitz.Page.get_text_words` + Returns: + the list of pages annotations, represented as a list of tuple (bounding box, value) + """ + return [self.get_page_words(idx, **kwargs) for idx in range(len(self.doc))] + + def get_page_artefacts(self, idx) -> List[Tuple[float, float, float, float]]: + return [tuple(self.doc[idx].get_image_bbox(artefact)) # type: ignore[misc] + for artefact in self.doc[idx].get_images(full=True)] + + def get_artefacts(self) -> List[List[Tuple[float, float, float, float]]]: + """Get the artefacts for the entire document + + Example:: + >>> from doctr.documents import DocumentFile + >>> artefacts = DocumentFile.from_pdf("path/to/your/doc.pdf").get_artefacts() + + Returns: + the list of pages artefacts, represented as a list of bounding boxes + """ + + return [self.get_page_artefacts(idx) for idx in range(len(self.doc))] diff --git a/doctr/io/reader.py b/doctr/io/reader.py new file mode 100644 index 0000000000..6d4b55e084 --- /dev/null +++ b/doctr/io/reader.py @@ -0,0 +1,73 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from pathlib import Path +from typing import List, Sequence, Union + +import numpy as np + +from doctr.utils.common_types import AbstractFile + +from .html import read_html +from .image import read_img_as_numpy +from .pdf import PDF, read_pdf + +__all__ = ['DocumentFile'] + + +class DocumentFile: + """Read a document from multiple extensions""" + + @classmethod + def from_pdf(cls, file: AbstractFile, **kwargs) -> PDF: + """Read a PDF file + + Example:: + >>> from doctr.documents import DocumentFile + >>> doc = DocumentFile.from_pdf("path/to/your/doc.pdf") + + Args: + file: the path to the PDF file or a binary stream + Returns: + a PDF document + """ + + doc = read_pdf(file, **kwargs) + + return PDF(doc) + + @classmethod + def from_url(cls, url: str, **kwargs) -> PDF: + """Interpret a web page as a PDF document + + Example:: + >>> from doctr.documents import DocumentFile + >>> doc = DocumentFile.from_url("https://www.yoursite.com") + + Args: + url: the URL of the target web page + Returns: + a PDF document + """ + pdf_stream = read_html(url) + return cls.from_pdf(pdf_stream, **kwargs) + + @classmethod + def from_images(cls, files: Union[Sequence[AbstractFile], AbstractFile], **kwargs) -> List[np.ndarray]: + """Read an image file (or a collection of image files) and convert it into an image in numpy format + + Example:: + >>> from doctr.documents import DocumentFile + >>> pages = DocumentFile.from_images(["path/to/your/page1.png", "path/to/your/page2.png"]) + + Args: + files: the path to the image file or a binary stream, or a collection of those + Returns: + the list of pages decoded as numpy ndarray of shape H x W x 3 + """ + if isinstance(files, (str, Path, bytes)): + files = [files] + + return [read_img_as_numpy(file, **kwargs) for file in files] diff --git a/doctr/models/__init__.py b/doctr/models/__init__.py new file mode 100644 index 0000000000..520b296236 --- /dev/null +++ b/doctr/models/__init__.py @@ -0,0 +1,5 @@ +from . import artefacts +from .classification import * +from .detection import * +from .recognition import * +from .zoo import * diff --git a/doctr/models/_utils.py b/doctr/models/_utils.py new file mode 100644 index 0000000000..1aabac72d0 --- /dev/null +++ b/doctr/models/_utils.py @@ -0,0 +1,222 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from math import floor +from statistics import median_low +from typing import List + +import cv2 +import numpy as np + +__all__ = ['estimate_orientation', 'extract_crops', 'extract_rcrops', 'get_bitmap_angle'] + + +def extract_crops(img: np.ndarray, boxes: np.ndarray, channels_last: bool = True) -> List[np.ndarray]: + """Created cropped images from list of bounding boxes + + Args: + img: input image + boxes: bounding boxes of shape (N, 4) where N is the number of boxes, and the relative + coordinates (xmin, ymin, xmax, ymax) + channels_last: whether the channel dimensions is the last one instead of the last one + + Returns: + list of cropped images + """ + if boxes.shape[0] == 0: + return [] + if boxes.shape[1] != 4: + raise AssertionError("boxes are expected to be relative and in order (xmin, ymin, xmax, ymax)") + + # Project relative coordinates + _boxes = boxes.copy() + h, w = img.shape[:2] if channels_last else img.shape[-2:] + if _boxes.dtype != int: + _boxes[:, [0, 2]] *= w + _boxes[:, [1, 3]] *= h + _boxes = _boxes.round().astype(int) + # Add last index + _boxes[2:] += 1 + if channels_last: + return [img[box[1]: box[3], box[0]: box[2]] for box in _boxes] + else: + return [img[:, box[1]: box[3], box[0]: box[2]] for box in _boxes] + + +def extract_rcrops( + img: np.ndarray, + polys: np.ndarray, + dtype=np.float32, + channels_last: bool = True +) -> List[np.ndarray]: + """Created cropped images from list of rotated bounding boxes + + Args: + img: input image + polys: bounding boxes of shape (N, 4, 2) + dtype: target data type of bounding boxes + channels_last: whether the channel dimensions is the last one instead of the last one + + Returns: + list of cropped images + """ + if polys.shape[0] == 0: + return [] + if polys.shape[1:] != (4, 2): + raise AssertionError("polys are expected to be quadrilateral, of shape (N, 4, 2)") + + # Project relative coordinates + _boxes = polys.copy() + height, width = img.shape[:2] if channels_last else img.shape[-2:] + if _boxes.dtype != np.int: + _boxes[:, :, 0] *= width + _boxes[:, :, 1] *= height + + src_pts = _boxes[:, 1:].astype(np.float32) + # Preserve size + d1 = np.linalg.norm(src_pts[:, 0] - src_pts[:, 1], axis=-1) + d2 = np.linalg.norm(src_pts[:, 1] - src_pts[:, 2], axis=-1) + # (N, 3, 2) + dst_pts = np.zeros((_boxes.shape[0], 3, 2), dtype=dtype) + dst_pts[:, 1, 0] = dst_pts[:, 2, 0] = d1 - 1 + dst_pts[:, 2, 1] = d2 - 1 + # Use a warp transformation to extract the crop + crops = [ + cv2.warpAffine( + img if channels_last else img.transpose(1, 2, 0), + # Transformation matrix + cv2.getAffineTransform(src_pts[idx], dst_pts[idx]), + (int(d1[idx]), int(d2[idx])), + ) + for idx in range(_boxes.shape[0]) + ] + return crops + + +def get_max_width_length_ratio(contour: np.ndarray) -> float: + """ + Get the maximum shape ratio of a contour. + Args: + contour: the contour from cv2.findContour + + Returns: the maximum shape ratio + + """ + _, (w, h), _ = cv2.minAreaRect(contour) + return max(w / h, h / w) + + +def estimate_orientation(img: np.ndarray, n_ct: int = 50, ratio_threshold_for_lines: float = 5) -> float: + """Estimate the angle of the general document orientation based on the + lines of the document and the assumption that they should be horizontal. + + Args: + img: the img to analyze + n_ct: the number of contours used for the orientation estimation + ratio_threshold_for_lines: this is the ratio w/h used to discriminates lines + Returns: + the angle of the general document orientation + """ + gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + gray_img = cv2.medianBlur(gray_img, 5) + thresh = cv2.threshold(gray_img, thresh=0, maxval=255, type=cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)[1] + + # try to merge words in lines + (h, w) = img.shape[:2] + k_x = max(1, (floor(w / 100))) + k_y = max(1, (floor(h / 100))) + kernel = cv2.getStructuringElement(cv2.MORPH_RECT, (k_x, k_y)) + thresh = cv2.dilate(thresh, kernel, iterations=1) + + # extract contours + contours, _ = cv2.findContours(thresh, cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + + # Sort contours + contours = sorted(contours, key=get_max_width_length_ratio, reverse=True) + + angles = [] + for contour in contours[:n_ct]: + _, (w, h), angle = cv2.minAreaRect(contour) + if w / h > ratio_threshold_for_lines: # select only contours with ratio like lines + angles.append(angle) + elif w / h < 1 / ratio_threshold_for_lines: # if lines are vertical, substract 90 degree + angles.append(angle - 90) + + if len(angles) == 0: + return 0 # in case no angles is found + else: + return -median_low(angles) + + +def get_bitmap_angle(bitmap: np.ndarray, n_ct: int = 20, std_max: float = 3.) -> float: + """From a binarized segmentation map, find contours and fit min area rectangles to determine page angle + + Args: + bitmap: binarized segmentation map + n_ct: number of contours to use to fit page angle + std_max: maximum deviation of the angle distribution to consider the mean angle reliable + + Returns: + The angle of the page + """ + # Find all contours on binarized seg map + contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_LIST, cv2.CHAIN_APPROX_SIMPLE) + # Sort contours + contours = sorted(contours, key=cv2.contourArea, reverse=True) + + # Find largest contours and fit angles + # Track heights and widths to find aspect ratio (determine is rotation is clockwise) + angles, heights, widths = [], [], [] + for ct in contours[:n_ct]: + _, (w, h), alpha = cv2.minAreaRect(ct) + widths.append(w) + heights.append(h) + angles.append(alpha) + + if np.std(angles) > std_max: + # Edge case with angles of both 0 and 90°, or multi_oriented docs + angle = 0. + else: + angle = -np.mean(angles) + # Determine rotation direction (clockwise/counterclockwise) + # Angle coverage: [-90°, +90°], half of the quadrant + if np.sum(widths) < np.sum(heights): # CounterClockwise + angle = 90 + angle + + return angle + + +def rectify_crops( + crops: List[np.ndarray], + orientations: List[int], +) -> List[np.ndarray]: + """Rotate each crop of the list according to the predicted orientation: + 0: already straight, no rotation + 1: 90 ccw, rotate 3 times ccw + 2: 180, rotate 2 times ccw + 3: 270 ccw, rotate 1 time ccw + """ + # Inverse predictions (if angle of +90 is detected, rotate by -90) + orientations = [4 - pred if pred != 0 else 0 for pred in orientations] + return [ + crop if orientation == 0 else np.rot90(crop, orientation) + for orientation, crop in zip(orientations, crops) + ] if len(orientations) > 0 else [] + + +def rectify_loc_preds( + page_loc_preds: np.ndarray, + orientations: List[int], +) -> np.ndarray: + """Orient the quadrangle (Polygon4P) according to the predicted orientation, + so that the points are in this order: top L, top R, bot R, bot L if the crop is readable + """ + return np.stack( + [np.roll( + page_loc_pred, + orientation, + axis=0) for orientation, page_loc_pred in zip(orientations, page_loc_preds)], + axis=0 + ) if len(orientations) > 0 else None diff --git a/doctr/models/artefacts/__init__.py b/doctr/models/artefacts/__init__.py new file mode 100644 index 0000000000..875f48a875 --- /dev/null +++ b/doctr/models/artefacts/__init__.py @@ -0,0 +1,2 @@ +from .barcode import * +from .face import * diff --git a/doctr/models/artefacts/barcode.py b/doctr/models/artefacts/barcode.py new file mode 100644 index 0000000000..f39776cb59 --- /dev/null +++ b/doctr/models/artefacts/barcode.py @@ -0,0 +1,77 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import List, Tuple + +import cv2 +import numpy as np + +__all__ = ['BarCodeDetector'] + + +class BarCodeDetector: + + """ Implements a Bar-code detector. + For now, only horizontal (or with a small angle) bar-codes are supported + + Args: + min_size: minimum relative size of a barcode on the page + canny_minval: lower bound for canny hysteresis + canny_maxval: upper-bound for canny hysteresis + """ + def __init__( + self, + min_size: float = 1 / 6, + canny_minval: int = 50, + canny_maxval: int = 150 + ) -> None: + self.min_size = min_size + self.canny_minval = canny_minval + self.canny_maxval = canny_maxval + + def __call__( + self, + img: np.array, + ) -> List[Tuple[float, float, float, float]]: + """Detect Barcodes on the image + Args: + img: np image + + Returns: + A list of tuples: [(xmin, ymin, xmax, ymax), ...] containing barcodes rel. coordinates + """ + # get image size and define parameters + height, width = img.shape[:2] + k = (1 + int(width / 512)) * 10 # spatial extension of kernels, 512 -> 20, 1024 -> 30, ... + min_w = int(width * self.min_size) # minimal size of a possible barcode + + # Detect edges + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + edges = cv2.Canny(gray, self.canny_minval, self.canny_maxval, apertureSize=3) + + # Horizontal dilation to aggregate bars of the potential barcode + # without aggregating text lines of the page vertically + edges = cv2.dilate(edges, np.ones((1, k), np.uint8)) + + # Instantiate a barcode-shaped kernel and erode to keep only vertical-bar structures + bar_code_kernel = np.zeros((k, 3), np.uint8) + bar_code_kernel[..., [0, 2]] = 1 + edges = cv2.erode(edges, bar_code_kernel, iterations=1) + + # Opening to remove noise + edges = cv2.morphologyEx(edges, cv2.MORPH_OPEN, np.ones((k, k), np.uint8)) + + # Dilation to retrieve vertical length (lost at the first dilation) + edges = cv2.dilate(edges, np.ones((k, 1), np.uint8)) + + # Find contours, and keep the widest as barcodes + contours, _ = cv2.findContours(edges, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE) + barcodes = [] + for contour in contours: + x, y, w, h = cv2.boundingRect(contour) + if w >= min_w: + barcodes.append((x / width, y / height, (x + w) / width, (y + h) / height)) + + return barcodes diff --git a/doctr/models/artefacts/face.py b/doctr/models/artefacts/face.py new file mode 100644 index 0000000000..7b858d9b87 --- /dev/null +++ b/doctr/models/artefacts/face.py @@ -0,0 +1,59 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import List, Tuple + +import cv2 +import numpy as np + +from doctr.utils.repr import NestedObject + +__all__ = ['FaceDetector'] + + +class FaceDetector(NestedObject): + + """ Implements a face detector to detect profile pictures on resumes, IDS, driving licenses, passports... + Based on open CV CascadeClassifier (haarcascades) + + Args: + n_faces: maximal number of faces to detect on a single image, default = 1 + """ + + def __init__( + self, + n_faces: int = 1, + ) -> None: + self.n_faces = n_faces + # Instantiate classifier + self.detector = cv2.CascadeClassifier(cv2.data.haarcascades + 'haarcascade_frontalface_default.xml') + + def extra_repr(self) -> str: + return f"n_faces={self.n_faces}" + + def __call__( + self, + img: np.array, + ) -> List[Tuple[float, float, float, float]]: + """Detect n_faces on the img + + Args: + img: image to detect faces on + + Returns: + A list of size n_faces, each face is a tuple of relative xmin, ymin, xmax, ymax + """ + height, width = img.shape[:2] + gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) + + faces = self.detector.detectMultiScale(gray, 1.5, 3) + # If faces are detected, keep only the biggest ones + rel_faces = [] + if len(faces) > 0: + x, y, w, h = sorted(faces, key=lambda x: x[2] + x[3])[-min(self.n_faces, len(faces))] + xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height + rel_faces.append((xmin, ymin, xmax, ymax)) + + return rel_faces diff --git a/doctr/models/builder.py b/doctr/models/builder.py new file mode 100644 index 0000000000..02c534031b --- /dev/null +++ b/doctr/models/builder.py @@ -0,0 +1,313 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + + +from typing import Dict, List, Tuple + +import numpy as np +from scipy.cluster.hierarchy import fclusterdata + +from doctr.io.elements import Block, Document, Line, Page, Word +from doctr.utils.geometry import estimate_page_angle, resolve_enclosing_bbox, resolve_enclosing_rbbox, rotate_boxes +from doctr.utils.repr import NestedObject + +__all__ = ['DocumentBuilder'] + + +class DocumentBuilder(NestedObject): + """Implements a document builder + + Args: + resolve_lines: whether words should be automatically grouped into lines + resolve_blocks: whether lines should be automatically grouped into blocks + paragraph_break: relative length of the minimum space separating paragraphs + export_as_straight_boxes: if True, force straight boxes in the export (fit a rectangle + box to all rotated boxes). Else, keep the boxes format unchanged, no matter what it is. + """ + + def __init__( + self, + resolve_lines: bool = True, + resolve_blocks: bool = True, + paragraph_break: float = 0.035, + export_as_straight_boxes: bool = False, + ) -> None: + + self.resolve_lines = resolve_lines + self.resolve_blocks = resolve_blocks + self.paragraph_break = paragraph_break + self.export_as_straight_boxes = export_as_straight_boxes + + @staticmethod + def _sort_boxes(boxes: np.ndarray) -> np.ndarray: + """Sort bounding boxes from top to bottom, left to right + + Args: + boxes: bounding boxes of shape (N, 4) or (N, 4, 2) (in case of rotated bbox) + + Returns: + tuple: indices of ordered boxes of shape (N,), boxes + If straight boxes are passed tpo the function, boxes are unchanged + else: boxes returned are straight boxes fitted to the straightened rotated boxes + so that we fit the lines afterwards to the straigthened page + """ + if boxes.ndim == 3: + boxes = rotate_boxes( + loc_preds=boxes, + angle=-estimate_page_angle(boxes), + orig_shape=(1024, 1024), + min_angle=5., + ) + boxes = np.concatenate((boxes.min(1), boxes.max(1)), -1) + return (boxes[:, 0] + 2 * boxes[:, 3] / np.median(boxes[:, 3] - boxes[:, 1])).argsort(), boxes + + def _resolve_sub_lines(self, boxes: np.ndarray, word_idcs: List[int]) -> List[List[int]]: + """Split a line in sub_lines + + Args: + boxes: bounding boxes of shape (N, 4) + word_idcs: list of indexes for the words of the line + + Returns: + A list of (sub-)lines computed from the original line (words) + """ + lines = [] + # Sort words horizontally + word_idcs = [word_idcs[idx] for idx in boxes[word_idcs, 0].argsort().tolist()] + + # Eventually split line horizontally + if len(word_idcs) < 2: + lines.append(word_idcs) + else: + sub_line = [word_idcs[0]] + for i in word_idcs[1:]: + horiz_break = True + + prev_box = boxes[sub_line[-1]] + # Compute distance between boxes + dist = boxes[i, 0] - prev_box[2] + # If distance between boxes is lower than paragraph break, same sub-line + if dist < self.paragraph_break: + horiz_break = False + + if horiz_break: + lines.append(sub_line) + sub_line = [] + + sub_line.append(i) + lines.append(sub_line) + + return lines + + def _resolve_lines(self, boxes: np.ndarray) -> List[List[int]]: + """Order boxes to group them in lines + + Args: + boxes: bounding boxes of shape (N, 4) or (N, 4, 2) in case of rotated bbox + + Returns: + nested list of box indices + """ + + # Sort boxes, and straighten the boxes if they are rotated + idxs, boxes = self._sort_boxes(boxes) + + # Compute median for boxes heights + y_med = np.median(boxes[:, 3] - boxes[:, 1]) + + lines = [] + words = [idxs[0]] # Assign the top-left word to the first line + # Define a mean y-center for the line + y_center_sum = boxes[idxs[0]][[1, 3]].mean() + + for idx in idxs[1:]: + vert_break = True + + # Compute y_dist + y_dist = abs(boxes[idx][[1, 3]].mean() - y_center_sum / len(words)) + # If y-center of the box is close enough to mean y-center of the line, same line + if y_dist < y_med / 2: + vert_break = False + + if vert_break: + # Compute sub-lines (horizontal split) + lines.extend(self._resolve_sub_lines(boxes, words)) + words = [] + y_center_sum = 0 + + words.append(idx) + y_center_sum += boxes[idx][[1, 3]].mean() + + # Use the remaining words to form the last(s) line(s) + if len(words) > 0: + # Compute sub-lines (horizontal split) + lines.extend(self._resolve_sub_lines(boxes, words)) + + return lines + + @staticmethod + def _resolve_blocks(boxes: np.ndarray, lines: List[List[int]]) -> List[List[List[int]]]: + """Order lines to group them in blocks + + Args: + boxes: bounding boxes of shape (N, 4) or (N, 4, 2) + lines: list of lines, each line is a list of idx + + Returns: + nested list of box indices + """ + # Resolve enclosing boxes of lines + if boxes.ndim == 3: + box_lines = np.asarray([ + resolve_enclosing_rbbox([tuple(boxes[idx, :, :]) for idx in line]) + for line in lines # type: ignore[misc] + ]) + else: + _box_lines = [ + resolve_enclosing_bbox([ + (tuple(boxes[idx, :2]), tuple(boxes[idx, 2:])) for idx in line # type: ignore[misc] + ]) + for line in lines + ] + box_lines = np.asarray([(x1, y1, x2, y2) for ((x1, y1), (x2, y2)) in _box_lines]) + + # Compute geometrical features of lines to clusterize + # Clusterizing only with box centers yield to poor results for complex documents + if boxes.ndim == 3: + box_features = np.stack( + ( + (box_lines[:, 0, 0] + box_lines[:, 0, 1]) / 2, + (box_lines[:, 0, 0] + box_lines[:, 2, 0]) / 2, + (box_lines[:, 0, 0] + box_lines[:, 2, 1]) / 2, + (box_lines[:, 0, 1] + box_lines[:, 2, 1]) / 2, + (box_lines[:, 0, 1] + box_lines[:, 2, 0]) / 2, + (box_lines[:, 2, 0] + box_lines[:, 2, 1]) / 2, + ), axis=-1 + ) + else: + box_features = np.stack( + ( + (box_lines[:, 0] + box_lines[:, 3]) / 2, + (box_lines[:, 1] + box_lines[:, 2]) / 2, + (box_lines[:, 0] + box_lines[:, 2]) / 2, + (box_lines[:, 1] + box_lines[:, 3]) / 2, + box_lines[:, 0], + box_lines[:, 1], + ), axis=-1 + ) + # Compute clusters + clusters = fclusterdata(box_features, t=0.1, depth=4, criterion='distance', metric='euclidean') + + _blocks: Dict[int, List[int]] = {} + # Form clusters + for line_idx, cluster_idx in enumerate(clusters): + if cluster_idx in _blocks.keys(): + _blocks[cluster_idx].append(line_idx) + else: + _blocks[cluster_idx] = [line_idx] + + # Retrieve word-box level to return a fully nested structure + blocks = [[lines[idx] for idx in block] for block in _blocks.values()] + + return blocks + + def _build_blocks(self, boxes: np.ndarray, word_preds: List[Tuple[str, float]]) -> List[Block]: + """Gather independent words in structured blocks + + Args: + boxes: bounding boxes of all detected words of the page, of shape (N, 5) or (N, 4, 2) + word_preds: list of all detected words of the page, of shape N + + Returns: + list of block elements + """ + + if boxes.shape[0] != len(word_preds): + raise ValueError(f"Incompatible argument lengths: {boxes.shape[0]}, {len(word_preds)}") + + if boxes.shape[0] == 0: + return [] + + # Decide whether we try to form lines + _boxes = boxes + if self.resolve_lines: + lines = self._resolve_lines(_boxes if _boxes.ndim == 3 else _boxes[:, :4]) + # Decide whether we try to form blocks + if self.resolve_blocks and len(lines) > 1: + _blocks = self._resolve_blocks(_boxes if _boxes.ndim == 3 else _boxes[:, :4], lines) + else: + _blocks = [lines] + else: + # Sort bounding boxes, one line for all boxes, one block for the line + lines = [self._sort_boxes(_boxes if _boxes.ndim == 3 else _boxes[:, :4])[0]] + _blocks = [lines] + + blocks = [ + Block( + [Line( + [ + Word( + *word_preds[idx], + tuple([tuple(pt) for pt in boxes[idx].tolist()]) + ) if boxes.ndim == 3 else + Word( + *word_preds[idx], + ((boxes[idx, 0], boxes[idx, 1]), (boxes[idx, 2], boxes[idx, 3])) + ) for idx in line + ] + ) for line in lines] + ) for lines in _blocks + ] + + return blocks + + def extra_repr(self) -> str: + return (f"resolve_lines={self.resolve_lines}, resolve_blocks={self.resolve_blocks}, " + f"paragraph_break={self.paragraph_break}, " + f"export_as_straight_boxes={self.export_as_straight_boxes}") + + def __call__( + self, + boxes: List[np.ndarray], + text_preds: List[List[Tuple[str, float]]], + page_shapes: List[Tuple[int, int]] + ) -> Document: + """Re-arrange detected words into structured blocks + + Args: + boxes: list of N elements, where each element represents the localization predictions, of shape (*, 5) + or (*, 6) for all words for a given page + text_preds: list of N elements, where each element is the list of all word prediction (text + confidence) + page_shape: shape of each page, of size N + + Returns: + document object + """ + if len(boxes) != len(text_preds) or len(boxes) != len(page_shapes): + raise ValueError("All arguments are expected to be lists of the same size") + + if self.export_as_straight_boxes and len(boxes) > 0: + # If boxes are already straight OK, else fit a bounding rect + if boxes[0].ndim == 3: + straight_boxes = [] + # Iterate over pages + for p_boxes in boxes: + # Iterate over boxes of the pages + straight_boxes.append(np.concatenate((p_boxes.min(1), p_boxes.max(1)), 1)) + boxes = straight_boxes + + _pages = [ + Page( + self._build_blocks( + page_boxes, + word_preds, + ), + _idx, + shape, + ) + for _idx, shape, page_boxes, word_preds in zip(range(len(boxes)), page_shapes, boxes, text_preds) + ] + + return Document(_pages) diff --git a/doctr/models/classification/__init__.py b/doctr/models/classification/__init__.py new file mode 100644 index 0000000000..2f0109fd44 --- /dev/null +++ b/doctr/models/classification/__init__.py @@ -0,0 +1,5 @@ +from .mobilenet import * +from .resnet import * +from .vgg import * +from .magc_resnet import * +from .zoo import * diff --git a/doctr/models/classification/magc_resnet/__init__.py b/doctr/models/classification/magc_resnet/__init__.py new file mode 100644 index 0000000000..059f261e82 --- /dev/null +++ b/doctr/models/classification/magc_resnet/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/classification/magc_resnet/pytorch.py b/doctr/models/classification/magc_resnet/pytorch.py new file mode 100644 index 0000000000..93c0451dad --- /dev/null +++ b/doctr/models/classification/magc_resnet/pytorch.py @@ -0,0 +1,158 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + + +import math +from functools import partial +from typing import Any, Dict, List, Optional, Tuple + +import torch +from torch import nn + +from doctr.datasets import VOCABS + +from ...utils.pytorch import load_pretrained_params +from ..resnet.pytorch import ResNet + +__all__ = ['magc_resnet31'] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'magc_resnet31': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (3, 32, 32), + 'classes': list(VOCABS['french']), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/magc_resnet31-857391d8.pt', + }, +} + + +class MAGC(nn.Module): + """Implements the Multi-Aspect Global Context Attention, as described in + `_. + + Args: + inplanes: input channels + headers: number of headers to split channels + attn_scale: if True, re-scale attention to counteract the variance distibutions + ratio: bottleneck ratio + **kwargs + """ + + def __init__( + self, + inplanes: int, + headers: int = 8, + attn_scale: bool = False, + ratio: float = 0.0625, # bottleneck ratio of 1/16 as described in paper + ) -> None: + super().__init__() + + self.headers = headers + self.inplanes = inplanes + self.attn_scale = attn_scale + self.planes = int(inplanes * ratio) + + self.single_header_inplanes = int(inplanes / headers) + + self.conv_mask = nn.Conv2d(self.single_header_inplanes, 1, kernel_size=1) + self.softmax = nn.Softmax(dim=1) + + self.transform = nn.Sequential( + nn.Conv2d(self.inplanes, self.planes, kernel_size=1), + nn.LayerNorm([self.planes, 1, 1]), + nn.ReLU(inplace=True), + nn.Conv2d(self.planes, self.inplanes, kernel_size=1) + ) + + def forward(self, inputs: torch.Tensor) -> torch.Tensor: + + batch, _, height, width = inputs.size() + # (N * headers, C / headers, H , W) + x = inputs.view(batch * self.headers, self.single_header_inplanes, height, width) + shortcut = x + # (N * headers, C / headers, H * W) + shortcut = shortcut.view(batch * self.headers, self.single_header_inplanes, height * width) + + # (N * headers, 1, H, W) + context_mask = self.conv_mask(x) + # (N * headers, H * W) + context_mask = context_mask.view(batch * self.headers, -1) + + # scale variance + if self.attn_scale and self.headers > 1: + context_mask = context_mask / math.sqrt(self.single_header_inplanes) + + # (N * headers, H * W) + context_mask = self.softmax(context_mask) + + # (N * headers, C / headers) + context = (shortcut * context_mask.unsqueeze(1)).sum(-1) + + # (N, C, 1, 1) + context = context.view(batch, self.headers * self.single_header_inplanes, 1, 1) + + # Transform: B, C, 1, 1 -> B, C, 1, 1 + transformed = self.transform(context) + return inputs + transformed + + +def _magc_resnet( + arch: str, + pretrained: bool, + num_blocks: List[int], + output_channels: List[int], + stage_conv: List[bool], + stage_pooling: List[Optional[Tuple[int, int]]], + **kwargs: Any, +) -> ResNet: + + kwargs['num_classes'] = kwargs.get('num_classes', len(default_cfgs[arch]['classes'])) + + # Build the model + model = ResNet( + num_blocks, + output_channels, + stage_conv, + stage_pooling, + partial(MAGC, headers=8, attn_scale=True), + **kwargs, + ) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]['url']) + + return model + + +def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet: + """Resnet31 architecture with Multi-Aspect Global Context Attention as described in + `"MASTER: Multi-Aspect Non-local Network for Scene Text Recognition", + `_. + + Example:: + >>> import torch + >>> from doctr.models import magc_resnet31 + >>> model = magc_resnet31(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 224, 224), dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + A feature extractor model + """ + + return _magc_resnet( + 'magc_resnet31', + pretrained, + [1, 2, 5, 3], + [256, 256, 512, 512], + [True] * 4, + [(2, 2), (2, 1), None, None], + **kwargs, + ) diff --git a/doctr/models/classification/magc_resnet/tensorflow.py b/doctr/models/classification/magc_resnet/tensorflow.py new file mode 100644 index 0000000000..7d7fb8781d --- /dev/null +++ b/doctr/models/classification/magc_resnet/tensorflow.py @@ -0,0 +1,194 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + + +import math +from functools import partial +from typing import Any, Dict, List, Optional, Tuple + +import tensorflow as tf +from tensorflow.keras import layers +from tensorflow.keras.models import Sequential + +from doctr.datasets import VOCABS + +from ...utils import load_pretrained_params +from ..resnet.tensorflow import ResNet + +__all__ = ['magc_resnet31'] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'magc_resnet31': { + 'mean': (0.5, 0.5, 0.5), + 'std': (1., 1., 1.), + 'input_shape': (32, 32, 3), + 'classes': list(VOCABS['french']), + 'url': None, + }, +} + + +class MAGC(layers.Layer): + """Implements the Multi-Aspect Global Context Attention, as described in + `_. + + Args: + inplanes: input channels + headers: number of headers to split channels + attn_scale: if True, re-scale attention to counteract the variance distibutions + ratio: bottleneck ratio + **kwargs + """ + + def __init__( + self, + inplanes: int, + headers: int = 8, + attn_scale: bool = False, + ratio: float = 0.0625, # bottleneck ratio of 1/16 as described in paper + **kwargs + ) -> None: + super().__init__(**kwargs) + + self.headers = headers # h + self.inplanes = inplanes # C + self.attn_scale = attn_scale + self.planes = int(inplanes * ratio) + + self.single_header_inplanes = int(inplanes / headers) # C / h + + self.conv_mask = layers.Conv2D( + filters=1, + kernel_size=1, + kernel_initializer=tf.initializers.he_normal() + ) + + self.transform = Sequential( + [ + layers.Conv2D( + filters=self.planes, + kernel_size=1, + kernel_initializer=tf.initializers.he_normal() + ), + layers.LayerNormalization([1, 2, 3]), + layers.ReLU(), + layers.Conv2D( + filters=self.inplanes, + kernel_size=1, + kernel_initializer=tf.initializers.he_normal() + ), + ], + name='transform' + ) + + def context_modeling(self, inputs: tf.Tensor) -> tf.Tensor: + b, h, w, c = (tf.shape(inputs)[i] for i in range(4)) + + # B, H, W, C -->> B*h, H, W, C/h + x = tf.reshape(inputs, shape=(b, h, w, self.headers, self.single_header_inplanes)) + x = tf.transpose(x, perm=(0, 3, 1, 2, 4)) + x = tf.reshape(x, shape=(b * self.headers, h, w, self.single_header_inplanes)) + + # Compute shorcut + shortcut = x + # B*h, 1, H*W, C/h + shortcut = tf.reshape(shortcut, shape=(b * self.headers, 1, h * w, self.single_header_inplanes)) + # B*h, 1, C/h, H*W + shortcut = tf.transpose(shortcut, perm=[0, 1, 3, 2]) + + # Compute context mask + # B*h, H, W, 1 + context_mask = self.conv_mask(x) + # B*h, 1, H*W, 1 + context_mask = tf.reshape(context_mask, shape=(b * self.headers, 1, h * w, 1)) + # scale variance + if self.attn_scale and self.headers > 1: + context_mask = context_mask / math.sqrt(self.single_header_inplanes) + # B*h, 1, H*W, 1 + context_mask = tf.keras.activations.softmax(context_mask, axis=2) + + # Compute context + # B*h, 1, C/h, 1 + context = tf.matmul(shortcut, context_mask) + context = tf.reshape(context, shape=(b, 1, c, 1)) + # B, 1, 1, C + context = tf.transpose(context, perm=(0, 1, 3, 2)) + # Set shape to resolve shape when calling this module in the Sequential MAGCResnet + batch, chan = inputs.get_shape().as_list()[0], inputs.get_shape().as_list()[-1] + context.set_shape([batch, 1, 1, chan]) + return context + + def call(self, inputs: tf.Tensor, **kwargs) -> tf.Tensor: + # Context modeling: B, H, W, C -> B, 1, 1, C + context = self.context_modeling(inputs) + # Transform: B, 1, 1, C -> B, 1, 1, C + transformed = self.transform(context) + return inputs + transformed + + +def _magc_resnet( + arch: str, + pretrained: bool, + num_blocks: List[int], + output_channels: List[int], + stage_downsample: List[bool], + stage_conv: List[bool], + stage_pooling: List[Optional[Tuple[int, int]]], + origin_stem: bool = True, + **kwargs: Any, +) -> ResNet: + + kwargs['num_classes'] = kwargs.get('num_classes', len(default_cfgs[arch]['classes'])) + kwargs['input_shape'] = kwargs.get('input_shape', default_cfgs[arch]['input_shape']) + + # Build the model + model = ResNet( + num_blocks, + output_channels, + stage_downsample, + stage_conv, + stage_pooling, + origin_stem, + partial(MAGC, headers=8, attn_scale=True), + **kwargs, + ) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]['url']) + + return model + + +def magc_resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet: + """Resnet31 architecture with Multi-Aspect Global Context Attention as described in + `"MASTER: Multi-Aspect Non-local Network for Scene Text Recognition", + `_. + + Example:: + >>> import torch + >>> from doctr.models import magc_resnet31 + >>> model = magc_resnet31(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 224, 224), dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + A feature extractor model + """ + + return _magc_resnet( + 'magc_resnet31', + pretrained, + [1, 2, 5, 3], + [256, 256, 512, 512], + [False] * 4, + [True] * 4, + [(2, 2), (2, 1), None, None], + False, + **kwargs, + ) diff --git a/doctr/models/classification/mobilenet/__init__.py b/doctr/models/classification/mobilenet/__init__.py new file mode 100644 index 0000000000..059f261e82 --- /dev/null +++ b/doctr/models/classification/mobilenet/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/classification/mobilenet/pytorch.py b/doctr/models/classification/mobilenet/pytorch.py new file mode 100644 index 0000000000..f2a89672d5 --- /dev/null +++ b/doctr/models/classification/mobilenet/pytorch.py @@ -0,0 +1,204 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +# Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py + +from typing import Any, Dict, List, Optional + +from torchvision.models import mobilenetv3 + +from doctr.datasets import VOCABS + +from ...utils import load_pretrained_params + +__all__ = ["mobilenet_v3_small", "mobilenet_v3_small_r", "mobilenet_v3_large", + "mobilenet_v3_large_r", "mobilenet_v3_small_orientation"] + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'mobilenet_v3_large': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (3, 32, 32), + 'classes': list(VOCABS['french']), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/mobilenet_v3_large-11fc8cb9.pt', + }, + 'mobilenet_v3_large_r': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (3, 32, 32), + 'classes': list(VOCABS['french']), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/mobilenet_v3_large_r-74a22066.pt', + }, + 'mobilenet_v3_small': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (3, 32, 32), + 'classes': list(VOCABS['french']), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/mobilenet_v3_small-6a4bfa6b.pt', + }, + 'mobilenet_v3_small_r': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (3, 32, 32), + 'classes': list(VOCABS['french']), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/mobilenet_v3_small_r-1a8a3530.pt', + }, + 'mobilenet_v3_small_orientation': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (3, 128, 128), + 'classes': [0, 90, 180, 270], + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/classif_mobilenet_v3_small-24f8ff57.pt' + }, +} + + +def _mobilenet_v3( + arch: str, + pretrained: bool, + rect_strides: Optional[List[str]] = None, + **kwargs: Any +) -> mobilenetv3.MobileNetV3: + + kwargs['num_classes'] = kwargs.get('num_classes', len(default_cfgs[arch]['classes'])) + + if arch.startswith("mobilenet_v3_small"): + model = mobilenetv3.mobilenet_v3_small(**kwargs) + else: + model = mobilenetv3.mobilenet_v3_large(**kwargs) + + # Rectangular strides + if isinstance(rect_strides, list): + for layer_name in rect_strides: + m = model + for child in layer_name.split('.'): + m = getattr(m, child) + m.stride = (2, 1) + + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]['url']) + + model.cfg = default_cfgs[arch] + + return model + + +def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_. + + Example:: + >>> import torch + >>> from doctr.models import mobilenet_v3_small + >>> model = mobilenetv3_small(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 32, 32), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + a torch.nn.Module + """ + + return _mobilenet_v3('mobilenet_v3_small', pretrained, **kwargs) + + +def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_, with rectangular pooling. + + Example:: + >>> import torch + >>> from doctr.models import mobilenet_v3_small_r + >>> model = mobilenet_v3_small_r(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 32, 32), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + a torch.nn.Module + """ + + return _mobilenet_v3( + 'mobilenet_v3_small_r', + pretrained, + ['features.2.block.1.0', 'features.4.block.1.0', 'features.9.block.1.0'], + **kwargs + ) + + +def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: + """MobileNetV3-Large architecture as described in + `"Searching for MobileNetV3", + `_. + + Example:: + >>> import torch + >>> from doctr.models import mobilenetv3_large + >>> model = mobilenetv3_large(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 32, 32), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + a torch.nn.Module + """ + return _mobilenet_v3('mobilenet_v3_large', pretrained, **kwargs) + + +def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: + """MobileNetV3-Large architecture as described in + `"Searching for MobileNetV3", + `_, with rectangular pooling. + + Example:: + >>> import torch + >>> from doctr.models import mobilenet_v3_large_r + >>> model = mobilenet_v3_large_r(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 32, 32), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + a torch.nn.Module + """ + return _mobilenet_v3( + 'mobilenet_v3_large_r', + pretrained, + ['features.4.block.1.0', 'features.7.block.1.0', 'features.13.block.1.0'], + **kwargs + ) + + +def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> mobilenetv3.MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_. + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import mobilenet_v3_small_orientation + >>> model = mobilenet_v3_small_orientation(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + a torch.nn.Module + """ + + return _mobilenet_v3('mobilenet_v3_small_orientation', pretrained, **kwargs) diff --git a/doctr/models/classification/mobilenet/tensorflow.py b/doctr/models/classification/mobilenet/tensorflow.py new file mode 100644 index 0000000000..d1328327a3 --- /dev/null +++ b/doctr/models/classification/mobilenet/tensorflow.py @@ -0,0 +1,385 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +# Greatly inspired by https://github.com/pytorch/vision/blob/master/torchvision/models/mobilenetv3.py + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple, Union + +import tensorflow as tf +from tensorflow.keras import layers +from tensorflow.keras.models import Sequential + +from ....datasets import VOCABS +from ...utils import conv_sequence, load_pretrained_params + +__all__ = ["MobileNetV3", "mobilenet_v3_small", "mobilenet_v3_small_r", "mobilenet_v3_large", + "mobilenet_v3_large_r", "mobilenet_v3_small_orientation"] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'mobilenet_v3_large': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (32, 32, 3), + 'classes': list(VOCABS['french']), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/mobilenet_v3_large-47d25d7e.zip', + }, + 'mobilenet_v3_large_r': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (32, 32, 3), + 'classes': list(VOCABS['french']), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/mobilenet_v3_large_r-a108e192.zip', + }, + 'mobilenet_v3_small': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (32, 32, 3), + 'classes': list(VOCABS['french']), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/mobilenet_v3_small-8a32c32c.zip', + }, + 'mobilenet_v3_small_r': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (32, 32, 3), + 'classes': list(VOCABS['french']), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/mobilenet_v3_small_r-3d61452e.zip', + }, + 'mobilenet_v3_small_orientation': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (128, 128, 3), + 'classes': [0, 90, 180, 270], + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/classif_mobilenet_v3_small-1ea8db03.zip', + }, +} + + +def hard_swish(x: tf.Tensor) -> tf.Tensor: + return x * tf.nn.relu6(x + 3.) / 6.0 + + +def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: + if min_value is None: + min_value = divisor + new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) + # Make sure that round down does not go down by more than 10%. + if new_v < 0.9 * v: + new_v += divisor + return new_v + + +class SqueezeExcitation(Sequential): + """Squeeze and Excitation. + """ + def __init__(self, chan: int, squeeze_factor: int = 4) -> None: + super().__init__( + [ + layers.GlobalAveragePooling2D(), + layers.Dense(chan // squeeze_factor, activation='relu'), + layers.Dense(chan, activation='hard_sigmoid'), + layers.Reshape((1, 1, chan)) + ] + ) + + def call(self, inputs: tf.Tensor, **kwargs: Any) -> tf.Tensor: + x = super().call(inputs, **kwargs) + x = tf.math.multiply(inputs, x) + return x + + +class InvertedResidualConfig: + def __init__( + self, + input_channels: int, + kernel: int, + expanded_channels: int, + out_channels: int, + use_se: bool, + activation: str, + stride: Union[int, Tuple[int, int]], + width_mult: float = 1, + ) -> None: + self.input_channels = self.adjust_channels(input_channels, width_mult) + self.kernel = kernel + self.expanded_channels = self.adjust_channels(expanded_channels, width_mult) + self.out_channels = self.adjust_channels(out_channels, width_mult) + self.use_se = use_se + self.use_hs = activation == "HS" + self.stride = stride + + @staticmethod + def adjust_channels(channels: int, width_mult: float): + return _make_divisible(channels * width_mult, 8) + + +class InvertedResidual(layers.Layer): + """InvertedResidual for mobilenet + + Args: + conf: configuration object for inverted residual + """ + def __init__( + self, + conf: InvertedResidualConfig, + **kwargs: Any, + ) -> None: + _kwargs = {'input_shape': kwargs.pop('input_shape')} if isinstance(kwargs.get('input_shape'), tuple) else {} + super().__init__(**kwargs) + + act_fn = hard_swish if conf.use_hs else tf.nn.relu + + _is_s1 = (isinstance(conf.stride, tuple) and conf.stride == (1, 1)) or conf.stride == 1 + self.use_res_connect = _is_s1 and conf.input_channels == conf.out_channels + + _layers = [] + # expand + if conf.expanded_channels != conf.input_channels: + _layers.extend(conv_sequence(conf.expanded_channels, act_fn, kernel_size=1, bn=True, **_kwargs)) + + # depth-wise + _layers.extend(conv_sequence( + conf.expanded_channels, act_fn, kernel_size=conf.kernel, strides=conf.stride, bn=True, + groups=conf.expanded_channels, + )) + + if conf.use_se: + _layers.append(SqueezeExcitation(conf.expanded_channels)) + + # project + _layers.extend(conv_sequence( + conf.out_channels, None, kernel_size=1, bn=True, + )) + + self.block = Sequential(_layers) + + def call( + self, + inputs: tf.Tensor, + **kwargs: Any, + ) -> tf.Tensor: + + out = self.block(inputs, **kwargs) + if self.use_res_connect: + out = tf.add(out, inputs) + + return out + + +class MobileNetV3(Sequential): + """Implements MobileNetV3, inspired from both: + `_. + and `_. + """ + + def __init__( + self, + layout: List[InvertedResidualConfig], + include_top: bool = True, + head_chans: int = 1024, + num_classes: int = 1000, + cfg: Optional[Dict[str, Any]] = None, + input_shape: Optional[Tuple[int, int, int]] = None, + ) -> None: + + _layers = [ + Sequential(conv_sequence(layout[0].input_channels, hard_swish, True, kernel_size=3, strides=2, + input_shape=input_shape), name="stem") + ] + + for idx, conf in enumerate(layout): + _layers.append( + InvertedResidual(conf, name=f"inverted_{idx}"), + ) + + _layers.append( + Sequential( + conv_sequence(6 * layout[-1].out_channels, hard_swish, True, kernel_size=1), + name="final_block" + ) + ) + + if include_top: + _layers.extend([ + layers.GlobalAveragePooling2D(), + layers.Dense(head_chans, activation=hard_swish), + layers.Dropout(0.2), + layers.Dense(num_classes), + ]) + + super().__init__(_layers) + self.cfg = cfg + + +def _mobilenet_v3( + arch: str, + pretrained: bool, + rect_strides: bool = False, + **kwargs: Any +) -> MobileNetV3: + _cfg = deepcopy(default_cfgs[arch]) + _cfg['input_shape'] = kwargs.get('input_shape', default_cfgs[arch]['input_shape']) + _cfg['num_classes'] = kwargs.get('num_classes', len(default_cfgs[arch]['classes'])) + + # cf. Table 1 & 2 of the paper + if arch.startswith("mobilenet_v3_small"): + inverted_residual_setting = [ + InvertedResidualConfig(16, 3, 16, 16, True, "RE", 2), # C1 + InvertedResidualConfig(16, 3, 72, 24, False, "RE", (2, 1) if rect_strides else 2), # C2 + InvertedResidualConfig(24, 3, 88, 24, False, "RE", 1), + InvertedResidualConfig(24, 5, 96, 40, True, "HS", (2, 1) if rect_strides else 2), # C3 + InvertedResidualConfig(40, 5, 240, 40, True, "HS", 1), + InvertedResidualConfig(40, 5, 240, 40, True, "HS", 1), + InvertedResidualConfig(40, 5, 120, 48, True, "HS", 1), + InvertedResidualConfig(48, 5, 144, 48, True, "HS", 1), + InvertedResidualConfig(48, 5, 288, 96, True, "HS", (2, 1) if rect_strides else 2), # C4 + InvertedResidualConfig(96, 5, 576, 96, True, "HS", 1), + InvertedResidualConfig(96, 5, 576, 96, True, "HS", 1), + ] + head_chans = 1024 + else: + inverted_residual_setting = [ + InvertedResidualConfig(16, 3, 16, 16, False, "RE", 1), + InvertedResidualConfig(16, 3, 64, 24, False, "RE", 2), # C1 + InvertedResidualConfig(24, 3, 72, 24, False, "RE", 1), + InvertedResidualConfig(24, 5, 72, 40, True, "RE", (2, 1) if rect_strides else 2), # C2 + InvertedResidualConfig(40, 5, 120, 40, True, "RE", 1), + InvertedResidualConfig(40, 5, 120, 40, True, "RE", 1), + InvertedResidualConfig(40, 3, 240, 80, False, "HS", (2, 1) if rect_strides else 2), # C3 + InvertedResidualConfig(80, 3, 200, 80, False, "HS", 1), + InvertedResidualConfig(80, 3, 184, 80, False, "HS", 1), + InvertedResidualConfig(80, 3, 184, 80, False, "HS", 1), + InvertedResidualConfig(80, 3, 480, 112, True, "HS", 1), + InvertedResidualConfig(112, 3, 672, 112, True, "HS", 1), + InvertedResidualConfig(112, 5, 672, 160, True, "HS", (2, 1) if rect_strides else 2), # C4 + InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1), + InvertedResidualConfig(160, 5, 960, 160, True, "HS", 1), + ] + head_chans = 1280 + + kwargs['num_classes'] = _cfg['num_classes'] + kwargs['input_shape'] = _cfg['input_shape'] + + # Build the model + model = MobileNetV3( + inverted_residual_setting, + head_chans=head_chans, + cfg=_cfg, + **kwargs, + ) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]['url']) + + return model + + +def mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_. + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import mobilenetv3_large + >>> model = mobilenetv3_small(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + a keras.Model + """ + + return _mobilenet_v3('mobilenet_v3_small', pretrained, False, **kwargs) + + +def mobilenet_v3_small_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_, with rectangular pooling. + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import mobilenet_v3_small_r + >>> model = mobilenet_v3_small_r(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + a keras.Model + """ + + return _mobilenet_v3('mobilenet_v3_small_r', pretrained, True, **kwargs) + + +def mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: + """MobileNetV3-Large architecture as described in + `"Searching for MobileNetV3", + `_. + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import mobilenetv3_large + >>> model = mobilenetv3_large(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + a keras.Model + """ + return _mobilenet_v3('mobilenet_v3_large', pretrained, False, **kwargs) + + +def mobilenet_v3_large_r(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: + """MobileNetV3-Large architecture as described in + `"Searching for MobileNetV3", + `_. + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import mobilenet_v3_large_r + >>> model = mobilenet_v3_large_r(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + a keras.Model + """ + return _mobilenet_v3('mobilenet_v3_large_r', pretrained, True, **kwargs) + + +def mobilenet_v3_small_orientation(pretrained: bool = False, **kwargs: Any) -> MobileNetV3: + """MobileNetV3-Small architecture as described in + `"Searching for MobileNetV3", + `_. + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import mobilenet_v3_small_orientation + >>> model = mobilenet_v3_small_orientation(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 512, 512, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + a keras.Model + """ + + return _mobilenet_v3('mobilenet_v3_small_orientation', pretrained, include_top=True, **kwargs) diff --git a/doctr/models/classification/predictor/__init__.py b/doctr/models/classification/predictor/__init__.py new file mode 100644 index 0000000000..059f261e82 --- /dev/null +++ b/doctr/models/classification/predictor/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/classification/predictor/pytorch.py b/doctr/models/classification/predictor/pytorch.py new file mode 100644 index 0000000000..138a5ba579 --- /dev/null +++ b/doctr/models/classification/predictor/pytorch.py @@ -0,0 +1,55 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import List, Union + +import numpy as np +import torch +from torch import nn + +from doctr.models.preprocessor import PreProcessor + +__all__ = ['CropOrientationPredictor'] + + +class CropOrientationPredictor(nn.Module): + """Implements an object able to detect the reading direction of a text box. + 4 possible orientations: 0, 90, 180, 270 degrees counter clockwise. + + Args: + pre_processor: transform inputs for easier batched model inference + model: core classification architecture (backbone + classification head) + """ + + def __init__( + self, + pre_processor: PreProcessor, + model: nn.Module, + ) -> None: + + super().__init__() + self.pre_processor = pre_processor + self.model = model.eval() + + @torch.no_grad() + def forward( + self, + crops: List[Union[np.ndarray, torch.Tensor]], + ) -> List[int]: + + # Dimension check + if any(crop.ndim != 3 for crop in crops): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + processed_batches = self.pre_processor(crops) + predicted_batches = [ + self.model(batch) + for batch in processed_batches + ] + + # Postprocess predictions + predicted_batches = [out_batch.argmax(dim=1).cpu().detach().numpy() for out_batch in predicted_batches] + + return [int(pred) for batch in predicted_batches for pred in batch] diff --git a/doctr/models/classification/predictor/tensorflow.py b/doctr/models/classification/predictor/tensorflow.py new file mode 100644 index 0000000000..d4954c6600 --- /dev/null +++ b/doctr/models/classification/predictor/tensorflow.py @@ -0,0 +1,56 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import List, Union + +import numpy as np +import tensorflow as tf +from tensorflow import keras + +from doctr.models.preprocessor import PreProcessor +from doctr.utils.repr import NestedObject + +__all__ = ['CropOrientationPredictor'] + + +class CropOrientationPredictor(NestedObject): + """Implements an object able to detect the reading direction of a text box. + 4 possible orientations: 0, 90, 180, 270 degrees counter clockwise. + + Args: + pre_processor: transform inputs for easier batched model inference + model: core classification architecture (backbone + classification head) + """ + + _children_names: List[str] = ['pre_processor', 'model'] + + def __init__( + self, + pre_processor: PreProcessor, + model: keras.Model, + ) -> None: + + self.pre_processor = pre_processor + self.model = model + + def __call__( + self, + crops: List[Union[np.ndarray, tf.Tensor]], + ) -> List[int]: + + # Dimension check + if any(crop.ndim != 3 for crop in crops): + raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.") + + processed_batches = self.pre_processor(crops) + predicted_batches = [ + self.model(batch, training=False) + for batch in processed_batches + ] + + # Postprocess predictions + predicted_batches = [out_batch.numpy().argmax(1) for out_batch in predicted_batches] + + return [int(pred) for batch in predicted_batches for pred in batch] diff --git a/doctr/models/classification/resnet/__init__.py b/doctr/models/classification/resnet/__init__.py new file mode 100644 index 0000000000..059f261e82 --- /dev/null +++ b/doctr/models/classification/resnet/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/classification/resnet/pytorch.py b/doctr/models/classification/resnet/pytorch.py new file mode 100644 index 0000000000..530e970a9b --- /dev/null +++ b/doctr/models/classification/resnet/pytorch.py @@ -0,0 +1,200 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + + +from typing import Any, Callable, Dict, List, Optional, Tuple + +from torch import nn +from torchvision.models.resnet import BasicBlock +from torchvision.models.resnet import ResNet as TVResNet +from torchvision.models.resnet import resnet18 as tv_resnet18 + +from doctr.datasets import VOCABS + +from ...utils import conv_sequence_pt, load_pretrained_params + +__all__ = ['ResNet', 'resnet18', 'resnet31', 'resnet_stage'] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'resnet18': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (3, 32, 32), + 'classes': list(VOCABS['french']), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/resnet18-244bf390.pt', + }, + 'resnet31': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (3, 32, 32), + 'classes': list(VOCABS['french']), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/resnet31-1056cc5c.pt', + }, +} + + +def resnet_stage(in_channels: int, out_channels: int, num_blocks: int) -> List[nn.Module]: + _layers: List[nn.Module] = [] + + in_chan = in_channels + for _ in range(num_blocks): + downsample = None + if in_chan != out_channels: + downsample = nn.Sequential(*conv_sequence_pt(in_chan, out_channels, False, True, kernel_size=1)) + + _layers.append(BasicBlock(in_chan, out_channels, downsample=downsample)) + in_chan = out_channels + + return _layers + + +class ResNet(nn.Sequential): + """Implements a ResNet-31 architecture from `"Show, Attend and Read:A Simple and Strong Baseline for Irregular + Text Recognition" `_. + + Args: + num_blocks: number of resnet block in each stage + output_channels: number of channels in each stage + stage_conv: whether to add a conv_sequence after each stage + stage_pooling: pooling to add after each stage (if None, no pooling) + attn_module: attention module to use in each stage + include_top: whether the classifier head should be instantiated + num_classes: number of output classes + """ + + def __init__( + self, + num_blocks: List[int], + output_channels: List[int], + stage_conv: List[bool], + stage_pooling: List[Optional[Tuple[int, int]]], + attn_module: Optional[Callable[[int], nn.Module]] = None, + include_top: bool = True, + num_classes: int = 1000, + ) -> None: + + _layers: List[nn.Module] = [ + *conv_sequence_pt(3, 64, True, True, kernel_size=3, padding=1), + *conv_sequence_pt(64, 128, True, True, kernel_size=3, padding=1), + nn.MaxPool2d(2), + ] + in_chans = [128] + output_channels[:-1] + for in_chan, out_chan, n_blocks, conv, pool in zip(in_chans, output_channels, num_blocks, stage_conv, + stage_pooling): + _stage = resnet_stage(in_chan, out_chan, n_blocks) + if attn_module is not None: + _stage.append(attn_module(out_chan)) + if conv: + _stage.extend(conv_sequence_pt(out_chan, out_chan, True, True, kernel_size=3, padding=1)) + if pool is not None: + _stage.append(nn.MaxPool2d(pool)) + _layers.append(nn.Sequential(*_stage)) + + if include_top: + _layers.extend([ + nn.AdaptiveAvgPool2d(1), + nn.Flatten(1), + nn.Linear(output_channels[-1], num_classes, bias=True), + ]) + + super().__init__(*_layers) + + for m in self.modules(): + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + +def _resnet( + arch: str, + pretrained: bool, + num_blocks: List[int], + output_channels: List[int], + stage_conv: List[bool], + stage_pooling: List[Optional[Tuple[int, int]]], + **kwargs: Any, +) -> ResNet: + + kwargs['num_classes'] = kwargs.get('num_classes', len(default_cfgs[arch]['classes'])) + + # Build the model + model = ResNet(num_blocks, output_channels, stage_conv, stage_pooling, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]['url']) + + return model + + +def _tv_resnet( + arch: str, + pretrained: bool, + arch_fn, + **kwargs: Any, +) -> TVResNet: + + kwargs['num_classes'] = kwargs.get('num_classes', len(default_cfgs[arch]['classes'])) + + # Build the model + model = arch_fn(**kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]['url']) + + return model + + +def resnet18(pretrained: bool = False, **kwargs: Any) -> TVResNet: + """ResNet-18 architecture as described in `"Deep Residual Learning for Image Recognition", + `_. + + Example:: + >>> import torch + >>> from doctr.models import resnet18 + >>> model = resnet18(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 224, 224), dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + A resnet18 model + """ + + return _tv_resnet('resnet18', pretrained, tv_resnet18, **kwargs) + + +def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet: + """Resnet31 architecture with rectangular pooling windows as described in + `"Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition", + `_. Downsizing: (H, W) --> (H/8, W/4) + + Example:: + >>> import torch + >>> from doctr.models import resnet31 + >>> model = resnet31(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 224, 224), dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + A resnet31 model + """ + + return _resnet( + 'resnet31', + pretrained, + [1, 2, 5, 3], + [256, 256, 512, 512], + [True] * 4, + [(2, 2), (2, 1), None, None], + **kwargs, + ) diff --git a/doctr/models/classification/resnet/tensorflow.py b/doctr/models/classification/resnet/tensorflow.py new file mode 100644 index 0000000000..4c691c416f --- /dev/null +++ b/doctr/models/classification/resnet/tensorflow.py @@ -0,0 +1,263 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, Callable, Dict, List, Optional, Tuple + +import tensorflow as tf +from tensorflow.keras import layers +from tensorflow.keras.models import Sequential + +from doctr.datasets import VOCABS + +from ...utils import conv_sequence, load_pretrained_params + +__all__ = ['ResNet', 'resnet18', 'resnet31'] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'resnet18': { + 'mean': (0.5, 0.5, 0.5), + 'std': (1., 1., 1.), + 'input_shape': (32, 32, 3), + 'classes': list(VOCABS['french']), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/resnet18-d4634669.zip', + }, + 'resnet31': { + 'mean': (0.5, 0.5, 0.5), + 'std': (1., 1., 1.), + 'input_shape': (32, 32, 3), + 'classes': list(VOCABS['french']), + 'url': None, + }, +} + + +class ResnetBlock(layers.Layer): + + """Implements a resnet31 block with shortcut + + Args: + conv_shortcut: Use of shortcut + output_channels: number of channels to use in Conv2D + kernel_size: size of square kernels + strides: strides to use in the first convolution of the block + """ + def __init__( + self, + output_channels: int, + conv_shortcut: bool, + strides: int = 1, + **kwargs + ) -> None: + + super().__init__(**kwargs) + if conv_shortcut: + self.shortcut = Sequential( + [ + layers.Conv2D( + filters=output_channels, + strides=strides, + padding='same', + kernel_size=1, + use_bias=False, + kernel_initializer='he_normal' + ), + layers.BatchNormalization() + ] + ) + else: + self.shortcut = layers.Lambda(lambda x: x) + self.conv_block = Sequential( + self.conv_resnetblock(output_channels, 3, strides) + ) + self.act = layers.Activation('relu') + + @staticmethod + def conv_resnetblock( + output_channels: int, + kernel_size: int, + strides: int = 1, + ) -> List[layers.Layer]: + return [ + *conv_sequence(output_channels, 'relu', bn=True, strides=strides, kernel_size=kernel_size), + *conv_sequence(output_channels, None, bn=True, kernel_size=kernel_size), + ] + + def call( + self, + inputs: tf.Tensor + ) -> tf.Tensor: + clone = self.shortcut(inputs) + conv_out = self.conv_block(inputs) + out = self.act(clone + conv_out) + + return out + + +def resnet_stage( + num_blocks: int, + out_channels: int, + shortcut: bool = False, + downsample: bool = False +) -> List[layers.Layer]: + _layers: List[layers.Layer] = [ + ResnetBlock(out_channels, conv_shortcut=shortcut, strides=2 if downsample else 1) + ] + + for _ in range(1, num_blocks): + _layers.append(ResnetBlock(out_channels, conv_shortcut=False)) + + return _layers + + +class ResNet(Sequential): + """Implements a ResNet architecture + + Args: + num_blocks: number of resnet block in each stage + output_channels: number of channels in each stage + stage_downsample: whether the first residual block of a stage should downsample + stage_conv: whether to add a conv_sequence after each stage + stage_pooling: pooling to add after each stage (if None, no pooling) + origin_stem: whether to use the orginal ResNet stem or ResNet-31's + attn_module: attention module to use in each stage + include_top: whether the classifier head should be instantiated + num_classes: number of output classes + input_shape: shape of inputs + """ + + def __init__( + self, + num_blocks: List[int], + output_channels: List[int], + stage_downsample: List[bool], + stage_conv: List[bool], + stage_pooling: List[Optional[Tuple[int, int]]], + origin_stem: bool = True, + attn_module: Optional[Callable[[int], layers.Layer]] = None, + include_top: bool = True, + num_classes: int = 1000, + input_shape: Optional[Tuple[int, int, int]] = None, + ) -> None: + + if origin_stem: + _layers = [ + *conv_sequence(64, 'relu', True, kernel_size=7, strides=2, input_shape=input_shape), + layers.MaxPool2D(pool_size=(3, 3), strides=2, padding='same'), + ] + inplanes = 64 + else: + _layers = [ + *conv_sequence(64, 'relu', True, kernel_size=3, input_shape=input_shape), + *conv_sequence(128, 'relu', True, kernel_size=3), + layers.MaxPool2D(pool_size=2, strides=2, padding='valid'), + ] + inplanes = 128 + + for n_blocks, out_chan, down, conv, pool in zip(num_blocks, output_channels, stage_downsample, stage_conv, + stage_pooling): + _layers.extend(resnet_stage(n_blocks, out_chan, out_chan != inplanes, down)) + if attn_module is not None: + _layers.append(attn_module(out_chan)) + if conv: + _layers.extend(conv_sequence(out_chan, activation='relu', bn=True, kernel_size=3)) + if pool: + _layers.append(layers.MaxPool2D(pool_size=pool, strides=pool, padding='valid')) + inplanes = out_chan + + if include_top: + _layers.extend([ + layers.GlobalAveragePooling2D(), + layers.Dense(num_classes), + ]) + + super().__init__(_layers) + + +def _resnet( + arch: str, + pretrained: bool, + num_blocks: List[int], + output_channels: List[int], + stage_downsample: List[bool], + stage_conv: List[bool], + stage_pooling: List[Optional[Tuple[int, int]]], + origin_stem: bool = True, + **kwargs: Any +) -> ResNet: + + kwargs['num_classes'] = kwargs.get('num_classes', len(default_cfgs[arch]['classes'])) + kwargs['input_shape'] = kwargs.get('input_shape', default_cfgs[arch]['input_shape']) + + # Build the model + model = ResNet(num_blocks, output_channels, stage_downsample, stage_conv, stage_pooling, origin_stem, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]['url']) + + return model + + +def resnet18(pretrained: bool = False, **kwargs: Any) -> ResNet: + """Resnet-18 architecture as described in `"Deep Residual Learning for Image Recognition", + `_. + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import resnet18 + >>> model = resnet18(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 224, 224, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + A classification model + """ + + return _resnet( + 'resnet18', + pretrained, + [2, 2, 2, 2], + [64, 128, 256, 512], + [False, True, True, True], + [False] * 4, + [None] * 4, + True, + **kwargs, + ) + + +def resnet31(pretrained: bool = False, **kwargs: Any) -> ResNet: + """Resnet31 architecture with rectangular pooling windows as described in + `"Show, Attend and Read:A Simple and Strong Baseline for Irregular Text Recognition", + `_. Downsizing: (H, W) --> (H/8, W/4) + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import resnet31 + >>> model = resnet31(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 224, 224, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained: boolean, True if model is pretrained + + Returns: + A classification model + """ + + return _resnet( + 'resnet31', + pretrained, + [1, 2, 5, 3], + [256, 256, 512, 512], + [False] * 4, + [True] * 4, + [(2, 2), (2, 1), None, None], + False, + **kwargs, + ) diff --git a/doctr/models/classification/vgg/__init__.py b/doctr/models/classification/vgg/__init__.py new file mode 100644 index 0000000000..059f261e82 --- /dev/null +++ b/doctr/models/classification/vgg/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/classification/vgg/pytorch.py b/doctr/models/classification/vgg/pytorch.py new file mode 100644 index 0000000000..abf133e687 --- /dev/null +++ b/doctr/models/classification/vgg/pytorch.py @@ -0,0 +1,75 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, Dict + +from torch import nn +from torchvision.models import vgg as tv_vgg + +from doctr.datasets import VOCABS + +from ...utils import load_pretrained_params + +__all__ = ['vgg16_bn_r'] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'vgg16_bn_r': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (3, 32, 32), + 'classes': list(VOCABS['french']), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/vgg16_bn_r-d108c19c.pt', + }, +} + + +def _vgg( + arch: str, + pretrained: bool, + tv_arch: str, + num_rect_pools: int = 3, + **kwargs: Any +) -> tv_vgg.VGG: + + kwargs['num_classes'] = kwargs.get('num_classes', len(default_cfgs[arch]['classes'])) + + # Build the model + model = tv_vgg.__dict__[tv_arch](**kwargs) + # List the MaxPool2d + pool_idcs = [idx for idx, m in enumerate(model.features) if isinstance(m, nn.MaxPool2d)] + # Replace their kernel with rectangular ones + for idx in pool_idcs[-num_rect_pools:]: + model.features[idx] = nn.MaxPool2d((2, 1)) + # Patch average pool & classification head + model.avgpool = nn.AdaptiveAvgPool2d((1, 1)) + model.classifier = nn.Linear(512, kwargs['num_classes']) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]['url']) + + return model + + +def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> tv_vgg.VGG: + """VGG-16 architecture as described in `"Very Deep Convolutional Networks for Large-Scale Image Recognition" + `_, modified by adding batch normalization, rectangular pooling and a simpler + classification head. + + Example:: + >>> import torch + >>> from doctr.models import vgg16_bn_r + >>> model = vgg16_bn_r(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 224, 224), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Returns: + VGG feature extractor + """ + + return _vgg('vgg16_bn_r', pretrained, 'vgg16_bn', 3, **kwargs) diff --git a/doctr/models/classification/vgg/tensorflow.py b/doctr/models/classification/vgg/tensorflow.py new file mode 100644 index 0000000000..39e4f940ea --- /dev/null +++ b/doctr/models/classification/vgg/tensorflow.py @@ -0,0 +1,115 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, Dict, List, Optional, Tuple + +from tensorflow.keras import layers +from tensorflow.keras.models import Sequential + +from doctr.datasets import VOCABS + +from ...utils import conv_sequence, load_pretrained_params + +__all__ = ['VGG', 'vgg16_bn_r'] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'vgg16_bn_r': { + 'mean': (0.5, 0.5, 0.5), + 'std': (1., 1., 1.), + 'input_shape': (32, 32, 3), + 'classes': list(VOCABS['french']), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/vgg16_bn_r-c5836cea.zip', + }, +} + + +class VGG(Sequential): + """Implements the VGG architecture from `"Very Deep Convolutional Networks for Large-Scale Image Recognition" + `_. + + Args: + num_blocks: number of convolutional block in each stage + planes: number of output channels in each stage + rect_pools: whether pooling square kernels should be replace with rectangular ones + include_top: whether the classifier head should be instantiated + num_classes: number of output classes + input_shape: shapes of the input tensor + """ + def __init__( + self, + num_blocks: List[int], + planes: List[int], + rect_pools: List[bool], + include_top: bool = False, + num_classes: int = 1000, + input_shape: Optional[Tuple[int, int, int]] = None, + ) -> None: + + _layers = [] + # Specify input_shape only for the first layer + kwargs = {"input_shape": input_shape} + for nb_blocks, out_chan, rect_pool in zip(num_blocks, planes, rect_pools): + for _ in range(nb_blocks): + _layers.extend(conv_sequence(out_chan, 'relu', True, kernel_size=3, **kwargs)) # type: ignore[arg-type] + kwargs = {} + _layers.append(layers.MaxPooling2D((2, 1 if rect_pool else 2))) + + if include_top: + _layers.extend([ + layers.GlobalAveragePooling2D(), + layers.Dense(num_classes) + ]) + super().__init__(_layers) + + +def _vgg( + arch: str, + pretrained: bool, + num_blocks: List[int], + planes: List[int], + rect_pools: List[bool], + **kwargs: Any +) -> VGG: + + kwargs['num_classes'] = kwargs.get("num_classes", len(default_cfgs[arch]['classes'])) + kwargs['input_shape'] = kwargs.get("input_shape", default_cfgs[arch]['input_shape']) + + # Build the model + model = VGG(num_blocks, planes, rect_pools, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]['url']) + + return model + + +def vgg16_bn_r(pretrained: bool = False, **kwargs: Any) -> VGG: + """VGG-16 architecture as described in `"Very Deep Convolutional Networks for Large-Scale Image Recognition" + `_, modified by adding batch normalization, rectangular pooling and a simpler + classification head. + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import vgg16_bn_r + >>> model = vgg16_bn_r(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 224, 224, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on ImageNet + + Returns: + VGG feature extractor + """ + + return _vgg( + 'vgg16_bn_r', + pretrained, + [2, 2, 3, 3, 3], + [64, 128, 256, 512, 512], + [False, False, True, True, True], + **kwargs + ) diff --git a/doctr/models/classification/zoo.py b/doctr/models/classification/zoo.py new file mode 100644 index 0000000000..8450129a14 --- /dev/null +++ b/doctr/models/classification/zoo.py @@ -0,0 +1,67 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any + +from doctr.file_utils import is_tf_available, is_torch_available + +from .. import classification +from ..preprocessor import PreProcessor +from .predictor import CropOrientationPredictor + +__all__ = ["crop_orientation_predictor"] + + +if is_tf_available(): + ARCHS = ['mobilenet_v3_small_orientation'] +elif is_torch_available(): + ARCHS = ['mobilenet_v3_small_orientation'] + + +def _crop_orientation_predictor( + arch: str, + pretrained: bool, + **kwargs: Any +) -> CropOrientationPredictor: + + if arch not in ARCHS: + raise ValueError(f"unknown architecture '{arch}'") + + # Load directly classifier from backbone + _model = classification.__dict__[arch](pretrained=pretrained) + kwargs['mean'] = kwargs.get('mean', _model.cfg['mean']) + kwargs['std'] = kwargs.get('std', _model.cfg['std']) + kwargs['batch_size'] = kwargs.get('batch_size', 64) + input_shape = _model.cfg['input_shape'][:-1] if is_tf_available() else _model.cfg['input_shape'][1:] + predictor = CropOrientationPredictor( + PreProcessor(input_shape, preserve_aspect_ratio=True, symmetric_pad=True, **kwargs), + _model + ) + return predictor + + +def crop_orientation_predictor( + arch: str = 'mobilenet_v3_small_orientation', + pretrained: bool = False, + **kwargs: Any +) -> CropOrientationPredictor: + """Orientation classification architecture. + + Example:: + >>> import numpy as np + >>> from doctr.models import crop_orientation_predictor + >>> model = crop_orientation_predictor(arch='classif_mobilenet_v3_small', pretrained=True) + >>> input_crop = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([input_crop]) + + Args: + arch: name of the architecture to use (e.g. 'mobilenet_v3_small') + pretrained: If True, returns a model pre-trained on our recognition crops dataset + + Returns: + CropOrientationPredictor + """ + + return _crop_orientation_predictor(arch, pretrained, **kwargs) diff --git a/doctr/models/core.py b/doctr/models/core.py new file mode 100644 index 0000000000..4bfb3fa70b --- /dev/null +++ b/doctr/models/core.py @@ -0,0 +1,19 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + + +from typing import Any, Dict, Optional + +from doctr.utils.repr import NestedObject + +__all__ = ['BaseModel'] + + +class BaseModel(NestedObject): + """Implements abstract DetectionModel class""" + + def __init__(self, cfg: Optional[Dict[str, Any]] = None) -> None: + super().__init__() + self.cfg = cfg diff --git a/doctr/models/detection/__init__.py b/doctr/models/detection/__init__.py new file mode 100644 index 0000000000..e2fafbadba --- /dev/null +++ b/doctr/models/detection/__init__.py @@ -0,0 +1,3 @@ +from .differentiable_binarization import * +from .linknet import * +from .zoo import * diff --git a/doctr/models/detection/_utils/__init__.py b/doctr/models/detection/_utils/__init__.py new file mode 100644 index 0000000000..6a3fee30ac --- /dev/null +++ b/doctr/models/detection/_utils/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available + +if is_tf_available(): + from .tensorflow import * +else: + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/detection/_utils/pytorch.py b/doctr/models/detection/_utils/pytorch.py new file mode 100644 index 0000000000..efbdc411e0 --- /dev/null +++ b/doctr/models/detection/_utils/pytorch.py @@ -0,0 +1,37 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from torch import Tensor +from torch.nn.functional import max_pool2d + +__all__ = ['erode', 'dilate'] + + +def erode(x: Tensor, kernel_size: int) -> Tensor: + """Performs erosion on a given tensor + + Args: + x: boolean tensor of shape (N, C, H, W) + kernel_size: the size of the kernel to use for erosion + Returns: + the eroded tensor + """ + _pad = (kernel_size - 1) // 2 + + return 1 - max_pool2d(1 - x, kernel_size, stride=1, padding=_pad) + + +def dilate(x: Tensor, kernel_size: int) -> Tensor: + """Performs dilation on a given tensor + + Args: + x: boolean tensor of shape (N, C, H, W) + kernel_size: the size of the kernel to use for dilation + Returns: + the dilated tensor + """ + _pad = (kernel_size - 1) // 2 + + return max_pool2d(x, kernel_size, stride=1, padding=_pad) diff --git a/doctr/models/detection/_utils/tensorflow.py b/doctr/models/detection/_utils/tensorflow.py new file mode 100644 index 0000000000..5fd85437c3 --- /dev/null +++ b/doctr/models/detection/_utils/tensorflow.py @@ -0,0 +1,34 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import tensorflow as tf + +__all__ = ['erode', 'dilate'] + + +def erode(x: tf.Tensor, kernel_size: int) -> tf.Tensor: + """Performs erosion on a given tensor + + Args: + x: boolean tensor of shape (N, H, W, C) + kernel_size: the size of the kernel to use for erosion + Returns: + the eroded tensor + """ + + return 1 - tf.nn.max_pool2d(1 - x, kernel_size, strides=1, padding="SAME") + + +def dilate(x: tf.Tensor, kernel_size: int) -> tf.Tensor: + """Performs dilation on a given tensor + + Args: + x: boolean tensor of shape (N, H, W, C) + kernel_size: the size of the kernel to use for dilation + Returns: + the dilated tensor + """ + + return tf.nn.max_pool2d(x, kernel_size, strides=1, padding="SAME") diff --git a/doctr/models/detection/core.py b/doctr/models/detection/core.py new file mode 100644 index 0000000000..fd6bf489ae --- /dev/null +++ b/doctr/models/detection/core.py @@ -0,0 +1,105 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import List + +import cv2 +import numpy as np + +from doctr.utils.repr import NestedObject + +__all__ = ['DetectionPostProcessor'] + + +class DetectionPostProcessor(NestedObject): + """Abstract class to postprocess the raw output of the model + + Args: + box_thresh (float): minimal objectness score to consider a box + bin_thresh (float): threshold to apply to segmentation raw heatmap + assume straight_pages (bool): if True, fit straight boxes only + """ + + def __init__( + self, + box_thresh: float = 0.5, + bin_thresh: float = 0.5, + assume_straight_pages: bool = True + ) -> None: + + self.box_thresh = box_thresh + self.bin_thresh = bin_thresh + self.assume_straight_pages = assume_straight_pages + self._opening_kernel = np.ones((3, 3), dtype=np.uint8) + + def extra_repr(self) -> str: + return f"bin_thresh={self.bin_thresh}, box_thresh={self.box_thresh}" + + @staticmethod + def box_score( + pred: np.ndarray, + points: np.ndarray, + assume_straight_pages: bool = True + ) -> float: + """Compute the confidence score for a polygon : mean of the p values on the polygon + + Args: + pred (np.ndarray): p map returned by the model + + Returns: + polygon objectness + """ + h, w = pred.shape[:2] + + if assume_straight_pages: + xmin = np.clip(np.floor(points[:, 0].min()).astype(np.int32), 0, w - 1) + xmax = np.clip(np.ceil(points[:, 0].max()).astype(np.int32), 0, w - 1) + ymin = np.clip(np.floor(points[:, 1].min()).astype(np.int32), 0, h - 1) + ymax = np.clip(np.ceil(points[:, 1].max()).astype(np.int32), 0, h - 1) + return pred[ymin:ymax + 1, xmin:xmax + 1].mean() + + else: + mask = np.zeros((h, w), np.int32) + cv2.fillPoly(mask, [points.astype(np.int32)], 1.0) + product = pred * mask + return np.sum(product) / np.count_nonzero(product) + + def bitmap_to_boxes( + self, + pred: np.ndarray, + bitmap: np.ndarray, + ) -> np.ndarray: + raise NotImplementedError + + def __call__( + self, + proba_map, + ) -> List[List[np.ndarray]]: + """Performs postprocessing for a list of model outputs + + Args: + proba_map: probability map of shape (N, H, W, C) + + Returns: + list of N class predictions (for each input sample), where each class predictions is a list of C tensors + of shape (*, 5) or (*, 6) + """ + + if proba_map.ndim != 4: + raise AssertionError(f"arg `proba_map` is expected to be 4-dimensional, got {proba_map.ndim}.") + + # Erosion + dilation on the binary map + bin_map = [ + [ + cv2.morphologyEx(bmap[..., idx], cv2.MORPH_OPEN, self._opening_kernel) + for idx in range(proba_map.shape[-1]) + ] + for bmap in (proba_map >= self.bin_thresh).astype(np.uint8) + ] + + return [ + [self.bitmap_to_boxes(pmaps[..., idx], bmaps[idx]) for idx in range(proba_map.shape[-1])] + for pmaps, bmaps in zip(proba_map, bin_map) + ] diff --git a/doctr/models/detection/differentiable_binarization/__init__.py b/doctr/models/detection/differentiable_binarization/__init__.py new file mode 100644 index 0000000000..059f261e82 --- /dev/null +++ b/doctr/models/detection/differentiable_binarization/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/detection/differentiable_binarization/base.py b/doctr/models/detection/differentiable_binarization/base.py new file mode 100644 index 0000000000..a2eb28ba03 --- /dev/null +++ b/doctr/models/detection/differentiable_binarization/base.py @@ -0,0 +1,348 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization + +from typing import List, Tuple + +import cv2 +import numpy as np +import pyclipper +from shapely.geometry import Polygon + +from ..core import DetectionPostProcessor + +__all__ = ['DBPostProcessor'] + + +class DBPostProcessor(DetectionPostProcessor): + """Implements a post processor for DBNet adapted from the implementation of `xuannianz + `_. + + Args: + unclip ratio: ratio used to unshrink polygons + min_size_box: minimal length (pix) to keep a box + max_candidates: maximum boxes to consider in a single page + box_thresh: minimal objectness score to consider a box + bin_thresh: threshold used to binzarized p_map at inference time + + """ + def __init__( + self, + box_thresh: float = 0.1, + bin_thresh: float = 0.3, + assume_straight_pages: bool = True, + ) -> None: + + super().__init__( + box_thresh, + bin_thresh, + assume_straight_pages + ) + self.unclip_ratio = 1.5 if assume_straight_pages else 2.2 + + def polygon_to_box( + self, + points: np.ndarray, + ) -> np.ndarray: + """Expand a polygon (points) by a factor unclip_ratio, and returns a polygon + + Args: + points: The first parameter. + + Returns: + a box in absolute coordinates (xmin, ymin, xmax, ymax) or (4, 2) array (quadrangle) + """ + if not self.assume_straight_pages: + # Compute the rectangle polygon enclosing the raw polygon + rect = cv2.minAreaRect(points) + points = cv2.boxPoints(rect) + # Add 1 pixel to correct cv2 approx + area = (rect[1][0] + 1) * (1 + rect[1][1]) + length = 2 * (rect[1][0] + rect[1][1]) + 2 + else: + poly = Polygon(points) + area = poly.area + length = poly.length + distance = area * self.unclip_ratio / length # compute distance to expand polygon + offset = pyclipper.PyclipperOffset() + offset.AddPath(points, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + _points = offset.Execute(distance) + # Take biggest stack of points + idx = 0 + if len(_points) > 1: + max_size = 0 + for _idx, p in enumerate(_points): + if len(p) > max_size: + idx = _idx + max_size = len(p) + # We ensure that _points can be correctly casted to a ndarray + _points = [_points[idx]] + expanded_points = np.asarray(_points) # expand polygon + if len(expanded_points) < 1: + return None + return cv2.boundingRect(expanded_points) if self.assume_straight_pages else np.roll( + cv2.boxPoints(cv2.minAreaRect(expanded_points)), -1, axis=0 + ) + + def bitmap_to_boxes( + self, + pred: np.ndarray, + bitmap: np.ndarray, + ) -> np.ndarray: + """Compute boxes from a bitmap/pred_map + + Args: + pred: Pred map from differentiable binarization output + bitmap: Bitmap map computed from pred (binarized) + angle_tol: Comparison tolerance of the angle with the median angle across the page + ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop + + Returns: + np tensor boxes for the bitmap, each box is a 5-element list + containing x, y, w, h, score for the box + """ + height, width = bitmap.shape[:2] + min_size_box = 1 + int(height / 512) + boxes = [] + # get contours from connected components on the bitmap + contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + # Check whether smallest enclosing bounding box is not too small + if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box): + continue + # Compute objectness + if self.assume_straight_pages: + x, y, w, h = cv2.boundingRect(contour) + points = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]]) + score = self.box_score(pred, points, assume_straight_pages=True) + else: + score = self.box_score(pred, contour, assume_straight_pages=False) + + if score < self.box_thresh: # remove polygons with a weak objectness + continue + + if self.assume_straight_pages: + _box = self.polygon_to_box(points) + else: + _box = self.polygon_to_box(np.squeeze(contour)) + + # Remove too small boxes + if self.assume_straight_pages: + if _box is None or _box[2] < min_size_box or _box[3] < min_size_box: + continue + elif np.linalg.norm(_box[2, :] - _box[0, :], axis=-1) < min_size_box: + continue + + if self.assume_straight_pages: + x, y, w, h = _box # type: ignore[misc] + # compute relative polygon to get rid of img shape + xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height + boxes.append([xmin, ymin, xmax, ymax, score]) + else: + # compute relative box to get rid of img shape, in that case _box is a 4pt polygon + if not isinstance(_box, np.ndarray) and _box.shape == (4, 2): + raise AssertionError("When assume straight pages is false a box is a (4, 2) array (polygon)") + _box[:, 0] /= width + _box[:, 1] /= height + boxes.append(_box) + + if not self.assume_straight_pages: + return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype) + else: + return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype) + + +class _DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_. + + Args: + feature extractor: the backbone serving as feature extractor + fpn_channels: number of channels each extracted feature maps is mapped to + """ + + shrink_ratio = 0.4 + thresh_min = 0.3 + thresh_max = 0.7 + min_size_box = 3 + assume_straight_pages: bool = True + + @staticmethod + def compute_distance( + xs: np.array, + ys: np.array, + a: np.array, + b: np.array, + eps: float = 1e-7, + ) -> float: + """Compute the distance for each point of the map (xs, ys) to the (a, b) segment + + Args: + xs : map of x coordinates (height, width) + ys : map of y coordinates (height, width) + a: first point defining the [ab] segment + b: second point defining the [ab] segment + + Returns: + The computed distance + + """ + square_dist_1 = np.square(xs - a[0]) + np.square(ys - a[1]) + square_dist_2 = np.square(xs - b[0]) + np.square(ys - b[1]) + square_dist = np.square(a[0] - b[0]) + np.square(a[1] - b[1]) + cosin = (square_dist - square_dist_1 - square_dist_2) / (2 * np.sqrt(square_dist_1 * square_dist_2) + eps) + square_sin = 1 - np.square(cosin) + square_sin = np.nan_to_num(square_sin) + result = np.sqrt(square_dist_1 * square_dist_2 * square_sin / square_dist) + result[cosin < 0] = np.sqrt(np.fmin(square_dist_1, square_dist_2))[cosin < 0] + return result + + def draw_thresh_map( + self, + polygon: np.array, + canvas: np.array, + mask: np.array, + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Draw a polygon treshold map on a canvas, as described in the DB paper + + Args: + polygon : array of coord., to draw the boundary of the polygon + canvas : threshold map to fill with polygons + mask : mask for training on threshold polygons + """ + if polygon.ndim != 2 or polygon.shape[1] != 2: + raise AttributeError("polygon should be a 2 dimensional array of coords") + + # Augment polygon by shrink_ratio + polygon_shape = Polygon(polygon) + distance = polygon_shape.area * (1 - np.power(self.shrink_ratio, 2)) / polygon_shape.length + subject = [tuple(coor) for coor in polygon] # Get coord as list of tuples + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + padded_polygon = np.array(padding.Execute(distance)[0]) + + # Fill the mask with 1 on the new padded polygon + cv2.fillPoly(mask, [padded_polygon.astype(np.int32)], 1.0) + + # Get min/max to recover polygon after distance computation + xmin = padded_polygon[:, 0].min() + xmax = padded_polygon[:, 0].max() + ymin = padded_polygon[:, 1].min() + ymax = padded_polygon[:, 1].max() + width = xmax - xmin + 1 + height = ymax - ymin + 1 + # Get absolute polygon for distance computation + polygon[:, 0] = polygon[:, 0] - xmin + polygon[:, 1] = polygon[:, 1] - ymin + # Get absolute padded polygon + xs = np.broadcast_to(np.linspace(0, width - 1, num=width).reshape(1, width), (height, width)) + ys = np.broadcast_to(np.linspace(0, height - 1, num=height).reshape(height, 1), (height, width)) + + # Compute distance map to fill the padded polygon + distance_map = np.zeros((polygon.shape[0], height, width), dtype=polygon.dtype) + for i in range(polygon.shape[0]): + j = (i + 1) % polygon.shape[0] + absolute_distance = self.compute_distance(xs, ys, polygon[i], polygon[j]) + distance_map[i] = np.clip(absolute_distance / distance, 0, 1) + distance_map = np.min(distance_map, axis=0) + + # Clip the padded polygon inside the canvas + xmin_valid = min(max(0, xmin), canvas.shape[1] - 1) + xmax_valid = min(max(0, xmax), canvas.shape[1] - 1) + ymin_valid = min(max(0, ymin), canvas.shape[0] - 1) + ymax_valid = min(max(0, ymax), canvas.shape[0] - 1) + + # Fill the canvas with the distances computed inside the valid padded polygon + canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] = np.fmax( + 1 - distance_map[ + ymin_valid - ymin:ymax_valid - ymin + 1, + xmin_valid - xmin:xmax_valid - xmin + 1 + ], + canvas[ymin_valid:ymax_valid + 1, xmin_valid:xmax_valid + 1] + ) + + return polygon, canvas, mask + + def build_target( + self, + target: List[np.ndarray], + output_shape: Tuple[int, int, int], + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: + + if any(t.dtype != np.float32 for t in target): + raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.") + if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for t in target): + raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.") + + input_dtype = target[0].dtype if len(target) > 0 else np.float32 + + seg_target = np.zeros(output_shape, dtype=np.uint8) + seg_mask = np.ones(output_shape, dtype=bool) + thresh_target = np.zeros(output_shape, dtype=np.float32) + thresh_mask = np.ones(output_shape, dtype=np.uint8) + + for idx, _target in enumerate(target): + # Draw each polygon on gt + if _target.shape[0] == 0: + # Empty image, full masked + seg_mask[idx] = False + + # Absolute bounding boxes + abs_boxes = _target.copy() + if abs_boxes.ndim == 3: + abs_boxes[:, :, 0] *= output_shape[-1] + abs_boxes[:, :, 1] *= output_shape[-2] + polys = abs_boxes + boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1) + abs_boxes = np.concatenate((abs_boxes.min(1), abs_boxes.max(1)), -1).round().astype(np.int32) + else: + abs_boxes[:, [0, 2]] *= output_shape[-1] + abs_boxes[:, [1, 3]] *= output_shape[-2] + abs_boxes = abs_boxes.round().astype(np.int32) + polys = np.stack([ + abs_boxes[:, [0, 1]], + abs_boxes[:, [0, 3]], + abs_boxes[:, [2, 3]], + abs_boxes[:, [2, 1]], + ], axis=1) + boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1]) + + for box, box_size, poly in zip(abs_boxes, boxes_size, polys): + # Mask boxes that are too small + if box_size < self.min_size_box: + seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False + continue + + # Negative shrink for gt, as described in paper + polygon = Polygon(poly) + distance = polygon.area * (1 - np.power(self.shrink_ratio, 2)) / polygon.length + subject = [tuple(coor) for coor in poly] + padding = pyclipper.PyclipperOffset() + padding.AddPath(subject, pyclipper.JT_ROUND, pyclipper.ET_CLOSEDPOLYGON) + shrinked = padding.Execute(-distance) + + # Draw polygon on gt if it is valid + if len(shrinked) == 0: + seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False + continue + shrinked = np.array(shrinked[0]).reshape(-1, 2) + if shrinked.shape[0] <= 2 or not Polygon(shrinked).is_valid: + seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False + continue + cv2.fillPoly(seg_target[idx], [shrinked.astype(np.int32)], 1) + + # Draw on both thresh map and thresh mask + poly, thresh_target[idx], thresh_mask[idx] = self.draw_thresh_map(poly, thresh_target[idx], + thresh_mask[idx]) + + thresh_target = thresh_target.astype(input_dtype) * (self.thresh_max - self.thresh_min) + self.thresh_min + + seg_target = seg_target.astype(input_dtype) + seg_mask = seg_mask.astype(bool) + thresh_target = thresh_target.astype(input_dtype) + thresh_mask = thresh_mask.astype(bool) + + return seg_target, seg_mask, thresh_target, thresh_mask diff --git a/doctr/models/detection/differentiable_binarization/pytorch.py b/doctr/models/detection/differentiable_binarization/pytorch.py new file mode 100644 index 0000000000..dbe49b0be9 --- /dev/null +++ b/doctr/models/detection/differentiable_binarization/pytorch.py @@ -0,0 +1,400 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, Callable, Dict, List, Optional + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from torchvision.models import resnet34, resnet50 +from torchvision.models._utils import IntermediateLayerGetter +from torchvision.ops.deform_conv import DeformConv2d + +from ...classification import mobilenet_v3_large +from ...utils import load_pretrained_params +from .base import DBPostProcessor, _DBNet + +__all__ = ['DBNet', 'db_resnet50', 'db_resnet34', 'db_mobilenet_v3_large', 'db_resnet50_rotation'] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'db_resnet50': { + 'input_shape': (3, 1024, 1024), + 'mean': (0.798, 0.785, 0.772), + 'std': (0.264, 0.2749, 0.287), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.3.1/db_resnet50-ac60cadc.pt', + }, + 'db_resnet34': { + 'input_shape': (3, 1024, 1024), + 'mean': (.5, .5, .5), + 'std': (1., 1., 1.), + 'url': None, + }, + 'db_mobilenet_v3_large': { + 'input_shape': (3, 1024, 1024), + 'mean': (0.798, 0.785, 0.772), + 'std': (0.264, 0.2749, 0.287), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.3.1/db_mobilenet_v3_large-fd62154b.pt', + }, + 'db_resnet50_rotation': { + 'input_shape': (3, 1024, 1024), + 'mean': (0.798, 0.785, 0.772), + 'std': (0.264, 0.2749, 0.287), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/db_resnet50-1138863a.pt', + }, +} + + +class FeaturePyramidNetwork(nn.Module): + def __init__( + self, + in_channels: List[int], + out_channels: int, + deform_conv: bool = False, + ) -> None: + + super().__init__() + + out_chans = out_channels // len(in_channels) + + conv_layer = DeformConv2d if deform_conv else nn.Conv2d + + self.in_branches = nn.ModuleList([ + nn.Sequential( + conv_layer(chans, out_channels, 1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True), + ) for idx, chans in enumerate(in_channels) + ]) + self.upsample = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.out_branches = nn.ModuleList([ + nn.Sequential( + conv_layer(out_channels, out_chans, 3, padding=1, bias=False), + nn.BatchNorm2d(out_chans), + nn.ReLU(inplace=True), + nn.Upsample(scale_factor=2 ** idx, mode='bilinear', align_corners=True), + ) for idx, chans in enumerate(in_channels) + ]) + + def forward(self, x: List[torch.Tensor]) -> torch.Tensor: + if len(x) != len(self.out_branches): + raise AssertionError + # Conv1x1 to get the same number of channels + _x: List[torch.Tensor] = [branch(t) for branch, t in zip(self.in_branches, x)] + out: List[torch.Tensor] = [_x[-1]] + for t in _x[:-1][::-1]: + out.append(self.upsample(out[-1]) + t) + + # Conv and final upsampling + out = [branch(t) for branch, t in zip(self.out_branches, out[::-1])] + + return torch.cat(out, dim=1) + + +class DBNet(_DBNet, nn.Module): + + def __init__( + self, + feat_extractor: IntermediateLayerGetter, + head_chans: int = 256, + deform_conv: bool = False, + num_classes: int = 1, + assume_straight_pages: bool = True, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + + super().__init__() + self.cfg = cfg + + conv_layer = DeformConv2d if deform_conv else nn.Conv2d + + self.assume_straight_pages = assume_straight_pages + + self.feat_extractor = feat_extractor + # Identify the number of channels for the head initialization + _is_training = self.feat_extractor.training + self.feat_extractor = self.feat_extractor.eval() + with torch.no_grad(): + out = self.feat_extractor(torch.zeros((1, 3, 224, 224))) + fpn_channels = [v.shape[1] for _, v in out.items()] + + if _is_training: + self.feat_extractor = self.feat_extractor.train() + + self.fpn = FeaturePyramidNetwork(fpn_channels, head_chans, deform_conv) + # Conv1 map to channels + + self.prob_head = nn.Sequential( + conv_layer(head_chans, head_chans // 4, 3, padding=1, bias=False), + nn.BatchNorm2d(head_chans // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(head_chans // 4, head_chans // 4, 2, stride=2, bias=False), + nn.BatchNorm2d(head_chans // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2), + ) + self.thresh_head = nn.Sequential( + conv_layer(head_chans, head_chans // 4, 3, padding=1, bias=False), + nn.BatchNorm2d(head_chans // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(head_chans // 4, head_chans // 4, 2, stride=2, bias=False), + nn.BatchNorm2d(head_chans // 4), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(head_chans // 4, num_classes, 2, stride=2), + ) + + self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages) + + for n, m in self.named_modules(): + # Don't override the initialization of the backbone + if n.startswith('feat_extractor.'): + continue + if isinstance(m, (nn.Conv2d, DeformConv2d)): + nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + target: Optional[List[np.ndarray]] = None, + return_model_output: bool = False, + return_preds: bool = False, + ) -> Dict[str, torch.Tensor]: + # Extract feature maps at different stages + feats = self.feat_extractor(x) + feats = [feats[str(idx)] for idx in range(len(feats))] + # Pass through the FPN + feat_concat = self.fpn(feats) + logits = self.prob_head(feat_concat) + + out: Dict[str, Any] = {} + if return_model_output or target is None or return_preds: + prob_map = torch.sigmoid(logits) + + if return_model_output: + out["out_map"] = prob_map + + if target is None or return_preds: + # Post-process boxes (keep only text predictions) + out["preds"] = [ + preds[0] for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) + ] + + if target is not None: + thresh_map = self.thresh_head(feat_concat) + loss = self.compute_loss(logits, thresh_map, target) + out['loss'] = loss + + return out + + def compute_loss( + self, + out_map: torch.Tensor, + thresh_map: torch.Tensor, + target: List[np.ndarray] + ) -> torch.Tensor: + """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes + and a list of masks for each image. From there it computes the loss with the model output + + Args: + out_map: output feature map of the model of shape (N, C, H, W) + thresh_map: threshold map of shape (N, C, H, W) + target: list of dictionary where each dict has a `boxes` and a `flags` entry + + Returns: + A loss tensor + """ + + prob_map = torch.sigmoid(out_map.squeeze(1)) + thresh_map = torch.sigmoid(thresh_map.squeeze(1)) + + targets = self.build_target(target, prob_map.shape) # type: ignore[arg-type] + + seg_target, seg_mask = torch.from_numpy(targets[0]), torch.from_numpy(targets[1]) + seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) + thresh_target, thresh_mask = torch.from_numpy(targets[2]), torch.from_numpy(targets[3]) + thresh_target, thresh_mask = thresh_target.to(out_map.device), thresh_mask.to(out_map.device) + + # Compute balanced BCE loss for proba_map + bce_scale = 5. + balanced_bce_loss = torch.zeros(1, device=out_map.device) + dice_loss = torch.zeros(1, device=out_map.device) + l1_loss = torch.zeros(1, device=out_map.device) + if torch.any(seg_mask): + bce_loss = F.binary_cross_entropy_with_logits(out_map.squeeze(1), seg_target, reduction='none')[seg_mask] + + neg_target = 1 - seg_target[seg_mask] + positive_count = seg_target[seg_mask].sum() + negative_count = torch.minimum(neg_target.sum(), 3. * positive_count) + negative_loss = bce_loss * neg_target + negative_loss = negative_loss.sort().values[-int(negative_count.item()):] + sum_losses = torch.sum(bce_loss * seg_target[seg_mask]) + torch.sum(negative_loss) + balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6) + + # Compute dice loss for approxbin_map + bin_map = 1 / (1 + torch.exp(-50. * (prob_map[seg_mask] - thresh_map[seg_mask]))) + + bce_min = bce_loss.min() + weights = (bce_loss - bce_min) / (bce_loss.max() - bce_min) + 1. + inter = torch.sum(bin_map * seg_target[seg_mask] * weights) + union = torch.sum(bin_map) + torch.sum(seg_target[seg_mask]) + 1e-8 + dice_loss = 1 - 2.0 * inter / union + + # Compute l1 loss for thresh_map + l1_scale = 10. + if torch.any(thresh_mask): + l1_loss = torch.mean(torch.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask])) + + return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss + + +def _dbnet( + arch: str, + pretrained: bool, + backbone_fn: Callable[[bool], nn.Module], + fpn_layers: List[str], + backbone_submodule: Optional[str] = None, + pretrained_backbone: bool = True, + **kwargs: Any, +) -> DBNet: + + # Starting with Imagenet pretrained params introduces some NaNs in layer3 & layer4 of resnet50 + pretrained_backbone = pretrained_backbone and not arch.split('_')[1].startswith('resnet') + pretrained_backbone = pretrained_backbone and not pretrained + + # Feature extractor + backbone = backbone_fn(pretrained_backbone) + if isinstance(backbone_submodule, str): + backbone = getattr(backbone, backbone_submodule) + feat_extractor = IntermediateLayerGetter( + backbone, + {layer_name: str(idx) for idx, layer_name in enumerate(fpn_layers)}, + ) + + # Build the model + model = DBNet(feat_extractor, cfg=default_cfgs[arch], **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]['url']) + + return model + + +def db_resnet34(pretrained: bool = False, **kwargs: Any) -> DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_, using a ResNet-34 backbone. + + Example:: + >>> import torch + >>> from doctr.models import db_resnet34 + >>> model = db_resnet34(pretrained=True) + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + + Returns: + text detection architecture + """ + + return _dbnet( + 'db_resnet34', + pretrained, + resnet34, + ['layer1', 'layer2', 'layer3', 'layer4'], + None, + **kwargs, + ) + + +def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_, using a ResNet-50 backbone. + + Example:: + >>> import torch + >>> from doctr.models import db_resnet50 + >>> model = db_resnet50(pretrained=True) + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + + Returns: + text detection architecture + """ + + return _dbnet( + 'db_resnet50', + pretrained, + resnet50, + ['layer1', 'layer2', 'layer3', 'layer4'], + None, + **kwargs, + ) + + +def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_, using a MobileNet V3 Large backbone. + + Example:: + >>> import torch + >>> from doctr.models import db_mobilenet_v3_large + >>> model = db_mobilenet_v3_large(pretrained=True) + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + + Returns: + text detection architecture + """ + + return _dbnet( + 'db_mobilenet_v3_large', + pretrained, + mobilenet_v3_large, + ['3', '6', '12', '16'], + 'features', + **kwargs, + ) + + +def db_resnet50_rotation(pretrained: bool = False, **kwargs: Any) -> DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_, using a ResNet-50 backbone. + This model is trained with rotated documents + + Example:: + >>> import torch + >>> from doctr.models import db_resnet50_rotation + >>> model = db_resnet50_rotation(pretrained=True) + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + + Returns: + text detection architecture + """ + + return _dbnet( + 'db_resnet50_rotation', + pretrained, + resnet50, + ['layer1', 'layer2', 'layer3', 'layer4'], + None, + **kwargs, + ) diff --git a/doctr/models/detection/differentiable_binarization/tensorflow.py b/doctr/models/detection/differentiable_binarization/tensorflow.py new file mode 100644 index 0000000000..184f84cb59 --- /dev/null +++ b/doctr/models/detection/differentiable_binarization/tensorflow.py @@ -0,0 +1,371 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import layers +from tensorflow.keras.applications import ResNet50 + +from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params +from doctr.utils.repr import NestedObject + +from ...classification import mobilenet_v3_large +from .base import DBPostProcessor, _DBNet + +__all__ = ['DBNet', 'db_resnet50', 'db_mobilenet_v3_large'] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'db_resnet50': { + 'mean': (0.798, 0.785, 0.772), + 'std': (0.264, 0.2749, 0.287), + 'input_shape': (1024, 1024, 3), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.2.0/db_resnet50-adcafc63.zip', + }, + 'db_mobilenet_v3_large': { + 'mean': (0.798, 0.785, 0.772), + 'std': (0.264, 0.2749, 0.287), + 'input_shape': (1024, 1024, 3), + 'url': 'https://github.com/mindee/doctr/releases/download/v0.3.1/db_mobilenet_v3_large-8c16d5bf.zip', + }, +} + + +class FeaturePyramidNetwork(layers.Layer, NestedObject): + """Feature Pyramid Network as described in `"Feature Pyramid Networks for Object Detection" + `_. + + Args: + channels: number of channel to output + """ + + def __init__( + self, + channels: int, + ) -> None: + super().__init__() + self.channels = channels + self.upsample = layers.UpSampling2D(size=(2, 2), interpolation='nearest') + self.inner_blocks = [layers.Conv2D(channels, 1, strides=1, kernel_initializer='he_normal') for _ in range(4)] + self.layer_blocks = [self.build_upsampling(channels, dilation_factor=2 ** idx) for idx in range(4)] + + @staticmethod + def build_upsampling( + channels: int, + dilation_factor: int = 1, + ) -> layers.Layer: + """Module which performs a 3x3 convolution followed by up-sampling + + Args: + channels: number of output channels + dilation_factor (int): dilation factor to scale the convolution output before concatenation + + Returns: + a keras.layers.Layer object, wrapping these operations in a sequential module + + """ + + _layers = conv_sequence(channels, 'relu', True, kernel_size=3) + + if dilation_factor > 1: + _layers.append(layers.UpSampling2D(size=(dilation_factor, dilation_factor), interpolation='nearest')) + + module = keras.Sequential(_layers) + + return module + + def extra_repr(self) -> str: + return f"channels={self.channels}" + + def call( + self, + x: List[tf.Tensor], + **kwargs: Any, + ) -> tf.Tensor: + + # Channel mapping + results = [block(fmap, **kwargs) for block, fmap in zip(self.inner_blocks, x)] + # Upsample & sum + for idx in range(len(results) - 1, -1): + results[idx] += self.upsample(results[idx + 1]) + # Conv & upsample + results = [block(fmap, **kwargs) for block, fmap in zip(self.layer_blocks, results)] + + return layers.concatenate(results) + + +class DBNet(_DBNet, keras.Model, NestedObject): + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_. + + Args: + feature extractor: the backbone serving as feature extractor + fpn_channels: number of channels each extracted feature maps is mapped to + num_classes: number of output channels in the segmentation map + assume_straight_pages: if True, fit straight bounding boxes only + cfg: the configuration dict of the model + """ + + _children_names: List[str] = ['feat_extractor', 'fpn', 'probability_head', 'threshold_head', 'postprocessor'] + + def __init__( + self, + feature_extractor: IntermediateLayerGetter, + fpn_channels: int = 128, # to be set to 256 to represent the author's initial idea + num_classes: int = 1, + assume_straight_pages: bool = True, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + + super().__init__() + self.cfg = cfg + + self.feat_extractor = feature_extractor + self.assume_straight_pages = assume_straight_pages + + self.fpn = FeaturePyramidNetwork(channels=fpn_channels) + # Initialize kernels + _inputs = [layers.Input(shape=in_shape[1:]) for in_shape in self.feat_extractor.output_shape] + output_shape = tuple(self.fpn(_inputs).shape) + + self.probability_head = keras.Sequential( + [ + *conv_sequence(64, 'relu', True, kernel_size=3, input_shape=output_shape[1:]), + layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer='he_normal'), + layers.BatchNormalization(), + layers.Activation('relu'), + layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer='he_normal'), + ] + ) + self.threshold_head = keras.Sequential( + [ + *conv_sequence(64, 'relu', True, kernel_size=3, input_shape=output_shape[1:]), + layers.Conv2DTranspose(64, 2, strides=2, use_bias=False, kernel_initializer='he_normal'), + layers.BatchNormalization(), + layers.Activation('relu'), + layers.Conv2DTranspose(num_classes, 2, strides=2, kernel_initializer='he_normal'), + ] + ) + + self.postprocessor = DBPostProcessor(assume_straight_pages=assume_straight_pages) + + def compute_loss( + self, + out_map: tf.Tensor, + thresh_map: tf.Tensor, + target: List[np.ndarray] + ) -> tf.Tensor: + """Compute a batch of gts, masks, thresh_gts, thresh_masks from a list of boxes + and a list of masks for each image. From there it computes the loss with the model output + + Args: + out_map: output feature map of the model of shape (N, H, W, C) + thresh_map: threshold map of shape (N, H, W, C) + target: list of dictionary where each dict has a `boxes` and a `flags` entry + + Returns: + A loss tensor + """ + + prob_map = tf.math.sigmoid(tf.squeeze(out_map, axis=[-1])) + thresh_map = tf.math.sigmoid(tf.squeeze(thresh_map, axis=[-1])) + + seg_target, seg_mask, thresh_target, thresh_mask = self.build_target(target, out_map.shape[:3]) + seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype) + seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) + thresh_target = tf.convert_to_tensor(thresh_target, dtype=out_map.dtype) + thresh_mask = tf.convert_to_tensor(thresh_mask, dtype=tf.bool) + + # Compute balanced BCE loss for proba_map + bce_scale = 5. + bce_loss = tf.keras.losses.binary_crossentropy(seg_target[..., None], out_map, from_logits=True)[seg_mask] + + neg_target = 1 - seg_target[seg_mask] + positive_count = tf.math.reduce_sum(seg_target[seg_mask]) + negative_count = tf.math.reduce_min([tf.math.reduce_sum(neg_target), 3. * positive_count]) + negative_loss = bce_loss * neg_target + negative_loss, _ = tf.nn.top_k(negative_loss, tf.cast(negative_count, tf.int32)) + sum_losses = tf.math.reduce_sum(bce_loss * seg_target[seg_mask]) + tf.math.reduce_sum(negative_loss) + balanced_bce_loss = sum_losses / (positive_count + negative_count + 1e-6) + + # Compute dice loss for approxbin_map + bin_map = 1 / (1 + tf.exp(-50. * (prob_map[seg_mask] - thresh_map[seg_mask]))) + + bce_min = tf.math.reduce_min(bce_loss) + weights = (bce_loss - bce_min) / (tf.math.reduce_max(bce_loss) - bce_min) + 1. + inter = tf.math.reduce_sum(bin_map * seg_target[seg_mask] * weights) + union = tf.math.reduce_sum(bin_map) + tf.math.reduce_sum(seg_target[seg_mask]) + 1e-8 + dice_loss = 1 - 2.0 * inter / union + + # Compute l1 loss for thresh_map + l1_scale = 10. + if tf.reduce_any(thresh_mask): + l1_loss = tf.math.reduce_mean(tf.math.abs(thresh_map[thresh_mask] - thresh_target[thresh_mask])) + else: + l1_loss = tf.constant(0.) + + return l1_scale * l1_loss + bce_scale * balanced_bce_loss + dice_loss + + def call( + self, + x: tf.Tensor, + target: Optional[List[np.ndarray]] = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + + feat_maps = self.feat_extractor(x, **kwargs) + feat_concat = self.fpn(feat_maps, **kwargs) + logits = self.probability_head(feat_concat, **kwargs) + + out: Dict[str, tf.Tensor] = {} + if return_model_output or target is None or return_preds: + prob_map = tf.math.sigmoid(logits) + + if return_model_output: + out["out_map"] = prob_map + + if target is None or return_preds: + # Post-process boxes (keep only text predictions) + out["preds"] = [preds[0] for preds in self.postprocessor(prob_map.numpy())] + + if target is not None: + thresh_map = self.threshold_head(feat_concat, **kwargs) + loss = self.compute_loss(logits, thresh_map, target) + out['loss'] = loss + + return out + + +def _db_resnet( + arch: str, + pretrained: bool, + backbone_fn, + fpn_layers: List[str], + pretrained_backbone: bool = True, + input_shape: Optional[Tuple[int, int, int]] = None, + **kwargs: Any, +) -> DBNet: + + pretrained_backbone = pretrained_backbone and not pretrained + + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg['input_shape'] = input_shape or _cfg['input_shape'] + + # Feature extractor + feat_extractor = IntermediateLayerGetter( + backbone_fn( + weights='imagenet' if pretrained_backbone else None, + include_top=False, + pooling=None, + input_shape=_cfg['input_shape'], + ), + fpn_layers, + ) + + # Build the model + model = DBNet(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, _cfg['url']) + + return model + + +def _db_mobilenet( + arch: str, + pretrained: bool, + backbone_fn, + fpn_layers: List[str], + pretrained_backbone: bool = True, + input_shape: Optional[Tuple[int, int, int]] = None, + **kwargs: Any, +) -> DBNet: + + pretrained_backbone = pretrained_backbone and not pretrained + + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg['input_shape'] = input_shape or _cfg['input_shape'] + + # Feature extractor + feat_extractor = IntermediateLayerGetter( + backbone_fn( + input_shape=_cfg['input_shape'], + include_top=False, + pretrained=pretrained_backbone, + ), + fpn_layers, + ) + + # Build the model + model = DBNet(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, _cfg['url']) + + return model + + +def db_resnet50(pretrained: bool = False, **kwargs: Any) -> DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_, using a ResNet-50 backbone. + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import db_resnet50 + >>> model = db_resnet50(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + + Returns: + text detection architecture + """ + + return _db_resnet( + 'db_resnet50', + pretrained, + ResNet50, + ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"], + **kwargs, + ) + + +def db_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> DBNet: + """DBNet as described in `"Real-time Scene Text Detection with Differentiable Binarization" + `_, using a mobilenet v3 large backbone. + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import db_mobilenet_v3_large + >>> model = db_mobilenet_v3_large(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + + Returns: + text detection architecture + """ + + return _db_mobilenet( + 'db_mobilenet_v3_large', + pretrained, + mobilenet_v3_large, + ["inverted_2", "inverted_5", "inverted_11", "final_block"], + **kwargs, + ) diff --git a/doctr/models/detection/linknet/__init__.py b/doctr/models/detection/linknet/__init__.py new file mode 100644 index 0000000000..059f261e82 --- /dev/null +++ b/doctr/models/detection/linknet/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/detection/linknet/base.py b/doctr/models/detection/linknet/base.py new file mode 100644 index 0000000000..6e4136bcdc --- /dev/null +++ b/doctr/models/detection/linknet/base.py @@ -0,0 +1,177 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization + +from typing import List, Tuple + +import cv2 +import numpy as np + +from doctr.file_utils import is_tf_available +from doctr.models.core import BaseModel + +from ..core import DetectionPostProcessor + +__all__ = ['_LinkNet', 'LinkNetPostProcessor'] + + +class LinkNetPostProcessor(DetectionPostProcessor): + """Implements a post processor for LinkNet model. + + Args: + bin_thresh: threshold used to binzarized p_map at inference time + box_thresh: minimal objectness score to consider a box + assume_straight_pages: whether the inputs were expected to have horizontal text elements + """ + def __init__( + self, + bin_thresh: float = 0.5, + box_thresh: float = 0.1, + assume_straight_pages: bool = True, + ) -> None: + super().__init__( + box_thresh, + bin_thresh, + assume_straight_pages + ) + + def bitmap_to_boxes( + self, + pred: np.ndarray, + bitmap: np.ndarray, + ) -> np.ndarray: + """Compute boxes from a bitmap/pred_map: find connected components then filter boxes + + Args: + pred: Pred map from differentiable linknet output + bitmap: Bitmap map computed from pred (binarized) + angle_tol: Comparison tolerance of the angle with the median angle across the page + ratio_tol: Under this limit aspect ratio, we cannot resolve the direction of the crop + + Returns: + np tensor boxes for the bitmap, each box is a 6-element list + containing x, y, w, h, alpha, score for the box + """ + height, width = bitmap.shape[:2] + min_size_box = 1 + int(height / 512) + boxes = [] + # get contours from connected components on the bitmap + contours, _ = cv2.findContours(bitmap.astype(np.uint8), cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE) + for contour in contours: + # Check whether smallest enclosing bounding box is not too small + if np.any(contour[:, 0].max(axis=0) - contour[:, 0].min(axis=0) < min_size_box): + continue + # Compute objectness + if self.assume_straight_pages: + x, y, w, h = cv2.boundingRect(contour) + points = np.array([[x, y], [x, y + h], [x + w, y + h], [x + w, y]]) + score = self.box_score(pred, points, assume_straight_pages=True) + else: + score = self.box_score(pred, contour, assume_straight_pages=False) + + if score < self.box_thresh: # remove polygons with a weak objectness + continue + + if self.assume_straight_pages: + # compute relative polygon to get rid of img shape + xmin, ymin, xmax, ymax = x / width, y / height, (x + w) / width, (y + h) / height + boxes.append([xmin, ymin, xmax, ymax, score]) + else: + _box = cv2.boxPoints(cv2.minAreaRect(contour)) + # compute relative box to get rid of img shape + _box[:, 0] /= width + _box[:, 1] /= height + boxes.append(_box) + + if not self.assume_straight_pages: + return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 4, 2), dtype=pred.dtype) + else: + return np.clip(np.asarray(boxes), 0, 1) if len(boxes) > 0 else np.zeros((0, 5), dtype=pred.dtype) + + +class _LinkNet(BaseModel): + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + Args: + out_chan: number of channels for the output + """ + + min_size_box: int = 3 + assume_straight_pages: bool = True + + def build_target( + self, + target: List[np.ndarray], + output_shape: Tuple[int, int], + ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + + if any(t.dtype != np.float32 for t in target): + raise AssertionError("the expected dtype of target 'boxes' entry is 'np.float32'.") + if any(np.any((t[:, :4] > 1) | (t[:, :4] < 0)) for t in target): + raise ValueError("the 'boxes' entry of the target is expected to take values between 0 & 1.") + + h, w = output_shape + target_shape = (len(target), h, w, 1) + + if self.assume_straight_pages: + seg_target = np.zeros(target_shape, dtype=bool) + edge_mask = np.zeros(target_shape, dtype=bool) + else: + seg_target = np.zeros(target_shape, dtype=np.uint8) + + seg_mask = np.ones(target_shape, dtype=bool) + + for idx, _target in enumerate(target): + # Draw each polygon on gt + if _target.shape[0] == 0: + # Empty image, full masked + seg_mask[idx] = False + + # Absolute bounding boxes + abs_boxes = _target.copy() + + if abs_boxes.ndim == 3: + abs_boxes[:, :, 0] *= w + abs_boxes[:, :, 1] *= h + abs_boxes = abs_boxes.round().astype(np.int32) + polys = abs_boxes + boxes_size = np.linalg.norm(abs_boxes[:, 2, :] - abs_boxes[:, 0, :], axis=-1) + else: + abs_boxes[:, [0, 2]] *= w + abs_boxes[:, [1, 3]] *= h + abs_boxes = abs_boxes.round().astype(np.int32) + polys = [None] * abs_boxes.shape[0] # Unused + boxes_size = np.minimum(abs_boxes[:, 2] - abs_boxes[:, 0], abs_boxes[:, 3] - abs_boxes[:, 1]) + + for poly, box, box_size in zip(polys, abs_boxes, boxes_size): + # Mask boxes that are too small + if box_size < self.min_size_box: + seg_mask[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = False + continue + # Fill polygon with 1 + if not self.assume_straight_pages: + cv2.fillPoly(seg_target[idx], [poly.astype(np.int32)], 1) + else: + if box.shape == (4, 2): + box = [np.min(box[:, 0]), np.min(box[:, 1]), np.max(box[:, 0]), np.max(box[:, 1])] + seg_target[idx, box[1]: box[3] + 1, box[0]: box[2] + 1] = True + # top edge + edge_mask[idx, box[1], box[0]: min(box[2] + 1, w)] = True + # bot edge + edge_mask[idx, min(box[3], h - 1), box[0]: min(box[2] + 1, w)] = True + # left edge + edge_mask[idx, box[1]: min(box[3] + 1, h), box[0]] = True + # right edge + edge_mask[idx, box[1]: min(box[3] + 1, h), min(box[2], w - 1)] = True + + # Don't forget to switch back to channel first if PyTorch is used + if not is_tf_available(): + seg_target = seg_target.transpose(0, 3, 1, 2) + seg_mask = seg_mask.transpose(0, 3, 1, 2) + edge_mask = edge_mask.transpose(0, 3, 1, 2) + + return seg_target, seg_mask, edge_mask diff --git a/doctr/models/detection/linknet/pytorch.py b/doctr/models/detection/linknet/pytorch.py new file mode 100644 index 0000000000..fc50691bd3 --- /dev/null +++ b/doctr/models/detection/linknet/pytorch.py @@ -0,0 +1,236 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, Callable, Dict, List, Optional, Tuple + +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F +from torchvision.models import resnet18 +from torchvision.models._utils import IntermediateLayerGetter + +from ...utils import load_pretrained_params +from .base import LinkNetPostProcessor, _LinkNet + +__all__ = ['LinkNet', 'linknet_resnet18'] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'linknet_resnet18': { + 'input_shape': (3, 1024, 1024), + 'mean': (.5, .5, .5), + 'std': (1., 1., 1.), + 'url': None, + }, +} + + +class LinkNetFPN(nn.Module): + def __init__(self, layer_shapes: List[Tuple[int, int, int]]) -> None: + super().__init__() + strides = [ + 1 if (in_shape[-1] == out_shape[-1]) else 2 + for in_shape, out_shape in zip(layer_shapes[:-1], layer_shapes[1:]) + ] + + chans = [shape[0] for shape in layer_shapes] + + _decoder_layers = [ + self.decoder_block(ochan, ichan, stride) for ichan, ochan, stride in zip(chans[:-1], chans[1:], strides) + ] + + self.decoders = nn.ModuleList(_decoder_layers) + + @staticmethod + def decoder_block(in_chan: int, out_chan: int, stride: int) -> nn.Sequential: + """Creates a LinkNet decoder block""" + + mid_chan = in_chan // 4 + return nn.Sequential( + nn.Conv2d(in_chan, mid_chan, kernel_size=1, bias=False), + nn.BatchNorm2d(mid_chan), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(mid_chan, mid_chan, 3, padding=1, output_padding=stride - 1, stride=stride, bias=False), + nn.BatchNorm2d(mid_chan), + nn.ReLU(inplace=True), + nn.Conv2d(mid_chan, out_chan, kernel_size=1, bias=False), + nn.BatchNorm2d(out_chan), + nn.ReLU(inplace=True), + ) + + def forward(self, feats: List[torch.Tensor]) -> torch.Tensor: + + out = feats[-1] + for decoder, fmap in zip(self.decoders[::-1], feats[:-1][::-1]): + out = decoder(out) + fmap + + out = self.decoders[0](out) + + return out + + +class LinkNet(nn.Module, _LinkNet): + + def __init__( + self, + feat_extractor: IntermediateLayerGetter, + num_classes: int = 1, + assume_straight_pages: bool = True, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + + super().__init__() + self.cfg = cfg + + self.feat_extractor = feat_extractor + # Identify the number of channels for the FPN initialization + self.feat_extractor.eval() + with torch.no_grad(): + in_shape = (3, 512, 512) + out = self.feat_extractor(torch.zeros((1, *in_shape))) + # Get the shapes of the extracted feature maps + _shapes = [v.shape[1:] for _, v in out.items()] + # Prepend the expected shapes of the first encoder + _shapes = [(_shapes[0][0], in_shape[1] // 4, in_shape[2] // 4)] + _shapes + self.feat_extractor.train() + + self.fpn = LinkNetFPN(_shapes) + + self.classifier = nn.Sequential( + nn.ConvTranspose2d(_shapes[0][0], 32, kernel_size=3, padding=1, output_padding=1, stride=2, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.Conv2d(32, 32, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(32), + nn.ReLU(inplace=True), + nn.ConvTranspose2d(32, num_classes, kernel_size=2, stride=2), + ) + + self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages) + + for n, m in self.named_modules(): + # Don't override the initialization of the backbone + if n.startswith('feat_extractor.'): + continue + if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): + nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def forward( + self, + x: torch.Tensor, + target: Optional[List[np.ndarray]] = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + + feats = self.feat_extractor(x) + logits = self.fpn([feats[str(idx)] for idx in range(len(feats))]) + logits = self.classifier(logits) + + out: Dict[str, Any] = {} + if return_model_output or target is None or return_preds: + prob_map = torch.sigmoid(logits) + if return_model_output: + out["out_map"] = prob_map + + if target is None or return_preds: + # Post-process boxes + out["preds"] = [ + preds[0] for preds in self.postprocessor(prob_map.detach().cpu().permute((0, 2, 3, 1)).numpy()) + ] + + if target is not None: + loss = self.compute_loss(logits, target) + out['loss'] = loss + + return out + + def compute_loss( + self, + out_map: torch.Tensor, + target: List[np.ndarray], + edge_factor: float = 2., + ) -> torch.Tensor: + """Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on + `_. + + Args: + out_map: output feature map of the model of shape (N, 1, H, W) + target: list of dictionary where each dict has a `boxes` and a `flags` entry + edge_factor: boost factor for box edges (in case of BCE) + + Returns: + A loss tensor + """ + seg_target, seg_mask, edge_mask = self.build_target(target, out_map.shape[-2:]) # type: ignore[arg-type] + + seg_target, seg_mask = torch.from_numpy(seg_target).to(dtype=out_map.dtype), torch.from_numpy(seg_mask) + seg_target, seg_mask = seg_target.to(out_map.device), seg_mask.to(out_map.device) + if edge_factor > 0: + edge_mask = torch.from_numpy(edge_mask).to(dtype=out_map.dtype, device=out_map.device) + + # Get the cross_entropy for each entry + loss = F.binary_cross_entropy_with_logits(out_map, seg_target, reduction='none') + + # Compute BCE loss with highlighted edges + if edge_factor > 0: + loss = ((1 + (edge_factor - 1) * edge_mask) * loss) + # Only consider contributions overlaping the mask + return loss[seg_mask].mean() + + +def _linknet( + arch: str, + pretrained: bool, + backbone_fn: Callable[[bool], nn.Module], + fpn_layers: List[str], + pretrained_backbone: bool = False, + **kwargs: Any +) -> LinkNet: + + pretrained_backbone = pretrained_backbone and not pretrained + + # Build the feature extractor + backbone = backbone_fn(pretrained_backbone) + feat_extractor = IntermediateLayerGetter( + backbone, + {layer_name: str(idx) for idx, layer_name in enumerate(fpn_layers)}, + ) + + # Build the model + model = LinkNet(feat_extractor, cfg=default_cfgs[arch], **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]['url']) + + return model + + +def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet: + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + Example:: + >>> import torch + >>> from doctr.models import linknet_resnet18 + >>> model = linknet_resnet18(pretrained=True).eval() + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> with torch.no_grad(): out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + + Returns: + text detection architecture + """ + + return _linknet('linknet_resnet18', pretrained, resnet18, ['layer1', 'layer2', 'layer3', 'layer4'], **kwargs) diff --git a/doctr/models/detection/linknet/tensorflow.py b/doctr/models/detection/linknet/tensorflow.py new file mode 100644 index 0000000000..792e503c7b --- /dev/null +++ b/doctr/models/detection/linknet/tensorflow.py @@ -0,0 +1,263 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +# Credits: post-processing adapted from https://github.com/xuannianz/DifferentiableBinarization + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +import numpy as np +import tensorflow as tf +from tensorflow import keras +from tensorflow.keras import Model, Sequential, layers + +from doctr.models.classification import resnet18 +from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params +from doctr.utils.repr import NestedObject + +from .base import LinkNetPostProcessor, _LinkNet + +__all__ = ['LinkNet', 'linknet_resnet18'] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'linknet_resnet18': { + 'mean': (0.798, 0.785, 0.772), + 'std': (0.264, 0.2749, 0.287), + 'input_shape': (1024, 1024, 3), + 'url': None, + }, +} + + +def decoder_block(in_chan: int, out_chan: int, stride: int, **kwargs: Any) -> Sequential: + """Creates a LinkNet decoder block""" + + return Sequential([ + *conv_sequence(in_chan // 4, 'relu', True, kernel_size=1, **kwargs), + layers.Conv2DTranspose( + filters=in_chan // 4, + kernel_size=3, + strides=stride, + padding="same", + use_bias=False, + kernel_initializer='he_normal' + ), + layers.BatchNormalization(), + layers.Activation('relu'), + *conv_sequence(out_chan, 'relu', True, kernel_size=1), + ]) + + +class LinkNetFPN(Model, NestedObject): + """LinkNet Decoder module""" + + def __init__( + self, + out_chans: int, + in_shapes: List[Tuple[int, ...]], + ) -> None: + + super().__init__() + self.out_chans = out_chans + strides = [2] * (len(in_shapes) - 1) + [1] + i_chans = [s[-1] for s in in_shapes[::-1]] + o_chans = i_chans[1:] + [out_chans] + self.decoders = [ + decoder_block(in_chan, out_chan, s, input_shape=in_shape) + for in_chan, out_chan, s, in_shape in zip(i_chans, o_chans, strides, in_shapes[::-1]) + ] + + def call( + self, + x: List[tf.Tensor] + ) -> tf.Tensor: + out = 0 + for decoder, fmap in zip(self.decoders, x[::-1]): + out = decoder(out + fmap) + return out + + def extra_repr(self) -> str: + return f"out_chans={self.out_chans}" + + +class LinkNet(_LinkNet, keras.Model): + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + Args: + num_classes: number of channels for the output + """ + + _children_names: List[str] = ['feat_extractor', 'fpn', 'classifier', 'postprocessor'] + + def __init__( + self, + feat_extractor: IntermediateLayerGetter, + fpn_channels: int = 64, + num_classes: int = 1, + assume_straight_pages: bool = True, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__(cfg=cfg) + + self.assume_straight_pages = assume_straight_pages + + self.feat_extractor = feat_extractor + + self.fpn = LinkNetFPN(fpn_channels, [_shape[1:] for _shape in self.feat_extractor.output_shape]) + self.fpn.build(self.feat_extractor.output_shape) + + self.classifier = Sequential([ + layers.Conv2DTranspose( + filters=32, + kernel_size=3, + strides=2, + padding="same", + use_bias=False, + kernel_initializer='he_normal', + input_shape=self.fpn.decoders[-1].output_shape[1:], + ), + layers.BatchNormalization(), + layers.Activation('relu'), + *conv_sequence(32, 'relu', True, kernel_size=3, strides=1), + layers.Conv2DTranspose( + filters=num_classes, + kernel_size=2, + strides=2, + padding="same", + use_bias=True, + kernel_initializer='he_normal' + ), + ]) + + self.postprocessor = LinkNetPostProcessor(assume_straight_pages=assume_straight_pages) + + def compute_loss( + self, + out_map: tf.Tensor, + target: List[np.ndarray], + edge_factor: float = 2., + ) -> tf.Tensor: + """Compute linknet loss, BCE with boosted box edges or focal loss. Focal loss implementation based on + `_. + + Args: + out_map: output feature map of the model of shape N x H x W x 1 + target: list of dictionary where each dict has a `boxes` and a `flags` entry + edge_factor: boost factor for box edges (in case of BCE) + + Returns: + A loss tensor + """ + seg_target, seg_mask, edge_mask = self.build_target(target, out_map.shape[1:3]) + + seg_target = tf.convert_to_tensor(seg_target, dtype=out_map.dtype) + seg_mask = tf.convert_to_tensor(seg_mask, dtype=tf.bool) + if edge_factor > 0: + edge_mask = tf.convert_to_tensor(edge_mask, dtype=tf.bool) + + # Get the cross_entropy for each entry + loss = tf.keras.losses.binary_crossentropy(seg_target, out_map, from_logits=True)[..., None] + + # Compute BCE loss with highlighted edges + if edge_factor > 0: + loss = tf.math.multiply( + 1 + (edge_factor - 1) * tf.cast(edge_mask, out_map.dtype), + loss + ) + + return tf.reduce_mean(loss[seg_mask]) + + def call( + self, + x: tf.Tensor, + target: Optional[List[np.ndarray]] = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + + feat_maps = self.feat_extractor(x, **kwargs) + logits = self.fpn(feat_maps, **kwargs) + logits = self.classifier(logits, **kwargs) + + out: Dict[str, tf.Tensor] = {} + if return_model_output or target is None or return_preds: + prob_map = tf.math.sigmoid(logits) + if return_model_output: + out["out_map"] = prob_map + + if target is None or return_preds: + # Post-process boxes + out["preds"] = [preds[0] for preds in self.postprocessor(prob_map.numpy())] + + if target is not None: + loss = self.compute_loss(logits, target) + out['loss'] = loss + + return out + + +def _linknet( + arch: str, + pretrained: bool, + backbone_fn, + fpn_layers: List[str], + pretrained_backbone: bool = True, + input_shape: Optional[Tuple[int, int, int]] = None, + **kwargs: Any +) -> LinkNet: + + pretrained_backbone = pretrained_backbone and not pretrained + + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg['input_shape'] = input_shape or default_cfgs[arch]['input_shape'] + + # Feature extractor + feat_extractor = IntermediateLayerGetter( + backbone_fn( + pretrained=pretrained_backbone, + include_top=False, + input_shape=_cfg['input_shape'], + ), + fpn_layers, + ) + + # Build the model + model = LinkNet(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, _cfg['url']) + + return model + + +def linknet_resnet18(pretrained: bool = False, **kwargs: Any) -> LinkNet: + """LinkNet as described in `"LinkNet: Exploiting Encoder Representations for Efficient Semantic Segmentation" + `_. + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import linknet_resnet18 + >>> model = linknet_resnet18(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 1024, 1024, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text detection dataset + + Returns: + text detection architecture + """ + + return _linknet( + 'linknet_resnet18', + pretrained, + resnet18, + ['resnet_block_1', 'resnet_block_3', 'resnet_block_5', 'resnet_block_7'], + **kwargs, + ) diff --git a/doctr/models/detection/predictor/__init__.py b/doctr/models/detection/predictor/__init__.py new file mode 100644 index 0000000000..6a3fee30ac --- /dev/null +++ b/doctr/models/detection/predictor/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available + +if is_tf_available(): + from .tensorflow import * +else: + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/detection/predictor/pytorch.py b/doctr/models/detection/predictor/pytorch.py new file mode 100644 index 0000000000..4eecde6135 --- /dev/null +++ b/doctr/models/detection/predictor/pytorch.py @@ -0,0 +1,51 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, List, Union + +import numpy as np +import torch +from torch import nn + +from doctr.models.preprocessor import PreProcessor + +__all__ = ['DetectionPredictor'] + + +class DetectionPredictor(nn.Module): + """Implements an object able to localize text elements in a document + + Args: + pre_processor: transform inputs for easier batched model inference + model: core detection architecture + """ + + def __init__( + self, + pre_processor: PreProcessor, + model: nn.Module, + ) -> None: + + super().__init__() + self.pre_processor = pre_processor + self.model = model.eval() + + @torch.no_grad() + def forward( + self, + pages: List[Union[np.ndarray, torch.Tensor]], + **kwargs: Any, + ) -> List[np.ndarray]: + + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + processed_batches = self.pre_processor(pages) + predicted_batches = [ + self.model(batch, return_preds=True, **kwargs)['preds'] # type:ignore[operator] + for batch in processed_batches + ] + return [pred for batch in predicted_batches for pred in batch] diff --git a/doctr/models/detection/predictor/tensorflow.py b/doctr/models/detection/predictor/tensorflow.py new file mode 100644 index 0000000000..8bb39fbcfa --- /dev/null +++ b/doctr/models/detection/predictor/tensorflow.py @@ -0,0 +1,52 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, List, Union + +import numpy as np +import tensorflow as tf +from tensorflow import keras + +from doctr.models.preprocessor import PreProcessor +from doctr.utils.repr import NestedObject + +__all__ = ['DetectionPredictor'] + + +class DetectionPredictor(NestedObject): + """Implements an object able to localize text elements in a document + + Args: + pre_processor: transform inputs for easier batched model inference + model: core detection architecture + """ + + _children_names: List[str] = ['pre_processor', 'model'] + + def __init__( + self, + pre_processor: PreProcessor, + model: keras.Model, + ) -> None: + + self.pre_processor = pre_processor + self.model = model + + def __call__( + self, + pages: List[Union[np.ndarray, tf.Tensor]], + **kwargs: Any, + ) -> List[np.ndarray]: + + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + processed_batches = self.pre_processor(pages) + predicted_batches = [ + self.model(batch, return_preds=True, training=False, **kwargs)['preds'] # type:ignore[operator] + for batch in processed_batches + ] + return [pred for batch in predicted_batches for pred in batch] diff --git a/doctr/models/detection/zoo.py b/doctr/models/detection/zoo.py new file mode 100644 index 0000000000..8eb7225bff --- /dev/null +++ b/doctr/models/detection/zoo.py @@ -0,0 +1,76 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any + +from doctr.file_utils import is_tf_available, is_torch_available + +from .. import detection +from ..preprocessor import PreProcessor +from .predictor import DetectionPredictor + +__all__ = ["detection_predictor"] + + +if is_tf_available(): + ARCHS = ['db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18'] + ROT_ARCHS = [] +elif is_torch_available(): + ARCHS = ['db_resnet34', 'db_resnet50', 'db_mobilenet_v3_large', 'linknet_resnet18', 'db_resnet50_rotation'] + ROT_ARCHS = ['db_resnet50_rotation'] + + +def _predictor( + arch: str, + pretrained: bool, + assume_straight_pages: bool = True, + **kwargs: Any +) -> DetectionPredictor: + + if arch not in ARCHS: + raise ValueError(f"unknown architecture '{arch}'") + + if arch not in ROT_ARCHS and not assume_straight_pages: + raise AssertionError("You are trying to use a model trained on straight pages while not assuming" + " your pages are straight. If you have only straight documents, don't pass" + f" assume_straight_pages=False, otherwise you should use one of these archs: {ROT_ARCHS}") + + # Detection + _model = detection.__dict__[arch](pretrained=pretrained, assume_straight_pages=assume_straight_pages) + kwargs['mean'] = kwargs.get('mean', _model.cfg['mean']) + kwargs['std'] = kwargs.get('std', _model.cfg['std']) + kwargs['batch_size'] = kwargs.get('batch_size', 1) + predictor = DetectionPredictor( + PreProcessor(_model.cfg['input_shape'][:-1] if is_tf_available() else _model.cfg['input_shape'][1:], **kwargs), + _model + ) + return predictor + + +def detection_predictor( + arch: str = 'db_resnet50', + pretrained: bool = False, + assume_straight_pages: bool = True, + **kwargs: Any +) -> DetectionPredictor: + """Text detection architecture. + + Example:: + >>> import numpy as np + >>> from doctr.models import detection_predictor + >>> model = detection_predictor(arch='db_resnet50', pretrained=True) + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + arch: name of the architecture to use (e.g. 'db_resnet50') + pretrained: If True, returns a model pre-trained on our text detection dataset + assume_straight_pages: If True, fit straight boxes to the page + + Returns: + Detection predictor + """ + + return _predictor(arch, pretrained, assume_straight_pages, **kwargs) diff --git a/doctr/models/obj_detection/__init__.py b/doctr/models/obj_detection/__init__.py new file mode 100644 index 0000000000..e7e08d7ac8 --- /dev/null +++ b/doctr/models/obj_detection/__init__.py @@ -0,0 +1 @@ +from .faster_rcnn import * diff --git a/doctr/models/obj_detection/faster_rcnn/__init__.py b/doctr/models/obj_detection/faster_rcnn/__init__.py new file mode 100644 index 0000000000..6748076aaf --- /dev/null +++ b/doctr/models/obj_detection/faster_rcnn/__init__.py @@ -0,0 +1,4 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if not is_tf_available() and is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/obj_detection/faster_rcnn/pytorch.py b/doctr/models/obj_detection/faster_rcnn/pytorch.py new file mode 100644 index 0000000000..c2df57ca06 --- /dev/null +++ b/doctr/models/obj_detection/faster_rcnn/pytorch.py @@ -0,0 +1,79 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, Dict + +from torchvision.models.detection import FasterRCNN, faster_rcnn + +from ...utils import load_pretrained_params + +__all__ = ['fasterrcnn_mobilenet_v3_large_fpn'] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'fasterrcnn_mobilenet_v3_large_fpn': { + 'input_shape': (3, 1024, 1024), + 'mean': (0.485, 0.456, 0.406), + 'std': (0.229, 0.224, 0.225), + 'anchor_sizes': [32, 64, 128, 256, 512], + 'anchor_aspect_ratios': (0.5, 1., 2.), + 'num_classes': 5, + 'url': 'https://github.com/mindee/doctr/releases/download/v0.4.1/fasterrcnn_mobilenet_v3_large_fpn-d5b2490d.pt', + }, +} + + +def _fasterrcnn(arch: str, pretrained: bool, **kwargs: Any) -> FasterRCNN: + + _kwargs = { + "image_mean": default_cfgs[arch]['mean'], + "image_std": default_cfgs[arch]['std'], + "box_detections_per_img": 150, + "box_score_thresh": 0.15, + "box_positive_fraction": 0.35, + "box_nms_thresh": 0.2, + "rpn_nms_thresh": 0.2, + "num_classes": default_cfgs[arch]['num_classes'], + } + + # Build the model + _kwargs.update(kwargs) + model = faster_rcnn.__dict__[arch](pretrained=False, pretrained_backbone=False, **_kwargs) + + if pretrained: + # Load pretrained parameters + load_pretrained_params(model, default_cfgs[arch]['url']) + else: + # Filter keys + state_dict = { + k: v for k, v in faster_rcnn.__dict__[arch](pretrained=True).state_dict().items() + if not k.startswith('roi_heads.') + } + + # Load state dict + model.load_state_dict(state_dict, strict=False) + + return model + + +def fasterrcnn_mobilenet_v3_large_fpn(pretrained: bool = False, **kwargs: Any) -> FasterRCNN: + """Faster-RCNN architecture with a MobileNet V3 backbone as described in `"Faster R-CNN: Towards Real-Time + Object Detection with Region Proposal Networks" `_. + + Example:: + >>> import torch + >>> from doctr.models.obj_detection import fasterrcnn_mobilenet_v3_large_fpn + >>> model = fasterrcnn_mobilenet_v3_large_fpn(pretrained=True).eval() + >>> input_tensor = torch.rand((1, 3, 1024, 1024), dtype=torch.float32) + >>> with torch.no_grad(): out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our object detection dataset + + Returns: + object detection architecture + """ + + return _fasterrcnn('fasterrcnn_mobilenet_v3_large_fpn', pretrained, **kwargs) diff --git a/doctr/models/predictor/__init__.py b/doctr/models/predictor/__init__.py new file mode 100644 index 0000000000..6a3fee30ac --- /dev/null +++ b/doctr/models/predictor/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available + +if is_tf_available(): + from .tensorflow import * +else: + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/predictor/base.py b/doctr/models/predictor/base.py new file mode 100644 index 0000000000..2bb22f6a9b --- /dev/null +++ b/doctr/models/predictor/base.py @@ -0,0 +1,95 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import List, Tuple + +import numpy as np + +from doctr.models.builder import DocumentBuilder + +from .._utils import extract_crops, extract_rcrops, rectify_crops, rectify_loc_preds +from ..classification import crop_orientation_predictor + +__all__ = ['_OCRPredictor'] + + +class _OCRPredictor: + """Implements an object able to localize and identify text elements in a set of documents + + Args: + det_predictor: detection module + reco_predictor: recognition module + """ + + doc_builder: DocumentBuilder + + def __init__(self) -> None: + self.crop_orientation_predictor = crop_orientation_predictor(pretrained=True) + + @staticmethod + def _generate_crops( + pages: List[np.ndarray], + loc_preds: List[np.ndarray], + channels_last: bool, + assume_straight_pages: bool = False, + ) -> List[List[np.ndarray]]: + + extraction_fn = extract_crops if assume_straight_pages else extract_rcrops + + crops = [ + extraction_fn(page, _boxes[:, :4], channels_last=channels_last) # type: ignore[operator] + for page, _boxes in zip(pages, loc_preds) + ] + return crops + + @staticmethod + def _prepare_crops( + pages: List[np.ndarray], + loc_preds: List[np.ndarray], + channels_last: bool, + assume_straight_pages: bool = False, + ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]: + + crops = _OCRPredictor._generate_crops(pages, loc_preds, channels_last, assume_straight_pages) + + # Avoid sending zero-sized crops + is_kept = [[all(s > 0 for s in crop.shape) for crop in page_crops] for page_crops in crops] + crops = [ + [crop for crop, _kept in zip(page_crops, page_kept) if _kept] + for page_crops, page_kept in zip(crops, is_kept) + ] + loc_preds = [_boxes[_kept] for _boxes, _kept in zip(loc_preds, is_kept)] + + return crops, loc_preds + + def _rectify_crops( + self, + crops: List[List[np.ndarray]], + loc_preds: List[np.ndarray], + ) -> Tuple[List[List[np.ndarray]], List[np.ndarray]]: + # Work at a page level + orientations = [self.crop_orientation_predictor(page_crops) for page_crops in crops] + rect_crops = [rectify_crops(page_crops, orientation) for page_crops, orientation in zip(crops, orientations)] + rect_loc_preds = [ + rectify_loc_preds(page_loc_preds, orientation) if len(page_loc_preds) > 0 else page_loc_preds + for page_loc_preds, orientation in zip(loc_preds, orientations) + ] + return rect_crops, rect_loc_preds + + @staticmethod + def _process_predictions( + loc_preds: List[np.ndarray], + word_preds: List[Tuple[str, float]], + ) -> Tuple[List[np.ndarray], List[List[Tuple[str, float]]]]: + + text_preds = [] + if len(loc_preds) > 0: + # Text + _idx = 0 + for page_boxes in loc_preds: + text_preds.append(word_preds[_idx: _idx + page_boxes.shape[0]]) + _idx += page_boxes.shape[0] + + return loc_preds, text_preds diff --git a/doctr/models/predictor/pytorch.py b/doctr/models/predictor/pytorch.py new file mode 100644 index 0000000000..848157951c --- /dev/null +++ b/doctr/models/predictor/pytorch.py @@ -0,0 +1,105 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, List, Union + +import numpy as np +import torch +from torch import nn + +from doctr.io.elements import Document +from doctr.models._utils import estimate_orientation +from doctr.models.builder import DocumentBuilder +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.utils.geometry import rotate_boxes, rotate_image + +from ..classification import crop_orientation_predictor +from .base import _OCRPredictor + +__all__ = ['OCRPredictor'] + + +class OCRPredictor(nn.Module, _OCRPredictor): + """Implements an object able to localize and identify text elements in a set of documents + + Args: + det_predictor: detection module + reco_predictor: recognition module + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions + (potentially rotated) as straight bounding boxes. + straighten_pages: if True, estimates the page general orientation based on the median line orientation. + Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped + accordingly. Doing so will improve performances for documents with page-uniform rotations. + + """ + + def __init__( + self, + det_predictor: DetectionPredictor, + reco_predictor: RecognitionPredictor, + assume_straight_pages: bool = True, + export_as_straight_boxes: bool = False, + straighten_pages: bool = False, + ) -> None: + + super().__init__() + self.det_predictor = det_predictor.eval() # type: ignore[attr-defined] + self.reco_predictor = reco_predictor.eval() # type: ignore[attr-defined] + self.doc_builder = DocumentBuilder(export_as_straight_boxes=export_as_straight_boxes) + self.assume_straight_pages = assume_straight_pages + self.straighten_pages = straighten_pages + self.crop_orientation_predictor = crop_orientation_predictor(pretrained=True) + + @torch.no_grad() + def forward( + self, + pages: List[Union[np.ndarray, torch.Tensor]], + **kwargs: Any, + ) -> Document: + + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + # Detect document rotation and rotate pages + if self.straighten_pages: + origin_page_orientations = [estimate_orientation(page) for page in pages] + pages = [rotate_image(page, -angle, expand=True) for page, angle in zip(pages, origin_page_orientations)] + + # Localize text elements + loc_preds = self.det_predictor(pages, **kwargs) + # Check whether crop mode should be switched to channels first + channels_last = len(pages) == 0 or isinstance(pages[0], np.ndarray) + # Crop images + crops, loc_preds = self._prepare_crops( + pages, loc_preds, channels_last=channels_last, assume_straight_pages=self.assume_straight_pages + ) + # Rectify crop orientation + if not self.assume_straight_pages: + crops, loc_preds = self._rectify_crops(crops, loc_preds) + # Identify character sequences + word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs) + + boxes, text_preds = self._process_predictions(loc_preds, word_preds) + + # Rotate back pages and boxes while keeping original image size + if self.straighten_pages: + boxes = [rotate_boxes(page_boxes, + angle, + orig_shape=page.shape[:2] if isinstance(page, np.ndarray) else page.shape[-2:] + ) for page_boxes, page, angle in zip(boxes, pages, origin_page_orientations)] + + out = self.doc_builder( + boxes, + text_preds, + [ + page.shape[:2] if channels_last else page.shape[-2:] # type: ignore[misc] + for page in pages + ] + ) + return out diff --git a/doctr/models/predictor/tensorflow.py b/doctr/models/predictor/tensorflow.py new file mode 100644 index 0000000000..0244a4b118 --- /dev/null +++ b/doctr/models/predictor/tensorflow.py @@ -0,0 +1,96 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, List, Union + +import numpy as np +import tensorflow as tf + +from doctr.io.elements import Document +from doctr.models._utils import estimate_orientation +from doctr.models.builder import DocumentBuilder +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.utils.geometry import rotate_boxes, rotate_image +from doctr.utils.repr import NestedObject + +from ..classification import crop_orientation_predictor +from .base import _OCRPredictor + +__all__ = ['OCRPredictor'] + + +class OCRPredictor(NestedObject, _OCRPredictor): + """Implements an object able to localize and identify text elements in a set of documents + + Args: + det_predictor: detection module + reco_predictor: recognition module + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions + (potentially rotated) as straight bounding boxes. + straighten_pages: if True, estimates the page general orientation based on the median line orientation. + Then, rotates page before passing it to the deep learning modules. The final predictions will be remapped + accordingly. Doing so will improve performances for documents with page-uniform rotations. + """ + _children_names = ['det_predictor', 'reco_predictor'] + + def __init__( + self, + det_predictor: DetectionPredictor, + reco_predictor: RecognitionPredictor, + assume_straight_pages: bool = True, + export_as_straight_boxes: bool = False, + straighten_pages: bool = False, + ) -> None: + + super().__init__() + self.det_predictor = det_predictor + self.reco_predictor = reco_predictor + self.doc_builder = DocumentBuilder(export_as_straight_boxes=export_as_straight_boxes) + self.assume_straight_pages = assume_straight_pages + self.straighten_pages = straighten_pages + self.crop_orientation_predictor = crop_orientation_predictor(pretrained=True) + + def __call__( + self, + pages: List[Union[np.ndarray, tf.Tensor]], + **kwargs: Any, + ) -> Document: + + # Dimension check + if any(page.ndim != 3 for page in pages): + raise ValueError("incorrect input shape: all pages are expected to be multi-channel 2D images.") + + origin_page_shapes = [page.shape[:2] for page in pages] + + # Detect document rotation and rotate pages + if self.straighten_pages: + origin_page_orientations = [estimate_orientation(page) for page in pages] + pages = [rotate_image(page, -angle, expand=True) for page, angle in zip(pages, origin_page_orientations)] + + # Localize text elements + loc_preds = self.det_predictor(pages, **kwargs) + # Crop images + crops, loc_preds = self._prepare_crops( + pages, loc_preds, channels_last=True, assume_straight_pages=self.assume_straight_pages + ) + # Rectify crop orientation + if not self.assume_straight_pages: + crops, loc_preds = self._rectify_crops(crops, loc_preds) + + # Identify character sequences + word_preds = self.reco_predictor([crop for page_crops in crops for crop in page_crops], **kwargs) + + boxes, text_preds = self._process_predictions(loc_preds, word_preds) + + # Rotate back pages and boxes while keeping original image size + if self.straighten_pages: + boxes = [rotate_boxes(page_boxes, angle, orig_shape=page.shape[:2]) for + page_boxes, page, angle in zip(boxes, pages, origin_page_orientations)] + + out = self.doc_builder(boxes, text_preds, origin_page_shapes) # type: ignore[misc] + return out diff --git a/doctr/models/preprocessor/__init__.py b/doctr/models/preprocessor/__init__.py new file mode 100644 index 0000000000..059f261e82 --- /dev/null +++ b/doctr/models/preprocessor/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/preprocessor/pytorch.py b/doctr/models/preprocessor/pytorch.py new file mode 100644 index 0000000000..52549bde36 --- /dev/null +++ b/doctr/models/preprocessor/pytorch.py @@ -0,0 +1,127 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import math +from typing import Any, List, Tuple, Union + +import numpy as np +import torch +from torch import nn +from torchvision.transforms import functional as F +from torchvision.transforms import transforms as T + +from doctr.transforms import Resize +from doctr.utils.multithreading import multithread_exec + +__all__ = ['PreProcessor'] + + +class PreProcessor(nn.Module): + """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization. + + Args: + output_size: expected size of each page in format (H, W) + batch_size: the size of page batches + mean: mean value of the training distribution by channel + std: standard deviation of the training distribution by channel + """ + + def __init__( + self, + output_size: Tuple[int, int], + batch_size: int, + mean: Tuple[float, float, float] = (.5, .5, .5), + std: Tuple[float, float, float] = (1., 1., 1.), + fp16: bool = False, + **kwargs: Any, + ) -> None: + super().__init__() + self.batch_size = batch_size + self.resize: T.Resize = Resize(output_size, **kwargs) + # Perform the division by 255 at the same time + self.normalize = T.Normalize(mean, std) + + def batch_inputs( + self, + samples: List[torch.Tensor] + ) -> List[torch.Tensor]: + """Gather samples into batches for inference purposes + + Args: + samples: list of samples of shape (C, H, W) + + Returns: + list of batched samples (*, C, H, W) + """ + + num_batches = int(math.ceil(len(samples) / self.batch_size)) + batches = [ + torch.stack(samples[idx * self.batch_size: min((idx + 1) * self.batch_size, len(samples))], dim=0) + for idx in range(int(num_batches)) + ] + + return batches + + def sample_transforms(self, x: Union[np.ndarray, torch.Tensor]) -> torch.Tensor: + if x.ndim != 3: + raise AssertionError("expected list of 3D Tensors") + if isinstance(x, np.ndarray): + if x.dtype not in (np.uint8, np.float32): + raise TypeError("unsupported data type for numpy.ndarray") + x = torch.from_numpy(x.copy()).permute(2, 0, 1) + elif x.dtype not in (torch.uint8, torch.float16, torch.float32): + raise TypeError("unsupported data type for torch.Tensor") + # Resizing + x = self.resize(x) + # Data type + if x.dtype == torch.uint8: + x = x.to(dtype=torch.float32).div(255).clip(0, 1) + x = x.to(dtype=torch.float32) + + return x + + def __call__( + self, + x: Union[torch.Tensor, np.ndarray, List[Union[torch.Tensor, np.ndarray]]] + ) -> List[torch.Tensor]: + """Prepare document data for model forwarding + + Args: + x: list of images (np.array) or tensors (already resized and batched) + Returns: + list of page batches + """ + + # Input type check + if isinstance(x, (np.ndarray, torch.Tensor)): + if x.ndim != 4: + raise AssertionError("expected 4D Tensor") + if isinstance(x, np.ndarray): + if x.dtype not in (np.uint8, np.float32): + raise TypeError("unsupported data type for numpy.ndarray") + x = torch.from_numpy(x.copy()).permute(0, 3, 1, 2) + elif x.dtype not in (torch.uint8, torch.float16, torch.float32): + raise TypeError("unsupported data type for torch.Tensor") + # Resizing + if x.shape[-2] != self.resize.size[0] or x.shape[-1] != self.resize.size[1]: + x = F.resize(x, self.resize.size, interpolation=self.resize.interpolation) + # Data type + if x.dtype == torch.uint8: + x = x.to(dtype=torch.float32).div(255).clip(0, 1) + x = x.to(dtype=torch.float32) + batches = [x] + + elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, torch.Tensor)) for sample in x): + # Sample transform (to tensor, resize) + samples = multithread_exec(self.sample_transforms, x) + # Batching + batches = self.batch_inputs(samples) # type: ignore[arg-type] + else: + raise TypeError(f"invalid input type: {type(x)}") + + # Batch transforms (normalize) + batches = multithread_exec(self.normalize, batches) # type: ignore[assignment] + + return batches diff --git a/doctr/models/preprocessor/tensorflow.py b/doctr/models/preprocessor/tensorflow.py new file mode 100644 index 0000000000..642568b166 --- /dev/null +++ b/doctr/models/preprocessor/tensorflow.py @@ -0,0 +1,127 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import math +from typing import Any, List, Tuple, Union + +import numpy as np +import tensorflow as tf + +from doctr.transforms import Normalize, Resize +from doctr.utils.multithreading import multithread_exec +from doctr.utils.repr import NestedObject + +__all__ = ['PreProcessor'] + + +class PreProcessor(NestedObject): + """Implements an abstract preprocessor object which performs casting, resizing, batching and normalization. + + Args: + output_size: expected size of each page in format (H, W) + batch_size: the size of page batches + mean: mean value of the training distribution by channel + std: standard deviation of the training distribution by channel + """ + + _children_names: List[str] = ['resize', 'normalize'] + + def __init__( + self, + output_size: Tuple[int, int], + batch_size: int, + mean: Tuple[float, float, float] = (.5, .5, .5), + std: Tuple[float, float, float] = (1., 1., 1.), + fp16: bool = False, + **kwargs: Any, + ) -> None: + + self.batch_size = batch_size + self.resize = Resize(output_size, **kwargs) + # Perform the division by 255 at the same time + self.normalize = Normalize(mean, std) + + def batch_inputs( + self, + samples: List[tf.Tensor] + ) -> List[tf.Tensor]: + """Gather samples into batches for inference purposes + + Args: + samples: list of samples (tf.Tensor) + + Returns: + list of batched samples + """ + + num_batches = int(math.ceil(len(samples) / self.batch_size)) + batches = [ + tf.stack(samples[idx * self.batch_size: min((idx + 1) * self.batch_size, len(samples))], axis=0) + for idx in range(int(num_batches)) + ] + + return batches + + def sample_transforms(self, x: Union[np.ndarray, tf.Tensor]) -> tf.Tensor: + if x.ndim != 3: + raise AssertionError("expected list of 3D Tensors") + if isinstance(x, np.ndarray): + if x.dtype not in (np.uint8, np.float32): + raise TypeError("unsupported data type for numpy.ndarray") + x = tf.convert_to_tensor(x) + elif x.dtype not in (tf.uint8, tf.float16, tf.float32): + raise TypeError("unsupported data type for torch.Tensor") + # Data type & 255 division + if x.dtype == tf.uint8: + x = tf.image.convert_image_dtype(x, dtype=tf.float32) + # Resizing + x = self.resize(x) + + return x + + def __call__( + self, + x: Union[tf.Tensor, np.ndarray, List[Union[tf.Tensor, np.ndarray]]] + ) -> List[tf.Tensor]: + """Prepare document data for model forwarding + + Args: + x: list of images (np.array) or tensors (already resized and batched) + Returns: + list of page batches + """ + + # Input type check + if isinstance(x, (np.ndarray, tf.Tensor)): + if x.ndim != 4: + raise AssertionError("expected 4D Tensor") + if isinstance(x, np.ndarray): + if x.dtype not in (np.uint8, np.float32): + raise TypeError("unsupported data type for numpy.ndarray") + x = tf.convert_to_tensor(x) + elif x.dtype not in (tf.uint8, tf.float16, tf.float32): + raise TypeError("unsupported data type for torch.Tensor") + + # Data type & 255 division + if x.dtype == tf.uint8: + x = tf.image.convert_image_dtype(x, dtype=tf.float32) + # Resizing + if x.shape[1] != self.resize.output_size[0] or x.shape[2] != self.resize.output_size[1]: + x = tf.image.resize(x, self.resize.output_size, method=self.resize.method) + + batches = [x] + + elif isinstance(x, list) and all(isinstance(sample, (np.ndarray, tf.Tensor)) for sample in x): + # Sample transform (to tensor, resize) + samples = multithread_exec(self.sample_transforms, x) + # Batching + batches = self.batch_inputs(samples) # type: ignore[arg-type] + else: + raise TypeError(f"invalid input type: {type(x)}") + + # Batch transforms (normalize) + batches = multithread_exec(self.normalize, batches) # type: ignore[assignment] + + return batches diff --git a/doctr/models/recognition/__init__.py b/doctr/models/recognition/__init__.py new file mode 100644 index 0000000000..9fc57e6d40 --- /dev/null +++ b/doctr/models/recognition/__init__.py @@ -0,0 +1,4 @@ +from .crnn import * +from .master import * +from .sar import * +from .zoo import * diff --git a/doctr/models/recognition/core.py b/doctr/models/recognition/core.py new file mode 100644 index 0000000000..4203611c19 --- /dev/null +++ b/doctr/models/recognition/core.py @@ -0,0 +1,61 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import List, Tuple + +import numpy as np + +from doctr.datasets import encode_sequences +from doctr.utils.repr import NestedObject + +__all__ = ['RecognitionPostProcessor', 'RecognitionModel'] + + +class RecognitionModel(NestedObject): + """Implements abstract RecognitionModel class""" + + vocab: str + max_length: int + + def build_target( + self, + gts: List[str], + ) -> Tuple[np.ndarray, List[int]]: + """Encode a list of gts sequences into a np array and gives the corresponding* + sequence lengths. + + Args: + gts: list of ground-truth labels + + Returns: + A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) + """ + encoded = encode_sequences( + sequences=gts, + vocab=self.vocab, + target_size=self.max_length, + eos=len(self.vocab) + ) + seq_len = [len(word) for word in gts] + return encoded, seq_len + + +class RecognitionPostProcessor(NestedObject): + """Abstract class to postprocess the raw output of the model + + Args: + vocab: string containing the ordered sequence of supported characters + """ + + def __init__( + self, + vocab: str, + ) -> None: + + self.vocab = vocab + self._embedding = list(self.vocab) + [''] + + def extra_repr(self) -> str: + return f"vocab_size={len(self.vocab)}" diff --git a/doctr/models/recognition/crnn/__init__.py b/doctr/models/recognition/crnn/__init__.py new file mode 100644 index 0000000000..059f261e82 --- /dev/null +++ b/doctr/models/recognition/crnn/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/recognition/crnn/pytorch.py b/doctr/models/recognition/crnn/pytorch.py new file mode 100644 index 0000000000..693a963c33 --- /dev/null +++ b/doctr/models/recognition/crnn/pytorch.py @@ -0,0 +1,308 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from itertools import groupby +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import functional as F + +from doctr.datasets import VOCABS, decode_sequence + +from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r +from ...utils.pytorch import load_pretrained_params +from ..core import RecognitionModel, RecognitionPostProcessor + +__all__ = ['CRNN', 'crnn_vgg16_bn', 'crnn_mobilenet_v3_small', + 'crnn_mobilenet_v3_large'] + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'crnn_vgg16_bn': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (3, 32, 128), + 'vocab': VOCABS['legacy_french'], + 'url': 'https://github.com/mindee/doctr/releases/download/v0.3.1/crnn_vgg16_bn-9762b0b0.pt', + }, + 'crnn_mobilenet_v3_small': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (3, 32, 128), + 'vocab': VOCABS['french'], + 'url': "https://github.com/mindee/doctr/releases/download/v0.3.1/crnn_mobilenet_v3_small_pt-3b919a02.pt", + }, + 'crnn_mobilenet_v3_large': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (3, 32, 128), + 'vocab': VOCABS['french'], + 'url': "https://github.com/mindee/doctr/releases/download/v0.3.1/crnn_mobilenet_v3_large_pt-f5259ec2.pt", + }, +} + + +class CTCPostProcessor(RecognitionPostProcessor): + """ + Postprocess raw prediction of the model (logits) to a list of words using CTC decoding + + Args: + vocab: string containing the ordered sequence of supported characters + """ + @staticmethod + def ctc_best_path( + logits: torch.Tensor, vocab: str = VOCABS['french'], blank: int = 0, + ) -> List[Tuple[str, float]]: + """Implements best path decoding as shown by Graves (Dissertation, p63), highly inspired from + `_. + + Args: + logits: model output, shape: N x T x C + vocab: vocabulary to use + blank: index of blank label + + Returns: + A list of tuples: (word, confidence) + """ + + # Gather the most confident characters, and assign the smallest conf among those to the sequence prob + probs = F.softmax(logits, dim=-1).max(dim=-1).values.min(dim=1).values + + # collapse best path (using itertools.groupby), map to chars, join char list to string + words = [ + decode_sequence([k for k, _ in groupby(seq.tolist()) if k != blank], vocab) + for seq in torch.argmax(logits, dim=-1) + ] + + return list(zip(words, probs.tolist())) + + def __call__( # type: ignore[override] + self, + logits: torch.Tensor + ) -> List[Tuple[str, float]]: + """ + Performs decoding of raw output with CTC and decoding of CTC predictions + with label_to_idx mapping dictionnary + + Args: + logits: raw output of the model, shape (N, C + 1, seq_len) + + Returns: + A tuple of 2 lists: a list of str (words) and a list of float (probs) + + """ + # Decode CTC + return self.ctc_best_path(logits=logits, vocab=self.vocab, blank=len(self.vocab)) + + +class CRNN(RecognitionModel, nn.Module): + """Implements a CRNN architecture as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + Args: + feature_extractor: the backbone serving as feature extractor + vocab: vocabulary used for encoding + rnn_units: number of units in the LSTM layers + cfg: configuration dictionary + """ + + _children_names: List[str] = ['feat_extractor', 'decoder', 'linear', 'postprocessor'] + + def __init__( + self, + feature_extractor: nn.Module, + vocab: str, + rnn_units: int = 128, + input_shape: Tuple[int, int, int] = (3, 32, 128), + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + self.vocab = vocab + self.cfg = cfg + self.max_length = 32 + self.feat_extractor = feature_extractor + + # Resolve the input_size of the LSTM + self.feat_extractor.eval() + with torch.no_grad(): + out_shape = self.feat_extractor(torch.zeros((1, *input_shape))).shape + lstm_in = out_shape[1] * out_shape[2] + # Switch back to original mode + self.feat_extractor.train() + + self.decoder = nn.LSTM( + input_size=lstm_in, hidden_size=rnn_units, batch_first=True, num_layers=2, bidirectional=True, + ) + + # features units = 2 * rnn_units because bidirectional layers + self.linear = nn.Linear(in_features=2 * rnn_units, out_features=len(vocab) + 1) + + self.postprocessor = CTCPostProcessor(vocab=vocab) + + for n, m in self.named_modules(): + # Don't override the initialization of the backbone + if n.startswith('feat_extractor.'): + continue + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight.data, mode='fan_out', nonlinearity='relu') + if m.bias is not None: + m.bias.data.zero_() + elif isinstance(m, nn.BatchNorm2d): + m.weight.data.fill_(1.0) + m.bias.data.zero_() + + def compute_loss( + self, + model_output: torch.Tensor, + target: List[str], + ) -> torch.Tensor: + """Compute CTC loss for the model. + + Args: + gt: the encoded tensor with gt labels + model_output: predicted logits of the model + seq_len: lengths of each gt word inside the batch + + Returns: + The loss of the model on the batch + """ + gt, seq_len = self.build_target(target) + batch_len = model_output.shape[0] + input_length = model_output.shape[1] * torch.ones(size=(batch_len,), dtype=torch.int32) + # N x T x C -> T x N x C + logits = model_output.permute(1, 0, 2) + probs = F.log_softmax(logits, dim=-1) + ctc_loss = F.ctc_loss( + probs, + torch.from_numpy(gt), + input_length, + torch.tensor(seq_len, dtype=torch.int), + len(self.vocab), + zero_infinity=True, + ) + + return ctc_loss + + def forward( + self, + x: torch.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + ) -> Dict[str, Any]: + + features = self.feat_extractor(x) + # B x C x H x W --> B x C*H x W --> B x W x C*H + c, h, w = features.shape[1], features.shape[2], features.shape[3] + features_seq = torch.reshape(features, shape=(-1, h * c, w)) + features_seq = torch.transpose(features_seq, 1, 2) + logits, _ = self.decoder(features_seq) + logits = self.linear(logits) + + out: Dict[str, Any] = {} + if return_model_output: + out["out_map"] = logits + + if target is None or return_preds: + # Post-process boxes + out["preds"] = self.postprocessor(logits) + + if target is not None: + out['loss'] = self.compute_loss(logits, target) + + return out + + +def _crnn( + arch: str, + pretrained: bool, + backbone_fn: Callable[[Any], nn.Module], + pretrained_backbone: bool = True, + **kwargs: Any, +) -> CRNN: + + pretrained_backbone = pretrained_backbone and not pretrained + + # Feature extractor + feat_extractor = backbone_fn(pretrained=pretrained_backbone).features # type: ignore[call-arg] + + kwargs['vocab'] = kwargs.get('vocab', default_cfgs[arch]['vocab']) + kwargs['input_shape'] = kwargs.get('input_shape', default_cfgs[arch]['input_shape']) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg['vocab'] = kwargs['vocab'] + _cfg['input_shape'] = kwargs['input_shape'] + + # Build the model + model = CRNN(feat_extractor, cfg=_cfg, **kwargs) # type: ignore[arg-type] + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, _cfg['url']) + + return model + + +def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN: + """CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + Example:: + >>> import torch + >>> from doctr.models import crnn_vgg16_bn + >>> model = crnn_vgg16_bn(pretrained=True) + >>> input_tensor = torch.rand(1, 3, 32, 128) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + + Returns: + text recognition architecture + """ + + return _crnn('crnn_vgg16_bn', pretrained, vgg16_bn_r, **kwargs) + + +def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN: + """CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + Example:: + >>> import torch + >>> from doctr.models import crnn_mobilenet_v3_small + >>> model = crnn_mobilenet_v3_small(pretrained=True) + >>> input_tensor = torch.rand(1, 3, 32, 128) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + + Returns: + text recognition architecture + """ + + return _crnn('crnn_mobilenet_v3_small', pretrained, mobilenet_v3_small_r, **kwargs) + + +def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN: + """CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + Example:: + >>> import torch + >>> from doctr.models import crnn_mobilenet_v3_large + >>> model = crnn_mobilenet_v3_large(pretrained=True) + >>> input_tensor = torch.rand(1, 3, 32, 128) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + + Returns: + text recognition architecture + """ + + return _crnn('crnn_mobilenet_v3_large', pretrained, mobilenet_v3_large_r, **kwargs) diff --git a/doctr/models/recognition/crnn/tensorflow.py b/doctr/models/recognition/crnn/tensorflow.py new file mode 100644 index 0000000000..89b44f2a8e --- /dev/null +++ b/doctr/models/recognition/crnn/tensorflow.py @@ -0,0 +1,278 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +import tensorflow as tf +from tensorflow.keras import layers +from tensorflow.keras.models import Model, Sequential + +from doctr.datasets import VOCABS + +from ...classification import mobilenet_v3_large_r, mobilenet_v3_small_r, vgg16_bn_r +from ...utils.tensorflow import load_pretrained_params +from ..core import RecognitionModel, RecognitionPostProcessor + +__all__ = ['CRNN', 'crnn_vgg16_bn', 'crnn_mobilenet_v3_small', + 'crnn_mobilenet_v3_large'] + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'crnn_vgg16_bn': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (32, 128, 3), + 'vocab': VOCABS['legacy_french'], + 'url': 'https://github.com/mindee/doctr/releases/download/v0.3.0/crnn_vgg16_bn-76b7f2c6.zip', + }, + 'crnn_mobilenet_v3_small': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (32, 128, 3), + 'vocab': VOCABS['french'], + 'url': 'https://github.com/mindee/doctr/releases/download/v0.3.1/crnn_mobilenet_v3_small-7f36edec.zip', + }, + 'crnn_mobilenet_v3_large': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (32, 128, 3), + 'vocab': VOCABS['french'], + 'url': None, + }, +} + + +class CTCPostProcessor(RecognitionPostProcessor): + """ + Postprocess raw prediction of the model (logits) to a list of words using CTC decoding + + Args: + vocab: string containing the ordered sequence of supported characters + ignore_case: if True, ignore case of letters + ignore_accents: if True, ignore accents of letters + """ + + def __call__( + self, + logits: tf.Tensor + ) -> List[Tuple[str, float]]: + """ + Performs decoding of raw output with CTC and decoding of CTC predictions + with label_to_idx mapping dictionnary + + Args: + logits: raw output of the model, shape BATCH_SIZE X SEQ_LEN X NUM_CLASSES + 1 + + Returns: + A list of decoded words of length BATCH_SIZE + + """ + # Decode CTC + _decoded, _log_prob = tf.nn.ctc_beam_search_decoder( + tf.transpose(logits, perm=[1, 0, 2]), + tf.fill(logits.shape[0], logits.shape[1]), + beam_width=1, top_paths=1, + ) + out_idxs = tf.sparse.to_dense(_decoded[0], default_value=len(self.vocab)) + probs = tf.math.exp(tf.squeeze(_log_prob, axis=1)) + + # Map it to characters + _decoded_strings_pred = tf.strings.reduce_join( + inputs=tf.nn.embedding_lookup(tf.constant(self._embedding, dtype=tf.string), out_idxs), + axis=-1 + ) + _decoded_strings_pred = tf.strings.split(_decoded_strings_pred, "") + decoded_strings_pred = tf.sparse.to_dense(_decoded_strings_pred.to_sparse(), default_value='not valid')[:, 0] + word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] + + return list(zip(word_values, probs.numpy().tolist())) + + +class CRNN(RecognitionModel, Model): + """Implements a CRNN architecture as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + Args: + feature_extractor: the backbone serving as feature extractor + vocab: vocabulary used for encoding + rnn_units: number of units in the LSTM layers + cfg: configuration dictionary + """ + + _children_names: List[str] = ['feat_extractor', 'decoder', 'postprocessor'] + + def __init__( + self, + feature_extractor: tf.keras.Model, + vocab: str, + rnn_units: int = 128, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + # Initialize kernels + h, w, c = feature_extractor.output_shape[1:] + + super().__init__() + self.vocab = vocab + self.max_length = w + self.cfg = cfg + self.feat_extractor = feature_extractor + + self.decoder = Sequential( + [ + layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)), + layers.Bidirectional(layers.LSTM(units=rnn_units, return_sequences=True)), + layers.Dense(units=len(vocab) + 1) + ] + ) + self.decoder.build(input_shape=(None, w, h * c)) + + self.postprocessor = CTCPostProcessor(vocab=vocab) + + def compute_loss( + self, + model_output: tf.Tensor, + target: List[str], + ) -> tf.Tensor: + """Compute CTC loss for the model. + + Args: + model_output: predicted logits of the model + target: lengths of each gt word inside the batch + + Returns: + The loss of the model on the batch + """ + gt, seq_len = self.build_target(target) + batch_len = model_output.shape[0] + input_length = tf.fill((batch_len,), model_output.shape[1]) + ctc_loss = tf.nn.ctc_loss( + gt, model_output, seq_len, input_length, logits_time_major=False, blank_index=len(self.vocab) + ) + return ctc_loss + + def call( + self, + x: tf.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + + features = self.feat_extractor(x, **kwargs) + # B x H x W x C --> B x W x H x C + transposed_feat = tf.transpose(features, perm=[0, 2, 1, 3]) + w, h, c = transposed_feat.get_shape().as_list()[1:] + # B x W x H x C --> B x W x H * C + features_seq = tf.reshape(transposed_feat, shape=(-1, w, h * c)) + logits = self.decoder(features_seq, **kwargs) + + out: Dict[str, tf.Tensor] = {} + if return_model_output: + out["out_map"] = logits + + if target is None or return_preds: + # Post-process boxes + out["preds"] = self.postprocessor(logits) + + if target is not None: + out['loss'] = self.compute_loss(logits, target) + + return out + + +def _crnn( + arch: str, + pretrained: bool, + backbone_fn, + pretrained_backbone: bool = True, + input_shape: Optional[Tuple[int, int, int]] = None, + **kwargs: Any +) -> CRNN: + + pretrained_backbone = pretrained_backbone and not pretrained + + kwargs['vocab'] = kwargs.get('vocab', default_cfgs[arch]['vocab']) + + _cfg = deepcopy(default_cfgs[arch]) + _cfg['vocab'] = kwargs['vocab'] + _cfg['input_shape'] = input_shape or default_cfgs[arch]['input_shape'] + + feat_extractor = backbone_fn( + input_shape=_cfg['input_shape'], + include_top=False, + pretrained=pretrained_backbone, + ) + + # Build the model + model = CRNN(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, _cfg['url']) + + return model + + +def crnn_vgg16_bn(pretrained: bool = False, **kwargs: Any) -> CRNN: + """CRNN with a VGG-16 backbone as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import crnn_vgg16_bn + >>> model = crnn_vgg16_bn(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + + Returns: + text recognition architecture + """ + + return _crnn('crnn_vgg16_bn', pretrained, vgg16_bn_r, **kwargs) + + +def crnn_mobilenet_v3_small(pretrained: bool = False, **kwargs: Any) -> CRNN: + """CRNN with a MobileNet V3 Small backbone as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import crnn_mobilenet_v3_small + >>> model = crnn_mobilenet_v3_small(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + + Returns: + text recognition architecture + """ + + return _crnn('crnn_mobilenet_v3_small', pretrained, mobilenet_v3_small_r, **kwargs) + + +def crnn_mobilenet_v3_large(pretrained: bool = False, **kwargs: Any) -> CRNN: + """CRNN with a MobileNet V3 Large backbone as described in `"An End-to-End Trainable Neural Network for Image-based + Sequence Recognition and Its Application to Scene Text Recognition" `_. + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import crnn_mobilenet_v3_large + >>> model = crnn_mobilenet_v3_large(pretrained=True) + >>> input_tensor = tf.random.uniform(shape=[1, 32, 128, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + + Returns: + text recognition architecture + """ + + return _crnn('crnn_mobilenet_v3_large', pretrained, mobilenet_v3_large_r, **kwargs) diff --git a/doctr/models/recognition/master/__init__.py b/doctr/models/recognition/master/__init__.py new file mode 100644 index 0000000000..059f261e82 --- /dev/null +++ b/doctr/models/recognition/master/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/recognition/master/base.py b/doctr/models/recognition/master/base.py new file mode 100644 index 0000000000..d8d04729d1 --- /dev/null +++ b/doctr/models/recognition/master/base.py @@ -0,0 +1,57 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import List, Tuple + +import numpy as np + +from ....datasets import encode_sequences +from ..core import RecognitionPostProcessor + + +class _MASTER: + + vocab: str + max_length: int + + def build_target( + self, + gts: List[str], + ) -> Tuple[np.ndarray, List[int]]: + """Encode a list of gts sequences into a np array and gives the corresponding* + sequence lengths. + + Args: + gts: list of ground-truth labels + + Returns: + A tuple of 2 tensors: Encoded labels and sequence lengths (for each entry of the batch) + """ + encoded = encode_sequences( + sequences=gts, + vocab=self.vocab, + target_size=self.max_length, + eos=len(self.vocab), + sos=len(self.vocab) + 1, + pad=len(self.vocab) + 2, + ) + seq_len = [len(word) for word in gts] + return encoded, seq_len + + +class _MASTERPostProcessor(RecognitionPostProcessor): + """Abstract class to postprocess the raw output of the model + + Args: + vocab: string containing the ordered sequence of supported characters + """ + + def __init__( + self, + vocab: str, + ) -> None: + + super().__init__(vocab) + self._embedding = list(vocab) + [''] + [''] + [''] diff --git a/doctr/models/recognition/master/pytorch.py b/doctr/models/recognition/master/pytorch.py new file mode 100644 index 0000000000..6b9b37d005 --- /dev/null +++ b/doctr/models/recognition/master/pytorch.py @@ -0,0 +1,295 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import functional as F +from torchvision.models._utils import IntermediateLayerGetter + +from doctr.datasets import VOCABS +from doctr.models.classification import magc_resnet31 + +from ...utils.pytorch import load_pretrained_params +from ..transformer.pytorch import Decoder, positional_encoding +from .base import _MASTER, _MASTERPostProcessor + +__all__ = ['MASTER', 'master'] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'master': { + 'mean': (.5, .5, .5), + 'std': (1., 1., 1.), + 'input_shape': (3, 48, 160), + 'vocab': VOCABS['french'], + 'url': None, + }, +} + + +class MASTER(_MASTER, nn.Module): + """Implements MASTER as described in paper: `_. + Implementation based on the official Pytorch implementation: `_. + + Args: + feature_extractor: the backbone serving as feature extractor + vocab: vocabulary, (without EOS, SOS, PAD) + d_model: d parameter for the transformer decoder + dff: depth of the pointwise feed-forward layer + num_heads: number of heads for the mutli-head attention module + num_layers: number of decoder layers to stack + max_length: maximum length of character sequence handled by the model + dropout: dropout probability of the decoder + input_shape: size of the image inputs + cfg: dictionary containing information about the model + """ + + feature_pe: torch.Tensor + + def __init__( + self, + feature_extractor: nn.Module, + vocab: str, + d_model: int = 512, + dff: int = 2048, + num_heads: int = 8, # number of heads in the transformer decoder + num_layers: int = 3, + max_length: int = 50, + dropout: float = 0.2, + input_shape: Tuple[int, int, int] = (3, 48, 160), + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + + self.max_length = max_length + self.vocab = vocab + self.cfg = cfg + self.vocab_size = len(vocab) + self.num_heads = num_heads + + self.feat_extractor = feature_extractor + self.seq_embedding = nn.Embedding(self.vocab_size + 3, d_model) # 3 more for EOS/SOS/PAD + + self.decoder = Decoder( + num_layers=num_layers, + d_model=d_model, + num_heads=num_heads, + dff=dff, + vocab_size=self.vocab_size, + maximum_position_encoding=max_length, + dropout=dropout, + ) + self.register_buffer('feature_pe', positional_encoding(input_shape[1] * input_shape[2], d_model)) + self.linear = nn.Linear(d_model, self.vocab_size + 3) + + self.postprocessor = MASTERPostProcessor(vocab=self.vocab) + + for n, m in self.named_modules(): + # Don't override the initialization of the backbone + if n.startswith('feat_extractor.'): + continue + if isinstance(m, nn.Conv2d): + nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') + elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)): + nn.init.constant_(m.weight, 1) + nn.init.constant_(m.bias, 0) + + def make_mask(self, target: torch.Tensor) -> torch.Tensor: + size = target.size(1) + look_ahead_mask = ~ (torch.triu(torch.ones(size, size, device=target.device)) == 1).transpose(0, 1)[:, None] + target_padding_mask = torch.eq(target, self.vocab_size + 2) # Pad symbol + combined_mask = target_padding_mask | look_ahead_mask + return torch.tile(combined_mask.permute(1, 0, 2), (self.num_heads, 1, 1)) + + @staticmethod + def compute_loss( + model_output: torch.Tensor, + gt: torch.Tensor, + seq_len: torch.Tensor, + ) -> torch.Tensor: + """Compute categorical cross-entropy loss for the model. + Sequences are masked after the EOS character. + + Args: + gt: the encoded tensor with gt labels + model_output: predicted logits of the model + seq_len: lengths of each gt word inside the batch + + Returns: + The loss of the model on the batch + """ + # Input length : number of timesteps + input_len = model_output.shape[1] + # Add one for additional token (sos disappear in shift!) + seq_len = seq_len + 1 + # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]! + # The "masked" first gt char is . Delete last logit of the model output. + cce = F.cross_entropy(model_output[:, :-1, :].permute(0, 2, 1), gt[:, 1:], reduction='none') + # Compute mask, remove 1 timestep here as well + mask_2d = torch.arange(input_len - 1, device=model_output.device)[None, :] >= seq_len[:, None] + cce[mask_2d] = 0 + + ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype) + return ce_loss.mean() + + def forward( + self, + x: torch.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + ) -> Dict[str, Any]: + """Call function for training + + Args: + x: images + target: list of str labels + return_model_output: if True, return logits + return_preds: if True, decode logits + + Returns: + A torch tensor, containing logits + """ + + # Encode + features = self.feat_extractor(x)['features'] + b, c, h, w = features.shape[:4] + # --> (N, H * W, C) + features = features.reshape((b, c, h * w)).permute(0, 2, 1) + encoded = features + self.feature_pe[:, :h * w, :] + + out: Dict[str, Any] = {} + + if target is not None: + # Compute target: tensor of gts and sequence lengths + _gt, _seq_len = self.build_target(target) + gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len) + gt, seq_len = gt.to(x.device), seq_len.to(x.device) + + if self.training: + if target is None: + raise AssertionError("In training mode, you need to pass a value to 'target'") + tgt_mask = self.make_mask(gt) + # Compute logits + output = self.decoder(gt, encoded, tgt_mask, None) + logits = self.linear(output) + + else: + logits = self.decode(encoded) + + if target is not None: + out['loss'] = self.compute_loss(logits, gt, seq_len) + + if return_model_output: + out['out_map'] = logits + + if return_preds: + predictions = self.postprocessor(logits) + out['preds'] = predictions + + return out + + def decode(self, encoded: torch.Tensor) -> torch.Tensor: + """Decode function for prediction + + Args: + encoded: input tensor + + Return: + A Tuple of torch.Tensor: predictions, logits + """ + b = encoded.size(0) + + # Padding symbol + SOS at the beginning + ys = torch.full((b, self.max_length), self.vocab_size + 2, dtype=torch.long, device=encoded.device) + ys[:, 0] = self.vocab_size + 1 + + # Final dimension include EOS/SOS/PAD + for i in range(self.max_length - 1): + ys_mask = self.make_mask(ys) + output = self.decoder(ys, encoded, ys_mask, None) + logits = self.linear(output) + prob = F.softmax(logits, dim=-1) + next_word = torch.max(prob, dim=-1).indices + ys[:, i + 1] = next_word[:, i + 1] + + # Shape (N, max_length, vocab_size + 1) + return logits + + +class MASTERPostProcessor(_MASTERPostProcessor): + """Post processor for MASTER architectures + """ + + def __call__( + self, + logits: torch.Tensor, + ) -> List[Tuple[str, float]]: + # compute pred with argmax for attention models + out_idxs = logits.argmax(-1) + # N x L + probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1) + # Take the minimum confidence of the sequence + probs = probs.min(dim=1).values.detach().cpu() + + # Manual decoding + word_values = [ + ''.join(self._embedding[idx] for idx in encoded_seq).split("")[0] + for encoded_seq in out_idxs.cpu().numpy() + ] + + return list(zip(word_values, probs.numpy().tolist())) + + +def _master( + arch: str, + pretrained: bool, + backbone_fn: Callable[[bool], nn.Module], + layer: str, + pretrained_backbone: bool = True, + **kwargs: Any +) -> MASTER: + + pretrained_backbone = pretrained_backbone and not pretrained + + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg['input_shape'] = kwargs.get('input_shape', _cfg['input_shape']) + _cfg['vocab'] = kwargs.get('vocab', _cfg['vocab']) + + kwargs['vocab'] = _cfg['vocab'] + kwargs['input_shape'] = _cfg['input_shape'] + + # Build the model + feat_extractor = IntermediateLayerGetter( + backbone_fn(pretrained_backbone), + {layer: 'features'}, + ) + model = MASTER(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]['url']) + + return model + + +def master(pretrained: bool = False, **kwargs: Any) -> MASTER: + """MASTER as described in paper: `_. + Example:: + >>> import torch + >>> from doctr.models import master + >>> model = master(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 48, 160)) + >>> out = model(input_tensor) + Args: + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + Returns: + text recognition architecture + """ + + return _master('master', pretrained, magc_resnet31, '10', **kwargs) diff --git a/doctr/models/recognition/master/tensorflow.py b/doctr/models/recognition/master/tensorflow.py new file mode 100644 index 0000000000..30a1ccc662 --- /dev/null +++ b/doctr/models/recognition/master/tensorflow.py @@ -0,0 +1,300 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +import tensorflow as tf +from tensorflow.keras import Model, layers + +from doctr.datasets import VOCABS +from doctr.models.classification import magc_resnet31 + +from ...utils.tensorflow import load_pretrained_params +from ..transformer.tensorflow import Decoder, create_look_ahead_mask, create_padding_mask, positional_encoding +from .base import _MASTER, _MASTERPostProcessor + +__all__ = ['MASTER', 'master'] + + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'master': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (32, 128, 3), + 'vocab': VOCABS['legacy_french'], + 'url': 'https://github.com/mindee/doctr/releases/download/v0.3.0/master-bade6eae.zip', + }, +} + + +class MASTER(_MASTER, Model): + + """Implements MASTER as described in paper: `_. + Implementation based on the official TF implementation: `_. + + Args: + feature_extractor: the backbone serving as feature extractor + vocab: vocabulary, (without EOS, SOS, PAD) + d_model: d parameter for the transformer decoder + dff: depth of the pointwise feed-forward layer + num_heads: number of heads for the mutli-head attention module + num_layers: number of decoder layers to stack + max_length: maximum length of character sequence handled by the model + dropout: dropout probability of the decoder + input_shape: size of the image inputs + cfg: dictionary containing information about the model + """ + + def __init__( + self, + feature_extractor: tf.keras.Model, + vocab: str, + d_model: int = 512, + dff: int = 2048, + num_heads: int = 8, # number of heads in the transformer decoder + num_layers: int = 3, + max_length: int = 50, + dropout: float = 0.2, + input_shape: Tuple[int, int, int] = (32, 128, 3), + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + super().__init__() + + self.vocab = vocab + self.max_length = max_length + self.cfg = cfg + self.vocab_size = len(vocab) + + self.feat_extractor = feature_extractor + self.seq_embedding = layers.Embedding(self.vocab_size + 3, d_model) # 3 more classes: EOS/PAD/SOS + + self.decoder = Decoder( + num_layers=num_layers, + d_model=d_model, + num_heads=num_heads, + dff=dff, + vocab_size=self.vocab_size, + maximum_position_encoding=max_length, + dropout=dropout, + ) + self.feature_pe = positional_encoding(input_shape[0] * input_shape[1], d_model) + self.linear = layers.Dense(self.vocab_size + 3, kernel_initializer=tf.initializers.he_uniform()) + + self.postprocessor = MASTERPostProcessor(vocab=self.vocab) + + def make_mask(self, target: tf.Tensor) -> tf.Tensor: + look_ahead_mask = create_look_ahead_mask(tf.shape(target)[1]) + target_padding_mask = create_padding_mask(target, self.vocab_size + 2) # Pad symbol + combined_mask = tf.maximum(target_padding_mask, look_ahead_mask) + return combined_mask + + @staticmethod + def compute_loss( + model_output: tf.Tensor, + gt: tf.Tensor, + seq_len: List[int], + ) -> tf.Tensor: + """Compute categorical cross-entropy loss for the model. + Sequences are masked after the EOS character. + + Args: + gt: the encoded tensor with gt labels + model_output: predicted logits of the model + seq_len: lengths of each gt word inside the batch + + Returns: + The loss of the model on the batch + """ + # Input length : number of timesteps + input_len = tf.shape(model_output)[1] + # Add one for additional token (sos disappear in shift!) + seq_len = tf.cast(seq_len, tf.int32) + 1 + # One-hot gt labels + oh_gt = tf.one_hot(gt, depth=model_output.shape[2]) + # Compute loss: don't forget to shift gt! Otherwise the model learns to output the gt[t-1]! + # The "masked" first gt char is . Delete last logit of the model output. + cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt[:, 1:, :], model_output[:, :-1, :]) + # Compute mask + mask_values = tf.zeros_like(cce) + mask_2d = tf.sequence_mask(seq_len, input_len - 1) # delete the last mask timestep as well + masked_loss = tf.where(mask_2d, cce, mask_values) + ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype)) + + return tf.expand_dims(ce_loss, axis=1) + + def call( + self, + x: tf.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + """Call function for training + + Args: + x: images + target: list of str labels + return_model_output: if True, return logits + return_preds: if True, decode logits + + Return: + A dictionnary containing eventually loss, logits and predictions. + """ + + # Encode + feature = self.feat_extractor(x, **kwargs) + b, h, w, c = (tf.shape(feature)[i] for i in range(4)) + feature = tf.reshape(feature, shape=(b, h * w, c)) + encoded = feature + tf.cast(self.feature_pe[:, :h * w, :], dtype=feature.dtype) + + out: Dict[str, tf.Tensor] = {} + + if target is not None: + # Compute target: tensor of gts and sequence lengths + gt, seq_len = self.build_target(target) + + if kwargs.get('training', False): + if target is None: + raise AssertionError("In training mode, you need to pass a value to 'target'") + tgt_mask = self.make_mask(gt) + # Compute logits + output = self.decoder(gt, encoded, tgt_mask, None, **kwargs) + logits = self.linear(output, **kwargs) + + else: + # When not training, we want to compute logits in with the decoder, although + # we have access to gts (we need gts to compute the loss, but not in the decoder) + logits = self.decode(encoded, **kwargs) + + if target is not None: + out['loss'] = self.compute_loss(logits, gt, seq_len) + + if return_model_output: + out['out_map'] = logits + + if return_preds: + predictions = self.postprocessor(logits) + out['preds'] = predictions + + return out + + def decode(self, encoded: tf.Tensor, **kwargs: Any) -> tf.Tensor: + """Decode function for prediction + + Args: + encoded: encoded features + + Return: + A Tuple of tf.Tensor: predictions, logits + """ + b = tf.shape(encoded)[0] + max_len = tf.constant(self.max_length, dtype=tf.int32) + start_symbol = tf.constant(self.vocab_size + 1, dtype=tf.int32) # SOS + padding_symbol = tf.constant(self.vocab_size + 2, dtype=tf.int32) # PAD + + ys = tf.fill(dims=(b, max_len - 1), value=padding_symbol) + start_vector = tf.fill(dims=(b, 1), value=start_symbol) + ys = tf.concat([start_vector, ys], axis=-1) + + logits = tf.zeros(shape=(b, max_len - 1, self.vocab_size + 3), dtype=encoded.dtype) # 3 symbols + # max_len = len + 2 (sos + eos) + for i in range(self.max_length - 1): + ys_mask = self.make_mask(ys) + output = self.decoder(ys, encoded, ys_mask, None, **kwargs) + logits = self.linear(output, **kwargs) + prob = tf.nn.softmax(logits, axis=-1) + next_word = tf.argmax(prob, axis=-1, output_type=ys.dtype) + # ys.shape = B, T + i_mesh, j_mesh = tf.meshgrid(tf.range(b), tf.range(max_len), indexing='ij') + indices = tf.stack([i_mesh[:, i + 1], j_mesh[:, i + 1]], axis=1) + + ys = tf.tensor_scatter_nd_update(ys, indices, next_word[:, i + 1]) + + # final_logits of shape (N, max_length - 1, vocab_size + 1) (whithout sos) + return logits + + +class MASTERPostProcessor(_MASTERPostProcessor): + """Post processor for MASTER architectures + + Args: + vocab: string containing the ordered sequence of supported characters + ignore_case: if True, ignore case of letters + ignore_accents: if True, ignore accents of letters + """ + + def __call__( + self, + logits: tf.Tensor, + ) -> List[Tuple[str, float]]: + # compute pred with argmax for attention models + out_idxs = tf.math.argmax(logits, axis=2) + # N x L + probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2) + # Take the minimum confidence of the sequence + probs = tf.math.reduce_min(probs, axis=1) + + # decode raw output of the model with tf_label_to_idx + out_idxs = tf.cast(out_idxs, dtype='int32') + embedding = tf.constant(self._embedding, dtype=tf.string) + decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1) + decoded_strings_pred = tf.strings.split(decoded_strings_pred, "") + decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value='not valid')[:, 0] + word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] + + return list(zip(word_values, probs.numpy().tolist())) + + +def _master( + arch: str, + pretrained: bool, + backbone_fn, + pretrained_backbone: bool = True, + **kwargs: Any +) -> MASTER: + + pretrained_backbone = pretrained_backbone and not pretrained + + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg['input_shape'] = kwargs.get('input_shape', _cfg['input_shape']) + _cfg['vocab'] = kwargs.get('vocab', _cfg['vocab']) + + kwargs['vocab'] = _cfg['vocab'] + kwargs['input_shape'] = _cfg['input_shape'] + + # Build the model + model = MASTER( + backbone_fn(pretrained=pretrained_backbone, input_shape=_cfg['input_shape'], include_top=False), + cfg=_cfg, + **kwargs, + ) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]['url']) + + return model + + +def master(pretrained: bool = False, **kwargs: Any) -> MASTER: + """MASTER as described in paper: `_. + + Example:: + >>> import tensorflow as tf + >>> from doctr.models import master + >>> model = master(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 48, 160, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + + Returns: + text recognition architecture + """ + + return _master('master', pretrained, magc_resnet31, **kwargs) diff --git a/doctr/models/recognition/predictor/__init__.py b/doctr/models/recognition/predictor/__init__.py new file mode 100644 index 0000000000..6a3fee30ac --- /dev/null +++ b/doctr/models/recognition/predictor/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available + +if is_tf_available(): + from .tensorflow import * +else: + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/recognition/predictor/_utils.py b/doctr/models/recognition/predictor/_utils.py new file mode 100644 index 0000000000..b201202ddf --- /dev/null +++ b/doctr/models/recognition/predictor/_utils.py @@ -0,0 +1,89 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import List, Tuple, Union + +import numpy as np + +from ..utils import merge_multi_strings + +__all__ = ['split_crops', 'remap_preds'] + + +def split_crops( + crops: List[np.ndarray], + max_ratio: float, + target_ratio: int, + dilation: float, + channels_last: bool = True, +) -> Tuple[List[np.ndarray], List[Union[int, Tuple[int, int]]], bool]: + """Chunk crops horizontally to match a given aspect ratio + + Args: + crops: list of numpy array of shape (H, W, 3) if channels_last or (3, H, W) otherwise + max_ratio: the maximum aspect ratio that won't trigger the chunk + target_ratio: when crops are chunked, they will be chunked to match this aspect ratio + dilation: the width dilation of final chunks (to provide some overlaps) + channels_last: whether the numpy array has dimensions in channels last order + + Returns: + a tuple with the new crops, their mapping, and a boolean specifying whether any remap is required + """ + + _remap_required = False + crop_map: List[Union[int, Tuple[int, int]]] = [] + new_crops: List[np.ndarray] = [] + for crop in crops: + h, w = crop.shape[:2] if channels_last else crop.shape[-2:] + aspect_ratio = w / h + if aspect_ratio > max_ratio: + # Determine the number of crops, reference aspect ratio = 4 = 128 / 32 + num_subcrops = int(aspect_ratio // target_ratio) + # Find the new widths, additional dilation factor to overlap crops + width = dilation * w / num_subcrops + centers = [(w / num_subcrops) * (1 / 2 + idx) for idx in range(num_subcrops)] + # Get the crops + if channels_last: + _crops = [ + crop[:, max(0, int(round(center - width / 2))): min(w - 1, int(round(center + width / 2))), :] + for center in centers + ] + else: + _crops = [ + crop[:, :, max(0, int(round(center - width / 2))): min(w - 1, int(round(center + width / 2)))] + for center in centers + ] + # Avoid sending zero-sized crops + _crops = [crop for crop in _crops if all(s > 0 for s in crop.shape)] + # Record the slice of crops + crop_map.append((len(new_crops), len(new_crops) + len(_crops))) + new_crops.extend(_crops) + # At least one crop will require merging + _remap_required = True + else: + crop_map.append(len(new_crops)) + new_crops.append(crop) + + return new_crops, crop_map, _remap_required + + +def remap_preds( + preds: List[Tuple[str, float]], + crop_map: List[Union[int, Tuple[int, int]]], + dilation: float +) -> List[Tuple[str, float]]: + remapped_out = [] + for _idx in crop_map: + # Crop hasn't been split + if isinstance(_idx, int): + remapped_out.append(preds[_idx]) + else: + # unzip + vals, probs = zip(*preds[_idx[0]: _idx[1]]) + # Merge the string values + remapped_out.append( + (merge_multi_strings(vals, dilation), min(probs)) # type: ignore[arg-type] + ) + return remapped_out diff --git a/doctr/models/recognition/predictor/pytorch.py b/doctr/models/recognition/predictor/pytorch.py new file mode 100644 index 0000000000..28233e8695 --- /dev/null +++ b/doctr/models/recognition/predictor/pytorch.py @@ -0,0 +1,85 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, List, Tuple, Union + +import numpy as np +import torch +from torch import nn + +from doctr.models.preprocessor import PreProcessor + +from ._utils import remap_preds, split_crops + +__all__ = ['RecognitionPredictor'] + + +class RecognitionPredictor(nn.Module): + """Implements an object able to identify character sequences in images + + Args: + pre_processor: transform inputs for easier batched model inference + model: core detection architecture + split_wide_crops: wether to use crop splitting for high aspect ratio crops + """ + + def __init__( + self, + pre_processor: PreProcessor, + model: nn.Module, + split_wide_crops: bool = True, + ) -> None: + + super().__init__() + self.pre_processor = pre_processor + self.model = model.eval() + self.split_wide_crops = split_wide_crops + self.critical_ar = 8 # Critical aspect ratio + self.dil_factor = 1.4 # Dilation factor to overlap the crops + self.target_ar = 6 # Target aspect ratio + + @torch.no_grad() + def forward( + self, + crops: List[Union[np.ndarray, torch.Tensor]], + **kwargs: Any, + ) -> List[Tuple[str, float]]: + + if len(crops) == 0: + return [] + # Dimension check + if any(crop.ndim != 3 for crop in crops): + raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.") + + # Split crops that are too wide + remapped = False + if self.split_wide_crops: + new_crops, crop_map, remapped = split_crops( + crops, + self.critical_ar, + self.target_ar, + self.dil_factor, + isinstance(crops[0], np.ndarray) + ) + if remapped: + crops = new_crops + + # Resize & batch them + processed_batches = self.pre_processor(crops) + + # Forward it + raw = [ + self.model(batch, return_preds=True, **kwargs)['preds'] # type: ignore[operator] + for batch in processed_batches + ] + + # Process outputs + out = [charseq for batch in raw for charseq in batch] + + # Remap crops + if self.split_wide_crops and remapped: + out = remap_preds(out, crop_map, self.dil_factor) + + return out diff --git a/doctr/models/recognition/predictor/tensorflow.py b/doctr/models/recognition/predictor/tensorflow.py new file mode 100644 index 0000000000..75dd937888 --- /dev/null +++ b/doctr/models/recognition/predictor/tensorflow.py @@ -0,0 +1,81 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any, List, Tuple, Union + +import numpy as np +import tensorflow as tf + +from doctr.models.preprocessor import PreProcessor +from doctr.utils.repr import NestedObject + +from ..core import RecognitionModel +from ._utils import remap_preds, split_crops + +__all__ = ['RecognitionPredictor'] + + +class RecognitionPredictor(NestedObject): + """Implements an object able to identify character sequences in images + + Args: + pre_processor: transform inputs for easier batched model inference + model: core detection architecture + split_wide_crops: wether to use crop splitting for high aspect ratio crops + """ + + _children_names: List[str] = ['pre_processor', 'model'] + + def __init__( + self, + pre_processor: PreProcessor, + model: RecognitionModel, + split_wide_crops: bool = True, + ) -> None: + + super().__init__() + self.pre_processor = pre_processor + self.model = model + self.split_wide_crops = split_wide_crops + self.critical_ar = 8 # Critical aspect ratio + self.dil_factor = 1.4 # Dilation factor to overlap the crops + self.target_ar = 6 # Target aspect ratio + + def __call__( + self, + crops: List[Union[np.ndarray, tf.Tensor]], + **kwargs: Any, + ) -> List[Tuple[str, float]]: + + if len(crops) == 0: + return [] + # Dimension check + if any(crop.ndim != 3 for crop in crops): + raise ValueError("incorrect input shape: all crops are expected to be multi-channel 2D images.") + + # Split crops that are too wide + remapped = False + if self.split_wide_crops: + new_crops, crop_map, remapped = split_crops(crops, self.critical_ar, self.target_ar, self.dil_factor) + if remapped: + crops = new_crops + + # Resize & batch them + processed_batches = self.pre_processor(crops) + + # Forward it + raw = [ + self.model(batch, return_preds=True, training=False, **kwargs)['preds'] # type: ignore[operator] + for batch in processed_batches + ] + + # Process outputs + out = [charseq for batch in raw for charseq in batch] + + # Remap crops + if self.split_wide_crops and remapped: + out = remap_preds(out, crop_map, self.dil_factor) + + return out diff --git a/doctr/models/recognition/sar/__init__.py b/doctr/models/recognition/sar/__init__.py new file mode 100644 index 0000000000..059f261e82 --- /dev/null +++ b/doctr/models/recognition/sar/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/recognition/sar/pytorch.py b/doctr/models/recognition/sar/pytorch.py new file mode 100644 index 0000000000..a316cb8140 --- /dev/null +++ b/doctr/models/recognition/sar/pytorch.py @@ -0,0 +1,323 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Callable, Dict, List, Optional, Tuple + +import torch +from torch import nn +from torch.nn import functional as F +from torchvision.models._utils import IntermediateLayerGetter + +from doctr.datasets import VOCABS + +from ...classification import resnet31 +from ...utils.pytorch import load_pretrained_params +from ..core import RecognitionModel, RecognitionPostProcessor + +__all__ = ['SAR', 'sar_resnet31'] + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'sar_resnet31': { + 'mean': (.5, .5, .5), + 'std': (1., 1., 1.), + 'input_shape': (3, 32, 128), + 'vocab': VOCABS['legacy_french'], + 'url': None, + }, +} + + +class AttentionModule(nn.Module): + + def __init__(self, feat_chans: int, state_chans: int, attention_units: int) -> None: + super().__init__() + self.feat_conv = nn.Conv2d(feat_chans, attention_units, 3, padding=1) + # No need to add another bias since both tensors are summed together + self.state_conv = nn.Conv2d(state_chans, attention_units, 1, bias=False) + self.attention_projector = nn.Conv2d(attention_units, 1, 1, bias=False) + + def forward(self, features: torch.Tensor, hidden_state: torch.Tensor) -> torch.Tensor: + # shape (N, vgg_units, H, W) -> (N, attention_units, H, W) + feat_projection = self.feat_conv(features) + # shape (N, rnn_units, 1, 1) -> (N, attention_units, 1, 1) + state_projection = self.state_conv(hidden_state) + projection = torch.tanh(feat_projection + state_projection) + # shape (N, attention_units, H, W) -> (N, 1, H, W) + attention = self.attention_projector(projection) + # shape (N, 1, H, W) -> (N, H * W) + attention = torch.flatten(attention, 1) + # shape (N, H * W) -> (N, 1, H, W) + attention = torch.softmax(attention, 1).reshape(-1, 1, features.shape[-2], features.shape[-1]) + + glimpse = (features * attention).sum(dim=(2, 3)) + + return glimpse + + +class SARDecoder(nn.Module): + """Implements decoder module of the SAR model + + Args: + rnn_units: number of hidden units in recurrent cells + max_length: maximum length of a sequence + vocab_size: number of classes in the model alphabet + embedding_units: number of hidden embedding units + attention_units: number of hidden attention units + num_decoder_layers: number of LSTM layers to stack + + """ + def __init__( + self, + rnn_units: int, + max_length: int, + vocab_size: int, + embedding_units: int, + attention_units: int, + num_decoder_layers: int = 2, + feat_chans: int = 512, + ) -> None: + + super().__init__() + self.vocab_size = vocab_size + self.lstm_cells = nn.ModuleList([ + nn.LSTMCell(rnn_units, rnn_units) for _ in range(num_decoder_layers) + ]) + self.embed = nn.Linear(self.vocab_size + 1, embedding_units, bias=False) + self.attention_module = AttentionModule(feat_chans, rnn_units, attention_units) + self.output_dense = nn.Linear(2 * rnn_units, vocab_size + 1) + self.max_length = max_length + + def forward( + self, + features: torch.Tensor, + holistic: torch.Tensor, + gt: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + # initialize states (each of shape (N, rnn_units)) + hx = [None, None] + # Initialize with the index of virtual START symbol (placed after so that the one-hot is only zeros) + symbol = torch.zeros((features.shape[0], self.vocab_size + 1), device=features.device, dtype=features.dtype) + logits_list = [] + for t in range(self.max_length + 1): # keep 1 step for + + # one-hot symbol with depth vocab_size + 1 + # embeded_symbol: shape (N, embedding_units) + embeded_symbol = self.embed(symbol) + + hx[0] = self.lstm_cells[0](embeded_symbol, hx[0]) + hx[1] = self.lstm_cells[1](hx[0][0], hx[1]) # type: ignore[index] + logits, _ = hx[1] # type: ignore[misc] + + glimpse = self.attention_module( + features, logits.unsqueeze(-1).unsqueeze(-1), # type: ignore[has-type] + ) + # logits: shape (N, rnn_units), glimpse: shape (N, 1) + logits = torch.cat([logits, glimpse], 1) # type: ignore[has-type] + # shape (N, rnn_units + 1) -> (N, vocab_size + 1) + logits = self.output_dense(logits) + # update symbol with predicted logits for t+1 step + if gt is not None: + _symbol = gt[:, t] # type: ignore[index] + else: + _symbol = logits.argmax(-1) + symbol = F.one_hot(_symbol, self.vocab_size + 1).to(dtype=features.dtype) + logits_list.append(logits) + outputs = torch.stack(logits_list, 1) # shape (N, max_length + 1, vocab_size + 1) + + return outputs + + +class SAR(nn.Module, RecognitionModel): + """Implements a SAR architecture as described in `"Show, Attend and Read:A Simple and Strong Baseline for + Irregular Text Recognition" `_. + + Args: + feature_extractor: the backbone serving as feature extractor + vocab: vocabulary used for encoding + rnn_units: number of hidden units in both encoder and decoder LSTM + embedding_units: number of embedding units + attention_units: number of hidden units in attention module + max_length: maximum word length handled by the model + num_decoders: number of LSTM to stack in decoder layer + dropout_prob: dropout probability of the encoder LSTM + cfg: default setup dict of the model + """ + + def __init__( + self, + feature_extractor, + vocab: str, + rnn_units: int = 512, + embedding_units: int = 512, + attention_units: int = 512, + max_length: int = 30, + num_decoders: int = 2, + dropout_prob: float = 0., + input_shape: Tuple[int, int, int] = (3, 32, 128), + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + + super().__init__() + self.vocab = vocab + self.cfg = cfg + + self.max_length = max_length + 1 # Add 1 timestep for EOS after the longest word + + self.feat_extractor = feature_extractor + + # Size the LSTM + self.feat_extractor.eval() + with torch.no_grad(): + out_shape = self.feat_extractor(torch.zeros((1, *input_shape)))['features'].shape + # Switch back to original mode + self.feat_extractor.train() + + self.encoder = nn.LSTM(out_shape[-1], rnn_units, 2, batch_first=True, dropout=dropout_prob) + + self.decoder = SARDecoder( + rnn_units, max_length, len(vocab), embedding_units, attention_units, num_decoders, out_shape[1], + ) + + self.postprocessor = SARPostProcessor(vocab=vocab) + + def forward( + self, + x: torch.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + ) -> Dict[str, Any]: + + features = self.feat_extractor(x)['features'] + pooled_features = features.max(dim=-2).values # vertical max pooling + _, (encoded, _) = self.encoder(pooled_features) + encoded = encoded[-1] + if target is not None: + _gt, _seq_len = self.build_target(target) + gt, seq_len = torch.from_numpy(_gt).to(dtype=torch.long), torch.tensor(_seq_len) # type: ignore[assignment] + gt, seq_len = gt.to(x.device), seq_len.to(x.device) + decoded_features = self.decoder(features, encoded, gt=None if target is None else gt) + + out: Dict[str, Any] = {} + if return_model_output: + out["out_map"] = decoded_features + + if target is None or return_preds: + # Post-process boxes + out["preds"] = self.postprocessor(decoded_features) + + if target is not None: + out['loss'] = self.compute_loss(decoded_features, gt, seq_len) # type: ignore[arg-type] + + return out + + @staticmethod + def compute_loss( + model_output: torch.Tensor, + gt: torch.Tensor, + seq_len: torch.Tensor, + ) -> torch.Tensor: + """Compute categorical cross-entropy loss for the model. + Sequences are masked after the EOS character. + + Args: + gt: the encoded tensor with gt labels + model_output: predicted logits of the model + seq_len: lengths of each gt word inside the batch + + Returns: + The loss of the model on the batch + """ + # Input length : number of timesteps + input_len = model_output.shape[1] + # Add one for additional token + seq_len = seq_len + 1 + # Compute loss + cce = F.cross_entropy(model_output.permute(0, 2, 1), gt, reduction='none') + # Compute mask + mask_2d = torch.arange(input_len, device=model_output.device)[None, :] < seq_len[:, None] + cce[mask_2d] = 0 + + ce_loss = cce.sum(1) / seq_len.to(dtype=model_output.dtype) + return ce_loss.mean() + + +class SARPostProcessor(RecognitionPostProcessor): + """Post processor for SAR architectures""" + + def __call__( + self, + logits: torch.Tensor, + ) -> List[Tuple[str, float]]: + # compute pred with argmax for attention models + out_idxs = logits.argmax(-1) + # N x L + probs = torch.gather(torch.softmax(logits, -1), -1, out_idxs.unsqueeze(-1)).squeeze(-1) + # Take the minimum confidence of the sequence + probs = probs.min(dim=1).values.detach().cpu() + + # Manual decoding + word_values = [ + ''.join(self._embedding[idx] for idx in encoded_seq).split("")[0] + for encoded_seq in out_idxs.detach().cpu().numpy() + ] + + return list(zip(word_values, probs.numpy().tolist())) + + +def _sar( + arch: str, + pretrained: bool, + backbone_fn: Callable[[bool], nn.Module], + layer: str, + pretrained_backbone: bool = True, + **kwargs: Any +) -> SAR: + + pretrained_backbone = pretrained_backbone and not pretrained + + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg['vocab'] = kwargs.get('vocab', _cfg['vocab']) + _cfg['input_shape'] = kwargs.get('input_shape', _cfg['input_shape']) + + # Feature extractor + feat_extractor = IntermediateLayerGetter( + backbone_fn(pretrained_backbone), + {layer: 'features'}, + ) + kwargs['vocab'] = _cfg['vocab'] + kwargs['input_shape'] = _cfg['input_shape'] + + # Build the model + model = SAR(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]['url']) + + return model + + +def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR: + """SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong + Baseline for Irregular Text Recognition" `_. + + Example: + >>> import torch + >>> from doctr.models import sar_resnet31 + >>> model = sar_resnet31(pretrained=False) + >>> input_tensor = torch.rand((1, 3, 32, 128)) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + + Returns: + text recognition architecture + """ + + return _sar('sar_resnet31', pretrained, resnet31, '10', **kwargs) diff --git a/doctr/models/recognition/sar/tensorflow.py b/doctr/models/recognition/sar/tensorflow.py new file mode 100644 index 0000000000..f8d641752d --- /dev/null +++ b/doctr/models/recognition/sar/tensorflow.py @@ -0,0 +1,361 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple + +import tensorflow as tf +from tensorflow.keras import Model, Sequential, layers + +from doctr.datasets import VOCABS +from doctr.utils.repr import NestedObject + +from ...classification import resnet31 +from ...utils.tensorflow import load_pretrained_params +from ..core import RecognitionModel, RecognitionPostProcessor + +__all__ = ['SAR', 'sar_resnet31'] + +default_cfgs: Dict[str, Dict[str, Any]] = { + 'sar_resnet31': { + 'mean': (0.694, 0.695, 0.693), + 'std': (0.299, 0.296, 0.301), + 'input_shape': (32, 128, 3), + 'vocab': VOCABS['legacy_french'], + 'url': 'https://github.com/mindee/doctr/releases/download/v0.3.0/sar_resnet31-9ee49970.zip', + }, +} + + +class AttentionModule(layers.Layer, NestedObject): + """Implements attention module of the SAR model + + Args: + attention_units: number of hidden attention units + + """ + def __init__( + self, + attention_units: int + ) -> None: + + super().__init__() + self.hidden_state_projector = layers.Conv2D( + attention_units, 1, strides=1, use_bias=False, padding='same', kernel_initializer='he_normal', + ) + self.features_projector = layers.Conv2D( + attention_units, 3, strides=1, use_bias=True, padding='same', kernel_initializer='he_normal', + ) + self.attention_projector = layers.Conv2D( + 1, 1, strides=1, use_bias=False, padding="same", kernel_initializer='he_normal', + ) + self.flatten = layers.Flatten() + + def call( + self, + features: tf.Tensor, + hidden_state: tf.Tensor, + **kwargs: Any, + ) -> tf.Tensor: + + [H, W] = features.get_shape().as_list()[1:3] + # shape (N, 1, 1, rnn_units) -> (N, 1, 1, attention_units) + hidden_state_projection = self.hidden_state_projector(hidden_state, **kwargs) + # shape (N, H, W, vgg_units) -> (N, H, W, attention_units) + features_projection = self.features_projector(features, **kwargs) + projection = tf.math.tanh(hidden_state_projection + features_projection) + # shape (N, H, W, attention_units) -> (N, H, W, 1) + attention = self.attention_projector(projection, **kwargs) + # shape (N, H, W, 1) -> (N, H * W) + attention = self.flatten(attention) + attention = tf.nn.softmax(attention) + # shape (N, H * W) -> (N, H, W, 1) + attention_map = tf.reshape(attention, [-1, H, W, 1]) + glimpse = tf.math.multiply(features, attention_map) + # shape (N, H * W) -> (N, 1) + glimpse = tf.reduce_sum(glimpse, axis=[1, 2]) + return glimpse + + +class SARDecoder(layers.Layer, NestedObject): + """Implements decoder module of the SAR model + + Args: + rnn_units: number of hidden units in recurrent cells + max_length: maximum length of a sequence + vocab_size: number of classes in the model alphabet + embedding_units: number of hidden embedding units + attention_units: number of hidden attention units + num_decoder_layers: number of LSTM layers to stack + + """ + def __init__( + self, + rnn_units: int, + max_length: int, + vocab_size: int, + embedding_units: int, + attention_units: int, + num_decoder_layers: int = 2, + input_shape: Optional[List[Tuple[Optional[int]]]] = None, + ) -> None: + + super().__init__() + self.vocab_size = vocab_size + self.lstm_decoder = layers.StackedRNNCells( + [layers.LSTMCell(rnn_units, implementation=1) for _ in range(num_decoder_layers)] + ) + self.embed = layers.Dense(embedding_units, use_bias=False, input_shape=(None, self.vocab_size + 1)) + self.attention_module = AttentionModule(attention_units) + self.output_dense = layers.Dense(vocab_size + 1, use_bias=True, input_shape=(None, 2 * rnn_units)) + self.max_length = max_length + + # Initialize kernels + if input_shape is not None: + self.attention_module.call(layers.Input(input_shape[0][1:]), layers.Input((1, 1, rnn_units))) + + def call( + self, + features: tf.Tensor, + holistic: tf.Tensor, + gt: Optional[tf.Tensor] = None, + **kwargs: Any, + ) -> tf.Tensor: + + # initialize states (each of shape (N, rnn_units)) + states = self.lstm_decoder.get_initial_state( + inputs=None, batch_size=features.shape[0], dtype=features.dtype + ) + # run first step of lstm + # holistic: shape (N, rnn_units) + _, states = self.lstm_decoder(holistic, states, **kwargs) + # Initialize with the index of virtual START symbol (placed after so that the one-hot is only zeros) + symbol = tf.fill(features.shape[0], self.vocab_size + 1) + logits_list = [] + if kwargs.get('training') and gt is None: + raise ValueError('Need to provide labels during training for teacher forcing') + for t in range(self.max_length + 1): # keep 1 step for + # one-hot symbol with depth vocab_size + 1 + # embeded_symbol: shape (N, embedding_units) + embeded_symbol = self.embed(tf.one_hot(symbol, depth=self.vocab_size + 1), **kwargs) + logits, states = self.lstm_decoder(embeded_symbol, states, **kwargs) + glimpse = self.attention_module( + features, tf.expand_dims(tf.expand_dims(logits, axis=1), axis=1), **kwargs, + ) + # logits: shape (N, rnn_units), glimpse: shape (N, 1) + logits = tf.concat([logits, glimpse], axis=-1) + # shape (N, rnn_units + 1) -> (N, vocab_size + 1) + logits = self.output_dense(logits, **kwargs) + # update symbol with predicted logits for t+1 step + if kwargs.get('training'): + symbol = gt[:, t] # type: ignore[index] + else: + symbol = tf.argmax(logits, axis=-1) + logits_list.append(logits) + outputs = tf.stack(logits_list, axis=1) # shape (N, max_length + 1, vocab_size + 1) + + return outputs + + +class SAR(Model, RecognitionModel): + """Implements a SAR architecture as described in `"Show, Attend and Read:A Simple and Strong Baseline for + Irregular Text Recognition" `_. + + Args: + feature_extractor: the backbone serving as feature extractor + vocab: vocabulary used for encoding + rnn_units: number of hidden units in both encoder and decoder LSTM + embedding_units: number of embedding units + attention_units: number of hidden units in attention module + max_length: maximum word length handled by the model + num_decoders: number of LSTM to stack in decoder layer + + """ + + _children_names: List[str] = ['feat_extractor', 'encoder', 'decoder', 'postprocessor'] + + def __init__( + self, + feature_extractor, + vocab: str, + rnn_units: int = 512, + embedding_units: int = 512, + attention_units: int = 512, + max_length: int = 30, + num_decoders: int = 2, + cfg: Optional[Dict[str, Any]] = None, + ) -> None: + + super().__init__() + self.vocab = vocab + self.cfg = cfg + + self.max_length = max_length + 1 # Add 1 timestep for EOS after the longest word + + self.feat_extractor = feature_extractor + + self.encoder = Sequential( + [ + layers.LSTM(units=rnn_units, return_sequences=True), + layers.LSTM(units=rnn_units, return_sequences=False) + ] + ) + # Initialize the kernels (watch out for reduce_max) + self.encoder.build(input_shape=(None,) + self.feat_extractor.output_shape[2:]) + + self.decoder = SARDecoder( + rnn_units, max_length, len(vocab), embedding_units, attention_units, num_decoders, + input_shape=[self.feat_extractor.output_shape, self.encoder.output_shape] + ) + + self.postprocessor = SARPostProcessor(vocab=vocab) + + @staticmethod + def compute_loss( + model_output: tf.Tensor, + gt: tf.Tensor, + seq_len: tf.Tensor, + ) -> tf.Tensor: + """Compute categorical cross-entropy loss for the model. + Sequences are masked after the EOS character. + + Args: + gt: the encoded tensor with gt labels + model_output: predicted logits of the model + seq_len: lengths of each gt word inside the batch + + Returns: + The loss of the model on the batch + """ + # Input length : number of timesteps + input_len = tf.shape(model_output)[1] + # Add one for additional token + seq_len = seq_len + 1 + # One-hot gt labels + oh_gt = tf.one_hot(gt, depth=model_output.shape[2]) + # Compute loss + cce = tf.nn.softmax_cross_entropy_with_logits(oh_gt, model_output) + # Compute mask + mask_values = tf.zeros_like(cce) + mask_2d = tf.sequence_mask(seq_len, input_len) + masked_loss = tf.where(mask_2d, cce, mask_values) + ce_loss = tf.math.divide(tf.reduce_sum(masked_loss, axis=1), tf.cast(seq_len, model_output.dtype)) + return tf.expand_dims(ce_loss, axis=1) + + def call( + self, + x: tf.Tensor, + target: Optional[List[str]] = None, + return_model_output: bool = False, + return_preds: bool = False, + **kwargs: Any, + ) -> Dict[str, Any]: + + features = self.feat_extractor(x, **kwargs) + pooled_features = tf.reduce_max(features, axis=1) # vertical max pooling + encoded = self.encoder(pooled_features, **kwargs) + if target is not None: + gt, seq_len = self.build_target(target) + seq_len = tf.cast(seq_len, tf.int32) + decoded_features = self.decoder(features, encoded, gt=None if target is None else gt, **kwargs) + + out: Dict[str, tf.Tensor] = {} + if return_model_output: + out["out_map"] = decoded_features + + if target is None or return_preds: + # Post-process boxes + out["preds"] = self.postprocessor(decoded_features) + + if target is not None: + out['loss'] = self.compute_loss(decoded_features, gt, seq_len) + + return out + + +class SARPostProcessor(RecognitionPostProcessor): + """Post processor for SAR architectures + + Args: + vocab: string containing the ordered sequence of supported characters + ignore_case: if True, ignore case of letters + ignore_accents: if True, ignore accents of letters + """ + + def __call__( + self, + logits: tf.Tensor, + ) -> List[Tuple[str, float]]: + # compute pred with argmax for attention models + out_idxs = tf.math.argmax(logits, axis=2) + # N x L + probs = tf.gather(tf.nn.softmax(logits, axis=-1), out_idxs, axis=-1, batch_dims=2) + # Take the minimum confidence of the sequence + probs = tf.math.reduce_min(probs, axis=1) + + # decode raw output of the model with tf_label_to_idx + out_idxs = tf.cast(out_idxs, dtype='int32') + embedding = tf.constant(self._embedding, dtype=tf.string) + decoded_strings_pred = tf.strings.reduce_join(inputs=tf.nn.embedding_lookup(embedding, out_idxs), axis=-1) + decoded_strings_pred = tf.strings.split(decoded_strings_pred, "") + decoded_strings_pred = tf.sparse.to_dense(decoded_strings_pred.to_sparse(), default_value='not valid')[:, 0] + word_values = [word.decode() for word in decoded_strings_pred.numpy().tolist()] + + return list(zip(word_values, probs.numpy().tolist())) + + +def _sar( + arch: str, + pretrained: bool, + backbone_fn, + pretrained_backbone: bool = True, + input_shape: Optional[Tuple[int, int, int]] = None, + **kwargs: Any +) -> SAR: + + pretrained_backbone = pretrained_backbone and not pretrained + + # Patch the config + _cfg = deepcopy(default_cfgs[arch]) + _cfg['input_shape'] = input_shape or _cfg['input_shape'] + _cfg['vocab'] = kwargs.get('vocab', _cfg['vocab']) + + # Feature extractor + feat_extractor = backbone_fn( + pretrained=pretrained_backbone, + input_shape=_cfg['input_shape'], + include_top=False, + ) + + kwargs['vocab'] = _cfg['vocab'] + + # Build the model + model = SAR(feat_extractor, cfg=_cfg, **kwargs) + # Load pretrained parameters + if pretrained: + load_pretrained_params(model, default_cfgs[arch]['url']) + + return model + + +def sar_resnet31(pretrained: bool = False, **kwargs: Any) -> SAR: + """SAR with a resnet-31 feature extractor as described in `"Show, Attend and Read:A Simple and Strong + Baseline for Irregular Text Recognition" `_. + + Example: + >>> import tensorflow as tf + >>> from doctr.models import sar_resnet31 + >>> model = sar_resnet31(pretrained=False) + >>> input_tensor = tf.random.uniform(shape=[1, 64, 256, 3], maxval=1, dtype=tf.float32) + >>> out = model(input_tensor) + + Args: + pretrained (bool): If True, returns a model pre-trained on our text recognition dataset + + Returns: + text recognition architecture + """ + + return _sar('sar_resnet31', pretrained, resnet31, **kwargs) diff --git a/doctr/models/recognition/transformer/__init__.py b/doctr/models/recognition/transformer/__init__.py new file mode 100644 index 0000000000..059f261e82 --- /dev/null +++ b/doctr/models/recognition/transformer/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/recognition/transformer/pytorch.py b/doctr/models/recognition/transformer/pytorch.py new file mode 100644 index 0000000000..d1ad696a19 --- /dev/null +++ b/doctr/models/recognition/transformer/pytorch.py @@ -0,0 +1,91 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import math +from typing import Optional + +import torch +from torch import nn + +__all__ = ['Decoder', 'positional_encoding'] + + +def positional_encoding(position: int, d_model: int = 512, dtype=torch.float32) -> torch.Tensor: + """Implementation borrowed from this pytorch tutorial: + `_. + + Args: + position: Number of positions to encode + d_model: depth of the encoding + + Returns: + 2D positional encoding as described in Transformer paper. + """ + pe = torch.zeros(position, d_model) + pos = torch.arange(0, position, dtype=dtype).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2, dtype=dtype) * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(pos * div_term) + pe[:, 1::2] = torch.cos(pos * div_term) + return pe.unsqueeze(0) + + +class Decoder(nn.Module): + + pos_encoding: torch.Tensor + + def __init__( + self, + num_layers: int = 3, + d_model: int = 512, + num_heads: int = 8, + dff: int = 2048, + vocab_size: int = 120, + maximum_position_encoding: int = 50, + dropout: float = 0.2, + ) -> None: + super().__init__() + + self.d_model = d_model + self.num_layers = num_layers + + self.embedding = nn.Embedding(vocab_size + 3, d_model) # 3 more classes EOS/SOS/PAD + self.register_buffer('pos_encoding', positional_encoding(maximum_position_encoding, d_model)) + + self.dec_layers = nn.ModuleList([ + nn.TransformerDecoderLayer( + d_model=d_model, + nhead=num_heads, + dim_feedforward=dff, + dropout=dropout, + activation='relu', + batch_first=True, + ) for _ in range(num_layers) + ]) + + self.dropout = nn.Dropout(dropout) + + def forward( + self, + x: torch.Tensor, + enc_output: torch.Tensor, + look_ahead_mask: torch.Tensor, + padding_mask: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + + seq_len = x.shape[1] # Batch first = True + + x = self.embedding(x) # (batch_size, target_seq_len, d_model) + x *= math.sqrt(self.d_model) + x += self.pos_encoding[:, :seq_len, :] + x = self.dropout(x) + + # Batch first = True in decoder + for i in range(self.num_layers): + x = self.dec_layers[i]( + tgt=x, memory=enc_output, tgt_mask=look_ahead_mask, memory_mask=padding_mask + ) + + # shape (batch_size, target_seq_len, d_model) + return x diff --git a/doctr/models/recognition/transformer/tensorflow.py b/doctr/models/recognition/transformer/tensorflow.py new file mode 100644 index 0000000000..dfba6f2bde --- /dev/null +++ b/doctr/models/recognition/transformer/tensorflow.py @@ -0,0 +1,265 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +# This module 'transformer.py' is 100% inspired from this Tensorflow tutorial: +# https://www.tensorflow.org/text/tutorials/transformer + + +from typing import Any, Tuple + +import numpy as np +import tensorflow as tf + +__all__ = ['Decoder', 'positional_encoding', 'create_look_ahead_mask', 'create_padding_mask'] + + +def get_angles(pos: np.array, i: np.array, d_model: int = 512) -> np.array: + """This function compute the 2D array of angles for sinusoidal positional encoding. + + Args: + pos: range of positions to encode + i: range of depth to encode positions + d_model: depth parameter of the model + + Returns: + 2D array of angles, len(pos) x len(i) + """ + angle_rates = 1 / np.power(10000, (2 * (i // 2)) / np.float32(d_model)) + return pos * angle_rates + + +def positional_encoding(position: int, d_model: int = 512, dtype=tf.float32) -> tf.Tensor: + """This function computes the 2D positional encoding of the position, on a depth d_model + + Args: + position: Number of positions to encode + d_model: depth of the encoding + + Returns: + 2D positional encoding as described in Transformer paper. + """ + angle_rads = get_angles( + np.arange(position)[:, np.newaxis], + np.arange(d_model)[np.newaxis, :], + d_model, + ) + # apply sin to even indices in the array; 2i + angle_rads[:, 0::2] = np.sin(angle_rads[:, 0::2]) + # apply cos to odd indices in the array; 2i+1 + angle_rads[:, 1::2] = np.cos(angle_rads[:, 1::2]) + pos_encoding = angle_rads[np.newaxis, ...] + return tf.cast(pos_encoding, dtype=dtype) + + +@tf.function +def create_padding_mask(seq: tf.Tensor, padding: int = 0, dtype=tf.float32) -> tf.Tensor: + seq = tf.cast(tf.math.equal(seq, padding), dtype) + # add extra dimensions to add the padding to the attention logits. + return seq[:, tf.newaxis, tf.newaxis, :] # (batch_size, 1, 1, seq_len) + + +@tf.function +def create_look_ahead_mask(size: int) -> tf.Tensor: + mask = 1 - tf.linalg.band_part(tf.ones((size, size)), -1, 0) + return mask # (seq_len, seq_len) + + +@tf.function +def scaled_dot_product_attention( + q: tf.Tensor, k: tf.Tensor, v: tf.Tensor, mask: tf.Tensor +) -> Tuple[tf.Tensor, tf.Tensor]: + + """Calculate the attention weights. + q, k, v must have matching leading dimensions. + k, v must have matching penultimate dimension, i.e.: seq_len_k = seq_len_v. + The mask has different shapes depending on its type(padding or look ahead) + but it must be broadcastable for addition. + Args: + q: query shape == (..., seq_len_q, depth) + k: key shape == (..., seq_len_k, depth) + v: value shape == (..., seq_len_v, depth_v) + mask: Float tensor with shape broadcastable to (..., seq_len_q, seq_len_k). Defaults to None. + Returns: + output, attention_weights + """ + + matmul_qk = tf.matmul(q, k, transpose_b=True) # (..., seq_len_q, seq_len_k) + # scale matmul_qk + dk = tf.cast(tf.shape(k)[-1], q.dtype) + scaled_attention_logits = matmul_qk / tf.math.sqrt(dk) + # add the mask to the scaled tensor. + if mask is not None: + scaled_attention_logits += (tf.cast(mask, dtype=q.dtype) * -1e9) + # softmax is normalized on the last axis (seq_len_k) so that the scores + # add up to 1. + attention_weights = tf.nn.softmax(scaled_attention_logits, axis=-1) # (..., seq_len_q, seq_len_k) + output = tf.matmul(attention_weights, v) # (..., seq_len_q, depth_v) + return output + + +class MultiHeadAttention(tf.keras.layers.Layer): + + def __init__(self, d_model: int = 512, num_heads: int = 8) -> None: + super(MultiHeadAttention, self).__init__() + self.num_heads = num_heads + self.d_model = d_model + + assert d_model % self.num_heads == 0 + + self.depth = d_model // self.num_heads + + self.wq = tf.keras.layers.Dense(d_model, kernel_initializer=tf.initializers.he_uniform()) + self.wk = tf.keras.layers.Dense(d_model, kernel_initializer=tf.initializers.he_uniform()) + self.wv = tf.keras.layers.Dense(d_model, kernel_initializer=tf.initializers.he_uniform()) + + self.dense = tf.keras.layers.Dense(d_model, kernel_initializer=tf.initializers.he_uniform()) + + def split_heads(self, x: tf.Tensor, batch_size: int) -> tf.Tensor: + """Split the last dimension into (num_heads, depth). + Transpose the result such that the shape is (batch_size, num_heads, seq_len, depth) + """ + x = tf.reshape(x, (batch_size, -1, self.num_heads, self.depth)) + return tf.transpose(x, perm=[0, 2, 1, 3]) + + def call( + self, + v: tf.Tensor, + k: tf.Tensor, + q: tf.Tensor, + mask: tf.Tensor, + **kwargs: Any, + ) -> Tuple[tf.Tensor, tf.Tensor]: + + batch_size = tf.shape(q)[0] + + q = self.wq(q, **kwargs) # (batch_size, seq_len, d_model) + k = self.wk(k, **kwargs) # (batch_size, seq_len, d_model) + v = self.wv(v, **kwargs) # (batch_size, seq_len, d_model) + + q = self.split_heads(q, batch_size) # (batch_size, num_heads, seq_len_q, depth) + k = self.split_heads(k, batch_size) # (batch_size, num_heads, seq_len_k, depth) + v = self.split_heads(v, batch_size) # (batch_size, num_heads, seq_len_v, depth) + + # scaled_attention.shape == (batch_size, num_heads, seq_len_q, depth) + # attention_weights.shape == (batch_size, num_heads, seq_len_q, seq_len_k) + scaled_attention = scaled_dot_product_attention(q, k, v, mask) + + scaled_attention = tf.transpose(scaled_attention, perm=[0, 2, 1, 3]) # (batch, seq_len_q, num_heads, depth) + + concat_attention = tf.reshape(scaled_attention, + (batch_size, -1, self.d_model)) # (batch_size, seq_len_q, d_model) + + output = self.dense(concat_attention, **kwargs) # (batch_size, seq_len_q, d_model) + + return output + + +def point_wise_feed_forward_network(d_model: int = 512, dff: int = 2048) -> tf.keras.Sequential: + return tf.keras.Sequential([ + tf.keras.layers.Dense( + dff, activation='relu', kernel_initializer=tf.initializers.he_uniform() + ), # (batch, seq_len, dff) + tf.keras.layers.Dense(d_model, kernel_initializer=tf.initializers.he_uniform()) # (batch, seq_len, d_model) + ]) + + +class DecoderLayer(tf.keras.layers.Layer): + + def __init__( + self, + d_model: int = 512, + num_heads: int = 8, + dff: int = 2048, + dropout: float = 0.2, + ) -> None: + super(DecoderLayer, self).__init__() + + self.mha1 = MultiHeadAttention(d_model, num_heads) + self.mha2 = MultiHeadAttention(d_model, num_heads) + + self.ffn = point_wise_feed_forward_network(d_model, dff) + + self.layernorm1 = tf.keras.layers.LayerNormalization(epsilon=1e-6) + self.layernorm2 = tf.keras.layers.LayerNormalization(epsilon=1e-6) + self.layernorm3 = tf.keras.layers.LayerNormalization(epsilon=1e-6) + + self.dropout1 = tf.keras.layers.Dropout(dropout) + self.dropout2 = tf.keras.layers.Dropout(dropout) + self.dropout3 = tf.keras.layers.Dropout(dropout) + + def call( + self, + x: tf.Tensor, + enc_output: tf.Tensor, + look_ahead_mask: tf.Tensor, + padding_mask: tf.Tensor, + **kwargs: Any, + ) -> Tuple[tf.Tensor, tf.Tensor, tf.Tensor]: + # enc_output.shape == (batch_size, input_seq_len, d_model) + + attn1 = self.mha1(x, x, x, look_ahead_mask, **kwargs) # (batch_size, target_seq_len, d_model) + attn1 = self.dropout1(attn1, **kwargs) + out1 = self.layernorm1(attn1 + x, **kwargs) + + attn2 = self.mha2(enc_output, enc_output, out1, padding_mask, **kwargs) # (batch_size, target_seq_len, d_model) + attn2 = self.dropout2(attn2, **kwargs) + out2 = self.layernorm2(attn2 + out1, **kwargs) # (batch_size, target_seq_len, d_model) + + ffn_output = self.ffn(out2, **kwargs) # (batch_size, target_seq_len, d_model) + ffn_output = self.dropout3(ffn_output, **kwargs) + out3 = self.layernorm3(ffn_output + out2, **kwargs) # (batch_size, target_seq_len, d_model) + + return out3 + + +class Decoder(tf.keras.layers.Layer): + + def __init__( + self, + num_layers: int = 3, + d_model: int = 512, + num_heads: int = 8, + dff: int = 2048, + vocab_size: int = 120, + maximum_position_encoding: int = 50, + dropout: float = 0.2, + ) -> None: + super(Decoder, self).__init__() + + self.d_model = d_model + self.num_layers = num_layers + + self.embedding = tf.keras.layers.Embedding(vocab_size + 3, d_model) # 3 more classes EOS/SOS/PAD + self.pos_encoding = positional_encoding(maximum_position_encoding, d_model) + + self.dec_layers = [DecoderLayer(d_model, num_heads, dff, dropout) + for _ in range(num_layers)] + + self.dropout = tf.keras.layers.Dropout(dropout) + + def call( + self, + x: tf.Tensor, + enc_output: tf.Tensor, + look_ahead_mask: tf.Tensor, + padding_mask: tf.Tensor, + **kwargs: Any, + ) -> Tuple[tf.Tensor, tf.Tensor]: + + seq_len = tf.shape(x)[1] + + x = self.embedding(x, **kwargs) # (batch_size, target_seq_len, d_model) + x *= tf.math.sqrt(tf.cast(self.d_model, x.dtype)) + x += tf.cast(self.pos_encoding[:, :seq_len, :], dtype=x.dtype) + + x = self.dropout(x, **kwargs) + + for i in range(self.num_layers): + x = self.dec_layers[i]( + x, enc_output, look_ahead_mask, padding_mask, **kwargs + ) + + # x.shape == (batch_size, target_seq_len, d_model) + return x diff --git a/doctr/models/recognition/utils.py b/doctr/models/recognition/utils.py new file mode 100644 index 0000000000..d5bf5cc883 --- /dev/null +++ b/doctr/models/recognition/utils.py @@ -0,0 +1,84 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import List + +from rapidfuzz.string_metric import levenshtein + +__all__ = ['merge_strings', 'merge_multi_strings'] + + +def merge_strings(a: str, b: str, dil_factor: float) -> str: + """Merges 2 character sequences in the best way to maximize the alignment of their overlapping characters. + + Args: + a: first char seq, suffix should be similar to b's prefix. + b: second char seq, prefix should be similar to a's suffix. + dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is + only used when the mother sequence is splitted on a character repetition + + Returns: + A merged character sequence. + + Example:: + >>> from doctr.model.recognition.utils import merge_sequences + >>> merge_sequences('abcd', 'cdefgh', 1.4) + 'abcdefgh' + >>> merge_sequences('abcdi', 'cdefgh', 1.4) + 'abcdefgh' + """ + seq_len = min(len(a), len(b)) + if seq_len == 0: # One sequence is empty, return the other + return b if len(a) == 0 else b + + # Initialize merging index and corresponding score (mean Levenstein) + min_score, index = 1., 0 # No overlap, just concatenate + + scores = [levenshtein(a[-i:], b[:i], processor=None) / i for i in range(1, seq_len + 1)] + + # Edge case (split in the middle of char repetitions): if it starts with 2 or more 0 + if len(scores) > 1 and (scores[0], scores[1]) == (0, 0): + # Compute n_overlap (number of overlapping chars, geometrically determined) + n_overlap = round(len(b) * (dil_factor - 1) / dil_factor) + # Find the number of consecutive zeros in the scores list + # Impossible to have a zero after a non-zero score in that case + n_zeros = sum(val == 0 for val in scores) + # Index is bounded by the geometrical overlap to avoid collapsing repetitions + min_score, index = 0, min(n_zeros, n_overlap) + + else: # Common case: choose the min score index + for i, score in enumerate(scores): + if score < min_score: + min_score, index = score, i + 1 # Add one because first index is an overlap of 1 char + + # Merge with correct overlap + if index == 0: + return a + b + return a[:-1] + b[index - 1:] + + +def merge_multi_strings(seq_list: List[str], dil_factor: float) -> str: + """Recursively merges consecutive string sequences with overlapping characters. + + Args: + seq_list: list of sequences to merge. Sequences need to be ordered from left to right. + dil_factor: dilation factor of the boxes to overlap, should be > 1. This parameter is + only used when the mother sequence is splitted on a character repetition + + Returns: + A merged character sequence + + Example:: + >>> from doctr.model.recognition.utils import merge_multi_sequences + >>> merge_multi_sequences(['abc', 'bcdef', 'difghi', 'aijkl'], 1.4) + 'abcdefghijkl' + """ + def _recursive_merge(a: str, seq_list: List[str], dil_factor: float) -> str: + # Recursive version of compute_overlap + if len(seq_list) == 1: + return merge_strings(a, seq_list[0], dil_factor) + return _recursive_merge(merge_strings(a, seq_list[0], dil_factor), seq_list[1:], dil_factor) + + return _recursive_merge("", seq_list, dil_factor) diff --git a/doctr/models/recognition/zoo.py b/doctr/models/recognition/zoo.py new file mode 100644 index 0000000000..ce697b6f58 --- /dev/null +++ b/doctr/models/recognition/zoo.py @@ -0,0 +1,56 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any + +from doctr.file_utils import is_tf_available +from doctr.models.preprocessor import PreProcessor + +from .. import recognition +from .predictor import RecognitionPredictor + +__all__ = ["recognition_predictor"] + + +ARCHS = ['crnn_vgg16_bn', 'crnn_mobilenet_v3_small', 'crnn_mobilenet_v3_large', 'sar_resnet31', 'master'] + + +def _predictor(arch: str, pretrained: bool, **kwargs: Any) -> RecognitionPredictor: + + if arch not in ARCHS: + raise ValueError(f"unknown architecture '{arch}'") + + _model = recognition.__dict__[arch](pretrained=pretrained) + kwargs['mean'] = kwargs.get('mean', _model.cfg['mean']) + kwargs['std'] = kwargs.get('std', _model.cfg['std']) + kwargs['batch_size'] = kwargs.get('batch_size', 32) + input_shape = _model.cfg['input_shape'][:2] if is_tf_available() else _model.cfg['input_shape'][-2:] + predictor = RecognitionPredictor( + PreProcessor(input_shape, preserve_aspect_ratio=True, **kwargs), + _model + ) + + return predictor + + +def recognition_predictor(arch: str = 'crnn_vgg16_bn', pretrained: bool = False, **kwargs: Any) -> RecognitionPredictor: + """Text recognition architecture. + + Example:: + >>> import numpy as np + >>> from doctr.models import recognition_predictor + >>> model = recognition_predictor(pretrained=True) + >>> input_page = (255 * np.random.rand(32, 128, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + arch: name of the architecture to use (e.g. 'crnn_vgg16_bn') + pretrained: If True, returns a model pre-trained on our text recognition dataset + + Returns: + Recognition predictor + """ + + return _predictor(arch, pretrained, **kwargs) diff --git a/doctr/models/utils/__init__.py b/doctr/models/utils/__init__.py new file mode 100644 index 0000000000..059f261e82 --- /dev/null +++ b/doctr/models/utils/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/models/utils/pytorch.py b/doctr/models/utils/pytorch.py new file mode 100644 index 0000000000..d1e084883b --- /dev/null +++ b/doctr/models/utils/pytorch.py @@ -0,0 +1,84 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import logging +from typing import Any, List, Optional + +import torch +from torch import nn + +from doctr.utils.data import download_from_url + +__all__ = ['load_pretrained_params', 'conv_sequence_pt'] + + +def load_pretrained_params( + model: nn.Module, + url: Optional[str] = None, + hash_prefix: Optional[str] = None, + overwrite: bool = False, + **kwargs: Any, +) -> None: + """Load a set of parameters onto a model + + Example:: + >>> from doctr.models import load_pretrained_params + >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip") + + Args: + model: the keras model to be loaded + url: URL of the zipped set of parameters + hash_prefix: first characters of SHA256 expected hash + overwrite: should the zip extraction be enforced if the archive has already been extracted + """ + + if url is None: + logging.warning("Invalid model URL, using default initialization.") + else: + archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir='models', **kwargs) + + # Read state_dict + state_dict = torch.load(archive_path, map_location='cpu') + + # Load weights + model.load_state_dict(state_dict) + + +def conv_sequence_pt( + in_channels: int, + out_channels: int, + relu: bool = False, + bn: bool = False, + **kwargs: Any, +) -> List[nn.Module]: + """Builds a convolutional-based layer sequence + + Example:: + >>> from doctr.models import conv_sequence + >>> from torch.nn import Sequential + >>> module = Sequential(conv_sequence(3, 32, True, True, kernel_size=3)) + + Args: + out_channels: number of output channels + relu: whether ReLU should be used + bn: should a batch normalization layer be added + + Returns: + list of layers + """ + # No bias before Batch norm + kwargs['bias'] = kwargs.get('bias', not(bn)) + # Add activation directly to the conv if there is no BN + conv_seq: List[nn.Module] = [ + nn.Conv2d(in_channels, out_channels, **kwargs) + ] + + if bn: + conv_seq.append(nn.BatchNorm2d(out_channels)) + + if relu: + conv_seq.append(nn.ReLU(inplace=True)) + + return conv_seq diff --git a/doctr/models/utils/tensorflow.py b/doctr/models/utils/tensorflow.py new file mode 100644 index 0000000000..1205336f78 --- /dev/null +++ b/doctr/models/utils/tensorflow.py @@ -0,0 +1,123 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import logging +import os +from typing import Any, Callable, List, Optional, Union +from zipfile import ZipFile + +from tensorflow.keras import Model, layers + +from doctr.utils.data import download_from_url + +logging.getLogger("tensorflow").setLevel(logging.DEBUG) + + +__all__ = ['load_pretrained_params', 'conv_sequence', 'IntermediateLayerGetter'] + + +def load_pretrained_params( + model: Model, + url: Optional[str] = None, + hash_prefix: Optional[str] = None, + overwrite: bool = False, + internal_name: str = 'weights', + **kwargs: Any, +) -> None: + """Load a set of parameters onto a model + + Example:: + >>> from doctr.models import load_pretrained_params + >>> load_pretrained_params(model, "https://yoursource.com/yourcheckpoint-yourhash.zip") + + Args: + model: the keras model to be loaded + url: URL of the zipped set of parameters + hash_prefix: first characters of SHA256 expected hash + overwrite: should the zip extraction be enforced if the archive has already been extracted + internal_name: name of the ckpt files + """ + + if url is None: + logging.warning("Invalid model URL, using default initialization.") + else: + archive_path = download_from_url(url, hash_prefix=hash_prefix, cache_subdir='models', **kwargs) + + # Unzip the archive + params_path = archive_path.parent.joinpath(archive_path.stem) + if not params_path.is_dir() or overwrite: + with ZipFile(archive_path, 'r') as f: + f.extractall(path=params_path) + + # Load weights + model.load_weights(f"{params_path}{os.sep}{internal_name}") + + +def conv_sequence( + out_channels: int, + activation: Union[str, Callable] = None, + bn: bool = False, + padding: str = 'same', + kernel_initializer: str = 'he_normal', + **kwargs: Any, +) -> List[layers.Layer]: + """Builds a convolutional-based layer sequence + + Example:: + >>> from doctr.models import conv_sequence + >>> from tensorflow.keras import Sequential + >>> module = Sequential(conv_sequence(32, 'relu', True, kernel_size=3, input_shape=[224, 224, 3])) + + Args: + out_channels: number of output channels + activation: activation to be used (default: no activation) + bn: should a batch normalization layer be added + padding: padding scheme + kernel_initializer: kernel initializer + + Returns: + list of layers + """ + # No bias before Batch norm + kwargs['use_bias'] = kwargs.get('use_bias', not(bn)) + # Add activation directly to the conv if there is no BN + kwargs['activation'] = activation if not bn else None + conv_seq = [ + layers.Conv2D(out_channels, padding=padding, kernel_initializer=kernel_initializer, **kwargs) + ] + + if bn: + conv_seq.append(layers.BatchNormalization()) + + if (isinstance(activation, str) or callable(activation)) and bn: + # Activation function can either be a string or a function ('relu' or tf.nn.relu) + conv_seq.append(layers.Activation(activation)) + + return conv_seq + + +class IntermediateLayerGetter(Model): + """Implements an intermediate layer getter + + Example:: + >>> from doctr.models import IntermediateLayerGetter + >>> from tensorflow.keras.applications import ResNet50 + >>> target_layers = ["conv2_block3_out", "conv3_block4_out", "conv4_block6_out", "conv5_block3_out"] + >>> feat_extractor = IntermediateLayerGetter(ResNet50(include_top=False, pooling=False), target_layers) + + Args: + model: the model to extract feature maps from + layer_names: the list of layers to retrieve the feature map from + """ + def __init__( + self, + model: Model, + layer_names: List[str] + ) -> None: + intermediate_fmaps = [model.get_layer(layer_name).get_output_at(0) for layer_name in layer_names] + super().__init__(model.input, outputs=intermediate_fmaps) + + def __repr__(self) -> str: + return f"{self.__class__.__name__}()" diff --git a/doctr/models/zoo.py b/doctr/models/zoo.py new file mode 100644 index 0000000000..a7de8f8b6e --- /dev/null +++ b/doctr/models/zoo.py @@ -0,0 +1,87 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Any + +from .detection.zoo import detection_predictor +from .predictor import OCRPredictor +from .recognition.zoo import recognition_predictor + +__all__ = ["ocr_predictor"] + + +def _predictor( + det_arch: str, + reco_arch: str, + pretrained: bool, + assume_straight_pages: bool = True, + preserve_aspect_ratio: bool = False, + det_bs: int = 2, + reco_bs: int = 128, + **kwargs, +) -> OCRPredictor: + + # Detection + det_predictor = detection_predictor( + det_arch, + pretrained=pretrained, + batch_size=det_bs, + assume_straight_pages=assume_straight_pages, + preserve_aspect_ratio=preserve_aspect_ratio, + ) + + # Recognition + reco_predictor = recognition_predictor(reco_arch, pretrained=pretrained, batch_size=reco_bs) + + return OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=assume_straight_pages, + **kwargs + ) + + +def ocr_predictor( + det_arch: str = 'db_resnet50', + reco_arch: str = 'crnn_vgg16_bn', + pretrained: bool = False, + assume_straight_pages: bool = True, + export_as_straight_boxes: bool = False, + preserve_aspect_ratio: bool = False, + **kwargs: Any +) -> OCRPredictor: + """End-to-end OCR architecture using one model for localization, and another for text recognition. + + Example:: + >>> import numpy as np + >>> from doctr.models import ocr_predictor + >>> model = ocr_predictor('db_resnet50', 'crnn_vgg16_bn', pretrained=True) + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([input_page]) + + Args: + det_arch: name of the detection architecture to use (e.g. 'db_resnet50', 'db_mobilenet_v3_large') + reco_arch: name of the recognition architecture to use (e.g. 'crnn_vgg16_bn', 'sar_resnet31') + pretrained: If True, returns a model pre-trained on our OCR dataset + assume_straight_pages: if True, speeds up the inference by assuming you only pass straight pages + without rotated textual elements. + export_as_straight_boxes: when assume_straight_pages is set to False, export final predictions + (potentially rotated) as straight bounding boxes. + preserve_aspect_ratio: If True, pad the input document image to preserve the aspect ratio before + running the detection model on it. + + Returns: + OCR predictor + """ + + return _predictor( + det_arch, + reco_arch, + pretrained, + assume_straight_pages=assume_straight_pages, + export_as_straight_boxes=export_as_straight_boxes, + preserve_aspect_ratio=preserve_aspect_ratio, + **kwargs, + ) diff --git a/doctr/transforms/__init__.py b/doctr/transforms/__init__.py new file mode 100644 index 0000000000..270dcebaa5 --- /dev/null +++ b/doctr/transforms/__init__.py @@ -0,0 +1 @@ +from .modules import * diff --git a/doctr/transforms/functional/__init__.py b/doctr/transforms/functional/__init__.py new file mode 100644 index 0000000000..64556e403a --- /dev/null +++ b/doctr/transforms/functional/__init__.py @@ -0,0 +1,6 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * diff --git a/doctr/transforms/functional/base.py b/doctr/transforms/functional/base.py new file mode 100644 index 0000000000..c5ce39828b --- /dev/null +++ b/doctr/transforms/functional/base.py @@ -0,0 +1,44 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Tuple, Union + +import numpy as np + +__all__ = ["crop_boxes"] + + +def crop_boxes( + boxes: np.ndarray, + crop_box: Union[Tuple[int, int, int, int], Tuple[float, float, float, float]], +) -> np.ndarray: + """Crop localization boxes + + Args: + boxes: ndarray of shape (N, 4) in relative or abs coordinates + crop_box: box (xmin, ymin, xmax, ymax) to crop the image, in the same coord format that the boxes + + Returns: + the cropped boxes + """ + is_box_rel = boxes.max() <= 1 + is_crop_rel = max(crop_box) <= 1 + + if is_box_rel ^ is_crop_rel: + raise AssertionError("both the boxes and the crop need to have the same coordinate convention") + + xmin, ymin, xmax, ymax = crop_box + # Clip boxes & correct offset + boxes[:, [0, 2]] = boxes[:, [0, 2]].clip(xmin, xmax) - xmin + boxes[:, [1, 3]] = boxes[:, [1, 3]].clip(ymin, ymax) - ymin + # Rescale relative coords + if is_box_rel: + boxes[:, [0, 2]] /= (xmax - xmin) + boxes[:, [1, 3]] /= (ymax - ymin) + + # Remove 0-sized boxes + is_valid = np.logical_and(boxes[:, 1] < boxes[:, 3], boxes[:, 0] < boxes[:, 2]) + + return boxes[is_valid] diff --git a/doctr/transforms/functional/pytorch.py b/doctr/transforms/functional/pytorch.py new file mode 100644 index 0000000000..6e82f0ce49 --- /dev/null +++ b/doctr/transforms/functional/pytorch.py @@ -0,0 +1,103 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Tuple + +import numpy as np +import torch +from torchvision.transforms import functional as F + +from doctr.utils.geometry import rotate_abs_geoms + +from .base import crop_boxes + +__all__ = ["invert_colors", "rotate_sample", "crop_detection"] + + +def invert_colors(img: torch.Tensor, min_val: float = 0.6) -> torch.Tensor: + out = F.rgb_to_grayscale(img, num_output_channels=3) + # Random RGB shift + shift_shape = [img.shape[0], 3, 1, 1] if img.ndim == 4 else [3, 1, 1] + rgb_shift = min_val + (1 - min_val) * torch.rand(shift_shape) + # Inverse the color + if out.dtype == torch.uint8: + out = (out.to(dtype=rgb_shift.dtype) * rgb_shift).to(dtype=torch.uint8) + else: + out = out * rgb_shift.to(dtype=out.dtype) + # Inverse the color + out = 255 - out if out.dtype == torch.uint8 else 1 - out + return out + + +def rotate_sample( + img: torch.Tensor, + geoms: np.ndarray, + angle: float, + expand: bool = False, +) -> Tuple[torch.Tensor, np.ndarray]: + """Rotate image around the center, interpolation=NEAREST, pad with 0 (black) + + Args: + img: image to rotate + geoms: array of geometries of shape (N, 4) or (N, 4, 2) + angle: angle in degrees. +: counter-clockwise, -: clockwise + expand: whether the image should be padded before the rotation + + Returns: + A tuple of rotated img (tensor), rotated geometries of shape (N, 4, 2) + """ + rotated_img = F.rotate(img, angle=angle, fill=0, expand=expand) # Interpolation NEAREST by default + rotated_img = rotated_img[:3] # when expand=True, it expands to RGBA channels + # Get absolute coords + _geoms = deepcopy(geoms) + if _geoms.shape[1:] == (4,): + if np.max(_geoms) <= 1: + _geoms[:, [0, 2]] *= img.shape[-1] + _geoms[:, [1, 3]] *= img.shape[-2] + elif _geoms.shape[1:] == (4, 2): + if np.max(_geoms) <= 1: + _geoms[..., 0] *= img.shape[-1] + _geoms[..., 1] *= img.shape[-2] + else: + raise AssertionError("invalid format for arg `geoms`") + + # Rotate the boxes: xmin, ymin, xmax, ymax or polygons --> (4, 2) polygon + rotated_geoms = rotate_abs_geoms(_geoms, angle, img.shape[1:], expand).astype(np.float32) # type: ignore[arg-type] + + # Always return relative boxes to avoid label confusions when resizing is performed aferwards + rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[2] + rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[1] + + return rotated_img, rotated_geoms + + +def crop_detection( + img: torch.Tensor, + boxes: np.ndarray, + crop_box: Tuple[float, float, float, float] +) -> Tuple[torch.Tensor, np.ndarray]: + """Crop and image and associated bboxes + + Args: + img: image to crop + boxes: array of boxes to clip, absolute (int) or relative (float) + crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords. + + Returns: + A tuple of cropped image, cropped boxes, where the image is not resized. + """ + if any(val < 0 or val > 1 for val in crop_box): + raise AssertionError("coordinates of arg `crop_box` should be relative") + h, w = img.shape[-2:] + xmin, ymin = int(round(crop_box[0] * (w - 1))), int(round(crop_box[1] * (h - 1))) + xmax, ymax = int(round(crop_box[2] * (w - 1))), int(round(crop_box[3] * (h - 1))) + cropped_img = F.crop( + img, ymin, xmin, ymax - ymin, xmax - xmin + ) + # Crop the box + boxes = crop_boxes(boxes, crop_box if boxes.max() <= 1 else (xmin, ymin, xmax, ymax)) + + return cropped_img, boxes diff --git a/doctr/transforms/functional/tensorflow.py b/doctr/transforms/functional/tensorflow.py new file mode 100644 index 0000000000..e3d30cb931 --- /dev/null +++ b/doctr/transforms/functional/tensorflow.py @@ -0,0 +1,137 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import math +from copy import deepcopy +from typing import Tuple + +import numpy as np +import tensorflow as tf +import tensorflow_addons as tfa + +from doctr.utils.geometry import compute_expanded_shape, rotate_abs_geoms + +from .base import crop_boxes + +__all__ = ["invert_colors", "rotate_sample", "crop_detection"] + + +def invert_colors(img: tf.Tensor, min_val: float = 0.6) -> tf.Tensor: + out = tf.image.rgb_to_grayscale(img) # Convert to gray + # Random RGB shift + shift_shape = [img.shape[0], 1, 1, 3] if img.ndim == 4 else [1, 1, 3] + rgb_shift = tf.random.uniform(shape=shift_shape, minval=min_val, maxval=1) + # Inverse the color + if out.dtype == tf.uint8: + out = tf.cast(tf.cast(out, dtype=rgb_shift.dtype) * rgb_shift, dtype=tf.uint8) + else: + out *= tf.cast(rgb_shift, dtype=out.dtype) + # Inverse the color + out = 255 - out if out.dtype == tf.uint8 else 1 - out + return out + + +def rotated_img_tensor(img: tf.Tensor, angle: float, expand: bool = False) -> tf.Tensor: + """Rotate image around the center, interpolation=NEAREST, pad with 0 (black) + + Args: + img: image to rotate + angle: angle in degrees. +: counter-clockwise, -: clockwise + expand: whether the image should be padded before the rotation + + Returns: + the rotated image (tensor) + """ + # Compute the expanded padding + h_crop, w_crop = 0, 0 + if expand: + exp_h, exp_w = compute_expanded_shape(img.shape[:-1], angle) + h_diff, w_diff = int(math.ceil(exp_h - img.shape[0])), int(math.ceil(exp_w - img.shape[1])) + h_pad, w_pad = max(h_diff, 0), max(w_diff, 0) + exp_img = tf.pad(img, tf.constant([[h_pad // 2, h_pad - h_pad // 2], [w_pad // 2, w_pad - w_pad // 2], [0, 0]])) + h_crop, w_crop = int(round(max(exp_img.shape[0] - exp_h, 0))), int(round(min(exp_img.shape[1] - exp_w, 0))) + else: + exp_img = img + # Rotate the padded image + rotated_img = tfa.image.rotate(exp_img, angle * math.pi / 180) # Interpolation NEAREST by default + # Crop the rest + if h_crop > 0 or w_crop > 0: + h_slice = slice(h_crop // 2, -h_crop // 2) if h_crop > 0 else slice(rotated_img.shape[0]) + w_slice = slice(-w_crop // 2, -w_crop // 2) if w_crop > 0 else slice(rotated_img.shape[1]) + rotated_img = rotated_img[h_slice, w_slice] + + return rotated_img + + +def rotate_sample( + img: tf.Tensor, + geoms: np.ndarray, + angle: float, + expand: bool = False, +) -> Tuple[tf.Tensor, np.ndarray]: + """Rotate image around the center, interpolation=NEAREST, pad with 0 (black) + + Args: + img: image to rotate + geoms: array of geometries of shape (N, 4) or (N, 4, 2) + angle: angle in degrees. +: counter-clockwise, -: clockwise + expand: whether the image should be padded before the rotation + + Returns: + A tuple of rotated img (tensor), rotated boxes (np array) + """ + # Rotated the image + rotated_img = rotated_img_tensor(img, angle, expand) + + # Get absolute coords + _geoms = deepcopy(geoms) + if _geoms.shape[1:] == (4,): + if np.max(_geoms) <= 1: + _geoms[:, [0, 2]] *= img.shape[1] + _geoms[:, [1, 3]] *= img.shape[0] + elif _geoms.shape[1:] == (4, 2): + if np.max(_geoms) <= 1: + _geoms[..., 0] *= img.shape[1] + _geoms[..., 1] *= img.shape[0] + else: + raise AssertionError + + # Rotate the boxes: xmin, ymin, xmax, ymax or polygons --> (4, 2) polygon + rotated_geoms = rotate_abs_geoms(_geoms, angle, img.shape[:-1], expand).astype(np.float32) + + # Always return relative boxes to avoid label confusions when resizing is performed aferwards + rotated_geoms[..., 0] = rotated_geoms[..., 0] / rotated_img.shape[1] + rotated_geoms[..., 1] = rotated_geoms[..., 1] / rotated_img.shape[0] + + return rotated_img, rotated_geoms + + +def crop_detection( + img: tf.Tensor, + boxes: np.ndarray, + crop_box: Tuple[float, float, float, float] +) -> Tuple[tf.Tensor, np.ndarray]: + """Crop and image and associated bboxes + + Args: + img: image to crop + boxes: array of boxes to clip, absolute (int) or relative (float) + crop_box: box (xmin, ymin, xmax, ymax) to crop the image. Relative coords. + + Returns: + A tuple of cropped image, cropped boxes, where the image is not resized. + """ + if any(val < 0 or val > 1 for val in crop_box): + raise AssertionError("coordinates of arg `crop_box` should be relative") + h, w = img.shape[:2] + xmin, ymin = int(round(crop_box[0] * (w - 1))), int(round(crop_box[1] * (h - 1))) + xmax, ymax = int(round(crop_box[2] * (w - 1))), int(round(crop_box[3] * (h - 1))) + cropped_img = tf.image.crop_to_bounding_box( + img, ymin, xmin, ymax - ymin, xmax - xmin + ) + # Crop the box + boxes = crop_boxes(boxes, crop_box if boxes.max() <= 1 else (xmin, ymin, xmax, ymax)) + + return cropped_img, boxes diff --git a/doctr/transforms/modules/__init__.py b/doctr/transforms/modules/__init__.py new file mode 100644 index 0000000000..1950176a6d --- /dev/null +++ b/doctr/transforms/modules/__init__.py @@ -0,0 +1,8 @@ +from doctr.file_utils import is_tf_available, is_torch_available + +from .base import * + +if is_tf_available(): + from .tensorflow import * +elif is_torch_available(): + from .pytorch import * # type: ignore[misc] diff --git a/doctr/transforms/modules/base.py b/doctr/transforms/modules/base.py new file mode 100644 index 0000000000..81fcb568ac --- /dev/null +++ b/doctr/transforms/modules/base.py @@ -0,0 +1,191 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import math +import random +from typing import Any, Callable, Dict, List, Tuple + +import numpy as np + +from doctr.utils.repr import NestedObject + +from .. import functional as F + +__all__ = ['SampleCompose', 'ImageTransform', 'ColorInversion', 'OneOf', 'RandomApply', 'RandomRotate', 'RandomCrop'] + + +class SampleCompose(NestedObject): + """Implements a wrapper that will apply transformations sequentially on both image and target + Example:: + >>> from doctr.transforms import SampleCompose, ImageTransform, ColorInversion, RandomRotate + >>> import tensorflow as tf + >>> import numpy as np + >>> transfos = SampleCompose([ImageTransform(ColorInversion((32, 32))), RandomRotate(30)]) + >>> out, out_boxes = transfos(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), np.zeros((2, 4))) + Args: + transforms: list of transformation modules + """ + + _children_names: List[str] = ['sample_transforms'] + + def __init__(self, transforms: List[Callable[[Any, Any], Tuple[Any, Any]]]) -> None: + self.sample_transforms = transforms + + def __call__(self, x: Any, target: Any) -> Tuple[Any, Any]: + for t in self.sample_transforms: + x, target = t(x, target) + + return x, target + + +class ImageTransform(NestedObject): + """Implements a transform wrapper to turn an image-only transformation into an image+target transform + Example:: + >>> from doctr.transforms import ImageTransform, ColorInversion + >>> import tensorflow as tf + >>> transfo = ImageTransform(ColorInversion((32, 32))) + >>> out, _ = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1), None) + Args: + transform: the image transformation module to wrap + """ + + _children_names: List[str] = ['img_transform'] + + def __init__(self, transform: Callable[[Any], Any]) -> None: + self.img_transform = transform + + def __call__(self, img: Any, target: Any) -> Tuple[Any, Any]: + img = self.img_transform(img) + return img, target + + +class ColorInversion(NestedObject): + """Applies the following tranformation to a tensor (image or batch of images): + convert to grayscale, colorize (shift 0-values randomly), and then invert colors + + Example:: + >>> from doctr.transforms import Normalize + >>> import tensorflow as tf + >>> transfo = ColorInversion(min_val=0.6) + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + Args: + min_val: range [min_val, 1] to colorize RGB pixels + """ + def __init__(self, min_val: float = 0.5) -> None: + self.min_val = min_val + + def extra_repr(self) -> str: + return f"min_val={self.min_val}" + + def __call__(self, img: Any) -> Any: + return F.invert_colors(img, self.min_val) + + +class OneOf(NestedObject): + """Randomly apply one of the input transformations + + Example:: + >>> from doctr.transforms import Normalize + >>> import tensorflow as tf + >>> transfo = OneOf([JpegQuality(), Gamma()]) + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + Args: + transforms: list of transformations, one only will be picked + """ + + _children_names: List[str] = ['transforms'] + + def __init__(self, transforms: List[Callable[[Any], Any]]) -> None: + self.transforms = transforms + + def __call__(self, img: Any) -> Any: + # Pick transformation + transfo = self.transforms[int(random.random() * len(self.transforms))] + # Apply + return transfo(img) + + +class RandomApply(NestedObject): + """Apply with a probability p the input transformation + + Example:: + >>> from doctr.transforms import Normalize + >>> import tensorflow as tf + >>> transfo = RandomApply(Gamma(), p=.5) + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + Args: + transform: transformation to apply + p: probability to apply + """ + def __init__(self, transform: Callable[[Any], Any], p: float = .5) -> None: + self.transform = transform + self.p = p + + def extra_repr(self) -> str: + return f"transform={self.transform}, p={self.p}" + + def __call__(self, img: Any) -> Any: + if random.random() < self.p: + return self.transform(img) + return img + + +class RandomRotate(NestedObject): + """Randomly rotate a tensor image and its boxes + + .. image:: https://github.com/mindee/doctr/releases/download/v0.4.0/rotation_illustration.png + :align: center + + Args: + max_angle: maximum angle for rotation, in degrees. Angles will be uniformly picked in + [-max_angle, max_angle] + expand: whether the image should be padded before the rotation + """ + def __init__(self, max_angle: float = 5., expand: bool = False) -> None: + self.max_angle = max_angle + self.expand = expand + + def extra_repr(self) -> str: + return f"max_angle={self.max_angle}, expand={self.expand}" + + def __call__(self, img: Any, target: np.ndarray) -> Tuple[Any, np.ndarray]: + angle = random.uniform(-self.max_angle, self.max_angle) + r_img, r_polys = F.rotate_sample(img, target, angle, self.expand) + # Removes deleted boxes + is_kept = (r_polys.max(1) > r_polys.min(1)).sum(1) == 2 + return r_img, r_polys[is_kept] + + +class RandomCrop(NestedObject): + """Randomly crop a tensor image and its boxes + + Args: + scale: tuple of floats, relative (min_area, max_area) of the crop + ratio: tuple of float, relative (min_ratio, max_ratio) where ratio = h/w + """ + def __init__(self, scale: Tuple[float, float] = (0.08, 1.), ratio: Tuple[float, float] = (0.75, 1.33)) -> None: + self.scale = scale + self.ratio = ratio + + def extra_repr(self) -> str: + return f"scale={self.scale}, ratio={self.ratio}" + + def __call__(self, img: Any, target: Dict[str, np.ndarray]) -> Tuple[Any, Dict[str, np.ndarray]]: + scale = random.uniform(self.scale[0], self.scale[1]) + ratio = random.uniform(self.ratio[0], self.ratio[1]) + # Those might overflow + crop_h = math.sqrt(scale * ratio) + crop_w = math.sqrt(scale / ratio) + xmin, ymin = random.uniform(0, 1 - crop_w), random.uniform(0, 1 - crop_h) + xmax, ymax = xmin + crop_w, ymin + crop_h + # Clip them + xmin, ymin = max(xmin, 0), max(ymin, 0) + xmax, ymax = min(xmax, 1), min(ymax, 1) + + croped_img, crop_boxes = F.crop_detection(img, target["boxes"], (xmin, ymin, xmax, ymax)) + return croped_img, dict(boxes=crop_boxes) diff --git a/doctr/transforms/modules/pytorch.py b/doctr/transforms/modules/pytorch.py new file mode 100644 index 0000000000..045cc3967e --- /dev/null +++ b/doctr/transforms/modules/pytorch.py @@ -0,0 +1,121 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import math +from typing import Any, Dict, Tuple, Union + +import torch +from PIL.Image import Image +from torch.nn.functional import pad +from torchvision.transforms import functional as F +from torchvision.transforms import transforms as T + +__all__ = ['Resize', 'GaussianNoise', 'ChannelShuffle', 'RandomHorizontalFlip'] + + +class Resize(T.Resize): + def __init__( + self, + size: Tuple[int, int], + interpolation=F.InterpolationMode.BILINEAR, + preserve_aspect_ratio: bool = False, + symmetric_pad: bool = False, + ) -> None: + super().__init__(size, interpolation) + self.preserve_aspect_ratio = preserve_aspect_ratio + self.symmetric_pad = symmetric_pad + + def forward(self, img: torch.Tensor) -> torch.Tensor: + target_ratio = self.size[0] / self.size[1] + actual_ratio = img.shape[-2] / img.shape[-1] + if not self.preserve_aspect_ratio or (target_ratio == actual_ratio): + return super().forward(img) + else: + # Resize + if actual_ratio > target_ratio: + tmp_size = (self.size[0], max(int(self.size[0] / actual_ratio), 1)) + else: + tmp_size = (max(int(self.size[1] * actual_ratio), 1), self.size[1]) + + # Scale image + img = F.resize(img, tmp_size, self.interpolation) + # Pad (inverted in pytorch) + _pad = (0, self.size[1] - img.shape[-1], 0, self.size[0] - img.shape[-2]) + if self.symmetric_pad: + half_pad = (math.ceil(_pad[1] / 2), math.ceil(_pad[3] / 2)) + _pad = (half_pad[0], _pad[1] - half_pad[0], half_pad[1], _pad[3] - half_pad[1]) + return pad(img, _pad) + + def __repr__(self) -> str: + interpolate_str = self.interpolation.value + _repr = f"output_size={self.size}, interpolation='{interpolate_str}'" + if self.preserve_aspect_ratio: + _repr += f", preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}" + return f"{self.__class__.__name__}({_repr})" + + +class GaussianNoise(torch.nn.Module): + """Adds Gaussian Noise to the input tensor + + Example:: + >>> from doctr.transforms import GaussianNoise + >>> import torch + >>> transfo = GaussianNoise(0., 1.) + >>> out = transfo(torch.rand((3, 224, 224))) + + Args: + mean : mean of the gaussian distribution + std : std of the gaussian distribution + """ + def __init__(self, mean: float = 0., std: float = 1.) -> None: + super().__init__() + self.std = std + self.mean = mean + + def forward(self, x: torch.Tensor) -> torch.Tensor: + # Reshape the distribution + noise = self.mean + 2 * self.std * torch.rand(x.shape, device=x.device) - self.std + if x.dtype == torch.uint8: + return (x + 255 * noise).round().clamp(0, 255).to(dtype=torch.uint8) + else: + return (x + noise.to(dtype=x.dtype)).clamp(0, 1) + + def extra_repr(self) -> str: + return f"mean={self.mean}, std={self.std}" + + +class ChannelShuffle(torch.nn.Module): + """Randomly shuffle channel order of a given image""" + + def __init__(self): + super().__init__() + + def forward(self, img: torch.Tensor) -> torch.Tensor: + # Get a random order + chan_order = torch.rand(img.shape[0]).argsort() + return img[chan_order] + + +class RandomHorizontalFlip(T.RandomHorizontalFlip): + + def forward( + self, + img: Union[torch.Tensor, Image], + target: Dict[str, Any] + ) -> Tuple[Union[torch.Tensor, Image], Dict[str, Any]]: + """ + Args: + img: Image to be flipped. + target: Dictionary with boxes (in relative coordinates of shape (N, 4)) and labels as keys + Returns: + Tuple of PIL Image or Tensor and target + """ + if torch.rand(1) < self.p: + _img = F.hflip(img) + _target = target.copy() + # Changing the relative bbox coordinates + _target["boxes"][:, ::2] = 1 - target["boxes"][:, [2, 0]] + return _img, _target + return img, target diff --git a/doctr/transforms/modules/tensorflow.py b/doctr/transforms/modules/tensorflow.py new file mode 100644 index 0000000000..7dc9bfc408 --- /dev/null +++ b/doctr/transforms/modules/tensorflow.py @@ -0,0 +1,419 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import random +from typing import Any, Callable, Dict, Iterable, List, Tuple, Union + +import numpy as np +import tensorflow as tf +import tensorflow_addons as tfa + +from doctr.utils.repr import NestedObject + +__all__ = ['Compose', 'Resize', 'Normalize', 'LambdaTransformation', 'ToGray', 'RandomBrightness', + 'RandomContrast', 'RandomSaturation', 'RandomHue', 'RandomGamma', 'RandomJpegQuality', 'GaussianBlur', + 'ChannelShuffle', 'GaussianNoise', 'RandomHorizontalFlip'] + + +class Compose(NestedObject): + """Implements a wrapper that will apply transformations sequentially + + Example:: + >>> from doctr.transforms import Compose, Resize + >>> import tensorflow as tf + >>> transfos = Compose([Resize((32, 32))]) + >>> out = transfos(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + Args: + transforms: list of transformation modules + """ + + _children_names: List[str] = ['transforms'] + + def __init__(self, transforms: List[Callable[[Any], Any]]) -> None: + self.transforms = transforms + + def __call__(self, x: Any) -> Any: + for t in self.transforms: + x = t(x) + + return x + + +class Resize(NestedObject): + """Resizes a tensor to a target size + + Example:: + >>> from doctr.transforms import Resize + >>> import tensorflow as tf + >>> transfo = Resize((32, 32)) + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + Args: + output_size: expected output size + method: interpolation method + preserve_aspect_ratio: if `True`, preserve aspect ratio and pad the rest with zeros + symmetric_pad: if `True` while preserving aspect ratio, the padding will be done symmetrically + """ + def __init__( + self, + output_size: Tuple[int, int], + method: str = 'bilinear', + preserve_aspect_ratio: bool = False, + symmetric_pad: bool = False, + ) -> None: + self.output_size = output_size + self.method = method + self.preserve_aspect_ratio = preserve_aspect_ratio + self.symmetric_pad = symmetric_pad + + def extra_repr(self) -> str: + _repr = f"output_size={self.output_size}, method='{self.method}'" + if self.preserve_aspect_ratio: + _repr += f", preserve_aspect_ratio={self.preserve_aspect_ratio}, symmetric_pad={self.symmetric_pad}" + return _repr + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + input_dtype = img.dtype + img = tf.image.resize(img, self.output_size, self.method, self.preserve_aspect_ratio) + if self.preserve_aspect_ratio: + # pad width + if not self.symmetric_pad: + offset = (0, 0) + elif self.output_size[0] == img.shape[0]: + offset = (0, int((self.output_size[1] - img.shape[1]) / 2)) + else: + offset = (int((self.output_size[0] - img.shape[0]) / 2), 0) + img = tf.image.pad_to_bounding_box(img, *offset, *self.output_size) + return tf.cast(img, dtype=input_dtype) + + +class Normalize(NestedObject): + """Normalize a tensor to a Gaussian distribution for each channel + + Example:: + >>> from doctr.transforms import Normalize + >>> import tensorflow as tf + >>> transfo = Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + Args: + mean: average value per channel + std: standard deviation per channel + """ + def __init__(self, mean: Tuple[float, float, float], std: Tuple[float, float, float]) -> None: + self.mean = tf.constant(mean) + self.std = tf.constant(std) + + def extra_repr(self) -> str: + return f"mean={self.mean.numpy().tolist()}, std={self.std.numpy().tolist()}" + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + img -= tf.cast(self.mean, dtype=img.dtype) + img /= tf.cast(self.std, dtype=img.dtype) + return img + + +class LambdaTransformation(NestedObject): + """Normalize a tensor to a Gaussian distribution for each channel + + Example:: + >>> from doctr.transforms import LambdaTransformation + >>> import tensorflow as tf + >>> transfo = LambdaTransformation(lambda x: x/ 255.) + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + Args: + fn: the function to be applied to the input tensor + """ + def __init__(self, fn: Callable[[tf.Tensor], tf.Tensor]) -> None: + self.fn = fn + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + return self.fn(img) + + +class ToGray(NestedObject): + """Convert a RGB tensor (batch of images or image) to a 3-channels grayscale tensor + + Example:: + >>> from doctr.transforms import Normalize + >>> import tensorflow as tf + >>> transfo = ToGray() + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + """ + def __init__(self, num_output_channels: int = 1): + self.num_output_channels = num_output_channels + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + img = tf.image.rgb_to_grayscale(img) + return img if self.num_output_channels == 1 else tf.repeat(img, self.num_output_channels, axis=-1) + + +class RandomBrightness(NestedObject): + """Randomly adjust brightness of a tensor (batch of images or image) by adding a delta + to all pixels + + Example: + >>> from doctr.transforms import Normalize + >>> import tensorflow as tf + >>> transfo = Brightness() + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + Args: + max_delta: offset to add to each pixel is randomly picked in [-max_delta, max_delta] + p: probability to apply transformation + """ + def __init__(self, max_delta: float = 0.3) -> None: + self.max_delta = max_delta + + def extra_repr(self) -> str: + return f"max_delta={self.max_delta}" + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + return tf.image.random_brightness(img, max_delta=self.max_delta) + + +class RandomContrast(NestedObject): + """Randomly adjust contrast of a tensor (batch of images or image) by adjusting + each pixel: (img - mean) * contrast_factor + mean. + + Example: + >>> from doctr.transforms import Normalize + >>> import tensorflow as tf + >>> transfo = Contrast() + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + Args: + delta: multiplicative factor is picked in [1-delta, 1+delta] (reduce contrast if factor<1) + """ + def __init__(self, delta: float = .3) -> None: + self.delta = delta + + def extra_repr(self) -> str: + return f"delta={self.delta}" + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + return tf.image.random_contrast(img, lower=1 - self.delta, upper=1 / (1 - self.delta)) + + +class RandomSaturation(NestedObject): + """Randomly adjust saturation of a tensor (batch of images or image) by converting to HSV and + increasing saturation by a factor. + + Example: + >>> from doctr.transforms import Normalize + >>> import tensorflow as tf + >>> transfo = Saturation() + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + Args: + delta: multiplicative factor is picked in [1-delta, 1+delta] (reduce saturation if factor<1) + """ + def __init__(self, delta: float = .5) -> None: + self.delta = delta + + def extra_repr(self) -> str: + return f"delta={self.delta}" + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + return tf.image.random_saturation(img, lower=1 - self.delta, upper=1 + self.delta) + + +class RandomHue(NestedObject): + """Randomly adjust hue of a tensor (batch of images or image) by converting to HSV and adding a delta + + Example:: + >>> from doctr.transforms import Normalize + >>> import tensorflow as tf + >>> transfo = Hue() + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + Args: + max_delta: offset to add to each pixel is randomly picked in [-max_delta, max_delta] + """ + def __init__(self, max_delta: float = 0.3) -> None: + self.max_delta = max_delta + + def extra_repr(self) -> str: + return f"max_delta={self.max_delta}" + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + return tf.image.random_hue(img, max_delta=self.max_delta) + + +class RandomGamma(NestedObject): + """randomly performs gamma correction for a tensor (batch of images or image) + + Example: + >>> from doctr.transforms import Normalize + >>> import tensorflow as tf + >>> transfo = Gamma() + >>> out = transfo(tf.random.uniform(shape=[8, 64, 64, 3], minval=0, maxval=1)) + + Args: + min_gamma: non-negative real number, lower bound for gamma param + max_gamma: non-negative real number, upper bound for gamma + min_gain: lower bound for constant multiplier + max_gain: upper bound for constant multiplier + """ + def __init__( + self, + min_gamma: float = 0.5, + max_gamma: float = 1.5, + min_gain: float = 0.8, + max_gain: float = 1.2, + ) -> None: + self.min_gamma = min_gamma + self.max_gamma = max_gamma + self.min_gain = min_gain + self.max_gain = max_gain + + def extra_repr(self) -> str: + return f"""gamma_range=({self.min_gamma}, {self.max_gamma}), + gain_range=({self.min_gain}, {self.max_gain})""" + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + gamma = random.uniform(self.min_gamma, self.max_gamma) + gain = random.uniform(self.min_gain, self.max_gain) + return tf.image.adjust_gamma(img, gamma=gamma, gain=gain) + + +class RandomJpegQuality(NestedObject): + """Randomly adjust jpeg quality of a 3 dimensional RGB image + + Example:: + >>> from doctr.transforms import Normalize + >>> import tensorflow as tf + >>> transfo = JpegQuality() + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + Args: + min_quality: int between [0, 100] + max_quality: int between [0, 100] + """ + def __init__(self, min_quality: int = 60, max_quality: int = 100) -> None: + self.min_quality = min_quality + self.max_quality = max_quality + + def extra_repr(self) -> str: + return f"min_quality={self.min_quality}" + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + return tf.image.random_jpeg_quality( + img, min_jpeg_quality=self.min_quality, max_jpeg_quality=self.max_quality + ) + + +class GaussianBlur(NestedObject): + """Randomly adjust jpeg quality of a 3 dimensional RGB image + + Example:: + >>> from doctr.transforms import GaussianBlur + >>> import tensorflow as tf + >>> transfo = GaussianBlur(3, (.1, 5)) + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + Args: + kernel_shape: size of the blurring kernel + std: min and max value of the standard deviation + """ + def __init__(self, kernel_shape: Union[int, Iterable[int]], std: Tuple[float, float]) -> None: + self.kernel_shape = kernel_shape + self.std = std + + def extra_repr(self) -> str: + return f"kernel_shape={self.kernel_shape}, std={self.std}" + + @tf.function + def __call__(self, img: tf.Tensor) -> tf.Tensor: + sigma = random.uniform(self.std[0], self.std[1]) + return tfa.image.gaussian_filter2d( + img, filter_shape=self.kernel_shape, sigma=sigma, + ) + + +class ChannelShuffle(NestedObject): + """Randomly shuffle channel order of a given image""" + + def __init__(self): + pass + + def __call__(self, img: tf.Tensor) -> tf.Tensor: + return tf.transpose(tf.random.shuffle(tf.transpose(img, perm=[2, 0, 1])), perm=[1, 2, 0]) + + +class GaussianNoise(NestedObject): + """Adds Gaussian Noise to the input tensor + + Example:: + >>> from doctr.transforms import GaussianNoise + >>> import tensorflow as tf + >>> transfo = GaussianNoise(0., 1.) + >>> out = transfo(tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1)) + + Args: + mean : mean of the gaussian distribution + std : std of the gaussian distribution + """ + def __init__(self, mean: float = 0., std: float = 1.) -> None: + super().__init__() + self.std = std + self.mean = mean + + def __call__(self, x: tf.Tensor) -> tf.Tensor: + # Reshape the distribution + noise = self.mean + 2 * self.std * tf.random.uniform(x.shape) - self.std + if x.dtype == tf.uint8: + return tf.cast( + tf.clip_by_value(tf.math.round(tf.cast(x, dtype=tf.float32) + 255 * noise), 0, 255), + dtype=tf.uint8 + ) + else: + return tf.cast(tf.clip_by_value(x + noise, 0, 1), dtype=x.dtype) + + def extra_repr(self) -> str: + return f"mean={self.mean}, std={self.std}" + + +class RandomHorizontalFlip(NestedObject): + """Adds random horizontal flip to the input tensor/np.ndarray + + Example:: + >>> from doctr.transforms import RandomHorizontalFlip + >>> import tensorflow as tf + >>> transfo = RandomHorizontalFlip(p=0.5) + >>> image = tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1) + >>> target = { + "boxes": np.array([[0.1, 0.1, 0.4, 0.5] ], dtype= np.float32), + "labels": np.ones(1, dtype= np.int64) + } + >>> out = transfo(image, target) + + Args: + p : probability of Horizontal Flip""" + def __init__(self, p: float) -> None: + super().__init__() + self.p = p + + def __call__( + self, + img: Union[tf.Tensor, np.ndarray], + target: Dict[str, Any] + ) -> Tuple[tf.Tensor, Dict[str, Any]]: + """ + Args: + img: Image to be flipped. + target: Dictionary with boxes (in relative coordinates of shape (N, 4)) and labels as keys + Returns: + Tuple of numpy nd-array or Tensor and target + """ + if np.random.rand(1) <= self.p: + _img = tf.image.flip_left_right(img) + _target = target.copy() + # Changing the relative bbox coordinates + _target["boxes"][:, ::2] = 1 - target["boxes"][:, [2, 0]] + return _img, _target + return img, target diff --git a/doctr/utils/__init__.py b/doctr/utils/__init__.py new file mode 100644 index 0000000000..eeb9b15920 --- /dev/null +++ b/doctr/utils/__init__.py @@ -0,0 +1,4 @@ +from .common_types import * +from .data import * +from .geometry import * +from .metrics import * diff --git a/doctr/utils/common_types.py b/doctr/utils/common_types.py new file mode 100644 index 0000000000..d3cc476f05 --- /dev/null +++ b/doctr/utils/common_types.py @@ -0,0 +1,18 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from pathlib import Path +from typing import List, Tuple, Union + +__all__ = ['Point2D', 'BoundingBox', 'Polygon4P', 'Polygon', 'Bbox'] + + +Point2D = Tuple[float, float] +BoundingBox = Tuple[Point2D, Point2D] +Polygon4P = Tuple[Point2D, Point2D, Point2D, Point2D] +Polygon = List[Point2D] +AbstractPath = Union[str, Path] +AbstractFile = Union[AbstractPath, bytes] +Bbox = Tuple[float, float, float, float] diff --git a/doctr/utils/data.py b/doctr/utils/data.py new file mode 100644 index 0000000000..b3aff3398b --- /dev/null +++ b/doctr/utils/data.py @@ -0,0 +1,109 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +# Adapted from https://github.com/pytorch/vision/blob/master/torchvision/datasets/utils.py + +import hashlib +import logging +import os +import re +import urllib +import urllib.error +import urllib.request +from pathlib import Path +from typing import Optional, Union + +from tqdm.auto import tqdm + +__all__ = ['download_from_url'] + + +# matches bfd8deac from resnet18-bfd8deac.ckpt +HASH_REGEX = re.compile(r'-([a-f0-9]*)\.') +USER_AGENT = "mindee/doctr" + + +def _urlretrieve(url: str, filename: Union[Path, str], chunk_size: int = 1024) -> None: + with open(filename, "wb") as fh: + with urllib.request.urlopen(urllib.request.Request(url, headers={"User-Agent": USER_AGENT})) as response: + with tqdm(total=response.length) as pbar: + for chunk in iter(lambda: response.read(chunk_size), ""): + if not chunk: + break + pbar.update(chunk_size) + fh.write(chunk) + + +def _check_integrity(file_path: Union[str, Path], hash_prefix: str) -> bool: + with open(file_path, 'rb') as f: + sha_hash = hashlib.sha256(f.read()).hexdigest() + + return sha_hash[:len(hash_prefix)] == hash_prefix + + +def download_from_url( + url: str, + file_name: Optional[str] = None, + hash_prefix: Optional[str] = None, + cache_dir: Optional[str] = None, + cache_subdir: Optional[str] = None, +) -> Path: + """Download a file using its URL + + Example:: + >>> from doctr.models import download_from_url + >>> download_from_url("https://yoursource.com/yourcheckpoint-yourhash.zip") + + Args: + url: the URL of the file to download + file_name: optional name of the file once downloaded + hash_prefix: optional expected SHA256 hash of the file + cache_dir: cache directory + cache_subdir: subfolder to use in the cache + + Returns: + the location of the downloaded file + """ + + if not isinstance(file_name, str): + file_name = url.rpartition('/')[-1] + + if not isinstance(cache_dir, str): + cache_dir = os.path.join(os.path.expanduser('~'), '.cache', 'doctr') + + # Check hash in file name + if hash_prefix is None: + r = HASH_REGEX.search(file_name) + hash_prefix = r.group(1) if r else None + + folder_path = Path(cache_dir) if cache_subdir is None else Path(cache_dir, cache_subdir) + file_path = folder_path.joinpath(file_name) + # Check file existence + if file_path.is_file() and (hash_prefix is None or _check_integrity(file_path, hash_prefix)): + logging.info(f"Using downloaded & verified file: {file_path}") + return file_path + + # Create folder hierarchy + folder_path.mkdir(parents=True, exist_ok=True) + # Download the file + try: + print(f"Downloading {url} to {file_path}") + _urlretrieve(url, file_path) + except (urllib.error.URLError, IOError) as e: # type: ignore[attr-defined] + if url[:5] == 'https': + url = url.replace('https:', 'http:') + print('Failed download. Trying https -> http instead.' + f" Downloading {url} to {file_path}") + _urlretrieve(url, file_path) + else: + raise e + + # Remove corrupted files + if isinstance(hash_prefix, str) and not _check_integrity(file_path, hash_prefix): + # Remove file + os.remove(file_path) + raise ValueError(f"corrupted download, the hash of {url} does not match its expected value") + + return file_path diff --git a/doctr/utils/fonts.py b/doctr/utils/fonts.py new file mode 100644 index 0000000000..51769ba74a --- /dev/null +++ b/doctr/utils/fonts.py @@ -0,0 +1,38 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import logging +import platform +from typing import Optional + +from PIL import ImageFont + +__all__ = ['get_font'] + + +def get_font(font_family: Optional[str] = None, font_size: int = 13) -> ImageFont.ImageFont: + """Resolves a compatible ImageFont for the system + + Args: + font_family: the font family to use + font_size: the size of the font upon rendering + + Returns: + the Pillow font + """ + + # Font selection + if font_family is None: + try: + font = ImageFont.truetype("FreeMono.ttf" if platform.system() == "Linux" else "Arial.ttf", font_size) + except OSError: + font = ImageFont.load_default() + logging.warning("unable to load recommended font family. Loading default PIL font," + "font size issues may be expected." + "To prevent this, it is recommended to specify the value of 'font_family'.") + else: + font = ImageFont.truetype(font_family, font_size) + + return font diff --git a/doctr/utils/geometry.py b/doctr/utils/geometry.py new file mode 100644 index 0000000000..ce52ed93bb --- /dev/null +++ b/doctr/utils/geometry.py @@ -0,0 +1,262 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from math import ceil +from typing import List, Tuple, Union + +import cv2 +import numpy as np + +from .common_types import BoundingBox, Polygon4P + +__all__ = ['bbox_to_polygon', 'polygon_to_bbox', 'resolve_enclosing_bbox', 'resolve_enclosing_rbbox', + 'rotate_boxes', 'compute_expanded_shape', 'rotate_image', 'estimate_page_angle', + 'convert_to_relative_coords', 'rotate_abs_geoms'] + + +def bbox_to_polygon(bbox: BoundingBox) -> Polygon4P: + return bbox[0], (bbox[1][0], bbox[0][1]), (bbox[0][0], bbox[1][1]), bbox[1] + + +def polygon_to_bbox(polygon: Polygon4P) -> BoundingBox: + x, y = zip(*polygon) + return (min(x), min(y)), (max(x), max(y)) + + +def resolve_enclosing_bbox(bboxes: Union[List[BoundingBox], np.ndarray]) -> Union[BoundingBox, np.ndarray]: + """Compute enclosing bbox either from: + + - an array of boxes: (*, 5), where boxes have this shape: + (xmin, ymin, xmax, ymax, score) + + - a list of BoundingBox + + Return a (1, 5) array (enclosing boxarray), or a BoundingBox + """ + if isinstance(bboxes, np.ndarray): + xmin, ymin, xmax, ymax, score = np.split(bboxes, 5, axis=1) + return np.array([xmin.min(), ymin.min(), xmax.max(), ymax.max(), score.mean()]) + else: + x, y = zip(*[point for box in bboxes for point in box]) + return (min(x), min(y)), (max(x), max(y)) + + +def resolve_enclosing_rbbox(rbboxes: List[np.ndarray], intermed_size: int = 1024) -> np.ndarray: + cloud = np.concatenate(rbboxes, axis=0) + # Convert to absolute for minAreaRect + cloud *= intermed_size + rect = cv2.minAreaRect(cloud.astype(np.int32)) + return cv2.boxPoints(rect) / intermed_size + + +def rotate_abs_points(points: np.ndarray, angle: float = 0.) -> np.ndarray: + """Rotate points counter-clockwise. + Points: array of size (N, 2) + """ + + angle_rad = angle * np.pi / 180. # compute radian angle for np functions + rotation_mat = np.array([ + [np.cos(angle_rad), -np.sin(angle_rad)], + [np.sin(angle_rad), np.cos(angle_rad)] + ], dtype=points.dtype) + return np.matmul(points, rotation_mat.T) + + +def compute_expanded_shape(img_shape: Tuple[int, int], angle: float) -> Tuple[int, int]: + """Compute the shape of an expanded rotated image + + Args: + img_shape: the height and width of the image + angle: angle between -90 and +90 degrees + + Returns: + the height and width of the rotated image + """ + + points = np.array([ + [img_shape[1] / 2, img_shape[0] / 2], + [-img_shape[1] / 2, img_shape[0] / 2], + ]) + + rotated_points = rotate_abs_points(points, angle) + + wh_shape = 2 * np.abs(rotated_points).max(axis=0) + return wh_shape[1], wh_shape[0] + + +def rotate_abs_geoms( + geoms: np.ndarray, + angle: float, + img_shape: Tuple[int, int], + expand: bool = True, +) -> np.ndarray: + """Rotate a batch of bounding boxes or polygons by an angle around the + image center. + + Args: + boxes: (N, 4) or (N, 4, 2) array of ABSOLUTE coordinate boxes + angle: anti-clockwise rotation angle in degrees + img_shape: the height and width of the image + expand: whether the image should be padded to avoid information loss + + Returns: + A batch of rotated polygons (N, 4, 2) + """ + + # Switch to polygons + polys = np.stack( + [geoms[:, [0, 1]], geoms[:, [2, 1]], geoms[:, [2, 3]], geoms[:, [0, 3]]], + axis=1 + ) if geoms.ndim == 2 else geoms + polys = polys.astype(np.float32) + + # Switch to image center as referential + polys[..., 0] -= img_shape[1] / 2 + polys[..., 1] = img_shape[0] / 2 - polys[..., 1] + + # Rotated them around image center + rotated_polys = rotate_abs_points(polys.reshape(-1, 2), angle).reshape(-1, 4, 2) + # Switch back to top-left corner as referential + target_shape = compute_expanded_shape(img_shape, angle) if expand else img_shape + # Clip coords to fit since there is no expansion + rotated_polys[..., 0] = (rotated_polys[..., 0] + target_shape[1] / 2).clip(0, target_shape[1]) + rotated_polys[..., 1] = (target_shape[0] / 2 - rotated_polys[..., 1]).clip(0, target_shape[0]) + + return rotated_polys + + +def rotate_boxes( + loc_preds: np.ndarray, + angle: float, + orig_shape: Tuple[int, int], + min_angle: float = 1., +) -> np.ndarray: + """Rotate a batch of straight bounding boxes (xmin, ymin, xmax, ymax, c) or rotated bounding boxes + (4, 2) of an angle, if angle > min_angle, around the center of the page. + If target_shape is specified, the boxes are remapped to the target shape after the rotation. This + is done to remove the padding that is created by rotate_page(expand=True) + + Args: + loc_preds: (N, 5) or (N, 4, 2) array of RELATIVE boxes + angle: angle between -90 and +90 degrees + orig_shape: shape of the origin image + min_angle: minimum angle to rotate boxes + + Returns: + A batch of rotated boxes (N, 4, 2): or a batch of straight bounding boxes + """ + + # Change format of the boxes to rotated boxes + _boxes = loc_preds.copy() + if _boxes.ndim == 2: + _boxes = np.stack( + [ + _boxes[:, [0, 1]], + _boxes[:, [2, 1]], + _boxes[:, [2, 3]], + _boxes[:, [0, 3]], + ], + axis=1 + ) + # If small angle, return boxes (no rotation) + if abs(angle) < min_angle or abs(angle) > 90 - min_angle: + return _boxes + # Compute rotation matrix + angle_rad = angle * np.pi / 180. # compute radian angle for np functions + rotation_mat = np.array([ + [np.cos(angle_rad), -np.sin(angle_rad)], + [np.sin(angle_rad), np.cos(angle_rad)] + ], dtype=_boxes.dtype) + # Rotate absolute points + points = np.stack((_boxes[:, :, 0] * orig_shape[1], _boxes[:, :, 1] * orig_shape[0]), axis=-1) + image_center = (orig_shape[1] / 2, orig_shape[0] / 2) + rotated_points = image_center + np.matmul(points - image_center, rotation_mat) + rotated_boxes = np.stack( + (rotated_points[:, :, 0] / orig_shape[1], rotated_points[:, :, 1] / orig_shape[0]), axis=-1 + ) + return rotated_boxes + + +def rotate_image( + image: np.ndarray, + angle: float, + expand: bool = False, + preserve_origin_shape: bool = False, +) -> np.ndarray: + """Rotate an image counterclockwise by an given angle. + + Args: + image: numpy tensor to rotate + angle: rotation angle in degrees, between -90 and +90 + expand: whether the image should be padded before the rotation + preserve_origin_shape: if expand is set to True, resizes the final output to the original image size + + Returns: + Rotated array, padded by 0 by default. + """ + + # Compute the expanded padding + if expand: + exp_shape = compute_expanded_shape(image.shape[:-1], angle) + h_pad, w_pad = int(max(0, ceil(exp_shape[0] - image.shape[0]))), int( + max(0, ceil(exp_shape[1] - image.shape[1]))) + exp_img = np.pad(image, ((h_pad // 2, h_pad - h_pad // 2), (w_pad // 2, w_pad - w_pad // 2), (0, 0))) + else: + exp_img = image + + height, width = exp_img.shape[:2] + rot_mat = cv2.getRotationMatrix2D((width / 2, height / 2), angle, 1.0) + rot_img = cv2.warpAffine(exp_img, rot_mat, (width, height)) + if expand: + # Pad to get the same aspect ratio + if (image.shape[0] / image.shape[1]) != (rot_img.shape[0] / rot_img.shape[1]): + # Pad width + if (rot_img.shape[0] / rot_img.shape[1]) > (image.shape[0] / image.shape[1]): + h_pad, w_pad = 0, int(rot_img.shape[0] * image.shape[1] / image.shape[0] - rot_img.shape[1]) + # Pad height + else: + h_pad, w_pad = int(rot_img.shape[1] * image.shape[0] / image.shape[1] - rot_img.shape[0]), 0 + rot_img = np.pad(rot_img, ((h_pad // 2, h_pad - h_pad // 2), (w_pad // 2, w_pad - w_pad // 2), (0, 0))) + if preserve_origin_shape: + # rescale + rot_img = cv2.resize(rot_img, image.shape[:-1][::-1], interpolation=cv2.INTER_LINEAR) + + return rot_img + + +def estimate_page_angle(polys: np.ndarray) -> float: + """Takes a batch of rotated previously ORIENTED polys (N, 4, 2) (rectified by the classifier) and return the + estimated angle ccw in degrees + """ + return np.median(np.arctan( + (polys[:, 0, 1] - polys[:, 1, 1]) / # Y axis from top to bottom! + (polys[:, 1, 0] - polys[:, 0, 0]) + )) * 180 / np.pi + + +def convert_to_relative_coords(geoms: np.ndarray, img_shape: Tuple[int, int]) -> np.ndarray: + """Convert a geometry to relative coordinates + + Args: + geoms: a set of polygons of shape (N, 4, 2) or of straight boxes of shape (N, 4) + img_shape: the height and width of the image + + Returns: + the updated geometry + """ + + # Polygon + if geoms.ndim == 3 and geoms.shape[1:] == (4, 2): + polygons = np.empty(geoms.shape, dtype=np.float32) + polygons[..., 0] = geoms[..., 0] / img_shape[1] + polygons[..., 1] = geoms[..., 1] / img_shape[0] + return polygons.clip(0, 1) + if geoms.ndim == 2 and geoms.shape[1] == 4: + boxes = np.empty(geoms.shape, dtype=np.float32) + boxes[:, ::2] = geoms[:, ::2] / img_shape[1] + boxes[:, 1::2] = geoms[:, 1::2] / img_shape[0] + return boxes.clip(0, 1) + + raise ValueError(f"invalid format for arg `geoms`: {geoms.shape}") diff --git a/doctr/utils/metrics.py b/doctr/utils/metrics.py new file mode 100644 index 0000000000..4c01574a15 --- /dev/null +++ b/doctr/utils/metrics.py @@ -0,0 +1,692 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Dict, List, Optional, Tuple + +import cv2 +import numpy as np +from scipy.optimize import linear_sum_assignment +from unidecode import unidecode + +__all__ = ['TextMatch', 'box_iou', 'box_ioa', 'mask_iou', 'polygon_iou', + 'nms', 'LocalizationConfusion', 'OCRMetric', 'DetectionMetric'] + + +def string_match(word1: str, word2: str) -> Tuple[bool, bool, bool, bool]: + """Performs string comparison with multiple levels of tolerance + + Args: + word1: a string + word2: another string + + Returns: + a tuple with booleans specifying respectively whether the raw strings, their lower-case counterparts, their + unidecode counterparts and their lower-case unidecode counterparts match + """ + raw_match = (word1 == word2) + caseless_match = (word1.lower() == word2.lower()) + unidecode_match = (unidecode(word1) == unidecode(word2)) + + # Warning: the order is important here otherwise the pair ("EUR", "€") cannot be matched + unicase_match = (unidecode(word1).lower() == unidecode(word2).lower()) + + return raw_match, caseless_match, unidecode_match, unicase_match + + +class TextMatch: + r"""Implements text match metric (word-level accuracy) for recognition task. + + The raw aggregated metric is computed as follows: + + .. math:: + \forall X, Y \in \mathcal{W}^N, + TextMatch(X, Y) = \frac{1}{N} \sum\limits_{i=1}^N f_{Y_i}(X_i) + + with the indicator function :math:`f_{a}` defined as: + + .. math:: + \forall a, x \in \mathcal{W}, + f_a(x) = \left\{ + \begin{array}{ll} + 1 & \mbox{if } x = a \\ + 0 & \mbox{otherwise.} + \end{array} + \right. + + where :math:`\mathcal{W}` is the set of all possible character sequences, + :math:`N` is a strictly positive integer. + + Example:: + >>> from doctr.utils import TextMatch + >>> metric = TextMatch() + >>> metric.update(['Hello', 'world'], ['hello', 'world']) + >>> metric.summary() + """ + + def __init__(self) -> None: + self.reset() + + def update( + self, + gt: List[str], + pred: List[str], + ) -> None: + """Update the state of the metric with new predictions + + Args: + gt: list of groung-truth character sequences + pred: list of predicted character sequences + """ + + if len(gt) != len(pred): + raise AssertionError("prediction size does not match with ground-truth labels size") + + for gt_word, pred_word in zip(gt, pred): + _raw, _caseless, _unidecode, _unicase = string_match(gt_word, pred_word) + self.raw += int(_raw) + self.caseless += int(_caseless) + self.unidecode += int(_unidecode) + self.unicase += int(_unicase) + + self.total += len(gt) + + def summary(self) -> Dict[str, float]: + """Computes the aggregated metrics + + Returns: + a dictionary with the exact match score for the raw data, its lower-case counterpart, its unidecode + counterpart and its lower-case unidecode counterpart + """ + if self.total == 0: + raise AssertionError("you need to update the metric before getting the summary") + + return dict( + raw=self.raw / self.total, + caseless=self.caseless / self.total, + unidecode=self.unidecode / self.total, + unicase=self.unicase / self.total, + ) + + def reset(self) -> None: + self.raw = 0 + self.caseless = 0 + self.unidecode = 0 + self.unicase = 0 + self.total = 0 + + +def box_iou(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray: + """Computes the IoU between two sets of bounding boxes + + Args: + boxes_1: bounding boxes of shape (N, 4) in format (xmin, ymin, xmax, ymax) + boxes_2: bounding boxes of shape (M, 4) in format (xmin, ymin, xmax, ymax) + Returns: + the IoU matrix of shape (N, M) + """ + + iou_mat = np.zeros((boxes_1.shape[0], boxes_2.shape[0]), dtype=np.float32) + + if boxes_1.shape[0] > 0 and boxes_2.shape[0] > 0: + l1, t1, r1, b1 = np.split(boxes_1, 4, axis=1) + l2, t2, r2, b2 = np.split(boxes_2, 4, axis=1) + + left = np.maximum(l1, l2.T) + top = np.maximum(t1, t2.T) + right = np.minimum(r1, r2.T) + bot = np.minimum(b1, b2.T) + + intersection = np.clip(right - left, 0, np.Inf) * np.clip(bot - top, 0, np.Inf) + union = (r1 - l1) * (b1 - t1) + ((r2 - l2) * (b2 - t2)).T - intersection + iou_mat = intersection / union + + return iou_mat + + +def box_ioa(boxes_1: np.ndarray, boxes_2: np.ndarray) -> np.ndarray: + """Computes the IoA (intersection over area) between two sets of bounding boxes: + ioa(i, j) = inter(i, j) / area(i) + + Args: + boxes_1: bounding boxes of shape (N, 4) in format (xmin, ymin, xmax, ymax) + boxes_2: bounding boxes of shape (M, 4) in format (xmin, ymin, xmax, ymax) + Returns: + the IoA matrix of shape (N, M) + """ + + ioa_mat = np.zeros((boxes_1.shape[0], boxes_2.shape[0]), dtype=np.float32) + + if boxes_1.shape[0] > 0 and boxes_2.shape[0] > 0: + l1, t1, r1, b1 = np.split(boxes_1, 4, axis=1) + l2, t2, r2, b2 = np.split(boxes_2, 4, axis=1) + + left = np.maximum(l1, l2.T) + top = np.maximum(t1, t2.T) + right = np.minimum(r1, r2.T) + bot = np.minimum(b1, b2.T) + + intersection = np.clip(right - left, 0, np.Inf) * np.clip(bot - top, 0, np.Inf) + area = (r1 - l1) * (b1 - t1) + ioa_mat = intersection / area + + return ioa_mat + + +def mask_iou(masks_1: np.ndarray, masks_2: np.ndarray) -> np.ndarray: + """Computes the IoU between two sets of boolean masks + + Args: + masks_1: boolean masks of shape (N, H, W) + masks_2: boolean masks of shape (M, H, W) + + Returns: + the IoU matrix of shape (N, M) + """ + + if masks_1.shape[1:] != masks_2.shape[1:]: + raise AssertionError("both boolean masks should have the same spatial shape") + + iou_mat = np.zeros((masks_1.shape[0], masks_2.shape[0]), dtype=np.float32) + + if masks_1.shape[0] > 0 and masks_2.shape[0] > 0: + axes = tuple(range(2, masks_1.ndim + 1)) + intersection = np.logical_and(masks_1[:, None, ...], masks_2[None, ...]).sum(axis=axes) + union = np.logical_or(masks_1[:, None, ...], masks_2[None, ...]).sum(axis=axes) + iou_mat = intersection / union + + return iou_mat + + +def polygon_iou( + polys_1: np.ndarray, + polys_2: np.ndarray, + mask_shape: Tuple[int, int], + use_broadcasting: bool = False +) -> np.ndarray: + """Computes the IoU between two sets of rotated bounding boxes + + Args: + polys_1: rotated bounding boxes of shape (N, 4, 2) + polys_2: rotated bounding boxes of shape (M, 4, 2) + mask_shape: spatial shape of the intermediate masks + use_broadcasting: if set to True, leverage broadcasting speedup by consuming more memory + + Returns: + the IoU matrix of shape (N, M) + """ + + if polys_1.ndim != 3 or polys_2.ndim != 3: + raise AssertionError("expects boxes to be in format (N, 4, 2)") + + iou_mat = np.zeros((polys_1.shape[0], polys_2.shape[0]), dtype=np.float32) + + if polys_1.shape[0] > 0 and polys_2.shape[0] > 0: + if use_broadcasting: + masks_1 = rbox_to_mask(polys_1, shape=mask_shape) + masks_2 = rbox_to_mask(polys_2, shape=mask_shape) + iou_mat = mask_iou(masks_1, masks_2) + else: + # Save memory by doing the computation for each pair + for idx, b1 in enumerate(polys_1): + m1 = _rbox_to_mask(b1, mask_shape) + for _idx, b2 in enumerate(polys_2): + m2 = _rbox_to_mask(b2, mask_shape) + iou_mat[idx, _idx] = np.logical_and(m1, m2).sum() / np.logical_or(m1, m2).sum() + + return iou_mat + + +def _rbox_to_mask(box: np.ndarray, shape: Tuple[int, int]) -> np.ndarray: + """Converts a rotated bounding box to a boolean mask + + Args: + box: rotated bounding box of shape (4, 2) + shape: spatial shapes of the output masks + + Returns: + the boolean mask of the specified shape + """ + + mask = np.zeros(shape, dtype=np.uint8) + # Get absolute coords + if box.dtype != int: + abs_box = box.copy() + abs_box[:, 0] = abs_box[:, 0] * shape[1] + abs_box[:, 1] = abs_box[:, 1] * shape[0] + abs_box = abs_box.round().astype(int) + else: + abs_box = box + abs_box[2:] = abs_box[2:] + 1 + cv2.fillPoly(mask, [abs_box - 1], 1) + + return mask.astype(bool) + + +def rbox_to_mask(boxes: np.ndarray, shape: Tuple[int, int]) -> np.ndarray: + """Converts rotated bounding boxes to boolean masks + + Args: + boxes: rotated bounding boxes of shape (N, 4, 2) + shape: spatial shapes of the output masks + + Returns: + the boolean masks of shape (N, H, W) + """ + + masks = np.zeros((boxes.shape[0], *shape), dtype=np.uint8) + + if boxes.shape[0] > 0: + # Get absolute coordinates + if boxes.dtype != np.int: + abs_boxes = boxes.copy() + abs_boxes[:, :, 0] = abs_boxes[:, :, 0] * shape[1] + abs_boxes[:, :, 1] = abs_boxes[:, :, 1] * shape[0] + abs_boxes = abs_boxes.round().astype(np.int) + else: + abs_boxes = boxes + abs_boxes[:, 2:] = abs_boxes[:, 2:] + 1 + + # TODO: optimize slicing to improve vectorization + for idx, _box in enumerate(abs_boxes): + cv2.fillPoly(masks[idx], [_box - 1], 1) + return masks.astype(bool) + + +def nms(boxes: np.ndarray, thresh: float = .5) -> List[int]: + """Perform non-max suppression, borrowed from `_. + + Args: + boxes: np array of straight boxes: (*, 5), (xmin, ymin, xmax, ymax, score) + thresh: iou threshold to perform box suppression. + + Returns: + A list of box indexes to keep + """ + x1 = boxes[:, 0] + y1 = boxes[:, 1] + x2 = boxes[:, 2] + y2 = boxes[:, 3] + scores = boxes[:, 4] + + areas = (x2 - x1) * (y2 - y1) + order = scores.argsort()[::-1] + + keep = [] + while order.size > 0: + i = order[0] + keep.append(i) + xx1 = np.maximum(x1[i], x1[order[1:]]) + yy1 = np.maximum(y1[i], y1[order[1:]]) + xx2 = np.minimum(x2[i], x2[order[1:]]) + yy2 = np.minimum(y2[i], y2[order[1:]]) + + w = np.maximum(0.0, xx2 - xx1) + h = np.maximum(0.0, yy2 - yy1) + inter = w * h + ovr = inter / (areas[i] + areas[order[1:]] - inter) + + inds = np.where(ovr <= thresh)[0] + order = order[inds + 1] + return keep + + +class LocalizationConfusion: + r"""Implements common confusion metrics and mean IoU for localization evaluation. + + The aggregated metrics are computed as follows: + + .. math:: + \forall Y \in \mathcal{B}^N, \forall X \in \mathcal{B}^M, \\ + Recall(X, Y) = \frac{1}{N} \sum\limits_{i=1}^N g_{X}(Y_i) \\ + Precision(X, Y) = \frac{1}{M} \sum\limits_{i=1}^M g_{X}(Y_i) \\ + meanIoU(X, Y) = \frac{1}{M} \sum\limits_{i=1}^M \max\limits_{j \in [1, N]} IoU(X_i, Y_j) + + with the function :math:`IoU(x, y)` being the Intersection over Union between bounding boxes :math:`x` and + :math:`y`, and the function :math:`g_{X}` defined as: + + .. math:: + \forall y \in \mathcal{B}, + g_X(y) = \left\{ + \begin{array}{ll} + 1 & \mbox{if } y\mbox{ has been assigned to any }(X_i)_i\mbox{ with an }IoU \geq 0.5 \\ + 0 & \mbox{otherwise.} + \end{array} + \right. + + where :math:`\mathcal{B}` is the set of possible bounding boxes, + :math:`N` (number of ground truths) and :math:`M` (number of predictions) are strictly positive integers. + + Example:: + >>> import numpy as np + >>> from doctr.utils import LocalizationConfusion + >>> metric = LocalizationConfusion(iou_thresh=0.5) + >>> metric.update(np.asarray([[0, 0, 100, 100]]), np.asarray([[0, 0, 70, 70], [110, 95, 200, 150]])) + >>> metric.summary() + + Args: + iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match + use_polygons: if set to True, predictions and targets will be expected to have rotated format + mask_shape: if use_polygons is True, describes the spatial shape of the image used + use_broadcasting: if use_polygons is True, use broadcasting for IoU computation by consuming more memory + """ + + def __init__( + self, + iou_thresh: float = 0.5, + use_polygons: bool = False, + mask_shape: Tuple[int, int] = (1024, 1024), + use_broadcasting: bool = True, + ) -> None: + self.iou_thresh = iou_thresh + self.use_polygons = use_polygons + self.mask_shape = mask_shape + self.use_broadcasting = use_broadcasting + self.reset() + + def update(self, gts: np.ndarray, preds: np.ndarray) -> None: + """Updates the metric + + Args: + gts: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones + preds: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones + """ + + if preds.shape[0] > 0: + # Compute IoU + if self.use_polygons: + iou_mat = polygon_iou(gts, preds, self.mask_shape, self.use_broadcasting) + else: + iou_mat = box_iou(gts, preds) + self.tot_iou += float(iou_mat.max(axis=0).sum()) + + # Assign pairs + gt_indices, pred_indices = linear_sum_assignment(-iou_mat) + self.matches += int((iou_mat[gt_indices, pred_indices] >= self.iou_thresh).sum()) + + # Update counts + self.num_gts += gts.shape[0] + self.num_preds += preds.shape[0] + + def summary(self) -> Tuple[Optional[float], Optional[float], Optional[float]]: + """Computes the aggregated metrics + + Returns: + a tuple with the recall, precision and meanIoU scores + """ + + # Recall + recall = self.matches / self.num_gts if self.num_gts > 0 else None + + # Precision + precision = self.matches / self.num_preds if self.num_preds > 0 else None + + # mean IoU + mean_iou = self.tot_iou / self.num_preds if self.num_preds > 0 else None + + return recall, precision, mean_iou + + def reset(self) -> None: + self.num_gts = 0 + self.num_preds = 0 + self.matches = 0 + self.tot_iou = 0. + + +class OCRMetric: + r"""Implements an end-to-end OCR metric. + + The aggregated metrics are computed as follows: + + .. math:: + \forall (B, L) \in \mathcal{B}^N \times \mathcal{L}^N, + \forall (\hat{B}, \hat{L}) \in \mathcal{B}^M \times \mathcal{L}^M, \\ + Recall(B, \hat{B}, L, \hat{L}) = \frac{1}{N} \sum\limits_{i=1}^N h_{B,L}(\hat{B}_i, \hat{L}_i) \\ + Precision(B, \hat{B}, L, \hat{L}) = \frac{1}{M} \sum\limits_{i=1}^M h_{B,L}(\hat{B}_i, \hat{L}_i) \\ + meanIoU(B, \hat{B}) = \frac{1}{M} \sum\limits_{i=1}^M \max\limits_{j \in [1, N]} IoU(\hat{B}_i, B_j) + + with the function :math:`IoU(x, y)` being the Intersection over Union between bounding boxes :math:`x` and + :math:`y`, and the function :math:`h_{B, L}` defined as: + + .. math:: + \forall (b, l) \in \mathcal{B} \times \mathcal{L}, + h_{B,L}(b, l) = \left\{ + \begin{array}{ll} + 1 & \mbox{if } b\mbox{ has been assigned to a given }B_j\mbox{ with an } \\ + & IoU \geq 0.5 \mbox{ and that for this assignment, } l = L_j\\ + 0 & \mbox{otherwise.} + \end{array} + \right. + + where :math:`\mathcal{B}` is the set of possible bounding boxes, + :math:`\mathcal{L}` is the set of possible character sequences, + :math:`N` (number of ground truths) and :math:`M` (number of predictions) are strictly positive integers. + + Example:: + >>> import numpy as np + >>> from doctr.utils import OCRMetric + >>> metric = OCRMetric(iou_thresh=0.5) + >>> metric.update(np.asarray([[0, 0, 100, 100]]), np.asarray([[0, 0, 70, 70], [110, 95, 200, 150]]), + ['hello'], ['hello', 'world']) + >>> metric.summary() + + Args: + iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match + use_polygons: if set to True, predictions and targets will be expected to have rotated format + mask_shape: if use_polygons is True, describes the spatial shape of the image used + use_broadcasting: if use_polygons is True, use broadcasting for IoU computation by consuming more memory + """ + + def __init__( + self, + iou_thresh: float = 0.5, + use_polygons: bool = False, + mask_shape: Tuple[int, int] = (1024, 1024), + use_broadcasting: bool = True, + ) -> None: + self.iou_thresh = iou_thresh + self.use_polygons = use_polygons + self.mask_shape = mask_shape + self.use_broadcasting = use_broadcasting + self.reset() + + def update( + self, + gt_boxes: np.ndarray, + pred_boxes: np.ndarray, + gt_labels: List[str], + pred_labels: List[str], + ) -> None: + """Updates the metric + + Args: + gt_boxes: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones + pred_boxes: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones + gt_labels: a list of N string labels + pred_labels: a list of M string labels + """ + + if gt_boxes.shape[0] != len(gt_labels) or pred_boxes.shape[0] != len(pred_labels): + raise AssertionError("there should be the same number of boxes and string both for the ground truth " + "and the predictions") + + # Compute IoU + if pred_boxes.shape[0] > 0: + if self.use_polygons: + iou_mat = polygon_iou(gt_boxes, pred_boxes, self.mask_shape, self.use_broadcasting) + else: + iou_mat = box_iou(gt_boxes, pred_boxes) + + self.tot_iou += float(iou_mat.max(axis=0).sum()) + + # Assign pairs + gt_indices, pred_indices = linear_sum_assignment(-iou_mat) + is_kept = iou_mat[gt_indices, pred_indices] >= self.iou_thresh + # String comparison + for gt_idx, pred_idx in zip(gt_indices[is_kept], pred_indices[is_kept]): + _raw, _caseless, _unidecode, _unicase = string_match(gt_labels[gt_idx], pred_labels[pred_idx]) + self.raw_matches += int(_raw) + self.caseless_matches += int(_caseless) + self.unidecode_matches += int(_unidecode) + self.unicase_matches += int(_unicase) + + self.num_gts += gt_boxes.shape[0] + self.num_preds += pred_boxes.shape[0] + + def summary(self) -> Tuple[Dict[str, Optional[float]], Dict[str, Optional[float]], Optional[float]]: + """Computes the aggregated metrics + + Returns: + a tuple with the recall & precision for each string comparison and the mean IoU + """ + + # Recall + recall = dict( + raw=self.raw_matches / self.num_gts if self.num_gts > 0 else None, + caseless=self.caseless_matches / self.num_gts if self.num_gts > 0 else None, + unidecode=self.unidecode_matches / self.num_gts if self.num_gts > 0 else None, + unicase=self.unicase_matches / self.num_gts if self.num_gts > 0 else None, + ) + + # Precision + precision = dict( + raw=self.raw_matches / self.num_preds if self.num_preds > 0 else None, + caseless=self.caseless_matches / self.num_preds if self.num_preds > 0 else None, + unidecode=self.unidecode_matches / self.num_preds if self.num_preds > 0 else None, + unicase=self.unicase_matches / self.num_preds if self.num_preds > 0 else None, + ) + + # mean IoU (overall detected boxes) + mean_iou = self.tot_iou / self.num_preds if self.num_preds > 0 else None + + return recall, precision, mean_iou + + def reset(self) -> None: + self.num_gts = 0 + self.num_preds = 0 + self.tot_iou = 0. + self.raw_matches = 0 + self.caseless_matches = 0 + self.unidecode_matches = 0 + self.unicase_matches = 0 + + +class DetectionMetric: + r"""Implements an object detection metric. + + The aggregated metrics are computed as follows: + + .. math:: + \forall (B, C) \in \mathcal{B}^N \times \mathcal{C}^N, + \forall (\hat{B}, \hat{C}) \in \mathcal{B}^M \times \mathcal{C}^M, \\ + Recall(B, \hat{B}, C, \hat{C}) = \frac{1}{N} \sum\limits_{i=1}^N h_{B,C}(\hat{B}_i, \hat{C}_i) \\ + Precision(B, \hat{B}, C, \hat{C}) = \frac{1}{M} \sum\limits_{i=1}^M h_{B,C}(\hat{B}_i, \hat{C}_i) \\ + meanIoU(B, \hat{B}) = \frac{1}{M} \sum\limits_{i=1}^M \max\limits_{j \in [1, N]} IoU(\hat{B}_i, B_j) + + with the function :math:`IoU(x, y)` being the Intersection over Union between bounding boxes :math:`x` and + :math:`y`, and the function :math:`h_{B, C}` defined as: + + .. math:: + \forall (b, c) \in \mathcal{B} \times \mathcal{C}, + h_{B,C}(b, c) = \left\{ + \begin{array}{ll} + 1 & \mbox{if } b\mbox{ has been assigned to a given }B_j\mbox{ with an } \\ + & IoU \geq 0.5 \mbox{ and that for this assignment, } c = C_j\\ + 0 & \mbox{otherwise.} + \end{array} + \right. + + where :math:`\mathcal{B}` is the set of possible bounding boxes, + :math:`\mathcal{C}` is the set of possible class indices, + :math:`N` (number of ground truths) and :math:`M` (number of predictions) are strictly positive integers. + + Example:: + >>> import numpy as np + >>> from doctr.utils import DetectionMetric + >>> metric = DetectionMetric(iou_thresh=0.5) + >>> metric.update(np.asarray([[0, 0, 100, 100]]), np.asarray([[0, 0, 70, 70], [110, 95, 200, 150]]), + np.zeros(1, dtype=np.int64), np.array([0, 1], dtype=np.int64)) + >>> metric.summary() + + Args: + iou_thresh: minimum IoU to consider a pair of prediction and ground truth as a match + use_polygons: if set to True, predictions and targets will be expected to have rotated format + mask_shape: if use_polygons is True, describes the spatial shape of the image used + use_broadcasting: if use_polygons is True, use broadcasting for IoU computation by consuming more memory + """ + + def __init__( + self, + iou_thresh: float = 0.5, + use_polygons: bool = False, + mask_shape: Tuple[int, int] = (1024, 1024), + use_broadcasting: bool = True, + ) -> None: + self.iou_thresh = iou_thresh + self.use_polygons = use_polygons + self.mask_shape = mask_shape + self.use_broadcasting = use_broadcasting + self.reset() + + def update( + self, + gt_boxes: np.ndarray, + pred_boxes: np.ndarray, + gt_labels: np.ndarray, + pred_labels: np.ndarray, + ) -> None: + """Updates the metric + + Args: + gt_boxes: a set of relative bounding boxes either of shape (N, 4) or (N, 5) if they are rotated ones + pred_boxes: a set of relative bounding boxes either of shape (M, 4) or (M, 5) if they are rotated ones + gt_labels: an array of class indices of shape (N,) + pred_labels: an array of class indices of shape (M,) + """ + + if gt_boxes.shape[0] != gt_labels.shape[0] or pred_boxes.shape[0] != pred_labels.shape[0]: + raise AssertionError("there should be the same number of boxes and string both for the ground truth " + "and the predictions") + + # Compute IoU + if pred_boxes.shape[0] > 0: + if self.use_polygons: + iou_mat = polygon_iou(gt_boxes, pred_boxes, self.mask_shape, self.use_broadcasting) + else: + iou_mat = box_iou(gt_boxes, pred_boxes) + + self.tot_iou += float(iou_mat.max(axis=0).sum()) + + # Assign pairs + gt_indices, pred_indices = linear_sum_assignment(-iou_mat) + is_kept = iou_mat[gt_indices, pred_indices] >= self.iou_thresh + # Category comparison + self.num_matches += int((gt_labels[gt_indices[is_kept]] == pred_labels[pred_indices[is_kept]]).sum()) + + self.num_gts += gt_boxes.shape[0] + self.num_preds += pred_boxes.shape[0] + + def summary(self) -> Tuple[Optional[float], Optional[float], Optional[float]]: + """Computes the aggregated metrics + + Returns: + a tuple with the recall & precision for each class prediction and the mean IoU + """ + + # Recall + recall = self.num_matches / self.num_gts if self.num_gts > 0 else None + + # Precision + precision = self.num_matches / self.num_preds if self.num_preds > 0 else None + + # mean IoU (overall detected boxes) + mean_iou = self.tot_iou / self.num_preds if self.num_preds > 0 else None + + return recall, precision, mean_iou + + def reset(self) -> None: + self.num_gts = 0 + self.num_preds = 0 + self.tot_iou = 0. + self.num_matches = 0 diff --git a/doctr/utils/multithreading.py b/doctr/utils/multithreading.py new file mode 100644 index 0000000000..51af0c75af --- /dev/null +++ b/doctr/utils/multithreading.py @@ -0,0 +1,39 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + + +import multiprocessing as mp +from multiprocessing.pool import ThreadPool +from typing import Any, Callable, Iterable, Optional + +__all__ = ['multithread_exec'] + + +def multithread_exec(func: Callable[[Any], Any], seq: Iterable[Any], threads: Optional[int] = None) -> Iterable[Any]: + """Execute a given function in parallel for each element of a given sequence + + Example:: + >>> from doctr.utils.multithreading import multithread_exec + >>> entries = [1, 4, 8] + >>> results = multithread_exec(lambda x: x ** 2, entries) + + Args: + func: function to be executed on each element of the iterable + seq: iterable + threads: number of workers to be used for multiprocessing + + Returns: + iterable of the function's results using the iterable as inputs + """ + + threads = threads if isinstance(threads, int) else min(16, mp.cpu_count()) + # Single-thread + if threads < 2: + results = map(func, seq) + # Multi-threading + else: + with ThreadPool(threads) as tp: + results = tp.map(func, seq) # type: ignore[assignment] + return results diff --git a/doctr/utils/repr.py b/doctr/utils/repr.py new file mode 100644 index 0000000000..f0876dcc8f --- /dev/null +++ b/doctr/utils/repr.py @@ -0,0 +1,58 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +# Adapted from https://github.com/pytorch/torch/blob/master/torch/nn/modules/module.py + +__all__ = ['NestedObject'] + + +def _addindent(s_, num_spaces): + s = s_.split('\n') + # don't do anything for single-line stuff + if len(s) == 1: + return s_ + first = s.pop(0) + s = [(num_spaces * ' ') + line for line in s] + s = '\n'.join(s) + s = first + '\n' + s + return s + + +class NestedObject: + def extra_repr(self) -> str: + return '' + + def __repr__(self): + # We treat the extra repr like the sub-object, one item per line + extra_lines = [] + extra_repr = self.extra_repr() + # empty string will be split into list [''] + if extra_repr: + extra_lines = extra_repr.split('\n') + child_lines = [] + if hasattr(self, '_children_names'): + for key in self._children_names: + child = getattr(self, key) + if isinstance(child, list) and len(child) > 0: + child_str = ",\n".join([repr(subchild) for subchild in child]) + if len(child) > 1: + child_str = _addindent(f"\n{child_str},", 2) + '\n' + child_str = f"[{child_str}]" + else: + child_str = repr(child) + child_str = _addindent(child_str, 2) + child_lines.append('(' + key + '): ' + child_str) + lines = extra_lines + child_lines + + main_str = self.__class__.__name__ + '(' + if lines: + # simple one-liner info, which most builtin Modules will use + if len(extra_lines) == 1 and not child_lines: + main_str += extra_lines[0] + else: + main_str += '\n ' + '\n '.join(lines) + '\n' + + main_str += ')' + return main_str diff --git a/doctr/utils/visualization.py b/doctr/utils/visualization.py new file mode 100644 index 0000000000..85760c2b31 --- /dev/null +++ b/doctr/utils/visualization.py @@ -0,0 +1,338 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from copy import deepcopy +from typing import Any, Dict, List, Optional, Tuple, Union + +import cv2 +import matplotlib.patches as patches +import matplotlib.pyplot as plt +import mplcursors +import numpy as np +from matplotlib.figure import Figure +from PIL import Image, ImageDraw +from unidecode import unidecode + +from .common_types import BoundingBox, Polygon4P +from .fonts import get_font + +__all__ = ['visualize_page', 'synthesize_page', 'draw_boxes'] + + +def rect_patch( + geometry: BoundingBox, + page_dimensions: Tuple[int, int], + label: Optional[str] = None, + color: Tuple[float, float, float] = (0, 0, 0), + alpha: float = 0.3, + linewidth: int = 2, + fill: bool = True, + preserve_aspect_ratio: bool = False +) -> patches.Rectangle: + """Create a matplotlib rectangular patch for the element + + Args: + geometry: bounding box of the element + page_dimensions: dimensions of the Page in format (height, width) + label: label to display when hovered + color: color to draw box + alpha: opacity parameter to fill the boxes, 0 = transparent + linewidth: line width + fill: whether the patch should be filled + preserve_aspect_ratio: pass True if you passed True to the predictor + + Returns: + a rectangular Patch + """ + + if len(geometry) != 2 or any(not isinstance(elt, tuple) or len(elt) != 2 for elt in geometry): + raise ValueError("invalid geometry format") + + # Unpack + height, width = page_dimensions + (xmin, ymin), (xmax, ymax) = geometry + # Switch to absolute coords + if preserve_aspect_ratio: + width = height = max(height, width) + xmin, w = xmin * width, (xmax - xmin) * width + ymin, h = ymin * height, (ymax - ymin) * height + + return patches.Rectangle( + (xmin, ymin), + w, + h, + fill=fill, + linewidth=linewidth, + edgecolor=(*color, alpha), + facecolor=(*color, alpha), + label=label, + ) + + +def polygon_patch( + geometry: np.ndarray, + page_dimensions: Tuple[int, int], + label: Optional[str] = None, + color: Tuple[float, float, float] = (0, 0, 0), + alpha: float = 0.3, + linewidth: int = 2, + fill: bool = True, + preserve_aspect_ratio: bool = False +) -> patches.Polygon: + """Create a matplotlib polygon patch for the element + + Args: + geometry: bounding box of the element + page_dimensions: dimensions of the Page in format (height, width) + label: label to display when hovered + color: color to draw box + alpha: opacity parameter to fill the boxes, 0 = transparent + linewidth: line width + fill: whether the patch should be filled + preserve_aspect_ratio: pass True if you passed True to the predictor + + Returns: + a polygon Patch + """ + + if not geometry.shape == (4, 2): + raise ValueError("invalid geometry format") + + # Unpack + height, width = page_dimensions + geometry[:, 0] = geometry[:, 0] * (max(width, height) if preserve_aspect_ratio else width) + geometry[:, 1] = geometry[:, 1] * (max(width, height) if preserve_aspect_ratio else height) + + return patches.Polygon( + geometry, + fill=fill, + linewidth=linewidth, + edgecolor=(*color, alpha), + facecolor=(*color, alpha), + label=label, + ) + + +def create_obj_patch( + geometry: Union[BoundingBox, Polygon4P, np.ndarray], + page_dimensions: Tuple[int, int], + **kwargs: Any, +) -> patches.Patch: + """Create a matplotlib patch for the element + + Args: + geometry: bounding box (straight or rotated) of the element + page_dimensions: dimensions of the page in format (height, width) + + Returns: + a matplotlib Patch + """ + if isinstance(geometry, tuple): + if len(geometry) == 2: # straight word BB (2 pts) + return rect_patch(geometry, page_dimensions, **kwargs) # type: ignore[arg-type] + elif len(geometry) == 4: # rotated word BB (4 pts) + return polygon_patch(np.asarray(geometry), page_dimensions, **kwargs) # type: ignore[arg-type] + elif isinstance(geometry, np.ndarray) and geometry.shape == (4, 2): # rotated line + return polygon_patch(geometry, page_dimensions, **kwargs) # type: ignore[arg-type] + raise ValueError("invalid geometry format") + + +def visualize_page( + page: Dict[str, Any], + image: np.ndarray, + words_only: bool = True, + display_artefacts: bool = True, + scale: float = 10, + interactive: bool = True, + add_labels: bool = True, + **kwargs: Any, +) -> Figure: + """Visualize a full page with predicted blocks, lines and words + + Example:: + >>> import numpy as np + >>> import matplotlib.pyplot as plt + >>> from doctr.utils.visualization import visualize_page + >>> from doctr.models import ocr_db_crnn + >>> model = ocr_db_crnn(pretrained=True) + >>> input_page = (255 * np.random.rand(600, 800, 3)).astype(np.uint8) + >>> out = model([[input_page]]) + >>> visualize_page(out[0].pages[0].export(), input_page) + >>> plt.show() + + Args: + page: the exported Page of a Document + image: np array of the page, needs to have the same shape than page['dimensions'] + words_only: whether only words should be displayed + display_artefacts: whether artefacts should be displayed + scale: figsize of the largest windows side + interactive: whether the plot should be interactive + add_labels: for static plot, adds text labels on top of bounding box + """ + # Get proper scale and aspect ratio + h, w = image.shape[:2] + size = (scale * w / h, scale) if h > w else (scale, h / w * scale) + fig, ax = plt.subplots(figsize=size) + # Display the image + ax.imshow(image) + # hide both axis + ax.axis('off') + + if interactive: + artists: List[patches.Patch] = [] # instantiate an empty list of patches (to be drawn on the page) + + for block in page['blocks']: + if not words_only: + rect = create_obj_patch(block['geometry'], page['dimensions'], + label='block', color=(0, 1, 0), linewidth=1, **kwargs) + # add patch on figure + ax.add_patch(rect) + if interactive: + # add patch to cursor's artists + artists.append(rect) + + for line in block['lines']: + if not words_only: + rect = create_obj_patch(line['geometry'], page['dimensions'], + label='line', color=(1, 0, 0), linewidth=1, **kwargs) + ax.add_patch(rect) + if interactive: + artists.append(rect) + + for word in line['words']: + rect = create_obj_patch(word['geometry'], page['dimensions'], + label=f"{word['value']} (confidence: {word['confidence']:.2%})", + color=(0, 0, 1), **kwargs) + ax.add_patch(rect) + if interactive: + artists.append(rect) + elif add_labels: + if len(word['geometry']) == 5: + text_loc = ( + int(page['dimensions'][1] * (word['geometry'][0] - word['geometry'][2] / 2)), + int(page['dimensions'][0] * (word['geometry'][1] - word['geometry'][3] / 2)) + ) + else: + text_loc = ( + int(page['dimensions'][1] * word['geometry'][0][0]), + int(page['dimensions'][0] * word['geometry'][0][1]) + ) + ax.text( + *text_loc, + word['value'], + size=10, + alpha=0.5, + color=(0, 0, 1), + ) + + if display_artefacts: + for artefact in block['artefacts']: + rect = create_obj_patch( + artefact['geometry'], + page['dimensions'], + label='artefact', + color=(0.5, 0.5, 0.5), + linewidth=1, + **kwargs + ) + ax.add_patch(rect) + if interactive: + artists.append(rect) + + if interactive: + # Create mlp Cursor to hover patches in artists + mplcursors.Cursor(artists, hover=2).connect("add", lambda sel: sel.annotation.set_text(sel.artist.get_label())) + fig.tight_layout(pad=0.) + + return fig + + +def synthesize_page( + page: Dict[str, Any], + draw_proba: bool = False, + font_size: int = 13, + font_family: Optional[str] = None, +) -> np.ndarray: + """Draw a the content of the element page (OCR response) on a blank page. + + Args: + page: exported Page object to represent + draw_proba: if True, draw words in colors to represent confidence. Blue: p=1, red: p=0 + font_size: size of the font, default font = 13 + font_family: family of the font + + Return: + the synthesized page + """ + + # Draw template + h, w = page["dimensions"] + response = 255 * np.ones((h, w, 3), dtype=np.int32) + + # Draw each word + for block in page["blocks"]: + for line in block["lines"]: + for word in line["words"]: + # Get aboslute word geometry + (xmin, ymin), (xmax, ymax) = word["geometry"] + xmin, xmax = int(round(w * xmin)), int(round(w * xmax)) + ymin, ymax = int(round(h * ymin)), int(round(h * ymax)) + + # White drawing context adapted to font size, 0.75 factor to convert pts --> pix + font = get_font(font_family, int(0.75 * (ymax - ymin))) + img = Image.new('RGB', (xmax - xmin, ymax - ymin), color=(255, 255, 255)) + d = ImageDraw.Draw(img) + # Draw in black the value of the word + try: + d.text((0, 0), word["value"], font=font, fill=(0, 0, 0)) + except UnicodeEncodeError: + # When character cannot be encoded, use its unidecode version + d.text((0, 0), unidecode(word["value"]), font=font, fill=(0, 0, 0)) + + # Colorize if draw_proba + if draw_proba: + p = int(255 * word["confidence"]) + mask = np.where(np.array(img) == 0, 1, 0) + proba = np.array([255 - p, 0, p]) + color = mask * proba[np.newaxis, np.newaxis, :] + white_mask = 255 * (1 - mask) + img = color + white_mask + + # Write to response page + response[ymin:ymax, xmin:xmax, :] = np.array(img) + + return response + + +def draw_boxes( + boxes: np.ndarray, + image: np.ndarray, + color: Optional[Tuple[int, int, int]] = None, + **kwargs +) -> None: + """Draw an array of relative straight boxes on an image + + Args: + boxes: array of relative boxes, of shape (*, 4) + image: np array, float32 or uint8 + color: color to use for bounding box edges + """ + h, w = image.shape[:2] + # Convert boxes to absolute coords + _boxes = deepcopy(boxes) + _boxes[:, [0, 2]] *= w + _boxes[:, [1, 3]] *= h + _boxes = _boxes.astype(np.int32) + for box in _boxes.tolist(): + xmin, ymin, xmax, ymax = box + image = cv2.rectangle( + image, + (xmin, ymin), + (xmax, ymax), + color=color if isinstance(color, tuple) else (0, 0, 255), + thickness=2 + ) + plt.imshow(image) + plt.plot(**kwargs) diff --git a/mypy.ini b/mypy.ini new file mode 100644 index 0000000000..f8b02f70ca --- /dev/null +++ b/mypy.ini @@ -0,0 +1,77 @@ +[mypy] + +files = doctr/*.py +show_error_codes = True +pretty = True + +[mypy-numpy.*] + +ignore_missing_imports = True + +[mypy-tensorflow.*] + +ignore_missing_imports = True + +[mypy-fitz.*] + +ignore_missing_imports = True + +[mypy-cv2.*] + +ignore_missing_imports = True + +[mypy-shapely.*] + +ignore_missing_imports = True + +[mypy-pyclipper.*] + +ignore_missing_imports = True + +[mypy-scipy.*] + +ignore_missing_imports = True + +[mypy-rapidfuzz.*] + +ignore_missing_imports = True + +[mypy-matplotlib.*] + +ignore_missing_imports = True + +[mypy-mplcursors.*] + +ignore_missing_imports = True + +[mypy-weasyprint.*] + +ignore_missing_imports = True + +[mypy-torchvision.*] + +ignore_missing_imports = True + +[mypy-torch.*] + +ignore_missing_imports = True + +[mypy-PIL.*] + +ignore_missing_imports = True + +[mypy-tqdm.*] + +ignore_missing_imports = True + +[mypy-tensorflow_addons.*] + +ignore_missing_imports = True + +[mypy-defusedxml.*] + +ignore_missing_imports = True + +[mypy-h5py.*] + +ignore_missing_imports = True diff --git a/notebooks/README.md b/notebooks/README.md new file mode 100644 index 0000000000..ea43ac0f39 --- /dev/null +++ b/notebooks/README.md @@ -0,0 +1,9 @@ +# docTR Notebooks + +Here are some notebooks compiled for users to better leverage the library capabilities: + +| Notebook | Description | | +|:----------|:-------------|------:| +| [Quicktour](https://github.com/mindee/notebooks/blob/main/doctr/quicktour.ipynb) | A presentation of the main features of docTR | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/quicktour.ipynb) | +| [Export as PDF/A](https://github.com/mindee/notebooks/blob/main/doctr/export_as_pdfa.ipynb) | Produce searchable PDFs from docTR results | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/export_as_pdfa.ipynb) | +[Artefact detection](https://github.com/mindee/notebooks/blob/main/doctr/artefact_detection.ipynb) | Object detection for artefacts in documents | [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/mindee/notebooks/blob/main/doctr/artefact_detection.ipynb) | diff --git a/references/classification/README.md b/references/classification/README.md new file mode 100644 index 0000000000..6f612aa60d --- /dev/null +++ b/references/classification/README.md @@ -0,0 +1,34 @@ +# Character classification + +The sample training script was made to train a character classification model with docTR. + +## Setup + +First, you need to install `doctr` (with pip, for instance) + +```shell +pip install -e . --upgrade +pip install -r references/requirements.txt +``` + +## Usage + +You can start your training in TensorFlow: + +```shell +python references/classification/train_tensorflow.py mobilenet_v3_large --epochs 5 +``` +or PyTorch: + +```shell +python references/classification/train_pytorch.py mobilenet_v3_large --epochs 5 --device 0 +``` + + +## Advanced options + +Feel free to inspect the multiple script option to customize your training to your own needs! + +```python +python references/classification/train_tensorflow.py --help +``` diff --git a/references/classification/latency.csv b/references/classification/latency.csv new file mode 100644 index 0000000000..b5b7fcdb1e --- /dev/null +++ b/references/classification/latency.csv @@ -0,0 +1,31 @@ +arch,input_shape,framework,hardware,mean,std +mobilenet_v3_small,"(32, 32)",pytorch,cpu,16.53,0.4 +mobilenet_v3_small,"(32, 32)",pytorch,gpu,4.48,0.28 +mobilenet_v3_small_r,"(32, 32)",pytorch,cpu,12.36,0.34 +mobilenet_v3_small_r,"(32, 32)",pytorch,gpu,4.54,0.17 +mobilenet_v3_large,"(32, 32)",pytorch,cpu,28.63,0.59 +mobilenet_v3_large,"(32, 32)",pytorch,gpu,5.2,0.24 +mobilenet_v3_large_r,"(32, 32)",pytorch,cpu,22.91,0.65 +mobilenet_v3_large_r,"(32, 32)",pytorch,gpu,5.72,0.19 +vgg16_bn_r,"(32, 32)",pytorch,cpu,160.66,1.89 +vgg16_bn_r,"(32, 32)",pytorch,gpu,9.2,1.54 +resnet31,"(32, 32)",pytorch,cpu,, +resnet31,"(32, 32)",pytorch,gpu,25.1,0.94 +magc_resnet31,"(32, 32)",pytorch,cpu,, +magc_resnet31,"(32, 32)",pytorch,gpu,25.82,0.43 +mobilenet_v3_small,"(32, 32)",tensorflow,cpu,85.35,1.06 +mobilenet_v3_small,"(32, 32)",tensorflow,gpu,30.61,0.54 +mobilenet_v3_small_r,"(32, 32)",tensorflow,cpu,243.89,2.04 +mobilenet_v3_small_r,"(32, 32)",tensorflow,gpu,28.19,0.67 +mobilenet_v3_large,"(32, 32)",tensorflow,cpu,184.7,1.63 +mobilenet_v3_large,"(32, 32)",tensorflow,gpu,34.59,0.73 +mobilenet_v3_large_r,"(32, 32)",tensorflow,cpu,437.52,4.55 +mobilenet_v3_large_r,"(32, 32)",tensorflow,gpu,32.59,0.71 +resnet18,"(32, 32)",tensorflow,cpu,27.96,0.43 +resnet18,"(32, 32)",tensorflow,gpu,12.35,0.39 +vgg16_bn_r,"(32, 32)",tensorflow,cpu,182.68,18.75 +vgg16_bn_r,"(32, 32)",tensorflow,gpu,9.04,1.1 +resnet31,"(32, 32)",tensorflow,cpu,, +resnet31,"(32, 32)",tensorflow,gpu,28.35,0.84 +magc_resnet31,"(32, 32)",tensorflow,cpu,, +magc_resnet31,"(32, 32)",tensorflow,gpu,32.71,2.66 diff --git a/references/classification/latency_pytorch.py b/references/classification/latency_pytorch.py new file mode 100644 index 0000000000..ac0b25e5be --- /dev/null +++ b/references/classification/latency_pytorch.py @@ -0,0 +1,64 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +""" +Image classification latency benchmark +""" + +import argparse +import os +import time + +import numpy as np +import torch + +os.environ['USE_TORCH'] = '1' + +from doctr.models import classification + + +@torch.no_grad() +def main(args): + + device = torch.device("cuda:0" if args.gpu else "cpu") + + # Pretrained imagenet model + model = classification.__dict__[args.arch]( + pretrained=args.pretrained, + ).eval().to(device=device) + + # Input + img_tensor = torch.rand((args.batch_size, 3, args.size, args.size)).to(device=device) + + # Warmup + for _ in range(10): + _ = model(img_tensor) + + timings = [] + + # Evaluation runs + for _ in range(args.it): + start_ts = time.perf_counter() + _ = model(img_tensor) + timings.append(time.perf_counter() - start_ts) + + _timings = np.array(timings) + print(f"{args.arch} ({args.it} runs on ({args.size}, {args.size}) inputs in batches of {args.batch_size})") + print(f"mean {1000 * _timings.mean():.2f}ms, std {1000 * _timings.std():.2f}ms") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='docTR latency benchmark for image classification (PyTorch)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("arch", type=str, help="Architecture to use") + parser.add_argument("--size", type=int, default=32, help="The image input size") + parser.add_argument("--batch-size", "-b", type=int, default=64, help="The batch_size") + parser.add_argument("--gpu", dest="gpu", help='Should the benchmark be performed on GPU', action="store_true") + parser.add_argument("--it", type=int, default=100, help="Number of iterations to run") + parser.add_argument("--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", + action="store_true") + args = parser.parse_args() + + main(args) diff --git a/references/classification/latency_tensorflow.py b/references/classification/latency_tensorflow.py new file mode 100644 index 0000000000..3130dde701 --- /dev/null +++ b/references/classification/latency_tensorflow.py @@ -0,0 +1,72 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +""" +Image classification latency benchmark +""" + +import argparse +import os +import time + +import numpy as np +import tensorflow as tf + +os.environ['USE_TF'] = '1' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +from doctr.models import classification + + +def main(args): + + if args.gpu: + gpu_devices = tf.config.experimental.list_physical_devices('GPU') + if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) + else: + raise AssertionError("TensorFlow cannot access your GPU. Please investigate!") + else: + os.environ['CUDA_VISIBLE_DEVICES'] = "" + + # Pretrained imagenet model + model = classification.__dict__[args.arch]( + pretrained=args.pretrained, + input_shape=(args.size, args.size, 3), + ) + + # Input + img_tensor = tf.random.uniform(shape=[args.batch_size, args.size, args.size, 3], maxval=1, dtype=tf.float32) + + # Warmup + for _ in range(10): + _ = model(img_tensor, training=False) + + timings = [] + + # Evaluation runs + for _ in range(args.it): + start_ts = time.perf_counter() + _ = model(img_tensor, training=False) + timings.append(time.perf_counter() - start_ts) + + _timings = np.array(timings) + print(f"{args.arch} ({args.it} runs on ({args.size}, {args.size}) inputs)") + print(f"mean {1000 * _timings.mean():.2f}ms, std {1000 * _timings.std():.2f}ms") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='docTR latency benchmark for imag classification (TensorFlow)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("arch", type=str, help="Architecture to use") + parser.add_argument("--size", type=int, default=32, help="The image input size") + parser.add_argument("--batch-size", "-b", type=int, default=64, help="The batch_size") + parser.add_argument("--gpu", dest="gpu", help='Should the benchmark be performed on GPU', action="store_true") + parser.add_argument("--it", type=int, default=100, help="Number of iterations to run") + parser.add_argument("--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", + action="store_true") + args = parser.parse_args() + + main(args) diff --git a/references/classification/train_pytorch.py b/references/classification/train_pytorch.py new file mode 100644 index 0000000000..9336cdcd16 --- /dev/null +++ b/references/classification/train_pytorch.py @@ -0,0 +1,391 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os + +os.environ['USE_TORCH'] = '1' + +import datetime +import logging +import multiprocessing as mp +import time + +import numpy as np +import torch +import wandb +from fastprogress.fastprogress import master_bar, progress_bar +from torch.nn.functional import cross_entropy +from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torchvision.transforms import (ColorJitter, Compose, GaussianBlur, Grayscale, InterpolationMode, Normalize, + RandomRotation) + +from doctr import transforms as T +from doctr.datasets import VOCABS, CharacterGenerator +from doctr.models import classification +from utils import plot_recorder, plot_samples + + +def record_lr( + model: torch.nn.Module, + train_loader: DataLoader, + batch_transforms, + optimizer, + start_lr: float = 1e-7, + end_lr: float = 1, + num_it: int = 100, + amp: bool = False, +): + """Gridsearch the optimal learning rate for the training. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + """ + + if num_it > len(train_loader): + raise ValueError("the value of `num_it` needs to be lower than the number of available batches") + + model = model.train() + # Update param groups & LR + optimizer.defaults['lr'] = start_lr + for pgroup in optimizer.param_groups: + pgroup['lr'] = start_lr + + gamma = (end_lr / start_lr) ** (1 / (num_it - 1)) + scheduler = MultiplicativeLR(optimizer, lambda step: gamma) + + lr_recorder = [start_lr * gamma ** idx for idx in range(num_it)] + loss_recorder = [] + + if amp: + scaler = torch.cuda.amp.GradScaler() + + for batch_idx, (images, targets) in enumerate(train_loader): + if torch.cuda.is_available(): + images = images.cuda() + targets = targets.cuda() + + images = batch_transforms(images) + + # Forward, Backward & update + optimizer.zero_grad() + if amp: + with torch.cuda.amp.autocast(): + out = model(images) + train_loss = cross_entropy(out, targets) + scaler.scale(train_loss).backward() + # Update the params + scaler.step(optimizer) + scaler.update() + else: + out = model(images) + train_loss = cross_entropy(out, targets) + train_loss.backward() + optimizer.step() + # Update LR + scheduler.step() + + # Record + if not torch.isfinite(train_loss): + if batch_idx == 0: + raise ValueError("loss value is NaN or inf.") + else: + break + loss_recorder.append(train_loss.item()) + # Stop after the number of iterations + if batch_idx + 1 == num_it: + break + + return lr_recorder[:len(loss_recorder)], loss_recorder + + +def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=False): + + if amp: + scaler = torch.cuda.amp.GradScaler() + + model.train() + # Iterate over the batches of the dataset + for images, targets in progress_bar(train_loader, parent=mb): + + if torch.cuda.is_available(): + images = images.cuda() + targets = targets.cuda() + + images = batch_transforms(images) + + optimizer.zero_grad() + if amp: + with torch.cuda.amp.autocast(): + out = model(images) + train_loss = cross_entropy(out, targets) + scaler.scale(train_loss).backward() + # Update the params + scaler.step(optimizer) + scaler.update() + else: + out = model(images) + train_loss = cross_entropy(out, targets) + train_loss.backward() + optimizer.step() + scheduler.step() + + mb.child.comment = f'Training loss: {train_loss.item():.6}' + + +@torch.no_grad() +def evaluate(model, val_loader, batch_transforms, amp=False): + # Model in eval mode + model.eval() + # Validation loop + val_loss, correct, samples, batch_cnt = 0, 0, 0, 0 + for images, targets in val_loader: + images = batch_transforms(images) + + if torch.cuda.is_available(): + images = images.cuda() + targets = targets.cuda() + + if amp: + with torch.cuda.amp.autocast(): + out = model(images) + loss = cross_entropy(out, targets) + else: + out = model(images) + loss = cross_entropy(out, targets) + # Compute metric + correct += (out.argmax(dim=1) == targets).sum().item() + + val_loss += loss.item() + batch_cnt += 1 + samples += images.shape[0] + + val_loss /= batch_cnt + acc = correct / samples + return val_loss, acc + + +def main(args): + + print(args) + + if not isinstance(args.workers, int): + args.workers = min(16, mp.cpu_count()) + + torch.backends.cudnn.benchmark = True + + vocab = VOCABS[args.vocab] + + fonts = args.font.split(",") + + # Load val data generator + st = time.time() + val_set = CharacterGenerator( + vocab=vocab, + num_samples=args.val_samples * len(vocab), + cache_samples=True, + img_transforms=Compose([ + T.Resize((args.input_size, args.input_size)), + # Ensure we have a 90% split of white-background images + T.RandomApply(T.ColorInversion(), .9), + ]), + font_family=fonts, + ) + val_loader = DataLoader( + val_set, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + sampler=SequentialSampler(val_set), + pin_memory=torch.cuda.is_available(), + ) + print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " + f"{len(val_loader)} batches)") + + batch_transforms = Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)) + + # Load doctr model + model = classification.__dict__[args.arch](pretrained=args.pretrained, num_classes=len(vocab)) + + # Resume weights + if isinstance(args.resume, str): + print(f"Resuming {args.resume}") + checkpoint = torch.load(args.resume, map_location='cpu') + model.load_state_dict(checkpoint) + + # GPU + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + # Silent default switch to GPU if available + elif torch.cuda.is_available(): + args.device = 0 + else: + logging.warning("No accessible GPU, targe device set to CPU.") + if torch.cuda.is_available(): + torch.cuda.set_device(args.device) + model = model.cuda() + + if args.test_only: + print("Running evaluation") + val_loss, acc = evaluate(model, val_loader, batch_transforms) + print(f"Validation loss: {val_loss:.6} (Acc: {acc:.2%})") + return + + st = time.time() + + # Load train data generator + train_set = CharacterGenerator( + vocab=vocab, + num_samples=args.train_samples * len(vocab), + cache_samples=True, + img_transforms=Compose([ + T.Resize((args.input_size, args.input_size)), + # Augmentations + T.RandomApply(T.ColorInversion(), .9), + # GaussianNoise + T.RandomApply(Grayscale(3), .1), + ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02), + T.RandomApply(GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 3)), .3), + RandomRotation(15, interpolation=InterpolationMode.BILINEAR), + ]), + font_family=fonts, + ) + + train_loader = DataLoader( + train_set, + batch_size=args.batch_size, + drop_last=True, + num_workers=args.workers, + sampler=RandomSampler(train_set), + pin_memory=torch.cuda.is_available(), + ) + print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " + f"{len(train_loader)} batches)") + + if args.show_samples: + x, target = next(iter(train_loader)) + plot_samples(x, list(map(vocab.__getitem__, target))) + return + + # Optimizer + optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], args.lr, + betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) + + # LR Finder + if args.find_lr: + lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp) + plot_recorder(lrs, losses) + return + # Scheduler + if args.sched == 'cosine': + scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) + elif args.sched == 'onecycle': + scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) + + # Training monitoring + current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name + + # W&B + if args.wb: + + run = wandb.init( + name=exp_name, + project="character-classification", + config={ + "learning_rate": args.lr, + "epochs": args.epochs, + "weight_decay": args.weight_decay, + "batch_size": args.batch_size, + "architecture": args.arch, + "input_size": args.input_size, + "optimizer": "adam", + "framework": "pytorch", + "vocab": args.vocab, + "scheduler": args.sched, + "pretrained": args.pretrained, + } + ) + + # Create loss queue + min_loss = np.inf + # Training loop + mb = master_bar(range(args.epochs)) + for epoch in mb: + fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb) + + # Validation loop at the end of each epoch + val_loss, acc = evaluate(model, val_loader, batch_transforms) + if val_loss < min_loss: + print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") + torch.save(model.state_dict(), f"./{exp_name}.pt") + min_loss = val_loss + mb.write(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})") + # W&B + if args.wb: + wandb.log({ + 'val_loss': val_loss, + 'acc': acc, + }) + + if args.wb: + run.finish() + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='DocTR training script for character classification (PyTorch)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('arch', type=str, help='text-recognition model to train') + parser.add_argument('--name', type=str, default=None, help='Name of your training experiment') + parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on') + parser.add_argument('-b', '--batch_size', type=int, default=64, help='batch size for training') + parser.add_argument('--device', default=None, type=int, help='device') + parser.add_argument('--input_size', type=int, default=32, help='input size H for the model, W = H') + parser.add_argument('--lr', type=float, default=0.001, help='learning rate for the optimizer (Adam)') + parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay') + parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading') + parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint') + parser.add_argument( + '--font', + type=str, + default="FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf", + help='Font family to be used' + ) + parser.add_argument('--vocab', type=str, default="french", help='Vocab to be used for training') + parser.add_argument( + '--train-samples', + dest='train_samples', + type=int, + default=1000, + help='Multiplied by the vocab length gets you the number of training samples that will be used.' + ) + parser.add_argument( + '--val-samples', + dest='val_samples', + type=int, + default=20, + help='Multiplied by the vocab length gets you the number of validation samples that will be used.' + ) + parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop") + parser.add_argument('--show-samples', dest='show_samples', action='store_true', + help='Display unormalized training samples') + parser.add_argument('--wb', dest='wb', action='store_true', + help='Log to Weights & Biases') + parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='Load pretrained parameters before starting the training') + parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use') + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR') + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/classification/train_tensorflow.py b/references/classification/train_tensorflow.py new file mode 100644 index 0000000000..62c40822a9 --- /dev/null +++ b/references/classification/train_tensorflow.py @@ -0,0 +1,351 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os + +os.environ['USE_TF'] = '1' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +import datetime +import multiprocessing as mp +import time + +import numpy as np +import tensorflow as tf +import wandb +from fastprogress.fastprogress import master_bar, progress_bar +from tensorflow.keras import mixed_precision + +gpu_devices = tf.config.experimental.list_physical_devices('GPU') +if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) + +from doctr import transforms as T +from doctr.datasets import VOCABS, CharacterGenerator, DataLoader +from doctr.models import classification +from utils import plot_recorder, plot_samples + + +def record_lr( + model: tf.keras.Model, + train_loader: DataLoader, + batch_transforms, + optimizer, + start_lr: float = 1e-7, + end_lr: float = 1, + num_it: int = 100, + amp: bool = False, +): + """Gridsearch the optimal learning rate for the training. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + """ + + if num_it > len(train_loader): + raise ValueError("the value of `num_it` needs to be lower than the number of available batches") + + # Update param groups & LR + gamma = (end_lr / start_lr) ** (1 / (num_it - 1)) + optimizer.learning_rate = start_lr + + lr_recorder = [start_lr * gamma ** idx for idx in range(num_it)] + loss_recorder = [] + + for batch_idx, (images, targets) in enumerate(train_loader): + + images = batch_transforms(images) + + # Forward, Backward & update + with tf.GradientTape() as tape: + out = model(images, training=True) + train_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(targets, out) + grads = tape.gradient(train_loss, model.trainable_weights) + + if amp: + grads = optimizer.get_unscaled_gradients(grads) + optimizer.apply_gradients(zip(grads, model.trainable_weights)) + + optimizer.learning_rate = optimizer.learning_rate * gamma + + # Record + train_loss = train_loss.numpy() + if np.any(np.isnan(train_loss)): + if batch_idx == 0: + raise ValueError("loss value is NaN or inf.") + else: + break + loss_recorder.append(train_loss.mean()) + # Stop after the number of iterations + if batch_idx + 1 == num_it: + break + + return lr_recorder[:len(loss_recorder)], loss_recorder + + +def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb, amp=False): + # Iterate over the batches of the dataset + for images, targets in progress_bar(train_loader, parent=mb): + images = batch_transforms(images) + + with tf.GradientTape() as tape: + out = model(images, training=True) + train_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(targets, out) + grads = tape.gradient(train_loss, model.trainable_weights) + if amp: + grads = optimizer.get_unscaled_gradients(grads) + optimizer.apply_gradients(zip(grads, model.trainable_weights)) + + mb.child.comment = f'Training loss: {train_loss.numpy().mean():.6}' + + +def evaluate(model, val_loader, batch_transforms): + # Validation loop + val_loss, correct, samples, batch_cnt = 0, 0, 0, 0 + val_iter = iter(val_loader) + for images, targets in val_iter: + images = batch_transforms(images) + out = model(images, training=False) + loss = tf.nn.sparse_softmax_cross_entropy_with_logits(targets, out) + # Compute metric + correct += int((out.numpy().argmax(1) == targets.numpy()).sum()) + + val_loss += loss.numpy().mean() + batch_cnt += 1 + samples += images.shape[0] + + val_loss /= batch_cnt + acc = correct / samples + return val_loss, acc + + +def collate_fn(samples): + + images, targets = zip(*samples) + images = tf.stack(images, axis=0) + + return images, tf.convert_to_tensor(targets) + + +def main(args): + + print(args) + + if not isinstance(args.workers, int): + args.workers = min(16, mp.cpu_count()) + + vocab = VOCABS[args.vocab] + + fonts = args.font.split(",") + + # AMP + if args.amp: + mixed_precision.set_global_policy('mixed_float16') + + # Load val data generator + st = time.time() + val_set = CharacterGenerator( + vocab=vocab, + num_samples=args.val_samples * len(vocab), + cache_samples=True, + img_transforms=T.Compose([ + T.Resize((args.input_size, args.input_size)), + # Ensure we have a 90% split of white-background images + T.RandomApply(T.ColorInversion(), .9), + ]), + font_family=fonts, + ) + val_loader = DataLoader( + val_set, + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + num_workers=args.workers, + collate_fn=collate_fn, + ) + print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " + f"{val_loader.num_batches} batches)") + + # Load doctr model + model = classification.__dict__[args.arch]( + pretrained=args.pretrained, + input_shape=(args.input_size, args.input_size, 3), + num_classes=len(vocab), + include_top=True, + ) + + # Resume weights + if isinstance(args.resume, str): + model.load_weights(args.resume) + + batch_transforms = T.Compose([ + T.Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)), + ]) + + if args.test_only: + print("Running evaluation") + val_loss, acc = evaluate(model, val_loader, batch_transforms) + print(f"Validation loss: {val_loss:.6} (Acc: {acc:.2%})") + return + + st = time.time() + + # Load train data generator + train_set = CharacterGenerator( + vocab=vocab, + num_samples=args.train_samples * len(vocab), + cache_samples=True, + img_transforms=T.Compose([ + T.Resize((args.input_size, args.input_size)), + # Augmentations + T.RandomApply(T.ColorInversion(), .9), + T.RandomApply(T.ToGray(3), .1), + T.RandomJpegQuality(60), + T.RandomSaturation(.3), + T.RandomContrast(.3), + T.RandomBrightness(.3), + # Blur + T.RandomApply(T.GaussianBlur(kernel_shape=(3, 3), std=(0.1, 3)), .3), + ]), + font_family=fonts, + ) + train_loader = DataLoader( + train_set, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + num_workers=args.workers, + collate_fn=collate_fn, + ) + print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " + f"{train_loader.num_batches} batches)") + + if args.show_samples: + x, target = next(iter(train_loader)) + plot_samples(x, list(map(vocab.__getitem__, target))) + return + + # Optimizer + scheduler = tf.keras.optimizers.schedules.ExponentialDecay( + args.lr, + decay_steps=args.epochs * len(train_loader), + decay_rate=1 / (1e3), # final lr as a fraction of initial lr + staircase=False + ) + optimizer = tf.keras.optimizers.Adam( + learning_rate=scheduler, + beta_1=0.95, + beta_2=0.99, + epsilon=1e-6, + ) + if args.amp: + optimizer = mixed_precision.LossScaleOptimizer(optimizer) + + # LR Finder + if args.find_lr: + lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp) + plot_recorder(lrs, losses) + return + + # Tensorboard to monitor training + current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name + + # W&B + if args.wb: + + run = wandb.init( + name=exp_name, + project="character-classification", + config={ + "learning_rate": args.lr, + "epochs": args.epochs, + "weight_decay": 0., + "batch_size": args.batch_size, + "architecture": args.arch, + "input_size": args.input_size, + "optimizer": "adam", + "framework": "tensorflow", + "vocab": args.vocab, + "scheduler": "exp_decay", + "pretrained": args.pretrained, + } + ) + + # Create loss queue + min_loss = np.inf + + # Training loop + mb = master_bar(range(args.epochs)) + for epoch in mb: + fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb, args.amp) + + # Validation loop at the end of each epoch + val_loss, acc = evaluate(model, val_loader, batch_transforms) + if val_loss < min_loss: + print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") + model.save_weights(f'./{exp_name}/weights') + min_loss = val_loss + mb.write(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} (Acc: {acc:.2%})") + # W&B + if args.wb: + wandb.log({ + 'val_loss': val_loss, + 'acc': acc, + }) + + if args.wb: + run.finish() + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='DocTR training script for character classification (TensorFlow)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('arch', type=str, help='text-recognition model to train') + parser.add_argument('--name', type=str, default=None, help='Name of your training experiment') + parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on') + parser.add_argument('-b', '--batch_size', type=int, default=64, help='batch size for training') + parser.add_argument('--input_size', type=int, default=32, help='input size H for the model, W = 4*H') + parser.add_argument('--lr', type=float, default=0.001, help='learning rate for the optimizer (Adam)') + parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading') + parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint') + parser.add_argument( + '--font', + type=str, + default="FreeMono.ttf,FreeSans.ttf,FreeSerif.ttf", + help='Font family to be used' + ) + parser.add_argument('--vocab', type=str, default="french", help='Vocab to be used for training') + parser.add_argument( + '--train-samples', + dest='train_samples', + type=int, + default=1000, + help='Multiplied by the vocab length gets you the number of training samples that will be used.' + ) + parser.add_argument( + '--val-samples', + dest='val_samples', + type=int, + default=20, + help='Multiplied by the vocab length gets you the number of validation samples that will be used.' + ) + parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop") + parser.add_argument('--show-samples', dest='show_samples', action='store_true', + help='Display unormalized training samples') + parser.add_argument('--wb', dest='wb', action='store_true', + help='Log to Weights & Biases') + parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='Load pretrained parameters before starting the training') + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR') + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/classification/utils.py b/references/classification/utils.py new file mode 100644 index 0000000000..cc950cc20e --- /dev/null +++ b/references/classification/utils.py @@ -0,0 +1,73 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import math + +import matplotlib.pyplot as plt +import numpy as np + + +def plot_samples(images, targets): + # Unnormalize image + num_samples = min(len(images), 12) + num_cols = min(len(images), 8) + num_rows = int(math.ceil(num_samples / num_cols)) + _, axes = plt.subplots(num_rows, num_cols, figsize=(20, 5)) + for idx in range(num_samples): + img = (255 * images[idx].numpy()).round().clip(0, 255).astype(np.uint8) + if img.shape[0] == 3 and img.shape[2] != 3: + img = img.transpose(1, 2, 0) + + row_idx = idx // num_cols + col_idx = idx % num_cols + + ax = axes[row_idx] if num_rows > 1 else axes + ax = ax[col_idx] if num_cols > 1 else ax + + ax.imshow(img) + ax.set_title(targets[idx]) + # Disable axis + for ax in axes.ravel(): + ax.axis('off') + plt.show() + + +def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> None: + """Display the results of the LR grid search. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + + Args: + lr_recorder: list of LR values + loss_recorder: list of loss values + beta (float, optional): smoothing factor + """ + + if len(lr_recorder) != len(loss_recorder) or len(lr_recorder) == 0: + raise AssertionError("Both `lr_recorder` and `loss_recorder` should have the same length") + + # Exp moving average of loss + smoothed_losses = [] + avg_loss = 0. + for idx, loss in enumerate(loss_recorder): + avg_loss = beta * avg_loss + (1 - beta) * loss + smoothed_losses.append(avg_loss / (1 - beta ** (idx + 1))) + + # Properly rescale Y-axis + data_slice = slice( + min(len(loss_recorder) // 10, 10), + -min(len(loss_recorder) // 20, 5) if len(loss_recorder) >= 20 else len(loss_recorder) + ) + vals = np.array(smoothed_losses[data_slice]) + min_idx = vals.argmin() + max_val = vals.max() if min_idx is None else vals[:min_idx + 1].max() # type: ignore[misc] + delta = max_val - vals[min_idx] + + plt.plot(lr_recorder[data_slice], smoothed_losses[data_slice]) + plt.xscale('log') + plt.xlabel('Learning Rate') + plt.ylabel('Training loss') + plt.ylim(vals[min_idx] - 0.1 * delta, max_val + 0.2 * delta) + plt.grid(True, linestyle='--', axis='x') + plt.show(**kwargs) diff --git a/references/detection/README.md b/references/detection/README.md new file mode 100644 index 0000000000..9ae20005e6 --- /dev/null +++ b/references/detection/README.md @@ -0,0 +1,67 @@ +# Text detection + +The sample training script was made to train text detection model with docTR. + +## Setup + +First, you need to install `doctr` (with pip, for instance) + +```shell +pip install -e . --upgrade +pip install -r references/requirements.txt +``` + +## Usage + +You can start your training in TensorFlow: + +```shell +python references/detection/train_tensorflow.py path/to/your/train_set path/to/your/val_set db_resnet50 --epochs 5 +``` +or PyTorch: + +```shell +python references/detection/train_pytorch.py path/to/your/train_set path/to/your/val_set db_resnet50 --epochs 5 --device 0 +``` + +## Data format + +You need to provide both `train_path` and `val_path` arguments to start training. +Each path must lead to folder with 1 subfolder and 1 file: + +```shell +├── images +│ ├── sample_img_01.png +│ ├── sample_img_02.png +│ ├── sample_img_03.png +│ └── ... +└── labels.json +``` + +Each JSON file must be a dictionary, where the keys are the image file names and the value is a dictionary with 3 entries: `img_dimensions` (spatial shape of the image), `img_hash` (SHA256 of the image file), `polygons` (the set of 2D points forming the localization polygon). +The order of the points does not matter inside a polygon. Points are (x, y) absolutes coordinates. + +labels.json +```shell +{ + "sample_img_01.png" = { + 'img_dimensions': (900, 600), + 'img_hash': "theimagedumpmyhash", + 'polygons': [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], ...] + }, + "sample_img_02.png" = { + 'img_dimensions': (900, 600), + 'img_hash': "thisisahash", + 'polygons': [[[x1, y1], [x2, y2], [x3, y3], [x4, y4]], ...] + } + ... +} +``` + +## Advanced options + +Feel free to inspect the multiple script option to customize your training to your own needs! + +```python +python references/detection/train_tensorflow.py --help +``` diff --git a/references/detection/evaluate_pytorch.py b/references/detection/evaluate_pytorch.py new file mode 100644 index 0000000000..d928cdd395 --- /dev/null +++ b/references/detection/evaluate_pytorch.py @@ -0,0 +1,160 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os + +os.environ['USE_TORCH'] = '1' + +import logging +import multiprocessing as mp +import time +from pathlib import Path + +import torch +from torch.utils.data import DataLoader, SequentialSampler +from torchvision.transforms import Normalize +from tqdm import tqdm + +from doctr import datasets +from doctr import transforms as T +from doctr.models import detection +from doctr.utils.metrics import LocalizationConfusion + + +@torch.no_grad() +def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): + # Model in eval mode + model.eval() + # Reset val metric + val_metric.reset() + # Validation loop + val_loss, batch_cnt = 0, 0 + for images, targets in tqdm(val_loader): + if torch.cuda.is_available(): + images = images.cuda() + images = batch_transforms(images) + targets = [t['boxes'] for t in targets] + if amp: + with torch.cuda.amp.autocast(): + out = model(images, targets, return_boxes=True) + else: + out = model(images, targets, return_boxes=True) + # Compute metric + loc_preds = out['preds'] + for boxes_gt, boxes_pred in zip(targets, loc_preds): + # Remove scores + val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1]) + + val_loss += out['loss'].item() + batch_cnt += 1 + + val_loss /= batch_cnt + recall, precision, mean_iou = val_metric.summary() + return val_loss, recall, precision, mean_iou + + +def main(args): + + print(args) + + if not isinstance(args.workers, int): + args.workers = min(16, mp.cpu_count()) + + torch.backends.cudnn.benchmark = True + + # Load docTR model + model = detection.__dict__[args.arch]( + pretrained=not isinstance(args.resume, str), + assume_straight_pages=not args.rotation + ).eval() + + if isinstance(args.size, int): + input_shape = (args.size, args.size) + else: + input_shape = model.cfg['input_shape'][-2:] + mean, std = model.cfg['mean'], model.cfg['std'] + + st = time.time() + ds = datasets.__dict__[args.dataset]( + train=True, + download=True, + rotated_bbox=args.rotation, + sample_transforms=T.Resize(input_shape), + ) + # Monkeypatch + subfolder = ds.root.split("/")[-2:] + ds.root = str(Path(ds.root).parent.parent) + ds.data = [(os.path.join(*subfolder, name), target) for name, target in ds.data] + _ds = datasets.__dict__[args.dataset](train=False, rotated_bbox=args.rotation) + subfolder = _ds.root.split("/")[-2:] + ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _ds.data]) + + test_loader = DataLoader( + ds, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + sampler=SequentialSampler(ds), + pin_memory=torch.cuda.is_available(), + collate_fn=ds.collate_fn, + ) + print(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in " + f"{len(test_loader)} batches)") + + batch_transforms = Normalize(mean=mean, std=std) + + # Resume weights + if isinstance(args.resume, str): + print(f"Resuming {args.resume}") + checkpoint = torch.load(args.resume, map_location='cpu') + model.load_state_dict(checkpoint) + + # GPU + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + # Silent default switch to GPU if available + elif torch.cuda.is_available(): + args.device = 0 + else: + logging.warning("No accessible GPU, targe device set to CPU.") + if torch.cuda.is_available(): + torch.cuda.set_device(args.device) + model = model.cuda() + + # Metrics + metric = LocalizationConfusion(rotated_bbox=args.rotation, mask_shape=input_shape) + + print("Running evaluation") + val_loss, recall, precision, mean_iou = evaluate(model, test_loader, batch_transforms, metric, amp=args.amp) + print(f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | " + f"Mean IoU: {mean_iou:.2%})") + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='docTR evaluation script for text detection (PyTorch)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('arch', type=str, help='text-detection model to evaluate') + parser.add_argument('--dataset', type=str, default="FUNSD", help='Dataset to evaluate on') + parser.add_argument('-b', '--batch_size', type=int, default=2, help='batch size for evaluation') + parser.add_argument('--device', default=None, type=int, help='device') + parser.add_argument('--size', type=int, default=None, help='model input size, H = W') + parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading') + parser.add_argument('--rotation', dest='rotation', action='store_true', + help='inference with rotated bbox') + parser.add_argument('--resume', type=str, default=None, help='Checkpoint to resume') + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/detection/evaluate_tensorflow.py b/references/detection/evaluate_tensorflow.py new file mode 100644 index 0000000000..786f2d4664 --- /dev/null +++ b/references/detection/evaluate_tensorflow.py @@ -0,0 +1,138 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os + +os.environ['USE_TF'] = '1' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +import multiprocessing as mp +import time +from pathlib import Path + +import tensorflow as tf +from tensorflow.keras import mixed_precision +from tqdm import tqdm + +gpu_devices = tf.config.experimental.list_physical_devices('GPU') +if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) + +from doctr import datasets +from doctr import transforms as T +from doctr.datasets import DataLoader +from doctr.models import detection +from doctr.utils.metrics import LocalizationConfusion + + +def evaluate(model, val_loader, batch_transforms, val_metric): + # Reset val metric + val_metric.reset() + # Validation loop + val_loss, batch_cnt = 0, 0 + for images, targets in tqdm(val_loader): + images = batch_transforms(images) + targets = [t['boxes'] for t in targets] + out = model(images, targets, training=False, return_boxes=True) + # Compute metric + loc_preds = out['preds'] + for boxes_gt, boxes_pred in zip(targets, loc_preds): + # Remove scores + val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :-1]) + + val_loss += out['loss'].numpy() + batch_cnt += 1 + + val_loss /= batch_cnt + recall, precision, mean_iou = val_metric.summary() + return val_loss, recall, precision, mean_iou + + +def main(args): + + print(args) + + if not isinstance(args.workers, int): + args.workers = min(16, mp.cpu_count()) + + # AMP + if args.amp: + mixed_precision.set_global_policy('mixed_float16') + + input_shape = (args.size, args.size, 3) if isinstance(args.size, int) else None + + # Load docTR model + model = detection.__dict__[args.arch]( + pretrained=isinstance(args.resume, str), + assume_straight_pages=not args.rotation, + input_shape=input_shape, + ) + + # Resume weights + if isinstance(args.resume, str): + print(f"Resuming {args.resume}") + model.load_weights(args.resume).expect_partial() + + input_shape = model.cfg['input_shape'] if input_shape is None else input_shape + mean, std = model.cfg['mean'], model.cfg['std'] + + st = time.time() + ds = datasets.__dict__[args.dataset]( + train=True, + download=True, + rotated_bbox=args.rotation, + sample_transforms=T.Resize(input_shape[:2]), + ) + # Monkeypatch + subfolder = ds.root.split("/")[-2:] + ds.root = str(Path(ds.root).parent.parent) + ds.data = [(os.path.join(*subfolder, name), target) for name, target in ds.data] + _ds = datasets.__dict__[args.dataset](train=False, rotated_bbox=args.rotation) + subfolder = _ds.root.split("/")[-2:] + ds.data.extend([(os.path.join(*subfolder, name), target) for name, target in _ds.data]) + + test_loader = DataLoader( + ds, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + shuffle=False, + ) + print(f"Test set loaded in {time.time() - st:.4}s ({len(ds)} samples in " + f"{len(test_loader)} batches)") + + batch_transforms = T.Normalize(mean=mean, std=std) + + # Metrics + metric = LocalizationConfusion(rotated_bbox=args.rotation, mask_shape=input_shape[:2]) + + print("Running evaluation") + val_loss, recall, precision, mean_iou = evaluate(model, test_loader, batch_transforms, metric) + print(f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | " + f"Mean IoU: {mean_iou:.2%})") + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='docTR evaluation script for text detection (TensorFlow)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('arch', type=str, help='text-detection model to evaluate') + parser.add_argument('--dataset', type=str, default="FUNSD", help='Dataset to evaluate on') + parser.add_argument('-b', '--batch_size', type=int, default=2, help='batch size for evaluation') + parser.add_argument('--size', type=int, default=None, help='model input size, H = W') + parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading') + parser.add_argument('--rotation', dest='rotation', action='store_true', + help='inference with rotated bbox') + parser.add_argument('--resume', type=str, default=None, help='Checkpoint to resume') + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/detection/latency.csv b/references/detection/latency.csv new file mode 100644 index 0000000000..cad1954d43 --- /dev/null +++ b/references/detection/latency.csv @@ -0,0 +1,15 @@ +arch,input_shape,framework,hardware,mean,std +linknet_resnet18,"(1024, 1024)",pytorch,cpu,473.73,31.55 +linknet_resnet18,"(1024, 1024)",pytorch,gpu,23.92,0.46 +db_resnet34,"(1024, 1024)",pytorch,cpu,955.65,153.92 +db_resnet34,"(1024, 1024)",pytorch,gpu,44.95,0.38 +db_resnet50,"(1024, 1024)",pytorch,cpu,1257.69,112.12 +db_resnet50,"(1024, 1024)",pytorch,gpu,65.11,0.3 +db_mobilenet_v3_large,"(1024, 1024)",pytorch,cpu,576.56,27.48 +db_mobilenet_v3_large,"(1024, 1024)",pytorch,gpu,40.48,0.75 +linknet_resnet18,"(1024, 1024)",tensorflow,cpu,642.86,13.24 +linknet_resnet18,"(1024, 1024)",tensorflow,gpu,31.62,1.56 +db_resnet50,"(1024, 1024)",tensorflow,cpu,1251.3,138.18 +db_resnet50,"(1024, 1024)",tensorflow,gpu,80.21,0.74 +db_mobilenet_v3_large,"(1024, 1024)",tensorflow,cpu,1641.06,144.8 +db_mobilenet_v3_large,"(1024, 1024)",tensorflow,gpu,179.02,3.4 diff --git a/references/detection/latency_pytorch.py b/references/detection/latency_pytorch.py new file mode 100644 index 0000000000..ea0d251a06 --- /dev/null +++ b/references/detection/latency_pytorch.py @@ -0,0 +1,64 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +""" +Text detection latency benchmark +""" + +import argparse +import os +import time + +import numpy as np +import torch + +os.environ['USE_TORCH'] = '1' + +from doctr.models import detection + + +@torch.no_grad() +def main(args): + + device = torch.device("cuda:0" if args.gpu else "cpu") + + # Pretrained imagenet model + model = detection.__dict__[args.arch]( + pretrained=args.pretrained, + pretrained_backbone=False + ).eval().to(device=device) + + # Input + img_tensor = torch.rand((1, 3, args.size, args.size)).to(device=device) + + # Warmup + for _ in range(10): + _ = model(img_tensor) + + timings = [] + + # Evaluation runs + for _ in range(args.it): + start_ts = time.perf_counter() + _ = model(img_tensor) + timings.append(time.perf_counter() - start_ts) + + _timings = np.array(timings) + print(f"{args.arch} ({args.it} runs on ({args.size}, {args.size}) inputs)") + print(f"mean {1000 * _timings.mean():.2f}ms, std {1000 * _timings.std():.2f}ms") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='docTR latency benchmark for text detection (PyTorch)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("arch", type=str, help="Architecture to use") + parser.add_argument("--size", type=int, default=1024, help="The image input size") + parser.add_argument("--gpu", dest="gpu", help='Should the benchmark be performed on GPU', action="store_true") + parser.add_argument("--it", type=int, default=100, help="Number of iterations to run") + parser.add_argument("--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", + action="store_true") + args = parser.parse_args() + + main(args) diff --git a/references/detection/latency_tensorflow.py b/references/detection/latency_tensorflow.py new file mode 100644 index 0000000000..c2d228bd06 --- /dev/null +++ b/references/detection/latency_tensorflow.py @@ -0,0 +1,72 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +""" +Text detection latency benchmark +""" + +import argparse +import os +import time + +import numpy as np +import tensorflow as tf + +os.environ['USE_TF'] = '1' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +from doctr.models import detection + + +def main(args): + + if args.gpu: + gpu_devices = tf.config.experimental.list_physical_devices('GPU') + if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) + else: + raise AssertionError("TensorFlow cannot access your GPU. Please investigate!") + else: + os.environ['CUDA_VISIBLE_DEVICES'] = "" + + # Pretrained imagenet model + model = detection.__dict__[args.arch]( + pretrained=args.pretrained, + pretrained_backbone=False, + input_shape=(args.size, args.size, 3), + ) + + # Input + img_tensor = tf.random.uniform(shape=[1, args.size, args.size, 3], maxval=1, dtype=tf.float32) + + # Warmup + for _ in range(10): + _ = model(img_tensor, training=False) + + timings = [] + + # Evaluation runs + for _ in range(args.it): + start_ts = time.perf_counter() + _ = model(img_tensor, training=False) + timings.append(time.perf_counter() - start_ts) + + _timings = np.array(timings) + print(f"{args.arch} ({args.it} runs on ({args.size}, {args.size}) inputs)") + print(f"mean {1000 * _timings.mean():.2f}ms, std {1000 * _timings.std():.2f}ms") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='docTR latency benchmark for text detection (TensorFlow)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("arch", type=str, help="Architecture to use") + parser.add_argument("--size", type=int, default=1024, help="The image input size") + parser.add_argument("--gpu", dest="gpu", help='Should the benchmark be performed on GPU', action="store_true") + parser.add_argument("--it", type=int, default=100, help="Number of iterations to run") + parser.add_argument("--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", + action="store_true") + args = parser.parse_args() + + main(args) diff --git a/references/detection/results.csv b/references/detection/results.csv new file mode 100644 index 0000000000..3eb6cd2da4 --- /dev/null +++ b/references/detection/results.csv @@ -0,0 +1,9 @@ +architecture,input_shape,framework,test_set,recall,precision,mean_iou +db_resnet50,"(1024, 1024)",tensorflow,funsd,0.8121,0.8665,0.6681 +db_resnet50,"(1024, 1024)",tensorflow,cord,0.9245,0.8962,0.7457 +db_mobilenet_v3_large,"(1024, 1024)",tensorflow,funsd,0.783,0.828,0.6396 +db_mobilenet_v3_large,"(1024, 1024)",tensorflow,cord,0.8098,0.6657,0.5978 +db_resnet50,"(1024, 1024)",pytorch,funsd,0.7917,0.863,0.6652 +db_mobilenet_v3_large,"(1024, 1024)",pytorch,funsd,0.8006,0.841,0.6476 +db_resnet50,"(1024, 1024)",pytorch,cord,0.9296,0.9123,0.7654 +db_mobilenet_v3_large,"(1024, 1024)",pytorch,cord,0.8053,0.6653,0.5976 diff --git a/references/detection/train_pytorch.py b/references/detection/train_pytorch.py new file mode 100644 index 0000000000..63eaca1fb2 --- /dev/null +++ b/references/detection/train_pytorch.py @@ -0,0 +1,391 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os + +os.environ['USE_TORCH'] = '1' + +import datetime +import hashlib +import logging +import multiprocessing as mp +import time + +import numpy as np +import torch +import wandb +from fastprogress.fastprogress import master_bar, progress_bar +from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torchvision.transforms import ColorJitter, Compose, Normalize + +from doctr import transforms as T +from doctr.datasets import DetectionDataset +from doctr.models import detection +from doctr.utils.metrics import LocalizationConfusion +from utils import plot_recorder, plot_samples + + +def record_lr( + model: torch.nn.Module, + train_loader: DataLoader, + batch_transforms, + optimizer, + start_lr: float = 1e-7, + end_lr: float = 1, + num_it: int = 100, + amp: bool = False, +): + """Gridsearch the optimal learning rate for the training. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + """ + + if num_it > len(train_loader): + raise ValueError("the value of `num_it` needs to be lower than the number of available batches") + + model = model.train() + # Update param groups & LR + optimizer.defaults['lr'] = start_lr + for pgroup in optimizer.param_groups: + pgroup['lr'] = start_lr + + gamma = (end_lr / start_lr) ** (1 / (num_it - 1)) + scheduler = MultiplicativeLR(optimizer, lambda step: gamma) + + lr_recorder = [start_lr * gamma ** idx for idx in range(num_it)] + loss_recorder = [] + + if amp: + scaler = torch.cuda.amp.GradScaler() + + for batch_idx, (images, targets) in enumerate(train_loader): + if torch.cuda.is_available(): + images = images.cuda() + + images = batch_transforms(images) + + # Forward, Backward & update + optimizer.zero_grad() + if amp: + with torch.cuda.amp.autocast(): + train_loss = model(images, targets)['loss'] + scaler.scale(train_loss).backward() + # Gradient clipping + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + # Update the params + scaler.step(optimizer) + scaler.update() + else: + train_loss = model(images, targets)['loss'] + train_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + optimizer.step() + # Update LR + scheduler.step() + + # Record + if not torch.isfinite(train_loss): + if batch_idx == 0: + raise ValueError("loss value is NaN or inf.") + else: + break + loss_recorder.append(train_loss.item()) + # Stop after the number of iterations + if batch_idx + 1 == num_it: + break + + return lr_recorder[:len(loss_recorder)], loss_recorder + + +def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=False): + + if amp: + scaler = torch.cuda.amp.GradScaler() + + model.train() + # Iterate over the batches of the dataset + for images, targets in progress_bar(train_loader, parent=mb): + + if torch.cuda.is_available(): + images = images.cuda() + images = batch_transforms(images) + + optimizer.zero_grad() + if amp: + with torch.cuda.amp.autocast(): + train_loss = model(images, targets)['loss'] + scaler.scale(train_loss).backward() + # Gradient clipping + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + # Update the params + scaler.step(optimizer) + scaler.update() + else: + train_loss = model(images, targets)['loss'] + train_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + optimizer.step() + + scheduler.step() + + mb.child.comment = f'Training loss: {train_loss.item():.6}' + + +@torch.no_grad() +def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): + # Model in eval mode + model.eval() + # Reset val metric + val_metric.reset() + # Validation loop + val_loss, batch_cnt = 0, 0 + for images, targets in val_loader: + if torch.cuda.is_available(): + images = images.cuda() + images = batch_transforms(images) + if amp: + with torch.cuda.amp.autocast(): + out = model(images, targets, return_preds=True) + else: + out = model(images, targets, return_preds=True) + # Compute metric + loc_preds = out['preds'] + for boxes_gt, boxes_pred in zip(targets, loc_preds): + # Remove scores + val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :4]) + + val_loss += out['loss'].item() + batch_cnt += 1 + + val_loss /= batch_cnt + recall, precision, mean_iou = val_metric.summary() + return val_loss, recall, precision, mean_iou + + +def main(args): + + print(args) + + if not isinstance(args.workers, int): + args.workers = min(16, mp.cpu_count()) + + torch.backends.cudnn.benchmark = True + + st = time.time() + val_set = DetectionDataset( + img_folder=os.path.join(args.val_path, 'images'), + label_path=os.path.join(args.val_path, 'labels.json'), + img_transforms=T.Resize((args.input_size, args.input_size)), + use_polygons=args.rotation, + ) + val_loader = DataLoader( + val_set, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + sampler=SequentialSampler(val_set), + pin_memory=torch.cuda.is_available(), + collate_fn=val_set.collate_fn, + ) + print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " + f"{len(val_loader)} batches)") + with open(os.path.join(args.val_path, 'labels.json'), 'rb') as f: + val_hash = hashlib.sha256(f.read()).hexdigest() + + batch_transforms = Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)) + + # Load doctr model + model = detection.__dict__[args.arch](pretrained=args.pretrained, assume_straight_pages=not args.rotation) + + # Resume weights + if isinstance(args.resume, str): + print(f"Resuming {args.resume}") + checkpoint = torch.load(args.resume, map_location='cpu') + model.load_state_dict(checkpoint) + + # GPU + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + # Silent default switch to GPU if available + elif torch.cuda.is_available(): + args.device = 0 + else: + logging.warning("No accessible GPU, targe device set to CPU.") + if torch.cuda.is_available(): + torch.cuda.set_device(args.device) + model = model.cuda() + + # Metrics + val_metric = LocalizationConfusion(use_polygons=args.rotation, mask_shape=(args.input_size, args.input_size)) + + if args.test_only: + print("Running evaluation") + val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp) + print(f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | " + f"Mean IoU: {mean_iou:.2%})") + return + + st = time.time() + # Load both train and val data generators + train_set = DetectionDataset( + img_folder=os.path.join(args.train_path, 'images'), + label_path=os.path.join(args.train_path, 'labels.json'), + img_transforms=Compose( + ([T.Resize((args.input_size, args.input_size))] if not args.rotation else []) + + [ + # Augmentations + T.RandomApply(T.ColorInversion(), .1), + ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02), + ] + ), + sample_transforms=T.SampleCompose([ + T.RandomRotate(90, expand=True), + T.ImageTransform(T.Resize((args.input_size, args.input_size))), + ]) if args.rotation else None, + use_polygons=args.rotation, + ) + + train_loader = DataLoader( + train_set, + batch_size=args.batch_size, + drop_last=True, + num_workers=args.workers, + sampler=RandomSampler(train_set), + pin_memory=torch.cuda.is_available(), + collate_fn=train_set.collate_fn, + ) + print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " + f"{len(train_loader)} batches)") + with open(os.path.join(args.train_path, 'labels.json'), 'rb') as f: + train_hash = hashlib.sha256(f.read()).hexdigest() + + if args.show_samples: + x, target = next(iter(train_loader)) + plot_samples(x, target) + return + + # Backbone freezing + if args.freeze_backbone: + for p in model.feat_extractor.parameters(): + p.reguires_grad_(False) + + # Optimizer + optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], args.lr, + betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) + # LR Finder + if args.find_lr: + lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp) + plot_recorder(lrs, losses) + return + # Scheduler + if args.sched == 'cosine': + scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) + elif args.sched == 'onecycle': + scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) + + # Training monitoring + current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name + + # W&B + if args.wb: + + run = wandb.init( + name=exp_name, + project="text-detection", + config={ + "learning_rate": args.lr, + "epochs": args.epochs, + "weight_decay": args.weight_decay, + "batch_size": args.batch_size, + "architecture": args.arch, + "input_size": args.input_size, + "optimizer": "adam", + "framework": "pytorch", + "scheduler": args.sched, + "train_hash": train_hash, + "val_hash": val_hash, + "pretrained": args.pretrained, + "rotation": args.rotation, + "amp": args.amp, + } + ) + + # Create loss queue + min_loss = np.inf + + # Training loop + mb = master_bar(range(args.epochs)) + for epoch in mb: + fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=args.amp) + # Validation loop at the end of each epoch + val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp) + if val_loss < min_loss: + print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") + torch.save(model.state_dict(), f"./{exp_name}.pt") + min_loss = val_loss + log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " + if any(val is None for val in (recall, precision, mean_iou)): + log_msg += "(Undefined metric value, caused by empty GTs or predictions)" + else: + log_msg += f"(Recall: {recall:.2%} | Precision: {precision:.2%} | Mean IoU: {mean_iou:.2%})" + mb.write(log_msg) + # W&B + if args.wb: + wandb.log({ + 'val_loss': val_loss, + 'recall': recall, + 'precision': precision, + 'mean_iou': mean_iou, + }) + + if args.wb: + run.finish() + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='DocTR training script for text detection (PyTorch)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('train_path', type=str, help='path to training data folder') + parser.add_argument('val_path', type=str, help='path to validation data folder') + parser.add_argument('arch', type=str, help='text-detection model to train') + parser.add_argument('--name', type=str, default=None, help='Name of your training experiment') + parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on') + parser.add_argument('-b', '--batch_size', type=int, default=2, help='batch size for training') + parser.add_argument('--device', default=None, type=int, help='device') + parser.add_argument('--input_size', type=int, default=1024, help='model input size, H = W') + parser.add_argument('--lr', type=float, default=0.001, help='learning rate for the optimizer (Adam)') + parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay') + parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading') + parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint') + parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop") + parser.add_argument('--freeze-backbone', dest='freeze_backbone', action='store_true', + help='freeze model backbone for fine-tuning') + parser.add_argument('--show-samples', dest='show_samples', action='store_true', + help='Display unormalized training samples') + parser.add_argument('--wb', dest='wb', action='store_true', + help='Log to Weights & Biases') + parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='Load pretrained parameters before starting the training') + parser.add_argument('--rotation', dest='rotation', action='store_true', + help='train with rotated documents') + parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use') + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR') + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/detection/train_tensorflow.py b/references/detection/train_tensorflow.py new file mode 100644 index 0000000000..5d3fa9585a --- /dev/null +++ b/references/detection/train_tensorflow.py @@ -0,0 +1,340 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os + +os.environ['USE_TF'] = '1' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +import datetime +import hashlib +import multiprocessing as mp +import time + +import numpy as np +import tensorflow as tf +import wandb +from fastprogress.fastprogress import master_bar, progress_bar +from tensorflow.keras import mixed_precision + +gpu_devices = tf.config.experimental.list_physical_devices('GPU') +if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) + +from doctr import transforms as T +from doctr.datasets import DataLoader, DetectionDataset +from doctr.models import detection +from doctr.utils.metrics import LocalizationConfusion +from utils import plot_recorder, plot_samples + + +def record_lr( + model: tf.keras.Model, + train_loader: DataLoader, + batch_transforms, + optimizer, + start_lr: float = 1e-7, + end_lr: float = 1, + num_it: int = 100, + amp: bool = False, +): + """Gridsearch the optimal learning rate for the training. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + """ + + if num_it > len(train_loader): + raise ValueError("the value of `num_it` needs to be lower than the number of available batches") + + # Update param groups & LR + gamma = (end_lr / start_lr) ** (1 / (num_it - 1)) + optimizer.learning_rate = start_lr + + lr_recorder = [start_lr * gamma ** idx for idx in range(num_it)] + loss_recorder = [] + + for batch_idx, (images, targets) in enumerate(train_loader): + + images = batch_transforms(images) + + # Forward, Backward & update + with tf.GradientTape() as tape: + train_loss = model(images, targets, training=True)['loss'] + grads = tape.gradient(train_loss, model.trainable_weights) + + if amp: + grads = optimizer.get_unscaled_gradients(grads) + optimizer.apply_gradients(zip(grads, model.trainable_weights)) + + optimizer.learning_rate = optimizer.learning_rate * gamma + + # Record + train_loss = train_loss.numpy() + if np.any(np.isnan(train_loss)): + if batch_idx == 0: + raise ValueError("loss value is NaN or inf.") + else: + break + loss_recorder.append(train_loss.mean()) + # Stop after the number of iterations + if batch_idx + 1 == num_it: + break + + return lr_recorder[:len(loss_recorder)], loss_recorder + + +def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb, amp=False): + train_iter = iter(train_loader) + # Iterate over the batches of the dataset + for images, targets in progress_bar(train_iter, parent=mb): + + images = batch_transforms(images) + + with tf.GradientTape() as tape: + train_loss = model(images, targets, training=True)['loss'] + grads = tape.gradient(train_loss, model.trainable_weights) + if amp: + grads = optimizer.get_unscaled_gradients(grads) + optimizer.apply_gradients(zip(grads, model.trainable_weights)) + + mb.child.comment = f'Training loss: {train_loss.numpy():.6}' + + +def evaluate(model, val_loader, batch_transforms, val_metric): + # Reset val metric + val_metric.reset() + # Validation loop + val_loss, batch_cnt = 0, 0 + val_iter = iter(val_loader) + for images, targets in val_iter: + images = batch_transforms(images) + out = model(images, targets, training=False, return_preds=True) + # Compute metric + loc_preds = out['preds'] + for boxes_gt, boxes_pred in zip(targets, loc_preds): + # Remove scores + val_metric.update(gts=boxes_gt, preds=boxes_pred[:, :4]) + + val_loss += out['loss'].numpy() + batch_cnt += 1 + + val_loss /= batch_cnt + recall, precision, mean_iou = val_metric.summary() + return val_loss, recall, precision, mean_iou + + +def main(args): + + print(args) + + if not isinstance(args.workers, int): + args.workers = min(16, mp.cpu_count()) + + # AMP + if args.amp: + mixed_precision.set_global_policy('mixed_float16') + + st = time.time() + val_set = DetectionDataset( + img_folder=os.path.join(args.val_path, 'images'), + label_path=os.path.join(args.val_path, 'labels.json'), + img_transforms=T.Resize((args.input_size, args.input_size)), + use_polygons=args.rotation, + ) + val_loader = DataLoader( + val_set, + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + num_workers=args.workers, + ) + print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " + f"{val_loader.num_batches} batches)") + with open(os.path.join(args.val_path, 'labels.json'), 'rb') as f: + val_hash = hashlib.sha256(f.read()).hexdigest() + + batch_transforms = T.Compose([ + T.Normalize(mean=(0.798, 0.785, 0.772), std=(0.264, 0.2749, 0.287)), + ]) + + # Load doctr model + model = detection.__dict__[args.arch]( + pretrained=args.pretrained, + input_shape=(args.input_size, args.input_size, 3), + assume_straight_pages=not args.rotation, + ) + + # Resume weights + if isinstance(args.resume, str): + model.load_weights(args.resume) + + # Metrics + val_metric = LocalizationConfusion(use_polygons=args.rotation, mask_shape=(args.input_size, args.input_size)) + + if args.test_only: + print("Running evaluation") + val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric) + print(f"Validation loss: {val_loss:.6} (Recall: {recall:.2%} | Precision: {precision:.2%} | " + f"Mean IoU: {mean_iou:.2%})") + return + + st = time.time() + # Load both train and val data generators + train_set = DetectionDataset( + img_folder=os.path.join(args.train_path, 'images'), + label_path=os.path.join(args.train_path, 'labels.json'), + img_transforms=T.Compose( + ([T.Resize((args.input_size, args.input_size))] if not args.rotation else []) + + [ + # Augmentations + T.RandomApply(T.ColorInversion(), .1), + T.RandomJpegQuality(60), + T.RandomSaturation(.3), + T.RandomContrast(.3), + T.RandomBrightness(.3), + ] + ), + sample_transforms=T.SampleCompose([ + T.RandomRotate(90, expand=True), + T.ImageTransform(T.Resize((args.input_size, args.input_size))), + ]) if args.rotation else None, + use_polygons=args.rotation, + ) + train_loader = DataLoader( + train_set, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + num_workers=args.workers, + ) + print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " + f"{train_loader.num_batches} batches)") + with open(os.path.join(args.train_path, 'labels.json'), 'rb') as f: + train_hash = hashlib.sha256(f.read()).hexdigest() + + if args.show_samples: + x, target = next(iter(train_loader)) + plot_samples(x, target) + return + + # Optimizer + scheduler = tf.keras.optimizers.schedules.ExponentialDecay( + args.lr, + decay_steps=args.epochs * len(train_loader), + decay_rate=1 / (25e4), # final lr as a fraction of initial lr + staircase=False + ) + optimizer = tf.keras.optimizers.Adam( + learning_rate=scheduler, + beta_1=0.95, + beta_2=0.99, + epsilon=1e-6, + clipnorm=5 + ) + if args.amp: + optimizer = mixed_precision.LossScaleOptimizer(optimizer) + # LR Finder + if args.find_lr: + lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp) + plot_recorder(lrs, losses) + return + + # Tensorboard to monitor training + current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name + + # W&B + if args.wb: + + run = wandb.init( + name=exp_name, + project="text-detection", + config={ + "learning_rate": args.lr, + "epochs": args.epochs, + "weight_decay": 0., + "batch_size": args.batch_size, + "architecture": args.arch, + "input_size": args.input_size, + "optimizer": "adam", + "framework": "tensorflow", + "scheduler": "exp_decay", + "train_hash": train_hash, + "val_hash": val_hash, + "pretrained": args.pretrained, + "rotation": args.rotation, + } + ) + + if args.freeze_backbone: + for layer in model.feat_extractor.layers: + layer.trainable = False + + min_loss = np.inf + + # Training loop + mb = master_bar(range(args.epochs)) + for epoch in mb: + fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb, args.amp) + # Validation loop at the end of each epoch + val_loss, recall, precision, mean_iou = evaluate(model, val_loader, batch_transforms, val_metric) + if val_loss < min_loss: + print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") + model.save_weights(f'./{exp_name}/weights') + min_loss = val_loss + log_msg = f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " + if any(val is None for val in (recall, precision, mean_iou)): + log_msg += "(Undefined metric value, caused by empty GTs or predictions)" + else: + log_msg += f"(Recall: {recall:.2%} | Precision: {precision:.2%} | Mean IoU: {mean_iou:.2%})" + mb.write(log_msg) + # W&B + if args.wb: + wandb.log({ + 'val_loss': val_loss, + 'recall': recall, + 'precision': precision, + 'mean_iou': mean_iou, + }) + + if args.wb: + run.finish() + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='DocTR training script for text detection (TensorFlow)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('train_path', type=str, help='path to training data folder') + parser.add_argument('val_path', type=str, help='path to validation data folder') + parser.add_argument('arch', type=str, help='text-detection model to train') + parser.add_argument('--name', type=str, default=None, help='Name of your training experiment') + parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on') + parser.add_argument('-b', '--batch_size', type=int, default=2, help='batch size for training') + parser.add_argument('--input_size', type=int, default=1024, help='model input size, H = W') + parser.add_argument('--lr', type=float, default=0.001, help='learning rate for the optimizer (Adam)') + parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading') + parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint') + parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop") + parser.add_argument('--freeze-backbone', dest='freeze_backbone', action='store_true', + help='freeze model backbone for fine-tuning') + parser.add_argument('--show-samples', dest='show_samples', action='store_true', + help='Display unormalized training samples') + parser.add_argument('--wb', dest='wb', action='store_true', + help='Log to Weights & Biases') + parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='Load pretrained parameters before starting the training') + parser.add_argument('--rotation', dest='rotation', action='store_true', + help='train with rotated documents') + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR') + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/detection/utils.py b/references/detection/utils.py new file mode 100644 index 0000000000..eb306abb7c --- /dev/null +++ b/references/detection/utils.py @@ -0,0 +1,83 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Dict, List + +import cv2 +import matplotlib.pyplot as plt +import numpy as np + + +def plot_samples(images, targets: List[Dict[str, np.ndarray]]) -> None: + # Unnormalize image + nb_samples = min(len(images), 4) + _, axes = plt.subplots(2, nb_samples, figsize=(20, 5)) + for idx in range(nb_samples): + img = (255 * images[idx].numpy()).round().clip(0, 255).astype(np.uint8) + if img.shape[0] == 3 and img.shape[2] != 3: + img = img.transpose(1, 2, 0) + + target = np.zeros(img.shape[:2], np.uint8) + boxes = targets[idx].copy() + boxes[:, [0, 2]] = boxes[:, [0, 2]] * img.shape[1] + boxes[:, [1, 3]] = boxes[:, [1, 3]] * img.shape[0] + boxes[:, :4] = boxes[:, :4].round().astype(int) + + for box in boxes: + if boxes.ndim == 3: + cv2.fillPoly(target, [np.int0(box)], 1) + else: + target[int(box[1]): int(box[3]) + 1, int(box[0]): int(box[2]) + 1] = 1 + if nb_samples > 1: + axes[0][idx].imshow(img) + axes[1][idx].imshow(target.astype(bool)) + else: + axes[0].imshow(img) + axes[1].imshow(target.astype(bool)) + + # Disable axis + for ax in axes.ravel(): + ax.axis('off') + plt.show() + + +def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> None: + """Display the results of the LR grid search. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + + Args: + lr_recorder: list of LR values + loss_recorder: list of loss values + beta (float, optional): smoothing factor + """ + + if len(lr_recorder) != len(loss_recorder) or len(lr_recorder) == 0: + raise AssertionError("Both `lr_recorder` and `loss_recorder` should have the same length") + + # Exp moving average of loss + smoothed_losses = [] + avg_loss = 0. + for idx, loss in enumerate(loss_recorder): + avg_loss = beta * avg_loss + (1 - beta) * loss + smoothed_losses.append(avg_loss / (1 - beta ** (idx + 1))) + + # Properly rescale Y-axis + data_slice = slice( + min(len(loss_recorder) // 10, 10), + # -min(len(loss_recorder) // 20, 5) if len(loss_recorder) >= 20 else len(loss_recorder) + len(loss_recorder) + ) + vals = np.array(smoothed_losses[data_slice]) + min_idx = vals.argmin() + max_val = vals.max() if min_idx is None else vals[:min_idx + 1].max() # type: ignore[misc] + delta = max_val - vals[min_idx] + + plt.plot(lr_recorder[data_slice], smoothed_losses[data_slice]) + plt.xscale('log') + plt.xlabel('Learning Rate') + plt.ylabel('Training loss') + plt.ylim(vals[min_idx] - 0.1 * delta, max_val + 0.2 * delta) + plt.grid(True, linestyle='--', axis='x') + plt.show(**kwargs) diff --git a/references/obj_detection/latency.csv b/references/obj_detection/latency.csv new file mode 100644 index 0000000000..ef92e78695 --- /dev/null +++ b/references/obj_detection/latency.csv @@ -0,0 +1,3 @@ +arch,input_shape,framework,hardware,mean,std +fasterrcnn_mobilenet_v3_large_fpn,"(1024, 1024)",pytorch,cpu,257.85,14.28 +fasterrcnn_mobilenet_v3_large_fpn,"(1024, 1024)",pytorch,gpu,17.76,0.68 diff --git a/references/obj_detection/latency_pytorch.py b/references/obj_detection/latency_pytorch.py new file mode 100644 index 0000000000..f39f3fe62c --- /dev/null +++ b/references/obj_detection/latency_pytorch.py @@ -0,0 +1,65 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +""" +Object detection latency benchmark +""" + +import argparse +import os +import time + +import numpy as np +import torch + +os.environ['USE_TORCH'] = '1' + +from doctr.models import obj_detection + + +@torch.no_grad() +def main(args): + + device = torch.device("cuda:0" if args.gpu else "cpu") + + # Pretrained imagenet model + model = obj_detection.__dict__[args.arch]( + pretrained=args.pretrained, + min_size=args.size, + max_size=args.size, + ).eval().to(device=device) + + # Input + img_tensor = torch.rand((1, 3, args.size, args.size)).to(device=device) + + # Warmup + for _ in range(10): + _ = model(img_tensor) + + timings = [] + + # Evaluation runs + for _ in range(args.it): + start_ts = time.perf_counter() + _ = model(img_tensor) + timings.append(time.perf_counter() - start_ts) + + _timings = np.array(timings) + print(f"{args.arch} ({args.it} runs on ({args.size}, {args.size}) inputs)") + print(f"mean {1000 * _timings.mean():.2f}ms, std {1000 * _timings.std():.2f}ms") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='docTR latency benchmark for object detection (PyTorch)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("arch", type=str, help="Architecture to use") + parser.add_argument("--size", type=int, default=1024, help="The image input size") + parser.add_argument("--gpu", dest="gpu", help='Should the benchmark be performed on GPU', action="store_true") + parser.add_argument("--it", type=int, default=100, help="Number of iterations to run") + parser.add_argument("--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", + action="store_true") + args = parser.parse_args() + + main(args) diff --git a/references/obj_detection/train_pytorch.py b/references/obj_detection/train_pytorch.py new file mode 100644 index 0000000000..4929a72ab2 --- /dev/null +++ b/references/obj_detection/train_pytorch.py @@ -0,0 +1,364 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os + +os.environ['USE_TORCH'] = '1' + +import datetime +import logging +import multiprocessing as mp +import time + +import numpy as np +import torch +import torch.optim as optim +import wandb +from fastprogress.fastprogress import master_bar, progress_bar +from torch.optim.lr_scheduler import MultiplicativeLR, StepLR +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torchvision.transforms import ColorJitter, Compose, GaussianBlur + +from doctr import transforms as T +from doctr.datasets import DocArtefacts +from doctr.models import obj_detection +from doctr.utils import DetectionMetric +from utils import plot_recorder, plot_samples + + +def record_lr( + model: torch.nn.Module, + train_loader: DataLoader, + optimizer, + start_lr: float = 1e-7, + end_lr: float = 1, + num_it: int = 100, + amp: bool = False, +): + """Gridsearch the optimal learning rate for the training. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + """ + + if num_it > len(train_loader): + raise ValueError("the value of `num_it` needs to be lower than the number of available batches") + + model = model.train() + # Update param groups & LR + optimizer.defaults['lr'] = start_lr + for pgroup in optimizer.param_groups: + pgroup['lr'] = start_lr + + gamma = (end_lr / start_lr) ** (1 / (num_it - 1)) + scheduler = MultiplicativeLR(optimizer, lambda step: gamma) + + lr_recorder = [start_lr * gamma ** idx for idx in range(num_it)] + loss_recorder = [] + + if amp: + scaler = torch.cuda.amp.GradScaler() + + for batch_idx, (images, targets) in enumerate(train_loader): + targets = convert_to_abs_coords(targets, images.shape) + if torch.cuda.is_available(): + images = images.cuda() + targets = [{k: v.cuda() for k, v in t.items()} for t in targets] + + # Forward, Backward & update + optimizer.zero_grad() + if amp: + with torch.cuda.amp.autocast(): + loss_dict = model(images, targets) + train_loss = sum(v for v in loss_dict.values()) + scaler.scale(train_loss).backward() + # Update the params + scaler.step(optimizer) + scaler.update() + else: + loss_dict = model(images, targets) + train_loss = sum(v for v in loss_dict.values()) + train_loss.backward() + optimizer.step() + # Update LR + scheduler.step() + + # Record + if not torch.isfinite(train_loss): + if batch_idx == 0: + raise ValueError("loss value is NaN or inf.") + else: + break + loss_recorder.append(train_loss.item()) + # Stop after the number of iterations + if batch_idx + 1 == num_it: + break + + return lr_recorder[:len(loss_recorder)], loss_recorder + + +def convert_to_abs_coords(targets, img_shape): + height, width = img_shape[-2:] + for idx, t in enumerate(targets): + targets[idx]['boxes'][:, 0::2] = (t['boxes'][:, 0::2] * width).round() + targets[idx]['boxes'][:, 1::2] = (t['boxes'][:, 1::2] * height).round() + + targets = [{ + "boxes": torch.from_numpy(t['boxes']).to(dtype=torch.float32), + "labels": torch.tensor(t['labels']).to(dtype=torch.long)} + for t in targets + ] + + return targets + + +def fit_one_epoch(model, train_loader, optimizer, scheduler, mb, amp=False): + if amp: + scaler = torch.cuda.amp.GradScaler() + + model.train() + # Iterate over the batches of the dataset + for images, targets in progress_bar(train_loader, parent=mb): + + targets = convert_to_abs_coords(targets, images.shape) + if torch.cuda.is_available(): + images = images.cuda() + targets = [{k: v.cuda() for k, v in t.items()} for t in targets] + + optimizer.zero_grad() + if amp: + with torch.cuda.amp.autocast(): + loss_dict = model(images, targets) + loss = sum(v for v in loss_dict.values()) + scaler.scale(loss).backward() + # Update the params + scaler.step(optimizer) + scaler.update() + else: + loss_dict = model(images, targets) + loss = sum(v for v in loss_dict.values()) + loss.backward() + optimizer.step() + + mb.child.comment = f'Training loss: {loss.item()}' + scheduler.step() + + +@torch.no_grad() +def evaluate(model, val_loader, metric, amp=False): + model.eval() + metric.reset() + for images, targets in val_loader: + targets = convert_to_abs_coords(targets, images.shape) + if torch.cuda.is_available(): + images = images.cuda() + + if amp: + with torch.cuda.amp.autocast(): + output = model(images) + else: + output = model(images) + + # Compute metric + pred_labels = np.concatenate([o['labels'].cpu().numpy() for o in output]) + pred_boxes = np.concatenate([o['boxes'].cpu().numpy() for o in output]) + gt_boxes = np.concatenate([o['boxes'].cpu().numpy() for o in targets]) + gt_labels = np.concatenate([o['labels'].cpu().numpy() for o in targets]) + metric.update(gt_boxes, pred_boxes, gt_labels, pred_labels) + + return metric.summary() + + +def main(args): + print(args) + + if not isinstance(args.workers, int): + args.workers = min(16, mp.cpu_count()) + + torch.backends.cudnn.benchmark = True + + st = time.time() + val_set = DocArtefacts( + train=False, + download=True, + img_transforms=T.Resize((args.input_size, args.input_size)), + ) + val_loader = DataLoader( + val_set, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + sampler=SequentialSampler(val_set), + pin_memory=torch.cuda.is_available(), + collate_fn=val_set.collate_fn, + ) + print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " + f"{len(val_loader)} batches)") + + # Load doctr model + model = obj_detection.__dict__[args.arch](pretrained=args.pretrained, num_classes=5) + + # Resume weights + if isinstance(args.resume, str): + print(f"Resuming {args.resume}") + checkpoint = torch.load(args.resume, map_location='cpu') + model.load_state_dict(checkpoint) + + # GPU + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + # Silent default switch to GPU if available + elif torch.cuda.is_available(): + args.device = 0 + else: + logging.warning("No accessible GPU, target device set to CPU.") + if torch.cuda.is_available(): + torch.cuda.set_device(args.device) + model = model.cuda() + + # Metrics + metric = DetectionMetric(iou_thresh=0.5) + + if args.test_only: + print("Running evaluation") + recall, precision, mean_iou = evaluate(model, val_loader, metric, amp=args.amp) + print(f"Recall: {recall:.2%} | Precision: {precision:.2%} |IoU: {mean_iou:.2%}") + return + + st = time.time() + # Load train data generators + train_set = DocArtefacts( + train=True, + download=True, + img_transforms=Compose([ + T.Resize((args.input_size, args.input_size)), + T.RandomApply(T.GaussianNoise(0., 0.25), p=0.5), + ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02), + T.RandomApply(GaussianBlur(kernel_size=(3, 3), sigma=(0.1, 3)), .3), + ]), + sample_transforms=T.RandomHorizontalFlip(p=0.5), + ) + train_loader = DataLoader( + train_set, + batch_size=args.batch_size, + drop_last=True, + num_workers=args.workers, + sampler=RandomSampler(train_set), + pin_memory=torch.cuda.is_available(), + collate_fn=train_set.collate_fn, + ) + print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " + f"{len(train_loader)} batches)") + + if args.show_samples: + images, targets = next(iter(train_loader)) + targets = convert_to_abs_coords(targets, images.shape) + plot_samples(images, targets, train_set.CLASSES) + return + + # Backbone freezing + if args.freeze_backbone: + for p in model.backbone.parameters(): + p.reguires_grad_(False) + + # Optimizer + optimizer = optim.SGD([p for p in model.parameters() if p.requires_grad], + lr=args.lr, weight_decay=args.weight_decay) + # LR Finder + if args.find_lr: + lrs, losses = record_lr(model, train_loader, optimizer, amp=args.amp) + plot_recorder(lrs, losses) + return + # Scheduler + scheduler = StepLR(optimizer, step_size=8, gamma=0.7) + + # Training monitoring + current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name + + # W&B + if args.wb: + run = wandb.init( + name=exp_name, + project="object-detection", + config={ + "learning_rate": args.lr, + "epochs": args.epochs, + "weight_decay": args.weight_decay, + "batch_size": args.batch_size, + "architecture": args.arch, + "input_size": args.input_size, + "optimizer": "sgd", + "framework": "pytorch", + "scheduler": "step", + "pretrained": args.pretrained, + "amp": args.amp, + } + ) + + mb = master_bar(range(args.epochs)) + max_score = 0. + + for epoch in mb: + fit_one_epoch(model, train_loader, optimizer, scheduler, mb, amp=args.amp) + # Validation loop at the end of each epoch + recall, precision, mean_iou = evaluate(model, val_loader, metric, amp=args.amp) + f1_score = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0. + + if f1_score > max_score: + print(f"Validation metric increased {max_score:.6} --> {f1_score:.6}: saving state...") + torch.save(model.state_dict(), f"./{exp_name}.pt") + max_score = f1_score + log_msg = f"Epoch {epoch + 1}/{args.epochs} - " + if any(val is None for val in (recall, precision, mean_iou)): + log_msg += "Undefined metric value, caused by empty GTs or predictions" + else: + log_msg += f"Recall: {recall:.2%} | Precision: {precision:.2%} | Mean IoU: {mean_iou:.2%}" + mb.write(log_msg) + # W&B + if args.wb: + wandb.log({ + 'recall': recall, + 'precision': precision, + 'mean_iou': mean_iou, + }) + + if args.wb: + run.finish() + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='DocTR training script for object detection (PyTorch)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('arch', type=str, help='text-detection model to train') + parser.add_argument('--name', type=str, default=None, help='Name of your training experiment') + parser.add_argument('--epochs', type=int, default=20, help='number of epochs to train the model on') + parser.add_argument('-b', '--batch_size', type=int, default=2, help='batch size for training') + parser.add_argument('--device', default=None, type=int, help='device') + parser.add_argument('--input_size', type=int, default=1024, help='model input size, H = W') + parser.add_argument('--lr', type=float, default=0.001, help='learning rate for the optimizer (SGD)') + parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay') + parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading') + parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint') + parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop") + parser.add_argument('--show-samples', dest='show_samples', action='store_true', + help='Display unormalized training samples') + parser.add_argument('--freeze-backbone', dest='freeze_backbone', action='store_true', + help='freeze model backbone for fine-tuning') + parser.add_argument('--wb', dest='wb', action='store_true', + help='Log to Weights & Biases') + parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='Load pretrained parameters before starting the training') + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR') + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/obj_detection/utils.py b/references/obj_detection/utils.py new file mode 100644 index 0000000000..8d287c8bb4 --- /dev/null +++ b/references/obj_detection/utils.py @@ -0,0 +1,77 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +from typing import Dict, List + +import cv2 +import matplotlib.pyplot as plt +import numpy as np +from matplotlib.cm import get_cmap + + +def plot_samples(images, targets: List[Dict[str, np.ndarray]], classes: List[str]) -> None: + cmap = get_cmap('gist_rainbow', len(classes)) + # Unnormalize image + nb_samples = min(len(images), 4) + _, axes = plt.subplots(1, nb_samples, figsize=(20, 5)) + for idx in range(nb_samples): + img = (255 * images[idx].numpy()).round().clip(0, 255).astype(np.uint8) + if img.shape[0] == 3 and img.shape[2] != 3: + img = img.transpose(1, 2, 0) + target = img.copy() + for box, class_idx in zip(targets[idx]['boxes'].numpy(), targets[idx]['labels']): + r, g, b, _ = cmap(class_idx.numpy()) + color = int(round(255 * r)), int(round(255 * g)), int(round(255 * b)) + cv2.rectangle(target, (int(box[0]), int(box[1])), (int(box[2]), int(box[3])), color, 2) + text_size, _ = cv2.getTextSize(classes[class_idx], cv2.FONT_HERSHEY_SIMPLEX, 1, 2) + text_w, text_h = text_size + cv2.rectangle(target, (int(box[0]), int(box[1])), (int(box[0]) + text_w, int(box[1]) - text_h), color, -1) + cv2.putText(target, classes[class_idx], (int(box[0]), int(box[1])), cv2.FONT_HERSHEY_SIMPLEX, 1, + (255, 255, 255), 2) + + axes[idx].imshow(target) + # Disable axis + for ax in axes.ravel(): + ax.axis('off') + plt.show() + + +def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> None: + """Display the results of the LR grid search. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + + Args: + lr_recorder: list of LR values + loss_recorder: list of loss values + beta (float, optional): smoothing factor + """ + + if len(lr_recorder) != len(loss_recorder) or len(lr_recorder) == 0: + raise AssertionError("Both `lr_recorder` and `loss_recorder` should have the same length") + + # Exp moving average of loss + smoothed_losses = [] + avg_loss = 0. + for idx, loss in enumerate(loss_recorder): + avg_loss = beta * avg_loss + (1 - beta) * loss + smoothed_losses.append(avg_loss / (1 - beta ** (idx + 1))) + + # Properly rescale Y-axis + data_slice = slice( + min(len(loss_recorder) // 10, 10), + -min(len(loss_recorder) // 20, 5) if len(loss_recorder) >= 20 else len(loss_recorder) + ) + vals = np.array(smoothed_losses[data_slice]) + min_idx = vals.argmin() + max_val = vals.max() if min_idx is None else vals[:min_idx + 1].max() # type: ignore[misc] + delta = max_val - vals[min_idx] + + plt.plot(lr_recorder[data_slice], smoothed_losses[data_slice]) + plt.xscale('log') + plt.xlabel('Learning Rate') + plt.ylabel('Training loss') + plt.ylim(vals[min_idx] - 0.1 * delta, max_val + 0.2 * delta) + plt.grid(True, linestyle='--', axis='x') + plt.show(**kwargs) diff --git a/references/recognition/README.md b/references/recognition/README.md new file mode 100644 index 0000000000..eaee629c5e --- /dev/null +++ b/references/recognition/README.md @@ -0,0 +1,63 @@ +# Text recognition + +The sample training script was made to train text recognition model with docTR. + +## Setup + +First, you need to install `doctr` (with pip, for instance) + +```shell +pip install -e . --upgrade +pip install -r references/requirements.txt +``` + +## Usage + +You can start your training in TensorFlow: + +```shell +python references/recognition/train_tensorflow.py path/to/your/train_set path/to/your/val_set crnn_vgg16_bn --epochs 5 +``` +or PyTorch: + +```shell +python references/recognition/train_pytorch.py path/to/your/train_set path/to/your/val_set crnn_vgg16_bn --epochs 5 --device 0 +``` + + + +## Data format + +You need to provide both `train_path` and `val_path` arguments to start training. +Each of these paths must lead to a 2-elements folder: + +```shell +├── images + ├── img_1.jpg + ├── img_2.jpg + ├── img_3.jpg + └── ... +├── labels.json +``` + +The JSON files must contain word-labels for each picture as a string. +The order of entries in the json does not matter. + +```shell +labels = { + 'img_1.jpg': 'I', + 'img_2.jpg': 'am', + 'img_3.jpg': 'a', + 'img_4.jpg': 'Jedi', + 'img_5.jpg': '!', + ... +} +``` + +## Advanced options + +Feel free to inspect the multiple script option to customize your training to your own needs! + +```python +python references/recognition/train_pytorch.py --help +``` diff --git a/references/recognition/latency.csv b/references/recognition/latency.csv new file mode 100644 index 0000000000..593664ef9e --- /dev/null +++ b/references/recognition/latency.csv @@ -0,0 +1,21 @@ +arch,input_shape,framework,hardware,mean,std +crnn_vgg16_bn,"(32, 128)",pytorch,cpu,687.93,93.79 +crnn_vgg16_bn,"(32, 128)",pytorch,gpu,32.34,0.25 +crnn_mobilenet_v3_small,"(32, 128)",pytorch,cpu,64.74,7.97 +crnn_mobilenet_v3_small,"(32, 128)",pytorch,gpu,8.43,0.57 +crnn_mobilenet_v3_large,"(32, 128)",pytorch,cpu,138.81,9.27 +crnn_mobilenet_v3_large,"(32, 128)",pytorch,gpu,12.93,1.1 +sar_resnet31,"(32, 128)",pytorch,cpu,, +sar_resnet31,"(32, 128)",pytorch,gpu,256.21,1.89 +master,"(32, 128)",pytorch,cpu,, +master,"(32, 128)",pytorch,gpu,, +crnn_vgg16_bn,"(32, 128)",tensorflow,cpu,826.57,20.15 +crnn_vgg16_bn,"(32, 128)",tensorflow,gpu,62.92,1.32 +crnn_mobilenet_v3_small,"(32, 128)",tensorflow,cpu,901.53,5.14 +crnn_mobilenet_v3_small,"(32, 128)",tensorflow,gpu,67.98,1.71 +crnn_mobilenet_v3_large,"(32, 128)",tensorflow,cpu,1487.1,26.14 +crnn_mobilenet_v3_large,"(32, 128)",tensorflow,gpu,75.67,1.34 +sar_resnet31,"(32, 128)",tensorflow,cpu,, +sar_resnet31,"(32, 128)",tensorflow,gpu,258.95,2.2 +master,"(32, 128)",tensorflow,cpu,, +master,"(32, 128)",tensorflow,gpu,1180.08,18.38 diff --git a/references/recognition/latency_pytorch.py b/references/recognition/latency_pytorch.py new file mode 100644 index 0000000000..1880da1228 --- /dev/null +++ b/references/recognition/latency_pytorch.py @@ -0,0 +1,65 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +""" +Text recognition latency benchmark +""" + +import argparse +import os +import time + +import numpy as np +import torch + +os.environ['USE_TORCH'] = '1' + +from doctr.models import recognition + + +@torch.no_grad() +def main(args): + + device = torch.device("cuda:0" if args.gpu else "cpu") + + # Pretrained imagenet model + model = recognition.__dict__[args.arch]( + pretrained=args.pretrained, + pretrained_backbone=False, + ).eval().to(device=device) + + # Input + img_tensor = torch.rand((args.batch_size, 3, args.size, 4 * args.size)).to(device=device) + + # Warmup + for _ in range(10): + _ = model(img_tensor) + + timings = [] + + # Evaluation runs + for _ in range(args.it): + start_ts = time.perf_counter() + _ = model(img_tensor) + timings.append(time.perf_counter() - start_ts) + + _timings = np.array(timings) + print(f"{args.arch} ({args.it} runs on ({args.size}, {4 * args.size}) inputs in batches of {args.batch_size})") + print(f"mean {1000 * _timings.mean():.2f}ms, std {1000 * _timings.std():.2f}ms") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='docTR latency benchmark for text recognition (PyTorch)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("arch", type=str, help="Architecture to use") + parser.add_argument("--batch-size", "-b", type=int, default=64, help="The batch_size") + parser.add_argument("--size", type=int, default=32, help="The image input size") + parser.add_argument("--gpu", dest="gpu", help='Should the benchmark be performed on GPU', action="store_true") + parser.add_argument("--it", type=int, default=100, help="Number of iterations to run") + parser.add_argument("--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", + action="store_true") + args = parser.parse_args() + + main(args) diff --git a/references/recognition/latency_tensorflow.py b/references/recognition/latency_tensorflow.py new file mode 100644 index 0000000000..96be738c45 --- /dev/null +++ b/references/recognition/latency_tensorflow.py @@ -0,0 +1,74 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +""" +Text recognition latency benchmark +""" + +import argparse +import os +import time + +import numpy as np +import tensorflow as tf + +os.environ['USE_TF'] = '1' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +from doctr.models import recognition + + +def main(args): + + if args.gpu: + gpu_devices = tf.config.experimental.list_physical_devices('GPU') + if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) + else: + raise AssertionError("TensorFlow cannot access your GPU. Please investigate!") + else: + os.environ['CUDA_VISIBLE_DEVICES'] = "" + + spatial_shape = (args.size, 4 * args.size) + # Pretrained imagenet model + model = recognition.__dict__[args.arch]( + pretrained=args.pretrained, + pretrained_backbone=False, + input_shape=(*spatial_shape, 3), + ) + + # Input + img_tensor = tf.random.uniform(shape=[args.batch_size, *spatial_shape, 3], maxval=1, dtype=tf.float32) + + # Warmup + for _ in range(10): + _ = model(img_tensor, training=False) + + timings = [] + + # Evaluation runs + for _ in range(args.it): + start_ts = time.perf_counter() + _ = model(img_tensor, training=False) + timings.append(time.perf_counter() - start_ts) + + _timings = np.array(timings) + print(f"{args.arch} ({args.it} runs on {spatial_shape} inputs in batches of {args.batch_size})") + print(f"mean {1000 * _timings.mean():.2f}ms, std {1000 * _timings.std():.2f}ms") + + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='docTR latency benchmark for text recognition (TensorFlow)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument("arch", type=str, help="Architecture to use") + parser.add_argument("--batch-size", "-b", type=int, default=64, help="The batch_size") + parser.add_argument("--size", type=int, default=32, help="The image input size") + parser.add_argument("--gpu", dest="gpu", help='Should the benchmark be performed on GPU', action="store_true") + parser.add_argument("--it", type=int, default=100, help="Number of iterations to run") + parser.add_argument("--pretrained", dest="pretrained", help="Use pre-trained models from the modelzoo", + action="store_true") + args = parser.parse_args() + + main(args) diff --git a/references/recognition/train_pytorch.py b/references/recognition/train_pytorch.py new file mode 100644 index 0000000000..14c709ba51 --- /dev/null +++ b/references/recognition/train_pytorch.py @@ -0,0 +1,381 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os + +os.environ['USE_TORCH'] = '1' + +import datetime +import hashlib +import logging +import multiprocessing as mp +import time +from pathlib import Path + +import numpy as np +import torch +import wandb +from fastprogress.fastprogress import master_bar, progress_bar +from torch.optim.lr_scheduler import CosineAnnealingLR, MultiplicativeLR, OneCycleLR +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from torchvision.transforms import ColorJitter, Compose, Normalize + +from doctr import transforms as T +from doctr.datasets import VOCABS, RecognitionDataset +from doctr.models import recognition +from doctr.utils.metrics import TextMatch +from utils import plot_recorder, plot_samples + + +def record_lr( + model: torch.nn.Module, + train_loader: DataLoader, + batch_transforms, + optimizer, + start_lr: float = 1e-7, + end_lr: float = 1, + num_it: int = 100, + amp: bool = False, +): + """Gridsearch the optimal learning rate for the training. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + """ + + if num_it > len(train_loader): + raise ValueError("the value of `num_it` needs to be lower than the number of available batches") + + model = model.train() + # Update param groups & LR + optimizer.defaults['lr'] = start_lr + for pgroup in optimizer.param_groups: + pgroup['lr'] = start_lr + + gamma = (end_lr / start_lr) ** (1 / (num_it - 1)) + scheduler = MultiplicativeLR(optimizer, lambda step: gamma) + + lr_recorder = [start_lr * gamma ** idx for idx in range(num_it)] + loss_recorder = [] + + if amp: + scaler = torch.cuda.amp.GradScaler() + + for batch_idx, (images, targets) in enumerate(train_loader): + if torch.cuda.is_available(): + images = images.cuda() + + images = batch_transforms(images) + + # Forward, Backward & update + optimizer.zero_grad() + if amp: + with torch.cuda.amp.autocast(): + train_loss = model(images, targets)['loss'] + scaler.scale(train_loss).backward() + # Gradient clipping + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + # Update the params + scaler.step(optimizer) + scaler.update() + else: + train_loss = model(images, targets)['loss'] + train_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + optimizer.step() + # Update LR + scheduler.step() + + # Record + if not torch.isfinite(train_loss): + if batch_idx == 0: + raise ValueError("loss value is NaN or inf.") + else: + break + loss_recorder.append(train_loss.item()) + # Stop after the number of iterations + if batch_idx + 1 == num_it: + break + + return lr_recorder[:len(loss_recorder)], loss_recorder + + +def fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=False): + + if amp: + scaler = torch.cuda.amp.GradScaler() + + model.train() + # Iterate over the batches of the dataset + for images, targets in progress_bar(train_loader, parent=mb): + + if torch.cuda.is_available(): + images = images.cuda() + images = batch_transforms(images) + + train_loss = model(images, targets)['loss'] + + optimizer.zero_grad() + if amp: + with torch.cuda.amp.autocast(): + train_loss = model(images, targets)['loss'] + scaler.scale(train_loss).backward() + # Gradient clipping + scaler.unscale_(optimizer) + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + # Update the params + scaler.step(optimizer) + scaler.update() + else: + train_loss = model(images, targets)['loss'] + train_loss.backward() + torch.nn.utils.clip_grad_norm_(model.parameters(), 5) + optimizer.step() + + scheduler.step() + + mb.child.comment = f'Training loss: {train_loss.item():.6}' + + +@torch.no_grad() +def evaluate(model, val_loader, batch_transforms, val_metric, amp=False): + # Model in eval mode + model.eval() + # Reset val metric + val_metric.reset() + # Validation loop + val_loss, batch_cnt = 0, 0 + for images, targets in val_loader: + if torch.cuda.is_available(): + images = images.cuda() + images = batch_transforms(images) + if amp: + with torch.cuda.amp.autocast(): + out = model(images, targets, return_preds=True) + else: + out = model(images, targets, return_preds=True) + # Compute metric + if len(out['preds']): + words, _ = zip(*out['preds']) + else: + words = [] + val_metric.update(targets, words) + + val_loss += out['loss'].item() + batch_cnt += 1 + + val_loss /= batch_cnt + result = val_metric.summary() + return val_loss, result['raw'], result['unicase'] + + +def main(args): + + print(args) + + if not isinstance(args.workers, int): + args.workers = min(16, mp.cpu_count()) + + torch.backends.cudnn.benchmark = True + + # Load val data generator + st = time.time() + val_set = RecognitionDataset( + img_folder=os.path.join(args.val_path, 'images'), + labels_path=os.path.join(args.val_path, 'labels.json'), + img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), + ) + val_loader = DataLoader( + val_set, + batch_size=args.batch_size, + drop_last=False, + num_workers=args.workers, + sampler=SequentialSampler(val_set), + pin_memory=torch.cuda.is_available(), + collate_fn=val_set.collate_fn, + ) + print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " + f"{len(val_loader)} batches)") + with open(os.path.join(args.val_path, 'labels.json'), 'rb') as f: + val_hash = hashlib.sha256(f.read()).hexdigest() + + batch_transforms = Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)) + + # Load doctr model + model = recognition.__dict__[args.arch](pretrained=args.pretrained, vocab=VOCABS[args.vocab]) + + # Resume weights + if isinstance(args.resume, str): + print(f"Resuming {args.resume}") + checkpoint = torch.load(args.resume, map_location='cpu') + model.load_state_dict(checkpoint) + + # GPU + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + # Silent default switch to GPU if available + elif torch.cuda.is_available(): + args.device = 0 + else: + logging.warning("No accessible GPU, targe device set to CPU.") + if torch.cuda.is_available(): + torch.cuda.set_device(args.device) + model = model.cuda() + + # Metrics + val_metric = TextMatch() + + if args.test_only: + print("Running evaluation") + val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp) + print(f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})") + return + + st = time.time() + + # Load train data generator + base_path = Path(args.train_path) + parts = [base_path] if base_path.joinpath('labels.json').is_file() else [ + base_path.joinpath(sub) for sub in os.listdir(base_path) + ] + train_set = RecognitionDataset( + parts[0].joinpath('images'), + parts[0].joinpath('labels.json'), + img_transforms=Compose([ + T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), + # Augmentations + T.RandomApply(T.ColorInversion(), .1), + ColorJitter(brightness=0.3, contrast=0.3, saturation=0.3, hue=0.02), + ]), + ) + if len(parts) > 1: + for subfolder in parts[1:]: + train_set.merge_dataset(RecognitionDataset(subfolder.joinpath('images'), subfolder.joinpath('labels.json'))) + + train_loader = DataLoader( + train_set, + batch_size=args.batch_size, + drop_last=True, + num_workers=args.workers, + sampler=RandomSampler(train_set), + pin_memory=torch.cuda.is_available(), + collate_fn=train_set.collate_fn, + ) + print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " + f"{len(train_loader)} batches)") + with open(parts[0].joinpath('labels.json'), 'rb') as f: + train_hash = hashlib.sha256(f.read()).hexdigest() + + if args.show_samples: + x, target = next(iter(train_loader)) + plot_samples(x, target) + return + + # Optimizer + optimizer = torch.optim.Adam([p for p in model.parameters() if p.requires_grad], args.lr, + betas=(0.95, 0.99), eps=1e-6, weight_decay=args.weight_decay) + # LR Finder + if args.find_lr: + lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp) + plot_recorder(lrs, losses) + return + # Scheduler + if args.sched == 'cosine': + scheduler = CosineAnnealingLR(optimizer, args.epochs * len(train_loader), eta_min=args.lr / 25e4) + elif args.sched == 'onecycle': + scheduler = OneCycleLR(optimizer, args.lr, args.epochs * len(train_loader)) + + # Training monitoring + current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name + + # W&B + if args.wb: + + run = wandb.init( + name=exp_name, + project="text-recognition", + config={ + "learning_rate": args.lr, + "epochs": args.epochs, + "weight_decay": args.weight_decay, + "batch_size": args.batch_size, + "architecture": args.arch, + "input_size": args.input_size, + "optimizer": "adam", + "framework": "pytorch", + "scheduler": args.sched, + "vocab": args.vocab, + "train_hash": train_hash, + "val_hash": val_hash, + "pretrained": args.pretrained, + } + ) + + # Create loss queue + min_loss = np.inf + # Training loop + mb = master_bar(range(args.epochs)) + for epoch in mb: + fit_one_epoch(model, train_loader, batch_transforms, optimizer, scheduler, mb, amp=args.amp) + + # Validation loop at the end of each epoch + val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric, amp=args.amp) + if val_loss < min_loss: + print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") + torch.save(model.state_dict(), f"./{exp_name}.pt") + min_loss = val_loss + mb.write(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " + f"(Exact: {exact_match:.2%} | Partial: {partial_match:.2%})") + # W&B + if args.wb: + wandb.log({ + 'val_loss': val_loss, + 'exact_match': exact_match, + 'partial_match': partial_match, + }) + + if args.wb: + run.finish() + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='DocTR training script for text recognition (PyTorch)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('train_path', type=str, help='path to train data folder(s)') + parser.add_argument('val_path', type=str, help='path to val data folder') + parser.add_argument('arch', type=str, help='text-recognition model to train') + parser.add_argument('--name', type=str, default=None, help='Name of your training experiment') + parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on') + parser.add_argument('-b', '--batch_size', type=int, default=64, help='batch size for training') + parser.add_argument('--device', default=None, type=int, help='device') + parser.add_argument('--input_size', type=int, default=32, help='input size H for the model, W = 4*H') + parser.add_argument('--lr', type=float, default=0.001, help='learning rate for the optimizer (Adam)') + parser.add_argument('--wd', '--weight-decay', default=0, type=float, help='weight decay', dest='weight_decay') + parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading') + parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint') + parser.add_argument('--vocab', type=str, default="french", help='Vocab to be used for training') + parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop") + parser.add_argument('--show-samples', dest='show_samples', action='store_true', + help='Display unormalized training samples') + parser.add_argument('--wb', dest='wb', action='store_true', + help='Log to Weights & Biases') + parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='Load pretrained parameters before starting the training') + parser.add_argument('--sched', type=str, default='cosine', help='scheduler to use') + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR') + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/recognition/train_tensorflow.py b/references/recognition/train_tensorflow.py new file mode 100644 index 0000000000..aaff786bee --- /dev/null +++ b/references/recognition/train_tensorflow.py @@ -0,0 +1,332 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os + +os.environ['USE_TF'] = '1' +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +import datetime +import hashlib +import multiprocessing as mp +import time +from pathlib import Path + +import numpy as np +import tensorflow as tf +import wandb +from fastprogress.fastprogress import master_bar, progress_bar +from tensorflow.keras import mixed_precision + +gpu_devices = tf.config.experimental.list_physical_devices('GPU') +if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) + +from doctr import transforms as T +from doctr.datasets import VOCABS, DataLoader, RecognitionDataset +from doctr.models import recognition +from doctr.utils.metrics import TextMatch +from utils import plot_recorder, plot_samples + + +def record_lr( + model: tf.keras.Model, + train_loader: DataLoader, + batch_transforms, + optimizer, + start_lr: float = 1e-7, + end_lr: float = 1, + num_it: int = 100, + amp: bool = False, +): + """Gridsearch the optimal learning rate for the training. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py + """ + + if num_it > len(train_loader): + raise ValueError("the value of `num_it` needs to be lower than the number of available batches") + + # Update param groups & LR + gamma = (end_lr / start_lr) ** (1 / (num_it - 1)) + optimizer.learning_rate = start_lr + + lr_recorder = [start_lr * gamma ** idx for idx in range(num_it)] + loss_recorder = [] + + for batch_idx, (images, targets) in enumerate(train_loader): + + images = batch_transforms(images) + + # Forward, Backward & update + with tf.GradientTape() as tape: + train_loss = model(images, targets, training=True)['loss'] + grads = tape.gradient(train_loss, model.trainable_weights) + + if amp: + grads = optimizer.get_unscaled_gradients(grads) + optimizer.apply_gradients(zip(grads, model.trainable_weights)) + + optimizer.learning_rate = optimizer.learning_rate * gamma + + # Record + train_loss = train_loss.numpy() + if np.any(np.isnan(train_loss)): + if batch_idx == 0: + raise ValueError("loss value is NaN or inf.") + else: + break + loss_recorder.append(train_loss.mean()) + # Stop after the number of iterations + if batch_idx + 1 == num_it: + break + + return lr_recorder[:len(loss_recorder)], loss_recorder + + +def fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb, amp=False): + train_iter = iter(train_loader) + # Iterate over the batches of the dataset + for images, targets in progress_bar(train_iter, parent=mb): + + images = batch_transforms(images) + + with tf.GradientTape() as tape: + train_loss = model(images, targets, training=True)['loss'] + grads = tape.gradient(train_loss, model.trainable_weights) + if amp: + grads = optimizer.get_unscaled_gradients(grads) + optimizer.apply_gradients(zip(grads, model.trainable_weights)) + + mb.child.comment = f'Training loss: {train_loss.numpy().mean():.6}' + + +def evaluate(model, val_loader, batch_transforms, val_metric): + # Reset val metric + val_metric.reset() + # Validation loop + val_loss, batch_cnt = 0, 0 + val_iter = iter(val_loader) + for images, targets in val_iter: + images = batch_transforms(images) + out = model(images, targets, return_preds=True, training=False) + # Compute metric + if len(out['preds']): + words, _ = zip(*out['preds']) + else: + words = [] + val_metric.update(targets, words) + + val_loss += out['loss'].numpy().mean() + batch_cnt += 1 + + val_loss /= batch_cnt + result = val_metric.summary() + return val_loss, result['raw'], result['unicase'] + + +def main(args): + + print(args) + + if not isinstance(args.workers, int): + args.workers = min(16, mp.cpu_count()) + + # AMP + if args.amp: + mixed_precision.set_global_policy('mixed_float16') + + # Load val data generator + st = time.time() + val_set = RecognitionDataset( + img_folder=os.path.join(args.val_path, 'images'), + labels_path=os.path.join(args.val_path, 'labels.json'), + img_transforms=T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), + ) + val_loader = DataLoader( + val_set, + batch_size=args.batch_size, + shuffle=False, + drop_last=False, + num_workers=args.workers, + ) + print(f"Validation set loaded in {time.time() - st:.4}s ({len(val_set)} samples in " + f"{val_loader.num_batches} batches)") + with open(os.path.join(args.val_path, 'labels.json'), 'rb') as f: + val_hash = hashlib.sha256(f.read()).hexdigest() + + # Load doctr model + model = recognition.__dict__[args.arch]( + pretrained=args.pretrained, + input_shape=(args.input_size, 4 * args.input_size, 3), + vocab=VOCABS[args.vocab] + ) + # Resume weights + if isinstance(args.resume, str): + model.load_weights(args.resume) + + # Metrics + val_metric = TextMatch() + + batch_transforms = T.Compose([ + T.Normalize(mean=(0.694, 0.695, 0.693), std=(0.299, 0.296, 0.301)), + ]) + + if args.test_only: + print("Running evaluation") + val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric) + print(f"Validation loss: {val_loss:.6} (Exact: {exact_match:.2%} | Partial: {partial_match:.2%})") + return + + st = time.time() + + # Load train data generator + base_path = Path(args.train_path) + parts = [base_path] if base_path.joinpath('labels.json').is_file() else [ + base_path.joinpath(sub) for sub in os.listdir(base_path) + ] + train_set = RecognitionDataset( + parts[0].joinpath('images'), + parts[0].joinpath('labels.json'), + img_transforms=T.Compose([ + T.RandomApply(T.ColorInversion(), .1), + T.Resize((args.input_size, 4 * args.input_size), preserve_aspect_ratio=True), + # Augmentations + T.RandomJpegQuality(60), + T.RandomSaturation(.3), + T.RandomContrast(.3), + T.RandomBrightness(.3), + ]), + ) + + if len(parts) > 1: + for subfolder in parts[1:]: + train_set.merge_dataset(RecognitionDataset(subfolder.joinpath('images'), subfolder.joinpath('labels.json'))) + + train_loader = DataLoader( + train_set, + batch_size=args.batch_size, + shuffle=True, + drop_last=True, + num_workers=args.workers, + ) + print(f"Train set loaded in {time.time() - st:.4}s ({len(train_set)} samples in " + f"{train_loader.num_batches} batches)") + with open(parts[0].joinpath('labels.json'), 'rb') as f: + train_hash = hashlib.sha256(f.read()).hexdigest() + + if args.show_samples: + x, target = next(iter(train_loader)) + plot_samples(x, target) + return + + # Optimizer + scheduler = tf.keras.optimizers.schedules.ExponentialDecay( + args.lr, + decay_steps=args.epochs * len(train_loader), + decay_rate=1 / (25e4), # final lr as a fraction of initial lr + staircase=False + ) + optimizer = tf.keras.optimizers.Adam( + learning_rate=scheduler, + beta_1=0.95, + beta_2=0.99, + epsilon=1e-6, + clipnorm=5 + ) + if args.amp: + optimizer = mixed_precision.LossScaleOptimizer(optimizer) + # LR Finder + if args.find_lr: + lrs, losses = record_lr(model, train_loader, batch_transforms, optimizer, amp=args.amp) + plot_recorder(lrs, losses) + return + + # Tensorboard to monitor training + current_time = datetime.datetime.now().strftime("%Y%m%d-%H%M%S") + exp_name = f"{args.arch}_{current_time}" if args.name is None else args.name + + # W&B + if args.wb: + + run = wandb.init( + name=exp_name, + project="text-recognition", + config={ + "learning_rate": args.lr, + "epochs": args.epochs, + "weight_decay": 0., + "batch_size": args.batch_size, + "architecture": args.arch, + "input_size": args.input_size, + "optimizer": "adam", + "framework": "tensorflow", + "scheduler": "exp_decay", + "vocab": args.vocab, + "train_hash": train_hash, + "val_hash": val_hash, + "pretrained": args.pretrained, + } + ) + + min_loss = np.inf + + # Training loop + mb = master_bar(range(args.epochs)) + for epoch in mb: + fit_one_epoch(model, train_loader, batch_transforms, optimizer, mb, args.amp) + + # Validation loop at the end of each epoch + val_loss, exact_match, partial_match = evaluate(model, val_loader, batch_transforms, val_metric) + if val_loss < min_loss: + print(f"Validation loss decreased {min_loss:.6} --> {val_loss:.6}: saving state...") + model.save_weights(f'./{exp_name}/weights') + min_loss = val_loss + mb.write(f"Epoch {epoch + 1}/{args.epochs} - Validation loss: {val_loss:.6} " + f"(Exact: {exact_match:.2%} | Partial: {partial_match:.2%})") + # W&B + if args.wb: + wandb.log({ + 'val_loss': val_loss, + 'exact_match': exact_match, + 'partial_match': partial_match, + }) + + if args.wb: + run.finish() + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='DocTR training script for text recognition (TensorFlow)', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('train_path', type=str, help='path to train data folder(s)') + parser.add_argument('val_path', type=str, help='path to val data folder') + parser.add_argument('arch', type=str, help='text-recognition model to train') + parser.add_argument('--name', type=str, default=None, help='Name of your training experiment') + parser.add_argument('--epochs', type=int, default=10, help='number of epochs to train the model on') + parser.add_argument('-b', '--batch_size', type=int, default=64, help='batch size for training') + parser.add_argument('--input_size', type=int, default=32, help='input size H for the model, W = 4*H') + parser.add_argument('--lr', type=float, default=0.001, help='learning rate for the optimizer (Adam)') + parser.add_argument('-j', '--workers', type=int, default=None, help='number of workers used for dataloading') + parser.add_argument('--resume', type=str, default=None, help='Path to your checkpoint') + parser.add_argument('--vocab', type=str, default="french", help='Vocab to be used for training') + parser.add_argument("--test-only", dest='test_only', action='store_true', help="Run the validation loop") + parser.add_argument('--show-samples', dest='show_samples', action='store_true', + help='Display unormalized training samples') + parser.add_argument('--wb', dest='wb', action='store_true', + help='Log to Weights & Biases') + parser.add_argument('--pretrained', dest='pretrained', action='store_true', + help='Load pretrained parameters before starting the training') + parser.add_argument("--amp", dest="amp", help="Use Automatic Mixed Precision", action="store_true") + parser.add_argument('--find-lr', action='store_true', help='Gridsearch the optimal LR') + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/references/recognition/utils.py b/references/recognition/utils.py new file mode 100644 index 0000000000..91ec6fdb97 --- /dev/null +++ b/references/recognition/utils.py @@ -0,0 +1,73 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import math + +import matplotlib.pyplot as plt +import numpy as np + + +def plot_samples(images, targets): + # Unnormalize image + num_samples = min(len(images), 12) + num_cols = min(len(images), 4) + num_rows = int(math.ceil(num_samples / num_cols)) + _, axes = plt.subplots(num_rows, num_cols, figsize=(20, 5)) + for idx in range(num_samples): + img = (255 * images[idx].numpy()).round().clip(0, 255).astype(np.uint8) + if img.shape[0] == 3 and img.shape[2] != 3: + img = img.transpose(1, 2, 0) + + row_idx = idx // num_cols + col_idx = idx % num_cols + ax = axes[row_idx] if num_rows > 1 else axes + ax = ax[col_idx] if num_cols > 1 else ax + + ax.imshow(img) + ax.set_title(targets[idx]) + # Disable axis + for ax in axes.ravel(): + ax.axis('off') + + plt.show() + + +def plot_recorder(lr_recorder, loss_recorder, beta: float = 0.95, **kwargs) -> None: + """Display the results of the LR grid search. + Adapted from https://github.com/frgfm/Holocron/blob/master/holocron/trainer/core.py. + + Args: + lr_recorder: list of LR values + loss_recorder: list of loss values + beta (float, optional): smoothing factor + """ + + if len(lr_recorder) != len(loss_recorder) or len(lr_recorder) == 0: + raise AssertionError("Both `lr_recorder` and `loss_recorder` should have the same length") + + # Exp moving average of loss + smoothed_losses = [] + avg_loss = 0. + for idx, loss in enumerate(loss_recorder): + avg_loss = beta * avg_loss + (1 - beta) * loss + smoothed_losses.append(avg_loss / (1 - beta ** (idx + 1))) + + # Properly rescale Y-axis + data_slice = slice( + min(len(loss_recorder) // 10, 10), + -min(len(loss_recorder) // 20, 5) if len(loss_recorder) >= 20 else len(loss_recorder) + ) + vals = np.array(smoothed_losses[data_slice]) + min_idx = vals.argmin() + max_val = vals.max() if min_idx is None else vals[:min_idx + 1].max() # type: ignore[misc] + delta = max_val - vals[min_idx] + + plt.plot(lr_recorder[data_slice], smoothed_losses[data_slice]) + plt.xscale('log') + plt.xlabel('Learning Rate') + plt.ylabel('Training loss') + plt.ylim(vals[min_idx] - 0.1 * delta, max_val + 0.2 * delta) + plt.grid(True, linestyle='--', axis='x') + plt.show(**kwargs) diff --git a/references/requirements.txt b/references/requirements.txt new file mode 100644 index 0000000000..66fce93272 --- /dev/null +++ b/references/requirements.txt @@ -0,0 +1,3 @@ +-e . +fastprogress>=0.1.21 +wandb>=0.10.31 diff --git a/requirements-pt.txt b/requirements-pt.txt new file mode 100644 index 0000000000..95d17d6f34 --- /dev/null +++ b/requirements-pt.txt @@ -0,0 +1,16 @@ +numpy>=1.16.0 +scipy>=1.4.0 +h5py>=3.1.0 +opencv-python>=3.4.5.20 +PyMuPDF>=1.16.0,!=1.18.11,!=1.18.12 +pyclipper>=1.2.0 +shapely>=1.6.0 +matplotlib>=3.1.0,<3.4.3 +mplcursors>=0.3 +weasyprint>=52.2,<53.0 +unidecode>=1.0.0 +torch>=1.8.0 +torchvision>=0.9.0 +Pillow>=8.3.2 +tqdm>=4.30.0 +rapidfuzz>=1.6.0 diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000..5f0621b870 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,17 @@ +numpy>=1.16.0 +scipy>=1.4.0 +h5py>=3.1.0 +opencv-python>=3.4.5.20 +PyMuPDF>=1.16.0,!=1.18.11,!=1.18.12 +pyclipper>=1.2.0 +shapely>=1.6.0 +matplotlib>=3.1.0,<3.4.3 +mplcursors>=0.3 +weasyprint>=52.2,<53.0 +unidecode>=1.0.0 +tensorflow>=2.4.0 +Pillow>=8.3.2 +tqdm>=4.30.0 +tensorflow-addons>=0.13.0 +rapidfuzz>=1.6.0 +keras<2.7.0 diff --git a/scripts/analyze.py b/scripts/analyze.py new file mode 100644 index 0000000000..9332aa1bf0 --- /dev/null +++ b/scripts/analyze.py @@ -0,0 +1,58 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + + +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +from doctr.file_utils import is_tf_available +from doctr.io import DocumentFile +from doctr.models import ocr_predictor + +# Enable GPU growth if using TF +if is_tf_available(): + import tensorflow as tf + gpu_devices = tf.config.experimental.list_physical_devices('GPU') + if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) + + +def main(args): + + model = ocr_predictor(args.detection, args.recognition, pretrained=True) + + if args.path.endswith(".pdf"): + doc = DocumentFile.from_pdf(args.path).as_images() + else: + doc = DocumentFile.from_images(args.path) + + out = model(doc) + + for page, img in zip(out.pages, doc): + page.show(img, block=not args.noblock, interactive=not args.static) + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='DocTR end-to-end analysis', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('path', type=str, help='Path to the input document (PDF or image)') + parser.add_argument('--detection', type=str, default='db_resnet50', + help='Text detection model to use for analysis') + parser.add_argument('--recognition', type=str, default='crnn_vgg16_bn', + help='Text recognition model to use for analysis') + parser.add_argument("--noblock", dest="noblock", help="Disables blocking visualization. Used only for CI.", + action="store_true") + parser.add_argument("--static", dest="static", help="Switches to static visualization", action="store_true") + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/collect_env.py b/scripts/collect_env.py new file mode 100644 index 0000000000..a7fda9d9ee --- /dev/null +++ b/scripts/collect_env.py @@ -0,0 +1,353 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +""" +Based on https://github.com/pytorch/pytorch/blob/master/torch/utils/collect_env.py +This script outputs relevant system environment info +Run it with `python collect_env.py`. +""" + +from __future__ import absolute_import, division, print_function, unicode_literals + +import locale +import os +import re +import subprocess +import sys +from collections import namedtuple + +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' + +try: + import doctr + DOCTR_AVAILABLE = True +except (ImportError, NameError, AttributeError, OSError): + DOCTR_AVAILABLE = False + +try: + import tensorflow as tf + TF_AVAILABLE = True +except (ImportError, NameError, AttributeError, OSError): + TF_AVAILABLE = False + +try: + import torch + TORCH_AVAILABLE = True +except (ImportError, NameError, AttributeError, OSError): + TORCH_AVAILABLE = False + +try: + import torchvision + TV_AVAILABLE = True +except (ImportError, NameError, AttributeError, OSError): + TV_AVAILABLE = False + +try: + import cv2 + CV2_AVAILABLE = True +except (ImportError, NameError, AttributeError, OSError): + CV2_AVAILABLE = False + +PY3 = sys.version_info >= (3, 0) + + +# System Environment Information +SystemEnv = namedtuple('SystemEnv', [ + 'doctr_version', + 'tf_version', + 'torch_version', + 'torchvision_version', + 'cv2_version', + 'os', + 'python_version', + 'is_cuda_available_tf', + 'is_cuda_available_torch', + 'cuda_runtime_version', + 'nvidia_driver_version', + 'nvidia_gpu_models', + 'cudnn_version', +]) + + +def run(command): + """Returns (return-code, stdout, stderr)""" + p = subprocess.Popen(command, stdout=subprocess.PIPE, + stderr=subprocess.PIPE, shell=True) + output, err = p.communicate() + rc = p.returncode + if PY3: + enc = locale.getpreferredencoding() + output = output.decode(enc) + err = err.decode(enc) + return rc, output.strip(), err.strip() + + +def run_and_read_all(run_lambda, command): + """Runs command using run_lambda; reads and returns entire output if rc is 0""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + return out + + +def run_and_parse_first_match(run_lambda, command, regex): + """Runs command using run_lambda, returns the first regex match if it exists""" + rc, out, _ = run_lambda(command) + if rc != 0: + return None + match = re.search(regex, out) + if match is None: + return None + return match.group(1) + + +def get_nvidia_driver_version(run_lambda): + if get_platform() == 'darwin': + cmd = 'kextstat | grep -i cuda' + return run_and_parse_first_match(run_lambda, cmd, + r'com[.]nvidia[.]CUDA [(](.*?)[)]') + smi = get_nvidia_smi() + return run_and_parse_first_match(run_lambda, smi, r'Driver Version: (.*?) ') + + +def get_gpu_info(run_lambda): + if get_platform() == 'darwin': + if TF_AVAILABLE and any(tf.config.list_physical_devices('GPU')): + return tf.config.list_physical_devices('GPU')[0].name + return None + smi = get_nvidia_smi() + uuid_regex = re.compile(r' \(UUID: .+?\)') + rc, out, _ = run_lambda(smi + ' -L') + if rc != 0: + return None + # Anonymize GPUs by removing their UUID + return re.sub(uuid_regex, '', out) + + +def get_running_cuda_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'nvcc --version', r'release .+ V(.*)') + + +def get_cudnn_version(run_lambda): + """This will return a list of libcudnn.so; it's hard to tell which one is being used""" + if get_platform() == 'win32': + cudnn_cmd = 'where /R "%CUDA_PATH%\\bin" cudnn*.dll' + elif get_platform() == 'darwin': + # CUDA libraries and drivers can be found in /usr/local/cuda/. See + # https://docs.nvidia.com/cuda/cuda-installation-guide-mac-os-x/index.html#install + # https://docs.nvidia.com/deeplearning/sdk/cudnn-install/index.html#installmac + # Use CUDNN_LIBRARY when cudnn library is installed elsewhere. + cudnn_cmd = 'ls /usr/local/cuda/lib/libcudnn*' + else: + cudnn_cmd = 'ldconfig -p | grep libcudnn | rev | cut -d" " -f1 | rev' + rc, out, _ = run_lambda(cudnn_cmd) + # find will return 1 if there are permission errors or if not found + if len(out) == 0 or (rc != 1 and rc != 0): + lib = os.environ.get('CUDNN_LIBRARY') + if lib is not None and os.path.isfile(lib): + return os.path.realpath(lib) + return None + files = set() + for fn in out.split('\n'): + fn = os.path.realpath(fn) # eliminate symbolic links + if os.path.isfile(fn): + files.add(fn) + if not files: + return None + # Alphabetize the result because the order is non-deterministic otherwise + files = list(sorted(files)) + if len(files) == 1: + return files[0] + result = '\n'.join(files) + return 'Probably one of the following:\n{}'.format(result) + + +def get_nvidia_smi(): + # Note: nvidia-smi is currently available only on Windows and Linux + smi = 'nvidia-smi' + if get_platform() == 'win32': + smi = '"C:\\Program Files\\NVIDIA Corporation\\NVSMI\\%s"' % smi + return smi + + +def get_platform(): + if sys.platform.startswith('linux'): + return 'linux' + elif sys.platform.startswith('win32'): + return 'win32' + elif sys.platform.startswith('cygwin'): + return 'cygwin' + elif sys.platform.startswith('darwin'): + return 'darwin' + else: + return sys.platform + + +def get_mac_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'sw_vers -productVersion', r'(.*)') + + +def get_windows_version(run_lambda): + return run_and_read_all(run_lambda, 'wmic os get Caption | findstr /v Caption') + + +def get_lsb_version(run_lambda): + return run_and_parse_first_match(run_lambda, 'lsb_release -a', r'Description:\t(.*)') + + +def check_release_file(run_lambda): + return run_and_parse_first_match(run_lambda, 'cat /etc/*-release', + r'PRETTY_NAME="(.*)"') + + +def get_os(run_lambda): + platform = get_platform() + + if platform == 'win32' or platform == 'cygwin': + return get_windows_version(run_lambda) + + if platform == 'darwin': + version = get_mac_version(run_lambda) + if version is None: + return None + return 'Mac OSX {}'.format(version) + + if platform == 'linux': + # Ubuntu/Debian based + desc = get_lsb_version(run_lambda) + if desc is not None: + return desc + + # Try reading /etc/*-release + desc = check_release_file(run_lambda) + if desc is not None: + return desc + + return platform + + # Unknown platform + return platform + + +def get_env_info(): + run_lambda = run + + doctr_str = doctr.__version__ if DOCTR_AVAILABLE else 'N/A' + + if TF_AVAILABLE: + tf_str = tf.__version__ + tf_cuda_available_str = any(tf.config.list_physical_devices('GPU')) + else: + tf_str = tf_cuda_available_str = 'N/A' + + if TORCH_AVAILABLE: + torch_str = torch.__version__ + torch_cuda_available_str = torch.cuda.is_available() + else: + torch_str = torch_cuda_available_str = 'N/A' + + tv_str = torchvision.__version__ if TV_AVAILABLE else 'N/A' + + cv2_str = cv2.__version__ if CV2_AVAILABLE else 'N/A' + + return SystemEnv( + doctr_version=doctr_str, + tf_version=tf_str, + torch_version=torch_str, + torchvision_version=tv_str, + cv2_version=cv2_str, + python_version=".".join(map(str, sys.version_info[:3])), + is_cuda_available_tf=tf_cuda_available_str, + is_cuda_available_torch=torch_cuda_available_str, + cuda_runtime_version=get_running_cuda_version(run_lambda), + nvidia_gpu_models=get_gpu_info(run_lambda), + nvidia_driver_version=get_nvidia_driver_version(run_lambda), + cudnn_version=get_cudnn_version(run_lambda), + os=get_os(run_lambda), + ) + + +env_info_fmt = """ +DocTR version: {doctr_version} +TensorFlow version: {tf_version} +PyTorch version: {torch_version} (torchvision {torchvision_version}) +OpenCV version: {cv2_version} +OS: {os} +Python version: {python_version} +Is CUDA available (TensorFlow): {is_cuda_available_tf} +Is CUDA available (PyTorch): {is_cuda_available_torch} +CUDA runtime version: {cuda_runtime_version} +GPU models and configuration: {nvidia_gpu_models} +Nvidia driver version: {nvidia_driver_version} +cuDNN version: {cudnn_version} +""".strip() + + +def pretty_str(envinfo): + def replace_nones(dct, replacement='Could not collect'): + for key in dct.keys(): + if dct[key] is not None: + continue + dct[key] = replacement + return dct + + def replace_bools(dct, true='Yes', false='No'): + for key in dct.keys(): + if dct[key] is True: + dct[key] = true + elif dct[key] is False: + dct[key] = false + return dct + + def maybe_start_on_next_line(string): + # If `string` is multiline, prepend a \n to it. + if string is not None and len(string.split('\n')) > 1: + return '\n{}\n'.format(string) + return string + + mutable_dict = envinfo._asdict() + + # If nvidia_gpu_models is multiline, start on the next line + mutable_dict['nvidia_gpu_models'] = \ + maybe_start_on_next_line(envinfo.nvidia_gpu_models) + + # If the machine doesn't have CUDA, report some fields as 'No CUDA' + dynamic_cuda_fields = [ + 'cuda_runtime_version', + 'nvidia_gpu_models', + 'nvidia_driver_version', + ] + all_cuda_fields = dynamic_cuda_fields + ['cudnn_version'] + all_dynamic_cuda_fields_missing = all( + mutable_dict[field] is None for field in dynamic_cuda_fields) + if TF_AVAILABLE and not any(tf.config.list_physical_devices('GPU')) and all_dynamic_cuda_fields_missing: + for field in all_cuda_fields: + mutable_dict[field] = 'No CUDA' + + # Replace True with Yes, False with No + mutable_dict = replace_bools(mutable_dict) + + # Replace all None objects with 'Could not collect' + mutable_dict = replace_nones(mutable_dict) + + return env_info_fmt.format(**mutable_dict) + + +def get_pretty_env_info(): + """Collects environment information for debugging purposes + Returns: + str: environment information + """ + return pretty_str(get_env_info()) + + +def main(): + print("Collecting environment information...\n") + output = get_pretty_env_info() + print(output) + + +if __name__ == '__main__': + main() diff --git a/scripts/detect_artefacts.py b/scripts/detect_artefacts.py new file mode 100644 index 0000000000..af565854a7 --- /dev/null +++ b/scripts/detect_artefacts.py @@ -0,0 +1,86 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + + +import os + +os.environ['USE_TORCH'] = '1' + +import argparse +import logging + +import cv2 +import matplotlib.pyplot as plt +import torch + +from doctr.io.image import read_img_as_tensor +from doctr.models import obj_detection + +CLASSES = ["__background__", "QR Code", "Barcode", "Logo", "Photo"] +CM = [(255, 255, 255), (0, 0, 150), (0, 0, 0), (0, 150, 0), (150, 0, 0)] + + +def plot_predictions(image, boxes, labels): + for box, label in zip(boxes, labels): + # Bounding box around artefacts + cv2.rectangle(image, (box[0], box[1]), (box[2], box[3]), + CM[label], 2) + text_size, _ = cv2.getTextSize(CLASSES[label], cv2.FONT_HERSHEY_SIMPLEX, 2, 2) + text_w, text_h = text_size + # Filled rectangle above bounding box + cv2.rectangle(image, (box[0], box[1]), (box[0] + text_w, box[1] - text_h), + CM[label], -1) + # Text bearing the name of the artefact detected + cv2.putText(image, CLASSES[label], (int(box[0]), int(box[1])), + cv2.FONT_HERSHEY_SIMPLEX, 2, (255, 255, 255), 3) + plt.axis('off') + plt.imshow(image) + plt.show() + + +@torch.no_grad() +def main(args): + print(args) + + model = obj_detection.__dict__[args.arch](pretrained=True, num_classes=5).eval() + # GPU + if isinstance(args.device, int): + if not torch.cuda.is_available(): + raise AssertionError("PyTorch cannot access your GPU. Please investigate!") + if args.device >= torch.cuda.device_count(): + raise ValueError("Invalid device index") + # Silent default switch to GPU if available + elif torch.cuda.is_available(): + args.device = 0 + else: + logging.warning("No accessible GPU, target device set to CPU.") + img = read_img_as_tensor(args.img_path).unsqueeze(0) + if torch.cuda.is_available(): + torch.cuda.set_device(args.device) + model = model.cuda() + img = img.cuda() + + pred = model(img) + labels = pred[0]['labels'].detach().cpu().numpy() + labels = labels.round().astype(int) + boxes = pred[0]['boxes'].detach().cpu().numpy() + boxes = boxes.round().astype(int) + img = img.cpu().permute(0, 2, 3, 1).numpy()[0].copy() + plot_predictions(img, boxes, labels) + + +def parse_args(): + parser = argparse.ArgumentParser(description="Artefact detection model to use", + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + parser.add_argument('arch', type=str, help='Artefact detection model to use') + parser.add_argument('img_path', type=str, help='path to the image') + parser.add_argument('--device', default=None, type=int, help='device') + args = parser.parse_args() + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/scripts/evaluate.py b/scripts/evaluate.py new file mode 100644 index 0000000000..dcf4da1e0f --- /dev/null +++ b/scripts/evaluate.py @@ -0,0 +1,172 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os + +os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" + +import numpy as np +from tqdm import tqdm + +from doctr import datasets +from doctr.file_utils import is_tf_available +from doctr.models import ocr_predictor +from doctr.models._utils import extract_crops +from doctr.utils.metrics import LocalizationConfusion, OCRMetric, TextMatch + +# Enable GPU growth if using TF +if is_tf_available(): + import tensorflow as tf + gpu_devices = tf.config.experimental.list_physical_devices('GPU') + if any(gpu_devices): + tf.config.experimental.set_memory_growth(gpu_devices[0], True) +else: + import torch + + +def _pct(val): + return "N/A" if val is None else f"{val:.2%}" + + +def main(args): + + predictor = ocr_predictor(args.detection, args.recognition, pretrained=True, reco_bs=args.batch_size) + + if args.img_folder and args.label_file: + testset = datasets.OCRDataset( + img_folder=args.img_folder, + label_file=args.label_file, + ) + sets = [testset] + else: + train_set = datasets.__dict__[args.dataset](train=True, download=True, use_polygons=args.rotation) + val_set = datasets.__dict__[args.dataset](train=False, download=True, use_polygons=args.rotation) + sets = [train_set, val_set] + + reco_metric = TextMatch() + if args.rotation and args.mask_shape: + det_metric = LocalizationConfusion( + iou_thresh=args.iou, + use_polygons=args.rotation, + mask_shape=(args.mask_shape, args.mask_shape) + ) + e2e_metric = OCRMetric( + iou_thresh=args.iou, + use_polygons=args.rotation, + mask_shape=(args.mask_shape, args.mask_shape) + ) + else: + det_metric = LocalizationConfusion(iou_thresh=args.iou, use_polygons=args.rotation) + e2e_metric = OCRMetric(iou_thresh=args.iou, use_polygons=args.rotation) + + sample_idx = 0 + for dataset in sets: + for page, target in tqdm(dataset): + # GT + gt_boxes = target['boxes'] + gt_labels = target['labels'] + + if args.img_folder and args.label_file: + x, y, w, h = gt_boxes[:, 0], gt_boxes[:, 1], gt_boxes[:, 2], gt_boxes[:, 3] + xmin, ymin = np.clip(x - w / 2, 0, 1), np.clip(y - h / 2, 0, 1) + xmax, ymax = np.clip(x + w / 2, 0, 1), np.clip(y + h / 2, 0, 1) + gt_boxes = np.stack([xmin, ymin, xmax, ymax], axis=-1) + + # Forward + if is_tf_available(): + out = predictor(page[None, ...]) + crops = extract_crops(page, gt_boxes) + reco_out = predictor.reco_predictor(crops) + else: + with torch.no_grad(): + out = predictor(page[None, ...]) + # We directly crop on PyTorch tensors, which are in channels_first + crops = extract_crops(page, gt_boxes, channels_last=False) + reco_out = predictor.reco_predictor(crops) + + if len(reco_out): + reco_words, _ = zip(*reco_out) + else: + reco_words = [] + + # Unpack preds + pred_boxes = [] + pred_labels = [] + for page in out.pages: + height, width = page.dimensions + for block in page.blocks: + for line in block.lines: + for word in line.words: + if not args.rotation: + (a, b), (c, d) = word.geometry + else: + [x1, y1], [x2, y2], [x3, y3], [x4, y4], = word.geometry + if gt_boxes.dtype == int: + if not args.rotation: + pred_boxes.append([int(a * width), int(b * height), + int(c * width), int(d * height)]) + else: + pred_boxes.append( + [ + [int(x1 * width), int(y1 * height)], + [int(x2 * width), int(y2 * height)], + [int(x3 * width), int(y3 * height)], + [int(x4 * width), int(y4 * height)], + ] + ) + else: + if not args.rotation: + pred_boxes.append([a, b, c, d]) + else: + pred_boxes.append([[x1, y1], [x2, y2], [x3, y3], [x4, y4]]) + pred_labels.append(word.value) + + # Update the metric + det_metric.update(gt_boxes, np.asarray(pred_boxes)) + reco_metric.update(gt_labels, reco_words) + e2e_metric.update(gt_boxes, np.asarray(pred_boxes), gt_labels, pred_labels) + + # Loop break + sample_idx += 1 + if isinstance(args.samples, int) and args.samples == sample_idx: + break + if isinstance(args.samples, int) and args.samples == sample_idx: + break + + # Unpack aggregated metrics + print(f"Model Evaluation (model= {args.detection} + {args.recognition}, " + f"dataset={'OCRDataset' if args.img_folder else args.dataset})") + recall, precision, mean_iou = det_metric.summary() + print(f"Text Detection - Recall: {_pct(recall)}, Precision: {_pct(precision)}, Mean IoU: {_pct(mean_iou)}") + acc = reco_metric.summary() + print(f"Text Recognition - Accuracy: {_pct(acc['raw'])} (unicase: {_pct(acc['unicase'])})") + recall, precision, mean_iou = e2e_metric.summary() + print(f"OCR - Recall: {_pct(recall['raw'])} (unicase: {_pct(recall['unicase'])}), " + f"Precision: {_pct(precision['raw'])} (unicase: {_pct(precision['unicase'])}), Mean IoU: {_pct(mean_iou)}") + + +def parse_args(): + import argparse + parser = argparse.ArgumentParser(description='DocTR end-to-end evaluation', + formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + parser.add_argument('detection', type=str, help='Text detection model to use for analysis') + parser.add_argument('recognition', type=str, help='Text recognition model to use for analysis') + parser.add_argument('--iou', type=float, default=0.5, help='IoU threshold to match a pair of boxes') + parser.add_argument('--dataset', type=str, default='FUNSD', help='choose a dataset: FUNSD, CORD') + parser.add_argument('--img_folder', type=str, default=None, help='Only for local sets, path to images') + parser.add_argument('--label_file', type=str, default=None, help='Only for local sets, path to labels') + parser.add_argument('--rotation', dest='rotation', action='store_true', help='evaluate with rotated bbox') + parser.add_argument('-b', '--batch_size', type=int, default=32, help='batch size for recognition') + parser.add_argument('--mask_shape', type=int, default=None, help='mask shape for mask iou (only for rotation)') + parser.add_argument('--samples', type=int, default=None, help='evaluate only on the N first samples') + args = parser.parse_args() + + return args + + +if __name__ == "__main__": + args = parse_args() + main(args) diff --git a/setup.cfg b/setup.cfg new file mode 100644 index 0000000000..f48fdadb8a --- /dev/null +++ b/setup.cfg @@ -0,0 +1,3 @@ +[metadata] +description-file = README.md +license_file = LICENSE diff --git a/setup.py b/setup.py new file mode 100644 index 0000000000..0556bf26ff --- /dev/null +++ b/setup.py @@ -0,0 +1,202 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +""" +Package installation setup +""" + +import os +import re +import subprocess +from pathlib import Path + +from setuptools import find_packages, setup + +version = "0.5.1a0" +sha = 'Unknown' +src_folder = 'doctr' +package_index = 'python-doctr' + +cwd = Path(__file__).parent.absolute() + +if os.getenv('BUILD_VERSION'): + version = os.getenv('BUILD_VERSION') +elif sha != 'Unknown': + try: + sha = subprocess.check_output(['git', 'rev-parse', 'HEAD'], cwd=cwd).decode('ascii').strip() + except Exception: + pass + version += '+' + sha[:7] +print(f"Building wheel {package_index}-{version}") + +with open(cwd.joinpath(src_folder, 'version.py'), 'w') as f: + f.write(f"__version__ = '{version}'\n") + +with open('README.md', 'r') as f: + readme = f.read() + +# Borrowed from https://github.com/huggingface/transformers/blob/master/setup.py +_deps = [ + "importlib_metadata", + "numpy>=1.16.0", + "scipy>=1.4.0", + "h5py>=3.1.0", + "opencv-python>=3.4.5.20", + "tensorflow>=2.4.0", + "PyMuPDF>=1.16.0,!=1.18.11,!=1.18.12", # 18.11 and 18.12 fail (issue #222) + "pyclipper>=1.2.0", + "shapely>=1.6.0", + "matplotlib>=3.1.0,<3.4.3", + "mplcursors>=0.3", + "weasyprint>=52.2,<53.0", + "unidecode>=1.0.0", + "tensorflow-cpu>=2.4.0", + "torch>=1.8.0", + "torchvision>=0.9.0", + "Pillow>=8.3.2", # cf. https://github.com/advisories/GHSA-98vv-pw6r-q6q4 + "tqdm>=4.30.0", + "tensorflow-addons>=0.13.0", + "rapidfuzz>=1.6.0", + "keras<2.7.0", + # Testing + "pytest>=5.3.2", + "coverage>=4.5.4", + "hdf5storage>=0.1.18", + "requests>=2.20.0", + "requirements-parser==0.2.0", + # Quality + "flake8>=3.9.0", + "isort>=5.7.0", + "mypy>=0.812", + "pydocstyle>=6.1.1", + # Docs + "sphinx<3.5.0", + "sphinx-rtd-theme==0.4.3", + "sphinxemoji>=0.1.8", + "sphinx-copybutton>=0.3.1", + "docutils<0.18", + "recommonmark>=0.7.1", + "sphinx-markdown-tables>=0.0.15", +] + +deps = {b: a for a, b in (re.findall(r"^(([^!=<>]+)(?:[!=<>].*)?$)", x)[0] for x in _deps)} + + +def deps_list(*pkgs): + return [deps[pkg] for pkg in pkgs] + + +install_requires = [ + deps["importlib_metadata"] + ";python_version<'3.8'", # importlib_metadata for Python versions that don't have it + deps["numpy"], + deps["scipy"], + deps["h5py"], + deps["opencv-python"], + deps["PyMuPDF"], + deps["pyclipper"], + deps["shapely"], + deps["matplotlib"], + deps["mplcursors"], + deps["weasyprint"], + deps["unidecode"], + deps["Pillow"], + deps["tqdm"], + deps["rapidfuzz"], +] + +extras = {} +extras["tf"] = deps_list( + "tensorflow", + "tensorflow-addons", + "keras", +) + +extras["tf-cpu"] = deps_list( + "tensorflow-cpu", + "tensorflow-addons", + "keras", +) + +extras["torch"] = deps_list( + "torch", + "torchvision", +) + +extras["all"] = ( + extras["tf"] + + extras["torch"] +) + +extras["testing"] = deps_list( + "pytest", + "coverage", + "requests", + "hdf5storage", + "requirements-parser", +) + +extras["quality"] = deps_list( + "flake8", + "isort", + "mypy", + "pydocstyle", +) + +extras["docs_specific"] = deps_list( + "sphinx", + "sphinx-rtd-theme", + "sphinxemoji", + "sphinx-copybutton", + "docutils", + "recommonmark", + "sphinx-markdown-tables", +) + +extras["docs"] = extras["all"] + extras["docs_specific"] + +extras["dev"] = ( + extras["all"] + + extras["testing"] + + extras["quality"] + + extras["docs_specific"] +) + +setup( + # Metadata + name=package_index, + version=version, + author='Mindee', + author_email='contact@mindee.com', + maintainer='François-Guillaume Fernandez, Charles Gaillard', + description='Document Text Recognition (docTR): deep Learning for high-performance OCR on documents.', + long_description=readme, + long_description_content_type="text/markdown", + url='https://github.com/mindee/doctr', + download_url='https://github.com/mindee/doctr/tags', + license='Apache', + classifiers=[ + 'Development Status :: 4 - Beta', + 'Intended Audience :: Developers', + "Intended Audience :: Education", + 'Intended Audience :: Science/Research', + 'License :: OSI Approved :: Apache Software License', + 'Natural Language :: English', + 'Operating System :: OS Independent', + 'Programming Language :: Python :: 3', + 'Programming Language :: Python :: 3.6', + 'Programming Language :: Python :: 3.7', + 'Topic :: Scientific/Engineering :: Artificial Intelligence', + ], + keywords=['OCR', 'deep learning', 'computer vision', 'tensorflow', 'pytorch', 'text detection', 'text recognition'], + + # Package info + packages=find_packages(exclude=('tests',)), + zip_safe=True, + python_requires='>=3.6.0', + include_package_data=True, + install_requires=install_requires, + extras_require=extras, + package_data={'': ['LICENSE']} +) diff --git a/tests/common/test_core.py b/tests/common/test_core.py new file mode 100644 index 0000000000..6731af80a9 --- /dev/null +++ b/tests/common/test_core.py @@ -0,0 +1,13 @@ +import doctr + + +def test_version(): + assert len(doctr.__version__.split('.')) == 3 + + +def test_is_tf_available(): + assert doctr.is_tf_available() + + +def test_is_torch_available(): + assert not doctr.is_torch_available() diff --git a/tests/common/test_datasets.py b/tests/common/test_datasets.py new file mode 100644 index 0000000000..914a48e1e4 --- /dev/null +++ b/tests/common/test_datasets.py @@ -0,0 +1,44 @@ +from pathlib import Path + +import numpy as np +import pytest + +from doctr import datasets + + +def test_visiondataset(): + url = 'https://data.deepai.org/mnist.zip' + with pytest.raises(ValueError): + datasets.datasets.VisionDataset(url, download=False) + + dataset = datasets.datasets.VisionDataset(url, download=True, extract_archive=True) + assert len(dataset) == 0 + assert repr(dataset) == 'VisionDataset()' + + +def test_abstractdataset(mock_image_path): + + with pytest.raises(ValueError): + datasets.datasets.AbstractDataset('my/fantasy/folder') + + # Check transforms + path = Path(mock_image_path) + ds = datasets.datasets.AbstractDataset(path.parent) + # Patch some data + ds.data = [(path.name, 0)] + + # Fetch the img + img, target = ds[0] + assert isinstance(target, int) and target == 0 + + # Check img_transforms + ds.img_transforms = lambda x: 1 - x + img2, target2 = ds[0] + assert np.all(img2.numpy() == 1 - img.numpy()) + assert target == target2 + + # Check sample_transforms + ds.img_transforms = None + ds.sample_transforms = lambda x, y: (x, y + 1) + img3, target3 = ds[0] + assert np.all(img3.numpy() == img.numpy()) and (target3 == (target + 1)) diff --git a/tests/common/test_datasets_utils.py b/tests/common/test_datasets_utils.py new file mode 100644 index 0000000000..b69c5d1daf --- /dev/null +++ b/tests/common/test_datasets_utils.py @@ -0,0 +1,74 @@ +import numpy as np +import pytest + +from doctr.datasets import utils + + +@pytest.mark.parametrize( + "input_str, vocab, output_str", + [ + ['f orêt', 'latin', 'foret'], + ['f or êt', 'french', 'forêt'], + ['¢¾©téØßřůž', 'french', '¢■■té■■ruz'], + ['Ûæëð', 'french', 'Û■ë■'], + ['Ûæë<àð', 'latin', 'U■erêi`l""" + encoded = utils.encode_string(input_str, mapping) + decoded = utils.decode_sequence(encoded, mapping) + assert decoded == input_str + + +def test_decode_sequence(): + mapping = "abcdef" + with pytest.raises(TypeError): + utils.decode_sequence(123, mapping) + with pytest.raises(AssertionError): + utils.decode_sequence(np.array([2, 10]), mapping) + with pytest.raises(AssertionError): + utils.decode_sequence(np.array([2, 4.5]), mapping) + + assert utils.decode_sequence([3, 4, 3, 4], mapping) == "dede" + + +@pytest.mark.parametrize( + "sequences, vocab, target_size, sos, eos, pad, dynamic_len, error, out_shape, gts", + [ + [['cba'], 'abcdef', None, None, 1, None, False, True, (1, 3), [[2, 1, 0]]], # eos in vocab + [['cba', 'a'], 'abcdef', None, None, -1, None, False, False, (2, 4), [[2, 1, 0, -1], [0, -1, -1, -1]]], + [['cba', 'a'], 'abcdef', None, None, 6, None, False, False, (2, 4), [[2, 1, 0, 6], [0, 6, 6, 6]]], + [['cba', 'a'], 'abcdef', 2, None, -1, None, False, False, (2, 2), [[2, 1], [0, -1]]], + [['cba', 'a'], 'abcdef', 4, None, -1, None, False, False, (2, 4), [[2, 1, 0, -1], [0, -1, -1, -1]]], + [['cba', 'a'], 'abcdef', 5, 7, -1, None, False, False, (2, 5), [[7, 2, 1, 0, -1], [7, 0, -1, -1, -1]]], + [['cba', 'a'], 'abcdef', 6, 7, -1, None, True, False, (2, 5), [[7, 2, 1, 0, -1], [7, 0, -1, -1, -1]]], + [['cba', 'a'], 'abcdef', None, 7, -1, 9, False, False, (2, 6), [[7, 2, 1, 0, -1, 9], [7, 0, -1, 9, 9, 9]]], + ], +) +def test_encode_sequences(sequences, vocab, target_size, sos, eos, pad, dynamic_len, error, out_shape, gts): + if error: + with pytest.raises(ValueError): + utils.encode_sequences(sequences, vocab, target_size, eos, sos, pad, dynamic_len) + else: + out = utils.encode_sequences(sequences, vocab, target_size, eos, sos, pad, dynamic_len) + assert isinstance(out, np.ndarray) + assert out.shape == out_shape + assert np.all(out == np.asarray(gts)), print(out, gts) diff --git a/tests/common/test_headers.py b/tests/common/test_headers.py new file mode 100644 index 0000000000..1454a9aafc --- /dev/null +++ b/tests/common/test_headers.py @@ -0,0 +1,46 @@ +from datetime import datetime +from pathlib import Path + + +def test_headers(): + + shebang = ["#!usr/bin/python\n"] + blank_line = "\n" + + _copyright_str = f"-{datetime.now().year}" if datetime.now().year > 2021 else "" + copyright_notice = [f"# Copyright (C) 2021{_copyright_str}, Mindee.\n"] + license_notice = [ + "# This program is licensed under the Apache License version 2.\n", + "# See LICENSE or go to for full license details.\n" + ] + + # Define all header options + headers = [ + shebang + [blank_line] + copyright_notice + [blank_line] + license_notice, + copyright_notice + [blank_line] + license_notice + ] + + excluded_files = ["version.py", "__init__.py"] + invalid_files = [] + + # For every python file in the repository + folders_to_check = [["doctr"], ["api", "app"], ["demo"], ["scripts"], ["references"]] + for folder in folders_to_check: + for source_path in Path(__file__).parent.parent.parent.joinpath(*folder).rglob('*.py'): + if source_path.name not in excluded_files: + # Parse header + header_length = max(len(option) for option in headers) + current_header = [] + with open(source_path) as f: + for idx, line in enumerate(f): + current_header.append(line) + if idx == header_length - 1: + break + # Validate it + if not any( + "".join(current_header[:min(len(option), len(current_header))]) == "".join(option) + for option in headers + ): + invalid_files.append(source_path) + + assert len(invalid_files) == 0, f"Invalid header in the following files: {invalid_files}" diff --git a/tests/common/test_io.py b/tests/common/test_io.py new file mode 100644 index 0000000000..04ecf6ec4e --- /dev/null +++ b/tests/common/test_io.py @@ -0,0 +1,143 @@ +from io import BytesIO + +import fitz +import numpy as np +import pytest +import requests + +from doctr import io + + +def test_convert_page_to_numpy(mock_pdf): + pdf = fitz.open(mock_pdf) + # Check correct read + rgb_page = io.pdf.convert_page_to_numpy(pdf[0], default_scales=(1, 1)) + assert isinstance(rgb_page, np.ndarray) + assert rgb_page.shape == (842, 595, 3) + + # Check channel order + bgr_page = io.pdf.convert_page_to_numpy(pdf[0], default_scales=(1, 1), bgr_output=True) + assert np.all(bgr_page == rgb_page[..., ::-1]) + + # Check resizing + resized_page = io.pdf.convert_page_to_numpy(pdf[0], output_size=(396, 306)) + assert resized_page.shape == (396, 306, 3) + + # Check rescaling + rgb_page = io.pdf.convert_page_to_numpy(pdf[0]) + assert isinstance(rgb_page, np.ndarray) + assert rgb_page.shape == (1684, 1190, 3) + + +def _check_doc_content(doc_tensors, num_pages): + # 1 doc of 8 pages + assert(len(doc_tensors) == num_pages) + assert all(isinstance(page, np.ndarray) for page in doc_tensors) + assert all(page.dtype == np.uint8 for page in doc_tensors) + + +def test_read_pdf(mock_pdf): + doc = io.read_pdf(mock_pdf) + assert isinstance(doc, fitz.Document) + + with open(mock_pdf, 'rb') as f: + doc = io.read_pdf(f.read()) + assert isinstance(doc, fitz.Document) + + # Wrong input type + with pytest.raises(TypeError): + _ = io.read_pdf(123) + + # Wrong path + with pytest.raises(FileNotFoundError): + _ = io.read_pdf("my_imaginary_file.pdf") + + +def test_read_img_as_numpy(tmpdir_factory, mock_pdf): + + # Wrong input type + with pytest.raises(TypeError): + _ = io.read_img_as_numpy(123) + + # Non-existing file + with pytest.raises(FileNotFoundError): + io.read_img_as_numpy("my_imaginary_file.jpg") + + # Invalid image + with pytest.raises(ValueError): + io.read_img_as_numpy(str(mock_pdf)) + + # From path + url = 'https://github.com/mindee/doctr/releases/download/v0.2.1/Grace_Hopper.jpg' + file = BytesIO(requests.get(url).content) + tmp_path = str(tmpdir_factory.mktemp("data").join("mock_img_file.jpg")) + with open(tmp_path, 'wb') as f: + f.write(file.getbuffer()) + + # Path & stream + with open(tmp_path, 'rb') as f: + page_stream = io.read_img_as_numpy(f.read()) + + for page in (io.read_img_as_numpy(tmp_path), page_stream): + # Data type + assert isinstance(page, np.ndarray) + assert page.dtype == np.uint8 + # Shape + assert page.shape == (606, 517, 3) + + # RGB + bgr_page = io.read_img_as_numpy(tmp_path, rgb_output=False) + assert np.all(page == bgr_page[..., ::-1]) + + # Resize + target_size = (200, 150) + resized_page = io.read_img_as_numpy(tmp_path, target_size) + assert resized_page.shape[:2] == target_size + + +def test_read_html(): + url = "https://www.google.com" + pdf_stream = io.read_html(url) + assert isinstance(pdf_stream, bytes) + + +def test_document_file(mock_pdf, mock_image_stream): + pages = io.DocumentFile.from_images(mock_image_stream) + _check_doc_content(pages, 1) + + assert isinstance(io.DocumentFile.from_pdf(mock_pdf).doc, fitz.Document) + assert isinstance(io.DocumentFile.from_url("https://www.google.com").doc, fitz.Document) + + +def test_pdf(mock_pdf): + + doc = io.DocumentFile.from_pdf(mock_pdf) + + # As images + pages = doc.as_images() + num_pages = 2 + _check_doc_content(pages, num_pages) + + # Get words + words = doc.get_words() + assert isinstance(words, list) and len(words) == num_pages + assert len([word for page_words in words for word in page_words]) == 9 + assert all(isinstance(bbox, tuple) and isinstance(value, str) + for page_words in words for (bbox, value) in page_words) + assert all(all(isinstance(coord, float) for coord in bbox) for page_words in words for (bbox, value) in page_words) + + # Get lines + lines = doc.get_lines() + assert isinstance(lines, list) and len(lines) == num_pages + assert len([line for page_lines in lines for line in page_lines]) == 2 + assert all(isinstance(bbox, tuple) and isinstance(value, str) + for page_lines in lines for (bbox, value) in page_lines) + assert all(all(isinstance(coord, float) for coord in bbox) for page_lines in lines for (bbox, value) in page_lines) + + # Get artefacts + artefacts = doc.get_artefacts() + assert isinstance(artefacts, list) and len(artefacts) == num_pages + assert len([art for page_art in artefacts for art in page_art]) == 0 + assert all(isinstance(bbox, tuple) for page_artefacts in artefacts for bbox in page_artefacts) + assert all(all(isinstance(coord, float) for coord in bbox) + for page_artefacts in artefacts for bbox in page_artefacts) diff --git a/tests/common/test_io_elements.py b/tests/common/test_io_elements.py new file mode 100644 index 0000000000..b602b2b83d --- /dev/null +++ b/tests/common/test_io_elements.py @@ -0,0 +1,235 @@ +from xml.etree.ElementTree import ElementTree + +import numpy as np +import pytest + +from doctr.io import elements + + +def _mock_words(size=(1., 1.), offset=(0, 0), confidence=0.9): + return [ + elements.Word("hello", confidence, ( + (offset[0], offset[1]), + (size[0] / 2 + offset[0], size[1] / 2 + offset[1]) + )), + elements.Word("world", confidence, ( + (size[0] / 2 + offset[0], size[1] / 2 + offset[1]), + (size[0] + offset[0], size[1] + offset[1]) + )) + ] + + +def _mock_artefacts(size=(1, 1), offset=(0, 0), confidence=0.8): + sub_size = (size[0] / 2, size[1] / 2) + return [ + elements.Artefact("qr_code", confidence, ( + (offset[0], offset[1]), + (sub_size[0] + offset[0], sub_size[1] + offset[1]) + )), + elements.Artefact("qr_code", confidence, ( + (sub_size[0] + offset[0], sub_size[1] + offset[1]), + (size[0] + offset[0], size[1] + offset[1]) + )), + ] + + +def _mock_lines(size=(1, 1), offset=(0, 0)): + sub_size = (size[0] / 2, size[1] / 2) + return [ + elements.Line(_mock_words(size=sub_size, offset=offset)), + elements.Line(_mock_words(size=sub_size, offset=(offset[0] + sub_size[0], offset[1] + sub_size[1]))), + ] + + +def _mock_blocks(size=(1, 1), offset=(0, 0)): + sub_size = (size[0] / 4, size[1] / 4) + return [ + elements.Block( + _mock_lines(size=sub_size, offset=offset), + _mock_artefacts(size=sub_size, offset=(offset[0] + sub_size[0], offset[1] + sub_size[1])) + ), + elements.Block( + _mock_lines(size=sub_size, offset=(offset[0] + 2 * sub_size[0], offset[1] + 2 * sub_size[1])), + _mock_artefacts(size=sub_size, offset=(offset[0] + 3 * sub_size[0], offset[1] + 3 * sub_size[1])), + ), + ] + + +def _mock_pages(block_size=(1, 1), block_offset=(0, 0)): + return [ + elements.Page(_mock_blocks(block_size, block_offset), 0, (300, 200), + {"value": 0., "confidence": 1.}, {"value": "EN", "confidence": 0.8}), + elements.Page(_mock_blocks(block_size, block_offset), 1, (500, 1000), + {"value": 0.15, "confidence": 0.8}, {"value": "FR", "confidence": 0.7}), + ] + + +def test_element(): + with pytest.raises(KeyError): + elements.Element(sub_elements=[1]) + + +def test_word(): + word_str = "hello" + conf = 0.8 + geom = ((0, 0), (1, 1)) + word = elements.Word(word_str, conf, geom) + + # Attribute checks + assert word.value == word_str + assert word.confidence == conf + assert word.geometry == geom + + # Render + assert word.render() == word_str + + # Export + assert word.export() == {"value": word_str, "confidence": conf, "geometry": geom} + + # Repr + assert word.__repr__() == f"Word(value='hello', confidence={conf:.2})" + + # Class method + state_dict = {"value": "there", "confidence": 0.1, "geometry": ((0, 0), (.5, .5))} + word = elements.Word.from_dict(state_dict) + assert word.export() == state_dict + + +def test_line(): + geom = ((0, 0), (0.5, 0.5)) + words = _mock_words(size=geom[1], offset=geom[0]) + line = elements.Line(words) + + # Attribute checks + assert len(line.words) == len(words) + assert all(isinstance(w, elements.Word) for w in line.words) + assert line.geometry == geom + + # Render + assert line.render() == "hello world" + + # Export + assert line.export() == {"words": [w.export() for w in words], "geometry": geom} + + # Repr + words_str = ' ' * 4 + ',\n '.join(repr(word) for word in words) + ',' + assert line.__repr__() == f"Line(\n (words): [\n{words_str}\n ]\n)" + + # Ensure that words repr does't span on several lines when there are none + assert repr(elements.Line([], ((0, 0), (1, 1)))) == "Line(\n (words): []\n)" + + # from dict + state_dict = { + "words": [{"value": "there", "confidence": 0.1, "geometry": ((0, 0), (1., 1.))}], + "geometry": ((0, 0), (1., 1.)) + } + line = elements.Line.from_dict(state_dict) + assert line.export() == state_dict + + +def test_artefact(): + artefact_type = "qr_code" + conf = 0.8 + geom = ((0, 0), (1, 1)) + artefact = elements.Artefact(artefact_type, conf, geom) + + # Attribute checks + assert artefact.type == artefact_type + assert artefact.confidence == conf + assert artefact.geometry == geom + + # Render + assert artefact.render() == "[QR_CODE]" + + # Export + assert artefact.export() == {"type": artefact_type, "confidence": conf, "geometry": geom} + + # Repr + assert artefact.__repr__() == f"Artefact(type='{artefact_type}', confidence={conf:.2})" + + +def test_block(): + geom = ((0, 0), (1, 1)) + sub_size = (geom[1][0] / 2, geom[1][0] / 2) + lines = _mock_lines(size=sub_size, offset=geom[0]) + artefacts = _mock_artefacts(size=sub_size, offset=sub_size) + block = elements.Block(lines, artefacts) + + # Attribute checks + assert len(block.lines) == len(lines) + assert len(block.artefacts) == len(artefacts) + assert all(isinstance(w, elements.Line) for w in block.lines) + assert all(isinstance(a, elements.Artefact) for a in block.artefacts) + assert block.geometry == geom + + # Render + assert block.render() == "hello world\nhello world" + + # Export + assert block.export() == {"lines": [line.export() for line in lines], + "artefacts": [artefact.export() for artefact in artefacts], "geometry": geom} + + +def test_page(): + page_idx = 0 + page_size = (300, 200) + orientation = {"value": 0., "confidence": 0.} + language = {"value": "EN", "confidence": 0.8} + blocks = _mock_blocks() + page = elements.Page(blocks, page_idx, page_size, orientation, language) + + # Attribute checks + assert len(page.blocks) == len(blocks) + assert all(isinstance(b, elements.Block) for b in page.blocks) + assert page.page_idx == page_idx + assert page.dimensions == page_size + assert page.orientation == orientation + assert page.language == language + + # Render + assert page.render() == "hello world\nhello world\n\nhello world\nhello world" + + # Export + assert page.export() == {"blocks": [b.export() for b in blocks], "page_idx": page_idx, "dimensions": page_size, + "orientation": orientation, "language": language} + + # Export XML + assert isinstance(page.export_as_xml(), tuple) and isinstance( + page.export_as_xml()[0], (bytes, bytearray)) and isinstance(page.export_as_xml()[1], ElementTree) + + # Repr + assert '\n'.join(repr(page).split('\n')[:2]) == f'Page(\n dimensions={repr(page_size)}' + + # Show + page.show(np.zeros((256, 256, 3), dtype=np.uint8), block=False) + + # Synthesize + img = page.synthesize() + assert isinstance(img, np.ndarray) + assert img.shape == (*page_size, 3) + + +def test_document(): + pages = _mock_pages() + doc = elements.Document(pages) + + # Attribute checks + assert len(doc.pages) == len(pages) + assert all(isinstance(p, elements.Page) for p in doc.pages) + + # Render + page_export = "hello world\nhello world\n\nhello world\nhello world" + assert doc.render() == f"{page_export}\n\n\n\n{page_export}" + + # Export + assert doc.export() == {"pages": [p.export() for p in pages]} + + # Export XML + assert isinstance(doc.export_as_xml(), list) and len(doc.export_as_xml()) == len(pages) + + # Show + doc.show([np.zeros((256, 256, 3), dtype=np.uint8) for _ in range(len(pages))], block=False) + + # Synthesize + img_list = doc.synthesize() + assert isinstance(img_list, list) and len(img_list) == len(pages) diff --git a/tests/common/test_models.py b/tests/common/test_models.py new file mode 100644 index 0000000000..05e8f125b6 --- /dev/null +++ b/tests/common/test_models.py @@ -0,0 +1,105 @@ +from copy import deepcopy +from io import BytesIO + +import cv2 +import numpy as np +import pytest +import requests + +from doctr.io import DocumentFile, reader +from doctr.models._utils import estimate_orientation, extract_crops, extract_rcrops, get_bitmap_angle +from doctr.utils import geometry + + +def test_extract_crops(mock_pdf): # noqa: F811 + doc_img = DocumentFile.from_pdf(mock_pdf).as_images()[0] + num_crops = 2 + rel_boxes = np.array([[idx / num_crops, idx / num_crops, (idx + 1) / num_crops, (idx + 1) / num_crops] + for idx in range(num_crops)], dtype=np.float32) + abs_boxes = np.array([[int(idx * doc_img.shape[1] / num_crops), + int(idx * doc_img.shape[0]) / num_crops, + int((idx + 1) * doc_img.shape[1] / num_crops), + int((idx + 1) * doc_img.shape[0] / num_crops)] + for idx in range(num_crops)], dtype=np.float32) + + with pytest.raises(AssertionError): + extract_crops(doc_img, np.zeros((1, 5))) + + for boxes in (rel_boxes, abs_boxes): + croped_imgs = extract_crops(doc_img, boxes) + # Number of crops + assert len(croped_imgs) == num_crops + # Data type and shape + assert all(isinstance(crop, np.ndarray) for crop in croped_imgs) + assert all(crop.ndim == 3 for crop in croped_imgs) + + # Identity + assert np.all(doc_img == extract_crops(doc_img, np.array([[0, 0, 1, 1]], dtype=np.float32), channels_last=True)[0]) + torch_img = np.transpose(doc_img, axes=(-1, 0, 1)) + assert np.all(torch_img == np.transpose( + extract_crops(doc_img, np.array([[0, 0, 1, 1]], dtype=np.float32), channels_last=False)[0], + axes=(-1, 0, 1) + )) + + # No box + assert extract_crops(doc_img, np.zeros((0, 4))) == [] + + +def test_extract_rcrops(mock_pdf): # noqa: F811 + doc_img = DocumentFile.from_pdf(mock_pdf).as_images()[0] + num_crops = 2 + rel_boxes = np.array([[[idx / num_crops, idx / num_crops], + [idx / num_crops + .1, idx / num_crops], + [idx / num_crops + .1, idx / num_crops + .1], + [idx / num_crops, idx / num_crops]] + for idx in range(num_crops)], dtype=np.float32) + abs_boxes = deepcopy(rel_boxes) + abs_boxes[:, :, 0] *= doc_img.shape[1] + abs_boxes[:, :, 1] *= doc_img.shape[0] + abs_boxes = abs_boxes.astype(np.int) + + with pytest.raises(AssertionError): + extract_rcrops(doc_img, np.zeros((1, 8))) + for boxes in (rel_boxes, abs_boxes): + croped_imgs = extract_rcrops(doc_img, boxes) + # Number of crops + assert len(croped_imgs) == num_crops + # Data type and shape + assert all(isinstance(crop, np.ndarray) for crop in croped_imgs) + assert all(crop.ndim == 3 for crop in croped_imgs) + + # No box + assert extract_rcrops(doc_img, np.zeros((0, 4, 2))) == [] + + +@pytest.fixture(scope="function") +def mock_image(tmpdir_factory): + url = 'https://github.com/mindee/doctr/releases/download/v0.2.1/bitmap30.png' + file = BytesIO(requests.get(url).content) + tmp_path = str(tmpdir_factory.mktemp("data").join("mock_bitmap.jpg")) + with open(tmp_path, 'wb') as f: + f.write(file.getbuffer()) + image = reader.read_img_as_numpy(tmp_path) + return image + + +@pytest.fixture(scope="function") +def mock_bitmap(mock_image): + bitmap = np.squeeze(cv2.cvtColor(mock_image, cv2.COLOR_BGR2GRAY) / 255.) + return bitmap + + +def test_get_bitmap_angle(mock_bitmap): + angle = get_bitmap_angle(mock_bitmap) + assert abs(angle - 30.) < 1. + + +def test_estimate_orientation(mock_image): + assert estimate_orientation(mock_image * 0) == 0 + + angle = estimate_orientation(mock_image) + assert abs(angle - 30.) < 1. + + rotated = geometry.rotate_image(mock_image, -angle) + angle_rotated = estimate_orientation(rotated) + assert abs(angle_rotated) < 1. diff --git a/tests/common/test_models_artefacts.py b/tests/common/test_models_artefacts.py new file mode 100644 index 0000000000..b5baaf185c --- /dev/null +++ b/tests/common/test_models_artefacts.py @@ -0,0 +1,20 @@ +import os + +from doctr.io import DocumentFile +from doctr.models.artefacts import BarCodeDetector, FaceDetector + + +def test_qr_code_detector(mock_image_folder): + detector = BarCodeDetector() + for img in os.listdir(mock_image_folder): + image = DocumentFile.from_images(os.path.join(mock_image_folder, img))[0] + barcode = detector(image) + assert len(barcode) == 0 + + +def test_face_detector(mock_image_folder): + detector = FaceDetector(n_faces=1) + for img in os.listdir(mock_image_folder): + image = DocumentFile.from_images(os.path.join(mock_image_folder, img))[0] + faces = detector(image) + assert len(faces) <= 1 diff --git a/tests/common/test_models_builder.py b/tests/common/test_models_builder.py new file mode 100644 index 0000000000..ec468c310c --- /dev/null +++ b/tests/common/test_models_builder.py @@ -0,0 +1,92 @@ +import numpy as np +import pytest + +from doctr.io import Document +from doctr.models import builder + + +def test_documentbuilder(): + + words_per_page = 10 + num_pages = 2 + + # Don't resolve lines + doc_builder = builder.DocumentBuilder(resolve_lines=False, resolve_blocks=False) + boxes = np.random.rand(words_per_page, 6) + boxes[:2] *= boxes[2:4] + + # Arg consistency check + with pytest.raises(ValueError): + doc_builder([boxes, boxes], [('hello', 1.0)] * 3, [(100, 200), (100, 200)]) + out = doc_builder([boxes, boxes], [[('hello', 1.0)] * words_per_page] * num_pages, [(100, 200), (100, 200)]) + assert isinstance(out, Document) + assert len(out.pages) == num_pages + # 1 Block & 1 line per page + assert len(out.pages[0].blocks) == 1 and len(out.pages[0].blocks[0].lines) == 1 + assert len(out.pages[0].blocks[0].lines[0].words) == words_per_page + + # Resolve lines + doc_builder = builder.DocumentBuilder(resolve_lines=True, resolve_blocks=True) + out = doc_builder([boxes, boxes], [[('hello', 1.0)] * words_per_page] * num_pages, [(100, 200), (100, 200)]) + + # No detection + boxes = np.zeros((0, 5)) + out = doc_builder([boxes, boxes], [[], []], [(100, 200), (100, 200)]) + assert len(out.pages[0].blocks) == 0 + + # Rotated boxes to export as straight boxes + boxes = np.array([ + [[0.1, 0.1], [0.2, 0.2], [0.15, 0.25], [0.05, 0.15]], + [[0.5, 0.5], [0.6, 0.6], [0.55, 0.65], [0.45, 0.55]], + ]) + doc_builder_2 = builder.DocumentBuilder( + resolve_blocks=False, + resolve_lines=False, + export_as_straight_boxes=True + ) + out = doc_builder_2([boxes], [[("hello", 0.99), ("world", 0.99)]], [(100, 100)]) + assert out.pages[0].blocks[0].lines[0].words[-1].geometry == ((0.45, 0.5), (0.6, 0.65)) + + # Repr + assert repr(doc_builder) == "DocumentBuilder(resolve_lines=True, " \ + "resolve_blocks=True, paragraph_break=0.035, export_as_straight_boxes=False)" + + +@pytest.mark.parametrize( + "input_boxes, sorted_idxs", + [ + [[[0, 0.5, 0.1, 0.6], [0, 0.3, 0.2, 0.4], [0, 0, 0.1, 0.1]], [2, 1, 0]], # vertical + [[[0.7, 0.5, 0.85, 0.6], [0.2, 0.3, 0.4, 0.4], [0, 0, 0.1, 0.1]], [2, 1, 0]], # diagonal + [[[0, 0.5, 0.1, 0.6], [0.15, 0.5, 0.25, 0.6], [0.5, 0.5, 0.6, 0.6]], [0, 1, 2]], # same line, 2p + [[[0, 0.5, 0.1, 0.6], [0.2, 0.49, 0.35, 0.59], [0.8, 0.52, 0.9, 0.63]], [0, 1, 2]], # ~same line + [[[0, 0.3, 0.4, 0.45], [0.5, 0.28, 0.75, 0.42], [0, 0.45, 0.1, 0.55]], [0, 1, 2]], # 2 lines + [[[0, 0.3, 0.4, 0.35], [0.75, 0.28, 0.95, 0.42], [0, 0.45, 0.1, 0.55]], [0, 1, 2]], # 2 lines + [[[[.1, .1], [.2, .2], [.15, .25], [.05, .15]], [[.5, .5], [.6, .6], [.55, .65], [.45, .55]]], [0, 1]], # rot + ], +) +def test_sort_boxes(input_boxes, sorted_idxs): + + doc_builder = builder.DocumentBuilder() + assert doc_builder._sort_boxes(np.asarray(input_boxes))[0].tolist() == sorted_idxs + + +@pytest.mark.parametrize( + "input_boxes, lines", + [ + [[[0, 0.5, 0.1, 0.6], [0, 0.3, 0.2, 0.4], [0, 0, 0.1, 0.1]], [[2], [1], [0]]], # vertical + [[[0.7, 0.5, 0.85, 0.6], [0.2, 0.3, 0.4, 0.4], [0, 0, 0.1, 0.1]], [[2], [1], [0]]], # diagonal + [[[0, 0.5, 0.14, 0.6], [0.15, 0.5, 0.25, 0.6], [0.5, 0.5, 0.6, 0.6]], [[0, 1], [2]]], # same line, 2p + [[[0, 0.5, 0.18, 0.6], [0.2, 0.48, 0.35, 0.58], [0.8, 0.52, 0.9, 0.63]], [[0, 1], [2]]], # ~same line + [[[0, 0.3, 0.48, 0.45], [0.5, 0.28, 0.75, 0.42], [0, 0.45, 0.1, 0.55]], [[0, 1], [2]]], # 2 lines + [[[0, 0.3, 0.4, 0.35], [0.75, 0.28, 0.95, 0.42], [0, 0.45, 0.1, 0.55]], [[0], [1], [2]]], # 2 lines + [ + [[[.1, .1], [.2, .2], [.15, .25], [.05, .15]], + [[.5, .5], [.6, .6], [.55, .65], [.45, .55]]], + [[0], [1]] + ], # rot + ], +) +def test_resolve_lines(input_boxes, lines): + + doc_builder = builder.DocumentBuilder() + assert doc_builder._resolve_lines(np.asarray(input_boxes)) == lines diff --git a/tests/common/test_models_detection.py b/tests/common/test_models_detection.py new file mode 100644 index 0000000000..4038445ee0 --- /dev/null +++ b/tests/common/test_models_detection.py @@ -0,0 +1,75 @@ +import numpy as np +import pytest + +from doctr.models.detection.differentiable_binarization.base import DBPostProcessor +from doctr.models.detection.linknet.base import LinkNetPostProcessor + + +def test_dbpostprocessor(): + postprocessor = DBPostProcessor(assume_straight_pages=True) + r_postprocessor = DBPostProcessor(assume_straight_pages=False) + with pytest.raises(AssertionError): + postprocessor(np.random.rand(2, 512, 512).astype(np.float32)) + mock_batch = np.random.rand(2, 512, 512, 1).astype(np.float32) + out = postprocessor(mock_batch) + r_out = r_postprocessor(mock_batch) + # Batch composition + assert isinstance(out, list) + assert len(out) == 2 + assert all(isinstance(sample, list) and all(isinstance(v, np.ndarray) for v in sample) for sample in out) + assert all(all(v.shape[1] == 5 for v in sample) for sample in out) + assert all(all(v.shape[1] == 4 and v.shape[2] == 2 for v in sample) for sample in r_out) + # Relative coords + assert all(all(np.all(np.logical_and(v[:, :4] >= 0, v[:, :4] <= 1)) for v in sample) for sample in out) + assert all(all(np.all(np.logical_and(v[:, :4] >= 0, v[:, :4] <= 1)) for v in sample) for sample in r_out) + # Repr + assert repr(postprocessor) == 'DBPostProcessor(bin_thresh=0.3, box_thresh=0.1)' + # Edge case when the expanded points of the polygon has two lists + issue_points = np.array([ + [869, 561], + [923, 581], + [925, 595], + [915, 583], + [889, 583], + [905, 593], + [882, 601], + [901, 595], + [904, 604], + [876, 608], + [915, 614], + [911, 605], + [925, 601], + [930, 616], + [911, 617], + [900, 636], + [931, 637], + [904, 649], + [932, 649], + [932, 628], + [918, 627], + [934, 624], + [935, 573], + [909, 569], + [934, 562]], dtype=np.int32) + out = postprocessor.polygon_to_box(issue_points) + r_out = r_postprocessor.polygon_to_box(issue_points) + assert isinstance(out, tuple) and len(out) == 4 + assert isinstance(r_out, np.ndarray) and r_out.shape == (4, 2) + + +def test_linknet_postprocessor(): + postprocessor = LinkNetPostProcessor() + r_postprocessor = LinkNetPostProcessor(assume_straight_pages=False) + with pytest.raises(AssertionError): + postprocessor(np.random.rand(2, 512, 512).astype(np.float32)) + mock_batch = np.random.rand(2, 512, 512, 1).astype(np.float32) + out = postprocessor(mock_batch) + r_out = r_postprocessor(mock_batch) + # Batch composition + assert isinstance(out, list) + assert len(out) == 2 + assert all(isinstance(sample, list) and all(isinstance(v, np.ndarray) for v in sample) for sample in out) + assert all(all(v.shape[1] == 5 for v in sample) for sample in out) + assert all(all(v.shape[1] == 4 and v.shape[2] == 2 for v in sample) for sample in r_out) + # Relative coords + assert all(all(np.all(np.logical_and(v[:4] >= 0, v[:4] <= 1)) for v in sample) for sample in out) diff --git a/tests/common/test_models_recognition_predictor.py b/tests/common/test_models_recognition_predictor.py new file mode 100644 index 0000000000..92bfb084f0 --- /dev/null +++ b/tests/common/test_models_recognition_predictor.py @@ -0,0 +1,39 @@ +import numpy as np +import pytest + +from doctr.models.recognition.predictor._utils import remap_preds, split_crops + + +@pytest.mark.parametrize( + "crops, max_ratio, target_ratio, dilation, channels_last, num_crops", + [ + # No split required + [[np.zeros((32, 128, 3), dtype=np.uint8)], 8, 4, 1.4, True, 1], + [[np.zeros((3, 32, 128), dtype=np.uint8)], 8, 4, 1.4, False, 1], + # Split required + [[np.zeros((32, 1024, 3), dtype=np.uint8)], 8, 6, 1.4, True, 5], + [[np.zeros((3, 32, 1024), dtype=np.uint8)], 8, 6, 1.4, False, 5], + ], +) +def test_split_crops(crops, max_ratio, target_ratio, dilation, channels_last, num_crops): + new_crops, crop_map, should_remap = split_crops(crops, max_ratio, target_ratio, dilation, channels_last) + assert len(new_crops) == num_crops + assert len(crop_map) == len(crops) + assert should_remap == (len(crops) != len(new_crops)) + + +@pytest.mark.parametrize( + "preds, crop_map, dilation, pred", + [ + # Nothing to remap + [[('hello', 0.5)], [0], 1.4, [('hello', 0.5)]], + # Merge + [[('hellowo', 0.5), ('loworld', 0.6)], [(0, 2)], 1.4, [('helloworld', 0.5)]], + ], +) +def test_remap_preds(preds, crop_map, dilation, pred): + preds = remap_preds(preds, crop_map, dilation) + assert len(preds) == len(pred) + assert preds == pred + assert all(isinstance(pred, tuple) for pred in preds) + assert all(isinstance(pred[0], str) and isinstance(pred[1], float) for pred in preds) diff --git a/tests/common/test_models_recognition_utils.py b/tests/common/test_models_recognition_utils.py new file mode 100644 index 0000000000..640a7b0cb9 --- /dev/null +++ b/tests/common/test_models_recognition_utils.py @@ -0,0 +1,29 @@ +import pytest + +from doctr.models.recognition.utils import merge_multi_strings, merge_strings + + +@pytest.mark.parametrize( + "a, b, merged", + [ + ['abc', 'def', 'abcdef'], + ['abcd', 'def', 'abcdef'], + ['abcde', 'def', 'abcdef'], + ['abcdef', 'def', 'abcdef'], + ['abcccc', 'cccccc', 'abcccccccc'], + ], +) +def test_merge_strings(a, b, merged): + assert merged == merge_strings(a, b, 1.4) + + +@pytest.mark.parametrize( + "seq_list, merged", + [ + [['abc', 'def'], 'abcdef'], + [['abcd', 'def', 'efgh', 'ijk'], 'abcdefghijk'], + [['abcdi', 'defk', 'efghi', 'aijk'], 'abcdefghijk'], + ], +) +def test_merge_multi_strings(seq_list, merged): + assert merged == merge_multi_strings(seq_list, 1.4) diff --git a/tests/common/test_requirements.py b/tests/common/test_requirements.py new file mode 100644 index 0000000000..79fed1f84d --- /dev/null +++ b/tests/common/test_requirements.py @@ -0,0 +1,48 @@ +from pathlib import Path + +import requirements +from requirements.requirement import Requirement + + +def test_deps_consistency(): + + IGNORE = ["flake8", "isort", "mypy", "pydocstyle", "importlib_metadata", "tensorflow-cpu"] + # Collect the deps from all requirements.txt + REQ_FILES = ["requirements.txt", "requirements-pt.txt", "tests/requirements.txt", "docs/requirements.txt"] + folder = Path(__file__).parent.parent.parent.absolute() + req_deps = {} + for file in REQ_FILES: + with open(folder.joinpath(file), 'r') as f: + _deps = [(req.name, req.specs) for req in requirements.parse(f)] + + for _dep in _deps: + lib, specs = _dep + assert req_deps.get(lib, specs) == specs, f"conflicting deps for {lib}" + req_deps[lib] = specs + + # Collect the one from setup.py + setup_deps = {} + with open(folder.joinpath("setup.py"), 'r') as f: + setup = f.readlines() + lines = setup[setup.index("_deps = [\n") + 1:] + lines = [_dep.strip() for _dep in lines[:lines.index("]\n")]] + lines = [_dep.split('"')[1] for _dep in lines if _dep.startswith('"')] + _reqs = [Requirement.parse(_line) for _line in lines] + _deps = [(req.name, req.specs) for req in _reqs] + for _dep in _deps: + lib, specs = _dep + assert setup_deps.get(lib) is None, f"conflicting deps for {lib}" + setup_deps[lib] = specs + + # Remove ignores + for k in IGNORE: + if isinstance(req_deps.get(k), list): + del req_deps[k] + if isinstance(setup_deps.get(k), list): + del setup_deps[k] + + # Compare them + assert len(req_deps) == len(setup_deps) + for k, v in setup_deps.items(): + assert isinstance(req_deps.get(k), list) + assert req_deps[k] == v, f"Mismatch on dependency {k}: {v} from setup.py, {req_deps[k]} from requirements.txt" diff --git a/tests/common/test_transforms.py b/tests/common/test_transforms.py new file mode 100644 index 0000000000..8ee011b461 --- /dev/null +++ b/tests/common/test_transforms.py @@ -0,0 +1,28 @@ +from doctr.transforms import modules as T + + +def test_imagetransform(): + + transfo = T.ImageTransform(lambda x: 1 - x) + assert transfo(0, 1) == (1, 1) + + +def test_samplecompose(): + + transfos = [lambda x, y: (1 - x, y), lambda x, y: (x, 2 * y)] + transfo = T.SampleCompose(transfos) + assert transfo(0, 1) == (1, 2) + + +def test_oneof(): + transfos = [lambda x: 1 - x, lambda x: x + 10] + transfo = T.OneOf(transfos) + out = transfo(1) + assert out == 0 or out == 11 + + +def test_randomapply(): + transfo = T.RandomApply(lambda x: 1 - x) + out = transfo(1) + assert out == 0 or out == 1 + assert repr(transfo).endswith(", p=0.5)") diff --git a/tests/common/test_utils_fonts.py b/tests/common/test_utils_fonts.py new file mode 100644 index 0000000000..ae268dd1df --- /dev/null +++ b/tests/common/test_utils_fonts.py @@ -0,0 +1,11 @@ +from PIL.ImageFont import FreeTypeFont, ImageFont + +from doctr.utils.fonts import get_font + + +def test_get_font(): + + # Attempts to load recommended OS font + font = get_font() + + assert isinstance(font, (ImageFont, FreeTypeFont)) diff --git a/tests/common/test_utils_geometry.py b/tests/common/test_utils_geometry.py new file mode 100644 index 0000000000..7cb91f81d4 --- /dev/null +++ b/tests/common/test_utils_geometry.py @@ -0,0 +1,111 @@ +import numpy as np +import pytest + +from doctr.utils import geometry + + +def test_bbox_to_polygon(): + assert geometry.bbox_to_polygon(((0, 0), (1, 1))) == ((0, 0), (1, 0), (0, 1), (1, 1)) + + +def test_polygon_to_bbox(): + assert geometry.polygon_to_bbox(((0, 0), (1, 0), (0, 1), (1, 1))) == ((0, 0), (1, 1)) + + +def test_resolve_enclosing_bbox(): + assert geometry.resolve_enclosing_bbox([((0, 0.5), (1, 0)), ((0.5, 0), (1, 0.25))]) == ((0, 0), (1, 0.5)) + pred = geometry.resolve_enclosing_bbox(np.array([[0.1, 0.1, 0.2, 0.2, 0.9], [0.15, 0.15, 0.2, 0.2, 0.8]])) + assert pred.all() == np.array([0.1, 0.1, 0.2, 0.2, 0.85]).all() + + +def test_resolve_enclosing_rbbox(): + pred = geometry.resolve_enclosing_rbbox([ + np.asarray([[.1, .1], [.2, .2], [.15, .25], [.05, .15]]), + np.asarray([[.5, .5], [.6, .6], [.55, .65], [.45, .55]]) + ]) + target1 = np.asarray([[.55, .65], [.05, .15], [.1, .1], [.6, .6]]) + target2 = np.asarray([[.05, .15], [.1, .1], [.6, .6], [.55, .65]]) + assert np.all(target1 - pred <= 1e-3) or np.all(target2 - pred <= 1e-3) + + +def test_rotate_boxes(): + boxes = np.array([[0.1, 0.1, 0.8, 0.3, 0.5]]) + rboxes = np.array([[0.1, 0.1], [0.8, 0.1], [0.8, 0.3], [0.1, 0.3]]) + # Angle = 0 + rotated = geometry.rotate_boxes(boxes, angle=0., orig_shape=(1, 1)) + assert np.all(rotated == rboxes) + # Angle < 1: + rotated = geometry.rotate_boxes(boxes, angle=0.5, orig_shape=(1, 1)) + assert np.all(rotated == rboxes) + # Angle = 30 + rotated = geometry.rotate_boxes(boxes, angle=30, orig_shape=(1, 1)) + assert rotated.shape == (1, 4, 2) + + boxes = np.array([[0., 0., 0.6, 0.2, 0.5]]) + # Angle = -90: + rotated = geometry.rotate_boxes(boxes, angle=-90, orig_shape=(1, 1), min_angle=0) + assert np.allclose(rotated, np.array([[[1, 0.], [1, 0.6], [0.8, 0.6], [0.8, 0.]]])) + # Angle = 90 + rotated = geometry.rotate_boxes(boxes, angle=+90, orig_shape=(1, 1), min_angle=0) + assert np.allclose(rotated, np.array([[[0, 1.], [0, 0.4], [0.2, 0.4], [0.2, 1.]]])) + + +def test_rotate_image(): + img = np.ones((32, 64, 3), dtype=np.float32) + rotated = geometry.rotate_image(img, 30.) + assert rotated.shape[:-1] == (32, 64) + assert rotated[0, 0, 0] == 0 + assert rotated[0, :, 0].sum() > 1 + + # Expand + rotated = geometry.rotate_image(img, 30., expand=True) + assert rotated.shape[:-1] == (60, 120) + assert rotated[0, :, 0].sum() <= 1 + + # Expand + rotated = geometry.rotate_image(img, 30., expand=True, preserve_origin_shape=True) + assert rotated.shape[:-1] == (32, 64) + assert rotated[0, :, 0].sum() <= 1 + + # Expand with 90° rotation + rotated = geometry.rotate_image(img, 90., expand=True) + assert rotated.shape[:-1] == (64, 128) + assert rotated[0, :, 0].sum() <= 1 + + +@pytest.mark.parametrize( + "abs_geoms, img_size, rel_geoms", + [ + # Full image (boxes) + [ + np.array([[0, 0, 32, 32]]), + (32, 32), + np.array([[0, 0, 1, 1]], dtype=np.float32) + ], + # Full image (polygons) + [ + np.array([[[0, 0], [32, 0], [32, 32], [0, 32]]]), + (32, 32), + np.array([[[0, 0], [1, 0], [1, 1], [0, 1]]], dtype=np.float32) + ], + # Quarter image (boxes) + [ + np.array([[0, 0, 16, 16]]), + (32, 32), + np.array([[0, 0, .5, .5]], dtype=np.float32) + ], + # Quarter image (polygons) + [ + np.array([[[0, 0], [16, 0], [16, 16], [0, 16]]]), + (32, 32), + np.array([[[0, 0], [.5, 0], [.5, .5], [0, .5]]], dtype=np.float32) + ], + ], +) +def test_convert_to_relative_coords(abs_geoms, img_size, rel_geoms): + + assert np.all(geometry.convert_to_relative_coords(abs_geoms, img_size) == rel_geoms) + + # Wrong format + with pytest.raises(ValueError): + geometry.convert_to_relative_coords(np.zeros((3, 5)), (32, 32)) diff --git a/tests/common/test_utils_metrics.py b/tests/common/test_utils_metrics.py new file mode 100644 index 0000000000..74a15c0064 --- /dev/null +++ b/tests/common/test_utils_metrics.py @@ -0,0 +1,306 @@ +import numpy as np +import pytest + +from doctr.utils import metrics + + +@pytest.mark.parametrize( + "gt, pred, raw, caseless, unidecode, unicase", + [ + [['grass', '56', 'True', 'EUR'], ['grass', '56', 'true', '€'], .5, .75, .75, 1], + [['éléphant', 'ça'], ['elephant', 'ca'], 0, 0, 1, 1], + ], +) +def test_text_match(gt, pred, raw, caseless, unidecode, unicase): + metric = metrics.TextMatch() + with pytest.raises(AssertionError): + metric.summary() + + with pytest.raises(AssertionError): + metric.update(['a', 'b'], ['c']) + + metric.update(gt, pred) + assert metric.summary() == dict(raw=raw, caseless=caseless, unidecode=unidecode, unicase=unicase) + + metric.reset() + assert metric.raw == metric.caseless == metric.unidecode == metric.unicase == metric.total == 0 + + +@pytest.mark.parametrize( + "box1, box2, iou, abs_tol", + [ + [[[0, 0, .5, .5]], [[0, 0, .5, .5]], 1, 0], # Perfect match + [[[0, 0, .5, .5]], [[.5, .5, 1, 1]], 0, 0], # No match + [[[0, 0, 1, 1]], [[.5, .5, 1, 1]], 0.25, 0], # Partial match + [[[.2, .2, .6, .6]], [[.4, .4, .8, .8]], 4 / 28, 1e-7], # Partial match + [[[0, 0, .1, .1]], [[.9, .9, 1, 1]], 0, 0], # Boxes far from each other + [np.zeros((0, 4)), [[0, 0, .5, .5]], 0, 0], # Zero-sized inputs + [[[0, 0, .5, .5]], np.zeros((0, 4)), 0, 0], # Zero-sized inputs + ], +) +def test_box_iou(box1, box2, iou, abs_tol): + iou_mat = metrics.box_iou(np.asarray(box1), np.asarray(box2)) + assert iou_mat.shape == (len(box1), len(box2)) + if iou_mat.size > 0: + assert abs(iou_mat - iou) <= abs_tol + + +@pytest.mark.parametrize( + "mask1, mask2, iou, abs_tol", + [ + [ + [[[True, True, False], [True, True, False]]], + [[[True, True, False], [True, True, False]]], + 1, + 0 + ], # Perfect match + [ + [[[True, False, False], [False, False, False]]], + [[[True, True, False], [True, True, False]]], + 0.25, + 0 + ], # Partial match + ], +) +def test_mask_iou(mask1, mask2, iou, abs_tol): + iou_mat = metrics.mask_iou(np.asarray(mask1), np.asarray(mask2)) + assert iou_mat.shape == (len(mask1), len(mask2)) + if iou_mat.size > 0: + assert abs(iou_mat - iou) <= abs_tol + + # Incompatible spatial shapes + with pytest.raises(AssertionError): + metrics.mask_iou(np.zeros((2, 3, 5), dtype=bool), np.ones((3, 2, 5), dtype=bool)) + + +@pytest.mark.parametrize( + "rbox1, rbox2, iou, abs_tol", + [ + [[[[0, 0], [.5, 0], [.5, .5], [0, .5]]], [[[0, 0], [.5, 0], [.5, .5], [0, .5]]], 1, 0], # Perfect match + [[[[0, 0], [.5, 0], [.5, .5], [0, .5]]], [[[.5, .5], [1, .5], [1, 1], [.5, 1]]], 0, 1e-4], # No match + [[[[0, 0], [1., 0], [1., 1.], [0, 1.]]], [[[.5, .5], [1, .5], [1., 1.], [.5, 1]]], 0.25, 5e-3], # Partial match + [ + [[[.2, .2], [.6, .2], [.6, .6], [.2, .6]]], + [[[.4, .4], [.8, .4], [.8, .8], [.4, .8]]], 4 / 28, 7e-3 + ], # Partial match + [ + [[[0, 0], [.05, 0], [.05, .05], [0, .05]]], + [[[.5, .5], [1, .5], [1, 1], [.5, 1]]], 0, 0 + ], # Boxes far from each other + [np.zeros((0, 4, 2)), [[[0, 0], [.05, 0], [.05, .05], [0, .05]]], 0, 0], # Zero-sized inputs + [[[[0, 0], [.05, 0], [.05, .05], [0, .05]]], np.zeros((0, 4, 2)), 0, 0], # Zero-sized inputs + ], +) +def test_polygon_iou(rbox1, rbox2, iou, abs_tol): + mask_shape = (256, 256) + iou_mat = metrics.polygon_iou(np.asarray(rbox1), np.asarray(rbox2), mask_shape) + assert iou_mat.shape == (len(rbox1), len(rbox2)) + if iou_mat.size > 0: + assert abs(iou_mat - iou) <= abs_tol + + # Ensure broadcasting doesn't change the result + iou_matbis = metrics.polygon_iou(np.asarray(rbox1), np.asarray(rbox2), mask_shape, use_broadcasting=False) + assert np.all((iou_mat - iou_matbis) <= 1e-7) + + # Incorrect boxes + with pytest.raises(AssertionError): + metrics.polygon_iou(np.zeros((2, 5), dtype=float), np.ones((3, 4), dtype=float), mask_shape) + + +@pytest.mark.parametrize( + "box, shape, mask", + [ + [ + [[0, 0], [.5, 0], [.5, .5], [0, .5]], (2, 2), + [[True, False], [False, False]], + ], + ], +) +def test_rbox_to_mask(box, shape, mask): + masks = metrics.rbox_to_mask(np.asarray(box)[None, ...], shape) + assert masks.shape == (1, *shape) + assert np.all(masks[0] == np.asarray(mask, dtype=bool)) + + +@pytest.mark.parametrize( + "gts, preds, iou_thresh, recall, precision, mean_iou", + [ + [[[[0, 0, .5, .5]]], [[[0, 0, .5, .5]]], 0.5, 1, 1, 1], # Perfect match + [[[[0, 0, 1, 1]]], [[[0, 0, .5, .5], [.6, .6, .7, .7]]], 0.2, 1, 0.5, 0.13], # Bad match + [[[[0, 0, 1, 1]]], [[[0, 0, .5, .5], [.6, .6, .7, .7]]], 0.5, 0, 0, 0.13], # Bad match + [[[[0, 0, .5, .5]], [[0, 0, .5, .5]]], [[[0, 0, .5, .5]], None], 0.5, 0.5, 1, 1], # No preds on 2nd sample + ], +) +def test_localization_confusion(gts, preds, iou_thresh, recall, precision, mean_iou): + + metric = metrics.LocalizationConfusion(iou_thresh) + for _gts, _preds in zip(gts, preds): + metric.update(np.asarray(_gts), np.zeros((0, 4)) if _preds is None else np.asarray(_preds)) + assert metric.summary() == (recall, precision, mean_iou) + metric.reset() + assert metric.num_gts == metric.num_preds == metric.matches == metric.tot_iou == 0 + + +@pytest.mark.parametrize( + "gts, preds, iou_thresh, recall, precision, mean_iou", + [ + [ + [[[[.05, .05], [.15, .05], [.15, .15], [.05, .15]]]], + [[[[.05, .05], [.15, .05], [.15, .15], [.05, .15]]]], 0.5, 1, 1, 1 + ], # Perfect match + [ + [[[[.1, .05], [.2, .05], [.2, .15], [.1, .15]]]], + [[[[.1, .05], [.3, .05], [.3, .15], [.1, .15]], [[.6, .6], [.8, .6], [.8, .8], [.6, .8]]]], + 0.2, 1, 0.5, 0.25 + ], # Bad match + [ + [[[[.05, .05], [.15, .05], [.15, .15], [.05, .15]]], + [[[.25, .25], [.35, .25], [35, .35], [.25, .35]]]], + [[[[.05, .05], [.15, .05], [.15, .15], [.05, .15]]], None], + 0.5, 0.5, 1, 1, + ], # Empty + ], +) +def test_r_localization_confusion(gts, preds, iou_thresh, recall, precision, mean_iou): + + metric = metrics.LocalizationConfusion(iou_thresh, use_polygons=True, mask_shape=(1000, 1000)) + for _gts, _preds in zip(gts, preds): + metric.update(np.asarray(_gts), np.zeros((0, 5)) if _preds is None else np.asarray(_preds)) + assert metric.summary()[:2] == (recall, precision) + assert abs(metric.summary()[2] - mean_iou) <= 5e-3 + metric.reset() + assert metric.num_gts == metric.num_preds == metric.matches == metric.tot_iou == 0 + + +@pytest.mark.parametrize( + "gt_boxes, gt_words, pred_boxes, pred_words, iou_thresh, recall, precision, mean_iou", + [ + [ # Perfect match + [[[0, 0, .5, .5]]], [["elephant"]], + [[[0, 0, .5, .5]]], [["elephant"]], + 0.5, + {"raw": 1, "caseless": 1, "unidecode": 1, "unicase": 1}, + {"raw": 1, "caseless": 1, "unidecode": 1, "unicase": 1}, + 1, + ], + [ # Bad match + [[[0, 0, .5, .5]]], [["elefant"]], + [[[0, 0, .5, .5]]], [["elephant"]], + 0.5, + {"raw": 0, "caseless": 0, "unidecode": 0, "unicase": 0}, + {"raw": 0, "caseless": 0, "unidecode": 0, "unicase": 0}, + 1, + ], + [ # Good match + [[[0, 0, 1, 1]]], [["EUR"]], + [[[0, 0, .5, .5], [.6, .6, .7, .7]]], [["€", "e"]], + 0.2, + {"raw": 0, "caseless": 0, "unidecode": 1, "unicase": 1}, + {"raw": 0, "caseless": 0, "unidecode": .5, "unicase": .5}, + 0.13, + ], + [ # No preds on 2nd sample + [[[0, 0, .5, .5]], [[0, 0, .5, .5]]], [["Elephant"], ["elephant"]], + [[[0, 0, .5, .5]], None], [["elephant"], []], + 0.5, + {"raw": 0, "caseless": .5, "unidecode": 0, "unicase": .5}, + {"raw": 0, "caseless": 1, "unidecode": 0, "unicase": 1}, + 1, + ], + ], +) +def test_ocr_metric( + gt_boxes, gt_words, pred_boxes, pred_words, iou_thresh, recall, precision, mean_iou +): + metric = metrics.OCRMetric(iou_thresh) + for _gboxes, _gwords, _pboxes, _pwords in zip(gt_boxes, gt_words, pred_boxes, pred_words): + metric.update( + np.asarray(_gboxes), + np.zeros((0, 4)) if _pboxes is None else np.asarray(_pboxes), + _gwords, + _pwords + ) + _recall, _precision, _mean_iou = metric.summary() + assert _recall == recall + assert _precision == precision + assert _mean_iou == mean_iou + metric.reset() + assert metric.num_gts == metric.num_preds == metric.tot_iou == 0 + assert metric.raw_matches == metric.caseless_matches == metric.unidecode_matches == metric.unicase_matches == 0 + # Shape check + with pytest.raises(AssertionError): + metric.update( + np.asarray(_gboxes), + np.zeros((0, 4)), + _gwords, + ["I", "have", "a", "bad", "feeling", "about", "this"], + ) + + +@pytest.mark.parametrize( + "gt_boxes, gt_classes, pred_boxes, pred_classes, iou_thresh, recall, precision, mean_iou", + [ + [ # Perfect match + [[[0, 0, .5, .5]]], [[0]], + [[[0, 0, .5, .5]]], [[0]], + 0.5, 1, 1, 1, + ], + [ # Bad match + [[[0, 0, .5, .5]]], [[0]], + [[[0, 0, .5, .5]]], [[1]], + 0.5, 0, 0, 1, + ], + [ # No preds on 2nd sample + [[[0, 0, .5, .5]], [[0, 0, .5, .5]]], [[0], [1]], + [[[0, 0, .5, .5]], None], [[0], []], + 0.5, .5, 1, 1, + ], + ], +) +def test_detection_metric( + gt_boxes, gt_classes, pred_boxes, pred_classes, iou_thresh, recall, precision, mean_iou +): + metric = metrics.DetectionMetric(iou_thresh) + for _gboxes, _gclasses, _pboxes, _pclasses in zip(gt_boxes, gt_classes, pred_boxes, pred_classes): + metric.update( + np.asarray(_gboxes), + np.zeros((0, 4)) if _pboxes is None else np.asarray(_pboxes), + np.array(_gclasses, dtype=np.int64), + np.array(_pclasses, dtype=np.int64), + ) + _recall, _precision, _mean_iou = metric.summary() + assert _recall == recall + assert _precision == precision + assert _mean_iou == mean_iou + metric.reset() + assert metric.num_gts == metric.num_preds == metric.tot_iou == 0 + assert metric.num_matches == 0 + # Shape check + with pytest.raises(AssertionError): + metric.update( + np.asarray(_gboxes), + np.zeros((0, 4)), + np.array(_gclasses, dtype=np.int64), + np.array([1, 2], dtype=np.int64) + ) + + +def test_nms(): + boxes = [ + [0.1, 0.1, 0.2, 0.2, 0.95], + [0.15, 0.15, 0.19, 0.2, 0.90], # to suppress + [0.5, 0.5, 0.6, 0.55, 0.90], + [0.55, 0.5, 0.7, 0.55, 0.85], # to suppress + ] + to_keep = metrics.nms(np.asarray(boxes), thresh=0.2) + assert to_keep == [0, 2] + + +def test_box_ioa(): + boxes = [ + [0.1, 0.1, 0.2, 0.2], + [0.15, 0.15, 0.2, 0.2], + ] + mat = metrics.box_ioa(np.array(boxes), np.array(boxes)) + assert mat[1, 0] == mat[0, 0] == mat[1, 1] == 1. + assert abs(mat[0, 1] - .25) <= 1e-7 diff --git a/tests/common/test_utils_multithreading.py b/tests/common/test_utils_multithreading.py new file mode 100644 index 0000000000..6eac627251 --- /dev/null +++ b/tests/common/test_utils_multithreading.py @@ -0,0 +1,20 @@ +import pytest + +from doctr.utils.multithreading import multithread_exec + + +@pytest.mark.parametrize( + "input_seq, func, output_seq", + [ + [[1, 2, 3], lambda x: 2 * x, [2, 4, 6]], + [[1, 2, 3], lambda x: x ** 2, [1, 4, 9]], + [ + ['this is', 'show me', 'I know'], + lambda x: x + ' the way', + ['this is the way', 'show me the way', 'I know the way'] + ], + ], +) +def test_multithread_exec(input_seq, func, output_seq): + assert multithread_exec(func, input_seq) == output_seq + assert list(multithread_exec(func, input_seq, 0)) == output_seq diff --git a/tests/common/test_utils_visualization.py b/tests/common/test_utils_visualization.py new file mode 100644 index 0000000000..b1e9998698 --- /dev/null +++ b/tests/common/test_utils_visualization.py @@ -0,0 +1,40 @@ +import numpy as np +import pytest +from test_io_elements import _mock_pages + +from doctr.utils import visualization + + +def test_visualize_page(): + pages = _mock_pages() + image = np.ones((300, 200, 3)) + visualization.visualize_page(pages[0].export(), image, words_only=False) + visualization.visualize_page(pages[0].export(), image, words_only=True, interactive=False) + # geometry checks + with pytest.raises(ValueError): + visualization.create_obj_patch([1, 2], (100, 100)) + + with pytest.raises(ValueError): + visualization.create_obj_patch((1, 2), (100, 100)) + + with pytest.raises(ValueError): + visualization.create_obj_patch((1, 2, 3, 4, 5), (100, 100)) + + +def test_synthesize_page(): + pages = _mock_pages() + visualization.synthesize_page(pages[0].export(), draw_proba=False) + render = visualization.synthesize_page(pages[0].export(), draw_proba=True) + assert isinstance(render, np.ndarray) + assert render.shape == (*pages[0].dimensions, 3) + + +def test_draw_boxes(): + image = np.ones((256, 256, 3), dtype=np.float32) + boxes = [ + [0.1, 0.1, 0.2, 0.2], + [0.15, 0.15, 0.19, 0.2], # to suppress + [0.5, 0.5, 0.6, 0.55], + [0.55, 0.5, 0.7, 0.55], # to suppress + ] + visualization.draw_boxes(boxes=np.array(boxes), image=image, block=False) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000000..3c241557d7 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,574 @@ +import json +import shutil +import tempfile +from io import BytesIO + +import fitz +import hdf5storage +import numpy as np +import pytest +import requests +import scipy.io as sio + + +@pytest.fixture(scope="session") +def mock_vocab(): + return ('3K}7eé;5àÎYho]QwV6qU~W"XnbBvcADfËmy.9ÔpÛ*{CôïE%M4#ÈR:g@T$x?0î£|za1ù8,OG€P-kçHëÀÂ2É/ûIJ\'j' + '(LNÙFut[)èZs+&°Sd=Ï!<â_Ç>rêi`l') + + +@pytest.fixture(scope="session") +def mock_pdf(tmpdir_factory): + + doc = fitz.open() + + page = doc.new_page() + page.insert_text(fitz.Point(50, 100), "I am a jedi!", fontsize=20) + page = doc.new_page() + page.insert_text(fitz.Point(50, 100), "No, I am your father.", fontsize=20) + + # Save the PDF + fn = tmpdir_factory.mktemp("data").join("mock_pdf_file.pdf") + with open(fn, 'wb') as f: + doc.save(f) + + return str(fn) + + +@pytest.fixture(scope="session") +def mock_text_box_stream(): + url = 'https://www.pngitem.com/pimgs/m/357-3579845_love-neon-loveislove-word-text-typography-freetoedit-picsart.png' + return requests.get(url).content + + +@pytest.fixture(scope="session") +def mock_text_box(mock_text_box_stream, tmpdir_factory): + file = BytesIO(mock_text_box_stream) + fn = tmpdir_factory.mktemp("data").join("mock_text_box_file.png") + with open(fn, 'wb') as f: + f.write(file.getbuffer()) + return str(fn) + + +@pytest.fixture(scope="session") +def mock_image_stream(): + url = "https://miro.medium.com/max/3349/1*mk1-6aYaf_Bes1E3Imhc0A.jpeg" + return requests.get(url).content + + +@pytest.fixture(scope="session") +def mock_image_path(mock_image_stream, tmpdir_factory): + file = BytesIO(mock_image_stream) + folder = tmpdir_factory.mktemp("images") + fn = folder.join("mock_image_file.jpeg") + with open(fn, 'wb') as f: + f.write(file.getbuffer()) + return str(fn) + + +@pytest.fixture(scope="session") +def mock_image_folder(mock_image_stream, tmpdir_factory): + file = BytesIO(mock_image_stream) + folder = tmpdir_factory.mktemp("images") + for i in range(5): + fn = folder.join("mock_image_file_" + str(i) + ".jpeg") + with open(fn, 'wb') as f: + f.write(file.getbuffer()) + return str(folder) + + +@pytest.fixture(scope="session") +def mock_detection_label(tmpdir_factory): + folder = tmpdir_factory.mktemp("labels") + labels = {} + for idx in range(5): + labels[f"mock_image_file_{idx}.jpeg"] = { + "img_dimensions": (800, 600), + "img_hash": "dummy_hash", + "polygons": [ + [[1, 2], [1, 3], [2, 1], [2, 3]], + [[10, 20], [10, 30], [20, 10], [20, 30]], + [[3, 2], [3, 3], [4, 1], [4, 3]], + [[30, 20], [30, 30], [40, 10], [40, 30]], + ], + } + + labels_path = folder.join('labels.json') + with open(labels_path, 'w') as f: + json.dump(labels, f) + return str(labels_path) + + +@pytest.fixture(scope="session") +def mock_recognition_label(tmpdir_factory): + label_file = tmpdir_factory.mktemp("labels").join("labels.json") + label = { + "mock_image_file_0.jpeg": "I", + "mock_image_file_1.jpeg": "am", + "mock_image_file_2.jpeg": "a", + "mock_image_file_3.jpeg": "jedi", + "mock_image_file_4.jpeg": "!", + } + with open(label_file, 'w') as f: + json.dump(label, f) + return str(label_file) + + +@pytest.fixture(scope="session") +def mock_ocrdataset(tmpdir_factory, mock_image_stream): + root = tmpdir_factory.mktemp("dataset") + label_file = root.join("labels.json") + label = { + "mock_image_file_0.jpg": { + "typed_words": [ + {'value': 'I', 'geometry': (.2, .2, .1, .1, 0)}, + {'value': 'am', 'geometry': (.5, .5, .1, .1, 0)}, + ] + }, + "mock_image_file_1.jpg": { + "typed_words": [ + {'value': 'a', 'geometry': (.2, .2, .1, .1, 0)}, + {'value': 'jedi', 'geometry': (.5, .5, .1, .1, 0)}, + ] + }, + "mock_image_file_2.jpg": { + "typed_words": [ + {'value': '!', 'geometry': (.2, .2, .1, .1, 0)}, + ] + } + } + with open(label_file, 'w') as f: + json.dump(label, f) + + file = BytesIO(mock_image_stream) + image_folder = tmpdir_factory.mktemp("images") + for i in range(3): + fn = image_folder.join(f"mock_image_file_{i}.jpg") + with open(fn, 'wb') as f: + f.write(file.getbuffer()) + + return str(image_folder), str(label_file) + + +@pytest.fixture(scope="session") +def mock_ic13(tmpdir_factory, mock_image_stream): + file = BytesIO(mock_image_stream) + image_folder = tmpdir_factory.mktemp("images") + label_folder = tmpdir_factory.mktemp("labels") + labels = ["1309, 2240, 1440, 2341, 'I'\n", + "800, 2240, 1440, 2341, 'am'\n", + "500, 2240, 1440, 2341, 'a'\n", + "900, 2240, 1440, 2341, 'jedi'\n", + "400, 2240, 1440, 2341, '!'"] + for i in range(5): + fn_l = label_folder.join(f"gt_mock_image_file_{i}.txt") + with open(fn_l, 'w') as f: + f.writelines(labels) + fn_i = image_folder.join(f"mock_image_file_{i}.jpg") + with open(fn_i, 'wb') as f: + f.write(file.getbuffer()) + return str(image_folder), str(label_folder) + + +@pytest.fixture(scope="session") +def mock_imgur5k(tmpdir_factory, mock_image_stream): + file = BytesIO(mock_image_stream) + image_folder = tmpdir_factory.mktemp("images") + label_folder = tmpdir_factory.mktemp("dataset_info") + labels = { + "index_id": { + "YsaVkzl": { + "image_url": "https://i.imgur.com/YsaVkzl.jpg", + "image_path": "/path/to/IMGUR5K-Handwriting-Dataset/images/YsaVkzl.jpg", + "image_hash": "993a7cbb04a7c854d1d841b065948369" + }, + "wz3wHhN": { + "image_url": "https://i.imgur.com/wz3wHhN.jpg", + "image_path": "/path/to/IMGUR5K-Handwriting-Dataset/images/wz3wHhN.jpg", + "image_hash": "9157426a98ee52f3e1e8d41fa3a99175" + }, + "BRHSP23": { + "image_url": "https://i.imgur.com/BRHSP23.jpg", + "image_path": "/path/to/IMGUR5K-Handwriting-Dataset/images/BRHSP23.jpg", + "image_hash": "aab01f7ac82ae53845b01674e9e34167" + } + }, + "index_to_ann_map": { + "YsaVkzl": [ + "YsaVkzl_0", + "YsaVkzl_1", + "YsaVkzl_2"], + "wz3wHhN": [ + "wz3wHhN_0", + "wz3wHhN_1"], + "BRHSP23": [ + "BRHSP23_0"] + }, + "ann_id": { + "YsaVkzl_0": { + "word": "I", + "bounding_box": "[605.33, 1150.67, 614.33, 226.33, 81.0]" + }, + "YsaVkzl_1": { + "word": "am", + "bounding_box": "[783.67, 654.67, 521.0, 222.33, 56.67]" + }, + "YsaVkzl_2": { + "word": "a", + "bounding_box": "[959.0, 437.0, 76.67, 201.0, 38.33]" + }, + "wz3wHhN_0": { + "word": "jedi", + "bounding_box": "[783.67, 654.67, 521.0, 222.33, 56.67]" + }, + "wz3wHhN_1": { + "word": "!", + "bounding_box": "[959.0, 437.0, 76.67, 201.0, 38.33]" + }, + "BRHSP23_0": { + "word": "jedi", + "bounding_box": "[783.67, 654.67, 521.0, 222.33, 56.67]" + } + } + } + label_file = label_folder.join("imgur5k_annotations.json") + with open(label_file, 'w') as f: + json.dump(labels, f) + for index_id in ['YsaVkzl', 'wz3wHhN', 'BRHSP23']: + fn_i = image_folder.join(f"{index_id}.jpg") + with open(fn_i, 'wb') as f: + f.write(file.getbuffer()) + return str(image_folder), str(label_file) + + +@pytest.fixture(scope="session") +def mock_svhn_dataset(tmpdir_factory, mock_image_stream): + root = tmpdir_factory.mktemp('datasets') + svhn_root = root.mkdir('svhn') + file = BytesIO(mock_image_stream) + # ascii image names + first = np.array([[49], [46], [112], [110], [103]], dtype=np.int16) # 1.png + second = np.array([[50], [46], [112], [110], [103]], dtype=np.int16) # 2.png + third = np.array([[51], [46], [112], [110], [103]], dtype=np.int16) # 3.png + # labels: label is also ascii + label = {'height': [35, 35, 35, 35], 'label': [1, 1, 3, 7], + 'left': [116, 128, 137, 151], 'top': [27, 29, 29, 26], + 'width': [15, 10, 17, 17]} + + matcontent = {'digitStruct': {'name': [first, second, third], 'bbox': [label, label, label]}} + # Mock train data + train_root = svhn_root.mkdir('train') + hdf5storage.write(matcontent, filename=train_root.join('digitStruct.mat')) + for i in range(3): + fn = train_root.join(f'{i+1}.png') + with open(fn, 'wb') as f: + f.write(file.getbuffer()) + # Packing data into an archive to simulate the real data set and bypass archive extraction + archive_path = root.join('svhn_train.tar') + shutil.make_archive(root.join('svhn_train'), 'tar', str(svhn_root)) + return str(archive_path) + + +@pytest.fixture(scope="session") +def mock_sroie_dataset(tmpdir_factory, mock_image_stream): + root = tmpdir_factory.mktemp('datasets') + sroie_root = root.mkdir('sroie2019_train_task1') + annotations_folder = sroie_root.mkdir('annotations') + image_folder = sroie_root.mkdir("images") + labels = ["72, 25, 326, 25, 326, 64, 72, 64, 'I'\n", + "50, 82, 440, 82, 440, 121, 50, 121, 'am'\n", + "205, 121, 285, 121, 285, 139, 205, 139, 'a'\n", + "18, 250, 440, 320, 250, 64, 85, 121, 'jedi'\n", + "400, 112, 252, 84, 112, 84, 75, 88, '!'"] + + file = BytesIO(mock_image_stream) + for i in range(3): + fn_i = image_folder.join(f"{i}.jpg") + with open(fn_i, 'wb') as f: + f.write(file.getbuffer()) + fn_l = annotations_folder.join(f"{i}.txt") + with open(fn_l, 'w') as f: + f.writelines(labels) + + # Packing data into an archive to simulate the real data set and bypass archive extraction + archive_path = root.join('sroie2019_train_task1.zip') + shutil.make_archive(root.join('sroie2019_train_task1'), 'zip', str(sroie_root)) + return str(archive_path) + + +@pytest.fixture(scope="session") +def mock_funsd_dataset(tmpdir_factory, mock_image_stream): + root = tmpdir_factory.mktemp('datasets') + funsd_root = root.mkdir('funsd') + sub_dataset_root = funsd_root.mkdir('dataset') + train_root = sub_dataset_root.mkdir('training_data') + image_folder = train_root.mkdir("images") + annotations_folder = train_root.mkdir("annotations") + labels = { + "form": [{ + "box": [84, 109, 136, 119], + "text": "I", + "label": "question", + "words": [{"box": [84, 109, 136, 119], "text": "I"}], + "linking": [[0, 37]], + "id": 0 + }, + { + "box": [85, 110, 145, 120], + "text": "am", + "label": "answer", + "words": [{"box": [85, 110, 145, 120], "text": "am"}], + "linking": [[1, 38]], + "id": 1 + }, + { + "box": [86, 115, 150, 125], + "text": "Luke", + "label": "answer", + "words": [{"box": [86, 115, 150, 125], "text": "Luke"}], + "linking": [[2, 44]], + "id": 2 + }] + } + + file = BytesIO(mock_image_stream) + for i in range(3): + fn_i = image_folder.join(f"{i}.png") + with open(fn_i, 'wb') as f: + f.write(file.getbuffer()) + fn_l = annotations_folder.join(f"{i}.json") + with open(fn_l, 'w') as f: + json.dump(labels, f) + + # Packing data into an archive to simulate the real data set and bypass archive extraction + archive_path = root.join('funsd.zip') + shutil.make_archive(root.join('funsd'), 'zip', str(funsd_root)) + return str(archive_path) + + +@pytest.fixture(scope="session") +def mock_cord_dataset(tmpdir_factory, mock_image_stream): + root = tmpdir_factory.mktemp('datasets') + cord_root = root.mkdir('cord_train') + image_folder = cord_root.mkdir("image") + annotations_folder = cord_root.mkdir("json") + labels = { + "dontcare": [], + "valid_line": [ + { + "words": [ + { + "quad": + {"x2": 270, "y3": 390, "x3": 270, "y4": 390, "x1": 256, "y1": 374, "x4": 256, "y2": 374}, + "is_key": 0, + "row_id": 2179893, + "text": "I" + } + ], + "category": "menu.cnt", + "group_id": 3 + }, + { + "words": [ + { + "quad": + {"x2": 270, "y3": 418, "x3": 270, "y4": 418, "x1": 258, "y1": 402, "x4": 258, "y2": 402}, + "is_key": 0, + "row_id": 2179894, + "text": "am" + } + ], + "category": "menu.cnt", + "group_id": 4 + }, + { + "words": [ + { + "quad": + {"x2": 272, "y3": 444, "x3": 272, "y4": 444, "x1": 258, "y1": 428, "x4": 258, "y2": 428}, + "is_key": 0, + "row_id": 2179895, + "text": "Luke" + } + ], + "category": "menu.cnt", + "group_id": 5 + }] + } + + file = BytesIO(mock_image_stream) + for i in range(3): + fn_i = image_folder.join(f"receipt_{i}.png") + with open(fn_i, 'wb') as f: + f.write(file.getbuffer()) + fn_l = annotations_folder.join(f"receipt_{i}.json") + with open(fn_l, 'w') as f: + json.dump(labels, f) + + # Packing data into an archive to simulate the real data set and bypass archive extraction + archive_path = root.join('cord_train.zip') + shutil.make_archive(root.join('cord_train'), 'zip', str(cord_root)) + return str(archive_path) + + +@pytest.fixture(scope="session") +def mock_synthtext_dataset(tmpdir_factory, mock_image_stream): + root = tmpdir_factory.mktemp('datasets') + synthtext_root = root.mkdir('SynthText') + image_folder = synthtext_root.mkdir("8") + annotation_file = synthtext_root.join('gt.mat') + labels = { + "imnames": [[["8/ballet_106_0.jpg"], ["8/ballet_106_1.jpg"], ["8/ballet_106_2.jpg"]]], + "wordBB": [[np.random.randint(1000, size=(2, 4, 5)) for _ in range(3)]], + "txt": [np.array([['I ', 'am\na ', 'Jedi ', '!'] for _ in range(3)])], + } + # hacky trick to write file into a LocalPath object with scipy.io.savemat + with tempfile.NamedTemporaryFile(mode='wb', delete=True) as f: + sio.savemat(f.name, labels) + shutil.copy(f.name, str(annotation_file)) + + file = BytesIO(mock_image_stream) + for i in range(3): + fn_i = image_folder.join(f"ballet_106_{i}.jpg") + with open(fn_i, 'wb') as f: + f.write(file.getbuffer()) + + # Packing data into an archive to simulate the real data set and bypass archive extraction + archive_path = root.join('SynthText.zip') + shutil.make_archive(root.join('SynthText'), 'zip', str(synthtext_root)) + return str(archive_path) + + +@pytest.fixture(scope="session") +def mock_doc_artefacts(tmpdir_factory, mock_image_stream): + root = tmpdir_factory.mktemp('datasets') + doc_root = root.mkdir('artefact_detection') + labels = { + '0.jpg': + [ + {'geometry': [0.94375, 0.4013671875, 0.99375, 0.4365234375], + 'label': 'bar_code'}, + {'geometry': [0.03125, 0.6923828125, 0.07875, 0.7294921875], + 'label': 'qr_code'}, + {'geometry': [0.1975, 0.1748046875, 0.39875, 0.2216796875], + 'label': 'bar_code'} + ], + '1.jpg': [ + {'geometry': [0.94375, 0.4013671875, 0.99375, 0.4365234375], + 'label': 'bar_code'}, + {'geometry': [0.03125, 0.6923828125, 0.07875, 0.7294921875], + 'label': 'qr_code'}, + {'geometry': [0.1975, 0.1748046875, 0.39875, 0.2216796875], + 'label': 'background'} + ], + '2.jpg': [ + {'geometry': [0.94375, 0.4013671875, 0.99375, 0.4365234375], + 'label': 'logo'}, + {'geometry': [0.03125, 0.6923828125, 0.07875, 0.7294921875], + 'label': 'qr_code'}, + {'geometry': [0.1975, 0.1748046875, 0.39875, 0.2216796875], + 'label': 'photo'} + ] + } + train_root = doc_root.mkdir('train') + label_file = train_root.join("labels.json") + + with open(label_file, 'w') as f: + json.dump(labels, f) + + image_folder = train_root.mkdir("images") + file = BytesIO(mock_image_stream) + for i in range(3): + fn = image_folder.join(f"{i}.jpg") + with open(fn, 'wb') as f: + f.write(file.getbuffer()) + # Packing data into an archive to simulate the real data set and bypass archive extraction + archive_path = root.join('artefact_detection.zip') + shutil.make_archive(root.join('artefact_detection'), 'zip', str(doc_root)) + return str(archive_path) + + +@pytest.fixture(scope="session") +def mock_iiit5k_dataset(tmpdir_factory, mock_image_stream): + root = tmpdir_factory.mktemp('datasets') + iiit5k_root = root.mkdir('IIIT5K') + image_folder = iiit5k_root.mkdir('train') + annotation_file = iiit5k_root.join('trainCharBound.mat') + labels = {'trainCharBound': + {"ImgName": ["train/0.png"], "chars": ["I"], "charBB": np.random.randint(50, size=(1, 4))}, + } + + # hacky trick to write file into a LocalPath object with scipy.io.savemat + with tempfile.NamedTemporaryFile(mode='wb', delete=True) as f: + sio.savemat(f.name, labels) + shutil.copy(f.name, str(annotation_file)) + + file = BytesIO(mock_image_stream) + for i in range(1): + fn_i = image_folder.join(f"{i}.png") + with open(fn_i, 'wb') as f: + f.write(file.getbuffer()) + + # Packing data into an archive to simulate the real data set and bypass archive extraction + archive_path = root.join('IIIT5K-Word-V3.tar') + shutil.make_archive(root.join('IIIT5K-Word-V3'), 'tar', str(iiit5k_root)) + return str(archive_path) + + +@pytest.fixture(scope="session") +def mock_svt_dataset(tmpdir_factory, mock_image_stream): + root = tmpdir_factory.mktemp('datasets') + svt_root = root.mkdir('svt1') + labels = """img/00_00.jpg +
341 Southwest 10th Avenue Portland OR
LIVING,ROOM,THEATERS + + LIVINGimg/00_01.jpg +
1100 Southwest 6th Avenue Portland OR
LULA + HOUSE + img/00_02.jpg +
341 Southwest 10th Avenue Portland OR
LIVING,ROOM,THEATERS + COST +
""" + + with open(svt_root.join("train.xml"), "w") as f: + f.write(labels) + + image_folder = svt_root.mkdir("img") + file = BytesIO(mock_image_stream) + for i in range(3): + fn = image_folder.join(f"00_0{i}.jpg") + with open(fn, 'wb') as f: + f.write(file.getbuffer()) + # Packing data into an archive to simulate the real data set and bypass archive extraction + archive_path = root.join('svt.zip') + shutil.make_archive(root.join('svt'), 'zip', str(svt_root)) + return str(archive_path) + + +@pytest.fixture(scope="session") +def mock_ic03_dataset(tmpdir_factory, mock_image_stream): + root = tmpdir_factory.mktemp('datasets') + ic03_root = root.mkdir('SceneTrialTrain') + labels = """images/0.jpg + LIVING + images/1.jpg + + + HOUSEimages/2.jpg + + COST + """ + + with open(ic03_root.join("words.xml"), "w") as f: + f.write(labels) + + image_folder = ic03_root.mkdir("images") + file = BytesIO(mock_image_stream) + for i in range(3): + fn = image_folder.join(f"{i}.jpg") + with open(fn, 'wb') as f: + f.write(file.getbuffer()) + # Packing data into an archive to simulate the real data set and bypass archive extraction + archive_path = root.join('ic03_train.zip') + shutil.make_archive(root.join('ic03_train'), 'zip', str(ic03_root)) + return str(archive_path) diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py new file mode 100644 index 0000000000..a5298c0cfb --- /dev/null +++ b/tests/pytorch/test_datasets_pt.py @@ -0,0 +1,436 @@ +import os +from shutil import move + +import numpy as np +import pytest +import torch +from torch.utils.data import DataLoader, RandomSampler + +from doctr import datasets +from doctr.transforms import Resize + + +def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False): + + # Fetch one sample + img, target = ds[0] + assert isinstance(img, torch.Tensor) + assert img.shape == (3, *input_size) + assert img.dtype == torch.float32 + assert isinstance(target, dict) + assert isinstance(target['boxes'], np.ndarray) and target['boxes'].dtype == np.float32 + if is_polygons: + assert target['boxes'].ndim == 3 and target['boxes'].shape[1:] == (4, 2) + else: + assert target['boxes'].ndim == 2 and target['boxes'].shape[1:] == (4,) + assert np.all(np.logical_and(target['boxes'] <= 1, target['boxes'] >= 0)) + if class_indices: + assert isinstance(target['labels'], np.ndarray) and target['labels'].dtype == np.int64 + else: + assert isinstance(target['labels'], list) and all(isinstance(s, str) for s in target['labels']) + assert len(target['labels']) == len(target['boxes']) + + # Check batching + loader = DataLoader( + ds, batch_size=batch_size, drop_last=True, sampler=RandomSampler(ds), num_workers=0, pin_memory=True, + collate_fn=ds.collate_fn) + + images, targets = next(iter(loader)) + assert isinstance(images, torch.Tensor) and images.shape == (batch_size, 3, *input_size) + assert isinstance(targets, list) and all(isinstance(elt, dict) for elt in targets) + + +def test_visiondataset(): + url = 'https://data.deepai.org/mnist.zip' + with pytest.raises(ValueError): + datasets.datasets.VisionDataset(url, download=False) + + dataset = datasets.datasets.VisionDataset(url, download=True, extract_archive=True) + assert len(dataset) == 0 + assert repr(dataset) == 'VisionDataset()' + + +def test_detection_dataset(mock_image_folder, mock_detection_label): + + input_size = (1024, 1024) + + ds = datasets.DetectionDataset( + img_folder=mock_image_folder, + label_path=mock_detection_label, + img_transforms=Resize(input_size), + ) + + assert len(ds) == 5 + img, target = ds[0] + assert isinstance(img, torch.Tensor) + assert img.dtype == torch.float32 + assert img.shape[-2:] == input_size + # Bounding boxes + assert isinstance(target, np.ndarray) and target.dtype == np.float32 + assert np.all(np.logical_and(target[:, :4] >= 0, target[:, :4] <= 1)) + assert target.shape[1] == 4 + + loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) + images, targets = next(iter(loader)) + assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size) + assert isinstance(targets, list) and all(isinstance(elt, np.ndarray) for elt in targets) + + # Rotated DS + rotated_ds = datasets.DetectionDataset( + img_folder=mock_image_folder, + label_path=mock_detection_label, + img_transforms=Resize(input_size), + use_polygons=True + ) + _, r_target = rotated_ds[0] + assert r_target.shape[1:] == (4, 2) + + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.DetectionDataset(mock_image_folder, mock_detection_label) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + + +def test_recognition_dataset(mock_image_folder, mock_recognition_label): + input_size = (32, 128) + ds = datasets.RecognitionDataset( + img_folder=mock_image_folder, + labels_path=mock_recognition_label, + img_transforms=Resize(input_size, preserve_aspect_ratio=True), + ) + assert len(ds) == 5 + image, label = ds[0] + assert isinstance(image, torch.Tensor) + assert image.shape[-2:] == input_size + assert image.dtype == torch.float32 + assert isinstance(label, str) + + loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) + images, labels = next(iter(loader)) + assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size) + assert isinstance(labels, list) and all(isinstance(elt, str) for elt in labels) + + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.RecognitionDataset(mock_image_folder, mock_recognition_label) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + + +@pytest.mark.parametrize( + "use_polygons", [False, True], +) +def test_ocrdataset(mock_ocrdataset, use_polygons): + + input_size = (512, 512) + + ds = datasets.OCRDataset( + *mock_ocrdataset, + img_transforms=Resize(input_size), + use_polygons=use_polygons, + ) + + assert len(ds) == 3 + _validate_dataset(ds, input_size, is_polygons=use_polygons) + + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.OCRDataset(*mock_ocrdataset) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + + +def test_charactergenerator(): + + input_size = (32, 32) + vocab = 'abcdef' + + ds = datasets.CharacterGenerator( + vocab=vocab, + num_samples=10, + cache_samples=True, + img_transforms=Resize(input_size), + ) + + assert len(ds) == 10 + image, label = ds[0] + assert isinstance(image, torch.Tensor) + assert image.shape[-2:] == input_size + assert image.dtype == torch.float32 + assert isinstance(label, int) and label < len(vocab) + + loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) + images, targets = next(iter(loader)) + assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size) + assert isinstance(targets, torch.Tensor) and targets.shape == (2,) + assert targets.dtype == torch.int64 + + +def test_wordgenerator(): + + input_size = (32, 128) + wordlen_range = (1, 10) + vocab = 'abcdef' + + ds = datasets.WordGenerator( + vocab=vocab, + min_chars=wordlen_range[0], + max_chars=wordlen_range[1], + num_samples=10, + cache_samples=True, + img_transforms=Resize(input_size), + ) + + assert len(ds) == 10 + image, target = ds[0] + assert isinstance(image, torch.Tensor) + assert image.shape[-2:] == input_size + assert image.dtype == torch.float32 + assert isinstance(target, str) and len(target) >= wordlen_range[0] and len(target) <= wordlen_range[1] + assert all(char in vocab for char in target) + + loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) + images, targets = next(iter(loader)) + assert isinstance(images, torch.Tensor) and images.shape == (2, 3, *input_size) + assert isinstance(targets, list) and len(targets) == 2 and all(isinstance(t, str) for t in targets) + + +@pytest.mark.parametrize( + "num_samples, rotate", + [ + [5, True], # Actual set has 229 train and 233 test samples + [5, False] + + ], +) +def test_ic13_dataset(num_samples, rotate, mock_ic13): + input_size = (512, 512) + ds = datasets.IC13( + *mock_ic13, + img_transforms=Resize(input_size), + use_polygons=rotate, + ) + + assert len(ds) == num_samples + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "num_samples, rotate", + [ + [3, True], # Actual set has 7149 train and 796 test samples + [3, False] + + ], +) +def test_imgur5k_dataset(num_samples, rotate, mock_imgur5k): + input_size = (512, 512) + ds = datasets.IMGUR5K( + *mock_imgur5k, + train=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + ) + + assert len(ds) == num_samples - 1 # -1 because of the test set 90 / 10 split + assert repr(ds) == f"IMGUR5K(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[32, 128], 3, True], # Actual set has 33402 training samples and 13068 test samples + [[32, 128], 3, False], + ], +) +def test_svhn(input_size, num_samples, rotate, mock_svhn_dataset): + # monkeypatch the path to temporary dataset + datasets.SVHN.TRAIN = (mock_svhn_dataset, None, "svhn_train.tar") + + ds = datasets.SVHN( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_svhn_dataset.split("/")[:-2]), cache_subdir=mock_svhn_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SVHN(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[512, 512], 3, True], # Actual set has 626 training samples and 360 test samples + [[512, 512], 3, False], + ], +) +def test_sroie(input_size, num_samples, rotate, mock_sroie_dataset): + # monkeypatch the path to temporary dataset + datasets.SROIE.TRAIN = (mock_sroie_dataset, None) + + ds = datasets.SROIE( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_sroie_dataset.split("/")[:-2]), cache_subdir=mock_sroie_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SROIE(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[512, 512], 3, True], # Actual set has 149 training samples and 50 test samples + [[512, 512], 3, False], + ], +) +def test_funsd(input_size, num_samples, rotate, mock_funsd_dataset): + # monkeypatch the path to temporary dataset + datasets.FUNSD.URL = mock_funsd_dataset + datasets.FUNSD.SHA256 = None + datasets.FUNSD.FILE_NAME = "funsd.zip" + + ds = datasets.FUNSD( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_funsd_dataset.split("/")[:-2]), cache_subdir=mock_funsd_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"FUNSD(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[512, 512], 3, True], # Actual set has 800 training samples and 100 test samples + [[512, 512], 3, False], + ], +) +def test_cord(input_size, num_samples, rotate, mock_cord_dataset): + # monkeypatch the path to temporary dataset + datasets.CORD.TRAIN = (mock_cord_dataset, None) + + ds = datasets.CORD( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_cord_dataset.split("/")[:-2]), cache_subdir=mock_cord_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"CORD(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[512, 512], 2, True], # Actual set has 772875 training samples and 85875 test samples + [[512, 512], 2, False], + ], +) +def test_synthtext(input_size, num_samples, rotate, mock_synthtext_dataset): + # monkeypatch the path to temporary dataset + datasets.SynthText.URL = mock_synthtext_dataset + datasets.SynthText.SHA256 = None + + ds = datasets.SynthText( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_synthtext_dataset.split("/")[:-2]), cache_subdir=mock_synthtext_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SynthText(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[512, 512], 3, True], # Actual set has 2700 training samples and 300 test samples + [[512, 512], 3, False], + ], +) +def test_artefact_detection(input_size, num_samples, rotate, mock_doc_artefacts): + # monkeypatch the path to temporary dataset + datasets.DocArtefacts.URL = mock_doc_artefacts + datasets.DocArtefacts.SHA256 = None + + ds = datasets.DocArtefacts( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_doc_artefacts.split("/")[:-2]), cache_subdir=mock_doc_artefacts.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"DocArtefacts(train={True})" + _validate_dataset(ds, input_size, class_indices=True, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[32, 128], 1, True], # Actual set has 2000 training samples and 3000 test samples + [[32, 128], 1, False], + ], +) +def test_iiit5k(input_size, num_samples, rotate, mock_iiit5k_dataset): + # monkeypatch the path to temporary dataset + datasets.IIIT5K.URL = mock_iiit5k_dataset + datasets.IIIT5K.SHA256 = None + + ds = datasets.IIIT5K( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_iiit5k_dataset.split("/")[:-2]), cache_subdir=mock_iiit5k_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"IIIT5K(train={True})" + _validate_dataset(ds, input_size, batch_size=1, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[512, 512], 3, True], # Actual set has 100 training samples and 249 test samples + [[512, 512], 3, False], + ], +) +def test_svt(input_size, num_samples, rotate, mock_svt_dataset): + # monkeypatch the path to temporary dataset + datasets.SVT.URL = mock_svt_dataset + datasets.SVT.SHA256 = None + + ds = datasets.SVT( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_svt_dataset.split("/")[:-2]), cache_subdir=mock_svt_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SVT(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[512, 512], 3, True], # Actual set has 246 training samples and 249 test samples + [[512, 512], 3, False], + ], +) +def test_ic03(input_size, num_samples, rotate, mock_ic03_dataset): + # monkeypatch the path to temporary dataset + datasets.IC03.TRAIN = (mock_ic03_dataset, None, "ic03_train.zip") + + ds = datasets.IC03( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_ic03_dataset.split("/")[:-2]), cache_subdir=mock_ic03_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"IC03(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) diff --git a/tests/pytorch/test_file_utils_pt.py b/tests/pytorch/test_file_utils_pt.py new file mode 100644 index 0000000000..7b36789561 --- /dev/null +++ b/tests/pytorch/test_file_utils_pt.py @@ -0,0 +1,5 @@ +from doctr.file_utils import is_torch_available + + +def test_file_utils(): + assert is_torch_available() diff --git a/tests/pytorch/test_io_image_pt.py b/tests/pytorch/test_io_image_pt.py new file mode 100644 index 0000000000..e6e4b4cb99 --- /dev/null +++ b/tests/pytorch/test_io_image_pt.py @@ -0,0 +1,50 @@ +import numpy as np +import pytest +import torch + +from doctr.io import decode_img_as_tensor, read_img_as_tensor, tensor_from_numpy + + +def test_read_img_as_tensor(mock_image_path): + + img = read_img_as_tensor(mock_image_path) + + assert isinstance(img, torch.Tensor) + assert img.dtype == torch.float32 + assert img.shape == (3, 900, 1200) + + img = read_img_as_tensor(mock_image_path, dtype=torch.float16) + assert img.dtype == torch.float16 + img = read_img_as_tensor(mock_image_path, dtype=torch.uint8) + assert img.dtype == torch.uint8 + + +def test_decode_img_as_tensor(mock_image_stream): + + img = decode_img_as_tensor(mock_image_stream) + + assert isinstance(img, torch.Tensor) + assert img.dtype == torch.float32 + assert img.shape == (3, 900, 1200) + + img = decode_img_as_tensor(mock_image_stream, dtype=torch.float16) + assert img.dtype == torch.float16 + img = decode_img_as_tensor(mock_image_stream, dtype=torch.uint8) + assert img.dtype == torch.uint8 + + +def test_tensor_from_numpy(mock_image_stream): + + with pytest.raises(ValueError): + tensor_from_numpy(np.zeros((256, 256, 3)), torch.int64) + + out = tensor_from_numpy(np.zeros((256, 256, 3), dtype=np.uint8)) + + assert isinstance(out, torch.Tensor) + assert out.dtype == torch.float32 + assert out.shape == (3, 256, 256) + + out = tensor_from_numpy(np.zeros((256, 256, 3), dtype=np.uint8), dtype=torch.float16) + assert out.dtype == torch.float16 + out = tensor_from_numpy(np.zeros((256, 256, 3), dtype=np.uint8), dtype=torch.uint8) + assert out.dtype == torch.uint8 diff --git a/tests/pytorch/test_models_classification_pt.py b/tests/pytorch/test_models_classification_pt.py new file mode 100644 index 0000000000..48ef463d26 --- /dev/null +++ b/tests/pytorch/test_models_classification_pt.py @@ -0,0 +1,91 @@ +import cv2 +import numpy as np +import pytest +import torch + +from doctr.models import classification +from doctr.models.classification.predictor import CropOrientationPredictor + + +@pytest.mark.parametrize( + "arch_name, input_shape, output_size", + [ + ["vgg16_bn_r", (3, 32, 32), (126,)], + ["resnet18", (3, 32, 32), (126,)], + ["resnet31", (3, 32, 32), (126,)], + ["magc_resnet31", (3, 32, 32), (126,)], + ["mobilenet_v3_small", (3, 32, 32), (126,)], + ["mobilenet_v3_large", (3, 32, 32), (126,)], + ], +) +def test_classification_architectures(arch_name, input_shape, output_size): + # Model + batch_size = 2 + model = classification.__dict__[arch_name](pretrained=True).eval() + # Forward + with torch.no_grad(): + out = model(torch.rand((batch_size, *input_shape), dtype=torch.float32)) + # Output checks + assert isinstance(out, torch.Tensor) + assert out.dtype == torch.float32 + assert out.numpy().shape == (batch_size, *output_size) + # Check FP16 + if torch.cuda.is_available(): + model = model.half().cuda() + with torch.no_grad(): + out = model(torch.rand((batch_size, *input_shape), dtype=torch.float16).cuda()) + assert out.dtype == torch.float16 + + +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["mobilenet_v3_small_orientation", (3, 128, 128)], + ], +) +def test_classification_models(arch_name, input_shape): + batch_size = 8 + model = classification.__dict__[arch_name](pretrained=False, input_shape=input_shape).eval() + assert isinstance(model, torch.nn.Module) + input_tensor = torch.rand((batch_size, *input_shape)) + + if torch.cuda.is_available(): + model.cuda() + input_tensor = input_tensor.cuda() + out = model(input_tensor) + assert isinstance(out, torch.Tensor) + assert out.shape == (8, 4) + + +@pytest.mark.parametrize( + "arch_name", + [ + "mobilenet_v3_small_orientation", + ], +) +def test_classification_zoo(arch_name): + batch_size = 16 + # Model + predictor = classification.zoo.crop_orientation_predictor(arch_name, pretrained=False) + predictor.model.eval() + # object check + assert isinstance(predictor, CropOrientationPredictor) + input_tensor = torch.rand((batch_size, 3, 128, 128)) + if torch.cuda.is_available(): + predictor.model.cuda() + input_tensor = input_tensor.cuda() + + with torch.no_grad(): + out = predictor(input_tensor) + out = predictor(input_tensor) + assert isinstance(out, list) and len(out) == batch_size + assert all(isinstance(pred, int) for pred in out) + + +def test_crop_orientation_model(mock_text_box): + text_box_0 = cv2.imread(mock_text_box) + text_box_90 = np.rot90(text_box_0, 1) + text_box_180 = np.rot90(text_box_0, 2) + text_box_270 = np.rot90(text_box_0, 3) + classifier = classification.crop_orientation_predictor("mobilenet_v3_small_orientation", pretrained=True) + assert classifier([text_box_0, text_box_90, text_box_180, text_box_270]) == [0, 1, 2, 3] diff --git a/tests/pytorch/test_models_detection_pt.py b/tests/pytorch/test_models_detection_pt.py new file mode 100644 index 0000000000..0b1316e6f4 --- /dev/null +++ b/tests/pytorch/test_models_detection_pt.py @@ -0,0 +1,93 @@ +import numpy as np +import pytest +import torch + +from doctr.models import detection +from doctr.models.detection._utils import dilate, erode +from doctr.models.detection.predictor import DetectionPredictor + + +@pytest.mark.parametrize( + "arch_name, input_shape, output_size, out_prob", + [ + ["db_resnet34", (3, 512, 512), (1, 512, 512), True], + ["db_resnet50", (3, 512, 512), (1, 512, 512), True], + ["db_mobilenet_v3_large", (3, 512, 512), (1, 512, 512), True], + ["linknet_resnet18", (3, 512, 512), (1, 512, 512), False], + ], +) +def test_detection_models(arch_name, input_shape, output_size, out_prob): + batch_size = 2 + model = detection.__dict__[arch_name](pretrained=False).eval() + assert isinstance(model, torch.nn.Module) + input_tensor = torch.rand((batch_size, *input_shape)) + target = [ + np.array([[.5, .5, 1, 1], [.5, .5, .8, .8]], dtype=np.float32), + np.array([[.5, .5, 1, 1], [.5, .5, .8, .9]], dtype=np.float32), + ] + if torch.cuda.is_available(): + model.cuda() + input_tensor = input_tensor.cuda() + out = model(input_tensor, target, return_model_output=True, return_preds=True) + assert isinstance(out, dict) + assert len(out) == 3 + # Check proba map + assert out['out_map'].shape == (batch_size, *output_size) + assert out['out_map'].dtype == torch.float32 + if out_prob: + assert torch.all((out['out_map'] >= 0) & (out['out_map'] <= 1)) + # Check boxes + for boxes in out['preds']: + assert boxes.shape[1] == 5 + assert np.all(boxes[:, :2] < boxes[:, 2:4]) + assert np.all(boxes[:, :4] >= 0) and np.all(boxes[:, :4] <= 1) + # Check loss + assert isinstance(out['loss'], torch.Tensor) + # Check the rotated case (same targets) + target = [ + np.array([[[.5, .5], [1, .5], [1, 1], [.5, 1]], [[.5, .5], [.8, .5], [.8, .8], [.5, .8]]], dtype=np.float32), + np.array([[[.5, .5], [1, .5], [1, 1], [.5, 1]], [[.5, .5], [.8, .5], [.8, .9], [.5, .9]]], dtype=np.float32), + ] + loss = model(input_tensor, target)['loss'] + assert isinstance(loss, torch.Tensor) and ((loss - out['loss']).abs() / loss).item() < 1e-1 + + +@pytest.mark.parametrize( + "arch_name", + [ + "db_resnet34", + "db_resnet50", + "db_mobilenet_v3_large", + "linknet_resnet18", + ], +) +def test_detection_zoo(arch_name): + # Model + predictor = detection.zoo.detection_predictor(arch_name, pretrained=False) + predictor.model.eval() + # object check + assert isinstance(predictor, DetectionPredictor) + input_tensor = torch.rand((2, 3, 1024, 1024)) + if torch.cuda.is_available(): + predictor.model.cuda() + input_tensor = input_tensor.cuda() + + with torch.no_grad(): + out = predictor(input_tensor) + assert all(isinstance(boxes, np.ndarray) and boxes.shape[1] == 5 for boxes in out) + + +def test_erode(): + x = torch.zeros((1, 1, 3, 3)) + x[..., 1, 1] = 1 + expected = torch.zeros((1, 1, 3, 3)) + out = erode(x, 3) + assert torch.equal(out, expected) + + +def test_dilate(): + x = torch.zeros((1, 1, 3, 3)) + x[..., 1, 1] = 1 + expected = torch.ones((1, 1, 3, 3)) + out = dilate(x, 3) + assert torch.equal(out, expected) diff --git a/tests/pytorch/test_models_obj_detection_pt.py b/tests/pytorch/test_models_obj_detection_pt.py new file mode 100644 index 0000000000..6b15c06f12 --- /dev/null +++ b/tests/pytorch/test_models_obj_detection_pt.py @@ -0,0 +1,34 @@ +import pytest +import torch + +from doctr.models import obj_detection + + +@pytest.mark.parametrize( + "arch_name, input_shape, pretrained", + [ + ["fasterrcnn_mobilenet_v3_large_fpn", (3, 512, 512), True], + ["fasterrcnn_mobilenet_v3_large_fpn", (3, 512, 512), False], + ], +) +def test_detection_models(arch_name, input_shape, pretrained): + batch_size = 2 + model = obj_detection.__dict__[arch_name](pretrained=pretrained).eval() + assert isinstance(model, torch.nn.Module) + input_tensor = torch.rand((batch_size, *input_shape)) + if torch.cuda.is_available(): + model.cuda() + input_tensor = input_tensor.cuda() + out = model(input_tensor) + assert isinstance(out, list) and all(isinstance(det, dict) for det in out) + + # Train mode + model = model.train() + target = [ + dict(boxes=torch.tensor([[.5, .5, 1, 1]], dtype=torch.float32), labels=torch.tensor((0,), dtype=torch.long)), + dict(boxes=torch.tensor([[.5, .5, 1, 1]], dtype=torch.float32), labels=torch.tensor((0,), dtype=torch.long)), + ] + if torch.cuda.is_available(): + target = [{k: v.cuda() for k, v in t.items()} for t in target] + out = model(input_tensor, target) + assert isinstance(out, dict) and all(isinstance(v, torch.Tensor) for v in out.values()) diff --git a/tests/pytorch/test_models_preprocessor_pt.py b/tests/pytorch/test_models_preprocessor_pt.py new file mode 100644 index 0000000000..991f3ba49e --- /dev/null +++ b/tests/pytorch/test_models_preprocessor_pt.py @@ -0,0 +1,48 @@ +import numpy as np +import pytest +import torch + +from doctr.models.preprocessor import PreProcessor + + +@pytest.mark.parametrize( + "batch_size, output_size, input_tensor, expected_batches, expected_value", + [ + [2, (128, 128), np.full((3, 256, 128, 3), 255, dtype=np.uint8), 1, .5], # numpy uint8 + [2, (128, 128), np.ones((3, 256, 128, 3), dtype=np.float32), 1, .5], # numpy fp32 + [2, (128, 128), torch.full((3, 3, 256, 128), 255, dtype=torch.uint8), 1, .5], # torch uint8 + [2, (128, 128), torch.ones((3, 3, 256, 128), dtype=torch.float32), 1, .5], # torch fp32 + [2, (128, 128), torch.ones((3, 3, 256, 128), dtype=torch.float16), 1, .5], # torch fp16 + [2, (128, 128), [np.full((256, 128, 3), 255, dtype=np.uint8)] * 3, 2, .5], # list of numpy uint8 + [2, (128, 128), [np.ones((256, 128, 3), dtype=np.float32)] * 3, 2, .5], # list of numpy fp32 + [2, (128, 128), [torch.full((3, 256, 128), 255, dtype=torch.uint8)] * 3, 2, .5], # list of torch uint8 + [2, (128, 128), [torch.ones((3, 256, 128), dtype=torch.float32)] * 3, 2, .5], # list of torch fp32 + [2, (128, 128), [torch.ones((3, 256, 128), dtype=torch.float16)] * 3, 2, .5], # list of torch fp32 + ], +) +def test_preprocessor(batch_size, output_size, input_tensor, expected_batches, expected_value): + + processor = PreProcessor(output_size, batch_size) + + # Invalid input type + with pytest.raises(TypeError): + processor(42) + # 4D check + with pytest.raises(AssertionError): + processor(np.full((256, 128, 3), 255, dtype=np.uint8)) + with pytest.raises(TypeError): + processor(np.full((1, 256, 128, 3), 255, dtype=np.int32)) + # 3D check + with pytest.raises(AssertionError): + processor([np.full((3, 256, 128, 3), 255, dtype=np.uint8)]) + with pytest.raises(TypeError): + processor([np.full((256, 128, 3), 255, dtype=np.int32)]) + + with torch.no_grad(): + out = processor(input_tensor) + assert isinstance(out, list) and len(out) == expected_batches + assert all(isinstance(b, torch.Tensor) for b in out) + assert all(b.dtype == torch.float32 for b in out) + assert all(b.shape[-2:] == output_size for b in out) + assert all(torch.all(b == expected_value) for b in out) + assert len(repr(processor).split('\n')) == 4 diff --git a/tests/pytorch/test_models_recognition_pt.py b/tests/pytorch/test_models_recognition_pt.py new file mode 100644 index 0000000000..8d3ab192f2 --- /dev/null +++ b/tests/pytorch/test_models_recognition_pt.py @@ -0,0 +1,85 @@ +import pytest +import torch + +from doctr.models import recognition +from doctr.models.recognition.crnn.pytorch import CTCPostProcessor +from doctr.models.recognition.master.pytorch import MASTERPostProcessor +from doctr.models.recognition.predictor import RecognitionPredictor + + +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["crnn_vgg16_bn", (3, 32, 128)], + ["crnn_mobilenet_v3_small", (3, 32, 128)], + ["crnn_mobilenet_v3_large", (3, 32, 128)], + ["sar_resnet31", (3, 32, 128)], + ["master", (3, 48, 160)], + ], +) +def test_recognition_models(arch_name, input_shape, mock_vocab): + batch_size = 4 + model = recognition.__dict__[arch_name](vocab=mock_vocab, pretrained=False, input_shape=input_shape).eval() + assert isinstance(model, torch.nn.Module) + input_tensor = torch.rand((batch_size, *input_shape)) + target = ["i", "am", "a", "jedi"] + + if torch.cuda.is_available(): + model.cuda() + input_tensor = input_tensor.cuda() + out = model(input_tensor, target, return_model_output=True, return_preds=True) + assert isinstance(out, dict) + assert len(out) == 3 + assert isinstance(out['preds'], list) + assert len(out['preds']) == batch_size + assert all(isinstance(word, str) and isinstance(conf, float) and 0 <= conf <= 1 for word, conf in out['preds']) + assert isinstance(out['out_map'], torch.Tensor) + assert out['out_map'].dtype == torch.float32 + assert isinstance(out['loss'], torch.Tensor) + + +@pytest.mark.parametrize( + "post_processor, input_shape", + [ + [CTCPostProcessor, [2, 119, 30]], + [MASTERPostProcessor, [2, 119, 30]], + ], +) +def test_reco_postprocessors(post_processor, input_shape, mock_vocab): + processor = post_processor(mock_vocab) + decoded = processor(torch.rand(*input_shape)) + assert isinstance(decoded, list) + assert all(isinstance(word, str) and isinstance(conf, float) and 0 <= conf <= 1 for word, conf in decoded) + assert len(decoded) == input_shape[0] + assert all(char in mock_vocab for word, _ in decoded for char in word) + # Repr + assert repr(processor) == f'{post_processor.__name__}(vocab_size={len(mock_vocab)})' + + +@pytest.mark.parametrize( + "arch_name", + [ + "crnn_vgg16_bn", + "crnn_mobilenet_v3_small", + "crnn_mobilenet_v3_large", + "sar_resnet31", + "master" + ], +) +def test_recognition_zoo(arch_name): + batch_size = 2 + # Model + predictor = recognition.zoo.recognition_predictor(arch_name, pretrained=False) + predictor.model.eval() + # object check + assert isinstance(predictor, RecognitionPredictor) + input_tensor = torch.rand((batch_size, 3, 128, 128)) + if torch.cuda.is_available(): + predictor.model.cuda() + input_tensor = input_tensor.cuda() + + with torch.no_grad(): + out = predictor(input_tensor) + out = predictor(input_tensor) + assert isinstance(out, list) and len(out) == batch_size + assert all(isinstance(word, str) and isinstance(conf, float) for word, conf in out) diff --git a/tests/pytorch/test_models_utils_pt.py b/tests/pytorch/test_models_utils_pt.py new file mode 100644 index 0000000000..4ab3eb3586 --- /dev/null +++ b/tests/pytorch/test_models_utils_pt.py @@ -0,0 +1,30 @@ +import os + +import pytest +from torch import nn + +from doctr.models.utils import conv_sequence_pt, load_pretrained_params + + +def test_load_pretrained_params(tmpdir_factory): + + model = nn.Sequential(nn.Linear(8, 8), nn.ReLU(), nn.Linear(8, 4)) + # Retrieve this URL + url = "https://github.com/mindee/doctr/releases/download/v0.2.1/tmp_checkpoint-6f0ce0e6.pt" + # Temp cache dir + cache_dir = tmpdir_factory.mktemp("cache") + # Pass an incorrect hash + with pytest.raises(ValueError): + load_pretrained_params(model, url, "mywronghash", cache_dir=str(cache_dir)) + # Let tit resolve the hash from the file name + load_pretrained_params(model, url, cache_dir=str(cache_dir)) + # Check that the file was downloaded & the archive extracted + assert os.path.exists(cache_dir.join('models').join(url.rpartition("/")[-1])) + + +def test_conv_sequence(): + + assert len(conv_sequence_pt(3, 8, kernel_size=3)) == 1 + assert len(conv_sequence_pt(3, 8, True, kernel_size=3)) == 2 + assert len(conv_sequence_pt(3, 8, False, True, kernel_size=3)) == 2 + assert len(conv_sequence_pt(3, 8, True, True, kernel_size=3)) == 3 diff --git a/tests/pytorch/test_models_zoo_pt.py b/tests/pytorch/test_models_zoo_pt.py new file mode 100644 index 0000000000..ae10ef85ee --- /dev/null +++ b/tests/pytorch/test_models_zoo_pt.py @@ -0,0 +1,77 @@ +import numpy as np +import pytest + +from doctr import models +from doctr.io import Document, DocumentFile +from doctr.models import detection, recognition +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.predictor import OCRPredictor +from doctr.models.preprocessor import PreProcessor +from doctr.models.recognition.predictor import RecognitionPredictor + + +@pytest.mark.parametrize( + "assume_straight_pages, straighten_pages", + [ + [True, False], + [True, True], + ] +) +def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): + det_bsize = 4 + det_predictor = DetectionPredictor( + PreProcessor(output_size=(512, 512), batch_size=det_bsize), + detection.db_mobilenet_v3_large(pretrained=False, pretrained_backbone=False) + ) + + assert not det_predictor.model.training + + reco_bsize = 32 + reco_predictor = RecognitionPredictor( + PreProcessor(output_size=(32, 128), batch_size=reco_bsize, preserve_aspect_ratio=True), + recognition.crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=mock_vocab) + ) + + assert not reco_predictor.model.training + + doc = DocumentFile.from_pdf(mock_pdf).as_images() + + predictor = OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=assume_straight_pages, + straighten_pages=straighten_pages, + ) + + out = predictor(doc) + assert isinstance(out, Document) + assert len(out.pages) == 2 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + +@pytest.mark.parametrize( + "det_arch, reco_arch", + [ + ["db_mobilenet_v3_large", "crnn_mobilenet_v3_large"], + ], +) +def test_zoo_models(det_arch, reco_arch): + # Model + predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True) + # Output checks + assert isinstance(predictor, OCRPredictor) + + doc = [np.zeros((512, 512, 3), dtype=np.uint8)] + out = predictor(doc) + # Document + assert isinstance(out, Document) + + # The input doc has 1 page + assert len(out.pages) == 1 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) diff --git a/tests/pytorch/test_transforms_pt.py b/tests/pytorch/test_transforms_pt.py new file mode 100644 index 0000000000..2725a0e544 --- /dev/null +++ b/tests/pytorch/test_transforms_pt.py @@ -0,0 +1,274 @@ +import math + +import numpy as np +import pytest +import torch + +from doctr.transforms import (ChannelShuffle, ColorInversion, GaussianNoise, RandomCrop, RandomHorizontalFlip, + RandomRotate, Resize) +from doctr.transforms.functional import crop_detection, rotate_sample + + +def test_resize(): + output_size = (32, 32) + transfo = Resize(output_size) + input_t = torch.ones((3, 64, 64), dtype=torch.float32) + out = transfo(input_t) + + assert torch.all(out == 1) + assert out.shape[-2:] == output_size + assert repr(transfo) == f"Resize(output_size={output_size}, interpolation='bilinear')" + + transfo = Resize(output_size, preserve_aspect_ratio=True) + input_t = torch.ones((3, 32, 64), dtype=torch.float32) + out = transfo(input_t) + + assert out.shape[-2:] == output_size + assert not torch.all(out == 1) + # Asymetric padding + assert torch.all(out[:, -1] == 0) and torch.all(out[:, 0] == 1) + + # Symetric padding + transfo = Resize(output_size, preserve_aspect_ratio=True, symmetric_pad=True) + assert repr(transfo) == (f"Resize(output_size={output_size}, interpolation='bilinear', " + f"preserve_aspect_ratio=True, symmetric_pad=True)") + out = transfo(input_t) + assert out.shape[-2:] == output_size + # symetric padding + assert torch.all(out[:, -1] == 0) and torch.all(out[:, 0] == 0) + + # Inverse aspect ratio + input_t = torch.ones((3, 64, 32), dtype=torch.float32) + out = transfo(input_t) + + assert not torch.all(out == 1) + assert out.shape[-2:] == output_size + + # Same aspect ratio + output_size = (32, 128) + transfo = Resize(output_size, preserve_aspect_ratio=True) + out = transfo(torch.ones((3, 16, 64), dtype=torch.float32)) + assert out.shape[-2:] == output_size + + # FP16 + input_t = torch.ones((3, 64, 64), dtype=torch.float16) + out = transfo(input_t) + assert out.dtype == torch.float16 + + +@pytest.mark.parametrize( + "rgb_min", + [ + 0.2, + 0.4, + 0.6, + ], +) +def test_invert_colorize(rgb_min): + + transfo = ColorInversion(min_val=rgb_min) + input_t = torch.ones((8, 3, 32, 32), dtype=torch.float32) + out = transfo(input_t) + assert torch.all(out <= 1 - rgb_min + 1e-4) + assert torch.all(out >= 0) + + input_t = torch.full((8, 3, 32, 32), 255, dtype=torch.uint8) + out = transfo(input_t) + assert torch.all(out <= int(math.ceil(255 * (1 - rgb_min + 1e-4)))) + assert torch.all(out >= 0) + + # FP16 + input_t = torch.ones((8, 3, 32, 32), dtype=torch.float16) + out = transfo(input_t) + assert out.dtype == torch.float16 + + +def test_rotate_sample(): + img = torch.ones((3, 200, 100), dtype=torch.float32) + boxes = np.array([0, 0, 100, 200])[None, ...] + polys = np.stack((boxes[..., [0, 1]], boxes[..., [2, 1]], boxes[..., [2, 3]], boxes[..., [0, 3]]), axis=1) + rel_boxes = np.array([0, 0, 1, 1], dtype=np.float32)[None, ...] + rel_polys = np.stack( + (rel_boxes[..., [0, 1]], rel_boxes[..., [2, 1]], rel_boxes[..., [2, 3]], rel_boxes[..., [0, 3]]), + axis=1 + ) + + # No angle + rotated_img, rotated_geoms = rotate_sample(img, boxes, 0, False) + assert torch.all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + rotated_img, rotated_geoms = rotate_sample(img, boxes, 0, True) + assert torch.all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 0, False) + assert torch.all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 0, True) + assert torch.all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + + # No expansion + expected_img = torch.zeros((3, 200, 100), dtype=torch.float32) + expected_img[:, 50: 150] = 1 + expected_polys = np.array([[0, .75], [0, .25], [1, .25], [1, .75]])[None, ...] + rotated_img, rotated_geoms = rotate_sample(img, boxes, 90, False) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 90, False) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_boxes, 90, False) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_polys, 90, False) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + + # Expansion + expected_img = torch.ones((3, 100, 200), dtype=torch.float32) + expected_polys = np.array([[0, 1], [0, 0], [1, 0], [1, 1]], dtype=np.float32)[None, ...] + rotated_img, rotated_geoms = rotate_sample(img, boxes, 90, True) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 90, True) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_boxes, 90, True) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_polys, 90, True) + assert torch.all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + + with pytest.raises(AssertionError): + rotate_sample(img, boxes[None, ...], 90, False) + + +def test_random_rotate(): + rotator = RandomRotate(max_angle=10., expand=False) + input_t = torch.ones((3, 50, 50), dtype=torch.float32) + boxes = np.array([ + [15, 20, 35, 30] + ]) + r_img, r_boxes = rotator(input_t, boxes) + assert r_img.shape == input_t.shape + + rotator = RandomRotate(max_angle=10., expand=True) + r_img, r_boxes = rotator(input_t, boxes) + assert r_img.shape != input_t.shape + + # FP16 (only on GPU) + if torch.cuda.is_available(): + input_t = torch.ones((3, 50, 50), dtype=torch.float16).cuda() + r_img, _ = rotator(input_t, boxes) + assert r_img.dtype == torch.float16 + + +def test_crop_detection(): + img = torch.ones((3, 50, 50), dtype=torch.float32) + abs_boxes = np.array([ + [15, 20, 35, 30], + [5, 10, 10, 20], + ]) + crop_box = (12 / 50, 23 / 50, 50 / 50, 50 / 50) + c_img, c_boxes = crop_detection(img, abs_boxes, crop_box) + assert c_img.shape == (3, 26, 37) + assert c_boxes.shape == (1, 4) + assert np.all(c_boxes == np.array([15 - 12, 0, 35 - 12, 30 - 23])[None, ...]) + + rel_boxes = np.array([ + [.3, .4, .7, .6], + [.1, .2, .2, .4], + ]) + crop_box = (0.24, 0.46, 1.0, 1.0) + c_img, c_boxes = crop_detection(img, rel_boxes, crop_box) + assert c_img.shape == (3, 26, 37) + assert c_boxes.shape == (1, 4) + assert np.abs(c_boxes - np.array([.06 / .76, 0., .46 / .76, .14 / .54])[None, ...]).mean() < 1e-7 + + # FP16 + img = torch.ones((3, 50, 50), dtype=torch.float16) + c_img, _ = crop_detection(img, abs_boxes, crop_box) + assert c_img.dtype == torch.float16 + + with pytest.raises(AssertionError): + crop_detection(img, abs_boxes, (2, 6, 24, 56)) + + +def test_random_crop(): + cropper = RandomCrop(scale=(0.5, 1.), ratio=(0.75, 1.33)) + input_t = torch.ones((3, 50, 50), dtype=torch.float32) + boxes = np.array([ + [15, 20, 35, 30] + ]) + img, target = cropper(input_t, dict(boxes=boxes)) + # Check the scale + assert img.shape[-1] * img.shape[-2] >= 0.4 * input_t.shape[-1] * input_t.shape[-2] + # Check aspect ratio + assert 0.65 <= img.shape[-2] / img.shape[-1] <= 1.5 + # Check the target + assert np.all(target['boxes'] >= 0) + assert np.all(target['boxes'][:, [0, 2]] <= img.shape[-1]) and np.all(target['boxes'][:, [1, 3]] <= img.shape[-2]) + + +@pytest.mark.parametrize( + "input_dtype, input_size", + [ + [torch.float32, (3, 32, 32)], + [torch.uint8, (3, 32, 32)], + ], +) +def test_channel_shuffle(input_dtype, input_size): + transfo = ChannelShuffle() + input_t = torch.rand(input_size, dtype=torch.float32) + if input_dtype == torch.uint8: + input_t = (255 * input_t).round() + input_t = input_t.to(dtype=input_dtype) + out = transfo(input_t) + assert isinstance(out, torch.Tensor) + assert out.shape == input_size + assert out.dtype == input_dtype + # Ensure that nothing has changed apart from channel order + if input_dtype == torch.uint8: + assert torch.all(input_t.sum(0) == out.sum(0)) + else: + # Float approximation + assert (input_t.sum(0) - out.sum(0)).abs().mean() < 1e-7 + + +@pytest.mark.parametrize( + "input_dtype,input_shape", + [ + [torch.float32, (3, 32, 32)], + [torch.uint8, (3, 32, 32)], + ] +) +def test_gaussian_noise(input_dtype, input_shape): + transform = GaussianNoise(0., 1.) + input_t = torch.rand(input_shape, dtype=torch.float32) + if input_dtype == torch.uint8: + input_t = (255 * input_t).round() + input_t = input_t.to(dtype=input_dtype) + transformed = transform(input_t) + assert isinstance(transformed, torch.Tensor) + assert transformed.shape == input_shape + assert transformed.dtype == input_dtype + assert torch.any(transformed != input_t) + assert torch.all(transformed >= 0) + if input_dtype == torch.uint8: + assert torch.all(transformed <= 255) + else: + assert torch.all(transformed <= 1.) + + +@pytest.mark.parametrize("p", [1, 0]) +def test_randomhorizontalflip(p): + # testing for 2 cases, with flip probability 1 and 0. + transform = RandomHorizontalFlip(p) + input_t = torch.ones((3, 32, 32), dtype=torch.float32) + input_t[..., :16] = 0 + target = {"boxes": np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32), "labels": np.ones(1, dtype=np.int64)} + transformed, _target = transform(input_t, target) + assert isinstance(transformed, torch.Tensor) + assert transformed.shape == input_t.shape + assert transformed.dtype == input_t.dtype + # integrity check of targets + assert isinstance(_target, dict) + assert all(isinstance(val, np.ndarray) for val in _target.values()) + assert _target["boxes"].dtype == np.float32 + assert _target["labels"].dtype == np.int64 + if p == 1: + assert np.all(_target["boxes"] == np.array([[0.7, 0.1, 0.9, 0.4]], dtype=np.float32)) + assert torch.all(transformed.mean((0, 1)) == torch.tensor([1] * 16 + [0] * 16, dtype=torch.float32)) + elif p == 0: + assert np.all(_target["boxes"] == np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)) + assert torch.all(transformed.mean((0, 1)) == torch.tensor([0] * 16 + [1] * 16, dtype=torch.float32)) + assert np.all(_target["labels"] == np.ones(1, dtype=np.int64)) diff --git a/tests/requirements.txt b/tests/requirements.txt new file mode 100644 index 0000000000..99025de2e0 --- /dev/null +++ b/tests/requirements.txt @@ -0,0 +1,5 @@ +pytest>=5.3.2 +requests>=2.20.0 +hdf5storage>=0.1.18 +coverage>=4.5.4 +requirements-parser==0.2.0 diff --git a/tests/tensorflow/test_datasets_loader_tf.py b/tests/tensorflow/test_datasets_loader_tf.py new file mode 100644 index 0000000000..5e5e86ab04 --- /dev/null +++ b/tests/tensorflow/test_datasets_loader_tf.py @@ -0,0 +1,79 @@ +from typing import List, Tuple + +import tensorflow as tf + +from doctr.datasets import DataLoader + + +class MockDataset: + def __init__(self, input_size): + + self.data: List[Tuple[float, bool]] = [ + (1, True), + (0, False), + (0.5, True), + ] + self.input_size = input_size + + def __len__(self): + return len(self.data) + + def __getitem__(self, index): + val, label = self.data[index] + return tf.cast(tf.fill(self.input_size, val), dtype=tf.float32), tf.constant(label, dtype=tf.bool) + + +class MockDatasetBis(MockDataset): + + @staticmethod + def collate_fn(samples): + x, y = zip(*samples) + return tf.stack(x, axis=0), list(y) + + +def test_dataloader(): + + loader = DataLoader( + MockDataset((32, 32)), + shuffle=True, + batch_size=2, + drop_last=True, + ) + + ds_iter = iter(loader) + num_batches = 0 + for x, y in ds_iter: + num_batches += 1 + assert len(loader) == 1 + assert num_batches == 1 + assert isinstance(x, tf.Tensor) and isinstance(y, tf.Tensor) + assert x.shape == (2, 32, 32) + assert y.shape == (2,) + + # Drop last + loader = DataLoader( + MockDataset((32, 32)), + shuffle=True, + batch_size=2, + drop_last=False, + ) + ds_iter = iter(loader) + num_batches = 0 + for x, y in ds_iter: + num_batches += 1 + assert loader.num_batches == 2 + assert num_batches == 2 + + # Custom collate + loader = DataLoader( + MockDatasetBis((32, 32)), + shuffle=True, + batch_size=2, + drop_last=False, + ) + + ds_iter = iter(loader) + x, y = next(ds_iter) + assert isinstance(x, tf.Tensor) and isinstance(y, list) + assert x.shape == (2, 32, 32) + assert len(y) == 2 diff --git a/tests/tensorflow/test_datasets_tf.py b/tests/tensorflow/test_datasets_tf.py new file mode 100644 index 0000000000..2933baa3b2 --- /dev/null +++ b/tests/tensorflow/test_datasets_tf.py @@ -0,0 +1,424 @@ +import os +from shutil import move + +import numpy as np +import pytest +import tensorflow as tf + +from doctr import datasets +from doctr.datasets import DataLoader +from doctr.transforms import Resize + + +def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False): + + # Fetch one sample + img, target = ds[0] + assert isinstance(img, tf.Tensor) + assert img.shape == (*input_size, 3) + assert img.dtype == tf.float32 + assert isinstance(target, dict) + assert isinstance(target['boxes'], np.ndarray) and target['boxes'].dtype == np.float32 + if is_polygons: + assert target['boxes'].ndim == 3 and target['boxes'].shape[1:] == (4, 2) + else: + assert target['boxes'].ndim == 2 and target['boxes'].shape[1:] == (4,) + assert np.all(np.logical_and(target['boxes'] <= 1, target['boxes'] >= 0)) + if class_indices: + assert isinstance(target['labels'], np.ndarray) and target['labels'].dtype == np.int64 + else: + assert isinstance(target['labels'], list) and all(isinstance(s, str) for s in target['labels']) + assert len(target['labels']) == len(target['boxes']) + + # Check batching + loader = DataLoader(ds, batch_size=batch_size) + + images, targets = next(iter(loader)) + assert isinstance(images, tf.Tensor) and images.shape == (batch_size, *input_size, 3) + assert isinstance(targets, list) and all(isinstance(elt, dict) for elt in targets) + + +def test_detection_dataset(mock_image_folder, mock_detection_label): + + input_size = (1024, 1024) + + ds = datasets.DetectionDataset( + img_folder=mock_image_folder, + label_path=mock_detection_label, + img_transforms=Resize(input_size), + ) + + assert len(ds) == 5 + img, target = ds[0] + assert isinstance(img, tf.Tensor) + assert img.shape[:2] == input_size + assert img.dtype == tf.float32 + # Bounding boxes + assert isinstance(target, np.ndarray) and target.dtype == np.float32 + assert np.all(np.logical_and(target[:, :4] >= 0, target[:, :4] <= 1)) + assert target.shape[1] == 4 + + loader = DataLoader(ds, batch_size=2) + images, targets = next(iter(loader)) + assert isinstance(images, tf.Tensor) and images.shape == (2, *input_size, 3) + assert isinstance(targets, list) and all(isinstance(elt, np.ndarray) for elt in targets) + + # Rotated DS + rotated_ds = datasets.DetectionDataset( + img_folder=mock_image_folder, + label_path=mock_detection_label, + img_transforms=Resize(input_size), + use_polygons=True + ) + _, r_target = rotated_ds[0] + assert r_target.shape[1:] == (4, 2) + + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.DetectionDataset(mock_image_folder, mock_detection_label) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + + +def test_recognition_dataset(mock_image_folder, mock_recognition_label): + input_size = (32, 128) + ds = datasets.RecognitionDataset( + img_folder=mock_image_folder, + labels_path=mock_recognition_label, + img_transforms=Resize(input_size, preserve_aspect_ratio=True), + ) + assert len(ds) == 5 + image, label = ds[0] + assert isinstance(image, tf.Tensor) + assert image.shape[:2] == input_size + assert image.dtype == tf.float32 + assert isinstance(label, str) + + loader = DataLoader(ds, batch_size=2) + images, labels = next(iter(loader)) + assert isinstance(images, tf.Tensor) and images.shape == (2, *input_size, 3) + assert isinstance(labels, list) and all(isinstance(elt, str) for elt in labels) + + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.RecognitionDataset(mock_image_folder, mock_recognition_label) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + + +@pytest.mark.parametrize( + "use_polygons", [False, True], +) +def test_ocrdataset(mock_ocrdataset, use_polygons): + + input_size = (512, 512) + + ds = datasets.OCRDataset( + *mock_ocrdataset, + img_transforms=Resize(input_size), + use_polygons=use_polygons, + ) + assert len(ds) == 3 + _validate_dataset(ds, input_size, is_polygons=use_polygons) + + # File existence check + img_name, _ = ds.data[0] + move(os.path.join(ds.root, img_name), os.path.join(ds.root, "tmp_file")) + with pytest.raises(FileNotFoundError): + datasets.OCRDataset(*mock_ocrdataset) + move(os.path.join(ds.root, "tmp_file"), os.path.join(ds.root, img_name)) + + +def test_charactergenerator(): + + input_size = (32, 32) + vocab = 'abcdef' + + ds = datasets.CharacterGenerator( + vocab=vocab, + num_samples=10, + cache_samples=True, + img_transforms=Resize(input_size), + ) + + assert len(ds) == 10 + image, label = ds[0] + assert isinstance(image, tf.Tensor) + assert image.shape[:2] == input_size + assert image.dtype == tf.float32 + assert isinstance(label, int) and label < len(vocab) + + loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) + images, targets = next(iter(loader)) + assert isinstance(images, tf.Tensor) and images.shape == (2, *input_size, 3) + assert isinstance(targets, tf.Tensor) and targets.shape == (2,) + assert targets.dtype == tf.int32 + + +def test_wordgenerator(): + + input_size = (32, 128) + wordlen_range = (1, 10) + vocab = 'abcdef' + + ds = datasets.WordGenerator( + vocab=vocab, + min_chars=wordlen_range[0], + max_chars=wordlen_range[1], + num_samples=10, + cache_samples=True, + img_transforms=Resize(input_size), + ) + + assert len(ds) == 10 + image, target = ds[0] + assert isinstance(image, tf.Tensor) + assert image.shape[:2] == input_size + assert image.dtype == tf.float32 + assert isinstance(target, str) and len(target) >= wordlen_range[0] and len(target) <= wordlen_range[1] + assert all(char in vocab for char in target) + + loader = DataLoader(ds, batch_size=2, collate_fn=ds.collate_fn) + images, targets = next(iter(loader)) + assert isinstance(images, tf.Tensor) and images.shape == (2, *input_size, 3) + assert isinstance(targets, list) and len(targets) == 2 and all(isinstance(t, str) for t in targets) + + +@pytest.mark.parametrize( + "num_samples, rotate", + [ + [5, True], # Actual set has 229 train and 233 test samples + [5, False] + + ], +) +def test_ic13_dataset(mock_ic13, num_samples, rotate): + input_size = (512, 512) + ds = datasets.IC13( + *mock_ic13, + img_transforms=Resize(input_size), + use_polygons=rotate, + ) + + assert len(ds) == num_samples + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "num_samples, rotate", + [ + [3, True], # Actual set has 7149 train and 796 test samples + [3, False] + + ], +) +def test_imgur5k_dataset(num_samples, rotate, mock_imgur5k): + input_size = (512, 512) + ds = datasets.IMGUR5K( + *mock_imgur5k, + train=True, + img_transforms=Resize(input_size), + use_polygons=rotate, + ) + + assert len(ds) == num_samples - 1 # -1 because of the test set 90 / 10 split + assert repr(ds) == f"IMGUR5K(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[32, 128], 3, True], # Actual set has 33402 training samples and 13068 test samples + [[32, 128], 3, False], + ], +) +def test_svhn(input_size, num_samples, rotate, mock_svhn_dataset): + # monkeypatch the path to temporary dataset + datasets.SVHN.TRAIN = (mock_svhn_dataset, None, "svhn_train.tar") + + ds = datasets.SVHN( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_svhn_dataset.split("/")[:-2]), cache_subdir=mock_svhn_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SVHN(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[512, 512], 3, True], # Actual set has 626 training samples and 360 test samples + [[512, 512], 3, False], + ], +) +def test_sroie(input_size, num_samples, rotate, mock_sroie_dataset): + # monkeypatch the path to temporary dataset + datasets.SROIE.TRAIN = (mock_sroie_dataset, None) + + ds = datasets.SROIE( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_sroie_dataset.split("/")[:-2]), cache_subdir=mock_sroie_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SROIE(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[512, 512], 3, True], # Actual set has 149 training samples and 50 test samples + [[512, 512], 3, False], + ], +) +def test_funsd(input_size, num_samples, rotate, mock_funsd_dataset): + # monkeypatch the path to temporary dataset + datasets.FUNSD.URL = mock_funsd_dataset + datasets.FUNSD.SHA256 = None + datasets.FUNSD.FILE_NAME = "funsd.zip" + + ds = datasets.FUNSD( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_funsd_dataset.split("/")[:-2]), cache_subdir=mock_funsd_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"FUNSD(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[512, 512], 3, True], # Actual set has 800 training samples and 100 test samples + [[512, 512], 3, False], + ], +) +def test_cord(input_size, num_samples, rotate, mock_cord_dataset): + # monkeypatch the path to temporary dataset + datasets.CORD.TRAIN = (mock_cord_dataset, None) + + ds = datasets.CORD( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_cord_dataset.split("/")[:-2]), cache_subdir=mock_cord_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"CORD(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[512, 512], 2, True], # Actual set has 772875 training samples and 85875 test samples + [[512, 512], 2, False], + ], +) +def test_synthtext(input_size, num_samples, rotate, mock_synthtext_dataset): + # monkeypatch the path to temporary dataset + datasets.SynthText.URL = mock_synthtext_dataset + datasets.SynthText.SHA256 = None + + ds = datasets.SynthText( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_synthtext_dataset.split("/")[:-2]), cache_subdir=mock_synthtext_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SynthText(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[512, 512], 3, True], # Actual set has 2700 training samples and 300 test samples + [[512, 512], 3, False], + ], +) +def test_artefact_detection(input_size, num_samples, rotate, mock_doc_artefacts): + # monkeypatch the path to temporary dataset + datasets.DocArtefacts.URL = mock_doc_artefacts + datasets.DocArtefacts.SHA256 = None + + ds = datasets.DocArtefacts( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_doc_artefacts.split("/")[:-2]), cache_subdir=mock_doc_artefacts.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"DocArtefacts(train={True})" + _validate_dataset(ds, input_size, class_indices=True, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[32, 128], 1, True], # Actual set has 2000 training samples and 3000 test samples + [[32, 128], 1, False], + ], +) +def test_iiit5k(input_size, num_samples, rotate, mock_iiit5k_dataset): + # monkeypatch the path to temporary dataset + datasets.IIIT5K.URL = mock_iiit5k_dataset + datasets.IIIT5K.SHA256 = None + + ds = datasets.IIIT5K( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_iiit5k_dataset.split("/")[:-2]), cache_subdir=mock_iiit5k_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"IIIT5K(train={True})" + img, target = ds[0] + _validate_dataset(ds, input_size, batch_size=1, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[512, 512], 3, True], # Actual set has 100 training samples and 249 test samples + [[512, 512], 3, False], + ], +) +def test_svt(input_size, num_samples, rotate, mock_svt_dataset): + # monkeypatch the path to temporary dataset + datasets.SVT.URL = mock_svt_dataset + datasets.SVT.SHA256 = None + + ds = datasets.SVT( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_svt_dataset.split("/")[:-2]), cache_subdir=mock_svt_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"SVT(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) + + +@pytest.mark.parametrize( + "input_size, num_samples, rotate", + [ + [[512, 512], 3, True], # Actual set has 246 training samples and 249 test samples + [[512, 512], 3, False], + ], +) +def test_ic03(input_size, num_samples, rotate, mock_ic03_dataset): + # monkeypatch the path to temporary dataset + datasets.IC03.TRAIN = (mock_ic03_dataset, None, "ic03_train.zip") + + ds = datasets.IC03( + train=True, download=True, img_transforms=Resize(input_size), use_polygons=rotate, + cache_dir="/".join(mock_ic03_dataset.split("/")[:-2]), cache_subdir=mock_ic03_dataset.split("/")[-2], + ) + + assert len(ds) == num_samples + assert repr(ds) == f"IC03(train={True})" + _validate_dataset(ds, input_size, is_polygons=rotate) diff --git a/tests/tensorflow/test_file_utils_tf.py b/tests/tensorflow/test_file_utils_tf.py new file mode 100644 index 0000000000..a28709de4b --- /dev/null +++ b/tests/tensorflow/test_file_utils_tf.py @@ -0,0 +1,5 @@ +from doctr.file_utils import is_tf_available + + +def test_file_utils(): + assert is_tf_available() diff --git a/tests/tensorflow/test_io_image_tf.py b/tests/tensorflow/test_io_image_tf.py new file mode 100644 index 0000000000..d69caa07c5 --- /dev/null +++ b/tests/tensorflow/test_io_image_tf.py @@ -0,0 +1,50 @@ +import numpy as np +import pytest +import tensorflow as tf + +from doctr.io import decode_img_as_tensor, read_img_as_tensor, tensor_from_numpy + + +def test_read_img_as_tensor(mock_image_path): + + img = read_img_as_tensor(mock_image_path) + + assert isinstance(img, tf.Tensor) + assert img.dtype == tf.float32 + assert img.shape == (900, 1200, 3) + + img = read_img_as_tensor(mock_image_path, dtype=tf.float16) + assert img.dtype == tf.float16 + img = read_img_as_tensor(mock_image_path, dtype=tf.uint8) + assert img.dtype == tf.uint8 + + +def test_decode_img_as_tensor(mock_image_stream): + + img = decode_img_as_tensor(mock_image_stream) + + assert isinstance(img, tf.Tensor) + assert img.dtype == tf.float32 + assert img.shape == (900, 1200, 3) + + img = decode_img_as_tensor(mock_image_stream, dtype=tf.float16) + assert img.dtype == tf.float16 + img = decode_img_as_tensor(mock_image_stream, dtype=tf.uint8) + assert img.dtype == tf.uint8 + + +def test_tensor_from_numpy(mock_image_stream): + + with pytest.raises(ValueError): + tensor_from_numpy(np.zeros((256, 256, 3)), tf.int64) + + out = tensor_from_numpy(np.zeros((256, 256, 3), dtype=np.uint8)) + + assert isinstance(out, tf.Tensor) + assert out.dtype == tf.float32 + assert out.shape == (256, 256, 3) + + out = tensor_from_numpy(np.zeros((256, 256, 3), dtype=np.uint8), dtype=tf.float16) + assert out.dtype == tf.float16 + out = tensor_from_numpy(np.zeros((256, 256, 3), dtype=np.uint8), dtype=tf.uint8) + assert out.dtype == tf.uint8 diff --git a/tests/tensorflow/test_models_classification_tf.py b/tests/tensorflow/test_models_classification_tf.py new file mode 100644 index 0000000000..1ea19e2c79 --- /dev/null +++ b/tests/tensorflow/test_models_classification_tf.py @@ -0,0 +1,75 @@ +import cv2 +import numpy as np +import pytest +import tensorflow as tf + +from doctr.models import classification +from doctr.models.classification.predictor import CropOrientationPredictor + + +@pytest.mark.parametrize( + "arch_name, input_shape, output_size", + [ + ["vgg16_bn_r", (32, 32, 3), (126,)], + ["resnet18", (32, 32, 3), (126,)], + ["resnet31", (32, 32, 3), (126,)], + ["magc_resnet31", (32, 32, 3), (126,)], + ["mobilenet_v3_small", (32, 32, 3), (126,)], + ["mobilenet_v3_large", (32, 32, 3), (126,)], + ], +) +def test_classification_architectures(arch_name, input_shape, output_size): + # Model + batch_size = 2 + tf.keras.backend.clear_session() + model = classification.__dict__[arch_name](pretrained=True, include_top=True, input_shape=input_shape) + # Forward + out = model(tf.random.uniform(shape=[batch_size, *input_shape], maxval=1, dtype=tf.float32)) + # Output checks + assert isinstance(out, tf.Tensor) + assert out.dtype == tf.float32 + assert out.numpy().shape == (batch_size, *output_size) + + +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["mobilenet_v3_small_orientation", (128, 128, 3)], + ], +) +def test_classification_models(arch_name, input_shape): + batch_size = 8 + reco_model = classification.__dict__[arch_name](pretrained=True, input_shape=input_shape) + assert isinstance(reco_model, tf.keras.Model) + input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) + + out = reco_model(input_tensor) + assert isinstance(out, tf.Tensor) + assert out.shape.as_list() == [8, 4] + + +@pytest.mark.parametrize( + "arch_name", + [ + "mobilenet_v3_small_orientation", + ], +) +def test_classification_zoo(arch_name): + batch_size = 16 + # Model + predictor = classification.zoo.crop_orientation_predictor(arch_name, pretrained=False) + # object check + assert isinstance(predictor, CropOrientationPredictor) + input_tensor = tf.random.uniform(shape=[batch_size, 128, 128, 3], minval=0, maxval=1) + out = predictor(input_tensor) + assert isinstance(out, list) and len(out) == batch_size + assert all(isinstance(pred, int) for pred in out) + + +def test_crop_orientation_model(mock_text_box): + text_box_0 = cv2.imread(mock_text_box) + text_box_90 = np.rot90(text_box_0, 1) + text_box_180 = np.rot90(text_box_0, 2) + text_box_270 = np.rot90(text_box_0, 3) + classifier = classification.crop_orientation_predictor("mobilenet_v3_small_orientation", pretrained=True) + assert classifier([text_box_0, text_box_90, text_box_180, text_box_270]) == [0, 1, 2, 3] diff --git a/tests/tensorflow/test_models_detection_tf.py b/tests/tensorflow/test_models_detection_tf.py new file mode 100644 index 0000000000..549e987226 --- /dev/null +++ b/tests/tensorflow/test_models_detection_tf.py @@ -0,0 +1,156 @@ +import numpy as np +import pytest +import tensorflow as tf + +from doctr.io import DocumentFile +from doctr.models import detection +from doctr.models.detection._utils import dilate, erode +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.preprocessor import PreProcessor + + +@pytest.mark.parametrize( + "arch_name, input_shape, output_size, out_prob", + [ + ["db_resnet50", (512, 512, 3), (512, 512, 1), True], + ["db_mobilenet_v3_large", (512, 512, 3), (512, 512, 1), True], + ["linknet_resnet18", (512, 512, 3), (512, 512, 1), False], + ], +) +def test_detection_models(arch_name, input_shape, output_size, out_prob): + batch_size = 2 + tf.keras.backend.clear_session() + model = detection.__dict__[arch_name](pretrained=True, input_shape=input_shape) + assert isinstance(model, tf.keras.Model) + input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) + target = [ + np.array([[.5, .5, 1, 1], [0.5, 0.5, .8, .8]], dtype=np.float32), + np.array([[.5, .5, 1, 1], [0.5, 0.5, .8, .9]], dtype=np.float32), + ] + # test training model + out = model(input_tensor, target, return_model_output=True, return_preds=True, training=True) + assert isinstance(out, dict) + assert len(out) == 3 + # Check proba map + assert isinstance(out['out_map'], tf.Tensor) + assert out['out_map'].dtype == tf.float32 + seg_map = out['out_map'].numpy() + assert seg_map.shape == (batch_size, *output_size) + if out_prob: + assert np.all(np.logical_and(seg_map >= 0, seg_map <= 1)) + # Check boxes + for boxes in out['preds']: + assert boxes.shape[1] == 5 + assert np.all(boxes[:, :2] < boxes[:, 2:4]) + assert np.all(boxes[:, :4] >= 0) and np.all(boxes[:, :4] <= 1) + # Check loss + assert isinstance(out['loss'], tf.Tensor) + # Target checks + target = [ + np.array([[0, 0, 1, 1]], dtype=np.uint8), + np.array([[0, 0, 1, 1]], dtype=np.uint8), + ] + with pytest.raises(AssertionError): + out = model(input_tensor, target, training=True) + + target = [ + np.array([[0, 0, 1.5, 1.5]], dtype=np.float32), + np.array([[-.2, -.3, 1, 1]], dtype=np.float32), + ] + with pytest.raises(ValueError): + out = model(input_tensor, target, training=True) + + # Check the rotated case + target = [ + np.array([[.75, .75, .5, .5, 0], [.65, .65, .3, .3, 0]], dtype=np.float32), + np.array([[.75, .75, .5, .5, 0], [.65, .7, .3, .4, 0]], dtype=np.float32), + ] + loss = model(input_tensor, target, training=True)['loss'] + assert isinstance(loss, tf.Tensor) and ((loss - out['loss']) / loss).numpy() < 21e-2 + + +@pytest.fixture(scope="session") +def test_detectionpredictor(mock_pdf): # noqa: F811 + + batch_size = 4 + predictor = DetectionPredictor( + PreProcessor(output_size=(512, 512), batch_size=batch_size), + detection.db_resnet50(input_shape=(512, 512, 3)) + ) + + pages = DocumentFile.from_pdf(mock_pdf).as_images() + out = predictor(pages) + # The input PDF has 2 pages + assert len(out) == 2 + + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + return predictor + + +@pytest.fixture(scope="session") +def test_rotated_detectionpredictor(mock_pdf): # noqa: F811 + + batch_size = 4 + predictor = DetectionPredictor( + PreProcessor(output_size=(512, 512), batch_size=batch_size), + detection.db_resnet50(assume_straight_pages=False, input_shape=(512, 512, 3)) + ) + + pages = DocumentFile.from_pdf(mock_pdf).as_images() + out = predictor(pages) + + # The input PDF has 2 pages + assert len(out) == 2 + + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + return predictor + + +@pytest.mark.parametrize( + "arch_name", + [ + "db_resnet50", + "db_mobilenet_v3_large", + "linknet_resnet18", + ], +) +def test_detection_zoo(arch_name): + # Model + tf.keras.backend.clear_session() + predictor = detection.zoo.detection_predictor(arch_name, pretrained=False) + # object check + assert isinstance(predictor, DetectionPredictor) + input_tensor = tf.random.uniform(shape=[2, 1024, 1024, 3], minval=0, maxval=1) + out = predictor(input_tensor) + assert all(isinstance(boxes, np.ndarray) and boxes.shape[1] == 5 for boxes in out) + + +def test_detection_zoo_error(): + with pytest.raises(ValueError): + _ = detection.zoo.detection_predictor("my_fancy_model", pretrained=False) + + +def test_erode(): + x = np.zeros((1, 3, 3, 1), dtype=np.float32) + x[:, 1, 1] = 1 + x = tf.convert_to_tensor(x) + expected = tf.zeros((1, 3, 3, 1)) + out = erode(x, 3) + assert tf.math.reduce_all(out == expected) + + +def test_dilate(): + x = np.zeros((1, 3, 3, 1), dtype=np.float32) + x[:, 1, 1] = 1 + x = tf.convert_to_tensor(x) + expected = tf.ones((1, 3, 3, 1)) + out = dilate(x, 3) + assert tf.math.reduce_all(out == expected) diff --git a/tests/tensorflow/test_models_preprocessor_tf.py b/tests/tensorflow/test_models_preprocessor_tf.py new file mode 100644 index 0000000000..efba8084f3 --- /dev/null +++ b/tests/tensorflow/test_models_preprocessor_tf.py @@ -0,0 +1,45 @@ +import numpy as np +import pytest +import tensorflow as tf + +from doctr.models.preprocessor import PreProcessor + + +@pytest.mark.parametrize( + "batch_size, output_size, input_tensor, expected_batches, expected_value", + [ + [2, (128, 128), np.full((3, 256, 128, 3), 255, dtype=np.uint8), 1, .5], # numpy uint8 + [2, (128, 128), np.ones((3, 256, 128, 3), dtype=np.float32), 1, .5], # numpy fp32 + [2, (128, 128), tf.cast(tf.fill((3, 256, 128, 3), 255), dtype=tf.uint8), 1, .5], # tf uint8 + [2, (128, 128), tf.ones((3, 128, 128, 3), dtype=tf.float32), 1, .5], # tf fp32 + [2, (128, 128), [np.full((256, 128, 3), 255, dtype=np.uint8)] * 3, 2, .5], # list of numpy uint8 + [2, (128, 128), [np.ones((256, 128, 3), dtype=np.float32)] * 3, 2, .5], # list of numpy fp32 + [2, (128, 128), [tf.cast(tf.fill((256, 128, 3), 255), dtype=tf.uint8)] * 3, 2, .5], # list of tf uint8 + [2, (128, 128), [tf.ones((128, 128, 3), dtype=tf.float32)] * 3, 2, .5], # list of tf fp32 + ], +) +def test_preprocessor(batch_size, output_size, input_tensor, expected_batches, expected_value): + + processor = PreProcessor(output_size, batch_size) + + # Invalid input type + with pytest.raises(TypeError): + processor(42) + # 4D check + with pytest.raises(AssertionError): + processor(np.full((256, 128, 3), 255, dtype=np.uint8)) + with pytest.raises(TypeError): + processor(np.full((1, 256, 128, 3), 255, dtype=np.int32)) + # 3D check + with pytest.raises(AssertionError): + processor([np.full((3, 256, 128, 3), 255, dtype=np.uint8)]) + with pytest.raises(TypeError): + processor([np.full((256, 128, 3), 255, dtype=np.int32)]) + + out = processor(input_tensor) + assert isinstance(out, list) and len(out) == expected_batches + assert all(isinstance(b, tf.Tensor) for b in out) + assert all(b.dtype == tf.float32 for b in out) + assert all(b.shape[1:3] == output_size for b in out) + assert all(tf.math.reduce_all(b == expected_value) for b in out) + assert len(repr(processor).split('\n')) == 4 diff --git a/tests/tensorflow/test_models_recognition_tf.py b/tests/tensorflow/test_models_recognition_tf.py new file mode 100644 index 0000000000..06b1f613ae --- /dev/null +++ b/tests/tensorflow/test_models_recognition_tf.py @@ -0,0 +1,114 @@ +import numpy as np +import pytest +import tensorflow as tf + +from doctr.io import DocumentFile +from doctr.models import recognition +from doctr.models._utils import extract_crops +from doctr.models.preprocessor import PreProcessor +from doctr.models.recognition.crnn.tensorflow import CTCPostProcessor +from doctr.models.recognition.master.tensorflow import MASTERPostProcessor +from doctr.models.recognition.predictor import RecognitionPredictor +from doctr.models.recognition.sar.tensorflow import SARPostProcessor + + +@pytest.mark.parametrize( + "arch_name, input_shape", + [ + ["crnn_vgg16_bn", (32, 128, 3)], + ["crnn_mobilenet_v3_small", (32, 128, 3)], + ["crnn_mobilenet_v3_large", (32, 128, 3)], + ["sar_resnet31", (32, 128, 3)], + ["master", (32, 128, 3)], + ], +) +def test_recognition_models(arch_name, input_shape): + batch_size = 4 + reco_model = recognition.__dict__[arch_name](pretrained=True, input_shape=input_shape) + assert isinstance(reco_model, tf.keras.Model) + input_tensor = tf.random.uniform(shape=[batch_size, *input_shape], minval=0, maxval=1) + target = ["i", "am", "a", "jedi"] + + out = reco_model(input_tensor, target, return_model_output=True, return_preds=True) + assert isinstance(out, dict) + assert len(out) == 3 + assert isinstance(out['out_map'], tf.Tensor) + assert out['out_map'].dtype == tf.float32 + assert isinstance(out['preds'], list) + assert len(out['preds']) == batch_size + assert all(isinstance(word, str) and isinstance(conf, float) and 0 <= conf <= 1 for word, conf in out['preds']) + assert isinstance(out['loss'], tf.Tensor) + + +@pytest.mark.parametrize( + "post_processor, input_shape", + [ + [SARPostProcessor, [2, 30, 119]], + [CTCPostProcessor, [2, 30, 119]], + [MASTERPostProcessor, [2, 30, 119]], + ], +) +def test_reco_postprocessors(post_processor, input_shape, mock_vocab): + processor = post_processor(mock_vocab) + decoded = processor(tf.random.uniform(shape=input_shape, minval=0, maxval=1, dtype=tf.float32)) + assert isinstance(decoded, list) + assert all(isinstance(word, str) and isinstance(conf, float) and 0 <= conf <= 1 for word, conf in decoded) + assert len(decoded) == input_shape[0] + assert all(char in mock_vocab for word, _ in decoded for char in word) + # Repr + assert repr(processor) == f'{post_processor.__name__}(vocab_size={len(mock_vocab)})' + + +@pytest.fixture(scope="session") +def test_recognitionpredictor(mock_pdf, mock_vocab): # noqa: F811 + + batch_size = 4 + predictor = RecognitionPredictor( + PreProcessor(output_size=(32, 128), batch_size=batch_size, preserve_aspect_ratio=True), + recognition.crnn_vgg16_bn(vocab=mock_vocab, input_shape=(32, 128, 3)) + ) + + pages = DocumentFile.from_pdf(mock_pdf).as_images() + # Create bounding boxes + boxes = np.array([[.5, .5, 0.75, 0.75], [0.5, 0.5, 1., 1.]], dtype=np.float32) + crops = extract_crops(pages[0], boxes) + + out = predictor(crops) + + # One prediction per crop + assert len(out) == boxes.shape[0] + assert all(isinstance(val, str) and isinstance(conf, float) for val, conf in out) + + # Dimension check + with pytest.raises(ValueError): + input_crop = (255 * np.random.rand(1, 128, 64, 3)).astype(np.uint8) + _ = predictor([input_crop]) + + return predictor + + +@pytest.mark.parametrize( + "arch_name", + [ + "crnn_vgg16_bn", + "crnn_mobilenet_v3_small", + "crnn_mobilenet_v3_large", + "sar_resnet31", + "master" + ], +) +def test_recognition_zoo(arch_name): + batch_size = 2 + # Model + predictor = recognition.zoo.recognition_predictor(arch_name, pretrained=False) + # object check + assert isinstance(predictor, RecognitionPredictor) + input_tensor = tf.random.uniform(shape=[batch_size, 128, 128, 3], minval=0, maxval=1) + out = predictor(input_tensor) + assert isinstance(out, list) and len(out) == batch_size + assert all(isinstance(word, str) and isinstance(conf, float) for word, conf in out) + + +def test_recognition_zoo_error(): + with pytest.raises(ValueError): + _ = recognition.zoo.recognition_predictor("my_fancy_model", pretrained=False) diff --git a/tests/tensorflow/test_models_utils_tf.py b/tests/tensorflow/test_models_utils_tf.py new file mode 100644 index 0000000000..b72001daa0 --- /dev/null +++ b/tests/tensorflow/test_models_utils_tf.py @@ -0,0 +1,45 @@ +import os + +import pytest +import tensorflow as tf +from tensorflow.keras import Sequential, layers +from tensorflow.keras.applications import ResNet50 + +from doctr.models.utils import IntermediateLayerGetter, conv_sequence, load_pretrained_params + + +def test_load_pretrained_params(tmpdir_factory): + + model = Sequential([layers.Dense(8, activation='relu', input_shape=(4,)), layers.Dense(4)]) + # Retrieve this URL + url = "https://github.com/mindee/doctr/releases/download/v0.1-models/tmp_checkpoint-4a98e492.zip" + # Temp cache dir + cache_dir = tmpdir_factory.mktemp("cache") + # Pass an incorrect hash + with pytest.raises(ValueError): + load_pretrained_params(model, url, "mywronghash", cache_dir=str(cache_dir), internal_name='') + # Let tit resolve the hash from the file name + load_pretrained_params(model, url, cache_dir=str(cache_dir), internal_name='') + # Check that the file was downloaded & the archive extracted + assert os.path.exists(cache_dir.join('models').join("tmp_checkpoint-4a98e492")) + # Check that archive was deleted + assert os.path.exists(cache_dir.join('models').join("tmp_checkpoint-4a98e492.zip")) + + +def test_conv_sequence(): + + assert len(conv_sequence(8, kernel_size=3)) == 1 + assert len(conv_sequence(8, 'relu', kernel_size=3)) == 1 + assert len(conv_sequence(8, None, True, kernel_size=3)) == 2 + assert len(conv_sequence(8, 'relu', True, kernel_size=3)) == 3 + + +def test_intermediate_layer_getter(): + backbone = ResNet50(include_top=False, weights=None, pooling=None) + feat_extractor = IntermediateLayerGetter(backbone, ["conv2_block3_out", "conv3_block4_out"]) + # Check num of output features + input_tensor = tf.random.uniform(shape=[1, 224, 224, 3], minval=0, maxval=1) + assert len(feat_extractor(input_tensor)) == 2 + + # Repr + assert repr(feat_extractor) == "IntermediateLayerGetter()" diff --git a/tests/tensorflow/test_models_zoo_tf.py b/tests/tensorflow/test_models_zoo_tf.py new file mode 100644 index 0000000000..c3c3b94d8f --- /dev/null +++ b/tests/tensorflow/test_models_zoo_tf.py @@ -0,0 +1,79 @@ +import numpy as np +import pytest + +from doctr import models +from doctr.io import Document, DocumentFile +from doctr.models import detection, recognition +from doctr.models.detection.predictor import DetectionPredictor +from doctr.models.predictor import OCRPredictor +from doctr.models.preprocessor import PreProcessor +from doctr.models.recognition.predictor import RecognitionPredictor + + +@pytest.mark.parametrize( + "assume_straight_pages, straighten_pages", + [ + [True, False], + [False, False], + [True, True], + ] +) +def test_ocrpredictor(mock_pdf, mock_vocab, assume_straight_pages, straighten_pages): + det_bsize = 4 + det_predictor = DetectionPredictor( + PreProcessor(output_size=(512, 512), batch_size=det_bsize), + detection.db_mobilenet_v3_large( + pretrained=True, + pretrained_backbone=False, + input_shape=(512, 512, 3), + assume_straight_pages=assume_straight_pages, + ) + ) + + reco_bsize = 16 + reco_predictor = RecognitionPredictor( + PreProcessor(output_size=(32, 128), batch_size=reco_bsize, preserve_aspect_ratio=True), + recognition.crnn_vgg16_bn(pretrained=False, pretrained_backbone=False, vocab=mock_vocab) + ) + + doc = DocumentFile.from_pdf(mock_pdf).as_images() + + predictor = OCRPredictor( + det_predictor, + reco_predictor, + assume_straight_pages=assume_straight_pages, + straighten_pages=straighten_pages, + ) + + out = predictor(doc) + assert isinstance(out, Document) + assert len(out.pages) == 2 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) + + +@pytest.mark.parametrize( + "det_arch, reco_arch", + [ + ["db_mobilenet_v3_large", "crnn_vgg16_bn"], + ], +) +def test_zoo_models(det_arch, reco_arch): + # Model + predictor = models.ocr_predictor(det_arch, reco_arch, pretrained=True) + # Output checks + assert isinstance(predictor, OCRPredictor) + + doc = [np.zeros((512, 512, 3), dtype=np.uint8)] + out = predictor(doc) + # Document + assert isinstance(out, Document) + + # The input doc has 1 page + assert len(out.pages) == 1 + # Dimension check + with pytest.raises(ValueError): + input_page = (255 * np.random.rand(1, 256, 512, 3)).astype(np.uint8) + _ = predictor([input_page]) diff --git a/tests/tensorflow/test_transforms_tf.py b/tests/tensorflow/test_transforms_tf.py new file mode 100644 index 0000000000..5182be854a --- /dev/null +++ b/tests/tensorflow/test_transforms_tf.py @@ -0,0 +1,431 @@ +import math + +import numpy as np +import pytest +import tensorflow as tf + +from doctr import transforms as T +from doctr.transforms.functional import crop_detection, rotate_sample + + +def test_resize(): + output_size = (32, 32) + transfo = T.Resize(output_size) + input_t = tf.cast(tf.fill([64, 64, 3], 1), dtype=tf.float32) + out = transfo(input_t) + + assert tf.reduce_all(out == 1) + assert out.shape[:2] == output_size + assert repr(transfo) == f"Resize(output_size={output_size}, method='bilinear')" + + transfo = T.Resize(output_size, preserve_aspect_ratio=True) + input_t = tf.cast(tf.fill([32, 64, 3], 1), dtype=tf.float32) + out = transfo(input_t) + + assert not tf.reduce_all(out == 1) + # Asymetric padding + assert tf.reduce_all(out[-1] == 0) and tf.reduce_all(out[0] == 1) + assert out.shape[:2] == output_size + + # Symetric padding + transfo = T.Resize(output_size, preserve_aspect_ratio=True, symmetric_pad=True) + assert repr(transfo) == (f"Resize(output_size={output_size}, method='bilinear', " + f"preserve_aspect_ratio=True, symmetric_pad=True)") + out = transfo(input_t) + # Asymetric padding + assert tf.reduce_all(out[-1] == 0) and tf.reduce_all(out[0] == 0) + + # Inverse aspect ratio + input_t = tf.cast(tf.fill([64, 32, 3], 1), dtype=tf.float32) + out = transfo(input_t) + + assert not tf.reduce_all(out == 1) + assert out.shape[:2] == output_size + + # FP16 + input_t = tf.cast(tf.fill([64, 64, 3], 1), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_compose(): + + output_size = (16, 16) + transfo = T.Compose([T.Resize((32, 32)), T.Resize(output_size)]) + input_t = tf.random.uniform(shape=[64, 64, 3], minval=0, maxval=1) + out = transfo(input_t) + + assert out.shape[:2] == output_size + assert len(repr(transfo).split("\n")) == 6 + + +@pytest.mark.parametrize( + "input_shape", + [ + [8, 32, 32, 3], + [32, 32, 3], + [32, 3], + ], +) +def test_normalize(input_shape): + mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] + transfo = T.Normalize(mean, std) + input_t = tf.cast(tf.fill(input_shape, 1), dtype=tf.float32) + + out = transfo(input_t) + + assert tf.reduce_all(out == 1) + assert repr(transfo) == f"Normalize(mean={mean}, std={std})" + + # FP16 + input_t = tf.cast(tf.fill(input_shape, 1), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_lambatransformation(): + + transfo = T.LambdaTransformation(lambda x: x / 2) + input_t = tf.cast(tf.fill([8, 32, 32, 3], 1), dtype=tf.float32) + out = transfo(input_t) + + assert tf.reduce_all(out == 0.5) + + +def test_togray(): + + transfo = T.ToGray() + r = tf.fill([8, 32, 32, 1], 0.2) + g = tf.fill([8, 32, 32, 1], 0.6) + b = tf.fill([8, 32, 32, 1], 0.7) + input_t = tf.cast(tf.concat([r, g, b], axis=-1), dtype=tf.float32) + out = transfo(input_t) + + assert tf.reduce_all(out <= .51) + assert tf.reduce_all(out >= .49) + + # FP16 + input_t = tf.cast(tf.concat([r, g, b], axis=-1), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +@pytest.mark.parametrize( + "rgb_min", + [ + 0.2, + 0.4, + 0.6, + ], +) +def test_invert_colorize(rgb_min): + + transfo = T.ColorInversion(min_val=rgb_min) + input_t = tf.cast(tf.fill([8, 32, 32, 3], 1), dtype=tf.float32) + out = transfo(input_t) + assert tf.reduce_all(out <= 1 - rgb_min + 1e-4) + assert tf.reduce_all(out >= 0) + + input_t = tf.cast(tf.fill([8, 32, 32, 3], 255), dtype=tf.uint8) + out = transfo(input_t) + assert tf.reduce_all(out <= int(math.ceil(255 * (1 - rgb_min)))) + assert tf.reduce_all(out >= 0) + + # FP16 + input_t = tf.cast(tf.fill([8, 32, 32, 3], 1), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_brightness(): + + transfo = T.RandomBrightness(max_delta=.1) + input_t = tf.cast(tf.fill([8, 32, 32, 3], .5), dtype=tf.float32) + out = transfo(input_t) + + assert tf.reduce_all(out >= .4) + assert tf.reduce_all(out <= .6) + + # FP16 + input_t = tf.cast(tf.fill([8, 32, 32, 3], .5), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_contrast(): + transfo = T.RandomContrast(delta=.2) + input_t = tf.cast(tf.fill([8, 32, 32, 3], .5), dtype=tf.float32) + out = transfo(input_t) + + assert tf.reduce_all(out == .5) + + # FP16 + if any(tf.config.list_physical_devices('GPU')): + input_t = tf.cast(tf.fill([8, 32, 32, 3], .5), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_saturation(): + + transfo = T.RandomSaturation(delta=.2) + input_t = tf.cast(tf.fill([8, 32, 32, 3], .5), dtype=tf.float32) + input_t = tf.image.hsv_to_rgb(input_t) + out = transfo(input_t) + hsv = tf.image.rgb_to_hsv(out) + + assert tf.reduce_all(hsv[:, :, :, 1] >= .4) + assert tf.reduce_all(hsv[:, :, :, 1] <= .6) + + # FP16 + if any(tf.config.list_physical_devices('GPU')): + input_t = tf.cast(tf.fill([8, 32, 32, 3], .5), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_hue(): + + transfo = T.RandomHue(max_delta=.2) + input_t = tf.cast(tf.fill([8, 32, 32, 3], .5), dtype=tf.float32) + input_t = tf.image.hsv_to_rgb(input_t) + out = transfo(input_t) + hsv = tf.image.rgb_to_hsv(out) + + assert tf.reduce_all(hsv[:, :, :, 0] <= .7) + assert tf.reduce_all(hsv[:, :, :, 0] >= .3) + + # FP16 + if any(tf.config.list_physical_devices('GPU')): + input_t = tf.cast(tf.fill([8, 32, 32, 3], .5), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_gamma(): + + transfo = T.RandomGamma(min_gamma=1., max_gamma=2., min_gain=.8, max_gain=1.) + input_t = tf.cast(tf.fill([8, 32, 32, 3], 2.), dtype=tf.float32) + out = transfo(input_t) + + assert tf.reduce_all(out >= 1.6) + assert tf.reduce_all(out <= 4.) + + # FP16 + input_t = tf.cast(tf.fill([8, 32, 32, 3], 2.), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_jpegquality(): + + transfo = T.RandomJpegQuality(min_quality=50) + input_t = tf.cast(tf.fill([32, 32, 3], 1), dtype=tf.float32) + out = transfo(input_t) + assert out.shape == input_t.shape + + # FP16 + input_t = tf.cast(tf.fill([32, 32, 3], 1), dtype=tf.float16) + out = transfo(input_t) + assert out.dtype == tf.float16 + + +def test_rotate_sample(): + img = tf.ones((200, 100, 3), dtype=tf.float32) + boxes = np.array([0, 0, 100, 200])[None, ...] + polys = np.stack((boxes[..., [0, 1]], boxes[..., [2, 1]], boxes[..., [2, 3]], boxes[..., [0, 3]]), axis=1) + rel_boxes = np.array([0, 0, 1, 1], dtype=np.float32)[None, ...] + rel_polys = np.stack( + (rel_boxes[..., [0, 1]], rel_boxes[..., [2, 1]], rel_boxes[..., [2, 3]], rel_boxes[..., [0, 3]]), + axis=1 + ) + + # No angle + rotated_img, rotated_geoms = rotate_sample(img, boxes, 0, False) + assert tf.math.reduce_all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + rotated_img, rotated_geoms = rotate_sample(img, boxes, 0, True) + assert tf.math.reduce_all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 0, False) + assert tf.math.reduce_all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 0, True) + assert tf.math.reduce_all(rotated_img == img) and np.all(rotated_geoms == rel_polys) + + # No expansion + expected_img = np.zeros((200, 100, 3), dtype=np.float32) + expected_img[50: 150] = 1 + expected_img = tf.convert_to_tensor(expected_img) + expected_polys = np.array([[0, .75], [0, .25], [1, .25], [1, .75]])[None, ...] + rotated_img, rotated_geoms = rotate_sample(img, boxes, 90, False) + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 90, False) + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_boxes, 90, False) + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_polys, 90, False) + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + + # Expansion + expected_img = tf.ones((100, 200, 3), dtype=tf.float32) + expected_polys = np.array([[0, 1], [0, 0], [1, 0], [1, 1]], dtype=np.float32)[None, ...] + rotated_img, rotated_geoms = rotate_sample(img, boxes, 90, True) + # import ipdb; ipdb.set_trace() + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, polys, 90, True) + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_boxes, 90, True) + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + rotated_img, rotated_geoms = rotate_sample(img, rel_polys, 90, True) + assert tf.math.reduce_all(rotated_img == expected_img) and np.all(rotated_geoms == expected_polys) + + with pytest.raises(AssertionError): + rotate_sample(img, boxes[None, ...], 90, False) + + +def test_random_rotate(): + rotator = T.RandomRotate(max_angle=10., expand=False) + input_t = tf.ones((50, 50, 3), dtype=tf.float32) + boxes = np.array([ + [15, 20, 35, 30] + ]) + r_img, r_boxes = rotator(input_t, boxes) + assert r_img.shape == input_t.shape + + rotator = T.RandomRotate(max_angle=10., expand=True) + r_img, r_boxes = rotator(input_t, boxes) + assert r_img.shape != input_t.shape + + # FP16 + input_t = tf.ones((50, 50, 3), dtype=tf.float16) + r_img, _ = rotator(input_t, boxes) + assert r_img.dtype == tf.float16 + + +def test_crop_detection(): + img = tf.ones((50, 50, 3), dtype=tf.float32) + abs_boxes = np.array([ + [15, 20, 35, 30], + [5, 10, 10, 20], + ]) + crop_box = (12 / 50, 23 / 50, 1., 1.) + c_img, c_boxes = crop_detection(img, abs_boxes, crop_box) + assert c_img.shape == (26, 37, 3) + assert c_boxes.shape == (1, 4) + assert np.all(c_boxes == np.array([15 - 12, 0, 35 - 12, 30 - 23])[None, ...]) + + rel_boxes = np.array([ + [.3, .4, .7, .6], + [.1, .2, .2, .4], + ]) + c_img, c_boxes = crop_detection(img, rel_boxes, crop_box) + assert c_img.shape == (26, 37, 3) + assert c_boxes.shape == (1, 4) + assert np.abs(c_boxes - np.array([.06 / .76, 0., .46 / .76, .14 / .54])[None, ...]).mean() < 1e-7 + + # FP16 + img = tf.ones((50, 50, 3), dtype=tf.float16) + c_img, _ = crop_detection(img, rel_boxes, crop_box) + assert c_img.dtype == tf.float16 + + with pytest.raises(AssertionError): + crop_detection(img, abs_boxes, (2, 6, 24, 56)) + + +def test_random_crop(): + transfo = T.RandomCrop(scale=(0.5, 1.), ratio=(0.75, 1.33)) + input_t = tf.ones((50, 50, 3), dtype=tf.float32) + boxes = np.array([ + [15, 20, 35, 30] + ]) + img, target = transfo(input_t, dict(boxes=boxes)) + # Check the scale (take a margin) + assert img.shape[0] * img.shape[1] >= 0.4 * input_t.shape[0] * input_t.shape[1] + # Check aspect ratio (take a margin) + assert 0.65 <= img.shape[0] / img.shape[1] <= 1.5 + # Check the target + assert np.all(target['boxes'] >= 0) + assert np.all(target['boxes'][:, [0, 2]] <= img.shape[1]) and np.all(target['boxes'][:, [1, 3]] <= img.shape[0]) + + +def test_gaussian_blur(): + blur = T.GaussianBlur(3, (.1, 3)) + input_t = np.ones((31, 31, 3), dtype=np.float32) + input_t[15, 15] = 0 + blur_img = blur(tf.convert_to_tensor(input_t)).numpy() + assert blur_img.shape == input_t.shape + assert np.all(blur_img[15, 15] > 0) + + +@pytest.mark.parametrize( + "input_dtype, input_size", + [ + [tf.float32, (32, 32, 3)], + [tf.uint8, (32, 32, 3)], + ], +) +def test_channel_shuffle(input_dtype, input_size): + transfo = T.ChannelShuffle() + input_t = tf.random.uniform(input_size, dtype=tf.float32) + if input_dtype == tf.uint8: + input_t = tf.math.round(255 * input_t) + input_t = tf.cast(input_t, dtype=input_dtype) + out = transfo(input_t) + assert isinstance(out, tf.Tensor) + assert out.shape == input_size + assert out.dtype == input_dtype + # Ensure that nothing has changed apart from channel order + assert tf.math.reduce_all(tf.math.reduce_sum(input_t, -1) == tf.math.reduce_sum(out, -1)) + + +@pytest.mark.parametrize( + "input_dtype,input_shape", + [ + [tf.float32, (32, 32, 3)], + [tf.uint8, (32, 32, 3)], + ] +) +def test_gaussian_noise(input_dtype, input_shape): + transform = T.GaussianNoise(0., 1.) + input_t = tf.random.uniform(input_shape, dtype=tf.float32) + if input_dtype == tf.uint8: + input_t = tf.math.round((255 * input_t)) + input_t = tf.cast(input_t, dtype=input_dtype) + transformed = transform(input_t) + assert isinstance(transformed, tf.Tensor) + assert transformed.shape == input_shape + assert transformed.dtype == input_dtype + assert tf.math.reduce_any(transformed != input_t) + assert tf.math.reduce_all(transformed >= 0) + if input_dtype == tf.uint8: + assert tf.reduce_all(transformed <= 255) + else: + assert tf.reduce_all(transformed <= 1.) + + +@pytest.mark.parametrize("p", [1, 0]) +def test_randomhorizontalflip(p): + # testing for 2 cases, with flip probability 1 and 0. + transform = T.RandomHorizontalFlip(p) + input_t = np.ones((32, 32, 3)) + input_t[:, :16, :] = 0 + input_t = tf.convert_to_tensor(input_t) + target = {"boxes": np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32), "labels": np.ones(1, dtype=np.int64)} + transformed, _target = transform(input_t, target) + assert isinstance(transformed, tf.Tensor) + assert transformed.shape == input_t.shape + assert transformed.dtype == input_t.dtype + # integrity check of targets + assert isinstance(_target, dict) + assert all(isinstance(val, np.ndarray) for val in _target.values()) + assert _target["boxes"].dtype == np.float32 + assert _target["labels"].dtype == np.int64 + if p == 1: + assert np.all(_target["boxes"] == np.array([[0.7, 0.1, 0.9, 0.4]], dtype=np.float32)) + assert tf.reduce_all( + tf.math.reduce_mean(transformed, (0, 2)) == tf.constant([1] * 16 + [0] * 16, dtype=tf.float64) + ) + elif p == 0: + assert np.all(_target["boxes"] == np.array([[0.1, 0.1, 0.3, 0.4]], dtype=np.float32)) + assert tf.reduce_all( + tf.math.reduce_mean(transformed, (0, 2)) == tf.constant([0] * 16 + [1] * 16, dtype=tf.float64) + ) + assert np.all(_target["labels"] == np.ones(1, dtype=np.int64)) From 5eebf06db5f9543c6d77ec7129a71f6cc79ff9a3 Mon Sep 17 00:00:00 2001 From: felixdittrich92 Date: Mon, 21 Feb 2022 12:04:31 +0100 Subject: [PATCH 2/7] add mjsynth loader --- docs/source/modules/datasets.rst | 1 + doctr/datasets/__init__.py | 1 + doctr/datasets/mjsynth.py | 70 ++++++++++++++++++++++++++++ tests/conftest.py | 26 +++++++++++ tests/pytorch/test_datasets_pt.py | 15 ++++++ tests/tensorflow/test_datasets_tf.py | 15 ++++++ 6 files changed, 128 insertions(+) create mode 100644 doctr/datasets/mjsynth.py diff --git a/docs/source/modules/datasets.rst b/docs/source/modules/datasets.rst index e40b1c506a..d9b07df3e0 100644 --- a/docs/source/modules/datasets.rst +++ b/docs/source/modules/datasets.rst @@ -27,6 +27,7 @@ Public datasets .. autoclass:: IC03 .. autoclass:: IC13 .. autoclass:: IMGUR5K +.. autoclass:: MJSynth docTR synthetic datasets ^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/doctr/datasets/__init__.py b/doctr/datasets/__init__.py index cd187271b1..92e1ff6831 100644 --- a/doctr/datasets/__init__.py +++ b/doctr/datasets/__init__.py @@ -9,6 +9,7 @@ from .ic13 import * from .iiit5k import * from .imgur5k import * +from .mjsynth import * from .ocr import * from .recognition import * from .sroie import * diff --git a/doctr/datasets/mjsynth.py b/doctr/datasets/mjsynth.py new file mode 100644 index 0000000000..1989ea7285 --- /dev/null +++ b/doctr/datasets/mjsynth.py @@ -0,0 +1,70 @@ +# Copyright (C) 2021-2022, Mindee. + +# This program is licensed under the Apache License version 2. +# See LICENSE or go to for full license details. + +import os +from typing import Any, Dict, List, Tuple + +from tqdm import tqdm + +from .datasets import AbstractDataset + +__all__ = ["MJSynth"] + + +class MJSynth(AbstractDataset): + """MJSynth dataset from `"Synthetic Data and Artificial Neural Networks for Natural Scene Text Recognition" + `_. + + Example:: + >>> # NOTE: This is a pure recognition dataset without bounding box labels. + >>> # NOTE: You need to download the dataset. + >>> from doctr.datasets import MJSynth + >>> train_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px", + >>> label_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt", + >>> train=True) + >>> img, target = train_set[0] + >>> test_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px", + >>> labels_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt") + >>> train=False) + >>> img, target = test_set[0] + + Args: + img_folder: folder with all the images of the dataset + labels_path: folder with all annotation files for the images + train: whether the subset should be the training one + **kwargs: keyword arguments from `AbstractDataset`. + """ + + def __init__( + self, + img_folder: str, + labels_path: str, + train: bool = True, + **kwargs: Any, + ) -> None: + super().__init__(img_folder, **kwargs) + + # File existence check + if not os.path.exists(labels_path) or not os.path.exists(img_folder): + raise FileNotFoundError( + f"unable to locate {labels_path if not os.path.exists(labels_path) else img_folder}") + + self.data: List[Tuple[str, Dict[str, str]]] = [] + self.train = train + + with open(labels_path) as f: + img_paths = f.readlines() + + train_samples = int(len(img_paths) * 0.85) + set_slice = slice(train_samples) if self.train else slice(train_samples, None) + + for path in tqdm(iterable=img_paths[set_slice], desc='Unpacking MJSynth', total=len(img_paths[set_slice])): + label = path.split('_')[1] + img_path = os.path.join(img_folder, path[2:]).strip() + + self.data.append((img_path, dict(labels=label))) + + def extra_repr(self) -> str: + return f"train={self.train}" diff --git a/tests/conftest.py b/tests/conftest.py index 1d34a9566f..f6cbb42f7d 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -599,3 +599,29 @@ def mock_ic03_dataset(tmpdir_factory, mock_image_stream): archive_path = root.join('ic03_train.zip') shutil.make_archive(root.join('ic03_train'), 'zip', str(ic03_root)) return str(archive_path) + + +@pytest.fixture(scope="session") +def mock_mjsynth_dataset(tmpdir_factory, mock_image_stream): + root = tmpdir_factory.mktemp('datasets') + mjsynth_root = root.mkdir('mjsynth') + image_folder = mjsynth_root.mkdir("images") + label_file = mjsynth_root.join("imlist.txt") + labels = [ + "./mjsynth/images/12_I_34.jpg\n", + "./mjsynth/images/12_am_34.jpg\n", + "./mjsynth/images/12_a_34.jpg\n", + "./mjsynth/images/12_Jedi_34.jpg\n", + "./mjsynth/images/12_!_34.jpg\n", + ] + + with open(label_file, "w") as f: + for label in labels: + f.write(label) + + file = BytesIO(mock_image_stream) + for i in ['I', 'am', 'a', 'Jedi', '!']: + fn = image_folder.join(f"12_{i}_34.jpg") + with open(fn, 'wb') as f: + f.write(file.getbuffer()) + return str(root), str(label_file) diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py index a8f015a177..1301a87367 100644 --- a/tests/pytorch/test_datasets_pt.py +++ b/tests/pytorch/test_datasets_pt.py @@ -505,3 +505,18 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): _validate_dataset_recognition_part(ds, input_size) else: _validate_dataset(ds, input_size, is_polygons=rotate) + + +# NOTE: following datasets are only for recognition task + +def test_mjsynth_dataset(mock_mjsynth_dataset): + input_size = (32, 128) + ds = datasets.MJSynth( + *mock_mjsynth_dataset, + img_transforms=Resize(input_size, preserve_aspect_ratio=True), + ) + + assert len(ds) == 4 # Actual set has 7581382 train and 1337891 test samples + image, target = ds[0] + assert repr(ds) == f"MJSynth(train={True})" + _validate_dataset_recognition_part(ds, input_size) diff --git a/tests/tensorflow/test_datasets_tf.py b/tests/tensorflow/test_datasets_tf.py index 4d32ec7c50..d7ea9a6a5a 100644 --- a/tests/tensorflow/test_datasets_tf.py +++ b/tests/tensorflow/test_datasets_tf.py @@ -490,3 +490,18 @@ def test_ic03(input_size, num_samples, rotate, recognition, mock_ic03_dataset): _validate_dataset_recognition_part(ds, input_size) else: _validate_dataset(ds, input_size, is_polygons=rotate) + + +# NOTE: following datasets are only for recognition task + +def test_mjsynth_dataset(mock_mjsynth_dataset): + input_size = (32, 128) + ds = datasets.MJSynth( + *mock_mjsynth_dataset, + img_transforms=Resize(input_size, preserve_aspect_ratio=True), + ) + + assert len(ds) == 4 # Actual set has 7581382 train and 1337891 test samples + image, target = ds[0] + assert repr(ds) == f"MJSynth(train={True})" + _validate_dataset_recognition_part(ds, input_size) From c1b714e7ee5398278ece95de90d15a64e029e9b9 Mon Sep 17 00:00:00 2001 From: felixdittrich92 Date: Mon, 21 Feb 2022 13:49:29 +0100 Subject: [PATCH 3/7] apply changes --- doctr/datasets/mjsynth.py | 24 ++++++++++++------------ tests/pytorch/test_datasets_pt.py | 18 +++++++++--------- tests/tensorflow/test_datasets_tf.py | 18 +++++++++--------- 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/doctr/datasets/mjsynth.py b/doctr/datasets/mjsynth.py index 1989ea7285..7245907975 100644 --- a/doctr/datasets/mjsynth.py +++ b/doctr/datasets/mjsynth.py @@ -22,17 +22,17 @@ class MJSynth(AbstractDataset): >>> # NOTE: You need to download the dataset. >>> from doctr.datasets import MJSynth >>> train_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px", - >>> label_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt", - >>> train=True) + >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt", + >>> train=True) >>> img, target = train_set[0] >>> test_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px", - >>> labels_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt") - >>> train=False) + >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt") + >>> train=False) >>> img, target = test_set[0] Args: img_folder: folder with all the images of the dataset - labels_path: folder with all annotation files for the images + label_path: path to the file with the labels train: whether the subset should be the training one **kwargs: keyword arguments from `AbstractDataset`. """ @@ -40,28 +40,28 @@ class MJSynth(AbstractDataset): def __init__( self, img_folder: str, - labels_path: str, + label_path: str, train: bool = True, **kwargs: Any, ) -> None: super().__init__(img_folder, **kwargs) # File existence check - if not os.path.exists(labels_path) or not os.path.exists(img_folder): + if not os.path.exists(label_path) or not os.path.exists(img_folder): raise FileNotFoundError( - f"unable to locate {labels_path if not os.path.exists(labels_path) else img_folder}") + f"unable to locate {label_path if not os.path.exists(label_path) else img_folder}") - self.data: List[Tuple[str, Dict[str, str]]] = [] + self.data: List[Tuple[str, Dict[str, Any]]] = [] self.train = train - with open(labels_path) as f: + with open(label_path) as f: img_paths = f.readlines() - train_samples = int(len(img_paths) * 0.85) + train_samples = int(len(img_paths) * 0.9) set_slice = slice(train_samples) if self.train else slice(train_samples, None) for path in tqdm(iterable=img_paths[set_slice], desc='Unpacking MJSynth', total=len(img_paths[set_slice])): - label = path.split('_')[1] + label = [path.split('_')[1]] img_path = os.path.join(img_folder, path[2:]).strip() self.data.append((img_path, dict(labels=label))) diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py index 1301a87367..86e9f0a943 100644 --- a/tests/pytorch/test_datasets_pt.py +++ b/tests/pytorch/test_datasets_pt.py @@ -10,7 +10,7 @@ from doctr.transforms import Resize -def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False): +def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False, has_boxes=True): # Fetch one sample img, target = ds[0] @@ -18,17 +18,18 @@ def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_poly assert img.shape == (3, *input_size) assert img.dtype == torch.float32 assert isinstance(target, dict) - assert isinstance(target['boxes'], np.ndarray) and target['boxes'].dtype == np.float32 - if is_polygons: - assert target['boxes'].ndim == 3 and target['boxes'].shape[1:] == (4, 2) - else: - assert target['boxes'].ndim == 2 and target['boxes'].shape[1:] == (4,) - assert np.all(np.logical_and(target['boxes'] <= 1, target['boxes'] >= 0)) + if has_boxes: + assert isinstance(target['boxes'], np.ndarray) and target['boxes'].dtype == np.float32 + if is_polygons: + assert target['boxes'].ndim == 3 and target['boxes'].shape[1:] == (4, 2) + else: + assert target['boxes'].ndim == 2 and target['boxes'].shape[1:] == (4,) + assert np.all(np.logical_and(target['boxes'] <= 1, target['boxes'] >= 0)) + assert len(target['labels']) == len(target['boxes']) if class_indices: assert isinstance(target['labels'], np.ndarray) and target['labels'].dtype == np.int64 else: assert isinstance(target['labels'], list) and all(isinstance(s, str) for s in target['labels']) - assert len(target['labels']) == len(target['boxes']) # Check batching loader = DataLoader( @@ -517,6 +518,5 @@ def test_mjsynth_dataset(mock_mjsynth_dataset): ) assert len(ds) == 4 # Actual set has 7581382 train and 1337891 test samples - image, target = ds[0] assert repr(ds) == f"MJSynth(train={True})" _validate_dataset_recognition_part(ds, input_size) diff --git a/tests/tensorflow/test_datasets_tf.py b/tests/tensorflow/test_datasets_tf.py index d7ea9a6a5a..e2fb2e4529 100644 --- a/tests/tensorflow/test_datasets_tf.py +++ b/tests/tensorflow/test_datasets_tf.py @@ -10,7 +10,7 @@ from doctr.transforms import Resize -def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False): +def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False, has_boxes=True): # Fetch one sample img, target = ds[0] @@ -18,17 +18,18 @@ def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_poly assert img.shape == (*input_size, 3) assert img.dtype == tf.float32 assert isinstance(target, dict) - assert isinstance(target['boxes'], np.ndarray) and target['boxes'].dtype == np.float32 - if is_polygons: - assert target['boxes'].ndim == 3 and target['boxes'].shape[1:] == (4, 2) - else: - assert target['boxes'].ndim == 2 and target['boxes'].shape[1:] == (4,) - assert np.all(np.logical_and(target['boxes'] <= 1, target['boxes'] >= 0)) + if has_boxes: + assert isinstance(target['boxes'], np.ndarray) and target['boxes'].dtype == np.float32 + if is_polygons: + assert target['boxes'].ndim == 3 and target['boxes'].shape[1:] == (4, 2) + else: + assert target['boxes'].ndim == 2 and target['boxes'].shape[1:] == (4,) + assert np.all(np.logical_and(target['boxes'] <= 1, target['boxes'] >= 0)) + assert len(target['labels']) == len(target['boxes']) if class_indices: assert isinstance(target['labels'], np.ndarray) and target['labels'].dtype == np.int64 else: assert isinstance(target['labels'], list) and all(isinstance(s, str) for s in target['labels']) - assert len(target['labels']) == len(target['boxes']) # Check batching loader = DataLoader(ds, batch_size=batch_size) @@ -502,6 +503,5 @@ def test_mjsynth_dataset(mock_mjsynth_dataset): ) assert len(ds) == 4 # Actual set has 7581382 train and 1337891 test samples - image, target = ds[0] assert repr(ds) == f"MJSynth(train={True})" _validate_dataset_recognition_part(ds, input_size) From d7bbe819849fb4d773fa191d75f50688b8ac0f8a Mon Sep 17 00:00:00 2001 From: felixdittrich92 Date: Mon, 21 Feb 2022 21:29:11 +0100 Subject: [PATCH 4/7] rename --- tests/pytorch/test_datasets_pt.py | 4 ++-- tests/tensorflow/test_datasets_tf.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py index 86e9f0a943..813664bb99 100644 --- a/tests/pytorch/test_datasets_pt.py +++ b/tests/pytorch/test_datasets_pt.py @@ -10,7 +10,7 @@ from doctr.transforms import Resize -def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False, has_boxes=True): +def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False, target_includes_boxes=True): # Fetch one sample img, target = ds[0] @@ -18,7 +18,7 @@ def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_poly assert img.shape == (3, *input_size) assert img.dtype == torch.float32 assert isinstance(target, dict) - if has_boxes: + if target_includes_boxes: assert isinstance(target['boxes'], np.ndarray) and target['boxes'].dtype == np.float32 if is_polygons: assert target['boxes'].ndim == 3 and target['boxes'].shape[1:] == (4, 2) diff --git a/tests/tensorflow/test_datasets_tf.py b/tests/tensorflow/test_datasets_tf.py index e2fb2e4529..2977536f3a 100644 --- a/tests/tensorflow/test_datasets_tf.py +++ b/tests/tensorflow/test_datasets_tf.py @@ -10,7 +10,7 @@ from doctr.transforms import Resize -def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False, has_boxes=True): +def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False, target_includes_boxes=True): # Fetch one sample img, target = ds[0] @@ -18,7 +18,7 @@ def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_poly assert img.shape == (*input_size, 3) assert img.dtype == tf.float32 assert isinstance(target, dict) - if has_boxes: + if target_includes_boxes: assert isinstance(target['boxes'], np.ndarray) and target['boxes'].dtype == np.float32 if is_polygons: assert target['boxes'].ndim == 3 and target['boxes'].shape[1:] == (4, 2) From 4f45239090f0dd954005d4b52b0145264e4645cb Mon Sep 17 00:00:00 2001 From: felixdittrich92 Date: Fri, 25 Mar 2022 11:15:45 +0100 Subject: [PATCH 5/7] update --- docs/source/index.rst | 1 + doctr/datasets/mjsynth.py | 23 +++++++++++------------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/docs/source/index.rst b/docs/source/index.rst index 980aa2e3a8..e0c3fc78fb 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -62,6 +62,7 @@ Supported datasets * IC03 from `ICDAR 2003 `_. * IC13 from `ICDAR 2013 `_. * IMGUR5K from `"TextStyleBrush: Transfer of Text Aesthetics from a Single Example" `_. +* MJSynth from `"Synthetic Data and Artificial Neural Networks for Natural Scene Text Recognition" `_. .. toctree:: diff --git a/doctr/datasets/mjsynth.py b/doctr/datasets/mjsynth.py index 7245907975..820e06a7b5 100644 --- a/doctr/datasets/mjsynth.py +++ b/doctr/datasets/mjsynth.py @@ -17,18 +17,17 @@ class MJSynth(AbstractDataset): """MJSynth dataset from `"Synthetic Data and Artificial Neural Networks for Natural Scene Text Recognition" `_. - Example:: - >>> # NOTE: This is a pure recognition dataset without bounding box labels. - >>> # NOTE: You need to download the dataset. - >>> from doctr.datasets import MJSynth - >>> train_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px", - >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt", - >>> train=True) - >>> img, target = train_set[0] - >>> test_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px", - >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt") - >>> train=False) - >>> img, target = test_set[0] + >>> # NOTE: This is a pure recognition dataset without bounding box labels. + >>> # NOTE: You need to download the dataset. + >>> from doctr.datasets import MJSynth + >>> train_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px", + >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt", + >>> train=True) + >>> img, target = train_set[0] + >>> test_set = MJSynth(img_folder="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px", + >>> label_path="/path/to/mjsynth/mnt/ramdisk/max/90kDICT32px/imlist.txt") + >>> train=False) + >>> img, target = test_set[0] Args: img_folder: folder with all the images of the dataset From 7877ac5897df1cd5d3707e496fbbb9f8f21c51ab Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 27 Apr 2022 21:59:28 +0200 Subject: [PATCH 6/7] update --- tests/pytorch/test_datasets_pt.py | 17 ++++++++--------- tests/tensorflow/test_datasets_tf.py | 17 ++++++++--------- 2 files changed, 16 insertions(+), 18 deletions(-) diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py index 813664bb99..d7c4b92c67 100644 --- a/tests/pytorch/test_datasets_pt.py +++ b/tests/pytorch/test_datasets_pt.py @@ -10,7 +10,7 @@ from doctr.transforms import Resize -def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False, target_includes_boxes=True): +def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False): # Fetch one sample img, target = ds[0] @@ -18,14 +18,13 @@ def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_poly assert img.shape == (3, *input_size) assert img.dtype == torch.float32 assert isinstance(target, dict) - if target_includes_boxes: - assert isinstance(target['boxes'], np.ndarray) and target['boxes'].dtype == np.float32 - if is_polygons: - assert target['boxes'].ndim == 3 and target['boxes'].shape[1:] == (4, 2) - else: - assert target['boxes'].ndim == 2 and target['boxes'].shape[1:] == (4,) - assert np.all(np.logical_and(target['boxes'] <= 1, target['boxes'] >= 0)) - assert len(target['labels']) == len(target['boxes']) + assert isinstance(target['boxes'], np.ndarray) and target['boxes'].dtype == np.float32 + if is_polygons: + assert target['boxes'].ndim == 3 and target['boxes'].shape[1:] == (4, 2) + else: + assert target['boxes'].ndim == 2 and target['boxes'].shape[1:] == (4,) + assert np.all(np.logical_and(target['boxes'] <= 1, target['boxes'] >= 0)) + assert len(target['labels']) == len(target['boxes']) if class_indices: assert isinstance(target['labels'], np.ndarray) and target['labels'].dtype == np.int64 else: diff --git a/tests/tensorflow/test_datasets_tf.py b/tests/tensorflow/test_datasets_tf.py index 2977536f3a..f49c6e0fd0 100644 --- a/tests/tensorflow/test_datasets_tf.py +++ b/tests/tensorflow/test_datasets_tf.py @@ -10,7 +10,7 @@ from doctr.transforms import Resize -def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False, target_includes_boxes=True): +def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_polygons=False): # Fetch one sample img, target = ds[0] @@ -18,14 +18,13 @@ def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_poly assert img.shape == (*input_size, 3) assert img.dtype == tf.float32 assert isinstance(target, dict) - if target_includes_boxes: - assert isinstance(target['boxes'], np.ndarray) and target['boxes'].dtype == np.float32 - if is_polygons: - assert target['boxes'].ndim == 3 and target['boxes'].shape[1:] == (4, 2) - else: - assert target['boxes'].ndim == 2 and target['boxes'].shape[1:] == (4,) - assert np.all(np.logical_and(target['boxes'] <= 1, target['boxes'] >= 0)) - assert len(target['labels']) == len(target['boxes']) + assert isinstance(target['boxes'], np.ndarray) and target['boxes'].dtype == np.float32 + if is_polygons: + assert target['boxes'].ndim == 3 and target['boxes'].shape[1:] == (4, 2) + else: + assert target['boxes'].ndim == 2 and target['boxes'].shape[1:] == (4,) + assert np.all(np.logical_and(target['boxes'] <= 1, target['boxes'] >= 0)) + assert len(target['labels']) == len(target['boxes']) if class_indices: assert isinstance(target['labels'], np.ndarray) and target['labels'].dtype == np.int64 else: From d42214e69c162d88774f39c3ed7e6d1cad162c72 Mon Sep 17 00:00:00 2001 From: felix Date: Wed, 27 Apr 2022 22:01:54 +0200 Subject: [PATCH 7/7] fix tests --- tests/pytorch/test_datasets_pt.py | 2 +- tests/tensorflow/test_datasets_tf.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/tests/pytorch/test_datasets_pt.py b/tests/pytorch/test_datasets_pt.py index d7c4b92c67..f4b96ae389 100644 --- a/tests/pytorch/test_datasets_pt.py +++ b/tests/pytorch/test_datasets_pt.py @@ -24,11 +24,11 @@ def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_poly else: assert target['boxes'].ndim == 2 and target['boxes'].shape[1:] == (4,) assert np.all(np.logical_and(target['boxes'] <= 1, target['boxes'] >= 0)) - assert len(target['labels']) == len(target['boxes']) if class_indices: assert isinstance(target['labels'], np.ndarray) and target['labels'].dtype == np.int64 else: assert isinstance(target['labels'], list) and all(isinstance(s, str) for s in target['labels']) + assert len(target['labels']) == len(target['boxes']) # Check batching loader = DataLoader( diff --git a/tests/tensorflow/test_datasets_tf.py b/tests/tensorflow/test_datasets_tf.py index f49c6e0fd0..c7f032713a 100644 --- a/tests/tensorflow/test_datasets_tf.py +++ b/tests/tensorflow/test_datasets_tf.py @@ -24,11 +24,11 @@ def _validate_dataset(ds, input_size, batch_size=2, class_indices=False, is_poly else: assert target['boxes'].ndim == 2 and target['boxes'].shape[1:] == (4,) assert np.all(np.logical_and(target['boxes'] <= 1, target['boxes'] >= 0)) - assert len(target['labels']) == len(target['boxes']) if class_indices: assert isinstance(target['labels'], np.ndarray) and target['labels'].dtype == np.int64 else: assert isinstance(target['labels'], list) and all(isinstance(s, str) for s in target['labels']) + assert len(target['labels']) == len(target['boxes']) # Check batching loader = DataLoader(ds, batch_size=batch_size)