推导思路

我们被要求推导公式(5)的由来。根据原文,公式(5)用于计算在接纳新请求r后,新批次R’在接下来的P(r)P(r)P(r)步中的Estimated TPOT(每令牌处理时间)。

回顾公式(4):

ITL(∣R∣,Lavg(R))=α⋅∣R∣⋅Lavg(R)+β⋅∣R∣+γ⋅Lavg(R)+δITL(|R|, Lavg(R)) = α · |R| · Lavg(R) + β · |R| + γ · Lavg(R) + δITL(R,Lavg(R))=αRLavg(R)+βR+γLavg(R)+δ

其中,ITL表示批处理大小为|R|,平均序列长度为Lavg(R)Lavg(R)Lavg(R)时的词元间延迟(即处理一个令牌所需的时间)。

在公式(5)中:

EstimatedTPOT(∣R′∣,Lavg(R′),P)=ε×(α⋅∣R′∣+γ)⋅(Lavg(R′)+P(r)/2)+β⋅∣R′∣+δEstimatedTPOT(|R'|, Lavg(R'), P) = ε × { (α·|R′| + γ) · (Lavg(R′) + P(r)/2) + β·|R′| + δ }EstimatedTPOT(R,Lavg(R),P)=ε×(αR′∣+γ)(Lavg(R)+P(r)/2)+βR′∣+δ

注意:TPOT(Time Per Output Token)即每输出令牌的时间,这里实际上就是ITL,因为ITL就是处理一个令牌的时间。但是,由于新请求加入后,批次中的请求的序列长度会随着生成令牌而增加,因此平均序列长度会动态变化。公式(5)的推导基于以下关键点:

  1. 保守假设:假设当前批次中的所有请求(包括新加入的请求)都将持续生成P(r)P(r)P(r)个令牌(即新请求的预测输出长度)。这意味着在接下来的P(r)P(r)P(r)步中,批次大小保持为∣R′∣|R'|R,但平均序列长度会逐渐增加。

  2. 平均序列长度的变化:在生成令牌的过程中,每个请求的序列长度都在增加。因此,在接下来的P(r)P(r)P(r)步中,平均序列长度并不是一个固定值,而是从当前的Lavg(R′)Lavg(R')Lavg(R)开始,每步增加(因为每个请求每一步都生成一个令牌,所以整个批次的平均序列长度每一步增加1)。因此,在P(r)P(r)P(r)步内,平均序列长度是一个线性增长的过程。

  3. 对ITL的积分:由于ITL依赖于平均序列长度LavgLavgLavg,而LavgLavgLavg随时间变化,所以我们需要计算在接下来的P(r)P(r)P(r)步中,处理整个批次所有令牌的总时间,然后除以令牌总数(即批次大小乘以步数)得到平均的每令牌处理时间(即TPOT)。但是原文采用了一种近似方法:用平均序列长度的平均值来估算这段时间内的平均ITL。

设当前时刻(迭代t)的平均序列长度为L0=Lavg(R′)L_0 = Lavg(R')L0=Lavg(R)

在接下来的第1步,平均序列长度变为L0+1L_0+1L0+1(因为每个请求都生成了一个新令牌,所以平均序列长度增加1)。

在接下来的第k步,平均序列长度为L0+kL_0 + kL0+k

因此,在接下来的P(r)P(r)P(r)步中,平均序列长度从L_0线性增长到L0+P(r)L_0+P(r)L0+P(r)

那么,在这P(r)P(r)P(r)步中,每一步的平均序列长度可以取开始和结束的平均值:

Lavg=[L0+(L0+P(r))]/2=L0+P(r)/2L_{avg} = [L_0 + (L_0 + P(r))] / 2 = L_0 + P(r)/2Lavg=[L0+(L0+P(r))]/2=L0+P(r)/2

因此,我们可以用这个平均序列长度Lavg=L0+P(r)/2L_{avg} = L_0 + P(r)/2Lavg=L0+P(r)/2 来代表整个P(r)P(r)P(r)步过程中的平均序列长度。

然后,将当前批处理大小|R’|和这个平均序列长度LavgL_{avg}Lavg代入公式(4)来估算每一步的ITL(即每令牌处理时间):

ITLavg=α⋅∣R′∣⋅(L0+P(r)/2)+β⋅∣R′∣+γ⋅(L0+P(r)/2)+δITL_avg = α · |R'| · (L_0 + P(r)/2) + β · |R'| + γ · (L_0 + P(r)/2) + δITLavg=αR(L0+P(r)/2)+βR+γ(L0+P(r)/2)+δ

但是注意,这个ITLavgITL_{avg}ITLavg是每一步(即处理一个令牌)的时间,所以整个批次处理P®步(即每个请求生成P(r)P(r)P(r)个令牌)的总时间应该是:

TotalTime=ITLavg∗P(r)TotalTime = ITL_{avg} * P(r)TotalTime=ITLavgP(r) (因为每一步处理一个批次的所有令牌,但时间是一个ITL_avg,所以P(r)P(r)P(r)步的总时间就是ITL_avg乘以步数P(r)P(r)P(r)

然而,这里我们要求的TPOT(每令牌处理时间)应该是总时间除以总处理的令牌数。总处理的令牌数为:∣R′∣∗P(r)|R'| * P(r)RP(r)(因为每个步骤生成|R’|个令牌,共P(r)P(r)P(r)步)。

因此,TPOT=TotalTime/(∣R′∣∗P(r))=(ITLavg∗P(r))/(∣R′∣∗P(r))=ITLavg/∣R′∣TPOT = TotalTime / (|R'| * P(r)) = (ITL_{avg} * P(r)) / (|R'| * P(r)) = ITL_{avg} / |R'|TPOT=TotalTime/(RP(r))=(ITLavgP(r))/(RP(r))=ITLavg/∣R

但是,注意这个结果并不是我们想要的,因为这样计算出来的TPOT实际上是每一步中每个令牌的平均处理时间(即ITLavgITL_{avg}ITLavg除以批大小?),这与我们之前对ITL的定义(整个批次处理一个令牌的时间)不一致。实际上,ITLavgITL_{avg}ITLavg本身就是整个批次处理一个令牌所需的时间(即每步的时间),所以它已经是我们通常意义上的TPOT(每输出令牌的时间).

Logo

火山引擎开发者社区是火山引擎打造的AI技术生态平台,聚焦Agent与大模型开发,提供豆包系列模型(图像/视频/视觉)、智能分析与会话工具,并配套评测集、动手实验室及行业案例库。社区通过技术沙龙、挑战赛等活动促进开发者成长,新用户可领50万Tokens权益,助力构建智能应用。

更多推荐