Skip to content

Commit

Permalink
add re predict
Browse files Browse the repository at this point in the history
  • Loading branch information
WenmuZhou committed Sep 20, 2022
1 parent 0a59848 commit 0619452
Show file tree
Hide file tree
Showing 22 changed files with 482 additions and 88 deletions.
4 changes: 2 additions & 2 deletions configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ Train:
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
Expand All @@ -83,7 +84,6 @@ Train:
drop_last: False
batch_size_per_card: 2
num_workers: 8
collate_fn: ListCollator

Eval:
dataset:
Expand All @@ -105,6 +105,7 @@ Eval:
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
Expand All @@ -120,4 +121,3 @@ Eval:
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator
5 changes: 2 additions & 3 deletions configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,7 @@ Train:
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
Expand All @@ -88,7 +89,6 @@ Train:
drop_last: False
batch_size_per_card: 2
num_workers: 4
collate_fn: ListCollator

Eval:
dataset:
Expand All @@ -112,6 +112,7 @@ Eval:
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
Expand All @@ -127,5 +128,3 @@ Eval:
drop_last: False
batch_size_per_card: 8
num_workers: 8
collate_fn: ListCollator

2 changes: 2 additions & 0 deletions configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,7 @@ Train:
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
Expand Down Expand Up @@ -155,6 +156,7 @@ Eval:
- VQAReTokenRelation:
- VQAReTokenChunk:
max_seq_len: *max_seq_len
- TensorizeEntitiesRelations:
- Resize:
size: [224,224]
- NormalizeImage:
Expand Down
34 changes: 31 additions & 3 deletions doc/doc_ch/algorithm_kie_layoutxlm.md
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
|模型|骨干网络|任务|配置文件|hmean|下载链接|
| --- | --- |--|--- | --- | --- |
|LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)|
|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[推理模型(coming soon)]()|
|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[训练模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[推理模型](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar)|

<a name="2"></a>

Expand All @@ -52,14 +52,14 @@

### 4.1 Python推理

**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于LayoutXLM模型的关键信息抽取过程。
- SER

首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。

``` bash
wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar
tar -xf ser_LayoutXLM_xfun_zh.tar
python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh/best_accuracy Global.save_inference_dir=./inference/ser_layoutxlm
python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer
```

LayoutXLM模型基于SER任务进行推理,可以执行如下命令:
Expand All @@ -80,6 +80,34 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下
<img src="../../ppstructure/docs/kie/result_ser/zh_val_42_ser.jpg" width="800">
</div>

- RE

首先将训练得到的模型转换成inference model。LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)),可以使用下面的命令进行转换。

``` bash
wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar
tar -xf re_LayoutXLM_xfun_zh.tar
python3 tools/export_model.py -c configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer
```

LayoutXLM模型基于RE任务进行推理,可以执行如下命令:

```bash
cd ppstructure
python3 kie/predict_kie_token_ser_re.py \
--kie_algorithm=LayoutXLM \
--re_model_dir=../inference/re_layoutxlm_infer \
--ser_model_dir=../inference/ser_layoutxlm_infer \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
--vis_font_path=../doc/fonts/simfang.ttf
```

RE可视化结果默认保存到`./output`文件夹里面,结果示例如下:

<div align="center">
<img src="../../ppstructure/docs/kie/result_re/zh_val_42_re.jpg" width="800">
</div>

