张卫强

个人信息Personal Information

教师英文名称:Wei-Qiang Zhang

教师拼音名称:Zhang Wei Qiang

电子邮箱:

办公地点:电子工程馆5-111

联系方式:010-62781847

学位:博士学位

毕业院校:清华大学

学科:信号与信息处理

教师博客

当前位置: 中文主页 >> 教师博客

任务无关的语音预训练模型结构化剪枝

点击次数:

结构化剪枝是一种广泛使用的模型压缩方式,可以从预训练的模型中移除线性层维度或是整个注意力层等参数组,压缩得到的模型不依赖特定硬件就能实现加速。本文介绍了清华大学语音与音频技术实验室在INTERSPEECH 2023上发表的文章《Task-Agnostic Structured Pruning of Speech Representation Models》, 作者对WavLM模型进行了任务无关的结构化剪枝,去除72%参数,推理速度加快一倍。

H. Wang, S. Wang, W.-Q. Zhang, H. Suo, and Y. Wan, “Task-agnostic structured pruning of speech representation models,” in Proc. Interspeech, 2023, pp. 231–235. doi: 10.21437/Interspeech.2023-1442.

基于L0正则化的剪枝

神经网络剪枝的一个核心问题是判断哪些权重是重要的,值得保留,哪些权重又是不重要的,可以从网络中删除,而常用的判断标准包括权重的大小、权重的梯度大小等等。Louizos等人提出了一种可学习的权重重要性判断标准,根据损失函数让网络自行决定保留哪些参数,并通过L0正则化来限制被保留的参数的数量。

具体来说,这种方法在训练阶段将一个可学习的掩码与权重相乘,如果掩码在学习后变为0,则移除这个参数,否则保留这个参数。掩码被建模为一个伯努利分布的随机变量z,学习的内容就是这个分布的参数p。

然而,由于从随机分布中采样的过程不可导,这个掩码不能通过梯度下降的方式进行优化。Louizos等人使用了一种重参数化技巧来解决这个问题。可学习的随机变量z通过两个变量计算得到,一个变量提供“随机性”,是从固定分布中采样得的;另一个变量提供“可学习性”,它的计算过程完全可导。具体来说,掩码通过如下方式计算:

图片1.jpg

其中α就是可学习的变量,u从均匀分布中采样得到。变量首先被缩放到一个下界略小于0、上界略大于1的空间,再通过tanh函数截断以使得z能精确地取0或1。训练结束后,若z=0,那么对应参数被删除,若为1或其他值,则参数被保留。下图展示了变量z和中间变量s的分布。

图片2.png

掩码(橘色)和中间变量(绿色)的分布

多尺度结构化剪枝

非结构化剪枝将矩阵变为稀疏矩阵,需要特殊的硬件进行加速;而非结构化剪枝则减少矩阵的维度或是将矩阵整个移除,在任意硬件上均可以加速。

基于L0正则化的方法可以应用于结构化剪枝。为网络中各种需要剪枝的结构(例如线性层维度、注意力头或是整个注意力层)分配一个剪枝掩码,即可进行结构化的剪枝。

Transformer网络中的结构具有不同尺度,线性层维度、注意力头等结构是小尺度结构,而整个注意力层或线性层则是大尺度结构。

在去除同样数量的参数的情况下,集中地移除少数大尺度结构比均匀地移除多个小尺度结构能获得更快的计算速度。因此,普林斯顿大学的陈丹琦等人提出了CoFi剪枝方法,并行地移除网络中的小尺度结构和大尺度结构,在提供更多剪枝自由度的同时显式地鼓励整个剪除大尺度结构。

使用直通估计优化多尺度剪枝

剪枝掩码z计算的过程需要通过tanh函数来保证能精准地取0或1,而tanh函数的截断却又使得梯度无法向后传递。

图片3.png

上图展示了z与logα的函数关系。可以看到,如果logα取值过大或过小,tanh函数就会进入截断区,梯度无法向后传播,logα也就无法被更新;换句话说,网络无法评估其剪枝决策的合理性。一般来说,大尺度的参数组更加重要,对应的logα也就会比较大,被更新的概率就会比较小。实验中,整个注意力层或线性层很少被直接移除。为了解决这个问题,我们将直通估计(Straight Trough Estimator,STE)引入计算过程,让梯度直接通过tanh函数,保证可学习参数能够被顺畅地更新。

图片4.jpg

上图展示了加入直通估计前后网络中剩余的注意力头的分布,深蓝色的方块表示留下的注意力头。可以看到,加入直通估计后,网络中的注意力头被整个移除,而实际测试中,加入STE后,剪枝得到的网络计算速度增加47%,比剪枝前的原始网络快2倍。

图片5.jpg

上图展示模型在SUPERB评测榜上的得分与参数量的关系。可以看到,我们的模型比基线的DistilHuBERT系列模型有更好的效果,同时在参数量显著小的情况下,比Wav2vec2.0 base模型有更好的平均表现。

在进行剪枝训练时,我们将网络剪枝和知识蒸馏结合在一起,用未经剪枝的网络指导剪枝训练,使用两个模型表征的MSE损失进行优化。因此,我们的方法是一种任务无关的方法,一次剪枝得到的模型可以被直接用于各种下游任务。

结论

我们提出了一种语音预训练模型的任务无关剪枝方法,对WavLM模型进行了压缩。通过在L0正则化方法中引入直通估计,我们保证了可学习剪枝掩码能被顺利更新,从而加快剪枝网络的计算速度。我们实现的压缩模型比原始模型减少72%的参数,计算速度加快一倍。