前言

本文是阅读论文《Sequence Transduction with Recurrent Neural Networks1 的笔记,主要讨论 RNN-T 在语音识别(ASR)中的应用,包含训练和解码两部分。
CTC2 的作者 Alex Graves 在 2012 年提出 RNN-T,论文中列举了 CTC 的两个问题:(1)不能解决输入序列比输出序列短的问题,如语音合成;(2)CTC 假设输出序列之间独立,并未建模输出之间的依赖关系。RNN-T 用来解决这两个问题。我们按照学习算法(训练)和预测算法(解码)的顺序介绍 RNN-T。

一、学习算法

RNN-T 能够将任何长度输入序列转换为有限的离散输出序列。输入序列
x = ( x 1 , x 2 , … , x T ) ∈ X ∗ , x i ∈ X , 1 ≤ i ≤ T \mathbf{x} = (x_1, x_2, \dots, x_T) \in \mathcal{X}^*, x_i \in \mathcal{X}, 1 ≤ i ≤ T x=(x1,x2,,xT)X,xiX,1iT
输出序列
y = ( y 1 , y 2 , … , y U ) ∈ Y ∗ , y j ∈ Y , 1 ≤ j ≤ U \mathbf{y} = (y_1, y_2, \dots, y_U) \in \mathcal{Y}^*, y_j \in \mathcal{Y}, 1 ≤ j ≤ U y=(y1,y2,,yU)Y,yjY,1jU
其中 x t x_t xt y u y_u yu 是实值有限长度向量,比如在 ASR 中, x t x_t xt 是 80 维的 Fbank 特征, y u y_u yu 是对应的独热标签。 T T T 是音频时长, U U U 表示输出的文字长度。
定义拓展的输出空间 Y ‾ = Y ⋃ Ø \mathcal{\overline{Y}}= \mathcal{Y}\bigcup \text{\O} Y=YØ Ø \text{\O} Ø 表示空向量,是一个输出占位符。比如解码后的序列为 ( y 1 , Ø , Ø , y 2 , Ø , y 3 ) ∈ Y ‾ ∗ (y_1, \text{\O},\text{\O},y_2,\text{\O},y_3)\in \mathcal{\overline{Y}^*} (y1,Ø,Ø,y2,Ø,y3)Y,去除空占位符后得到 ( y 1 , y 2 , y 3 ) ∈ Y ∗ (y_1,y_2,y_3)\in \mathcal{Y^*} (y1,y2,y3)Y。我们称 a ∈ Y ‾ ∗ \mathbf{a} \in \mathcal{\overline{Y}}^* aY 为一个对齐,给定 x \mathbf{x} x, RNN-T 表示条件概率分布 P ( a ∈ Y ‾ ∗ ∣ x ) \mathbb{P}(\mathbf{a} \in \mathcal{\overline{Y}^*} | \mathbf{x}) P(aYx)。注意,这个时候的对齐 a \mathbf{a} a 是包含空向量的,真正解码后的文字序列 y ∈ Y ∗ \mathbf{y} \in \mathcal{Y}^* yY 不含空向量,那 P ( y ∈ Y ∗ ∣ x ) \mathbb{P}(\mathbf{y} \in \mathcal{Y}^* | \mathbf{x}) P(yYx) 怎么算呢?和 HMMCTC 一样,将所有可能的对齐都加起来
P ( y ∈ Y ∗ ∣ x ) = ∑ a ∈ B − 1 ( y ) P ( a ∣ x ) (1) \mathbb{P}(\mathbf{y} \in \mathcal{Y}^* | \mathbf{x})= \sum_{\mathbf{a}\in \mathcal{B}^{-1}(\mathbf{y})} \mathbb{P}(\mathbf{a}|\mathbf{x}) \tag{1} P(yYx)=aB1(y)P(ax)(1)
其中 B : Y ‾ ∗ ↦ Y ∗ \mathcal{B}: \mathcal{\overline{Y}}^* \mapsto \mathcal{Y}^* B:YY 是去除空向量的映射。

在这里插入图片描述


图1 RNN-T结构3

