From 5cab4c0c33c3b8a9102c5f29be120872e809e203 Mon Sep 17 00:00:00 2001 From: lgcy <865562216@qq.com> Date: Tue, 17 May 2022 20:47:45 +0800 Subject: [PATCH 1/2] update lmdb_dateset for ppocrv3 rec MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit 对lmdb_dataset适配ppocrv3的RecConAug数据增强 --- ppocr/data/lmdb_dataset.py | 24 ++++++++++++++++++++++++ 1 file changed, 24 insertions(+) diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index e1b49809d1..0c1ea9465c 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -88,6 +88,29 @@ def get_img_data(self, value): if imgori is None: return None return imgori + + def get_ext_data(self): + ext_data_num = 0 + for op in self.ops: + if hasattr(op, 'ext_data_num'): + ext_data_num = getattr(op, 'ext_data_num') + break + load_data_ops = self.ops[:self.ext_op_transform_idx] + ext_data = [] + + while len(ext_data) < ext_data_num: + lmdb_idx, file_idx = self.data_idx_order_list[np.random.randint(self.__len__())] + lmdb_idx = int(lmdb_idx) + file_idx = int(file_idx) + sample_info = self.get_lmdb_sample_info(self.lmdb_sets[lmdb_idx]['txn'], + file_idx) + if sample_info is None: + continue + img, label = sample_info + data = {'image': img, 'label': label} + outs = transform(data, load_data_ops) + ext_data.append(data) + return ext_data def get_lmdb_sample_info(self, txn, index): label_key = 'label-%09d'.encode() % index @@ -109,6 +132,7 @@ def __getitem__(self, idx): return self.__getitem__(np.random.randint(self.__len__())) img, label = sample_info data = {'image': img, 'label': label} + data['ext_data'] = self.get_ext_data() outs = transform(data, self.ops) if outs is None: return self.__getitem__(np.random.randint(self.__len__())) From 0dfb536fa6eec201682d01c6e6dc256e33029db3 Mon Sep 17 00:00:00 2001 From: lgcy <865562216@qq.com> Date: Tue, 17 May 2022 21:00:39 +0800 Subject: [PATCH 2/2] Update lmdb_dataset.py --- ppocr/data/lmdb_dataset.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/ppocr/data/lmdb_dataset.py b/ppocr/data/lmdb_dataset.py index 0c1ea9465c..2b1ccaddce 100644 --- a/ppocr/data/lmdb_dataset.py +++ b/ppocr/data/lmdb_dataset.py @@ -37,6 +37,8 @@ def __init__(self, config, mode, logger, seed=None): if self.do_shuffle: np.random.shuffle(self.data_idx_order_list) self.ops = create_operators(dataset_config['transforms'], global_config) + self.ext_op_transform_idx = dataset_config.get("ext_op_transform_idx", + 2) ratio_list = dataset_config.get("ratio_list", [1.0]) self.need_reset = True in [x < 1 for x in ratio_list]