ML 中一些在 Inference Training 下表现不同的算子

| Views: | Total Words: 5k | Reading Time: 5 mins.

前言

这个其实是最近自己试图在一个 IR 上复现 ResNet 这种模型遇到的问题,进行了一些考察并写下此文权当记录了。

如标题所言,有一类算子(或者说 nn Module,取决于 IR 的粒度和设计)在 Inference 模式和 Training 模式下有着不同的表现。当然有的 IR 可能都没有 Inference / Training 的概念,但它也一定要能区分两种情况。大部分框架的解决方式是在调用其时加一个 bool flag(training: bool),或者在整个 Mod 中用一个 flag 表示整个 Mod 当前处于哪种 Mode。例如,Keras 调用 fit() 就是 Training Mode,调用 evaluate() 或者 predict() 就是 Inference Mode。

这类算子的这种特性在有些时候会对我们的 IR / Script 的设计带来困难。比如我们有一个静态图 G,这个算子 op 已经被写死在图中了,我们又不希望在全局加个 flag 表示当前状态,那么我们要如何在同一张图 G 中让 op 根据我们的需求在我们 forward / backward 这张图时表现不同呢?有一种方式是,我们可以写一个图 Pass,将这类 op 的 inference / training 状态根据我们的需要 rewrite 掉。

这里我们就不细讲用什么方法去实现这种 “表现不同” 了,我们来讲讲为什么它们需要有这种不同的表现——从它们后面数学原理的角度。简单地在 torch.nn.functional 下面搜索 training: bool,就可以找到所有有这种特性的算子。具体来讲有三类:

  • Batch Normalization
  • Dropout
  • Randomized Leaky ReLU

Batch Normalization

大名鼎鼎的 batch norm,膜拜一下原论文:Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

BatchNorm 的目的当然就是 normalization,由于数据本身的规模不同,且经过不同层的分布也不同(被原论文称为 Internal Covariate Shift 的现象),由此我们必须选择较低的学习率以及精细的参数初始化。Why BN works 其实还是一个很难严谨回答的问题,但是可以确定的是,在 Normalization 过后数据分布的均值和方差相同了,在实验上表现出更快的收敛效果以及更好的稳定性。

以 BatchNorm2d 为例,对于一个 4D 的 input (N, C, H, W),它会对 input 求一个经验上的,一个 mini-batch 上的均值和方差(因此叫做 Batch Normalization)。注意,似乎按这个说法,Batch Norm 直接对 batch size N 取平均值就好,但是主流的实现都是对 axis=[N, H, W] 这三维一起取平均值。也就是说,得到的 mean 和 var 都是 (C,) 的 shape。这是因为,对于一个 feature map (H, W) 我们本身也要 normalize 它,这件事可以和对 batch 的 normalize 一起做,因此就这样写了。注意,为了后面方便 broadcast(显然 (N, C, H, W) 是不能和 (C,) broadcast 的),通常会把均值和方差 reshape 成 (1, 1, C, 1) 的样子。

Normalization 的式子就是这样

其中多出来的部分可以先不用管。eps 看这个位置显然是为了 numeric stability 的,后面两个是 trainable 的参数,可以在训练过程中不断调整 BN 的效果(详情请见 affine transform)。

到目前为止貌似还没有涉及到需要 Inference / Training 不同行为的地方。但是我们可以来想一个问题:在 inference 时,通常我们并不会输入一个 mini-batch,可能只有一个 sample。这时候的 Batch Norm 就会显得很怪。自然我们会想到利用 Training 时候收集的数据来代表均值和方差,在假设两者分布是一样的情况下。那我们就有了一个基本的设计:在 BN 中维护两个状态(internal state)表示训练过程中收集到的 mean 和 var,然后在 inference 过程中将其运用于 BN 的正向计算。

更进一步的,由于训练时候我们是一个 batch 一个 batch 地将数据喂进去,我们并不是事先知道所有的数据,所以我们需要随着时间 t 维护一个滑动的 mean 和 var。在 Torch 中被叫做 running_mean / var,TF 中叫做 moving_mean / var。通常采用 EMA(Exponential Moving Average) 的方式维护。当然也有别的方法,比如 FAIR 这里实现的一个叫 precise_bn 的方法。

所以我们可以看到,batch norm 需要的行为就是:

  • Training 时候收集 statistical 相关数据
  • Inference 的时候用它来算

Dropout

Dropout,NN 中著名的奇技淫巧之一,通过随机把数据给丢掉避免 overfit。我当初刚听到这个东西的时候,感觉很奇妙,但是潜意识里感觉它也是有道理的,理解上就像随机给一个系统来点破坏,这样如果你往 overfit 的方向走了,下轮训练当一些节点被 drop 之后你的 loss 又会很大。

当然,这个丢弃数据只有在 Training 时候是 make sense 的(在 Inference 的时候,你的网络都训好了,这时候我们只需要稳定的预测结果就好)。因此理论上在 Inference 时候 Dropout 就可以去掉了,或者说退化成 Identity。这就是 Dropout 行为不同的原因。

不过这里还有个事,就是关于 Drop 之后的 rescaling。我搜索资料发现一般有两种说法,一是在 Training 端 rescale,也就是将数据统一乘上 $\frac{1}{1−p}$ (期望下有 1−$p$ 的数据会留下来);二是在 Inference 端 rescale,也就是将 Identity 换成一个乘以 1−$p$ 。这两种流派分别被称为 Inverted Dropout 和 Vanilla Dropout(听名字也知道后者是先有的,然后发展到前者)。现在主流框架的实现都是采用前者。

Randomized Leaky ReLU

Leaky ReLU 的提出是为了解决 Dead ReLU Problem——0梯度太多造成反向传播就像 “一潭死水”。而 RReLU 则将 Leaky ReLU 里的那个小线性系数 $\alpha$ 变成了一个随机变量,这被证明是更有效的:Empirical Evaluation of Rectified Activations in Convolutional Network

这里的 Training 和 Inference 不同行为其实和 Dropout 也有点像,就是希望消去 Inference 时的随机性。因此在 Training 的过程中, $\alpha$ 从一个 uniform distribution sample 出来,而到了 Inference 时则是直接取期望值 $\frac{\text{lower} + \text{upper}}{2}$。

Reference

Torch 官方文档

Tensorflow API doc

PaddlePaddle API 文档

facebookresearch/fvcore

小小将:BatchNorm避坑指南

https://www.zhihu.com/question/61751133

Batch Normalization: Accelerating Deep Network Training by Reducing Internal Covariate Shift

Empirical Evaluation of Rectified Activations in Convolutional Network

Author: SiriusNEO

Published on: Metric Space

All posts on this blog are licensed under the CC BY-NC-SA 4.0 license unless otherwise noted.