【语音识别】Sequence Transduction with Recurrent Neural Networks(RNN-T)
Sequence Transduction with Recurrent Neural Networks(RNN-T) 论文笔记
前言
本文是阅读论文《Sequence Transduction with Recurrent Neural Networks》1 的笔记,主要讨论 RNN-T 在语音识别(ASR)中的应用,包含训练和解码两部分。
CTC2 的作者 Alex Graves 在 2012 年提出 RNN-T,论文中列举了 CTC 的两个问题:(1)不能解决输入序列比输出序列短的问题,如语音合成;(2)CTC 假设输出序列之间独立,并未建模输出之间的依赖关系。RNN-T 用来解决这两个问题。我们按照学习算法(训练)和预测算法(解码)的顺序介绍 RNN-T。
一、学习算法
RNN-T 能够将任何长度输入序列转换为有限的离散输出序列。输入序列
x = ( x 1 , x 2 , … , x T ) ∈ X ∗ , x i ∈ X , 1 ≤ i ≤ T \mathbf{x} = (x_1, x_2, \dots, x_T) \in \mathcal{X}^*, x_i \in \mathcal{X}, 1 ≤ i ≤ T x=(x1,x2,…,xT)∈X∗,xi∈X,1≤i≤T
输出序列
y = ( y 1 , y 2 , … , y U ) ∈ Y ∗ , y j ∈ Y , 1 ≤ j ≤ U \mathbf{y} = (y_1, y_2, \dots, y_U) \in \mathcal{Y}^*, y_j \in \mathcal{Y}, 1 ≤ j ≤ U y=(y1,y2,…,yU)∈Y∗,yj∈Y,1≤j≤U
其中 x t x_t xt 和 y u y_u yu 是实值有限长度向量,比如在 ASR 中, x t x_t xt 是 80 维的 Fbank 特征, y u y_u yu 是对应的独热标签。 T T T 是音频时长, U U U 表示输出的文字长度。
定义拓展的输出空间 Y ‾ = Y ⋃ Ø \mathcal{\overline{Y}}= \mathcal{Y}\bigcup \text{\O} Y=Y⋃Ø, Ø \text{\O} Ø 表示空向量,是一个输出占位符。比如解码后的序列为 ( y 1 , Ø , Ø , y 2 , Ø , y 3 ) ∈ Y ‾ ∗ (y_1, \text{\O},\text{\O},y_2,\text{\O},y_3)\in \mathcal{\overline{Y}^*} (y1,Ø,Ø,y2,Ø,y3)∈Y∗,去除空占位符后得到 ( y 1 , y 2 , y 3 ) ∈ Y ∗ (y_1,y_2,y_3)\in \mathcal{Y^*} (y1,y2,y3)∈Y∗。我们称 a ∈ Y ‾ ∗ \mathbf{a} \in \mathcal{\overline{Y}}^* a∈Y∗ 为一个对齐,给定 x \mathbf{x} x, RNN-T 表示条件概率分布 P ( a ∈ Y ‾ ∗ ∣ x ) \mathbb{P}(\mathbf{a} \in \mathcal{\overline{Y}^*} | \mathbf{x}) P(a∈Y∗∣x)。注意,这个时候的对齐 a \mathbf{a} a 是包含空向量的,真正解码后的文字序列 y ∈ Y ∗ \mathbf{y} \in \mathcal{Y}^* y∈Y∗ 不含空向量,那 P ( y ∈ Y ∗ ∣ x ) \mathbb{P}(\mathbf{y} \in \mathcal{Y}^* | \mathbf{x}) P(y∈Y∗∣x) 怎么算呢?和 HMM、CTC 一样,将所有可能的对齐都加起来
P ( y ∈ Y ∗ ∣ x ) = ∑ a ∈ B − 1 ( y ) P ( a ∣ x ) (1) \mathbb{P}(\mathbf{y} \in \mathcal{Y}^* | \mathbf{x})= \sum_{\mathbf{a}\in \mathcal{B}^{-1}(\mathbf{y})} \mathbb{P}(\mathbf{a}|\mathbf{x}) \tag{1} P(y∈Y∗∣x)=a∈B−1(y)∑P(a∣x)(1)
其中 B : Y ‾ ∗ ↦ Y ∗ \mathcal{B}: \mathcal{\overline{Y}}^* \mapsto \mathcal{Y}^* B:Y∗↦Y∗ 是去除空向量的映射。