如图 1 所示,为了解决 CTC 的问题,RNN-T 在 CTC 的基础上添加了 Prediction RNN 和 Joint network 两个模型。于是 RNN-T 由编码器(Encoder)、解码器(Decoder)、融合器(Joiner)3 部分构成。

1.1 RNN-T模型结构

1.1.1 编码器

编码器在论文中称作转写网络 F \mathcal{F} F(Transcription Network),由双向 RNN 构成,它的作用是将 ( x 1 , … , x T ) (x_1, \dots, x_T) (x1,,xT) 映射到高维表示 f t , 1 ≤ t ≤ T \mathbf{f_t}, 1 ≤ t ≤ T ft,1tT,每个 f t \mathbf{f_t} ft K + 1 K+1 K+1 个值, K + 1 K+1 K+1 是建模单元的个数加空记号。这部分相当于 CTC 的编码器主体,到如今,具体结构已经不重要,Transformer、Zipformer 等结构均可。

1.1.2 解码器

解码器在论文中称作预测网络 G \mathcal{G} G(Prediction Network),由一层 RNN 组成,输入序列 y ^ = ( Ø , y 1 , … , y U ) \mathbf{\hat{y}} = (\text{\O}, y_1, \dots, y_U) y^=(Ø,y1,,yU) g = G ( y ^ ) \mathbf{g} = \mathcal{G}(\mathbf{\hat{y}}) g=G(y^) g \mathbf{g} g 是输出序列。 y ^ \mathbf{\hat{y}} y^ 的每个元素也都是向量,假设有 K K K 个标签,比如中文拼音建模, K K K 约是 200。 y i y_i yi 代表一个标签处是 1,其余为 0 的长度为 K K K 的向量, Ø \text{\O} Ø 代表全是 0 的长度为 K K K 的向量,故 y ^ \mathbf{\hat{y}} y^ 其实是一个形状为 K × ( U + 1 ) K × (U+1) K×(U+1) 的张量。解码器的输出则是形状为 ( K + 1 ) × ( U + 1 ) (K+1) × (U+1) (K+1)×(U+1) 的张量,因为每一次的输出节点包含空记号。

1.1.3 融合器

转写网络只处理音频信号,预测网络只处理文字信息,所以我们类比传统 ASR,有时候也将其称为声学模型(Acoustic model, AM)和语言模型(Linguistic model,LM)。传统 ASR 是利用加权有限状态转换器(Weighted Finite-State Transducer,WFST)来进行解码,将两者的信息融合起来。那 RNN-Transducer 怎么融合呢?
在这里插入图片描述


图2 传统 ASR 系统框架 4

由于转写向量 f t , 1 ≤ t ≤ T \mathbf{f_t}, 1≤t≤T ft,1tT 和预测向量 g u , 0 ≤ u ≤ U \mathbf{g_u}, 0≤u≤U gu,0uU 的长度都是 K + 1 K+1 K+1,确定了 t t t u u u,可以将对应向量加起来,于是有
h ( k , t , u ) = e x p ( f t k + g u k ) (2) h(k, t, u) = exp(f_t^k + g_u^k) \tag{2} h(k,t,u)=exp(ftk+guk)(2)
其中 k k k 代表向量的第 k k k 个元素。于是输出分布为
P ( k ∈ Y ˉ ∣ t , u ) = h ( k , t , u ) ∑ k ′ ∈ Y ˉ h ( k ′ , t , u ) (3) \mathbb{P}(k \in \bar{\mathcal{Y}} | t, u)= \frac{h(k, t, u)}{ \sum_{k'\in \bar{\mathcal{Y}}} h(k', t, u)} \tag{3} P(kYˉt,u)=kYˉh(k,t,u)h(k,t,u)(3)
原论文只是简单相加进行融合,在2013年的文章5 中,作者将其替换成了神经网络,我们称其为融合器(Joiner,Joint network),这样编码器和解码器的输出维度也不限于 K + 1 K + 1 K+1,利用 softmax 函数来生成(3)中的概率分布。
我们记
y ( t , u ) : = P ( y u + 1 ∣ t , u ) Ø ( t , u ) : = P ( Ø ∣ t , u ) \begin{align} y(t, u) &:= \mathbb{P}(y_{u+1}|t,u) \tag{4} \\ \text{\O}(t,u) &:= \mathbb{P}(\text{\O}|t,u) \tag{5} \end{align} y(t,u)Ø(t,u):=P(yu+1t,u):=P(Øt,u)(4)(5)
在这里插入图片描述