<a name="4-2"></a>
### 4.2 C++推理部署
Expand Down
34 changes: 32 additions & 2 deletions doc/doc_ch/algorithm_kie_vi_layoutxlm.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去
|模型|骨干网络|任务|配置文件|hmean|下载链接|
| --- | --- |---| --- | --- | --- |
|VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)|
|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[推理模型(coming soon)]()|
|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[训练模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[推理模型](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar)|

<a name="2"></a>

Expand All @@ -45,7 +45,7 @@ VI-LayoutXLM基于LayoutXLM进行改进,在下游任务训练过程中,去

### 4.1 Python推理

**注:** 目前RE任务推理过程仍在适配中,下面以SER任务为例,介绍基于VI-LayoutXLM模型的关键信息抽取过程。
-SER

首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。

Expand Down Expand Up @@ -74,6 +74,36 @@ SER可视化结果默认保存到`./output`文件夹里面,结果示例如下
<img src="../../ppstructure/docs/kie/result_ser/zh_val_42_ser.jpg" width="800">
</div>

-RE

首先将训练得到的模型转换成inference model。以VI-LayoutXLM模型在XFUND_zh数据集上训练的模型为例([模型下载地址](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)),可以使用下面的命令进行转换。

``` bash
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar
tar -xf re_vi_layoutxlm_xfund_pretrained.tar
python3 tools/export_model.py -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_vi_layoutxlm_xfund_pretrained/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm_infer
```

VI-LayoutXLM模型基于RE任务进行推理,可以执行如下命令:

```bash
cd ppstructure
python3 kie/predict_kie_token_ser_re.py \
--kie_algorithm=LayoutXLM \
--re_model_dir=../inference/re_vi_layoutxlm_infer \
--ser_model_dir=../inference/ser_vi_layoutxlm_infer \
--use_visual_backbone=False \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx"
```

RE可视化结果默认保存到`./output`文件夹里面,结果示例如下:

<div align="center">
<img src="../../ppstructure/docs/kie/result_re/zh_val_42_re.jpg" width="800">
</div>

<a name="4-2"></a>
### 4.2 C++推理部署
Expand Down
38 changes: 35 additions & 3 deletions doc/doc_en/algorithm_kie_layoutxlm_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows.
|Model|Backbone|Task |Cnnfig|Hmean|Download link|
| --- | --- |--|--- | --- | --- |
|LayoutXLM|LayoutXLM-base|SER |[ser_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml)|90.38%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)/[inference model](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh_infer.tar)|
|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[inference model(coming soon)]()|
|LayoutXLM|LayoutXLM-base|RE | [re_layoutxlm_xfund_zh.yml](../../configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml)|74.83%|[trained model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)/[inference model](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh_infer.tar)|


## 2. Environment
Expand All @@ -46,15 +46,15 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code

### 4.1 Python Inference

**Note:** Currently, the RE model inference process is still in the process of adaptation. We take SER model as an example to introduce the KIE process based on LayoutXLM model.
- SER

First, we need to export the trained model into inference model. Take LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar)). Use the following command to export.


``` bash
wget https://paddleocr.bj.bcebos.com/pplayout/ser_LayoutXLM_xfun_zh.tar
tar -xf ser_LayoutXLM_xfun_zh.tar
python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh/best_accuracy Global.save_inference_dir=./inference/ser_layoutxlm
python3 tools/export_model.py -c configs/kie/layoutlm_series/ser_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./ser_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/ser_layoutxlm_infer
```

Use the following command to infer using LayoutXLM SER model.
Expand All @@ -77,6 +77,38 @@ The SER visualization results are saved in the `./output` directory by default.
</div>


- RE

First, we need to export the trained model into inference model. Take LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar)). Use the following command to export.


``` bash
wget https://paddleocr.bj.bcebos.com/pplayout/re_LayoutXLM_xfun_zh.tar
tar -xf re_LayoutXLM_xfun_zh.tar
python3 tools/export_model.py -c configs/kie/layoutlm_series/re_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_LayoutXLM_xfun_zh Global.save_inference_dir=./inference/re_layoutxlm_infer
```

Use the following command to infer using LayoutXLM RE model.


```bash
cd ppstructure
python3 kie/predict_kie_token_ser_re.py \
--kie_algorithm=LayoutXLM \
--re_model_dir=../inference/re_layoutxlm_infer \
--ser_model_dir=../inference/ser_layoutxlm_infer \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
--vis_font_path=../doc/fonts/simfang.ttf
```
The RE visualization results are saved in the `./output` directory by default. The results are as follows.


<div align="center">
<img src="../../ppstructure/docs/kie/result_re/zh_val_42_re.jpg" width="800">
</div>


### 4.2 C++ Inference

Not supported
Expand Down
39 changes: 37 additions & 2 deletions doc/doc_en/algorithm_kie_vi_layoutxlm_en.md
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ On XFUND_zh dataset, the algorithm reproduction Hmean is as follows.
|Model|Backbone|Task |Cnnfig|Hmean|Download link|
| --- | --- |---| --- | --- | --- |
|VI-LayoutXLM |VI-LayoutXLM-base | SER |[ser_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/ser_vi_layoutxlm_xfund_zh_udml.yml)|93.19%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_infer.tar)|
|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[inference model(coming soon)]()|
|VI-LayoutXLM |VI-LayoutXLM-base |RE | [re_vi_layoutxlm_xfund_zh_udml.yml](../../configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh_udml.yml)|83.92%|[trained model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)/[inference model](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_infer.tar)|