图1 RNN-T结构3
1.1 RNN-T模型结构
1.1.1 编码器
编码器在论文中称作转写网络 F \mathcal{F} F(Transcription Network),由双向 RNN 构成,它的作用是将 ( x 1 , … , x T ) (x_1, \dots, x_T) (x1,…,xT) 映射到高维表示 f t , 1 ≤ t ≤ T \mathbf{f_t}, 1 ≤ t ≤ T ft,1≤t≤T,每个 f t \mathbf{f_t} ft 有 K + 1 K+1 K+1 个值, K + 1 K+1 K+1 是建模单元的个数加空记号。这部分相当于 CTC 的编码器主体,到如今,具体结构已经不重要,Transformer、Zipformer 等结构均可。
1.1.2 解码器
解码器在论文中称作预测网络 G \mathcal{G} G(Prediction Network),由一层 RNN 组成,输入序列 y ^ = ( Ø , y 1 , … , y U ) \mathbf{\hat{y}} = (\text{\O}, y_1, \dots, y_U) y^=(Ø,y1,…,yU), g = G ( y ^ ) \mathbf{g} = \mathcal{G}(\mathbf{\hat{y}}) g=G(y^), g \mathbf{g} g 是输出序列。 y ^ \mathbf{\hat{y}} y^ 的每个元素也都是向量,假设有 K K K 个标签,比如中文拼音建模, K K K 约是 200。 y i y_i yi 代表一个标签处是 1,其余为 0 的长度为 K K K 的向量, Ø \text{\O} Ø 代表全是 0 的长度为 K K K 的向量,故 y ^ \mathbf{\hat{y}} y^ 其实是一个形状为 K × ( U + 1 ) K × (U+1) K×(U+1) 的张量。解码器的输出则是形状为 ( K + 1 ) × ( U + 1 ) (K+1) × (U+1) (K+1)×(U+1) 的张量,因为每一次的输出节点包含空记号。
1.1.3 融合器
转写网络只处理音频信号,预测网络只处理文字信息,所以我们类比传统 ASR,有时候也将其称为声学模型(Acoustic model, AM)和语言模型(Linguistic model,LM)。传统 ASR 是利用加权有限状态转换器(Weighted Finite-State Transducer,WFST)来进行解码,将两者的信息融合起来。那 RNN-Transducer 怎么融合呢?
图2 传统 ASR 系统框架 4
由于转写向量 f t , 1 ≤ t ≤ T \mathbf{f_t}, 1≤t≤T ft,1≤t≤T 和预测向量 g u , 0 ≤ u ≤ U \mathbf{g_u}, 0≤u≤U gu,0≤u≤U 的长度都是 K + 1 K+1 K+1,确定了 t t t 和 u u u,可以将对应向量加起来,于是有
h ( k , t , u ) = e x p ( f t k + g u k ) (2) h(k, t, u) = exp(f_t^k + g_u^k) \tag{2} h(k,t,u)=exp(ftk+guk)(2)
其中 k k k 代表向量的第 k k k 个元素。于是输出分布为
P ( k ∈ Y ˉ ∣ t , u ) = h ( k , t , u ) ∑ k ′ ∈ Y ˉ h ( k ′ , t , u ) (3) \mathbb{P}(k \in \bar{\mathcal{Y}} | t, u)= \frac{h(k, t, u)}{ \sum_{k'\in \bar{\mathcal{Y}}} h(k', t, u)} \tag{3} P(k∈Yˉ∣t,u)=∑k′∈Yˉh(k′,t,u)h(k,t,u)(3)
原论文只是简单相加进行融合,在2013年的文章5 中,作者将其替换成了神经网络,我们称其为融合器(Joiner,Joint network),这样编码器和解码器的输出维度也不限于 K + 1 K + 1 K+1,利用 softmax 函数来生成(3)中的概率分布。
我们记
y ( t , u ) : = P ( y u + 1 ∣ t , u ) Ø ( t , u ) : = P ( Ø ∣ t , u ) \begin{align} y(t, u) &:= \mathbb{P}(y_{u+1}|t,u) \tag{4} \\ \text{\O}(t,u) &:= \mathbb{P}(\text{\O}|t,u) \tag{5} \end{align} y(t,u)Ø(t,u):=P(yu+1∣t,u):=P(Ø∣t,u)(4)(5)
图3 Transducer概率网格图
我们结合图 3 来说明 y ( t , u ) y(t, u) y(t,u) 和 Ø ( t , u ) \text{\O}(t,u) Ø(t,u) 的含义。图 3 中节点 ( t , u ) (t, u) (t,u) 表示 t t t 时刻输出了 u u u 个非空标签(token)的概率。从节点 ( t , u ) (t, u) (t,u) 出发的横向箭头 Ø ( t , u ) \text{\O}(t,u) Ø(t,u) 表示 ( t , u ) (t, u) (t,u) 后输出空标记,在 Y \mathcal{Y} Y 空间来讲就是什么都没输出。纵向箭头表示 ( t , u ) (t, u) (t,u) 后输出 y u + 1 y_{u+1} yu+1 的概率。底部的黑色实心节点表示空标记。按照解码顺序,网格永远是从左下角出发,最终到达右上角。图 3 中的红色代表一种可能的解码路径。
相比 CTC 的解码图,RNN-T 没有斜向箭头,摆脱了时间 t t t 的限制,就有可能输出任意多的 token,这样就解决了 CTC 的第一个问题。而 RNN-T 的创新是增加了预测网络和融合网络,建立了输出之间的依赖关系,相比 CTC,输出概率分布从 P ( y u + 1 ∣ x 1 , … , x n ) \mathbb{P}(y_{u+1}|x_1,\dots,x_n) P(yu+1∣x1,…,xn) 变为 P ( y u + 1 ∣ x 1 , … , x n , y 1 , … , y u ) \mathbb{P}(y_{u+1}|x_1,\dots,x_n,y_1,\dots,y_u) P(yu+1∣x1,…,xn,y1,…,yu),解决了 CTC 的第二个问题。
1.2 前向后向算法
回顾式(1),我们无法直接写出 P ( y ∈ Y ∗ ∣ x ) \mathbb{P}(\mathbf{y} \in \mathcal{Y}^* | \mathbf{x}) P(y∈Y∗∣x) 的计算公式,还是通过动态规划来算。
1.2.1 前向算法
(1)前向变量
α ( t , u ) \alpha(t, u) α(t,u) 表示 [ 1 : t ] [1:t] [1:t] 时刻输出 token y [ 1 , u ] \mathbf{y}_{[1, u]} y[1,u],就是图 3 中的节点 ( t , u ) (t, u) (t,u)。
(2)初始化
α ( 1 , 0 ) = 1 (6) \alpha(1, 0) = 1 \tag{6} α(1,0)=1(6)
(3)递推公式
每个节点只接收来自左边和下边的箭头。对 ∀ 1 ≤ t ≤ T , 0 ≤ u ≤ U \forall 1≤t≤T, 0≤u≤U ∀1≤t≤T,0≤u≤U,有
α ( t , u ) = α ( t − 1 , u ) Ø ( t − 1 , u ) + α ( t , u − 1 ) y ( t , u − 1 ) (7) \alpha(t, u) = \alpha(t-1, u)\text{\O}(t-1,u) + \alpha(t, u-1)y(t,u-1) \tag{7} α(t,u)=α(t−1,u)Ø(t−1,u)+α(t,u−1)y(t,u−1)(7)
(4)终止
P ( y ∣ x ) = α ( T , U ) Ø ( T , U ) (8) \mathbb{P}(\mathbf{y} | \mathbf{x}) = \alpha(T, U)\text{\O}(T, U) \tag{8} P(y∣x)=α(T,U)Ø(T,U)(8)
如图 3 中节点 ( 4 , 3 ) (4,3) (4,3) 经过 Ø ( 4 , 3 ) \text{\O}(4, 3) Ø(4,3) 到达终止节点。
1.2.2 后向算法
(1)后向变量
β ( t , u ) \beta(t, u) β(t,u) 表示 [ t : T ] [t:T] [t:T] 时刻输出 token y [ u + 1 , U ] \mathbf{y}_{[u+1, U]} y[u+1,U]。
(2)初始化
β ( T , U ) = Ø ( T , U ) (9) \beta(T, U) = \text{\O}(T, U) \tag{9} β(T,U)=Ø(T,U)(9)
(3)递推公式
每个节点只接收来自右边和上边的箭头。对 ∀ 1 ≤ t ≤ T , 0 ≤ u ≤ U \forall 1≤t≤T, 0≤u≤U ∀1≤t≤T,0≤u≤U,有
β ( t , u ) = β ( t + 1 , u ) Ø ( t , u ) + β ( t , u + 1 ) y ( t , u ) (10) \beta(t, u) = \beta(t+1, u)\text{\O}(t,u) + \beta(t, u+1)y(t,u) \tag{10} β(t,u)=β(t+1,u)Ø(t,u)+β(t,u+1)y(t,u)(10)
(4)终止
β ( 1 , 0 ) = β ( 2 , 0 ) Ø ( 1 , 0 ) + β ( 1 , 1 ) y ( 1 , 0 ) (11) \beta(1, 0) = \beta(2, 0)\text{\O}(1,0) + \beta(1, 1)y(1,0) \tag{11} β(1,0)=β(2,0)Ø(1,0)+β(1,1)y(1,0)(11)
1.3 反向传播
已知输入序列 x \mathbf{x} x 和目标输出序列 y ∗ \mathbf{y^*} y∗,训练模型就是减小负对数似然 L = − l n P ( y ∗ ∣ x ) \mathcal{L} = -ln\mathbb{P}(\mathbf{y^*}|\mathbf{x}) L=−lnP(y∗∣x),需要算偏导数 ∂ L ∂ f t k \frac{\partial\mathcal{L}}{\partial{f_t^k}} ∂ftk∂L 和 ∂ L ∂ g u k \frac{\partial\mathcal{L}}{\partial{g_u^k}} ∂guk∂L。
根据前向变量和反向变量的定义,不难得出
P ( y ∗ ∣ x ) = ∑ ( t , u ) : t + u = n α ( t , u ) β ( t , u ) ∀ n : 1 ≤ n ≤ U + T (12) \mathbb{P}(\mathbf{y^*}|\mathbf{x}) = \sum_{(t,u):t+u=n}\alpha(t,u)\beta(t,u) \ \ \ \ \ \forall n:1≤n≤U+T \tag{12} P(y∗∣x)=(t,u):t+u=n∑α(t,u)β(t,u) ∀n:1≤n≤U+T(12)
结合等式(7)(10)(12)和 L \mathcal{L} L 的定义,可以得到
∂ L ∂ P ( k ∣ t , u ) = { − α ( t , u ) β ( t , u + 1 ) P ( y ∗ ∣ x ) k = y u + 1 − α ( t , u ) β ( t + 1 , u ) P ( y ∗ ∣ x ) k = Ø 0 o t h e r w i s e (13) \frac{\partial\mathcal{L}}{\partial \mathbb{P}(k|t,u)} = \left\{ \begin{array}{} -\frac{\alpha(t,u)\beta(t,u+1)}{\mathbb{P}(\mathbf{y^*}|\mathbf{x})} & k = y_{u+1} \\ -\frac{\alpha(t,u)\beta(t+1,u)}{\mathbb{P}(\mathbf{y^*}|\mathbf{x})} & k = \text{\O} \\ 0 & otherwise \end{array}\right. \tag{13} ∂P(k∣t,u)∂L=⎩
⎨
⎧−P(y∗∣x)α(t,u)β(t,u+1)−P(y∗∣x)α(t,u)β(t+1,u)0k=yu+1k=Øotherwise(13)
由(13)又可以推出
∑ u = 0 U ∑ k ′ ∈ Y ˉ ∂ L ∂ P ( k ′ ∣ t , u ) ∂ P ( k ′ ∣ t , u ) ∂ f t k = − 1 P ( y ∗ ∣ x ) ∑ u = 0 U ( ∂ P ( y u + 1 ∣ t , u ) ∂ f t k ( α ( t , u ) β ( t , u + 1 ) ) + ∂ P ( Ø ∣ t , u ) ∂ f t k ( α ( t , u ) β ( t + 1 , u ) ) ) = − 1 P ( y ∗ ∣ x ) ∑ u = 0 U ( ∂ α ( t , u ) β ( t , u ) ∂ f t k ) = − 1 P ( y ∗ ∣ x ) ∂ P ( y ∗ ∣ x ) ∂ f t k = ∂ L ∂ f t k \begin{align*} &\sum_{u=0}^{U}\sum_{k'\in \bar{\mathcal{Y}}}\frac{\partial\mathcal{L}}{\partial \mathbb{P}(k'|t,u)}\frac{\partial \mathbb{P}(k'|t,u)}{\partial{f_t^k}} \\ =& -\frac{1}{\mathbb{P}(\mathbf{y^*}|\mathbf{x})}\sum_{u=0}^{U}\left(\frac{\partial\mathbb{P}(y_{u+1}|t,u)}{\partial{f_t^k}}(\alpha(t,u)\beta(t,u+1)) + \frac{\partial \mathbb{P}(\text{\O}|t,u)}{\partial{f_t^k}}(\alpha(t,u)\beta(t+1,u)) \right)\\ =& -\frac{1}{\mathbb{P}(\mathbf{y^*}|\mathbf{x})}\sum_{u=0}^{U}\left(\frac{\partial\alpha(t,u)\beta(t,u)}{\partial f_t^k} \right)\\ =& -\frac{1}{\mathbb{P}(\mathbf{y^*}|\mathbf{x})}\frac{\partial \mathbb{P}(\mathbf{y^*}|\mathbf{x})}{\partial f_t^k} \\ =& \frac{\partial\mathcal{L}}{\partial{f_t^k}} \end{align*} ====u=0∑Uk′∈Yˉ∑∂P(k′∣t,u)∂L∂ftk∂P(k′∣t,u)−P(y∗∣x)1u=0∑U(∂ftk∂P(yu+1∣t,u)(α(t,u)β(t,u+1))+∂ftk∂P(Ø∣t,u)(α(t,u)β(t+1,u)))−P(y∗∣x)1u=0∑U(∂ftk∂α(t,u)β(t,u))−P(y∗∣x)1∂ftk∂P(y∗∣x)∂ftk∂L
第 2 个等号是因为 ∂ L ∂ P ( k ′ ∣ t , u ) \frac{\partial\mathcal{L}}{\partial \mathbb{P}(k'|t,u)} ∂P(k′∣t,u)∂L 沿 k ′ k' k′ 求和,只有 2 项非零。第 3 个等号利用了等式(10)。于是有
∂ L ∂ f t k = ∑ u = 0 U ∑ k ′ ∈ Y ˉ ∂ L ∂ P ( k ′ ∣ t , u ) ∂ P ( k ′ ∣ t , u ) ∂ f t k (14) \frac{\partial\mathcal{L}}{\partial{f_t^k}} = \sum_{u=0}^{U}\sum_{k'\in \bar{\mathcal{Y}}}\frac{\partial\mathcal{L}}{\partial \mathbb{P}(k'|t,u)}\frac{\partial \mathbb{P}(k'|t,u)}{\partial{f_t^k}} \tag{14} ∂ftk∂L=u=0∑Uk′∈Yˉ∑∂P(k′∣t,u)∂L∂ftk∂P(k′∣t,u)(14)
同理可得
∂ L ∂ g u k = ∑ t = 1 T ∑ k ′ ∈ Y ˉ ∂ L ∂ P ( k ′ ∣ t , u ) ∂ P ( k ′ ∣ t , u ) ∂ g u k (15) \frac{\partial\mathcal{L}}{\partial{g_u^k}} = \sum_{t=1}^{T}\sum_{k'\in \bar{\mathcal{Y}}}\frac{\partial\mathcal{L}}{\partial \mathbb{P}(k'|t,u)}\frac{\partial \mathbb{P}(k'|t,u)}{\partial{g_u^k}} \tag{15} ∂guk∂L=t=1∑Tk′∈Yˉ∑∂P(k′∣t,u)∂L∂guk∂P(k′∣t,u)(15)
由等式(3)可得
∂ P ( k ′ ∣ t , u ) ∂ f t k = ∂ P ( k ′ ∣ t , u ) ∂ g u k = P ( k ′ ∣ t , u ) [ δ k k ′ − P ( k ∣ t , u ) ] (16) \frac{\partial \mathbb{P}(k'|t,u)}{\partial{f_t^k}}=\frac{\partial \mathbb{P}(k'|t,u)}{\partial{g_u^k}} = \mathbb{P}(k'|t,u)[\delta_{kk'}-\mathbb{P}(k|t,u)] \tag{16} ∂ftk∂P(k′∣t,u)=∂guk∂P(k′∣t,u)=P(k′∣t,u)[δkk′−P(k∣t,u)](16)
反向传播对偏导数 ∂ L ∂ f t k \frac{\partial\mathcal{L}}{\partial{f_t^k}} ∂ftk∂L 和 ∂ L ∂ g u k \frac{\partial\mathcal{L}}{\partial{g_u^k}} ∂guk∂L 进行回传。
二、预测算法
模型训好后,通过预测算法(解码)转写出对应的文字。论文中介绍了波束搜索(Beam Search)算法,可以扩展到任意长的序列,并通过控制波束搜索宽度,对计算成本和搜索精度进行平衡。
图4 宽度为3的波束搜索 6
如图 4 所示,每一层代表时刻 t t t 的解码选择,每一层的数字是标签的序号。比如共有 K K K 个标签,当 t = 1 t=1 t=1 时,可能的路径有 K K K 条;当 t = 2 t=2 t=2 时,有 K 2 K^2 K2 条路径;当 t = n t=n t=n 时,有 K n K^n Kn 条路径,显然这会导致内存爆炸。波束搜索是确定搜索宽度 W , 1 ≤ W ≤ K W, 1≤W≤K W,1≤W≤K,每一时刻 t t t,选择路径概率最大的 W W W 个节点保存,在 t + 1 t+1 t+1 时,在这 W W W 个节点的基础上拓展,再选出 t + 1 t+1 t+1 时刻路径概率最大的 W W W 个节点,以此类推,将节点拓展下去,最终概率最大的路径就是转写结果。CTC 中的最佳路径解码就是 W = 1 W=1 W=1 的情形,由此推知,波束搜索也是近似算法。
下面是语音工具包 k2 对论文中波束搜索算法的实现 7。
# k2 中的 beam search 实现
def beam_search(
model: Transducer,
encoder_out: torch.Tensor,
beam: int = 4,
) -> List[int]:
"""
It implements Algorithm 1 in https://arxiv.org/pdf/1211.3711.pdf
Args:
model:
An instance of `Transducer`.
encoder_out:
A tensor of shape (N, T, C) from the encoder. Support only N==1 for now.
beam:
Beam size.
Returns:
Return the decoded result.
"""
assert encoder_out.ndim == 3
# support only batch_size == 1 for now
assert encoder_out.size(0) == 1, encoder_out.size(0)
blank_id = model.decoder.blank_id # 空记号,编号 0
context_size = model.decoder.context_size # 2,bigram
device = model.device
decoder_input = torch.tensor([blank_id] * context_size, device=device).reshape(
1, context_size
)
decoder_out = model.decoder(decoder_input, need_pad=False)
T = encoder_out.size(1) # 解码音频的长度
t = 0
B = HypothesisList()
# HypothesisList 对象 B 存储当前最优的假设,形如 '0_0_48_366_65_66_6':
# Hypothesis(ys=[0, 0, 48, 366, 65, 66, 6], log_prob=tensor([-1.0889])
# 保存路径和对应的概率
B.add(
Hypothesis(
ys=[blank_id] * context_size,
log_prob=torch.zeros(1, dtype=torch.float32, device=device),
)
)
max_sym_per_utt = 20000
sym_per_utt = 0
encoder_out_len = torch.tensor([1])
decoder_out_len = torch.tensor([1])
decoder_cache: Dict[str, torch.Tensor] = {}
# 遍历每个时间步 t, 直至音频结束或符号数量上限
while t < T and sym_per_utt < max_sym_per_utt:
# fmt: off
current_encoder_out = encoder_out[:, t:t+1, :]
# fmt: on
A = B # A 承接了 t - 1 时刻的最优假设,将在此假设基础上继续延拓
B = HypothesisList() # B 中存储 t 时刻最优 beam 个假设
joint_cache: Dict[str, torch.Tensor] = {}
while True:
y_star = A.get_most_probable() # 从 A 中选取最优假设
A.remove(y_star)
decoder_out = run_decoder(
ys=y_star.ys, model=model, decoder_cache=decoder_cache
)
key = "_".join(map(str, y_star.ys[-context_size:]))
key += f"-t-{t}"
# 利用转写向量、预测向量和 joiner 计算联合概率
log_prob = run_joiner(
key=key,
model=model,
encoder_out=current_encoder_out,
decoder_out=decoder_out,
encoder_out_len=encoder_out_len,
decoder_out_len=decoder_out_len,
joint_cache=joint_cache,
)
# 先处理空记号,虽然不输出有效 token,但需要累积概率,此路径仍是潜在最优解
# First, process the blank symbol
skip_log_prob = log_prob[blank_id]
new_y_star_log_prob = y_star.log_prob + skip_log_prob
# 更新 B
# ys[:] returns a copy of ys
B.add(Hypothesis(ys=y_star.ys[:], log_prob=new_y_star_log_prob))
# 处理非空标签,对 A 进行延拓
# Second, process other non-blank labels
values, indices = log_prob.topk(beam + 1)
# 相当于图 4 中第 2 层节点 5 往节点 8 和节点 9 延拓,每次都拓展 beam + 1 个
# 这 beam + 1 个节点可能包括空记号,之所以是 beam + 1, 也是为了保证 B 最终能筛出最大的 beam 个假设
for idx in range(values.size(0)):
i = indices[idx].item()
if i == blank_id:
continue
new_ys = y_star.ys + [i]
new_log_prob = y_star.log_prob + values[idx]
A.add(Hypothesis(ys=new_ys, log_prob=new_log_prob))
# Check whether B contains more than "beam" elements more probable
# than the most probable in A
A_most_probable = A.get_most_probable()
kept_B = B.filter(A_most_probable.log_prob)
# 根据设定的束宽,筛选并保留 B 中最优的假设
if len(kept_B) >= beam:
B = kept_B.topk(beam)
break
# 进入 t + 1 层
t += 1
# 从最终保留的假设中,选取经过长度归一化处理后的最优假设
best_hyp = B.get_most_probable(length_norm=True)
# 移除上下文中的空白符号,返回解码结果
ys = best_hyp.ys[context_size:] # [context_size:] to remove blanks
return ys
总结
论文引入了预测网络和联合网络,解决了 CTC 不合理的独立假设问题。解码时候,可以对每帧输入进行预测输出,使得 RNN-T 天然具有自回归性,可以用于流式 ASR。且 AM + LM 的组合,让建模更加完备合理。
RNN-T 也有缺点,相比 CTC 的输出维度 (N, T, C), RNN-T 则是 (N, T, U, C),U 是解码器的输出长度,可达上百量级。故 RNN-T 内存占用大 ,训练复杂度高。
随着模型结构的进步,编码器中的 RNN 逐渐被各种 Former 类模型(Conformer,Zipformer等)替代。谷歌在深入研究解码器后也发现 3,RNN 结构并不是必须的,直接将 y u − 1 y_{u-1} yu−1 的嵌入(Embedding)送入联合网络,也能达到相近的效果。就像信乐团没有信,飞儿乐队没有飞一样,RNN-T 在如今的很多论文中,都只被称作 Transducer。
参考文献
-
Alex Graves. Sequence Transduction with Recurrent Neural Networks. 2012. ↩︎
-
Alex Graves, et al. Connectionist Temporal Classification: Labelling Unsegmented Sequence Data with Recurrent Neural Networks. 2006. ↩︎
-
Ghodsi, Mohammadreza , et al. “Rnn-Transducer with Stateless Prediction Network.” ICASSP 2020 - 2020 IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP) IEEE, 2020. ↩︎ ↩︎
-
洪青阳,李琳. 《语音识别原理与应用》, 2020. ↩︎
-
Alex Graves, et al. Speech recognition with deep recurrent neural networks. 2013. ↩︎
-
https://en.wikipedia.org/wiki/Beam_search. ↩︎
-
https://github.com/csukuangfj/icefall/blob/master/egs/librispeech/ASR/transducer_stateless/beam_search.py. ↩︎
火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。
更多推荐
所有评论(0)