Skip to content

Commit

Permalink
Code for CRNN
Browse files Browse the repository at this point in the history
training code for chinese text recognition
  • Loading branch information
diaomin committed Apr 6, 2018
1 parent f7d4636 commit f2006ba
Show file tree
Hide file tree
Showing 16 changed files with 923 additions and 0 deletions.
Empty file added data_utils/__init__.py
Empty file.
266 changes: 266 additions & 0 deletions data_utils/data_iter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
from __future__ import print_function

import os
from PIL import Image
import numpy as np
import mxnet as mx
import random

def write_txt_file():
root_path = "D:/Data/VOCtrainval_11-May-2012/test/"

dirs = os.listdir(os.path.join(root_path,"images"))
content = []
for d in dirs:
files = os.listdir(os.path.join(root_path,"images", d))
for f in files:
content.append(d+"/"+f+" "+d+"\n")

random.shuffle(content)

train_f = open(os.path.join(root_path,"train.txt"),"w")
test_f = open(os.path.join(root_path, "test.txt"), "w")

for i,c in enumerate(content):
if i < 0.8*len(content):
train_f.write(c)
else:
test_f.write(c)
train_f.close()
test_f.close()

def write_mx_lst(data_type="train"):
txt_file = "D:/BaiduNetdiskDownload/Synthetic_Chinese_String_Dataset/"
in_f = open(os.path.join(txt_file, data_type+".txt"), "r")
out_f = open(os.path.join(txt_file, data_type+".lst"), "w")
lines = in_f.readlines()
random.shuffle(lines)
for idx, line in enumerate(lines):
new_line = str(idx)+"\t"
lst = line.strip().split(" ")
for i in range(len(lst)-1):
new_line = new_line+lst[i+1]+"\t"
new_line = new_line+"images/"+lst[0]+"\n"
out_f.write(new_line)
in_f.close()
out_f.close()



class SimpleBatch(object):
def __init__(self, data_names, data, label_names=list(), label=list()):
self._data = data
self._label = label
self._data_names = data_names
self._label_names = label_names

self.pad = 0
self.index = None # TODO: what is index?

@property
def data(self):
return self._data

@property
def label(self):
return self._label

@property
def data_names(self):
return self._data_names

@property
def label_names(self):
return self._label_names

@property
def provide_data(self):
return [(n, x.shape) for n, x in zip(self._data_names, self._data)]

@property
def provide_label(self):
return [(n, x.shape) for n, x in zip(self._label_names, self._label)]


class ImageIter(mx.io.DataIter):

"""
Iterator class for generating captcha image data
"""
def __init__(self, data_root, data_list, batch_size, data_shape, num_label, name=None):
"""
Parameters
----------
data_root: str
root directory of images
data_list: str
a .txt file stores the image name and corresponding labels for each line
batch_size: int
name: str
"""
super(ImageIter, self).__init__()
self.batch_size = batch_size
self.data_shape = data_shape
self.num_label = num_label

self.data_root = data_root
self.dataset_lst_file = open(data_list)

self.provide_data = [('data', (batch_size, 1, data_shape[1], data_shape[0]))]
self.provide_label = [('label', (self.batch_size, self.num_label))]
self.name = name

def __iter__(self):
data = []
label = []
cnt = 0
for m_line in self.dataset_lst_file:
img_lst = m_line.strip().split(' ')
img_path = os.path.join(self.data_root, img_lst[0])

cnt += 1
img = Image.open(img_path).resize(self.data_shape, Image.BILINEAR).convert('L')
img = np.array(img).reshape((1, self.data_shape[1], self.data_shape[0]))
data.append(img)

ret = np.zeros(self.num_label, int)
for idx in range(1, len(img_lst)):
ret[idx-1] = int(img_lst[idx])

label.append(ret)
if cnt % self.batch_size == 0:
data_all = [mx.nd.array(data)]
label_all = [mx.nd.array(label)]
data_names = ['data']
label_names = ['label']
data.clear()
label.clear()
yield SimpleBatch(data_names, data_all, label_names, label_all)
continue


def reset(self):
if self.dataset_lst_file.seekable():
self.dataset_lst_file.seek(0)

class ImageIterLstm(mx.io.DataIter):

"""
Iterator class for generating captcha image data
"""

def __init__(self, data_root, data_list, batch_size, data_shape, num_label, lstm_init_states, name=None):
"""
Parameters
----------
data_root: str
root directory of images
data_list: str
a .txt file stores the image name and corresponding labels for each line
batch_size: int
name: str
"""
super(ImageIterLstm, self).__init__()
self.batch_size = batch_size
self.data_shape = data_shape
self.num_label = num_label

self.init_states = lstm_init_states
self.init_state_arrays = [mx.nd.zeros(x[1]) for x in lstm_init_states]

