diff --git a/LICENSE b/LICENSE index 1da1952..9f3e082 100644 --- a/LICENSE +++ b/LICENSE @@ -1,21 +1,21 @@ -MIT License - -Copyright (c) 2024 bang123-box - -Permission is hereby granted, free of charge, to any person obtaining a copy -of this software and associated documentation files (the "Software"), to deal -in the Software without restriction, including without limitation the rights -to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -copies of the Software, and to permit persons to whom the Software is -furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all -copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -SOFTWARE. +MIT License + +Copyright (c) 2024 bang123-box + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/configs/charset/36_lowercase.yaml b/configs/charset/36_lowercase.yaml new file mode 100644 index 0000000..ce2a5a0 --- /dev/null +++ b/configs/charset/36_lowercase.yaml @@ -0,0 +1,3 @@ +# @package _global_ +model: + charset_train: "0123456789abcdefghijklmnopqrstuvwxyz" diff --git a/configs/charset/62_mixed-case.yaml b/configs/charset/62_mixed-case.yaml new file mode 100644 index 0000000..07db844 --- /dev/null +++ b/configs/charset/62_mixed-case.yaml @@ -0,0 +1,3 @@ +# @package _global_ +model: + charset_train: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ" diff --git a/configs/charset/94_full.yaml b/configs/charset/94_full.yaml new file mode 100644 index 0000000..186bf42 --- /dev/null +++ b/configs/charset/94_full.yaml @@ -0,0 +1,3 @@ +# @package _global_ +model: + charset_train: "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~" diff --git a/configs/dataset/real.yaml b/configs/dataset/real.yaml new file mode 100644 index 0000000..786042d --- /dev/null +++ b/configs/dataset/real.yaml @@ -0,0 +1,3 @@ +# @package _global_ +data: + train_dir: real diff --git a/configs/dataset/synth.yaml b/configs/dataset/synth.yaml new file mode 100644 index 0000000..1d56808 --- /dev/null +++ b/configs/dataset/synth.yaml @@ -0,0 +1,3 @@ +# @package _global_ +data: + train_dir: synth diff --git a/configs/dataset/union14m.yaml b/configs/dataset/union14m.yaml new file mode 100644 index 0000000..e0141cb --- /dev/null +++ b/configs/dataset/union14m.yaml @@ -0,0 +1,3 @@ +# @package _global_ +data: + train_dir: Union14M \ No newline at end of file diff --git a/configs/main.yaml b/configs/main.yaml new file mode 100644 index 0000000..13064fa --- /dev/null +++ b/configs/main.yaml @@ -0,0 +1,51 @@ +defaults: + - _self_ + - model: cfe + - charset: 36_lowercase # 94_full, 36_lowercase + - dataset: synth + +model: + _convert_: all + img_size: [32, 128] # [ height, width ] + max_label_length: 25 + # The ordering in charset_train matters. It determines the token IDs assigned to each character. + charset_train: ??? + # For charset_test, ordering doesn't matter. + charset_test: "0123456789abcdefghijklmnopqrstuvwxyz" + batch_size: 384 + weight_decay: 0.0 + warmup_pct: 0.075 # equivalent to 1.5 epochs of warm up + +data: + _target_: strhub.data.module.SceneTextDataModule + root_dir: /home/zbb/data + train_dir: ??? + batch_size: ${model.batch_size} + img_size: ${model.img_size} + charset_train: ${model.charset_train} + charset_test: ${model.charset_test} + max_label_length: ${model.max_label_length} + remove_whitespace: true + normalize_unicode: true + augment: True + num_workers: 12 + +trainer: + _target_: pytorch_lightning.Trainer + _convert_: all + val_check_interval: 2000 + max_epochs: 20 + gradient_clip_val: 20 + accelerator: gpu + devices: 4 + +ckpt_path: null +pretrained: null + +hydra: + output_subdir: config + run: + dir: ./output/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} + sweep: + dir: multirun/${model.name}/${now:%Y-%m-%d}_${now:%H-%M-%S} + subdir: ${hydra.job.override_dirname} diff --git a/configs/model/cfe.yaml b/configs/model/cfe.yaml new file mode 100644 index 0000000..f3e96dd --- /dev/null +++ b/configs/model/cfe.yaml @@ -0,0 +1,45 @@ +name: cfe +_target_: strhub.models.cfe.system.CFE + +# Architecture +num_control_points: 20 +enc_mlp_ratio: 4 +window_size: [[7, 11], [7, 11], [7, 11]] +merge_types: 'Conv' +local_type: 'r2' +prenorm: False +tps: False +use_pe: True +cclossexist: True +cc_weights: 0.2 +fpn_layers: [0,1,2] +dec_mlp_ratio: 4 +dec_depth: 1 + +# base +embed_dim: [128,256,384] +enc_num_heads: [4,8,12] +depth: [3,6,9] +mixer_types: ['Local', 8, "Global", 10] +decoder_dim: 256 +dec_num_heads: 8 + +# small +# embed_dim: [96,192,256] 64,128,256 +# enc_num_heads: [3,6,8], 2,4,8 +# depth: [3,6,6] 3,6,3 +# mixer_types: ['Local', 8, "Global", 7] +# decoder_dim: 192 +# dec_num_heads: 6 + +## tiny +# embed_dim: [64,128,256] +# enc_num_heads: [2,4,8] +# depth: [3,6,3] +# mixer_types: ['Local', 6, "Global", 6] +# decoder_dim: 128 +# dec_num_heads: 4 + +# Training +lr: 5e-4 +dropout: 0.1 \ No newline at end of file diff --git a/requirements/bench.in b/requirements/bench.in new file mode 100644 index 0000000..9a4199c --- /dev/null +++ b/requirements/bench.in @@ -0,0 +1,4 @@ +-c ${CONSTRAINTS} + +hydra-core >=1.2.0 +fvcore >=0.1.5.post20220512 diff --git a/requirements/bench.txt b/requirements/bench.txt new file mode 100644 index 0000000..83d370d --- /dev/null +++ b/requirements/bench.txt @@ -0,0 +1,17 @@ +antlr4-python3-runtime==4.9.3 +fvcore==0.1.5.post20221221 +hydra-core==1.3.2 +importlib-resources==5.12.0 +iopath==0.1.10 +numpy==1.24.3 +omegaconf==2.3.0 +packaging==23.1 +pillow==9.5.0 +portalocker==2.7.0 +pyyaml==6.0 +tabulate==0.9.0 +termcolor==2.3.0 +tqdm==4.65.0 +typing-extensions==4.6.2 +yacs==0.1.8 +zipp==3.15.0 diff --git a/requirements/constraints.txt b/requirements/constraints.txt new file mode 100644 index 0000000..272d212 --- /dev/null +++ b/requirements/constraints.txt @@ -0,0 +1,400 @@ +--extra-index-url https://download.pytorch.org/whl/cpu + +aiohttp==3.8.4 + # via fsspec +aiosignal==1.3.1 + # via + # aiohttp + # ray +antlr4-python3-runtime==4.9.3 + # via + # hydra-core + # omegaconf +asttokens==2.2.1 + # via stack-data +async-timeout==4.0.2 + # via aiohttp +attrs==23.1.0 + # via + # aiohttp + # jsonschema + # ray +ax-platform==0.3.2 + # via -r requirements/tune.in +backcall==0.2.0 + # via ipython +botorch==0.8.5 + # via ax-platform +certifi==2023.5.7 + # via requests +charset-normalizer==3.1.0 + # via + # aiohttp + # requests +click==8.0.4 + # via + # nltk + # ray +comm==0.1.3 + # via ipykernel +contourpy==1.0.7 + # via matplotlib +cycler==0.11.0 + # via matplotlib +debugpy==1.6.7 + # via ipykernel +decorator==5.1.1 + # via ipython +distlib==0.3.6 + # via virtualenv +executing==1.2.0 + # via stack-data +filelock==3.12.0 + # via + # huggingface-hub + # ray + # virtualenv +fonttools==4.39.4 + # via matplotlib +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal + # ray +fsspec==2023.5.0 + # via + # huggingface-hub + # pytorch-lightning +fvcore==0.1.5.post20221221 + # via -r requirements/bench.in +gpytorch==1.10 + # via botorch +grpcio==1.43.0 + # via ray +huggingface-hub==0.15.1 + # via timm +hydra-core==1.3.2 + # via + # -r requirements/bench.in + # -r requirements/tune.in +idna==3.4 + # via + # requests + # yarl +imageio==2.30.0 + # via + # imgaug + # scikit-image +imgaug==0.4.0 + # via + # -r requirements/train.in + # -r requirements/tune.in +importlib-metadata==6.6.0 + # via jupyter-client +importlib-resources==5.12.0 + # via + # hydra-core + # jsonschema + # matplotlib +iopath==0.1.10 + # via fvcore +ipykernel==6.23.1 + # via ipywidgets +ipython==8.12.2 + # via + # ipykernel + # ipywidgets +ipywidgets==8.0.6 + # via ax-platform +jedi==0.18.2 + # via ipython +jinja2==3.1.2 + # via ax-platform +joblib==1.2.0 + # via + # nltk + # scikit-learn +jsonschema==4.17.3 + # via ray +jupyter-client==8.2.0 + # via ipykernel +jupyter-core==5.3.0 + # via + # ipykernel + # jupyter-client +jupyterlab-widgets==3.0.7 + # via ipywidgets +kiwisolver==1.4.4 + # via matplotlib +lazy-loader==0.2 + # via scikit-image +lightning-utilities==0.8.0 + # via pytorch-lightning +linear-operator==0.4.0 + # via + # botorch + # gpytorch +lmdb==1.4.1 + # via + # -r requirements/test.in + # -r requirements/tune.in +markupsafe==2.1.2 + # via jinja2 +matplotlib==3.7.1 + # via imgaug +matplotlib-inline==0.1.6 + # via + # ipykernel + # ipython +msgpack==1.0.5 + # via ray +multidict==6.0.4 + # via + # aiohttp + # yarl +multipledispatch==0.6.0 + # via botorch +nest-asyncio==1.5.6 + # via ipykernel +networkx==3.1 + # via scikit-image +nltk==3.8.1 + # via -r requirements/core.in +numpy==1.24.3 + # via + # contourpy + # fvcore + # imageio + # imgaug + # matplotlib + # opencv-python + # opt-einsum + # pandas + # pyro-ppl + # pytorch-lightning + # pywavelets + # ray + # scikit-image + # scikit-learn + # scipy + # shapely + # tensorboardx + # tifffile + # torchmetrics + # torchvision +omegaconf==2.3.0 + # via hydra-core +opencv-python==4.7.0.72 + # via imgaug +opt-einsum==3.3.0 + # via pyro-ppl +packaging==23.1 + # via + # huggingface-hub + # hydra-core + # ipykernel + # lightning-utilities + # matplotlib + # plotly + # pytorch-lightning + # scikit-image + # tensorboardx + # torchmetrics +pandas==2.0.2 + # via + # ax-platform + # ray +parso==0.8.3 + # via jedi +pexpect==4.8.0 + # via ipython +pickleshare==0.7.5 + # via ipython +pillow==9.5.0 + # via + # -r requirements/test.in + # -r requirements/tune.in + # fvcore + # imageio + # imgaug + # matplotlib + # scikit-image + # torchvision +pkgutil-resolve-name==1.3.10 + # via jsonschema +platformdirs==3.5.1 + # via + # jupyter-core + # virtualenv +plotly==5.14.1 + # via ax-platform +portalocker==2.7.0 + # via iopath +prompt-toolkit==3.0.38 + # via ipython +protobuf==3.20.3 + # via + # ray + # tensorboardx +psutil==5.9.5 + # via ipykernel +ptyprocess==0.7.0 + # via pexpect +pure-eval==0.2.2 + # via stack-data +pygments==2.15.1 + # via ipython +pyparsing==3.0.9 + # via matplotlib +pyro-api==0.1.2 + # via pyro-ppl +pyro-ppl==1.8.4 + # via botorch +pyrsistent==0.19.3 + # via jsonschema +python-dateutil==2.8.2 + # via + # jupyter-client + # matplotlib + # pandas +pytorch-lightning==1.9.5 + # via -r requirements/core.in +pytz==2023.3 + # via pandas +pywavelets==1.4.1 + # via scikit-image +pyyaml==6.0 + # via + # -r requirements/core.in + # fvcore + # huggingface-hub + # omegaconf + # pytorch-lightning + # ray + # timm + # yacs +pyzmq==25.1.0 + # via + # ipykernel + # jupyter-client +ray==1.13.0 + # via -r requirements/tune.in +regex==2023.5.5 + # via nltk +requests==2.31.0 + # via + # fsspec + # huggingface-hub + # ray + # torchvision +safetensors==0.3.1 + # via timm +scikit-image==0.20.0 + # via imgaug +scikit-learn==1.2.2 + # via + # ax-platform + # gpytorch +scipy==1.9.1 + # via + # ax-platform + # botorch + # imgaug + # linear-operator + # scikit-image + # scikit-learn +shapely==2.0.1 + # via imgaug +six==1.16.0 + # via + # asttokens + # grpcio + # imgaug + # multipledispatch + # python-dateutil +stack-data==0.6.2 + # via ipython +tabulate==0.9.0 + # via + # fvcore + # ray +tenacity==8.2.2 + # via plotly +tensorboardx==2.6 + # via ray +termcolor==2.3.0 + # via fvcore +threadpoolctl==3.1.0 + # via scikit-learn +tifffile==2023.4.12 + # via scikit-image +timm==0.9.2 + # via -r requirements/core.in +torch==1.13.1+cpu + # via + # -r requirements/core.in + # botorch + # linear-operator + # pyro-ppl + # pytorch-lightning + # timm + # torchmetrics + # torchvision +torchmetrics==0.11.4 + # via pytorch-lightning +torchvision==0.14.1+cpu + # via + # -r requirements/core.in + # timm +tornado==6.3.2 + # via + # ipykernel + # jupyter-client +tqdm==4.65.0 + # via + # -r requirements/test.in + # fvcore + # huggingface-hub + # iopath + # nltk + # pyro-ppl + # pytorch-lightning +traitlets==5.9.0 + # via + # comm + # ipykernel + # ipython + # ipywidgets + # jupyter-client + # jupyter-core + # matplotlib-inline +typeguard==2.13.3 + # via ax-platform +typing-extensions==4.6.2 + # via + # huggingface-hub + # iopath + # ipython + # lightning-utilities + # pytorch-lightning + # torch + # torchmetrics + # torchvision +tzdata==2023.3 + # via pandas +urllib3==2.0.2 + # via requests +virtualenv==20.23.0 + # via ray +wcwidth==0.2.6 + # via prompt-toolkit +widgetsnbextension==4.0.7 + # via ipywidgets +yacs==0.1.8 + # via fvcore +yarl==1.9.2 + # via aiohttp +zipp==3.15.0 + # via + # importlib-metadata + # importlib-resources diff --git a/requirements/core.in b/requirements/core.in new file mode 100644 index 0000000..8be3a85 --- /dev/null +++ b/requirements/core.in @@ -0,0 +1,8 @@ +-c ${CONSTRAINTS} + +torch >=1.10.0, <2.0.0 +torchvision >=0.11.0, <0.15.0 +timm >=0.6.5 +pytorch-lightning >=1.7.0, <2.0.0 # TODO: refactor code to separate model from training code. +nltk >=3.7.0 # TODO: refactor/reorganize code. This is a train/test dependency. +PyYAML >=6.0.0 # TODO: can we move this to train/test? diff --git a/requirements/core.txt b/requirements/core.txt new file mode 100644 index 0000000..170f4dd --- /dev/null +++ b/requirements/core.txt @@ -0,0 +1,35 @@ + + +aiohttp==3.8.4 +aiosignal==1.3.1 +async-timeout==4.0.2 +attrs==23.1.0 +certifi==2023.5.7 +charset-normalizer==3.1.0 +click==8.0.4 +filelock==3.12.0 +frozenlist==1.3.3 +fsspec[http]==2023.5.0 +huggingface-hub==0.15.1 +idna==3.4 +joblib==1.2.0 +lightning-utilities==0.8.0 +multidict==6.0.4 +nltk==3.8.1 +numpy==1.24.3 +packaging==23.1 +pillow==9.5.0 +pytorch-lightning==1.9.5 +pyyaml==6.0 +regex==2023.5.5 +requests==2.31.0 +safetensors==0.3.1 +timm==0.9.2 +#torch==1.13.1+cpu +#torchmetrics==0.11.4 +torchmetrics +#torchvision==0.14.1+cpu +tqdm==4.65.0 +typing-extensions==4.6.2 +urllib3==2.0.2 +yarl==1.9.2 diff --git a/requirements/test.in b/requirements/test.in new file mode 100644 index 0000000..b4e843e --- /dev/null +++ b/requirements/test.in @@ -0,0 +1,5 @@ +-c ${CONSTRAINTS} + +lmdb >=1.3.0 +Pillow >=9.2.0 +tqdm >=4.64.0 diff --git a/requirements/test.txt b/requirements/test.txt new file mode 100644 index 0000000..76e591e --- /dev/null +++ b/requirements/test.txt @@ -0,0 +1,3 @@ +lmdb==1.4.1 +pillow==9.5.0 +tqdm==4.65.0 diff --git a/requirements/train.in b/requirements/train.in new file mode 100644 index 0000000..41ad6d2 --- /dev/null +++ b/requirements/train.in @@ -0,0 +1,6 @@ +-c ${CONSTRAINTS} + +lmdb >=1.3.0 +Pillow >=9.2.0 +imgaug >=0.4.0 +hydra-core >=1.2.0 diff --git a/requirements/train.txt b/requirements/train.txt new file mode 100644 index 0000000..49f08bd --- /dev/null +++ b/requirements/train.txt @@ -0,0 +1,57 @@ +tensorboard +antlr4-python3-runtime +#==4.9.3 +contourpy +#==1.0.7 +cycler +#==0.11.0 +fonttools +#==4.39.4 +hydra-core +#==1.3.2 +imageio +#==2.30.0 +imgaug +#==0.4.0 +importlib-resources +#==5.12.0 +kiwisolver +#==1.4.4 +lazy-loader +#==0.2 +lmdb +#==1.4.1 +matplotlib +#==3.7.1 +networkx +#==3.1 +numpy +#==1.24.3 +omegaconf +#==2.3.0 +opencv-python +#==4.7.0.72 +packaging +#==23.1 +pillow +#==9.5.0 +pyparsing +#==3.0.9 +python-dateutil +#==2.8.2 +pywavelets +#==1.4.1 +pyyaml +#==6.0 +scikit-image +#==0.20.0 +scipy +#==1.9.1 +shapely +#==2.0.1 +six +#==1.16.0 +tifffile +#==2023.4.12 +zipp +#==3.15.0 diff --git a/requirements/tune.in b/requirements/tune.in new file mode 100644 index 0000000..80a0ec0 --- /dev/null +++ b/requirements/tune.in @@ -0,0 +1,8 @@ +-c ${CONSTRAINTS} + +lmdb >=1.3.0 +Pillow >=9.2.0 +imgaug >=0.4.0 +hydra-core >=1.2.0 +ray[tune] >=1.13.0, <2.0.0 +ax-platform >=0.2.5.1 diff --git a/requirements/tune.txt b/requirements/tune.txt new file mode 100644 index 0000000..6f183ab --- /dev/null +++ b/requirements/tune.txt @@ -0,0 +1,101 @@ +aiosignal==1.3.1 +antlr4-python3-runtime==4.9.3 +asttokens==2.2.1 +attrs==23.1.0 +ax-platform==0.3.2 +backcall==0.2.0 +botorch==0.8.5 +certifi==2023.5.7 +charset-normalizer==3.1.0 +click==8.0.4 +comm==0.1.3 +contourpy==1.0.7 +cycler==0.11.0 +debugpy==1.6.7 +decorator==5.1.1 +distlib==0.3.6 +executing==1.2.0 +filelock==3.12.0 +fonttools==4.39.4 +frozenlist==1.3.3 +gpytorch==1.10 +grpcio==1.43.0 +hydra-core==1.3.2 +idna==3.4 +imageio==2.30.0 +imgaug==0.4.0 +importlib-metadata==6.6.0 +importlib-resources==5.12.0 +ipykernel==6.23.1 +ipython==8.12.2 +ipywidgets==8.0.6 +jedi==0.18.2 +jinja2==3.1.2 +joblib==1.2.0 +jsonschema==4.17.3 +jupyter-client==8.2.0 +jupyter-core==5.3.0 +jupyterlab-widgets==3.0.7 +kiwisolver==1.4.4 +lazy-loader==0.2 +linear-operator==0.4.0 +lmdb==1.4.1 +markupsafe==2.1.2 +matplotlib==3.7.1 +matplotlib-inline==0.1.6 +msgpack==1.0.5 +multipledispatch==0.6.0 +nest-asyncio==1.5.6 +networkx==3.1 +numpy==1.24.3 +omegaconf==2.3.0 +opencv-python==4.7.0.72 +opt-einsum==3.3.0 +packaging==23.1 +pandas==2.0.2 +parso==0.8.3 +pexpect==4.8.0 +pickleshare==0.7.5 +pillow==9.5.0 +pkgutil-resolve-name==1.3.10 +platformdirs==3.5.1 +plotly==5.14.1 +prompt-toolkit==3.0.38 +protobuf==3.20.3 +psutil==5.9.5 +ptyprocess==0.7.0 +pure-eval==0.2.2 +pygments==2.15.1 +pyparsing==3.0.9 +pyro-api==0.1.2 +pyro-ppl==1.8.4 +pyrsistent==0.19.3 +python-dateutil==2.8.2 +pytz==2023.3 +pywavelets==1.4.1 +pyyaml==6.0 +pyzmq==25.1.0 +ray[tune]==1.13.0 +requests==2.31.0 +scikit-image==0.20.0 +scikit-learn==1.2.2 +scipy==1.9.1 +shapely==2.0.1 +six==1.16.0 +stack-data==0.6.2 +tabulate==0.9.0 +tenacity==8.2.2 +tensorboardx==2.6 +threadpoolctl==3.1.0 +tifffile==2023.4.12 +tornado==6.3.2 +tqdm==4.65.0 +traitlets==5.9.0 +typeguard==2.13.3 +typing-extensions==4.6.2 +tzdata==2023.3 +urllib3==2.0.2 +virtualenv==20.23.0 +wcwidth==0.2.6 +widgetsnbextension==4.0.7 +zipp==3.15.0 diff --git a/strhub/__init__.py b/strhub/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/strhub/__pycache__/__init__.cpython-311.pyc b/strhub/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..649e0b6 Binary files /dev/null and b/strhub/__pycache__/__init__.cpython-311.pyc differ diff --git a/strhub/__pycache__/__init__.cpython-38.pyc b/strhub/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..f0bb8af Binary files /dev/null and b/strhub/__pycache__/__init__.cpython-38.pyc differ diff --git a/strhub/__pycache__/__init__.cpython-39.pyc b/strhub/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..3804979 Binary files /dev/null and b/strhub/__pycache__/__init__.cpython-39.pyc differ diff --git a/strhub/data/__init__.py b/strhub/data/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/strhub/data/__pycache__/__init__.cpython-311.pyc b/strhub/data/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000..6f6f0c2 Binary files /dev/null and b/strhub/data/__pycache__/__init__.cpython-311.pyc differ diff --git a/strhub/data/__pycache__/__init__.cpython-38.pyc b/strhub/data/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..979e9fb Binary files /dev/null and b/strhub/data/__pycache__/__init__.cpython-38.pyc differ diff --git a/strhub/data/__pycache__/__init__.cpython-39.pyc b/strhub/data/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..043d5cd Binary files /dev/null and b/strhub/data/__pycache__/__init__.cpython-39.pyc differ diff --git a/strhub/data/__pycache__/aa_overrides.cpython-38.pyc b/strhub/data/__pycache__/aa_overrides.cpython-38.pyc new file mode 100644 index 0000000..ad67ff6 Binary files /dev/null and b/strhub/data/__pycache__/aa_overrides.cpython-38.pyc differ diff --git a/strhub/data/__pycache__/aa_overrides.cpython-39.pyc b/strhub/data/__pycache__/aa_overrides.cpython-39.pyc new file mode 100644 index 0000000..b8fc243 Binary files /dev/null and b/strhub/data/__pycache__/aa_overrides.cpython-39.pyc differ diff --git a/strhub/data/__pycache__/augment.cpython-38.pyc b/strhub/data/__pycache__/augment.cpython-38.pyc new file mode 100644 index 0000000..3bfe403 Binary files /dev/null and b/strhub/data/__pycache__/augment.cpython-38.pyc differ diff --git a/strhub/data/__pycache__/augment.cpython-39.pyc b/strhub/data/__pycache__/augment.cpython-39.pyc new file mode 100644 index 0000000..fcccb99 Binary files /dev/null and b/strhub/data/__pycache__/augment.cpython-39.pyc differ diff --git a/strhub/data/__pycache__/dataset.cpython-311.pyc b/strhub/data/__pycache__/dataset.cpython-311.pyc new file mode 100644 index 0000000..aac968b Binary files /dev/null and b/strhub/data/__pycache__/dataset.cpython-311.pyc differ diff --git a/strhub/data/__pycache__/dataset.cpython-38.pyc b/strhub/data/__pycache__/dataset.cpython-38.pyc new file mode 100644 index 0000000..4fe6af3 Binary files /dev/null and b/strhub/data/__pycache__/dataset.cpython-38.pyc differ diff --git a/strhub/data/__pycache__/dataset.cpython-39.pyc b/strhub/data/__pycache__/dataset.cpython-39.pyc new file mode 100644 index 0000000..6f22aa7 Binary files /dev/null and b/strhub/data/__pycache__/dataset.cpython-39.pyc differ diff --git a/strhub/data/__pycache__/module.cpython-311.pyc b/strhub/data/__pycache__/module.cpython-311.pyc new file mode 100644 index 0000000..793293c Binary files /dev/null and b/strhub/data/__pycache__/module.cpython-311.pyc differ diff --git a/strhub/data/__pycache__/module.cpython-38.pyc b/strhub/data/__pycache__/module.cpython-38.pyc new file mode 100644 index 0000000..7f7a1e5 Binary files /dev/null and b/strhub/data/__pycache__/module.cpython-38.pyc differ diff --git a/strhub/data/__pycache__/module.cpython-39.pyc b/strhub/data/__pycache__/module.cpython-39.pyc new file mode 100644 index 0000000..0adddfd Binary files /dev/null and b/strhub/data/__pycache__/module.cpython-39.pyc differ diff --git a/strhub/data/__pycache__/utils.cpython-311.pyc b/strhub/data/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000..e602949 Binary files /dev/null and b/strhub/data/__pycache__/utils.cpython-311.pyc differ diff --git a/strhub/data/__pycache__/utils.cpython-38.pyc b/strhub/data/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000..8629db4 Binary files /dev/null and b/strhub/data/__pycache__/utils.cpython-38.pyc differ diff --git a/strhub/data/__pycache__/utils.cpython-39.pyc b/strhub/data/__pycache__/utils.cpython-39.pyc new file mode 100644 index 0000000..9f1851e Binary files /dev/null and b/strhub/data/__pycache__/utils.cpython-39.pyc differ diff --git a/strhub/data/aa_overrides.py b/strhub/data/aa_overrides.py new file mode 100644 index 0000000..7dcba71 --- /dev/null +++ b/strhub/data/aa_overrides.py @@ -0,0 +1,46 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Extends default ops to accept optional parameters.""" +from functools import partial + +from timm.data.auto_augment import _LEVEL_DENOM, _randomly_negate, LEVEL_TO_ARG, NAME_TO_OP, rotate + + +def rotate_expand(img, degrees, **kwargs): + """Rotate operation with expand=True to avoid cutting off the characters""" + kwargs['expand'] = True + return rotate(img, degrees, **kwargs) + + +def _level_to_arg(level, hparams, key, default): + magnitude = hparams.get(key, default) + level = (level / _LEVEL_DENOM) * magnitude + level = _randomly_negate(level) + return level, + + +def apply(): + # Overrides + NAME_TO_OP.update({ + 'Rotate': rotate_expand + }) + LEVEL_TO_ARG.update({ + 'Rotate': partial(_level_to_arg, key='rotate_deg', default=30.), + 'ShearX': partial(_level_to_arg, key='shear_x_pct', default=0.3), + 'ShearY': partial(_level_to_arg, key='shear_y_pct', default=0.3), + 'TranslateXRel': partial(_level_to_arg, key='translate_x_pct', default=0.45), + 'TranslateYRel': partial(_level_to_arg, key='translate_y_pct', default=0.45), + }) diff --git a/strhub/data/augment.py b/strhub/data/augment.py new file mode 100644 index 0000000..21dc9d2 --- /dev/null +++ b/strhub/data/augment.py @@ -0,0 +1,111 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from functools import partial + +import imgaug.augmenters as iaa +import numpy as np +from PIL import ImageFilter, Image +from timm.data import auto_augment + +from strhub.data import aa_overrides + +aa_overrides.apply() + +_OP_CACHE = {} + + +def _get_op(key, factory): + try: + op = _OP_CACHE[key] + except KeyError: + op = factory() + _OP_CACHE[key] = op + return op + + +def _get_param(level, img, max_dim_factor, min_level=1): + max_level = max(min_level, max_dim_factor * max(img.size)) + return round(min(level, max_level)) + + +def gaussian_blur(img, radius, **__): + radius = _get_param(radius, img, 0.02) + key = 'gaussian_blur_' + str(radius) + op = _get_op(key, lambda: ImageFilter.GaussianBlur(radius)) + return img.filter(op) + + +def motion_blur(img, k, **__): + k = _get_param(k, img, 0.08, 3) | 1 # bin to odd values + key = 'motion_blur_' + str(k) + op = _get_op(key, lambda: iaa.MotionBlur(k)) + return Image.fromarray(op(image=np.asarray(img))) + + +def gaussian_noise(img, scale, **_): + scale = _get_param(scale, img, 0.25) | 1 # bin to odd values + key = 'gaussian_noise_' + str(scale) + op = _get_op(key, lambda: iaa.AdditiveGaussianNoise(scale=scale)) + return Image.fromarray(op(image=np.asarray(img))) + + +def poisson_noise(img, lam, **_): + lam = _get_param(lam, img, 0.2) | 1 # bin to odd values + key = 'poisson_noise_' + str(lam) + op = _get_op(key, lambda: iaa.AdditivePoissonNoise(lam)) + return Image.fromarray(op(image=np.asarray(img))) + + +def _level_to_arg(level, _hparams, max): + level = max * level / auto_augment._LEVEL_DENOM + return level, + + +_RAND_TRANSFORMS = auto_augment._RAND_INCREASING_TRANSFORMS.copy() +_RAND_TRANSFORMS.remove('SharpnessIncreasing') # remove, interferes with *blur ops +_RAND_TRANSFORMS.extend([ + 'GaussianBlur', + # 'MotionBlur', + # 'GaussianNoise', + 'PoissonNoise' +]) +auto_augment.LEVEL_TO_ARG.update({ + 'GaussianBlur': partial(_level_to_arg, max=4), + 'MotionBlur': partial(_level_to_arg, max=20), + 'GaussianNoise': partial(_level_to_arg, max=0.1 * 255), + 'PoissonNoise': partial(_level_to_arg, max=40) +}) +auto_augment.NAME_TO_OP.update({ + 'GaussianBlur': gaussian_blur, + 'MotionBlur': motion_blur, + 'GaussianNoise': gaussian_noise, + 'PoissonNoise': poisson_noise +}) + + +def rand_augment_transform(magnitude=5, num_layers=3): + # These are tuned for magnitude=5, which means that effective magnitudes are half of these values. + hparams = { + 'rotate_deg': 30, + 'shear_x_pct': 0.9, + 'shear_y_pct': 0.2, + 'translate_x_pct': 0.10, + 'translate_y_pct': 0.30 + } + ra_ops = auto_augment.rand_augment_ops(magnitude, hparams=hparams, transforms=_RAND_TRANSFORMS) + # Supply weights to disable replacement in random selection (i.e. avoid applying the same op twice) + choice_weights = [1. / len(ra_ops) for _ in range(len(ra_ops))] + return auto_augment.RandAugment(ra_ops, num_layers, choice_weights) diff --git a/strhub/data/dataset.py b/strhub/data/dataset.py new file mode 100644 index 0000000..9abdcc5 --- /dev/null +++ b/strhub/data/dataset.py @@ -0,0 +1,211 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import glob +import cv2 +import os +import io +import numpy as np +import logging +import unicodedata +import torch +from scipy.cluster.vq import * +from pylab import * +from pathlib import Path, PurePath +from typing import Callable, Optional, Union + +import lmdb +import math +import json +from PIL import Image +from torchvision import transforms as T +from torch.utils.data import Dataset, ConcatDataset + +from strhub.data.utils import CharsetAdapter + +log = logging.getLogger(__name__) + + +def build_tree_dataset(root: Union[PurePath, str],*args, **kwargs): + try: + kwargs.pop('root') # prevent 'root' from being passed via kwargs + except KeyError: + pass + root = Path(root).absolute() + log.info(f'dataset root:\t{root}') + datasets = [] + for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True): + mdb = Path(mdb) + ds_name = str(mdb.parent.relative_to(root)) + ds_root = str(mdb.parent.absolute()) + dataset = LmdbDataset(ds_root, *args, **kwargs) + log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}') + datasets.append(dataset) + return ConcatDataset(datasets) + +def lmdb_unsupervised_tree_dataset(roots, *args, **kwargs): + if not isinstance(roots, list): + roots = [roots] + datasets = [] + for root in roots: + root = Path(root).absolute() + log.info(f'dataset root:\t{root}') + for mdb in glob.glob(str(root / '**/data.mdb'), recursive=True): + mdb = Path(mdb) + ds_name = str(mdb.parent.relative_to(root)) + ds_root = str(mdb.parent.absolute()) + dataset = UnSupervised_LmdbDataset(ds_root, *args, **kwargs) + log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}') + datasets.append(dataset) + return ConcatDataset(datasets) + + +def build_json_train_dataset(roots,data_path, *args, **kwargs): + if not isinstance(roots, list): + roots = [roots] + datasets = [] + for root in roots: + root = Path(root).absolute() + log.info(f'dataset root:\t{root}') + for mdb in glob.glob(str(root / '**.json'), recursive=True): + mdb = Path(mdb) + ds_name = str(mdb).split("\\")[-1] + dataset = JSonDataset(data_path, mdb, *args, **kwargs) + log.info(f'\tlmdb:\t{ds_name}\tnum samples: {len(dataset)}') + datasets.append(dataset) + return ConcatDataset(datasets) + + +class LmdbDataset(Dataset): + """Dataset interface to an LMDB database. + + It supports both labelled and unlabelled datasets. For unlabelled datasets, the image index itself is returned + as the label. Unicode characters are normalized by default. Case-sensitivity is inferred from the charset. + Labels are transformed according to the charset. + """ + def __init__(self, root: str, charset: str, max_label_len: int, min_image_dim: int = 0, + remove_whitespace: bool = True, normalize_unicode: bool = True, + unlabelled: bool = False, mask = False, transform: Optional[Callable] = None, train=True, img_size=[32, 128], transpose: bool = True, trans_aug: bool = False, p: float = 0.0): + self.mask = mask + self._env = None + self.root = root + self.unlabelled = unlabelled + self.transform = transform if train == True else None + self.labels = [] + self.train = train + self.img_w, self.img_h = img_size[1], img_size[0] + self.to_tensor = T.Compose([T.ToTensor(),T.Normalize(0.5, 0.5)]) + self.filtered_index_list = [] + self.num_samples = self._preprocess_labels(charset, remove_whitespace, normalize_unicode, + max_label_len, min_image_dim) + + def __del__(self): + if self._env is not None: + self._env.close() + self._env = None + + def _create_env(self): + return lmdb.open(self.root, max_readers=1, readonly=True, create=False, + readahead=False, meminit=False, lock=False) + + def resize_multiscales(self, img, borderType=cv2.BORDER_CONSTANT): + def _resize_ratio(img, ratio, fix_h=True): + if ratio * self.img_w < self.img_h: + if fix_h: + trg_h = self.img_h + else: + trg_h = int(ratio * self.img_w) + trg_w = self.img_w + else: + trg_h, trg_w = self.img_h, int(self.img_h / ratio) + img = cv2.resize(img, (trg_w, trg_h)) + pad_h, pad_w = (self.img_h - trg_h) / 2, (self.img_w - trg_w) / 2 + top, bottom = math.ceil(pad_h), math.floor(pad_h) + left, right = math.ceil(pad_w), math.floor(pad_w) + img = cv2.copyMakeBorder(img, top, bottom, left, right, borderType) + return Image.fromarray(img) + return img.resize((self.img_w,self.img_h)) + + def resize(self, img): + self.resize_multiscales(img, cv2.BORDER_REPLICATE) + + def _process_training(self, image): + if self.transform != None: image = self.transform(image) + return self.resize_multiscales(image, cv2.BORDER_REPLICATE) + + def _process_test(self, image): + return self.resize_multiscales(image, cv2.BORDER_REPLICATE) + + @property + def env(self): + if self._env is None: + self._env = self._create_env() + return self._env + + def _preprocess_labels(self, charset, remove_whitespace, normalize_unicode, max_label_len, min_image_dim): + charset_adapter = CharsetAdapter(charset) + with self._create_env() as env, env.begin() as txn: + num_samples = int(txn.get('num-samples'.encode())) + if self.unlabelled: + return num_samples + for index in range(num_samples): + index += 1 # lmdb starts with 1 + label_key = f'label-{index:09d}'.encode() + label = txn.get(label_key).decode() + # Normally, whitespace is removed from the labels. + if remove_whitespace: + label = ''.join(label.split()) + # Normalize unicode composites (if any) and convert to compatible ASCII characters + if normalize_unicode: + label = unicodedata.normalize('NFKD', label).encode('ascii', 'ignore').decode() + # Filter by length before removing unsupported characters. The original label might be too long. + if len(label) > max_label_len: + continue + label = charset_adapter(label) + # We filter out samples which don't contain any supported characters + if not label: + continue + # Filter images that are too small. + if min_image_dim > 0: + img_key = f'image-{index:09d}'.encode() + buf = io.BytesIO(txn.get(img_key)) + w, h = Image.open(buf).size + if w < self.min_image_dim or h < self.min_image_dim: + continue + self.labels.append(label) + self.filtered_index_list.append(index) + return len(self.labels) + + def __len__(self): + return self.num_samples + + def __getitem__(self, index): + if self.unlabelled: + label = index + else: + label = self.labels[index] + index = self.filtered_index_list[index] + + img_key = f'image-{index:09d}'.encode() + with self.env.begin() as txn: + imgbuf = txn.get(img_key) + buf = io.BytesIO(imgbuf) + img = Image.open(buf).convert('RGB') + + if self.train: + img = self._process_training(img) + else: + img = self._process_test(img) + img = self.to_tensor(img) + return img, label \ No newline at end of file diff --git a/strhub/data/module.py b/strhub/data/module.py new file mode 100644 index 0000000..790b812 --- /dev/null +++ b/strhub/data/module.py @@ -0,0 +1,123 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import os +from pathlib import PurePath +from typing import Optional, Callable, Sequence, Tuple + +import pytorch_lightning as pl +from torch.utils.data import DataLoader +from torchvision import transforms as T + +from .dataset import build_tree_dataset, LmdbDataset + + +class SceneTextDataModule(pl.LightningDataModule): + TEST_BENCHMARK_SUB = ('IIIT5k', 'SVT', 'IC13_857', 'IC15_1811', 'SVTP', 'CUTE80') + TEST_BENCHMARK = ('IIIT5k', 'SVT', 'IC13_1015', 'IC15_2077', 'SVTP', 'CUTE80') + TEST_NEW = ('ArT', 'COCOv1.4', 'Uber') + TEST_ALL = tuple(set(TEST_BENCHMARK_SUB + TEST_BENCHMARK + TEST_NEW)) + + def __init__(self, root_dir: str, train_dir: str, img_size: Sequence[int], max_label_length: int, charset_train: str, charset_test: str, batch_size: int, num_workers: int, augment: bool, mask=False, remove_whitespace: bool = True, normalize_unicode: bool = True,min_image_dim: int = 0, rotation: int = 0,transpose: bool = False, trans_aug: bool = False, p: float=0.0, collate_fn: Optional[Callable] = None): #transpose: bool, + super().__init__() + self.p = p + self.mask = mask + self.root_dir = root_dir + self.train_dir = train_dir + self.img_size = tuple(img_size) + self.max_label_length = max_label_length + self.charset_train = charset_train + self.charset_test = charset_test + self.batch_size = batch_size + self.num_workers = num_workers + self.augment = augment + self.remove_whitespace = remove_whitespace + self.normalize_unicode = normalize_unicode + self.min_image_dim = min_image_dim + self.rotation = rotation + self.img_size = img_size + self.collate_fn = collate_fn + self.transpose = transpose + self.trans_aug = trans_aug + self._train_dataset = None + self._val_dataset = None + + @staticmethod + def get_transform(img_size: Tuple[int], augment: bool = False, rotation: int = 0): + transforms = [] + if augment: + from .augment import rand_augment_transform + transforms.append(rand_augment_transform()) + if rotation: + transforms.append(lambda img: img.rotate(rotation, expand=True)) + ''' + transforms.extend([ + T.Resize(img_size, T.InterpolationMode.BICUBIC), + #T.ToTensor(), + #T.Normalize(0.5, 0.5) + ]) + ''' + return T.Compose(transforms) + + @property + def train_dataset(self): + if self._train_dataset is None: + transform = self.get_transform(self.img_size, self.augment) + if self.train_dir not in ("synth", "real"): + root = PurePath(self.root_dir, 'train', self.train_dir, "Union14M-L-lmdb", "train") + else: + root = PurePath(self.root_dir, 'train', self.train_dir) + ''' + if self.train_dir == "real": + root = PurePath(self.root_dir, 'train', self.train_dir) + else: + root = "../data/train/synth/" + ''' + self._train_dataset = build_tree_dataset(root, self.charset_train, self.max_label_length, + self.min_image_dim, self.remove_whitespace, self.normalize_unicode,mask=self.mask, + transform=transform, train=True, img_size=self.img_size) + return self._train_dataset + + @property + def val_dataset(self): + if self._val_dataset is None: + transform = self.get_transform(self.img_size) + if self.train_dir not in ("synth", "real"): + root = PurePath(self.root_dir, 'train', self.train_dir, "Union14M-L-lmdb", "val") + else: + root = PurePath(self.root_dir, 'val') + self._val_dataset = build_tree_dataset(root, self.charset_test, self.max_label_length, + self.min_image_dim, self.remove_whitespace, self.normalize_unicode,mask=self.mask, + transform=transform, train=False, img_size=self.img_size) + return self._val_dataset + + def train_dataloader(self): + return DataLoader(self.train_dataset, batch_size=self.batch_size, shuffle=True, + num_workers=self.num_workers, persistent_workers=self.num_workers > 0, + pin_memory=True, collate_fn=self.collate_fn) + + def val_dataloader(self): + return DataLoader(self.val_dataset, batch_size=self.batch_size, + num_workers=self.num_workers, persistent_workers=self.num_workers > 0, + pin_memory=True, collate_fn=self.collate_fn) + + def test_dataloaders(self, subset): + transform = self.get_transform(self.img_size, rotation=self.rotation) + root = PurePath(self.root_dir, 'test') + datasets = {s: LmdbDataset(os.path.join(root, s), self.charset_test, self.max_label_length, + self.min_image_dim, self.remove_whitespace, self.normalize_unicode,mask=self.mask, + transform=transform, train=False, img_size=self.img_size) for s in subset} + return {k: DataLoader(v, batch_size=self.batch_size, num_workers=self.num_workers, + pin_memory=True, collate_fn=self.collate_fn) + for k, v in datasets.items()} \ No newline at end of file diff --git a/strhub/data/utils.py b/strhub/data/utils.py new file mode 100644 index 0000000..85d27a1 --- /dev/null +++ b/strhub/data/utils.py @@ -0,0 +1,149 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import os +from abc import ABC, abstractmethod +from itertools import groupby +from typing import List, Optional, Tuple + +import torch +from torch import Tensor +from torch.nn.utils.rnn import pad_sequence + + +class CharsetAdapter: + """Transforms labels according to the target charset.""" + + def __init__(self, target_charset) -> None: + super().__init__() + self.lowercase_only = target_charset == target_charset.lower() + self.uppercase_only = target_charset == target_charset.upper() + self.unsupported = re.compile(f'[^{re.escape(target_charset)}]') + + def __call__(self, label): + if self.lowercase_only: + label = label.lower() + elif self.uppercase_only: + label = label.upper() + # Remove unsupported characters + label = self.unsupported.sub('', label) + return label + + +class BaseTokenizer(ABC): + + def __init__(self, charset: str, specials_first: tuple = (), specials_last: tuple = ()) -> None: + self._itos = specials_first + tuple(charset) + specials_last + self._stoi = {s: i for i, s in enumerate(self._itos)} + + def __len__(self): + return len(self._itos) + + def _tok2ids(self, tokens: str) -> List[int]: + return [self._stoi[s] for s in tokens] + + def _ids2tok(self, token_ids: List[int], join: bool = True) -> str: + tokens = [self._itos[i] for i in token_ids] + return ''.join(tokens) if join else tokens + + @abstractmethod + def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: + """Encode a batch of labels to a representation suitable for the model. + + Args: + labels: List of labels. Each can be of arbitrary length. + device: Create tensor on this device. + + Returns: + Batched tensor representation padded to the max label length. Shape: N, L + """ + raise NotImplementedError + + @abstractmethod + def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: + """Internal method which performs the necessary filtering prior to decoding.""" + raise NotImplementedError + + def decode(self, token_dists: Tensor, raw: bool = False) -> Tuple[List[str], List[Tensor]]: + """Decode a batch of token distributions. + + Args: + token_dists: softmax probabilities over the token distribution. Shape: N, L, C + raw: return unprocessed labels (will return list of list of strings) + + Returns: + list of string labels (arbitrary length) and + their corresponding sequence probabilities as a list of Tensors + """ + batch_tokens = [] + batch_probs = [] + for dist in token_dists: + probs, ids = dist.max(-1) # greedy selection + if not raw: + probs, ids = self._filter(probs, ids) + tokens = self._ids2tok(ids, not raw) + batch_tokens.append(tokens) + batch_probs.append(probs) + return batch_tokens, batch_probs + + +class Tokenizer(BaseTokenizer): + BOS = '[B]' + EOS = '[E]' + PAD = '[P]' + + def __init__(self, charset: str) -> None: + specials_first = (self.EOS,) + specials_last = (self.BOS, self.PAD) + super().__init__(charset, specials_first, specials_last) + self.eos_id, self.bos_id, self.pad_id = [self._stoi[s] for s in specials_first + specials_last] + + def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: + batch = [torch.as_tensor([self.bos_id] + self._tok2ids(y) + [self.eos_id], dtype=torch.long, device=device) + for y in labels] + return pad_sequence(batch, batch_first=True, padding_value=self.pad_id) + + def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: + ids = ids.tolist() + try: + eos_idx = ids.index(self.eos_id) + except ValueError: + eos_idx = len(ids) # Nothing to truncate. + # Truncate after EOS + ids = ids[:eos_idx] + probs = probs[:eos_idx + 1] # but include prob. for EOS (if it exists) + return probs, ids + + +class CTCTokenizer(BaseTokenizer): + BLANK = '[B]' + + def __init__(self, charset: str) -> None: + # BLANK uses index == 0 by default + super().__init__(charset, specials_first=(self.BLANK,)) + self.blank_id = self._stoi[self.BLANK] + + def encode(self, labels: List[str], device: Optional[torch.device] = None) -> Tensor: + # We use a padded representation since we don't want to use CUDNN's CTC implementation + batch = [torch.as_tensor(self._tok2ids(y), dtype=torch.long, device=device) for y in labels] + return pad_sequence(batch, batch_first=True, padding_value=self.blank_id) + + def _filter(self, probs: Tensor, ids: Tensor) -> Tuple[Tensor, List[int]]: + # Best path decoding: + ids = list(zip(*groupby(ids.tolist())))[0] # Remove duplicate tokens + ids = [x for x in ids if x != self.blank_id] # Remove BLANKs + # `probs` is just pass-through since all positions are considered part of the path + return probs, ids diff --git a/strhub/models/__init__.py b/strhub/models/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/strhub/models/__pycache__/__init__.cpython-38.pyc b/strhub/models/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..5cc672b Binary files /dev/null and b/strhub/models/__pycache__/__init__.cpython-38.pyc differ diff --git a/strhub/models/__pycache__/base.cpython-38.pyc b/strhub/models/__pycache__/base.cpython-38.pyc new file mode 100644 index 0000000..001e4f5 Binary files /dev/null and b/strhub/models/__pycache__/base.cpython-38.pyc differ diff --git a/strhub/models/__pycache__/modules.cpython-38.pyc b/strhub/models/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000..f7bbe14 Binary files /dev/null and b/strhub/models/__pycache__/modules.cpython-38.pyc differ diff --git a/strhub/models/__pycache__/utils.cpython-38.pyc b/strhub/models/__pycache__/utils.cpython-38.pyc new file mode 100644 index 0000000..a6f1e3a Binary files /dev/null and b/strhub/models/__pycache__/utils.cpython-38.pyc differ diff --git a/strhub/models/base.py b/strhub/models/base.py new file mode 100644 index 0000000..b3b6ee2 --- /dev/null +++ b/strhub/models/base.py @@ -0,0 +1,220 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +import PIL +import numpy as np +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Optional, Tuple, List + +import pytorch_lightning as pl +import torch +import torch.nn.functional as F +from nltk import edit_distance +from collections import defaultdict +from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT +from timm.optim import create_optimizer_v2 +from torch import Tensor +from torch.optim import Optimizer +from torch.optim.lr_scheduler import OneCycleLR + +from strhub.data.utils import CharsetAdapter, CTCTokenizer, Tokenizer, BaseTokenizer + + +@dataclass +class BatchResult: + num_samples: int + correct: int + ned: float + confidence: float + label_length: int + loss: Tensor + loss_numel: int + pred_str: List[str] + + +class BaseSystem(pl.LightningModule, ABC): + + def __init__(self, tokenizer: BaseTokenizer, charset_test: str, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None: + super().__init__() + self.tokenizer = tokenizer + self.charset_adapter = CharsetAdapter(charset_test) + self.batch_size = batch_size + self.lr = lr + self.warmup_pct = warmup_pct + self.weight_decay = weight_decay + + @abstractmethod + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + """Inference + + Args: + images: Batch of images. Shape: N, Ch, H, W + max_length: Max sequence length of the output. If None, will use default. + + Returns: + logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials) + """ + raise NotImplementedError + + @abstractmethod + def forward_logits_loss(self, images: Tensor, labels: List[str]) -> Tuple[Tensor, Tensor, int]: + """Like forward(), but also computes the loss (calls forward() internally). + + Args: + images: Batch of images. Shape: N, Ch, H, W + labels: Text labels of the images + + Returns: + logits: N, L, C (L = sequence length, C = number of classes, typically len(charset_train) + num specials) + loss: mean loss for the batch + loss_numel: number of elements the loss was calculated from + """ + raise NotImplementedError + + def configure_optimizers(self): + agb = self.trainer.accumulate_grad_batches + # Linear scaling so that the effective learning rate is constant regardless of the number of GPUs used with DDP. + lr_scale = agb * math.sqrt(self.trainer.num_devices) * self.batch_size / 256. + lr = lr_scale * self.lr + optim = create_optimizer_v2(self, 'adamw', lr, self.weight_decay) + sched = OneCycleLR(optim, lr, self.trainer.estimated_stepping_batches, pct_start=self.warmup_pct, + cycle_momentum=False) + return {'optimizer': optim, 'lr_scheduler': {'scheduler': sched, 'interval': 'step'}} + + def optimizer_zero_grad(self, epoch: int, batch_idx: int, optimizer: Optimizer, optimizer_idx: int): + optimizer.zero_grad(set_to_none=True) + + def _eval_step(self, batch, validation: bool) -> Optional[STEP_OUTPUT]: + if len(batch) > 2: + images, labels, tgt_mask = batch + else: + images, labels = batch + tgt_mask = None + correct = 0 + total = 0 + ned = 0 + confidence = 0 + label_length = 0 + pred_str = [] + if validation: + logits, loss, loss_numel = self.forward_logits_loss(images, labels, tgt_mask) + else: + # At test-time, we shouldn't specify a max_label_length because the test-time charset used + # might be different from the train-time charset. max_label_length in eval_logits_loss() is computed + # based on the transformed label, which could be wrong if the actual gt label contains characters existing + # in the train-time charset but not in the test-time charset. For example, "aishahaleyes.blogspot.com" + # is exactly 25 characters, but if processed by CharsetAdapter for the 36-char set, it becomes 23 characters + # long only, which sets max_label_length = 23. This will cause the model prediction to be truncated. + logits = self.forward(images)[0] + loss = loss_numel = None # Only used for validation; not needed at test-time. + probs = logits.softmax(-1) + preds, probs = self.tokenizer.decode(probs) + + if True: + for pred, prob, gt in zip(preds, probs, labels): + confidence += prob.prod().item() + pred = self.charset_adapter(pred) + # Follow ICDAR 2019 definition of N.E.D. + ned += edit_distance(pred, gt) / max(len(pred), len(gt)) + if pred == gt: + correct += 1 + pred_str.append(pred) + total += 1 + label_length += len(pred) + + return dict(output=BatchResult(total, correct, ned, confidence, label_length, loss, loss_numel, pred_str)) + + @staticmethod + def _aggregate_results(outputs: EPOCH_OUTPUT) -> Tuple[float, float, float]: + if not outputs: + return 0., 0., 0. + total_loss = 0 + total_loss_numel = 0 + total_n_correct = 0 + total_norm_ED = 0 + total_size = 0 + for result in outputs: + result = result['output'] + total_loss += result.loss_numel * result.loss + total_loss_numel += result.loss_numel + total_n_correct += result.correct + total_norm_ED += result.ned + total_size += result.num_samples + acc = total_n_correct / total_size + ned = (1 - total_norm_ED / total_size) + loss = total_loss / total_loss_numel + return acc, ned, loss + + def validation_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]: + return self._eval_step(batch, True) + + def validation_epoch_end(self, outputs: EPOCH_OUTPUT) -> None: + acc, ned, loss = self._aggregate_results(outputs) + self.log('val_accuracy', 100 * acc, sync_dist=True) + self.log('val_NED', 100 * ned, sync_dist=True) + self.log('val_loss', loss, sync_dist=True) + self.log('hp_metric', acc, sync_dist=True) + + def test_step(self, batch, batch_idx) -> Optional[STEP_OUTPUT]: + return self._eval_step(batch, False) + + +class CrossEntropySystem(BaseSystem): + + def __init__(self, charset_train: str, charset_test: str, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None: + tokenizer = Tokenizer(charset_train) + super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.bos_id = tokenizer.bos_id + self.eos_id = tokenizer.eos_id + self.pad_id = tokenizer.pad_id + + def forward_logits_loss(self, images: Tensor, labels: List[str], tgt_mask = None) -> Tuple[Tensor, Tensor, int]: + targets = self.tokenizer.encode(labels, self.device) + targets = targets[:, 1:] # Discard + max_len = targets.shape[1] - 1 # exclude from count + result = self.forward(images, max_len) + loss = 0 + + loss += F.cross_entropy(result[0].flatten(end_dim=1), targets.flatten(), ignore_index=self.pad_id) + + if result[1] is not None and tgt_mask is not None: + loss += F.binary_cross_entropy_with_logits(result[1].flatten(), tgt_mask.flatten()) + loss_numel = (targets != self.pad_id).sum() #有多少个有限字符 + return result[0], loss, loss_numel + + +class CTCSystem(BaseSystem): + + def __init__(self, charset_train: str, charset_test: str, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float) -> None: + tokenizer = CTCTokenizer(charset_train) + super().__init__(tokenizer, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.blank_id = tokenizer.blank_id + + def forward_logits_loss(self, images: Tensor, labels: List[str], tgt_mask = None) -> Tuple[Tensor, Tensor, int]: + targets = self.tokenizer.encode(labels, self.device) + logits, mask = self.forward(images) + log_probs = logits.log_softmax(-1).transpose(0, 1) # swap batch and seq. dims + T, N, _ = log_probs.shape + input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long, device=self.device) + target_lengths = torch.as_tensor(list(map(len, labels)), dtype=torch.long, device=self.device) + loss = F.ctc_loss(log_probs, targets, input_lengths, target_lengths, blank=self.blank_id, zero_infinity=True) + if mask is not None and tgt_mask is not None: + loss += F.binary_cross_entropy_with_logits(mask.flatten(), tgt_mask.flatten()) + return logits, loss, N diff --git a/strhub/models/cfe/__init__.py b/strhub/models/cfe/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/strhub/models/cfe/__pycache__/__init__.cpython-38.pyc b/strhub/models/cfe/__pycache__/__init__.cpython-38.pyc new file mode 100644 index 0000000..ba330fd Binary files /dev/null and b/strhub/models/cfe/__pycache__/__init__.cpython-38.pyc differ diff --git a/strhub/models/cfe/__pycache__/__init__.cpython-39.pyc b/strhub/models/cfe/__pycache__/__init__.cpython-39.pyc new file mode 100644 index 0000000..c6053c9 Binary files /dev/null and b/strhub/models/cfe/__pycache__/__init__.cpython-39.pyc differ diff --git a/strhub/models/cfe/__pycache__/modules.cpython-38.pyc b/strhub/models/cfe/__pycache__/modules.cpython-38.pyc new file mode 100644 index 0000000..f6ac8a1 Binary files /dev/null and b/strhub/models/cfe/__pycache__/modules.cpython-38.pyc differ diff --git a/strhub/models/cfe/__pycache__/modules.cpython-39.pyc b/strhub/models/cfe/__pycache__/modules.cpython-39.pyc new file mode 100644 index 0000000..d16fe5b Binary files /dev/null and b/strhub/models/cfe/__pycache__/modules.cpython-39.pyc differ diff --git a/strhub/models/cfe/__pycache__/system.cpython-38.pyc b/strhub/models/cfe/__pycache__/system.cpython-38.pyc new file mode 100644 index 0000000..1f71603 Binary files /dev/null and b/strhub/models/cfe/__pycache__/system.cpython-38.pyc differ diff --git a/strhub/models/cfe/__pycache__/system.cpython-39.pyc b/strhub/models/cfe/__pycache__/system.cpython-39.pyc new file mode 100644 index 0000000..6b94cb6 Binary files /dev/null and b/strhub/models/cfe/__pycache__/system.cpython-39.pyc differ diff --git a/strhub/models/cfe/modules.py b/strhub/models/cfe/modules.py new file mode 100644 index 0000000..d80a684 --- /dev/null +++ b/strhub/models/cfe/modules.py @@ -0,0 +1,707 @@ +# Copyright (c) OpenMMLab. All rights reserved. +from typing import Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath +from functools import partial +from strhub.models.utils import init_weights +from timm.models.helpers import named_apply + +def ConvBNLayer(inchanns, outchanns, kernel, stride, padding=0, activation="relu"): + return nn.Sequential( + *[ + nn.Conv2d(inchanns, outchanns, kernel, stride, padding), + nn.BatchNorm2d(outchanns), + nn.ReLU(inplace=True) if activation == "relu" else nn.GELU() + ] + ) + +def rotate_every_two(x): + x1 = x[:, :, :, ::2] + x2 = x[:, :, :, 1::2] + x = torch.stack((-x2, x1), dim=-1) + return x.flatten(-2) # in einsum notation: rearrange(x, '... d j -> ... (d j)')\ + +def theta_shift(x, sin, cos): + return (x * cos) + (rotate_every_two(x) * sin) + + +class RetNetRelPos(nn.Module): + # The code is modified based on the paper: https://arxiv.org/abs/2307.08621 + def __init__(self, sigma, head_dim, input_shape, local_k, mixer, local_type="r2"): + super().__init__() + self.h, self.w = input_shape + self.hk, self.wk = local_k + self.mixer = mixer + self.sigma = sigma + self.local_type = local_type + angle = 1.0 / (10000 ** torch.linspace(0, 1, head_dim // 2)) # dims of each head // 2 + angle = angle.unsqueeze(-1).repeat(1, 2).flatten() # 16 * 2 => 32 + d = torch.log(1 - 2 ** (-5 - torch.arange(sigma, dtype=torch.float))) # num_heads + #decay = torch.tensor(sigma, dtype=torch.float) # sigma + self.register_buffer("angle", angle) + self.register_buffer("d", d) + height, width = self.h, self.w + index = torch.arange(height * width) + sin = torch.sin(index[:, None] * self.angle[None, :]) + cos = torch.cos(index[:, None] * self.angle[None, :]) + self.register_buffer("sin", sin) + self.register_buffer("cos", cos) + if self.mixer == 'Local': + if self.local_type == "r1": + self.r(height, width) + elif self.local_type == "r2": + self.r2(height, width) + elif self.local_type == "r3": + self.r1(height, width) + else: + print("the local type must be in r1, r2, r3") + else: #"global" + mask = torch.ones( + [self.sigma, height * width, height ,width], + dtype=torch.float) + mask = mask.flatten(2) + mask = mask.unsqueeze(0) + decay = torch.zeros((self.sigma, height * width, height * width)).type_as(self.d).unsqueeze(0) + self.register_buffer("decay_matrix", mask) + self.register_buffer("mask", decay) + + def r1(self, height, width): + mask = torch.zeros( + [self.sigma, height * width, height ,width], + dtype=torch.float) + for h in range(0, height): + for w in range(0, width): + i = h - torch.arange(0, height).type_as(mask).unsqueeze(-1) + j = w - torch.arange(0, width).type_as(mask).unsqueeze(0) + i_j = torch.abs(i) + torch.abs(j) + deacy_hw = torch.exp(self.d[:, None, None] * i_j[None, :, :]) + mask[:, h * width + w, :, :] = deacy_hw + mask = mask.flatten(2) + mask = mask.unsqueeze(0) + decay = torch.zeros((self.sigma, height * width, height * width)).type_as(self.d).unsqueeze(0) + self.register_buffer("decay_matrix", mask) + self.register_buffer("mask", decay) + + def r2(self, height, width): + mask = torch.zeros( + [self.sigma, height * width, height ,width], + dtype=torch.float) + for h in range(0, height): + for w in range(0, width): + i = h - torch.arange(0, height).type_as(mask).unsqueeze(-1).repeat(1, width) + j = w - torch.arange(0, width).type_as(mask).unsqueeze(0).repeat(height, 1) + i, j = torch.abs(i), torch.abs(j) + i[i < j] = j[i < j] + deacy_hw = torch.exp(self.d[:, None, None] * i[None, :, :]) + mask[:, h * width + w, :, :] = deacy_hw + mask = mask.flatten(2) + mask = mask.unsqueeze(0) + decay = torch.zeros((self.sigma, height * width, height * width)).type_as(self.d).unsqueeze(0) + self.register_buffer("decay_matrix", mask) + self.register_buffer("mask", decay) + + def r3(self, height, width): + hk, wk =self.hk, self.wk + mask = torch.zeros( + [self.sigma, height * width, height + hk - 1, width + wk - 1], + dtype=torch.float) + for h in range(0, height): + for w in range(0, width): + mask[h * width + w, h:h + hk, w:w + wk] = 1.0 + mask = mask[:, :, hk // 2:height + hk // 2, + wk // 2:width + wk // 2].flatten(2) + mask = mask.unsqueeze(0) + decay = torch.masked_fill(torch.zeros((self.sigma, height * width, height * width)).type_as(self.d), mask == -1.0, float("-inf")) + decay = decay.unsqueeze(0) + self.register_buffer("decay_matrix", mask) + self.register_buffer("mask", decay) + +class OverlapPatchEmbed(nn.Module): + """Image to the progressive overlapping Patch Embedding. + + Args: + in_channels (int): Number of input channels. Defaults to 3. + embed_dims (int): The dimensions of embedding. Defaults to 768. + num_layers (int, optional): Number of Conv_BN_Layer. Defaults to 2 and + limit to [2, 3]. + init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. + """ + def __init__(self, + in_channels: int = 3, + embed_dims: int = 768, + num_layers: int = 2): + super().__init__() + assert num_layers in [2, 3], \ + 'The number of layers must belong to [2, 3]' + self.net = nn.Sequential() + for num in range(num_layers, 0, -1): + if (num == num_layers): + _input = in_channels + _output = embed_dims // (2**(num - 1)) + self.net.add_module( + f'ConvBNLayer{str(num_layers - num)}', + ConvBNLayer( + inchanns=_input, + outchanns=_output, + kernel=3, + stride=2, + padding=1, + activation="gelu")) + _input = _output + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function. + Args: + x (Tensor): A Tensor of shape :math:`(N, C, H, W)`. + Returns: + Tensor: A tensor of shape math:`(N, HW//16, C)`. + """ + x = self.net(x).flatten(2).permute(0, 2, 1) + return x + + +class ConvMixer(nn.Module): + """The conv Mixer. + Args: + embed_dims (int): Number of character components. + num_heads (int, optional): Number of heads. Defaults to 8. + input_shape (Tuple[int, int], optional): The shape of input [H, W]. + Defaults to [8, 32]. + local_k (Tuple[int, int], optional): Window size. Defaults to [3, 3]. + init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. + """ + def __init__(self, + embed_dims: int, + num_heads: int = 8, + input_shape: Tuple[int, int] = [8, 32], + local_k: Tuple[int, int] = [3, 3], + **kwargs): + super().__init__(**kwargs) + self.input_shape = input_shape + self.embed_dims = embed_dims + self.local_mixer = nn.Conv2d( + in_channels=embed_dims, + out_channels=embed_dims, + kernel_size=local_k, + stride=1, + padding=(local_k[0] // 2, local_k[1] // 2), + groups=num_heads) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function. + + Args: + x (torch.Tensor): A Tensor of shape :math:`(N, HW, C)`. + + Returns: + torch.Tensor: Tensor: A tensor of shape math:`(N, HW, C)`. + """ + h, w = self.input_shape + x = x.permute(0, 2, 1).reshape([-1, self.embed_dims, h, w]) + x = self.local_mixer(x) + x = x.flatten(2).permute(0, 2, 1) + return x + + +class AttnMixer(nn.Module): + """One of mixer of {'Global', 'Local'}. Defaults to Global Mixer. + Args: + embed_dims (int): Number of character components. + num_heads (int, optional): Number of heads. Defaults to 8. + mixer (str, optional): The mixer type, choices are 'Global' and + 'Local'. Defaults to 'Global'. + input_shape (Tuple[int, int], optional): The shape of input [H, W]. + Defaults to [8, 32]. + local_k (Tuple[int, int], optional): Window size. Defaults to [7, 11]. + qkv_bias (bool, optional): Whether a additive bias is required. + Defaults to False. + qk_scale (float, optional): A scaling factor. Defaults to None. + attn_drop (float, optional): Attn dropout probability. Defaults to 0.0. + proj_drop (float, optional): Proj dropout layer. Defaults to 0.0. + init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. + """ + def __init__(self, + embed_dims: int, + num_heads: int = 8, + mixer: str = 'Global', + input_shape: Tuple[int, int] = [8, 32], + local_k: Tuple[int, int] = [7, 11], + qkv_bias: bool = False, + qk_scale: float = None, + attn_drop: float = 0., + proj_drop: float = 0., + local_type: str = "r2", + use_pe: bool = True, + **kwargs): + super().__init__(**kwargs) + assert mixer in {'Global', 'Local'}, \ + "The type of mixer must belong to {'Global', 'Local'}" + self.num_heads = num_heads + head_dim = embed_dims // num_heads + self.scale = qk_scale or head_dim**-0.5 + self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(embed_dims, embed_dims) + self.proj_drop = nn.Dropout(proj_drop) + self.input_shape = input_shape + self.use_pe = use_pe + self.xpos = RetNetRelPos(num_heads, head_dim, input_shape, local_k, mixer, local_type=local_type) + if input_shape is not None: + height, width = input_shape + self.input_size = height * width + self.embed_dims = embed_dims + self.mixer = mixer + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function. + Args: + x (torch.Tensor): A Tensor of shape :math:`(N, H, W, C)`. + Returns: + torch.Tensor: A Tensor of shape :math:`(N, H, W, C)`. + """ + if self.input_shape is not None: + input_size, embed_dims = self.input_size, self.embed_dims + else: + _, input_size, embed_dims = x.shape + sin, cos, mask, decay = self.xpos.sin, self.xpos.cos, self.xpos.mask, self.xpos.decay_matrix + qkv = self.qkv(x).reshape((-1, input_size, 3, self.num_heads, + embed_dims // self.num_heads)).permute( + (2, 0, 3, 1, 4)) + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + if self.use_pe: + qr = theta_shift(q, sin, cos) #bs, num_heads, input_len, dim + kr = theta_shift(k, sin, cos) #bs, num_heads, input_len, dim + else: + qr, kr = q, k + attn = qr.matmul(kr.permute(0, 1, 3, 2)) + + attn = F.softmax(attn, dim=-1) + if self.mixer == 'Local': + attn = attn * decay + attn = self.attn_drop(attn) + + x = attn.matmul(v).permute(0, 2, 1, 3).reshape(-1, input_size, + embed_dims) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MLP(nn.Module): + """The MLP block. + Args: + in_features (int): The input features. + hidden_features (int, optional): The hidden features. + Defaults to None. + out_features (int, optional): The output features. + Defaults to None. + drop (float, optional): cfg of dropout function. Defaults to 0.0. + init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. + """ + def __init__(self, + in_features: int, + hidden_features: int = None, + out_features: int = None, + drop: float = 0., + **kwargs): + super().__init__(**kwargs) + hidden_features = hidden_features or in_features + out_features = out_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = nn.GELU() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function. + Args: + x (torch.Tensor): A Tensor of shape :math:`(N, H, W, C)`. + Returns: + torch.Tensor: A Tensor of shape :math:`(N, H, W, C)`. + """ + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +class MixingBlock(nn.Module): + """The Mixing block. + Args: + embed_dims (int): Number of character components. + num_heads (int): Number of heads + mixer (str, optional): The mixer type. Defaults to 'Global'. + window_size (Tuple[int ,int], optional): Local window size. + Defaults to [7, 11]. + input_shape (Tuple[int, int], optional): The shape of input [H, W]. + Defaults to [8, 32]. + mlp_ratio (float, optional): The ratio of hidden features to input. + Defaults to 4.0. + qkv_bias (bool, optional): Whether a additive bias is required. + Defaults to False. + qk_scale (float, optional): A scaling factor. Defaults to None. + drop (float, optional): cfg of Dropout. Defaults to 0.. + attn_drop (float, optional): cfg of Dropout. Defaults to 0.0. + drop_path (float, optional): The probability of drop path. + Defaults to 0.0. + pernorm (bool, optional): Whether to place the MxingBlock before norm. + Defaults to True. + init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. + """ + def __init__(self, + embed_dims: int, + num_heads: int, + mixer: str = 'Global', + window_size: Tuple[int, int] = [7, 11], + input_shape: Tuple[int, int] = [8, 32], + mlp_ratio: float = 4., + qkv_bias: bool = False, + qk_scale: float = None, + drop: float = 0., + attn_drop: float = 0., + drop_path=0., + prenorm: bool = True, + local_type: str = "r2", + use_pe : bool = True, + **kwargs): + super().__init__(**kwargs) + self.norm1 = nn.LayerNorm(embed_dims, eps=1e-6) + if mixer in {'Global', 'Local'}: + self.mixer = AttnMixer( + embed_dims, + num_heads=num_heads, + mixer=mixer, + input_shape=input_shape, + local_k=window_size, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + attn_drop=attn_drop, + proj_drop=drop, local_type=local_type, use_pe=use_pe) + elif mixer == 'Conv': + self.mixer = ConvMixer( + embed_dims, + num_heads=num_heads, + input_shape=input_shape, + local_k=window_size) + else: + raise TypeError('The mixer must be one of [Global, Local, Conv]') + self.drop_path = DropPath( + drop_path) if drop_path > 0. else nn.Identity() + self.norm2 = nn.LayerNorm(embed_dims, eps=1e-6) + mlp_hidden_dim = int(embed_dims * mlp_ratio) + self.mlp_ratio = mlp_ratio + self.mlp = MLP( + in_features=embed_dims, hidden_features=mlp_hidden_dim, drop=drop) + self.prenorm = prenorm + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function. + Args: + x (torch.Tensor): A Tensor of shape :math:`(N, H*W, C)`. + Returns: + torch.Tensor: A Tensor of shape :math:`(N, H*W, C)`. + """ + if self.prenorm: + x = self.norm1(x + self.drop_path(self.mixer(x))) + x = self.norm2(x + self.drop_path(self.mlp(x))) + else: + x = x + self.drop_path(self.mixer(self.norm1(x))) + x = x + self.drop_path(self.mlp(self.norm2(x))) + return x + + +class MerigingBlock(nn.Module): + """The last block of any stage, except for the last stage. + Args: + in_channels (int): The channels of input. + out_channels (int): The channels of output. + types (str, optional): Which downsample operation of ['Pool', 'Conv']. + Defaults to 'Pool'. + stride (Union[int, Tuple[int, int]], optional): Stride of the Conv. + Defaults to [2, 1]. + act (bool, optional): activation function. Defaults to None. + init_cfg (dict or list[dict], optional): Initialization configs. + Defaults to None. + """ + def __init__(self, + in_channels: int, + out_channels: int, + types: str = 'Pool', + stride: Union[int, Tuple[int, int]] = [2, 1], + act: bool = None, + **kwargs): + super().__init__(**kwargs) + self.types = types + if types == 'Pool': + self.avgpool = nn.AvgPool2d( + kernel_size=[3, 5], stride=stride, padding=[1, 2]) + self.maxpool = nn.MaxPool2d( + kernel_size=[3, 5], stride=stride, padding=[1, 2]) + self.proj = nn.Linear(in_channels, out_channels) + else: + self.conv = nn.Conv2d( + in_channels, + out_channels, + kernel_size=3, + stride=stride, + padding=1) + self.norm = nn.LayerNorm(out_channels) + if act is not None: + self.act = act() + else: + self.act = None + + def forward(self, x: torch.Tensor) -> torch.Tensor: + """Forward function. + Args: + x (torch.Tensor): A Tensor of shape :math:`(N, H, W, C)`. + Returns: + torch.Tensor: A Tensor of shape :math:`(N, H/2, W, 2C)`. + """ + if self.types == 'Pool': + x = (self.avgpool(x) + self.maxpool(x)) * 0.5 + out = self.proj(x.flatten(2).permute(0, 2, 1)) + else: + x = self.conv(x) + out = x.flatten(2).permute(0, 2, 1) + out = self.norm(out) + if self.act is not None: + out = self.act(out) + return out + + +class CACE(nn.Module): + # This code is modified from https://github.com/open-mmlab/mmocr/blob/main/mmocr/models/textrecog/encoders/svtr_encoder.py + def __init__(self, + img_size: Tuple[int, int] = [32, 128], + in_channels: int = 3, + embed_dims: Tuple[int, int, int] = [128,256,384], + depth: Tuple[int, int, int] = [3, 6, 9], + num_heads: Tuple[int, int, int] = [4, 8, 12], + mixer_types: Tuple[str] = ['Local'] * 8 + ['Global'] * 10, + window_size: Tuple[Tuple[int, int]] = [[7, 11], [7, 11], + [7, 11]], + merging_types: str = 'Conv', + mlp_ratio: int = 4, + qkv_bias: bool = True, + qk_scale: float = None, + drop_rate: float = 0., + last_drop: float = 0.1, + attn_drop_rate: float = 0., + drop_path_rate: float = 0.1, + out_channels: int = 192, + num_layers: int = 2, + prenorm: bool = False, + local_type: str = "r2", + use_pe : bool = True, + **kwargs): + super().__init__(**kwargs) + self.img_size = img_size + self.embed_dims = embed_dims + self.out_channels = out_channels + self.prenorm = prenorm + self.patch_embed = OverlapPatchEmbed( + in_channels=in_channels, + embed_dims=embed_dims[0], + num_layers=num_layers) + num_patches = (img_size[1] // (2**num_layers)) * ( + img_size[0] // (2**num_layers)) + self.input_shape = [ + img_size[0] // (2**num_layers), img_size[1] // (2**num_layers) + ] + self.pos_drop = nn.Dropout(drop_rate) + dpr = np.linspace(0, drop_path_rate, sum(depth)) + + self.blocks1 = nn.ModuleList([ + MixingBlock( + embed_dims=embed_dims[0], + num_heads=num_heads[0], + mixer=mixer_types[0:depth[0]][i], + window_size=window_size[0], + input_shape=self.input_shape, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[0:depth[0]][i], + prenorm=prenorm, local_type=local_type, use_pe=use_pe) for i in range(depth[0]) + ]) + self.downsample1 = MerigingBlock( + in_channels=embed_dims[0], + out_channels=embed_dims[1], + types=merging_types, + stride=[2, 1]) + input_shape = [self.input_shape[0] // 2, self.input_shape[1]] + self.merging_types = merging_types + + self.blocks2 = nn.ModuleList([ + MixingBlock( + embed_dims=embed_dims[1], + num_heads=num_heads[1], + mixer=mixer_types[depth[0]:depth[0] + depth[1]][i], + window_size=window_size[1], + input_shape=input_shape, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0]:depth[0] + depth[1]][i], + prenorm=prenorm, local_type=local_type) for i in range(depth[1]) + ]) + self.downsample2 = MerigingBlock( + in_channels=embed_dims[1], + out_channels=embed_dims[2], + types=merging_types, + stride=[2, 1]) + input_shape = [self.input_shape[0] // 4, self.input_shape[1]] + + self.blocks3 = nn.ModuleList([ + MixingBlock( + embed_dims=embed_dims[2], + num_heads=num_heads[2], + mixer=mixer_types[depth[0] + depth[1]:][i], + window_size=window_size[2], + input_shape=input_shape, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + qk_scale=qk_scale, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[depth[0] + depth[1]:][i], + prenorm=prenorm, local_type=local_type) for i in range(depth[2]) + ]) + named_apply(partial(init_weights, exclude=[]), self) + + def forward_features(self, x: torch.Tensor) -> torch.Tensor: + """Forward function except the last combing operation. + Args: + x (torch.Tensor): A Tensor of shape :math:`(N, 3, H, W)`. + Returns: + torch.Tensor: A List Tensor of shape :math:[`(N, H/4, W/4, C_1)`, `(N, H/8, W/4, C_2)`, `(N, H/16, W/4, C_3)`]`. + """ + x = self.patch_embed(x) + x = self.pos_drop(x) + fpn = [] + for blk in self.blocks1: + x = blk(x) + fpn.append(x) + x = self.downsample1( + x.permute(0, 2, 1).reshape([ + -1, self.embed_dims[0], self.input_shape[0], + self.input_shape[1] + ])) + + for blk in self.blocks2: + x = blk(x) + fpn.append(x) + x = self.downsample2( + x.permute(0, 2, 1).reshape([ + -1, self.embed_dims[1], self.input_shape[0] // 2, + self.input_shape[1] + ])) + + for blk in self.blocks3: + x = blk(x) + fpn.append(x) + return fpn + + def forward(self, + x: torch.Tensor, + ) -> torch.Tensor: + """Forward function. + Args: + x (torch.Tensor): A Tensor of shape :math:`(N, 3, H, W)`. + Returns: + torch.Tensor: A List Tensor of shape :math:[`(N, H/4, W/4, C_1)`, `(N, H/8, W/4, C_2)`, `(N, H/16, W/4, C_3)`]. + """ + fpn = self.forward_features(x) + return fpn + + +class FusionModule(nn.Module): + def __init__(self, img_size, embed_dims, out_dim, fpn_layers, **kwargs): + super().__init__(**kwargs) + self.h, self.w = img_size[0] // 4, img_size[1]// 4 + self.fpn_layers = fpn_layers + self.linear = nn.ModuleList() + for dim in embed_dims: + self.linear.append(nn.Linear(dim, out_dim)) + + def forward(self, fpn): + fusion = [] + assert len(fpn) == len(self.linear), print("the length of output encoder must \ + equal to the length of embed_dims") + for f, layer in zip(fpn, self.linear): + fusion.append(layer(f)) + return torch.cat([fusion[i] for i in self.fpn_layers], dim=1) + + +class Intra_Inter_ConsistencyLoss(nn.Module): + """Contrastive Center loss. + Reference: + Args: + num_classes (int): number of classes. + feat_dim (int): feature dimension. + """ + def __init__(self, num_classes=94, in_dim=256, out_dim=2048, eps=1, alpha=0.1, start=0): + super().__init__() + self.num_classes = num_classes + self.in_dim = in_dim + self.eps = eps + self.alpha = alpha + self.out_dim = out_dim + self.start = start + + self.linear = nn.Linear(in_dim, out_dim) if out_dim > 0 else nn.Identity() + self.centers = nn.Parameter(torch.randn(self.num_classes, self.out_dim)) if out_dim > 0 else nn.Parameter(torch.randn(self.num_classes, self.in_dim)) + ''' + self.linear = nn.Sequential( + nn.Linear(in_dim, out_dim), + nn.ReLU(inplace=True), + nn.Linear(out_dim, in_dim) + ) if out_dim > 0 else nn.Identity() + self.centers = nn.Parameter(torch.randn(self.num_classes, self.in_dim)) + ''' + nn.init.trunc_normal_(self.centers, mean=0, std=0.02) + + def forward(self, features: torch.Tensor, targets: torch.Tensor, labels: List[str]): + """ + Args: + x: feature matrix with shape (batch_size, feat_dim). + labels: ground truth labels with shape (batch_size). + """ + features = self.linear(features) + new_x, new_t = [] , [] + for f, l, t in zip(features, labels, targets): + new_x.append(f[:len(l)]) + new_t.append(t[:len(l)] - self.start) + x = torch.cat(new_x, dim=0) + labels = torch.cat(new_t, dim=0) + batch_size = x.size(0) + mat = torch.pow(x, 2).sum(dim=1, keepdim=True).expand(batch_size, self.num_classes) + \ + torch.pow(self.centers, 2).sum(dim=1, keepdim=True).expand(self.num_classes, batch_size).t() + distmat = mat - 2 * x @ self.centers.t() + + classes = torch.arange(self.num_classes, device=labels.device).long().unsqueeze(0) + labels = labels.unsqueeze(1).expand(batch_size, self.num_classes) + mask = labels.eq(classes.expand(batch_size, self.num_classes)) + + dist = (distmat * mask.float()).sum(1) + distmat = distmat * (~mask).float() + sum_dist = distmat.sum(1) + self.eps + + ctc = dist / sum_dist + loss = self.alpha * ctc.sum() / 2 + return loss \ No newline at end of file diff --git a/strhub/models/cfe/system.py b/strhub/models/cfe/system.py new file mode 100644 index 0000000..865a818 --- /dev/null +++ b/strhub/models/cfe/system.py @@ -0,0 +1,157 @@ +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from functools import partial +from itertools import permutations +from typing import Sequence, Any, Optional, Union, List + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor + +from pytorch_lightning.utilities.types import STEP_OUTPUT +from timm.models.helpers import named_apply + +from strhub.models.base import CrossEntropySystem +from strhub.models.utils import init_weights +from .modules import CACE,FusionModule, Intra_Inter_ConsistencyLoss +from strhub.models.modules import DecoderLayer, Decoder, TokenEmbedding, TPS_SpatialTransformerNetwork + + +class CFE(CrossEntropySystem): + def __init__(self, charset_train: str, charset_test: str, max_label_length: int, + batch_size: int, lr: float, warmup_pct: float, weight_decay: float, + img_size: Sequence[int], embed_dim: int, decoder_dim:int, + enc_num_heads: int, enc_mlp_ratio: int, depth: List[int], + dec_num_heads: int, dec_mlp_ratio: int, dec_depth: int, + mixer_types: List[Union[int, str]], merge_types:str,num_control_points:int, + dropout: float, window_size:List[List[int]], iiclexist:bool = True, prenorm:bool = False, tps = False, use_pe = True, + fpn_layers = [0, 1, 2], cc_weights = 0.2, local_type = "r2", **kwargs) -> None: + super().__init__(charset_train, charset_test, batch_size, lr, warmup_pct, weight_decay) + self.save_hyperparameters() + self.max_label_length = max_label_length + mixer_types = [mixer_types[0]] * mixer_types[1] + [mixer_types[2]] * mixer_types[3] + self.transformation = TPS_SpatialTransformerNetwork( + F=num_control_points, I_size=tuple(img_size), I_r_size=tuple(img_size), + I_channel_num=3) if tps else nn.Identity() + + self.encoder = CACE(img_size, 3, embed_dims=embed_dim, depth=depth, num_heads=enc_num_heads, + mixer_types=mixer_types,window_size=window_size, mlp_ratio=enc_mlp_ratio,\ + merging_types=merge_types, prenorm=prenorm, local_type=local_type, use_pe=use_pe) + self.fusion = FusionModule(img_size, embed_dim, decoder_dim, fpn_layers) + decoder_layer = DecoderLayer(decoder_dim, dec_num_heads, decoder_dim * dec_mlp_ratio, dropout) + self.decoder = Decoder(decoder_layer, num_layers=dec_depth, norm=nn.LayerNorm(decoder_dim)) + if iiclexist: + self.iicl = Intra_Inter_ConsistencyLoss(len(charset_train), decoder_dim, 0 * decoder_dim, alpha=cc_weights, \ + start=self.tokenizer._stoi[charset_train[0]]) + print(f"cc_weights: {cc_weights}") + else: + self.iicl = None + + # We predict and + self.head = nn.Linear(decoder_dim, len(self.tokenizer)) + self.text_embed = TokenEmbedding(len(self.tokenizer), decoder_dim) + + # +1 for + self.pos_embedding = nn.Parameter(torch.Tensor(1, max_label_length + 1, decoder_dim)) + self.dropout = nn.Dropout(p=dropout) + named_apply(partial(init_weights, exclude=['encoder']), self) + nn.init.trunc_normal_(self.pos_embedding, std=.02) + + @torch.jit.ignore + def no_weight_decay(self): + param_names = {'text_embed.embedding.weight', 'pos_embedding'} + return param_names + + def encode(self, img: torch.Tensor): + img = self.transformation(img) + fpn = self.encoder(img) + x = self.fusion(fpn) + return x + + def decode(self, tgt: torch.Tensor, memory: torch.Tensor, tgt_mask: Optional[Tensor] = None, + tgt_padding_mask: Optional[Tensor] = None, tgt_query: Optional[Tensor] = None, + tgt_query_mask: Optional[Tensor] = None): + N, L = tgt.shape + + tgt_query = self.pos_embedding[:, :L] + self.text_embed(tgt) + tgt_query = self.dropout(tgt_query) + return self.decoder(tgt_query, memory, tgt_query_mask, tgt_mask, tgt_padding_mask) + + def forward(self, images: Tensor, max_length: Optional[int] = None) -> Tensor: + testing = max_length is None + max_length = self.max_label_length if max_length is None else min(max_length, self.max_label_length) + bs = images.shape[0] + # +1 for at end of sequence. + num_steps = max_length + 1 + img = self.transformation(images) + fpn = self.encoder(img) + memory = self.fusion(fpn) + + tgt_in = torch.full((bs, num_steps), self.pad_id, dtype=torch.long, device=self._device) + tgt_in[:, 0] = self.bos_id + self_attn_mask = self.get_selfmask(num_steps) + logits = [] + for i in range(num_steps): + j = i + 1 # next token index + tgt_out = self.decode(tgt_in[:, :j], memory, tgt_query_mask=self_attn_mask[:j, :j]) + # the next token probability is in the output's ith token position + p_i = self.head(tgt_out[0][:, -1, :]) + logits.append(p_i.unsqueeze(1)) + if j < num_steps: + # greedy decode. add the next token index to the target input + tgt_in[:, j] = p_i.argmax(-1) + + if testing and (tgt_in == self.eos_id).any(dim=-1).all(): + break + logits = torch.cat(logits, dim=1) + return logits, tgt_out + + def get_selfmask(self, T: int): + return torch.triu(torch.full((T, T), float('-inf'), device=self._device), 1) + + def training_step(self, batch, batch_idx) -> STEP_OUTPUT: + images, labels = batch + tgt = self.tokenizer.encode(labels, self._device) + + # Encode the source sequence (i.e. the image codes) + memory = self.encode(images) + + tgt_in = tgt[:, :-1] + tgt_out = tgt[:, 1:] + + self_attn_mask = self.get_selfmask(tgt_in.shape[1]) + + loss = 0 + loss_numel = 0 + n = (tgt_out != self.pad_id).sum().item() + + out = self.decode(tgt_in, memory, tgt_query_mask=self_attn_mask)[0] + logits = self.head(out).flatten(end_dim=1) + loss += n * F.cross_entropy(logits, tgt_out.flatten(), ignore_index=self.pad_id) + loss_numel += n + loss /= loss_numel + self.log('loss', loss) + if self.ccloss: + total_steps = self.trainer.estimated_stepping_batches * self.trainer.accumulate_grad_batches + #and self.trainer.current_epoch <= int(self.trainer.max_epochs-2) + if self.global_step >= 0.75 * total_steps: + iicl = self.iicl(out, tgt_out, labels) + self.log('iicl', iicl) + loss += iicl + return loss \ No newline at end of file diff --git a/strhub/models/modules.py b/strhub/models/modules.py new file mode 100644 index 0000000..d173b7b --- /dev/null +++ b/strhub/models/modules.py @@ -0,0 +1,280 @@ +r"""Shared modules used by CRNN and TRBA""" +import math +import torch +from typing import Optional +from torch import nn as nn, Tensor +from torch.nn import functional as F +from torch.nn.modules import transformer + + +class BidirectionalLSTM(nn.Module): + """Ref: https://github.com/clovaai/deep-text-recognition-benchmark/blob/master/modules/sequence_modeling.py""" + + def __init__(self, input_size, hidden_size, output_size): + super().__init__() + self.rnn = nn.LSTM(input_size, hidden_size, bidirectional=True, batch_first=True) + self.linear = nn.Linear(hidden_size * 2, output_size) + + def forward(self, input): + """ + input : visual feature [batch_size x T x input_size], T = num_steps. + output : contextual feature [batch_size x T x output_size] + """ + recurrent, _ = self.rnn(input) # batch_size x T x input_size -> batch_size x T x (2*hidden_size) + output = self.linear(recurrent) # batch_size x T x output_size + return output + + +class DecoderLayer(nn.Module): + """A Transformer decoder layer supporting two-stream attention (XLNet) + This implements a pre-LN decoder, as opposed to the post-LN default in PyTorch.""" + + def __init__(self, d_model, nhead, dim_feedforward=2048, dropout=0.1, activation='gelu', + layer_norm_eps=1e-5): + super().__init__() + self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) + self.cross_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout, batch_first=True) + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.dropout = nn.Dropout(dropout) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.norm_q = nn.LayerNorm(d_model, eps=layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + self.activation = transformer._get_activation_fn(activation) + + def __setstate__(self, state): + if 'activation' not in state: + state['activation'] = F.gelu + super().__setstate__(state) + + def forward_stream(self, tgt: Tensor, tgt_norm: Tensor, tgt_kv: Tensor, memory: Tensor, + tgt_mask: Optional[Tensor],tgt_key_padding_mask: Optional[Tensor]): + + tgt2, sa_weights = self.self_attn(tgt_norm, tgt_kv, tgt_kv, attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask) + tgt = tgt + self.dropout1(tgt2) + + tgt2, ca_weights = self.cross_attn(self.norm1(tgt), memory, memory) + tgt = tgt + self.dropout2(tgt2) + + tgt2 = self.linear2(self.dropout(self.activation(self.linear1(self.norm2(tgt))))) + tgt = tgt + self.dropout3(tgt2) + return tgt, sa_weights, ca_weights + + #query等于是pos_embed + token_embed,memory为vitencoder得到的多层次特征 + def forward(self, query, memory, query_mask: Optional[Tensor] = None, + content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None, + update_content: bool = False): + query_norm = self.norm_q(query) + query,sa_weights, ca_weights = self.forward_stream(query, query_norm, query_norm, + memory, query_mask, content_key_padding_mask) + if update_content: + content = self.forward_stream(query, query_norm, query_norm, memory, + content_mask,content_key_padding_mask)[0] + return query,sa_weights, ca_weights + + +class Decoder(nn.Module): + __constants__ = ['norm'] + + def __init__(self, decoder_layer, num_layers, norm): + super().__init__() + self.layers = transformer._get_clones(decoder_layer, num_layers) + self.num_layers = num_layers + self.norm = norm + + def forward(self, query, memory, query_mask: Optional[Tensor] = None, content_mask: Optional[Tensor] = None, + content_key_padding_mask: Optional[Tensor] = None): + for i, mod in enumerate(self.layers): + query, sa_weights, ca_weights = mod(query, memory, query_mask, content_mask, content_key_padding_mask, + update_content= False) + query = self.norm(query) + return query, sa_weights, ca_weights + + +class TokenEmbedding(nn.Module): + + def __init__(self, charset_size: int, embed_dim: int): + super().__init__() + self.embedding = nn.Embedding(charset_size, embed_dim) + self.embed_dim = embed_dim + + def forward(self, tokens: torch.Tensor): + return math.sqrt(self.embed_dim) * self.embedding(tokens) + + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class TPS_SpatialTransformerNetwork(nn.Module): + """ Rectification Network of RARE, namely TPS based STN """ + + def __init__(self, F, I_size, I_r_size, I_channel_num=1): + """ Based on RARE TPS + input: + batch_I: Batch Input Image [batch_size x I_channel_num x I_height x I_width] + I_size : (height, width) of the input image I + I_r_size : (height, width) of the rectified image I_r + I_channel_num : the number of channels of the input image I + output: + batch_I_r: rectified image [batch_size x I_channel_num x I_r_height x I_r_width] + """ + super().__init__() + self.F = F + self.I_size = I_size + self.I_r_size = I_r_size # = (I_r_height, I_r_width) + self.I_channel_num = I_channel_num + self.LocalizationNetwork = LocalizationNetwork(self.F, self.I_channel_num) + self.GridGenerator = GridGenerator(self.F, self.I_r_size) + + def forward(self, batch_I): + batch_C_prime = self.LocalizationNetwork(batch_I) # batch_size x K x 2 + # batch_size x n (= I_r_width x I_r_height) x 2 + build_P_prime = self.GridGenerator.build_P_prime(batch_C_prime) + build_P_prime_reshape = build_P_prime.reshape([build_P_prime.size(0), self.I_r_size[0], self.I_r_size[1], 2]) + + if torch.__version__ > "1.2.0": + batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border', align_corners=True) + else: + batch_I_r = F.grid_sample(batch_I, build_P_prime_reshape, padding_mode='border') + + return batch_I_r + + +class LocalizationNetwork(nn.Module): + """ Localization Network of RARE, which predicts C' (K x 2) from I (I_width x I_height) """ + + def __init__(self, F, I_channel_num): + super().__init__() + self.F = F + self.I_channel_num = I_channel_num + self.conv = nn.Sequential( + nn.Conv2d(in_channels=self.I_channel_num, out_channels=64, kernel_size=3, stride=1, padding=1, + bias=False), nn.BatchNorm2d(64), nn.ReLU(True), + nn.MaxPool2d(2, 2), # batch_size x 64 x I_height/2 x I_width/2 + nn.Conv2d(64, 128, 3, 1, 1, bias=False), nn.BatchNorm2d(128), nn.ReLU(True), + nn.MaxPool2d(2, 2), # batch_size x 128 x I_height/4 x I_width/4 + nn.Conv2d(128, 256, 3, 1, 1, bias=False), nn.BatchNorm2d(256), nn.ReLU(True), + nn.MaxPool2d(2, 2), # batch_size x 256 x I_height/8 x I_width/8 + nn.Conv2d(256, 512, 3, 1, 1, bias=False), nn.BatchNorm2d(512), nn.ReLU(True), + nn.AdaptiveAvgPool2d(1) # batch_size x 512 + ) + + self.localization_fc1 = nn.Sequential(nn.Linear(512, 256), nn.ReLU(True)) + self.localization_fc2 = nn.Linear(256, self.F * 2) + + # Init fc2 in LocalizationNetwork + self.localization_fc2.weight.data.fill_(0) + """ see RARE paper Fig. 6 (a) """ + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = np.linspace(0.0, -1.0, num=int(F / 2)) + ctrl_pts_y_bottom = np.linspace(1.0, 0.0, num=int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + initial_bias = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + self.localization_fc2.bias.data = torch.from_numpy(initial_bias).float().view(-1) + + def forward(self, batch_I): + """ + input: batch_I : Batch Input Image [batch_size x I_channel_num x I_height x I_width] + output: batch_C_prime : Predicted coordinates of fiducial points for input batch [batch_size x F x 2] + """ + batch_size = batch_I.size(0) + features = self.conv(batch_I).view(batch_size, -1) + batch_C_prime = self.localization_fc2(self.localization_fc1(features)).view(batch_size, self.F, 2) + return batch_C_prime + + +class GridGenerator(nn.Module): + """ Grid Generator of RARE, which produces P_prime by multipling T with P """ + + def __init__(self, F, I_r_size): + """ Generate P_hat and inv_delta_C for later """ + super().__init__() + self.eps = 1e-6 + self.I_r_height, self.I_r_width = I_r_size + self.F = F + self.C = self._build_C(self.F) # F x 2 + self.P = self._build_P(self.I_r_width, self.I_r_height) + + # num_gpu = torch.cuda.device_count() + # if num_gpu > 1: + # for multi-gpu, you may need register buffer + self.register_buffer("inv_delta_C", torch.tensor( + self._build_inv_delta_C(self.F, self.C)).float()) # F+3 x F+3 + self.register_buffer("P_hat", torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float()) # n x F+3 + # else: + # # for fine-tuning with different image width, you may use below instead of self.register_buffer + # self.inv_delta_C = torch.tensor(self._build_inv_delta_C(self.F, self.C)).float() # F+3 x F+3 + # self.P_hat = torch.tensor(self._build_P_hat(self.F, self.C, self.P)).float() # n x F+3 + + def _build_C(self, F): + """ Return coordinates of fiducial points in I_r; C """ + ctrl_pts_x = np.linspace(-1.0, 1.0, int(F / 2)) + ctrl_pts_y_top = -1 * np.ones(int(F / 2)) + ctrl_pts_y_bottom = np.ones(int(F / 2)) + ctrl_pts_top = np.stack([ctrl_pts_x, ctrl_pts_y_top], axis=1) + ctrl_pts_bottom = np.stack([ctrl_pts_x, ctrl_pts_y_bottom], axis=1) + C = np.concatenate([ctrl_pts_top, ctrl_pts_bottom], axis=0) + return C # F x 2 + + def _build_inv_delta_C(self, F, C): + """ Return inv_delta_C which is needed to calculate T """ + hat_C = np.zeros((F, F), dtype=float) # F x F + for i in range(0, F): + for j in range(i, F): + r = np.linalg.norm(C[i] - C[j]) + hat_C[i, j] = r + hat_C[j, i] = r + np.fill_diagonal(hat_C, 1) + hat_C = (hat_C ** 2) * np.log(hat_C) + # print(C.shape, hat_C.shape) + delta_C = np.concatenate( # F+3 x F+3 + [ + np.concatenate([np.ones((F, 1)), C, hat_C], axis=1), # F x F+3 + np.concatenate([np.zeros((2, 3)), np.transpose(C)], axis=1), # 2 x F+3 + np.concatenate([np.zeros((1, 3)), np.ones((1, F))], axis=1) # 1 x F+3 + ], + axis=0 + ) + inv_delta_C = np.linalg.inv(delta_C) + return inv_delta_C # F+3 x F+3 + + def _build_P(self, I_r_width, I_r_height): + I_r_grid_x = (np.arange(-I_r_width, I_r_width, 2) + 1.0) / I_r_width # self.I_r_width + I_r_grid_y = (np.arange(-I_r_height, I_r_height, 2) + 1.0) / I_r_height # self.I_r_height + P = np.stack( # self.I_r_width x self.I_r_height x 2 + np.meshgrid(I_r_grid_x, I_r_grid_y), + axis=2 + ) + return P.reshape([-1, 2]) # n (= self.I_r_width x self.I_r_height) x 2 + + def _build_P_hat(self, F, C, P): + n = P.shape[0] # n (= self.I_r_width x self.I_r_height) + P_tile = np.tile(np.expand_dims(P, axis=1), (1, F, 1)) # n x 2 -> n x 1 x 2 -> n x F x 2 + C_tile = np.expand_dims(C, axis=0) # 1 x F x 2 + P_diff = P_tile - C_tile # n x F x 2 + rbf_norm = np.linalg.norm(P_diff, ord=2, axis=2, keepdims=False) # n x F + rbf = np.multiply(np.square(rbf_norm), np.log(rbf_norm + self.eps)) # n x F + P_hat = np.concatenate([np.ones((n, 1)), P, rbf], axis=1) + return P_hat # n x F+3 + + def build_P_prime(self, batch_C_prime): + """ Generate Grid from batch_C_prime [batch_size x F x 2] """ + batch_size = batch_C_prime.size(0) + batch_inv_delta_C = self.inv_delta_C.repeat(batch_size, 1, 1) + batch_P_hat = self.P_hat.repeat(batch_size, 1, 1) + batch_C_prime_with_zeros = torch.cat((batch_C_prime, batch_C_prime.new_zeros( + (batch_size, 3, 2), dtype=torch.float)), dim=1) # batch_size x F+3 x 2 + batch_T = torch.bmm(batch_inv_delta_C, batch_C_prime_with_zeros) # batch_size x F+3 x 2 + batch_P_prime = torch.bmm(batch_P_hat, batch_T) # batch_size x n x 2 + return batch_P_prime # batch_size x n x 2 diff --git a/strhub/models/utils.py b/strhub/models/utils.py new file mode 100644 index 0000000..e795fe3 --- /dev/null +++ b/strhub/models/utils.py @@ -0,0 +1,129 @@ +from pathlib import PurePath +from typing import Sequence + +import os +import torch +from torch import nn + +import yaml + + +class InvalidModelError(RuntimeError): + """Exception raised for any model-related error (creation, loading)""" + + +_WEIGHTS_URL = { + 'parseq-tiny': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq_tiny-e7a21b54.pt', + 'parseq': 'https://github.com/baudm/parseq/releases/download/v1.0.0/parseq-bb5792a6.pt', + 'abinet': 'https://github.com/baudm/parseq/releases/download/v1.0.0/abinet-1d1e373e.pt', + 'trba': 'https://github.com/baudm/parseq/releases/download/v1.0.0/trba-cfaed284.pt', + 'vitstr': 'https://github.com/baudm/parseq/releases/download/v1.0.0/vitstr-26d0fcf4.pt', + 'crnn': 'https://github.com/baudm/parseq/releases/download/v1.0.0/crnn-679d0e31.pt', +} + + +def _get_config(experiment: str, **kwargs): + """Emulates hydra config resolution""" + root = PurePath(__file__).parents[2] + with open(root / 'configs/main.yaml', 'r') as f: + config = yaml.load(f, yaml.Loader)['model'] + with open(root / f'configs/charset/94_full.yaml', 'r') as f: + config.update(yaml.load(f, yaml.Loader)['model']) + with open(root / f'configs/experiment/{experiment}.yaml', 'r') as f: + exp = yaml.load(f, yaml.Loader) + # Apply base model config + model = exp['defaults'][0]['override /model'] + with open(root / f'configs/model/{model}.yaml', 'r') as f: + config.update(yaml.load(f, yaml.Loader)) + # Apply experiment config + if 'model' in exp: + config.update(exp['model']) + config.update(kwargs) + # Workaround for now: manually cast the lr to the correct type. + config['lr'] = float(config['lr']) + return config + + +def _get_model_class(key): + if 'abinet' in key: + from .abinet.system import ABINet as ModelClass + elif 'crnn' in key: + from .crnn.system import CRNN as ModelClass + elif 'trba' in key: + from .trba.system import TRBA as ModelClass + elif 'trbc' in key: + from .trba.system import TRBC as ModelClass + elif 'vitstr' in key: + from .vitstr.system import ViTSTR as ModelClass + elif 'parseq' in key: + from .parseq.system import PARSeq as ModelClass + elif "cfe" in key: + from .cfe.system import CFE as ModelClass + else: + raise InvalidModelError(f"Unable to find model class for '{key}'") + return ModelClass + + +def get_pretrained_weights(experiment): + if os.path.exists(os.path.join("checkpoint", experiment + ".pth")): + return torch.load(os.path.join("checkpoint", experiment + ".pth"), map_location="cpu") + try: + url = _WEIGHTS_URL[experiment] + except KeyError: + raise InvalidModelError(f"No pretrained weights found for '{experiment}'") from None + return torch.hub.load_state_dict_from_url(url=url, map_location='cpu', check_hash=True) + + +def create_model(experiment: str, pretrained: bool = False, **kwargs): + try: + config = _get_config(experiment, **kwargs) + except FileNotFoundError: + raise InvalidModelError(f"No configuration found for '{experiment}'") from None + ModelClass = _get_model_class(experiment) + model = ModelClass(**config) + if pretrained: + model.load_state_dict(get_pretrained_weights(experiment), strict=False) + return model + + +def load_from_checkpoint(checkpoint_path: str, **kwargs): + if checkpoint_path.startswith('pretrained='): + model_id = checkpoint_path.split('=', maxsplit=1)[1] + model = create_model(model_id, True, **kwargs) + else: + ModelClass = _get_model_class(checkpoint_path) + checkpoint_path = "checkpoint/" + checkpoint_path + ".ckpt" + model = ModelClass.load_from_checkpoint(checkpoint_path, strict=False,**kwargs) + return model + + +def parse_model_args(args): + kwargs = {} + arg_types = {t.__name__: t for t in [int, float, str]} + arg_types['bool'] = lambda v: v.lower() == 'true' # special handling for bool + for arg in args: + name, value = arg.split('=', maxsplit=1) + name, arg_type = name.split(':', maxsplit=1) + kwargs[name] = arg_types[arg_type](value) + return kwargs + + +def init_weights(module: nn.Module, name: str = '', exclude: Sequence[str] = ()): + """Initialize the weights using the typical initialization schemes used in SOTA models.""" + if any(map(name.startswith, exclude)): + return + if isinstance(module, nn.Linear): + nn.init.trunc_normal_(module.weight, std=.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Embedding): + nn.init.trunc_normal_(module.weight, std=.02) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + elif isinstance(module, nn.Conv2d): + nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu') + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, (nn.LayerNorm, nn.BatchNorm2d, nn.GroupNorm)): + nn.init.ones_(module.weight) + nn.init.zeros_(module.bias) diff --git a/test.py b/test.py new file mode 100644 index 0000000..b4b977b --- /dev/null +++ b/test.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import time +import argparse +import string +import sys +import numpy as np +import collections +from dataclasses import dataclass +from typing import List + +import torch + +from tqdm import tqdm + +from strhub.data.module import SceneTextDataModule +from strhub.models.utils import load_from_checkpoint, parse_model_args + + +@dataclass +class Result: + dataset: str + num_samples: int + accuracy: float + ned: float + confidence: float + label_length: float + + +def print_results_table(results: List[Result], file=None): + w = max(map(len, map(getattr, results, ['dataset'] * len(results)))) + w = max(w, len('Dataset'), len('Combined')) + print('| {:<{w}} | # samples | Accuracy | 1 - NED | Confidence | Label Length |'.format('Dataset', w=w), file=file) + print('|:{:-<{w}}:|----------:|---------:|--------:|-----------:|-------------:|'.format('----', w=w), file=file) + c = Result('Combined', 0, 0, 0, 0, 0) + for res in results: + c.num_samples += res.num_samples + c.accuracy += res.num_samples * res.accuracy + c.ned += res.num_samples * res.ned + c.confidence += res.num_samples * res.confidence + c.label_length += res.num_samples * res.label_length + print(f'| {res.dataset:<{w}} | {res.num_samples:>9} | {res.accuracy:>8.2f} | {res.ned:>7.2f} ' + f'| {res.confidence:>10.2f} | {res.label_length:>12.2f} |', file=file) + c.accuracy /= c.num_samples + c.ned /= c.num_samples + c.confidence /= c.num_samples + c.label_length /= c.num_samples + print('|-{:-<{w}}-|-----------|----------|---------|------------|--------------|'.format('----', w=w), file=file) + print(f'| {c.dataset:<{w}} | {c.num_samples:>9} | {c.accuracy:>8.2f} | {c.ned:>7.2f} ' + f'| {c.confidence:>10.2f} | {c.label_length:>12.2f} |', file=file) + +@torch.inference_mode() +def main(): + parser = argparse.ArgumentParser() + parser.add_argument('checkpoint', help="Model checkpoint (or 'pretrained=')") + parser.add_argument('--data_root', required=True) + parser.add_argument('--batch_size', type=int, default=512) + parser.add_argument('--num_workers', type=int, default=6) + parser.add_argument('--test_speed', action='store_true', default=False, help='whethe to test speed') + parser.add_argument('--cased', action='store_true', default=False, help='Cased comparison') + parser.add_argument('--punctuation', action='store_true', default=False, help='Check punctuation') + parser.add_argument('--new', action='store_true', default=False, help='Evaluate on new benchmark datasets') + parser.add_argument('--rotation', type=int, default=0, help='Angle of rotation (counter clockwise) in degrees.') + parser.add_argument('--device', default='cuda') + args, unknown = parser.parse_known_args() + print(args) + kwargs = parse_model_args(unknown) + if args.test_speed: + args.batch_size = 1 + charset_test = string.digits + string.ascii_lowercase + if args.cased: + charset_test += string.ascii_uppercase + if args.punctuation: + charset_test += string.punctuation + kwargs.update({'charset_test': charset_test}) + print(f'Additional keyword arguments: {kwargs}') + + model = load_from_checkpoint(args.checkpoint, **kwargs).eval().to(args.device) + hp = model.hparams + datamodule = SceneTextDataModule(args.data_root, '_unused_', hp.img_size, hp.max_label_length, hp.charset_train, + hp.charset_test, args.batch_size, args.num_workers, False, rotation=args.rotation) + if "Union14M" not in args.data_root: + test_set = SceneTextDataModule.TEST_BENCHMARK_SUB + SceneTextDataModule.TEST_BENCHMARK + ("WordArt_train", "WordArt_test") + if args.new: + test_set += SceneTextDataModule.TEST_NEW + result_groups = { + 'Benchmark (Subset)': SceneTextDataModule.TEST_BENCHMARK_SUB, + 'Benchmark': SceneTextDataModule.TEST_BENCHMARK, + 'WordArt': ("WordArt_train", "WordArt_test"), + } + if args.new: + result_groups.update({'New': SceneTextDataModule.TEST_NEW}) + else: + test_set1 = ("curve", "multi_oriented","artistic", "contextless", "salient" , "multi_words","general") + test_set = test_set1 + ("incomplete", ) + result_groups = { + 'Union14M-L': test_set1, + "incomplete" : ("incomplete", ) + } + test_set = sorted(set(test_set)) + print(test_set) + + results = {} + max_width = max(map(len, test_set)) + for name, dataloader in datamodule.test_dataloaders(test_set).items(): + time1 = time.time() + total = 0 + correct = 0 + ned = 0 + confidence = 0 + label_length = 0 + if args.test_speed and 'WordArt_train' not in name: + continue + for imgs, labels in tqdm(iter(dataloader), desc=f'{name:>{max_width}}'): + res = model.test_step((imgs.to(model.device), labels), -1)["output"] + total += res.num_samples + correct += res.correct + ned += res.ned + confidence += res.confidence + label_length += res.label_length + time2 = time.time() + accuracy = 100 * correct / total + mean_ned = 100 * (1 - ned / total) + mean_conf = 100 * confidence / total + mean_label_length = label_length / total + results[name] = Result(name, total, accuracy, mean_ned, mean_conf, mean_label_length) + if args.test_speed and 'WordArt_train' in name: + print(f"rnuning time: {time2 - time1}, samples have {total}, average fps {(time2-time1)/total*1000} ms/img.") + exit() + if "Union14M" in args.data_root: + args.checkpoint = args.checkpoint + "_union" + with open(os.path.join("test_results", args.checkpoint + '.txt'), 'w') as f: + for out in [f, sys.stdout]: + for group, subset in result_groups.items(): + print(f'{group} set:', file=out) + print_results_table([results[s] for s in subset], out) + print('\n', file=out) + + +if __name__ == '__main__': + main() diff --git a/tools/art_converter.py b/tools/art_converter.py new file mode 100644 index 0000000..f61e0b5 --- /dev/null +++ b/tools/art_converter.py @@ -0,0 +1,26 @@ +#!/usr/bin/env python3 + +import json + +with open('train_task2_labels.json', 'r', encoding='utf8') as f: + d = json.load(f) + +with open('gt.txt', 'w', encoding='utf8') as f: + for k, v in d.items(): + if len(v) != 1: + print('error', v) + v = v[0] + if v['language'].lower() != 'latin': + # print('Skipping non-Latin:', v) + continue + if v['illegibility']: + # print('Skipping unreadable:', v) + continue + label = v['transcription'].strip() + if not label: + # print('Skipping blank label') + continue + if '#' in label and label != 'LocaL#3': + # print('Skipping corrupted label') + continue + f.write('\t'.join(['train_task2_images/' + k + '.jpg', label]) + '\n') diff --git a/tools/case_sensitive_str_datasets_converter.py b/tools/case_sensitive_str_datasets_converter.py new file mode 100644 index 0000000..7ce7c0d --- /dev/null +++ b/tools/case_sensitive_str_datasets_converter.py @@ -0,0 +1,28 @@ +#!/usr/bin/env python3 + +import os.path +import sys +from pathlib import Path + +d = sys.argv[1] +p = Path(d) + +gt = [] + +num_samples = len(list(p.glob('label/*.txt'))) +ext = 'jpg' if p.joinpath('IMG', '1.jpg').is_file() else 'png' + +for i in range(1, num_samples + 1): + img = p.joinpath('IMG', f'{i}.{ext}') + name = os.path.splitext(img.name)[0] + + with open(p.joinpath('label', f'{i}.txt'), 'r') as f: + label = f.readline() + gt.append((os.path.join('IMG', img.name), label)) + +with open(d + '/lmdb.txt', 'w', encoding='utf-8') as f: + for line in gt: + fname, label = line + fname = fname.strip() + label = label.strip() + f.write('\t'.join([fname, label]) + '\n') diff --git a/tools/coco_2_converter.py b/tools/coco_2_converter.py new file mode 100644 index 0000000..848d9a1 --- /dev/null +++ b/tools/coco_2_converter.py @@ -0,0 +1,126 @@ +#!/usr/bin/env python3 +import argparse +import html +import math +import os +import os.path as osp +from functools import partial + +import mmcv +from PIL import Image +from mmocr.utils.fileio import list_to_file + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Generate training and validation set of TextOCR ' + 'by cropping box image.') + parser.add_argument('root_path', help='Root dir path of TextOCR') + parser.add_argument( + 'n_proc', default=1, type=int, help='Number of processes to run') + args = parser.parse_args() + return args + + +def process_img(args, src_image_root, dst_image_root): + # Dirty hack for multiprocessing + img_idx, img_info, anns = args + src_img = Image.open(osp.join(src_image_root, 'train2014', img_info['file_name'])) + src_w, src_h = src_img.size + labels = [] + for ann_idx, ann in enumerate(anns): + text_label = html.unescape(ann['utf8_string'].strip()) + + # Ignore empty labels + if not text_label or ann['class'] != 'machine printed' or ann['language'] != 'english' or \ + ann['legibility'] != 'legible': + continue + + # Some labels and images with '#' in the middle are actually good, but some aren't, so we just filter them all. + if text_label != '#' and '#' in text_label: + continue + + # Some labels use '*' to denote unreadable characters + if text_label.startswith('*') or text_label.endswith('*'): + continue + + pad = 2 + x, y, w, h = ann['bbox'] + x, y = max(0, math.floor(x) - pad), max(0, math.floor(y) - pad) + w, h = math.ceil(w), math.ceil(h) + x2, y2 = min(src_w, x + w + 2 * pad), min(src_h, y + h + 2 * pad) + dst_img = src_img.crop((x, y, x2, y2)) + dst_img_name = f'img_{img_idx}_{ann_idx}.jpg' + dst_img_path = osp.join(dst_image_root, dst_img_name) + # Preserve JPEG quality + dst_img.save(dst_img_path, qtables=src_img.quantization) + labels.append(f'{osp.basename(dst_image_root)}/{dst_img_name}' + f' {text_label}') + src_img.close() + return labels + + +def convert_textocr(root_path, + dst_image_path, + dst_label_filename, + annotation_filename, + img_start_idx=0, + nproc=1): + annotation_path = osp.join(root_path, annotation_filename) + if not osp.exists(annotation_path): + raise Exception( + f'{annotation_path} not exists, please check and try again.') + src_image_root = root_path + + # outputs + dst_label_file = osp.join(root_path, dst_label_filename) + dst_image_root = osp.join(root_path, dst_image_path) + os.makedirs(dst_image_root, exist_ok=True) + + annotation = mmcv.load(annotation_path) + split = 'train' if 'train' in dst_label_filename else 'val' + + process_img_with_path = partial( + process_img, + src_image_root=src_image_root, + dst_image_root=dst_image_root) + tasks = [] + for img_idx, img_info in enumerate(annotation['imgs'].values()): + if img_info['set'] != split: + continue + ann_ids = annotation['imgToAnns'][str(img_info['id'])] + anns = [annotation['anns'][str(ann_id)] for ann_id in ann_ids] + tasks.append((img_idx + img_start_idx, img_info, anns)) + + labels_list = mmcv.track_parallel_progress( + process_img_with_path, tasks, keep_order=True, nproc=nproc) + final_labels = [] + for label_list in labels_list: + final_labels += label_list + list_to_file(dst_label_file, final_labels) + return len(annotation['imgs']) + + +def main(): + args = parse_args() + root_path = args.root_path + print('Processing training set...') + num_train_imgs = convert_textocr( + root_path=root_path, + dst_image_path='image', + dst_label_filename='train_label.txt', + annotation_filename='cocotext.v2.json', + nproc=args.n_proc) + print('Processing validation set...') + convert_textocr( + root_path=root_path, + dst_image_path='image_val', + dst_label_filename='val_label.txt', + annotation_filename='cocotext.v2.json', + img_start_idx=num_train_imgs, + nproc=args.n_proc) + print('Finish') + + +if __name__ == '__main__': + main() diff --git a/tools/coco_text_converter.py b/tools/coco_text_converter.py new file mode 100644 index 0000000..09d130d --- /dev/null +++ b/tools/coco_text_converter.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 + +for s in ['train', 'val']: + with open('{}_words_gt.txt'.format(s), 'r', encoding='utf8') as f: + d = f.readlines() + + with open('{}_lmdb.txt'.format(s), 'w', encoding='utf8') as f: + for line in d: + try: + fname, label = line.split(',', maxsplit=1) + except ValueError: + continue + fname = '{}_words/{}.jpg'.format(s, fname.strip()) + label = label.strip().strip('|') + f.write('\t'.join([fname, label]) + '\n') diff --git a/tools/create_lmdb_dataset.py b/tools/create_lmdb_dataset.py new file mode 100644 index 0000000..8d6a669 --- /dev/null +++ b/tools/create_lmdb_dataset.py @@ -0,0 +1,132 @@ +#!/usr/bin/env python3 +""" a modified version of CRNN torch repository https://github.com/bgshih/crnn/blob/master/tool/create_dataset.py """ +import io +import os + +import fire +import lmdb +import numpy as np +from PIL import Image + + +def checkImageIsValid(imageBin): + if imageBin is None: + return False + img = Image.open(io.BytesIO(imageBin)).convert('RGB') + return np.prod(img.size) > 0 + + +def writeCache(env, cache): + with env.begin(write=True) as txn: + for k, v in cache.items(): + txn.put(k, v) + + +# def createDataset(inputPath, gtFile, outputPath, checkValid=True): +# """ +# Create LMDB dataset for training and evaluation. +# ARGS: +# inputPath : input folder path where starts imagePath +# outputPath : LMDB output path +# gtFile : list of image path and label +# checkValid : if true, check the validity of every image +# """ +# os.makedirs(outputPath, exist_ok=True) +# env = lmdb.open(outputPath, map_size=1099511627776) + +# cache = {} +# cnt = 1 + +# with open(gtFile, 'r', encoding='utf-8') as f: +# data = f.readlines() + +# nSamples = len(data) +# for i, line in enumerate(data): +# imagePath, label = line.strip().split(maxsplit=1) +# imagePath = imagePath.replace("\\", "/") +# imagePath = os.path.join(inputPath, imagePath) +# with open(imagePath, 'rb') as f: +# imageBin = f.read() +# if checkValid: +# try: +# img = Image.open(io.BytesIO(imageBin)).convert('RGB') +# except IOError as e: +# with open(outputPath + '/error_image_log.txt', 'a') as log: +# log.write('{}-th image data occured error: {}, {}\n'.format(i, imagePath, e)) +# continue +# if np.prod(img.size) == 0: +# print('%s is not a valid image' % imagePath) +# continue + +# imageKey = 'image-%09d'.encode() % cnt +# labelKey = 'label-%09d'.encode() % cnt +# cache[imageKey] = imageBin +# cache[labelKey] = label.encode() + +# if cnt % 1000 == 0: +# writeCache(env, cache) +# cache = {} +# print('Written %d / %d' % (cnt, nSamples)) +# cnt += 1 +# nSamples = cnt - 1 +# cache['num-samples'.encode()] = str(nSamples).encode() +# writeCache(env, cache) +# env.close() +# print('Created dataset with %d samples' % nSamples) + +def createDataset(inputPath, outputPath, checkValid=True): + """ + Create LMDB dataset for training and evaluation. + ARGS: + inputPath : input folder path where starts imagePath + outputPath : LMDB output path + gtFile : list of image path and label + checkValid : if true, check the validity of every image + """ + os.makedirs(outputPath, exist_ok=True) + env = lmdb.open(outputPath, map_size=1099511627776) + + cache = {} + cnt = 1 + + data = os.listdir(os.path.join(inputPath, "test_image")) + + nSamples = len(data) + for i, d in enumerate(data): + imagePath = d.replace("\\", "/") + imagePath = os.path.join(inputPath, "test_image", imagePath) + with open(imagePath, 'rb') as f: + imageBin = f.read() + if checkValid: + try: + img = Image.open(io.BytesIO(imageBin)).convert('RGB') + except IOError as e: + with open(outputPath + '/error_image_log.txt', 'a') as log: + log.write('{}-th image data occured error: {}, {}\n'.format(i, imagePath, e)) + continue + if np.prod(img.size) == 0: + print('%s is not a valid image' % imagePath) + continue + + imageKey = 'image-%09d'.encode() % cnt + labelKey = 'label-%09d'.encode() % cnt + pathkey = 'path-%09d'.encode() % cnt + cache[imageKey] = imageBin + cache[labelKey] = "null character".encode() + cache[pathkey] = ("test_image\\" + d).encode() + + if cnt % 1000 == 0: + writeCache(env, cache) + cache = {} + print('Written %d / %d' % (cnt, nSamples)) + cnt += 1 + nSamples = cnt - 1 + cache['num-samples'.encode()] = str(nSamples).encode() + writeCache(env, cache) + env.close() + print('Created dataset with %d samples' % nSamples) + + + +if __name__ == '__main__': + fire.Fire(createDataset) diff --git a/tools/filter_lmdb.py b/tools/filter_lmdb.py new file mode 100644 index 0000000..0d1b445 --- /dev/null +++ b/tools/filter_lmdb.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +import io +import os +from argparse import ArgumentParser + +import numpy as np +import lmdb +from PIL import Image + + +def main(): + parser = ArgumentParser() + parser.add_argument('inputs', nargs='+', help='Path to input LMDBs') + parser.add_argument('--output', help='Path to output LMDB') + parser.add_argument('--min_image_dim', type=int, default=8) + args = parser.parse_args() + + os.makedirs(args.output, exist_ok=True) + with lmdb.open(args.output, map_size=1099511627776) as env_out: + in_samples = 0 + out_samples = 0 + samples_per_chunk = 1000 + for lmdb_in in args.inputs: + with lmdb.open(lmdb_in, readonly=True, max_readers=1, lock=False) as env_in: + with env_in.begin() as txn: + num_samples = int(txn.get('num-samples'.encode())) + in_samples += num_samples + chunks = np.array_split(range(num_samples), num_samples // samples_per_chunk) + for chunk in chunks: + cache = {} + with env_in.begin() as txn: + for index in chunk: + index += 1 # lmdb starts at 1 + image_key = f'image-{index:09d}'.encode() + image_bin = txn.get(image_key) + img = Image.open(io.BytesIO(image_bin)) + w, h = img.size + if w < args.min_image_dim or h < args.min_image_dim: + print(f'Skipping: {index}, w = {w}, h = {h}') + continue + out_samples += 1 # increment. start at 1 + label_key = f'label-{index:09d}'.encode() + out_label_key = f'label-{out_samples:09d}'.encode() + out_image_key = f'image-{out_samples:09d}'.encode() + cache[out_label_key] = txn.get(label_key) + cache[out_image_key] = image_bin + with env_out.begin(write=True) as txn: + for k, v in cache.items(): + txn.put(k, v) + print(f'Written samples from {chunk[0]} to {chunk[-1]}') + with env_out.begin(write=True) as txn: + txn.put('num-samples'.encode(), str(out_samples).encode()) + print(f'Written {out_samples} samples to {args.output} out of {in_samples} input samples.') + + +if __name__ == '__main__': + main() diff --git a/tools/lsvt_converter.py b/tools/lsvt_converter.py new file mode 100644 index 0000000..9d19d24 --- /dev/null +++ b/tools/lsvt_converter.py @@ -0,0 +1,107 @@ +#!/usr/bin/env python3 +import argparse +import os +import os.path as osp +import re +from functools import partial + +import mmcv +import numpy as np +from PIL import Image +from mmocr.utils.fileio import list_to_file + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Generate training set of LSVT ' + 'by cropping box image.') + parser.add_argument('root_path', help='Root dir path of LSVT') + parser.add_argument( + 'n_proc', default=1, type=int, help='Number of processes to run') + args = parser.parse_args() + return args + + +def process_img(args, src_image_root, dst_image_root): + # Dirty hack for multiprocessing + img_idx, img_info, anns = args + try: + src_img = Image.open(osp.join(src_image_root, 'train_full_images_0/{}.jpg'.format(img_info))) + except IOError: + src_img = Image.open(osp.join(src_image_root, 'train_full_images_1/{}.jpg'.format(img_info))) + blacklist = ['LOFTINESS*'] + whitelist = ['#Find YOUR Fun#', 'Story #', '*0#'] + labels = [] + for ann_idx, ann in enumerate(anns): + text_label = ann['transcription'] + + # Ignore illegible or words with non-Latin characters + if ann['illegibility'] or re.findall(r'[\u4e00-\u9fff]+', text_label) or text_label in blacklist or \ + ('#' in text_label and text_label not in whitelist): + continue + + points = np.asarray(ann['points']) + x1, y1 = points.min(axis=0) + x2, y2 = points.max(axis=0) + + dst_img = src_img.crop((x1, y1, x2, y2)) + dst_img_name = f'img_{img_idx}_{ann_idx}.jpg' + dst_img_path = osp.join(dst_image_root, dst_img_name) + # Preserve JPEG quality + dst_img.save(dst_img_path, qtables=src_img.quantization) + labels.append(f'{osp.basename(dst_image_root)}/{dst_img_name}' + f' {text_label}') + src_img.close() + return labels + + +def convert_lsvt(root_path, + dst_image_path, + dst_label_filename, + annotation_filename, + img_start_idx=0, + nproc=1): + annotation_path = osp.join(root_path, annotation_filename) + if not osp.exists(annotation_path): + raise Exception( + f'{annotation_path} not exists, please check and try again.') + src_image_root = root_path + + # outputs + dst_label_file = osp.join(root_path, dst_label_filename) + dst_image_root = osp.join(root_path, dst_image_path) + os.makedirs(dst_image_root, exist_ok=True) + + annotation = mmcv.load(annotation_path) + + process_img_with_path = partial( + process_img, + src_image_root=src_image_root, + dst_image_root=dst_image_root) + tasks = [] + for img_idx, (img_info, anns) in enumerate(annotation.items()): + tasks.append((img_idx + img_start_idx, img_info, anns)) + labels_list = mmcv.track_parallel_progress( + process_img_with_path, tasks, keep_order=True, nproc=nproc) + final_labels = [] + for label_list in labels_list: + final_labels += label_list + list_to_file(dst_label_file, final_labels) + return len(annotation) + + +def main(): + args = parse_args() + root_path = args.root_path + print('Processing training set...') + convert_lsvt( + root_path=root_path, + dst_image_path='image_train', + dst_label_filename='train_label.txt', + annotation_filename='train_full_labels.json', + nproc=args.n_proc) + print('Finish') + + +if __name__ == '__main__': + main() diff --git a/tools/mlt19_converter.py b/tools/mlt19_converter.py new file mode 100644 index 0000000..665d497 --- /dev/null +++ b/tools/mlt19_converter.py @@ -0,0 +1,15 @@ +#!/usr/bin/env python3 + +import sys + +root = sys.argv[1] + +with open(root + '/gt.txt', 'r') as f: + d = f.readlines() + +with open(root + '/lmdb.txt', 'w') as f: + for line in d: + img, script, label = line.split(',', maxsplit=2) + label = label.strip() + if label and script in ['Latin', 'Symbols']: + f.write('\t'.join([img, label]) + '\n') diff --git a/tools/openvino_converter.py b/tools/openvino_converter.py new file mode 100644 index 0000000..e6ce568 --- /dev/null +++ b/tools/openvino_converter.py @@ -0,0 +1,116 @@ +#!/usr/bin/env python3 +import math +import os +import os.path as osp +from argparse import ArgumentParser +from functools import partial + +import mmcv +from PIL import Image + +from mmocr.utils.fileio import list_to_file + + +def parse_args(): + parser = ArgumentParser(description='Generate training and validation set ' + 'of OpenVINO annotations for Open ' + 'Images by cropping box image.') + parser.add_argument( + 'root_path', help='Root dir containing images and annotations') + parser.add_argument( + 'n_proc', default=1, type=int, help='Number of processes to run') + args = parser.parse_args() + return args + + +def process_img(args, src_image_root, dst_image_root): + # Dirty hack for multiprocessing + img_idx, img_info, anns = args + src_img = Image.open(osp.join(src_image_root, img_info['file_name'])) + labels = [] + for ann_idx, ann in enumerate(anns): + attrs = ann['attributes'] + text_label = attrs['transcription'] + + # Ignore illegible or non-English words + if not attrs['legible'] or attrs['language'] != 'english': + continue + + x, y, w, h = ann['bbox'] + x, y = max(0, math.floor(x)), max(0, math.floor(y)) + w, h = math.ceil(w), math.ceil(h) + dst_img = src_img.crop((x, y, x + w, y + h)) + dst_img_name = f'img_{img_idx}_{ann_idx}.jpg' + dst_img_path = osp.join(dst_image_root, dst_img_name) + # Preserve JPEG quality + dst_img.save(dst_img_path, qtables=src_img.quantization) + labels.append(f'{osp.basename(dst_image_root)}/{dst_img_name}' + f' {text_label}') + src_img.close() + return labels + + +def convert_openimages(root_path, + dst_image_path, + dst_label_filename, + annotation_filename, + img_start_idx=0, + nproc=1): + annotation_path = osp.join(root_path, annotation_filename) + if not osp.exists(annotation_path): + raise Exception( + f'{annotation_path} not exists, please check and try again.') + src_image_root = root_path + + # outputs + dst_label_file = osp.join(root_path, dst_label_filename) + dst_image_root = osp.join(root_path, dst_image_path) + os.makedirs(dst_image_root, exist_ok=True) + + annotation = mmcv.load(annotation_path) + + process_img_with_path = partial( + process_img, + src_image_root=src_image_root, + dst_image_root=dst_image_root) + tasks = [] + anns = {} + for ann in annotation['annotations']: + anns.setdefault(ann['image_id'], []).append(ann) + for img_idx, img_info in enumerate(annotation['images']): + tasks.append((img_idx + img_start_idx, img_info, anns[img_info['id']])) + labels_list = mmcv.track_parallel_progress( + process_img_with_path, tasks, keep_order=True, nproc=nproc) + final_labels = [] + for label_list in labels_list: + final_labels += label_list + list_to_file(dst_label_file, final_labels) + return len(annotation['images']) + + +def main(): + args = parse_args() + root_path = args.root_path + print('Processing training set...') + num_train_imgs = 0 + for s in '125f': + num_train_imgs = convert_openimages( + root_path=root_path, + dst_image_path=f'image_{s}', + dst_label_filename=f'train_{s}_label.txt', + annotation_filename=f'text_spotting_openimages_v5_train_{s}.json', + img_start_idx=num_train_imgs, + nproc=args.n_proc) + print('Processing validation set...') + convert_openimages( + root_path=root_path, + dst_image_path='image_val', + dst_label_filename='val_label.txt', + annotation_filename='text_spotting_openimages_v5_validation.json', + img_start_idx=num_train_imgs, + nproc=args.n_proc) + print('Finish') + + +if __name__ == '__main__': + main() diff --git a/tools/test_abinet_lm_acc.py b/tools/test_abinet_lm_acc.py new file mode 100644 index 0000000..a411b39 --- /dev/null +++ b/tools/test_abinet_lm_acc.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +import argparse +import string +import sys + +import torch +import torch.nn.functional as F +from torch import Tensor +from torch.nn.utils.rnn import pad_sequence + +from tqdm import tqdm + +from strhub.data.module import SceneTextDataModule +from strhub.models.abinet.system import ABINet + +sys.path.insert(0, '.') +from hubconf import _get_config +from test import Result, print_results_table + + +class ABINetLM(ABINet): + + def _encode(self, labels): + targets = [torch.arange(self.max_label_length + 1)] # dummy target. used to set pad_sequence() length + lengths = [] + for label in labels: + targets.append(torch.as_tensor([self.tokenizer._stoi[c] for c in label])) + lengths.append(len(label) + 1) + targets = pad_sequence(targets, batch_first=True, padding_value=0)[1:] # exclude dummy target + lengths = torch.as_tensor(lengths, device=self.device) + targets = F.one_hot(targets, len(self.tokenizer._stoi))[..., :len(self.tokenizer._stoi) - 2].float().to(self.device) + return targets, lengths + + def forward(self, labels: Tensor, max_length: int = None) -> Tensor: + targets, lengths = self._encode(labels) + return self.model.language(targets, lengths)['logits'] + + +def main(): + parser = argparse.ArgumentParser(description='Measure the word accuracy of ABINet LM using the ground truth as input') + parser.add_argument('checkpoint', help='Official pretrained weights for ABINet-LV (best-train-abinet.pth)') + parser.add_argument('--data_root', default='data') + parser.add_argument('--batch_size', type=int, default=512) + parser.add_argument('--num_workers', type=int, default=4) + parser.add_argument('--new', action='store_true', default=False, help='Evaluate on new benchmark datasets') + parser.add_argument('--device', default='cuda') + args = parser.parse_args() + + # charset used by original ABINet + charset = string.ascii_lowercase + '1234567890' + ckpt = torch.load(args.checkpoint) + + config = _get_config('abinet', charset_train=charset, charset_test=charset) + model = ABINetLM(**config) + model.model.load_state_dict(ckpt['model']) + + model = model.eval().to(args.device) + model.freeze() # disable autograd + hp = model.hparams + datamodule = SceneTextDataModule(args.data_root, '_unused_', hp.img_size, hp.max_label_length, hp.charset_train, + hp.charset_test, args.batch_size, args.num_workers, False) + + test_set = SceneTextDataModule.TEST_BENCHMARK + if args.new: + test_set += SceneTextDataModule.TEST_NEW + test_set = sorted(set(test_set)) + + results = {} + max_width = max(map(len, test_set)) + for name, dataloader in datamodule.test_dataloaders(test_set).items(): + total = 0 + correct = 0 + ned = 0 + confidence = 0 + label_length = 0 + for _, labels in tqdm(iter(dataloader), desc=f'{name:>{max_width}}'): + res = model.test_step((labels, labels), -1)['output'] + total += res.num_samples + correct += res.correct + ned += res.ned + confidence += res.confidence + label_length += res.label_length + accuracy = 100 * correct / total + mean_ned = 100 * (1 - ned / total) + mean_conf = 100 * confidence / total + mean_label_length = label_length / total + results[name] = Result(name, total, accuracy, mean_ned, mean_conf, mean_label_length) + + result_groups = { + 'Benchmark': SceneTextDataModule.TEST_BENCHMARK + } + if args.new: + result_groups.update({'New': SceneTextDataModule.TEST_NEW}) + for group, subset in result_groups.items(): + print(f'{group} set:') + print_results_table([results[s] for s in subset]) + print('\n') + + +if __name__ == '__main__': + main() diff --git a/tools/textocr_converter.py b/tools/textocr_converter.py new file mode 100644 index 0000000..1bb2e9d --- /dev/null +++ b/tools/textocr_converter.py @@ -0,0 +1,142 @@ +#!/usr/bin/env python3 +# Copyright (c) OpenMMLab. All rights reserved. +import argparse +import math +import os +import os.path as osp +from functools import partial + +import mmcv +import numpy as np +from PIL import Image +from mmocr.utils.fileio import list_to_file + + +def parse_args(): + parser = argparse.ArgumentParser( + description='Generate training and validation set of TextOCR ' + 'by cropping box image.') + parser.add_argument('root_path', help='Root dir path of TextOCR') + parser.add_argument( + 'n_proc', default=1, type=int, help='Number of processes to run') + parser.add_argument('--rectify_pose', action='store_true', + help='Fix pose of rotated text to make them horizontal') + args = parser.parse_args() + return args + + +def rectify_image_pose(image, top_left, points): + # Points-based heuristics for determining text orientation w.r.t. bounding box + points = np.asarray(points).reshape(-1, 2) + dist = ((points - np.asarray(top_left)) ** 2).sum(axis=1) + left_midpoint = (points[0] + points[-1]) / 2 + right_corner_points = ((points - left_midpoint) ** 2).sum(axis=1).argsort()[-2:] + right_midpoint = points[right_corner_points].sum(axis=0) / 2 + d_x, d_y = abs(right_midpoint - left_midpoint) + + if dist[0] + dist[-1] <= dist[right_corner_points].sum(): + if d_x >= d_y: + rot = 0 + else: + rot = 90 + else: + if d_x >= d_y: + rot = 180 + else: + rot = -90 + if rot: + image = image.rotate(rot, expand=True) + return image + + +def process_img(args, src_image_root, dst_image_root): + # Dirty hack for multiprocessing + img_idx, img_info, anns, rectify_pose = args + src_img = Image.open(osp.join(src_image_root, img_info['file_name'])) + labels = [] + for ann_idx, ann in enumerate(anns): + text_label = ann['utf8_string'] + + # Ignore illegible or non-English words + if text_label == '.': + continue + + x, y, w, h = ann['bbox'] + x, y = max(0, math.floor(x)), max(0, math.floor(y)) + w, h = math.ceil(w), math.ceil(h) + dst_img = src_img.crop((x, y, x + w, y + h)) + if rectify_pose: + dst_img = rectify_image_pose(dst_img, (x, y), ann['points']) + dst_img_name = f'img_{img_idx}_{ann_idx}.jpg' + dst_img_path = osp.join(dst_image_root, dst_img_name) + # Preserve JPEG quality + dst_img.save(dst_img_path, qtables=src_img.quantization) + labels.append(f'{osp.basename(dst_image_root)}/{dst_img_name}' + f' {text_label}') + src_img.close() + return labels + + +def convert_textocr(root_path, + dst_image_path, + dst_label_filename, + annotation_filename, + img_start_idx=0, + nproc=1, + rectify_pose=False): + annotation_path = osp.join(root_path, annotation_filename) + if not osp.exists(annotation_path): + raise Exception( + f'{annotation_path} not exists, please check and try again.') + src_image_root = root_path + + # outputs + dst_label_file = osp.join(root_path, dst_label_filename) + dst_image_root = osp.join(root_path, dst_image_path) + os.makedirs(dst_image_root, exist_ok=True) + + annotation = mmcv.load(annotation_path) + + process_img_with_path = partial( + process_img, + src_image_root=src_image_root, + dst_image_root=dst_image_root) + tasks = [] + for img_idx, img_info in enumerate(annotation['imgs'].values()): + ann_ids = annotation['imgToAnns'][img_info['id']] + anns = [annotation['anns'][ann_id] for ann_id in ann_ids] + tasks.append((img_idx + img_start_idx, img_info, anns, rectify_pose)) + labels_list = mmcv.track_parallel_progress( + process_img_with_path, tasks, keep_order=True, nproc=nproc) + final_labels = [] + for label_list in labels_list: + final_labels += label_list + list_to_file(dst_label_file, final_labels) + return len(annotation['imgs']) + + +def main(): + args = parse_args() + root_path = args.root_path + print('Processing training set...') + num_train_imgs = convert_textocr( + root_path=root_path, + dst_image_path='image', + dst_label_filename='train_label.txt', + annotation_filename='TextOCR_0.1_train.json', + nproc=args.n_proc, + rectify_pose=args.rectify_pose) + print('Processing validation set...') + convert_textocr( + root_path=root_path, + dst_image_path='image', + dst_label_filename='val_label.txt', + annotation_filename='TextOCR_0.1_val.json', + img_start_idx=num_train_imgs, + nproc=args.n_proc, + rectify_pose=args.rectify_pose) + print('Finish') + + +if __name__ == '__main__': + main() diff --git a/train.py b/train.py new file mode 100644 index 0000000..8ebbd07 --- /dev/null +++ b/train.py @@ -0,0 +1,109 @@ +#!/usr/bin/env python3 +# Scene Text Recognition Model Hub +# Copyright 2022 Darwin Bautista +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +import torch +from pathlib import Path + +from omegaconf import DictConfig, open_dict +import hydra +from hydra.core.hydra_config import HydraConfig + + +from pytorch_lightning import Trainer +#from swa import StochasticWeightAveraging +from pytorch_lightning.callbacks import ModelCheckpoint, StochasticWeightAveraging +from pytorch_lightning.loggers import TensorBoardLogger +from pytorch_lightning.strategies import DDPStrategy +from pytorch_lightning.utilities.model_summary import summarize + +from strhub.data.module import SceneTextDataModule +from strhub.models.base import BaseSystem +from strhub.models.utils import get_pretrained_weights + + +# Copied from OneCycleLR +def _annealing_cos(start, end, pct): + "Cosine anneal from `start` to `end` as pct goes from 0.0 to 1.0." + cos_out = math.cos(math.pi * pct) + 1 + return end + (start - end) / 2.0 * cos_out + + +def get_swa_lr_factor(warmup_pct, swa_epoch_start, div_factor=25, final_div_factor=1e4) -> float: + """Get the SWA LR factor for the given `swa_epoch_start`. Assumes OneCycleLR Scheduler.""" + total_steps = 1000 # Can be anything. We use 1000 for convenience. + start_step = int(total_steps * warmup_pct) - 1 + end_step = total_steps - 1 + step_num = int(total_steps * swa_epoch_start) - 1 + pct = (step_num - start_step) / (end_step - start_step) + return _annealing_cos(1, 1 / (div_factor * final_div_factor), pct) + + +@hydra.main(config_path='configs', config_name='main', version_base='1.2') +def main(config: DictConfig): + trainer_strategy = None + with open_dict(config): + # Resolve absolute path to data.root_dir + config.data.root_dir = hydra.utils.to_absolute_path(config.data.root_dir) + # Special handling for GPU-affected config + gpu = config.trainer.get('accelerator') == 'gpu' + devices = config.trainer.get('devices', 0) + if gpu: + # Use mixed-precision training + config.trainer.precision = 16 #全精度训练 + if gpu and devices > 1: + # Use DDP + config.trainer.strategy = 'ddp' + # DDP optimizations + trainer_strategy = DDPStrategy(find_unused_parameters=True, gradient_as_bucket_view=True) + # Scale steps-based config, 多卡训练会加大batchsize,因此要缩小val_check_interval + config.trainer.val_check_interval //= devices + if config.trainer.get('max_steps', -1) > 0: + config.trainer.max_steps //= devices + + # Special handling for PARseq + if config.model.get('perm_mirrored', False): + assert config.model.perm_num % 2 == 0, 'perm_num should be even if perm_mirrored = True' + + model: BaseSystem = hydra.utils.instantiate(config.model) + # If specified, use pretrained weights to initialize the model + if config.pretrained is not None: + load_state = torch.load(config.pretrained, map_location=model.device) + model_state = model.state_dict() + new_state = {} + for k, v in model_state.items(): + if k in load_state and v.shape == load_state[k].shape: + new_state[k] = v + model.load_state_dict(new_state, strict=True) + + print(summarize(model, max_depth=1 if model.hparams.name.startswith('parseq') else 2)) + + datamodule: SceneTextDataModule = hydra.utils.instantiate(config.data) + + checkpoint = ModelCheckpoint(dirpath="./output/" + config.model.name,monitor='val_accuracy', mode='max', save_top_k=3, save_last=True, + filename='{epoch}-{step}-{val_accuracy:.4f}-{val_NED:.4f}') + + swa_epoch_start = 0.75 + swa_lr = config.model.lr * get_swa_lr_factor(config.model.warmup_pct, swa_epoch_start) + swa = StochasticWeightAveraging(swa_lr, swa_epoch_start) + cwd = HydraConfig.get().runtime.output_dir #if config.ckpt_path is None else str(Path(config.ckpt_path).parents[1].absolute()) + trainer: Trainer = hydra.utils.instantiate(config.trainer, logger=TensorBoardLogger(cwd, '', '.'), + strategy=trainer_strategy, enable_model_summary=False, + callbacks=[checkpoint, swa]) + trainer.fit(model, datamodule=datamodule, ckpt_path=config.ckpt_path) + + +if __name__ == '__main__': + main()