图3 Transducer概率网格图

我们结合图 3 来说明 y ( t , u ) y(t, u) y(t,u) Ø ( t , u ) \text{\O}(t,u) Ø(t,u) 的含义。图 3 中节点 ( t , u ) (t, u) (t,u) 表示 t t t 时刻输出了 u u u 个非空标签(token)的概率。从节点 ( t , u ) (t, u) (t,u) 出发的横向箭头 Ø ( t , u ) \text{\O}(t,u) Ø(t,u) 表示 ( t , u ) (t, u) (t,u) 后输出空标记,在 Y \mathcal{Y} Y 空间来讲就是什么都没输出。纵向箭头表示 ( t , u ) (t, u) (t,u) 后输出 y u + 1 y_{u+1} yu+1 的概率。底部的黑色实心节点表示空标记。按照解码顺序,网格永远是从左下角出发,最终到达右上角。图 3 中的红色代表一种可能的解码路径。
相比 CTC 的解码图,RNN-T 没有斜向箭头,摆脱了时间 t t t 的限制,就有可能输出任意多的 token,这样就解决了 CTC 的第一个问题。而 RNN-T 的创新是增加了预测网络和融合网络,建立了输出之间的依赖关系,相比 CTC,输出概率分布从 P ( y u + 1 ∣ x 1 , … , x n ) \mathbb{P}(y_{u+1}|x_1,\dots,x_n) P(yu+1x1,,xn) 变为 P ( y u + 1 ∣ x 1 , … , x n , y 1 , … , y u ) \mathbb{P}(y_{u+1}|x_1,\dots,x_n,y_1,\dots,y_u) P(yu+1x1,,xn,y1,,yu),解决了 CTC 的第二个问题。

1.2 前向后向算法

回顾式(1),我们无法直接写出 P ( y ∈ Y ∗ ∣ x ) \mathbb{P}(\mathbf{y} \in \mathcal{Y}^* | \mathbf{x}) P(yYx) 的计算公式,还是通过动态规划来算。

1.2.1 前向算法

(1)前向变量
α ( t , u ) \alpha(t, u) α(t,u) 表示 [ 1 : t ] [1:t] [1:t] 时刻输出 token y [ 1 , u ] \mathbf{y}_{[1, u]} y[1,u],就是图 3 中的节点 ( t , u ) (t, u) (t,u)
(2)初始化
α ( 1 , 0 ) = 1 (6) \alpha(1, 0) = 1 \tag{6} α(1,0)=1(6)
(3)递推公式
每个节点只接收来自左边和下边的箭头。对 ∀ 1 ≤ t ≤ T , 0 ≤ u ≤ U \forall 1≤t≤T, 0≤u≤U ∀1tT,0uU,有
α ( t , u ) = α ( t − 1 , u ) Ø ( t − 1 , u ) + α ( t , u − 1 ) y ( t , u − 1 ) (7) \alpha(t, u) = \alpha(t-1, u)\text{\O}(t-1,u) + \alpha(t, u-1)y(t,u-1) \tag{7} α(t,u)=α(t1,u)Ø(t1,u)+α(t,u1)y(t,u1)(7)
(4)终止
P ( y ∣ x ) = α ( T , U ) Ø ( T , U ) (8) \mathbb{P}(\mathbf{y} | \mathbf{x}) = \alpha(T, U)\text{\O}(T, U) \tag{8} P(yx)=α(T,U)Ø(T,U)(8)
如图 3 中节点 ( 4 , 3 ) (4,3) (4,3) 经过 Ø ( 4 , 3 ) \text{\O}(4, 3) Ø(4,3) 到达终止节点。

1.2.2 后向算法

