文章标题:快速 Transformer 解码:一个写入头足矣
作者/机构:Noam Shazeer, Google

A1 主要贡献

核心问题与研究目标:Transformer 神经网络序列模型在增量推理(incremental inference)时的速度是一个主要挑战。在现代计算硬件上,其速度受限于为重载注意力层状态所需的大型“键”(keys)和“值”(values)张量而产生的内存带宽。本文旨在提出一种架构变体,以在仅有轻微质量下降的情况下,大幅提升推理速度。

创新点:本文提出了一种名为多查询注意力(multi-query attention)的变体。在这种结构中,键(keys)和值(values)在所有不同的注意力“头”(heads)之间共享。这种设计极大地减小了这些张量的尺寸,从而降低了增量解码过程中的内存带宽需求。实验证明,采用该方法的模型解码速度确实快得多,并且与基线模型相比,模型质量仅有轻微的下降。

A3 背景知识

神经注意力(Neural Attention)

神经注意力的基本功能:由 【Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate, 2014】引入的神经注意力是处理可变长度表示的强大工具。一个神经注意力函数接收一个查询向量 q 和一组 m 个不同的(键向量,值向量)对(由矩阵 KV 表示),并生成一个输出向量 y。输出 y 是不同值向量的加权和,其权重通过比较查询向量与各个键向量得出。

点积注意力(Dot-Product Attention)

点积注意力的计算方式:一种常见的实现方式是,权重由查询向量与不同键向量的点积经过 softmax 函数计算得出。以下代码描述了此过程。

def DotProductAttention(q, K, V):
    """ Dot-Product Attention on one query.
    Args:
        q: a vector with shape [k]
        K: a matrix with shape [m, k]
        V: a matrix with shape [m, v]
    Returns:
        y: a vector with shape [v]
    """
    logits = tf.einsum("k,mk->m", q, K)
    weights = tf.softmax(logits)
    return tf.einsum("m,mv->v", weights, V)

Einsum 标记法说明:文中的代码示例使用了在 TensorFlow 和 numpy 中定义的 einsum 标记法,用于任意维度张量间的广义收缩运算。在这种表示法中,一个等式指定了输入和输出张量的维度名称。该计算在数值上等同于将每个输入广播到拥有所有维度的并集,然后按元素相乘,并对所有不在期望输出形状中的维度进行求和。

多头注意力(Multi-head Attention)

多头注意力的并行结构:Transformer 序列到序列模型【Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017】并行使用 h 个不同的注意力层(头)。这 h 个不同层的查询向量(query vectors)由输入向量 x 经过 h 个不同的学习线性投影 Pq 得到。同样,键(keys)和值(values)由 m 个不同输入向量的集合 M 经过 h 个不同的学习线性投影 PkPv 得到。这 h 个层的输出本身再通过不同的学习线性投影 Po,然后相加。为简化起见,本文假设输入和输出向量具有相同的维度 d

def MultiheadAttention(x, M, P_q, P_k, P_v, P_o):
    """ Multi-head Attention on one query.
    Args:
        x: a vector with shape [d]
        M: a matrix with shape [m, d]
        P_q: a tensor with shape [h, d, k]
        P_k: a tensor with shape [h, d, k]
        P_v: a tensor with shape [h, d, v]
        P_o: a tensor with shape [h, d, v]
    Returns:
        y: a vector with shape [d]
    """
    q = tf.einsum("d,hdk->hk", x, P_q)
    K = tf.einsum("md,hdk->hmk", M, P_k)
    V = tf.einsum("md,hdv->hmv", M, P_v)
    logits = tf.einsum("hk,hmk->hm", q, K)
    weights = tf.softmax(logits)
    o = tf.einsum("hm,hmv->hv", weights, V)
    y = tf.einsum("hv,hdv->d", o, P_o)
    return y

关于缩放因子的说明:【Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017】在 logits 上包含了一个恒定的缩放因子。本文在代码中省略了它,因为它可以被整合到线性投影 PqPk 中。

批处理多头注意力(Multi-head Attention, Batched)

