forked from PaddlePaddle/PaddleOCR
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* add sr model * update for eval * submit sr * polish code * polish code * polish code * update sr model * update doc * update doc * update doc * fix typo * format code * update metric * fix export
- Loading branch information
Showing
24 changed files
with
1,719 additions
and
19 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,85 @@ | ||
Global: | ||
use_gpu: true | ||
epoch_num: 500 | ||
log_smooth_window: 20 | ||
print_batch_step: 10 | ||
save_model_dir: ./output/sr/sr_tsrn_transformer_strock/ | ||
save_epoch_step: 3 | ||
# evaluation is run every 2000 iterations | ||
eval_batch_step: [0, 1000] | ||
cal_metric_during_train: False | ||
pretrained_model: | ||
checkpoints: | ||
save_inference_dir: sr_output | ||
use_visualdl: False | ||
infer_img: doc/imgs_words_en/word_52.png | ||
# for data or label process | ||
character_dict_path: ./train_data/srdata/english_decomposition.txt | ||
max_text_length: 100 | ||
infer_mode: False | ||
use_space_char: False | ||
save_res_path: ./output/sr/predicts_gestalt.txt | ||
|
||
Optimizer: | ||
name: Adam | ||
beta1: 0.5 | ||
beta2: 0.999 | ||
clip_norm: 0.25 | ||
lr: | ||
learning_rate: 0.0001 | ||
|
||
Architecture: | ||
model_type: sr | ||
algorithm: Gestalt | ||
Transform: | ||
name: TSRN | ||
STN: True | ||
infer_mode: False | ||
|
||
Loss: | ||
name: StrokeFocusLoss | ||
character_dict_path: ./train_data/srdata/english_decomposition.txt | ||
|
||
PostProcess: | ||
name: None | ||
|
||
Metric: | ||
name: SRMetric | ||
main_indicator: all | ||
|
||
Train: | ||
dataset: | ||
name: LMDBDataSetSR | ||
data_dir: ./train_data/srdata/train | ||
transforms: | ||
- SRResize: | ||
imgH: 32 | ||
imgW: 128 | ||
down_sample_scale: 2 | ||
- SRLabelEncode: # Class handling label | ||
- KeepKeys: | ||
keep_keys: ['img_lr', 'img_hr', 'length', 'input_tensor', 'label'] # dataloader will return list in this order | ||
loader: | ||
shuffle: False | ||
batch_size_per_card: 16 | ||
drop_last: True | ||
num_workers: 4 | ||
|
||
Eval: | ||
dataset: | ||
name: LMDBDataSetSR | ||
data_dir: ./train_data/srdata/test | ||
transforms: | ||
- SRResize: | ||
imgH: 32 | ||
imgW: 128 | ||
down_sample_scale: 2 | ||
- SRLabelEncode: # Class handling label | ||
- KeepKeys: | ||
keep_keys: ['img_lr', 'img_hr','length', 'input_tensor', 'label'] # dataloader will return list in this order | ||
loader: | ||
shuffle: False | ||
drop_last: False | ||
batch_size_per_card: 16 | ||
num_workers: 4 | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
# Text Gestalt | ||
|
||
- [1. 算法简介](#1) | ||
- [2. 环境配置](#2) | ||
- [3. 模型训练、评估、预测](#3) | ||
- [3.1 训练](#3-1) | ||
- [3.2 评估](#3-2) | ||
- [3.3 预测](#3-3) | ||
- [4. 推理部署](#4) | ||
- [4.1 Python推理](#4-1) | ||
- [4.2 C++推理](#4-2) | ||
- [4.3 Serving服务化部署](#4-3) | ||
- [4.4 更多推理部署](#4-4) | ||
- [5. FAQ](#5) | ||
|
||
<a name="1"></a> | ||
## 1. 算法简介 | ||
|
||
论文信息: | ||
> [Text Gestalt: Stroke-Aware Scene Text Image Super-Resolution](https://arxiv.org/pdf/2112.08171.pdf) | ||
> Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang | ||
> AAAI, 2022 | ||
参考[FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/text-gestalt) 数据下载说明,在TextZoom测试集合上超分算法效果如下: | ||
|
||
|模型|骨干网络|PSNR_Avg|SSIM_Avg|配置文件|下载链接| | ||
|---|---|---|---|---|---| | ||
|Text Gestalt|tsrn|19.28|0.6560| [configs/sr/sr_tsrn_transformer_strock.yml](../../configs/sr/sr_tsrn_transformer_strock.yml)|[训练模型](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar)| | ||
|
||
|
||
<a name="2"></a> | ||
## 2. 环境配置 | ||
请先参考[《运行环境准备》](./environment.md)配置PaddleOCR运行环境,参考[《项目克隆》](./clone.md)克隆项目代码。 | ||
|
||
|
||
<a name="3"></a> | ||
## 3. 模型训练、评估、预测 | ||
|
||
请参考[文本识别训练教程](./recognition.md)。PaddleOCR对代码进行了模块化,训练不同的识别模型只需要**更换配置文件**即可。 | ||
|
||
- 训练 | ||
|
||
在完成数据准备后,便可以启动训练,训练命令如下: | ||
|
||
``` | ||
#单卡训练(训练周期长,不建议) | ||
python3 tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml | ||
#多卡训练,通过--gpus参数指定卡号 | ||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml | ||
``` | ||
|
||
- 评估 | ||
|
||
``` | ||
# GPU 评估, Global.pretrained_model 为待测权重 | ||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy | ||
``` | ||
|
||
- 预测: | ||
|
||
``` | ||
# 预测使用的配置文件必须与训练一致 | ||
python3 tools/infer_sr.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png | ||
``` | ||
|
||
 | ||
|
||
执行命令后,上面图像的超分结果如下: | ||
|
||
 | ||
|
||
<a name="4"></a> | ||
## 4. 推理部署 | ||
|
||
<a name="4-1"></a> | ||
### 4.1 Python推理 | ||
|
||
首先将文本超分训练过程中保存的模型,转换成inference model。以 Text-Gestalt 训练的[模型](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar) 为例,可以使用如下命令进行转换: | ||
```shell | ||
python3 tools/export_model.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out | ||
``` | ||
Text-Gestalt 文本超分模型推理,可以执行如下命令: | ||
``` | ||
python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128 | ||
``` | ||
|
||
执行命令后,图像的超分结果如下: | ||
|
||
 | ||
|
||
<a name="4-2"></a> | ||
### 4.2 C++推理 | ||
|
||
暂未支持 | ||
|
||
<a name="4-3"></a> | ||
### 4.3 Serving服务化部署 | ||
|
||
暂未支持 | ||
|
||
<a name="4-4"></a> | ||
### 4.4 更多推理部署 | ||
|
||
暂未支持 | ||
|
||
<a name="5"></a> | ||
## 5. FAQ | ||
|
||
|
||
## 引用 | ||
|
||
```bibtex | ||
@inproceedings{chen2022text, | ||
title={Text gestalt: Stroke-aware scene text image super-resolution}, | ||
author={Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang}, | ||
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, | ||
volume={36}, | ||
number={1}, | ||
pages={285--293}, | ||
year={2022} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,136 @@ | ||
# Text Gestalt | ||
|
||
- [1. Introduction](#1) | ||
- [2. Environment](#2) | ||
- [3. Model Training / Evaluation / Prediction](#3) | ||
- [3.1 Training](#3-1) | ||
- [3.2 Evaluation](#3-2) | ||
- [3.3 Prediction](#3-3) | ||
- [4. Inference and Deployment](#4) | ||
- [4.1 Python Inference](#4-1) | ||
- [4.2 C++ Inference](#4-2) | ||
- [4.3 Serving](#4-3) | ||
- [4.4 More](#4-4) | ||
- [5. FAQ](#5) | ||
|
||
|
||
<a name="1"></a> | ||
## 1. Introduction | ||
|
||
Paper: | ||
> [Text Gestalt: Stroke-Aware Scene Text Image Super-Resolution](https://arxiv.org/pdf/2112.08171.pdf) | ||
> Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang | ||
> AAAI, 2022 | ||
Referring to the [FudanOCR](https://github.com/FudanVI/FudanOCR/tree/main/text-gestalt) data download instructions, the effect of the super-score algorithm on the TextZoom test set is as follows: | ||
|
||
|Model|Backbone|config|Acc|Download link| | ||
|---|---|---|---|---|---| | ||
|Text Gestalt|tsrn|19.28|0.6560| [configs/sr/sr_tsrn_transformer_strock.yml](../../configs/sr/sr_tsrn_transformer_strock.yml)|[train model](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar)| | ||
|
||
|
||
<a name="2"></a> | ||
## 2. Environment | ||
Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code. | ||
|
||
|
||
<a name="3"></a> | ||
## 3. Model Training / Evaluation / Prediction | ||
|
||
Please refer to [Text Recognition Tutorial](./recognition_en.md). PaddleOCR modularizes the code, and training different models only requires **changing the configuration file**. | ||
|
||
Training: | ||
|
||
Specifically, after the data preparation is completed, the training can be started. The training command is as follows: | ||
|
||
``` | ||
#Single GPU training (long training period, not recommended) | ||
python3 tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml | ||
#Multi GPU training, specify the gpu number through the --gpus parameter | ||
python3 -m paddle.distributed.launch --gpus '0,1,2,3' tools/train.py -c configs/sr/sr_tsrn_transformer_strock.yml | ||
``` | ||
|
||
|
||
Evaluation: | ||
|
||
``` | ||
# GPU evaluation | ||
python3 -m paddle.distributed.launch --gpus '0' tools/eval.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy | ||
``` | ||
|
||
Prediction: | ||
|
||
``` | ||
# The configuration file used for prediction must match the training | ||
python3 tools/infer_sr.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.infer_img=doc/imgs_words_en/word_52.png | ||
``` | ||
|
||
 | ||
|
||
After executing the command, the super-resolution result of the above image is as follows: | ||
|
||
 | ||
|
||
<a name="4"></a> | ||
## 4. Inference and Deployment | ||
|
||
<a name="4-1"></a> | ||
### 4.1 Python Inference | ||
|
||
First, the model saved during the training process is converted into an inference model. ( [Model download link](https://paddleocr.bj.bcebos.com/sr_tsrn_transformer_strock_train.tar) ), you can use the following command to convert: | ||
|
||
```shell | ||
python3 tools/export_model.py -c configs/sr/sr_tsrn_transformer_strock.yml -o Global.pretrained_model={path/to/weights}/best_accuracy Global.save_inference_dir=./inference/sr_out | ||
``` | ||
|
||
For Text-Gestalt super-resolution model inference, the following commands can be executed: | ||
|
||
``` | ||
python3 tools/infer/predict_sr.py --sr_model_dir=./inference/sr_out --image_dir=doc/imgs_words_en/word_52.png --sr_image_shape=3,32,128 | ||
``` | ||
|
||
After executing the command, the super-resolution result of the above image is as follows: | ||
|
||
 | ||
|
||
|
||
<a name="4-2"></a> | ||
### 4.2 C++ Inference | ||
|
||
Not supported | ||
|
||
<a name="4-3"></a> | ||
### 4.3 Serving | ||
|
||
Not supported | ||
|
||
<a name="4-4"></a> | ||
### 4.4 More | ||
|
||
Not supported | ||
|
||
<a name="5"></a> | ||
## 5. FAQ | ||
|
||
|
||
## Citation | ||
|
||
```bibtex | ||
@inproceedings{chen2022text, | ||
title={Text gestalt: Stroke-aware scene text image super-resolution}, | ||
author={Chen, Jingye and Yu, Haiyang and Ma, Jianqi and Li, Bin and Xue, Xiangyang}, | ||
booktitle={Proceedings of the AAAI Conference on Artificial Intelligence}, | ||
volume={36}, | ||
number={1}, | ||
pages={285--293}, | ||
year={2022} | ||
} | ||
``` |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.