self.data_root = data_root
self.dataset_lines = open(data_list).readlines()

self.provide_data = [('data', (batch_size, 1, data_shape[1], data_shape[0]))] + lstm_init_states
self.provide_label = [('label', (self.batch_size, self.num_label))]
self.name = name

def __iter__(self):
init_state_names = [x[0] for x in self.init_states]
data = []
label = []
cnt = 0
for m_line in self.dataset_lines:
img_lst = m_line.strip().split(' ')
img_path = os.path.join(self.data_root, img_lst[0])

cnt += 1
img = Image.open(img_path).resize(self.data_shape, Image.BILINEAR).convert('L')
img = np.array(img).reshape((1, self.data_shape[1], self.data_shape[0]))
data.append(img)

ret = np.zeros(self.num_label, int)
for idx in range(1, len(img_lst)):
ret[idx - 1] = int(img_lst[idx])

label.append(ret)
if cnt % self.batch_size == 0:
data_all = [mx.nd.array(data)] + self.init_state_arrays
label_all = [mx.nd.array(label)]
data_names = ['data'] + init_state_names
label_names = ['label']
data = []
label = []
yield SimpleBatch(data_names, data_all, label_names, label_all)
continue

def reset(self):
# if self.dataset_lst_file.seekable():
# self.dataset_lst_file.seek(0)
random.shuffle(self.dataset_lines)

# def get_label(buf):
# ret = np.zeros(10)
# for i in range(len(buf)):
# ret[i] = 1 + int(buf[i])
# if len(buf) == 9:
# ret[3] = 0
# return ret

# class OCRIter(mx.io.DataIter):
# """
# Iterator class for generating captcha image data
# """
#
# def __init__(self, count, batch_size, captcha, name):
# """
# Parameters
# ----------
# count: int
# Number of batches to produce for one epoch
# batch_size: int
#
# captcha MPCaptcha
# Captcha image generator. Can be MPCaptcha or any other class providing .shape and .get() interface
# name: str
# """
# super(OCRIter, self).__init__()
# self.batch_size = batch_size
# self.count = count
#
# self.data_shape = captcha.shape
# print(self.data_shape)
# self.provide_data = [('data', (batch_size, 1, self.data_shape[0], self.data_shape[1]))]
# self.provide_label = [('label', (self.batch_size, 10))]
# self.mp_captcha = captcha
# self.name = name
#
# def __iter__(self):
# for k in range(self.count):
# data = []
# label = []
# for i in range(self.batch_size):
# img, num = self.mp_captcha.get()
# img = np.array(img).reshape((1, self.data_shape[0], self.data_shape[1]))
# data.append(img)
# label.append(get_label(num))
# data_all = [mx.nd.array(data)]
# label_all = [mx.nd.array(label)]
# data_names = ['data']
# label_names = ['label']
#
# data_batch = SimpleBatch(data_names, data_all, label_names, label_all)
# yield data_batch

if __name__=="__main__":
write_mx_lst("test")
Empty file added fit/__init__.py
Empty file.
33 changes: 33 additions & 0 deletions fit/ctc_loss.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import mxnet as mx

def _add_warp_ctc_loss(pred, seq_len, num_label, label):
""" Adds Symbol.contrib.ctc_loss on top of pred symbol and returns the resulting symbol """
label = mx.sym.Reshape(data=label, shape=(-1,))
label = mx.sym.Cast(data=label, dtype='int32')
return mx.sym.WarpCTC(data=pred, label=label, label_length=num_label, input_length=seq_len)


def _add_mxnet_ctc_loss(pred, seq_len, label):
""" Adds Symbol.WapCTC on top of pred symbol and returns the resulting symbol """
pred_ctc = mx.sym.Reshape(data=pred, shape=(-4, seq_len, -1, 0))

loss = mx.sym.contrib.ctc_loss(data=pred_ctc, label=label)
ctc_loss = mx.sym.MakeLoss(loss)

softmax_class = mx.symbol.SoftmaxActivation(data=pred)
softmax_loss = mx.sym.MakeLoss(softmax_class)
softmax_loss = mx.sym.BlockGrad(softmax_loss)
return mx.sym.Group([softmax_loss, ctc_loss])


def add_ctc_loss(pred, seq_len, num_label, loss_type):
""" Adds CTC loss on top of pred symbol and returns the resulting symbol """
label = mx.sym.Variable('label')
if loss_type == 'warpctc':
print("Using WarpCTC Loss")
sm = _add_warp_ctc_loss(pred, seq_len, num_label, label)
else:
print("Using MXNet CTC Loss")
assert loss_type == 'ctc'
sm = _add_mxnet_ctc_loss(pred, seq_len, label)
return sm
Loading

0 comments on commit f2006ba

Please sign in to comment.