两种批处理方式:在实践中,将多个查询批处理在一起效率更高。下面的代码增加了两种批处理方式。首先,我们从一个序列的 n 个不同位置生成查询,这些查询都与相同的键和值进行交互。此外,我们一次处理 b 个不同的非交互序列。遵循【Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017】的做法,在一个自回归模型中,我们可以通过向 logits 添加一个在非法位置值为 -∞ 的“掩码”(mask)来防止信息的反向流动。

def MultiheadAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
    """ Multi-head Attention.
    Args:
        X: a tensor with shape [b, n, d]
        M: a tensor with shape [b, m, d]
        mask: a tensor with shape [b, h, n, m]
        P_q: a tensor with shape [h, d, k]
        P_k: a tensor with shape [h, d, k]
        P_v: a tensor with shape [h, d, v]
        P_o: a tensor with shape [h, d, v]
    Returns:
        Y: a tensor with shape [b, n, d]
    """
    Q = tf.einsum("bnd,hdk->bhnk", X, P_q)
    K = tf.einsum("bmd,hdk->bhmk", M, P_k)
    V = tf.einsum("bmd,hdv->bhmv", M, P_v)
    logits = tf.einsum("bhnk,bhmk->bhnm", Q, K)
    weights = tf.softmax(logits + mask)
    O = tf.einsum("bhnm,bhmv->bhnv", weights, V)
    Y = tf.einsum("bhnv,hdv->bnd", O, P_o)
    return Y

性能分析:为了简化性能分析,本文做出以下几点假设:$m = n$;$k = v = d/h$,如【Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017】所建议;$n \le d$。总的算术操作数量为 $\Theta(bnd^2)$(因为在简化假设下,上述每个 tf.einsum 操作的复杂度为 $O(bnd^2)$)。需要访问的内存总量等于所有相关张量大小的总和:$O(bnd + bhn^2 + d^2)$。其中第一项来自 X, M, Q, K, V, OY,第二项来自 logitsweights,第三项来自投影张量 Pq, Pk, PvPo。将两者相除,我们发现内存访问与算术操作的比率为 $O(1/k + 1/(bn))$。这种低比率对于在现代 GPU/TPU 硬件上获得良好性能是必要的,因为这些硬件的计算能力可能比内存带宽高出两个数量级。

增量式多头注意力(Multi-head Attention, Incremental)

增量计算的必要性:在某些情况下,数据依赖性使得无法并行处理来自多个位置的查询。一个例子是自回归语言模型(如 Transformer)中的自注意力层。在每个位置产生的查询会关注到截至该位置(包括该位置)的所有位置产生的键值对。在训练期间,由于真实的(ground-truth)目标序列是已知的,我们可以使用类似于 2.3 节中的高效并行实现。然而,当从训练好的模型生成序列时,特定位置的自注意力层输出会影响下一位置生成的 token,而这个 token 又会影响下一位置该层的输入。这阻止了并行计算。