(1)后向变量
β ( t , u ) \beta(t, u) β(t,u) 表示 [ t : T ] [t:T] [t:T] 时刻输出 token y [ u + 1 , U ] \mathbf{y}_{[u+1, U]} y[u+1,U]
(2)初始化
β ( T , U ) = Ø ( T , U ) (9) \beta(T, U) = \text{\O}(T, U) \tag{9} β(T,U)=Ø(T,U)(9)
(3)递推公式
每个节点只接收来自右边和上边的箭头。对 ∀ 1 ≤ t ≤ T , 0 ≤ u ≤ U \forall 1≤t≤T, 0≤u≤U ∀1tT,0uU,有
β ( t , u ) = β ( t + 1 , u ) Ø ( t , u ) + β ( t , u + 1 ) y ( t , u ) (10) \beta(t, u) = \beta(t+1, u)\text{\O}(t,u) + \beta(t, u+1)y(t,u) \tag{10} β(t,u)=β(t+1,u)Ø(t,u)+β(t,u+1)y(t,u)(10)
(4)终止
β ( 1 , 0 ) = β ( 2 , 0 ) Ø ( 1 , 0 ) + β ( 1 , 1 ) y ( 1 , 0 ) (11) \beta(1, 0) = \beta(2, 0)\text{\O}(1,0) + \beta(1, 1)y(1,0) \tag{11} β(1,0)=β(2,0)Ø(1,0)+β(1,1)y(1,0)(11)

1.3 反向传播

