-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlosses.py
30 lines (22 loc) · 1.14 KB
/
losses.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import tensorflow as tf
import keras
class CTCloss(tf.keras.losses.Loss):
""" CTCLoss objec for training the model"""
def __init__(self, name: str = "CTCloss",reduction=tf.keras.losses.Reduction.SUM) -> None:
super(CTCloss, self).__init__()
self.name = name
self.loss_fn = tf.keras.backend.ctc_batch_cost
self.reduction = reduction
def __call__(self, y_true: tf.Tensor, y_pred: tf.Tensor, sample_weight=None) -> tf.Tensor:
""" Compute the training batch CTC loss value"""
batch_len = tf.cast(tf.shape(y_true)[0], dtype="int64")
input_length = tf.cast(tf.shape(y_pred)[1], dtype="int64")
label_length = tf.cast(tf.shape(y_true)[1], dtype="int64")
input_length = input_length * tf.ones(shape=(batch_len, 1), dtype="int64")
label_length = label_length * tf.ones(shape=(batch_len, 1), dtype="int64")
#print("input_length:",input_length)
#tf.print(input_length)
#print("label_length:",label_length)
#tf.print(label_length)
loss = self.loss_fn(y_true, y_pred, input_length, label_length)
return loss