Please refer to ["Environment Preparation"](./environment_en.md) to configure the PaddleOCR environment, and refer to ["Project Clone"](./clone_en.md) to clone the project code.
Expand All @@ -37,7 +37,7 @@ Please refer to [KIE tutorial](./kie_en.md)。PaddleOCR has modularized the code

### 4.1 Python Inference

**Note:** Currently, the RE model inference process is still in the process of adaptation. We take SER model as an example to introduce the KIE process based on VI-LayoutXLM model.
-SER

First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/ser_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export.

Expand Down Expand Up @@ -70,6 +70,41 @@ The SER visualization results are saved in the `./output` folder by default. The
</div>


-RE

First, we need to export the trained model into inference model. Take VI-LayoutXLM model trained on XFUND_zh as an example ([trained model download link](https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar)). Use the following command to export.


``` bash
wget https://paddleocr.bj.bcebos.com/ppstructure/models/vi_layoutxlm/re_vi_layoutxlm_xfund_pretrained.tar
tar -xf re_vi_layoutxlm_xfund_pretrained.tar
python3 tools/export_model.py -c configs/kie/vi_layoutxlm/re_vi_layoutxlm_xfund_zh.yml -o Architecture.Backbone.checkpoints=./re_vi_layoutxlm_xfund_pretrained/best_accuracy Global.save_inference_dir=./inference/re_vi_layoutxlm_infer
```

Use the following command to infer using VI-LayoutXLM RE model.


```bash
cd ppstructure
python3 kie/predict_kie_token_ser_re.py \
--kie_algorithm=LayoutXLM \
--re_model_dir=../inference/re_vi_layoutxlm_infer \
--ser_model_dir=../inference/ser_vi_layoutxlm_infer \
--use_visual_backbone=False \
--image_dir=./docs/kie/input/zh_val_42.jpg \
--ser_dict_path=../train_data/XFUND/class_list_xfun.txt \
--vis_font_path=../doc/fonts/simfang.ttf \
--ocr_order_method="tb-yx"
```

The RE visualization results are saved in the `./output` folder by default. The results are as follows.


<div align="center">
<img src="../../ppstructure/docs/kie/result_re/zh_val_42_re.jpg" width="800">
</div>


### 4.2 C++ Inference

Not supported
Expand Down
8 changes: 3 additions & 5 deletions ppocr/data/imaug/vqa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation
from .token import VQATokenPad, VQASerTokenChunk, VQAReTokenChunk, VQAReTokenRelation, TensorizeEntitiesRelations

__all__ = [
'VQATokenPad',
'VQASerTokenChunk',
'VQAReTokenChunk',
'VQAReTokenRelation',
'VQATokenPad', 'VQASerTokenChunk', 'VQAReTokenChunk', 'VQAReTokenRelation',
'TensorizeEntitiesRelations'
]
1 change: 1 addition & 0 deletions ppocr/data/imaug/vqa/token/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@
from .vqa_token_chunk import VQASerTokenChunk, VQAReTokenChunk
from .vqa_token_pad import VQATokenPad
from .vqa_token_relation import VQAReTokenRelation
from .vqa_re_convert import TensorizeEntitiesRelations
51 changes: 51 additions & 0 deletions ppocr/data/imaug/vqa/token/vqa_re_convert.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
# copyright (c) 2022 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np


class TensorizeEntitiesRelations(object):
def __init__(self, max_seq_len=512, infer_mode=False, **kwargs):
self.max_seq_len = max_seq_len
self.infer_mode = infer_mode

def __call__(self, data):
entities = data['entities']
relations = data['relations']

entities_new = np.full(
shape=[self.max_seq_len + 1, 3], fill_value=-1, dtype='int64')
entities_new[0, 0] = len(entities['start'])
entities_new[0, 1] = len(entities['end'])
entities_new[0, 2] = len(entities['label'])
entities_new[1:len(entities['start']) + 1, 0] = np.array(entities[
'start'])
entities_new[1:len(entities['end']) + 1, 1] = np.array(entities['end'])
entities_new[1:len(entities['label']) + 1, 2] = np.array(entities[
'label'])

relations_new = np.full(
shape=[self.max_seq_len * self.max_seq_len + 1, 2],
fill_value=-1,
dtype='int64')
relations_new[0, 0] = len(relations['head'])
relations_new[0, 1] = len(relations['tail'])
relations_new[1:len(relations['head']) + 1, 0] = np.array(relations[
'head'])
relations_new[1:len(relations['tail']) + 1, 1] = np.array(relations[
'tail'])

data['entities'] = entities_new
data['relations'] = relations_new
return data
Loading

0 comments on commit 0619452

Please sign in to comment.