TensorFlow函数教程:tf.nn.ctc_loss

2019-01-31 13:45 更新

tf.nn.ctc_loss函数

tf.nn.ctc_loss(
    labels,
    inputs,
    sequence_length,
    preprocess_collapse_repeated=False,
    ctc_merge_repeated=True,
    ignore_longer_outputs_than_inputs=False,
    time_major=True
)

定义在:tensorflow/python/ops/ctc_ops.py.

参见指南:神经网络>连接时间分类(CTC)

计算CTC(连接时间分类)loss.

输入要求:

sequence_length(b) <= time for all b

max(labels.indices(labels.indices[:, 1] == b, 2))
  <= sequence_length(b) for all b.

笔记:

此类为您执行softmax操作,因此输入应该是例如LSTM对输出的线性预测.

inputs张量的最内层的维度大小,num_classes,代表num_labels + 1类别,其中num_labels是实际的标签的数量,而最大的值(num_classes - 1)是为空白标签保留的.

例如,对于包含3个标签[a, b, c]的词汇表,num_classes = 4,并且标签索引是{a: 0, b: 1, c: 2, blank: 3}.

关于参数preprocess_collapse_repeatedctc_merge_repeated

如果preprocess_collapse_repeated为True,则在loss计算之前运行预处理步骤,其中传递给loss的重复标签会合并为单个标签.如果训练标签来自,例如强制对齐,并因此具有不必要的重复,则这是有用的.

如果ctc_merge_repeated设置为False,则在CTC计算的深处,重复的非空白标签将不会合并,并被解释为单个标签.这是CTC的简化(非标准)版本.

以下是(大致)预期的第一顺序行为表:

  • preprocess_collapse_repeated=Falsectc_merge_repeated=True

典型的CTC行为:输出实际的重复类,其间有空白,还可以输出中间没有空白的重复类,这需要由解码器折叠.

  • preprocess_collapse_repeated=Truectc_merge_repeated=False

不要得知输出重复的类,因为它们在训练之前在输入标签中折叠.

  • preprocess_collapse_repeated=Falsectc_merge_repeated=False

输出中间有空白的重复类,但通常不需要解码器折叠/合并重复的类.

  • preprocess_collapse_repeated=Truectc_merge_repeated=True

未经测试,很可能不会得知输出重复的类.

ignore_longer_outputs_than_inputs选项允许在处理输出长于输入的序列时指定CTCLoss的行为.如果为true,则CTCLoss将仅为这些项返回零梯度,否则返回InvalidArgument错误,停止训练.

参数:

  • labels:一个int32SparseTensorlabels.indices[i, :] == [b, t]表示labels.values[i]存储(batch b, time t)的id;labels.values[i]必须采用[0, num_labels)中的值.
  • inputs:3-D float Tensor如果time_major == False,这将是一个Tensor,形状:[batch_size, max_time, num_classes]如果time_major == True(默认值),这将是一个Tensor,形状:[max_time, batch_size, num_classes];是logits.
  • sequence_length:1-Dint32向量,大小为[batch_size]序列长度.
  • preprocess_collapse_repeatedBoolean,默认值:False;如果为True,则在CTC计算之前折叠重复的标签.
  • ctc_merge_repeatedBoolean,默认值:True.
  • ignore_longer_outputs_than_inputs:Boolean,默认值:False;如果为True,则输出比输入长的序列将被忽略.
  • time_majorinputs张量的形状格式如果是True,那些Tensors必须具有形状[max_time, batch_size, num_classes]如果为False,则Tensors必须具有形状[batch_size, max_time, num_classes]使用time_major = True(默认)更有效,因为它避免了在ctc_loss计算开始时的转置.但是,大多数TensorFlow数据都是批处理为主的,因此通过此函数还可以接受以批处理为主的形式的输入.

返回:

1-DfloatTensor,大小为[batch]包含负对数概率.

可能引发的异常:

  • TypeError:如果标签不是SparseTensor.
以上内容是否对您有帮助:
在线笔记
App下载
App下载

扫描二维码

下载编程狮App

公众号
微信公众号

编程狮公众号