def MultiheadSelfAttentionIncremental(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
    """ Multi-head Self-Attention (one step).
    Args:
        x: a tensor with shape [b, d]
        prev_K: tensor with shape [b, h, m, k]
        prev_V: tensor with shape [b, h, m, v]
        P_q: a tensor with shape [h, d, k]
        P_k: a tensor with shape [h, d, k]
        P_v: a tensor with shape [h, d, v]
        P_o: a tensor with shape [h, d, v]
    Returns:
        y: a tensor with shape [b, d]
        new_K: tensor with shape [b, h, m+1, k]
        new_V: tensor with shape [b, h, m+1, v]
    """
    q = tf.einsum("bd,hdk->bhk", x, P_q)
    new_K = tf.concat([prev_K, tf.expand_dims(tf.einsum("bd,hdk->bhk", M, P_k), axis=2)], axis=2)
    new_V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd,hdv->bhv", M, P_v), axis=2)], axis=2)
    logits = tf.einsum("bhk,bhmk->bhm", q, new_K)
    weights = tf.softmax(logits)
    o = tf.einsum("bhm,bhmv->bhv", weights, new_V)
    y = tf.einsum("bhv,hdv->bd", O, P_o)
    return y, new_K, new_V

性能分析与瓶颈:本文采用与 2.3.1 节相同的简化假设。在 n 次调用中,总算术操作数为 $\Theta(bnd^2)$。总内存访问量为 $\Theta(bn^2d + nd^2)$,第一项来自 KV,第二项来自 Pq, Pk, PvPo。将内存除以计算量,我们发现内存访问与算术操作的比率为 $\Theta(n/d + d/(bn))$。当 $n \approx d$ 或 $b \approx 1$ 时,该比率接近 1,导致内存带宽成为现代计算硬件上的主要性能瓶颈。为了使增量生成高效,我们必须将这两个项都减少到远小于 1。$d/(bn)$ 项比较容易处理,只需在内存允许的情况下使用更大的批量大小即可。减少 $n/d$ 项则更难,该项与每一步重新加载代表内存的 KV 张量(大小为 $bhmk = bn d k/d = bn^2$)的开销有关。一个解决方案是限制序列长度 n。另一个是减少被关注的位置数量,可以通过关注局部邻域,或如 【Peter J Liu, Mohammad Saleh, Etienne Pot, Ben Goodrich, Ryan Sepassi, Lukasz Kaiser, and Noam Shazeer. Generating wikipedia by summarizing long sequences. In Proceedings of the International Conference on Learning Representations, 2018】、【Biao Zhang, Deyi Xiong, and Jinsong Su. Accelerating neural transformer via an average attention network, 2018】、【Daniel Povey, Hossein Hadian, Pegah Ghahremani, Ke Li, and Sanjeev Khudanpur. A time-restricted selfattention layer for ASR. In Proceddings of the IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2018】 中那样压缩内存位置的数量。本文提出了一种正交的方法来减小 KV 张量的大小,即移除它们的“头”维度,同时在查询中保留“头”维度。

A2 方法细节

多查询注意力(Multi-Query Attention)

多查询注意力的核心设计:我们引入多查询注意力作为【Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017】中描述的多头注意力的一种变体。多头注意力由多个并行的注意力层(头)组成,对查询、键、值和输出进行不同的线性变换。多查询注意力与此相同,唯一的区别是不同的头共享同一组键(keys)和值(values)

代码实现:(增量式)多查询(自)注意力的代码与上面列出的多头注意力的代码相同,只是我们从 tf.einsum 方程中代表 KVPkPv 的“头”维度的字母 "h" 移除了。

def MultiqueryAttentionBatched(X, M, mask, P_q, P_k, P_v, P_o):
    """ Multi-Query Attention.
    Args:
        X: a tensor with shape [b, n, d]
        M: a tensor with shape [b, m, d]
        mask: a tensor with shape [b, h, n, m]
        P_q: a tensor with shape [h, d, k]
        P_k: a tensor with shape [d, k]
        P_v: a tensor with shape [d, v]
        P_o: a tensor with shape [h, d, v]
    Returns:
        Y: a tensor with shape [b, n, d]
    """
    Q = tf.einsum("bnd,hdk->bhnk", X, P_q)
    K = tf.einsum("bmd,dk->bmk", M, P_k)
    V = tf.einsum("bmd,dv->bmv", M, P_v)
    logits = tf.einsum("bhnk,bmk->bhnm", Q, K)
    weights = tf.softmax(logits + mask)
    O = tf.einsum("bhnm,bmv->bhnv", weights, V)
    Y = tf.einsum("bhnv,hdv->bnd", O, P_o)
    return Y

def MultiquerySelfAttentionIncremental(x, prev_K, prev_V, P_q, P_k, P_v, P_o):
    """ Multi-query Self-Attention (one step).
    Args:
        x: a tensor with shape [b, d]
        prev_K: tensor with shape [b, m, k]
        prev_V: tensor with shape [b, m, v]
        P_q: a tensor with shape [h, d, k]
        P_k: a tensor with shape [d, k]
        P_v: a tensor with shape [d, v]
        P_o: a tensor with shape [h, d, v]
    Returns:
        y: a tensor with shape [b, d]
        new_K: tensor with shape [b, m+1, k]
        new_V: tensor with shape [b, m+1, v]
    """
    q = tf.einsum("bd,hdk->bhk", x, P_q)
    K = tf.concat([prev_K, tf.expand_dims(tf.einsum("bd,dk->bk", M, P_k), axis=2)], axis=2)
    V = tf.concat([prev_V, tf.expand_dims(tf.einsum("bd,dv->bv", M, P_v), axis=2)], axis=2)
    logits = tf.einsum("bhk,bmk->bhm", q, K)
    weights = tf.softmax(logits)
    o = tf.einsum("bhm,bmv->bhv", weights, V)
    y = tf.einsum("bhv,hdv->bd", O, P_o)
    return y, K, V

增量式多查询注意力的性能分析

性能分析:我们采用与 2.3.1 节相同的简化假设。在 n 次调用中,总算术操作数仍为 $\Theta(bnd^2)$。总内存访问量为 $\Theta(bnd + bn^2k + nd^2)$,第一项来自 x, q, oy,第二项来自 KV,第三项来自 Pq, Pk, Pv, Po。将内存除以计算量,我们发现内存访问与算术操作的比率为 $\Theta(1/d + n/(dh) + d/(bn))$。

性能提升:我们已将有问题的 $n/d$ 项减小了 h 倍。理论上,在给定大批量 b 的情况下,这应能显著提高增量生成的性能。在实验部分,我们将展示性能提升是真实的,并且模型质量保持在高水平。

A4 实验环境

机器翻译任务

  • 数据集:WMT 2014 英德翻译任务。
  • 基线模型架构:一个 6 层的编码器-解码器 Transformer 模型。关键参数为:$d_{model} = 1024$, $d_{ff} = 4096$, $h = 8$, $d_k = d_v = 128$。使用学习式位置嵌入,并在 token 嵌入层和输出层之间共享权重。总参数量为 2.11 亿。
  • 多查询模型:将模型中所有的注意力层(编码器自注意力、解码器自注意力、编码器-解码器注意力)替换为多查询注意力。为了使总参数量与基线模型相同,将前馈隐藏层从 4096 扩大到 5440。
  • 其他对比模型
    1. 训练了基线和多查询模型的“局部”注意力版本,其中解码器自注意力层仅关注当前位置及之前的 31 个位置。
    2. 训练了几个通过减少头数 h 或键/值维度 kv 来减小 K 和 V 大小的模型,并同样扩大前馈隐藏层以匹配参数量。
  • 硬件与软件配置
    • 硬件:在 32 核的 TPUv3 集群上进行训练。
    • 软件:使用了 tensor2tensormesh-tensorflow 库的实现。
  • 训练配置:所有模型训练 100,000 步(约 20 个 epoch)。每个训练批次包含 128 个样本,每个样本由一个 256-token 的输入序列和一个 256-token 的目标序列组成。每个模型训练耗时约 2 小时。

语言模型任务

  • 数据集:Billion-Word Language Modeling Benchmark 【Ciprian Chelba, Tomas Mikolov, Mike Schuster, Qi Ge, Thorsten Brants, and Phillipp Koehn. One billion word benchmark for measuring progress in statistical language modeling. CoRR, abs/1312.3005, 2013】。
  • 基线模型架构:一个 6 层的 transformer-decoder 模型。关键参数为:$d_{model} = 1024$, $d_{ff} = 8192$, $h = 8$, $d_k = d_v = 128$。总参数量为 1.92 亿,所有变体模型参数量均与此相同。
  • 硬件与软件配置
    • 硬件:在 32 核的 TPUv3 集群上进行训练。
  • 训练配置:训练了 136,000 步(10 个 epoch),批处理大小为 64,000 个 tokens。每个模型训练耗时约 3 小时。

A4 实验结果

模型质量

  • 机器翻译任务结果(表1)
    • 在开发集上使用贪婪最大似然解码,多查询注意力模型在 BLEU 分数和每个子词 token 的困惑度(perplexity)上都略差于基线模型。
    • 然而,多查询注意力模型的结果远好于其他通过减少 hd_kd_v 的替代方案。
    • 在测试集上使用贪婪解码和集束搜索(beam 4, α = 0.6)进行验证,多查询模型的表现与基线相似,并且在 beam-4 解码下获得了最高的 BLEU 分数(28.5)。

表1:WMT14 EN-DE 结果
表1:WMT14 EN-DE 结果

  • 语言模型任务结果(表3)
    • 结果与翻译任务相似。多查询注意力模型在开发集上的每个词(非子词)困惑度略差于基线模型,但显著优于其他替代方案。

表3:Billion-Word LM Benchmark 结果
表3:Billion-Word LM Benchmark 结果

速度

  • 训练与推理速度评估环境:在一台 TPUv2(8核)上进行。
  • 训练速度(表2)
    • 一个训练步骤(包含 32,768 个输入 token 和 32,768 个目标 token)耗时:
      • 基线模型:433ms,摊销为每个(输入+目标)token 13.2µs。
      • 多查询模型:425ms,速度与基线模型相当。
  • 推理速度(表2)
    • 使用贪婪增量推理,批大小为 1024 个序列,源序列和目标序列长度均为 128。
    • 基线模型:编码器耗时 222ms(摊销 1.7µs/token),解码器每步耗时 47ms(摊销 46µs/token)。
    • 多查询模型:编码器耗时 195ms(摊销 1.5µs/token),解码器每步耗时 3.9ms(摊销 3.8µs/token)。
    • 结论:多查询模型在解码器部分的推理速度比基线模型快了超过一个数量级。集束搜索也得到了类似的结果。

表2:WMT14 EN-DE 翻译任务在序列长度为128时的摊销训练和推理成本。数值单位为每个输出token的TPUv2-微秒。
表2:WMT14 EN-DE 翻译任务在序列长度为128时的摊销训练和推理成本。数值单位为每个输出token的TPUv2-微秒。

A5 结论

本文提出了多查询注意力,作为多头注意力的替代方案,它在增量推理场景下具有低得多的内存带宽需求。我们相信,这将使得基于注意力的序列模型能够在对推理性能要求严苛的应用中得到更广泛的采用。

参考文献

  • 【Bahdanau et al., 2014】 Dzmitry Bahdanau, Kyunghyun Cho, and Yoshua Bengio. Neural machine translation by jointly learning to align and translate, 2014. (在第2节背景知识中引用,用于介绍神经注意力的起源。)
  • 【Chelba et al., 2013】 Ciprian Chelba, Tomas Mikolov, Mike Schuster, Qi Ge, Thorsten Brants, and Phillipp Koehn. One billion word benchmark for measuring progress in statistical language modeling. CoRR, abs/1312.3005, 2013. URL http://arxiv.org/abs/1312.3005. (在4.1节实验设置中引用,作为语言模型任务的数据集。)
  • 【Liu et al., 2018】 Peter J Liu, Mohammad Saleh, Etienne Pot, Ben Goodrich, Ryan Sepassi, Lukasz Kaiser, and Noam Shazeer. Generating wikipedia by summarizing long sequences. In Proceedings of the International Conference on Learning Representations, 2018. (在2.4.1节性能分析中引用,作为一种减少注意力位置数量的现有方法。)
  • 【Povey et al., 2018】 Daniel Povey, Hossein Hadian, Pegah Ghahremani, Ke Li, and Sanjeev Khudanpur. A time-restricted selfattention layer for ASR. In Proceddings of the IEEE International Conference on Acoustics, Speech and Signal Processing (ICASSP). IEEE, 2018. (在2.4.1节性能分析中引用,作为一种减少注意力位置数量的现有方法。)
  • 【Vaswani et al., 2017】 Ashish Vaswani, Noam Shazeer, Niki Parmar, Jakob Uszkoreit, Llion Jones, Aidan N. Gomez, Lukasz Kaiser, and Illia Polosukhin. Attention is all you need. In NIPS, 2017. (在全文多处引用,作为 Transformer 模型和多头注意力的基础工作。)
  • 【Zhang et al., 2018】 Biao Zhang, Deyi Xiong, and Jinsong Su. Accelerating neural transformer via an average attention network, 2018. (在2.4.1节性能分析中引用,作为一种减少注意力位置数量的现有方法。)