已知输入序列 x \mathbf{x} x 和目标输出序列 y ∗ \mathbf{y^*} y,训练模型就是减小负对数似然 L = − l n P ( y ∗ ∣ x ) \mathcal{L} = -ln\mathbb{P}(\mathbf{y^*}|\mathbf{x}) L=lnP(yx),需要算偏导数 ∂ L ∂ f t k \frac{\partial\mathcal{L}}{\partial{f_t^k}} ftkL ∂ L ∂ g u k \frac{\partial\mathcal{L}}{\partial{g_u^k}} gukL
根据前向变量和反向变量的定义,不难得出
P ( y ∗ ∣ x ) = ∑ ( t , u ) : t + u = n α ( t , u ) β ( t , u )       ∀ n : 1 ≤ n ≤ U + T (12) \mathbb{P}(\mathbf{y^*}|\mathbf{x}) = \sum_{(t,u):t+u=n}\alpha(t,u)\beta(t,u) \ \ \ \ \ \forall n:1≤n≤U+T \tag{12} P(yx)=(t,u):t+u=nα(t,u)β(t,u)     n:1nU+T(12)
结合等式(7)(10)(12)和 L \mathcal{L} L 的定义,可以得到
∂ L ∂ P ( k ∣ t , u ) = { − α ( t , u ) β ( t , u + 1 ) P ( y ∗ ∣ x ) k = y u + 1 − α ( t , u ) β ( t + 1 , u ) P ( y ∗ ∣ x ) k = Ø 0 o t h e r w i s e (13) \frac{\partial\mathcal{L}}{\partial \mathbb{P}(k|t,u)} = \left\{ \begin{array}{} -\frac{\alpha(t,u)\beta(t,u+1)}{\mathbb{P}(\mathbf{y^*}|\mathbf{x})} & k = y_{u+1} \\ -\frac{\alpha(t,u)\beta(t+1,u)}{\mathbb{P}(\mathbf{y^*}|\mathbf{x})} & k = \text{\O} \\ 0 & otherwise \end{array}\right. \tag{13} P(kt,u)L= P(yx)α(t,u)β(t,u+1)P(yx)α(t,u)β(t+1,u)0k=yu+1k=Øotherwise(13)
由(13)又可以推出
∑ u = 0 U ∑ k ′ ∈ Y ˉ ∂ L ∂ P ( k ′ ∣ t , u ) ∂ P ( k ′ ∣ t , u ) ∂ f t k = − 1 P ( y ∗ ∣ x ) ∑ u = 0 U ( ∂ P ( y u + 1 ∣ t , u ) ∂ f t k ( α ( t , u ) β ( t , u + 1 ) ) + ∂ P ( Ø ∣ t , u ) ∂ f t k ( α ( t , u ) β ( t + 1 , u ) ) ) = − 1 P ( y ∗ ∣ x ) ∑ u = 0 U ( ∂ α ( t , u ) β ( t , u ) ∂ f t k ) = − 1 P ( y ∗ ∣ x ) ∂ P ( y ∗ ∣ x ) ∂ f t k = ∂ L ∂ f t k \begin{align*} &\sum_{u=0}^{U}\sum_{k'\in \bar{\mathcal{Y}}}\frac{\partial\mathcal{L}}{\partial \mathbb{P}(k'|t,u)}\frac{\partial \mathbb{P}(k'|t,u)}{\partial{f_t^k}} \\ =& -\frac{1}{\mathbb{P}(\mathbf{y^*}|\mathbf{x})}\sum_{u=0}^{U}\left(\frac{\partial\mathbb{P}(y_{u+1}|t,u)}{\partial{f_t^k}}(\alpha(t,u)\beta(t,u+1)) + \frac{\partial \mathbb{P}(\text{\O}|t,u)}{\partial{f_t^k}}(\alpha(t,u)\beta(t+1,u)) \right)\\ =& -\frac{1}{\mathbb{P}(\mathbf{y^*}|\mathbf{x})}\sum_{u=0}^{U}\left(\frac{\partial\alpha(t,u)\beta(t,u)}{\partial f_t^k} \right)\\ =& -\frac{1}{\mathbb{P}(\mathbf{y^*}|\mathbf{x})}\frac{\partial \mathbb{P}(\mathbf{y^*}|\mathbf{x})}{\partial f_t^k} \\ =& \frac{\partial\mathcal{L}}{\partial{f_t^k}} \end{align*} ====u=0UkYˉP(kt,u)LftkP(kt,u)P(yx)1u=0U(ftkP(yu+1t,u)(α(t,u)β(t,u+1))+ftkP(Øt,u)(α(t,u)β(t+1,u)))P(yx)1u=0U(ftkα(t,u)β(t,u))P(yx)1ftkP(yx)ftkL
第 2 个等号是因为 ∂ L ∂ P ( k ′ ∣ t , u ) \frac{\partial\mathcal{L}}{\partial \mathbb{P}(k'|t,u)} P(kt,u)L 沿 k ′ k' k 求和,只有 2 项非零。第 3 个等号利用了等式(10)。于是有
∂ L ∂ f t k = ∑ u = 0 U ∑ k ′ ∈ Y ˉ ∂ L ∂ P ( k ′ ∣ t , u ) ∂ P ( k ′ ∣ t , u ) ∂ f t k (14) \frac{\partial\mathcal{L}}{\partial{f_t^k}} = \sum_{u=0}^{U}\sum_{k'\in \bar{\mathcal{Y}}}\frac{\partial\mathcal{L}}{\partial \mathbb{P}(k'|t,u)}\frac{\partial \mathbb{P}(k'|t,u)}{\partial{f_t^k}} \tag{14} ftkL=u=0UkYˉP(kt,u)LftkP(kt,u)(14)
同理可得
∂ L ∂ g u k = ∑ t = 1 T ∑ k ′ ∈ Y ˉ ∂ L ∂ P ( k ′ ∣ t , u ) ∂ P ( k ′ ∣ t , u ) ∂ g u k (15) \frac{\partial\mathcal{L}}{\partial{g_u^k}} = \sum_{t=1}^{T}\sum_{k'\in \bar{\mathcal{Y}}}\frac{\partial\mathcal{L}}{\partial \mathbb{P}(k'|t,u)}\frac{\partial \mathbb{P}(k'|t,u)}{\partial{g_u^k}} \tag{15} gukL=t=1TkYˉP(kt,u)LgukP(kt,u)(15)
由等式(3)可得
∂ P ( k ′ ∣ t , u ) ∂ f t k = ∂ P ( k ′ ∣ t , u ) ∂ g u k = P ( k ′ ∣ t , u ) [ δ k k ′ − P ( k ∣ t , u ) ] (16) \frac{\partial \mathbb{P}(k'|t,u)}{\partial{f_t^k}}=\frac{\partial \mathbb{P}(k'|t,u)}{\partial{g_u^k}} = \mathbb{P}(k'|t,u)[\delta_{kk'}-\mathbb{P}(k|t,u)] \tag{16} ftkP(kt,u)=gukP(kt,u)=P(kt,u)[δkkP(kt,u)](16)
反向传播对偏导数 ∂ L ∂ f t k \frac{\partial\mathcal{L}}{\partial{f_t^k}} ftkL ∂ L ∂ g u k \frac{\partial\mathcal{L}}{\partial{g_u^k}} gukL 进行回传。

二、预测算法

模型训好后,通过预测算法(解码)转写出对应的文字。论文中介绍了波束搜索(Beam Search)算法,可以扩展到任意长的序列,并通过控制波束搜索宽度,对计算成本和搜索精度进行平衡。
请添加图片描述


图4 宽度为3的波束搜索 6

如图 4 所示,每一层代表时刻 t t t 的解码选择,每一层的数字是标签的序号。比如共有 K K K 个标签,当 t = 1 t=1 t=1 时,可能的路径有 K K K 条;当 t = 2 t=2 t=2 时,有 K 2 K^2 K2 条路径;当 t = n t=n t=n 时,有 K n K^n Kn 条路径,显然这会导致内存爆炸。波束搜索是确定搜索宽度 W , 1 ≤ W ≤ K W, 1≤W≤K W,1WK,每一时刻 t t t,选择路径概率最大的 W W W 个节点保存,在 t + 1 t+1 t+1 时,在这 W W W 个节点的基础上拓展,再选出 t + 1 t+1 t+1 时刻路径概率最大的 W W W 个节点,以此类推,将节点拓展下去,最终概率最大的路径就是转写结果。CTC 中的最佳路径解码就是 W = 1 W=1 W=1 的情形,由此推知,波束搜索也是近似算法。
下面是语音工具包 k2 对论文中波束搜索算法的实现 7

# k2 中的 beam search 实现
def beam_search(
    model: Transducer,
    encoder_out: torch.Tensor,
    beam: int = 4,
) -> List[int]:
    """
    It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
    
    Args:
      model:
        An instance of `Transducer`.
      encoder_out:
        A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
      beam:
        Beam size.
    Returns:
      Return the decoded result.
    """
    assert encoder_out.ndim == 3

    # support only batch_size == 1 for now
    assert encoder_out.size(0) == 1, encoder_out.size(0)
    blank_id = model.decoder.blank_id		# 空记号,编号 0
    context_size = model.decoder.context_size  # 2,bigram

    device = model.device
		
    decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
        1, context_size
    )

    decoder_out = model.decoder(decoder_input, need_pad=False)

    T = encoder_out.size(1)  # 解码音频的长度
    t = 0

    B = HypothesisList()
    # HypothesisList 对象 B 存储当前最优的假设,形如 '0_0_48_366_65_66_6': 
    # Hypothesis(ys=[0, 0, 48, 366, 65, 66, 6], log_prob=tensor([-1.0889])
    # 保存路径和对应的概率
    B.add(
        Hypothesis(
            ys=[blank_id] * context_size,
            log_prob=torch.zeros(1, dtype=torch.float32, device=device),
        )
    )

    max_sym_per_utt = 20000

    sym_per_utt = 0

    encoder_out_len = torch.tensor([1])
    decoder_out_len = torch.tensor([1])

    decoder_cache: Dict[str, torch.Tensor] = {}
    
    # 遍历每个时间步 t, 直至音频结束或符号数量上限
    while t < T and sym_per_utt < max_sym_per_utt:
        # fmt: off
        current_encoder_out = encoder_out[:, t:t+1, :]
        # fmt: on
        A = B		# A 承接了 t - 1 时刻的最优假设,将在此假设基础上继续延拓
        B = HypothesisList()		# B 中存储 t 时刻最优 beam 个假设

        joint_cache: Dict[str, torch.Tensor] = {}

        while True:
            y_star = A.get_most_probable() # 从 A 中选取最优假设
            A.remove(y_star)

            decoder_out = run_decoder(
                ys=y_star.ys, model=model, decoder_cache=decoder_cache
            )

            key = "_".join(map(str, y_star.ys[-context_size:]))
            key += f"-t-{t}"
            
            # 利用转写向量、预测向量和 joiner 计算联合概率
            log_prob = run_joiner(
                key=key,
                model=model,
                encoder_out=current_encoder_out,
                decoder_out=decoder_out,
                encoder_out_len=encoder_out_len,
                decoder_out_len=decoder_out_len,
                joint_cache=joint_cache,
            )
					
			# 先处理空记号,虽然不输出有效 token,但需要累积概率,此路径仍是潜在最优解 
            # First, process the blank symbol
            skip_log_prob = log_prob[blank_id]
            new_y_star_log_prob = y_star.log_prob + skip_log_prob
					
			# 更新 B
            # ys[:] returns a copy of ys
            B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
		
			# 处理非空标签,对 A 进行延拓
            # Second, process other non-blank labels
            values, indices = log_prob.topk(beam + 1)
            # 相当于图 4 中第 2 层节点 5 往节点 8 和节点 9 延拓,每次都拓展 beam + 1 个
            # 这 beam + 1 个节点可能包括空记号,之所以是 beam + 1, 也是为了保证 B 最终能筛出最大的 beam 个假设
            for idx in range(values.size(0)):
                i = indices[idx].item()
                if i == blank_id:
                    continue

                new_ys = y_star.ys + [i]

                new_log_prob = y_star.log_prob + values[idx]
                A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))

            # Check whether B contains more than "beam" elements more probable
            # than the most probable in A
            A_most_probable = A.get_most_probable()

            kept_B = B.filter(A_most_probable.log_prob)
					
			# 根据设定的束宽,筛选并保留 B 中最优的假设
            if len(kept_B) >= beam:
                B = kept_B.topk(beam)
                break
		# 进入 t + 1 层
        t += 1
	
	# 从最终保留的假设中,选取经过长度归一化处理后的最优假设
    best_hyp = B.get_most_probable(length_norm=True)
    # 移除上下文中的空白符号,返回解码结果
    ys = best_hyp.ys[context_size:]  # [context_size:] to remove blanks
    return ys

总结

论文引入了预测网络和联合网络,解决了 CTC 不合理的独立假设问题。解码时候,可以对每帧输入进行预测输出,使得 RNN-T 天然具有自回归性,可以用于流式 ASR。且 AM + LM 的组合,让建模更加完备合理。
RNN-T 也有缺点,相比 CTC 的输出维度 (N, T, C), RNN-T 则是 (N, T, U, C),U 是解码器的输出长度,可达上百量级。故 RNN-T 内存占用大 ,训练复杂度高。
随着模型结构的进步,编码器中的 RNN 逐渐被各种 Former 类模型(Conformer,Zipformer等)替代。谷歌在深入研究解码器后也发现 3,RNN 结构并不是必须的,直接将 y u − 1 y_{u-1} yu1 的嵌入(Embedding)送入联合网络,也能达到相近的效果。就像信乐团没有信,飞儿乐队没有飞一样,RNN-T 在如今的很多论文中,都只被称作 Transducer。

参考文献

  1. Alex Graves. Sequence Transduction with Recurrent Neural Networks. 2012. ↩︎

  2. Alex Graves, et al. Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks. 2006. ↩︎

  3. Ghodsi, Mohammadreza , et al. “Rnn-Transducer with Stateless Prediction Network.” ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) IEEE, 2020. ↩︎ ↩︎

  4. 洪青阳,李琳. 《语音识别原理与应用》, 2020. ↩︎

  5. Alex Graves, et al. Speech recognition with deep recurrent neural networks. 2013. ↩︎

  6. https://en.wikipedia.org/wiki/Beam_search. ↩︎

  7. https://github.com/csukuangfj/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/beam_search.py. ↩︎

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