[論文理解] Meta Pseudo Labels
Intro
GOOGLE 21年的CVPR,提出了一種Teacher、Student都在訓練中進行優化的基於偽標籤的優化方法,最重要的是效能好,是目前引數量同等情況下在IMAGENET上精度最高的方法,TOP1 ACC高達90.2%。
文章的貢獻主要有:
- 提出一種形式化的蒸餾方法,該方法利用偽標籤同時更新Teacher網路和Student網路。
- 文章提出的方法具有超高的效能,同參數量情況下率先將IMAGENET準確率提升到90+。
Model
直入主題,文章提出的方法叫Meta Pseudo Labels,因而相較於傳統的Pseudo Labels方法多了Meta建模的過程。傳統的基於偽標籤的蒸餾方法是基於一個預訓練好的Teacher模型,利用Teacher模型提供的偽標籤作為Student模型的Target,進行訓練。而本文的方法可以通過Student模型在有標籤資料上的表現(Loss)來幫助Teacher模型優化。與其他半監督模型不太一樣的是,Teacher模型並非是通過EMA方式進行更新的,而是梯度方式。
Method
對於傳統的蒸餾方法,用偽標籤方式進行優化的過程可以描述為:
\[\theta_{S}^{\mathrm{PL}}=\underset{\theta_{S}}{\operatorname{argmin}} \underbrace{\mathbb{E}_{x_{u}}\left[\operatorname{CE}\left(T\left(x_{u} ; \theta_{T}\right), S\left(x_{u} ; \theta_{S}\right)\right)\right]}_{:=\mathcal{L}_{u}\left(\theta_{T}, \theta_{S}\right)} \]其中\(T\)
本文的想法是,利用Student模型在有標籤資料上的表現,來更新Teacher模型,那麼這種“表現”數學化其實對應的就是Student模型在有標籤資料上的Loss,因此可以表示為:
\[\mathbb{E}_{x_{l}, y_{l}}\left[\operatorname{CE}\left(y_{l}, S\left(x_{l} ; \theta_{S}^{\mathrm{PL}}\right)\right)\right]:=\mathcal{L}_{l}\left(\theta_{S}^{\mathrm{PL}}\right) \]上式的公式有一個引數\(\theta_{S}^{\mathrm{PL}}\)
這樣其實就已經完成了Meta模型的建模,即將一個模型的引數作為某一函式表達的輸入,另一模型的引數作為該函式表達的引數,經過對該引數表達的損失函式的優化,得到最優引數。
因此,要在這一過程中更新Teacher模型,則需要最小化第二個公式的損失:
\[\begin{aligned}\min _{\theta_{T}} & \mathcal{L}_{l}\left(\theta_{S}^{\mathrm{PL}}\left(\theta_{T}\right)\right), \\\text { where } & \theta_{S}^{\mathrm{PL}}\left(\theta_{T}\right)=\underset{\theta_{S}}{\operatorname{argmin}} \mathcal{L}_{u}\left(\theta_{T}, \theta_{S}\right)\end{aligned} \]很顯然,上式的argmin函式沒法用梯度方式來優化,因為得等到\(\theta_S\)達到最優,才能進行下一步,顯然會導致訓練無法端到端。文章對該問題做了一個one-step的近似:
\[\theta_{S}^{\mathrm{PL}}\left(\theta_{T}\right) \approx \theta_{S}-\eta_{S} \cdot \nabla_{\theta_{S}} \mathcal{L}_{u}\left(\theta_{T}, \theta_{S}\right) \]到這裡,上式的優化目標變成了:
\[\min _{\theta_{T}} \quad \mathcal{L}_{l}\left(\theta_{S}-\eta_{S} \cdot \nabla_{\theta_{S}} \mathcal{L}_{u}\left(\theta_{T}, \theta_{S}\right)\right) \]OK,那麼講道理如果能對該式求\(\theta_T\)的梯度,就可以利用梯度下降方法來端到端優化了,定義上式為\(R\),那麼具體求解過程為:
\[\underbrace{\frac{\partial R}{\partial \theta_{T}}}_{1 \times|T|}=\frac{\partial}{\partial \theta_{T}} \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[\theta_{S}-\eta_{S} \nabla_{\theta_{S}} \operatorname{CE}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)\right]\right)\right) \]為簡化表示,定義:
\[\underbrace{\bar{\theta}_{S}^{\prime}}_{|S| \times 1}=\mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[\theta_{S}-\eta_{S} \nabla_{\theta_{S}} \mathbf{C E}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)\right] \]則上式可以表示為:
\[\begin{aligned}\underbrace{\frac{\partial R}{\partial \theta_{T}}}_{1 \times|T|} &=\frac{\partial}{\partial \theta_{T}} \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[\theta_{S}-\eta_{S} \nabla_{\theta_{S}} \mathbf{C E}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)\right]\right)\right) \\&=\frac{\partial}{\partial \theta_{T}} \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \bar{\theta}_{S}^{\prime}\right)\right) \\&=\underbrace{\left.\frac{\partial \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \bar{\theta}_{S}^{\prime}\right)\right)}{\partial \theta_{S}}\right|_{\left.\theta_{S}=\bar{\theta}_{S}^{\prime}\right)}}_{|\times| S \mid} \cdot \underbrace{\frac{\partial \bar{\theta}_{S}^{\prime}}{\partial \theta_{T}}}_{|S| \times|T|}\end{aligned} \]上式的左邊其實是很容易利用梯度下降求解的,因為可以利用Student模型在有標籤資料集上更新前後引數相減得到梯度:
由於
\[\theta_{S}^*=\theta_{S}-\eta_{S} \nabla_{\theta_{S}} \mathbf{C E}\left(y_{l}, S\left(x_{l} ; \theta_{S}\right)\right) \]因此很容易利用更新前後\(\theta_S\)進行相減得到其梯度:
\[\eta_{S} \nabla_{\theta_{S}} \mathbf{C E}\left(y_{l}, S\left(x_{l} ; \theta_{S}\right)\right) =\theta_{S}-\theta_{S}^* \]所以現在需要聚焦到前面式子的右側項\(\frac{\partial \bar{\theta}_{S}^{\prime}}{\partial \theta_{T}}\):
將這該式展開:
\[\begin{aligned}\underbrace{\frac{\partial \bar{\theta}_{S}^{\prime}}{\partial \theta_{T}}}_{|S| \times|T|} &=\frac{\partial}{\partial \theta_{T}} \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[\theta_{S}-\eta_{S} \nabla_{\theta_{S}} \operatorname{CE}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)\right] \\&=\frac{\partial}{\partial \theta_{T}} \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[\theta_{S}-\eta_{S} \cdot\left(\left.\frac{\partial \operatorname{CE}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)}{\partial \theta_{S}}\right|_{\theta_{S}=\theta_{S}}\right)^{\top}\right]\end{aligned} \]為了簡化表示,我們再次定義:
\[\underbrace{g_{S}\left(\widehat{y}_{u}\right)}_{|S| \times|1|}=\left(\left.\frac{\partial \operatorname{CE}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)}{\partial \theta_{S}}\right|_{\theta_{S}=\theta_{S}}\right)^{\top} \]那麼上式就變成了:
\[\underbrace{\frac{\partial \bar{\theta}_{S}^{\prime}}{\partial \theta_{T}}}_{|S| \times|T|}=-\eta_{S} \cdot \frac{\partial}{\partial \theta_{T}} \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}[\underbrace{g_{S}\left(\widehat{y}_{u}\right)}_{|S| \times 1}] \]這裡的\(g_{S}\left(\widehat{y}_{u}\right)\)並不依賴\(\theta_T\)的,只是\(\widehat{y}_{u}\)需要利用偽標籤演算法依賴Teacher模型的引數罷了,舉一個例子:
\[\frac{\partial }{\partial \theta} \mathbb{E}_{x \sim p(x;\theta)}f(x) \]這個式子要求梯度可以這麼做:
\[\frac{\partial }{\partial \theta} \mathbb{E}_{x \sim p(x;\theta)}f(x) \\ = \frac{\partial}{\partial \theta} \int p(x;\theta)f(x)dx \\ = \int \frac{\partial}{\partial \theta}p(x;\theta)f(x)dx \\ = \int p(x;\theta) \nabla_\theta log(p(x;\theta))f(x)dx \\ =\mathbb{E}_{x \sim p(x;\theta)} f(x)\nabla_\theta log(p(x;\theta)) \]那麼同理呀,上式可以寫成:
\[\frac{\partial }{\partial \theta_T} \mathbb{E}_{\hat{y}_u \sim T(x_u;\theta_T)}[g_s(\hat{y}_u)] \\ = \frac{\partial}{\partial \theta_T} \sum_{\hat{y}_u} p(\hat{y}_u|x_u;\theta_T)g_s(\hat{y}_u) \\ = \sum_{\hat{y}_u} \frac{\partial}{\partial \theta_T}p(\hat{y}_u|x_u;\theta_T)g_s(\hat{y}_u) \\ = \sum_{\hat{y}_u} p(\hat{y}_u|x_u;\theta_T) \frac{\partial}{\partial \theta_T} log(p(\hat{y}_u|x_u;\theta_T)g_s(\hat{y}_u) \\ =\mathbb{E}_{\hat{y}_u \sim T(x_u;\theta_T)} [g_s(\hat{y}_u) \frac{\partial}{\partial \theta_T} log(p(\hat{y}_u|x_u;\theta_T)] \]因此有:
\[\begin{aligned}\underbrace{\frac{\partial \bar{\theta}_{S}^{(t+1)}}{\partial \theta_{T}}}_{|S| \times|T|} &=-\eta_{S} \cdot \frac{\partial}{\partial \theta_{T}} \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[g_{S}\left(\widehat{y}_{u}\right)\right] \\&=-\eta_{S} \cdot \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}[\underbrace{g_{S}\left(\widehat{y}_{u}\right)}_{|S| \times 1} \underbrace{\cdot \underbrace{\frac{\partial \log P\left(\widehat{y}_{u} \mid x_{u} ; \theta_{T}\right)}{\partial \theta_{T}}}_{1 \times|T|}]}\\&=\eta_{S} \cdot \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}[\underbrace{g_{S}\left(\widehat{y}_{u}\right)}_{|S| \times 1} \cdot \underbrace{\frac{\partial \operatorname{CE}\left(\widehat{y}_{u}, T\left(x_{u} ; \theta_{T}\right)\right)}{\partial \theta_{T}}}_{1 \times|T|}]\end{aligned} \]到這一步就可以利用交叉熵損失項來計算該部分梯度了。
到這裡,再整理一下上面提到的左項和右項:
\[\begin{aligned}\underbrace{\frac{\partial R}{\partial \theta_{T}}}_{1 \times|T|} &=\underbrace{\left.\frac{\partial \mathbf{C E}\left(y_{l}, S\left(x_{l} ; \bar{\theta}_{S}^{\prime}\right)\right)}{\partial \theta_{S}}\right|_{\theta_{S}=\bar{\theta}_{S}^{\prime}}}_{1 \times|S|} \underbrace{\frac{\partial \bar{\theta}_{S}^{\prime}}{\partial \theta_{T}}}_{|S| \times|T|} \\&=\eta_{S} \cdot \underbrace{\left.\frac{\partial \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \bar{\theta}_{S}^{\prime}\right)\right)}{\partial \theta_{S}}\right|_{\theta_{S}=\bar{\theta}_{S}^{\prime}}}_{1 \times|S|} \cdot \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}[\underbrace{g_{S}\left(\widehat{y}_{u}\right)}_{|S| \times 1} \cdot \underbrace{\frac{\partial \operatorname{CE}\left(\widehat{y}_{u}, T\left(x_{u} ; \theta_{T}\right)\right)}{\partial \theta_{T}}}_{1 \times|T|}]\end{aligned} \]上式均值項需要進行取樣才能計算(過程就是對batch內樣本計算),以batch內一個樣本為例,其梯度為:
\[\begin{aligned}\nabla_{\theta_{T}} \mathcal{L}_{l} &=\eta_{S} \cdot \underbrace{\frac{\partial \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \theta_{S}^{\prime}\right)\right)}{\partial \theta_{S}}}_{1 \times|S|} \cdot \underbrace{\left(\left.\frac{\partial \mathbf{C E}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)}{\partial \theta_{S}}\right|_{\theta_{S}=\theta_{S}}\right)^{\top}}_{|S| \times 1} \cdot \underbrace{\frac{\partial \operatorname{CE}\left(\widehat{y}_{u}, T\left(x_{u} ; \theta_{T}\right)\right)}{\partial \theta_{T}}}_{1 \times|T|} \\&=\underbrace{\eta_{S} \cdot\left(\left(\nabla_{\theta_{S}^{\prime}} \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \theta_{S}^{\prime}\right)\right)^{\top} \cdot \nabla_{\theta_{S}} \operatorname{CE}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)\right)\right.}_{\text {A scalar }:=h} \cdot \nabla_{\theta_{T}} \mathbf{C E}\left(\widehat{y}_{u}, T\left(x_{u} ; \theta_{T}\right)\right)\end{aligned} \]可以看到左端項其實利用矩陣乘法已經是一個scalar了,右端項為一vector。
到這裡,理論推導部分就結束了。
Algorithm
演算法部分相當簡單,和UDA損失一起使用,那麼基本分為兩個過程,首先利用Teacher模型提供偽標籤,優化更新Student模型,然後利用上面計算的公式,求出scalar h,代入求得梯度項,更新Teacher模型;Teacher和Student模型交替進行優化。
Experiments
實驗上,一般先在CIFAR10等這樣的小資料集上進行一輪比較:
基本是最優的。
在IMAGENET上進行全監督實驗時,其實也是按半監督的方式來做的,只是將IMAGENET的全部樣本當作有標籤的樣本,然後再用了1.3億張JFT資料集當無標籤樣本訓練的。
其實驗結果:
300+M的模型就已經達到90%的acc了。。
Code
Ref: https://github.com/kekmodel/MPL-pytorch
batch_size = images_l.shape[0]
t_images = torch.cat((images_l, images_uw, images_us))
t_logits = teacher_model(t_images)
t_logits_l = t_logits[:batch_size]
t_logits_uw, t_logits_us = t_logits[batch_size:].chunk(2)
del t_logits
t_loss_l = criterion(t_logits_l, targets)
soft_pseudo_label = torch.softmax(t_logits_uw.detach()/args.temperature, dim=-1)
max_probs, hard_pseudo_label = torch.max(soft_pseudo_label, dim=-1)
mask = max_probs.ge(args.threshold).float()
t_loss_u = torch.mean(
-(soft_pseudo_label * torch.log_softmax(t_logits_us, dim=-1)).sum(dim=-1) * mask
)
weight_u = args.lambda_u * min(1., (step+1) / args.uda_steps)
t_loss_uda = t_loss_l + weight_u * t_loss_u
s_images = torch.cat((images_l, images_us))
s_logits = student_model(s_images)
s_logits_l = s_logits[:batch_size]
s_logits_us = s_logits[batch_size:]
del s_logits
s_loss_l_old = F.cross_entropy(s_logits_l.detach(), targets)
s_loss = criterion(s_logits_us, hard_pseudo_label)
s_scaler.scale(s_loss).backward()
if args.grad_clip > 0:
s_scaler.unscale_(s_optimizer)
nn.utils.clip_grad_norm_(student_model.parameters(), args.grad_clip)
s_scaler.step(s_optimizer)
s_scaler.update()
s_scheduler.step()
if args.ema > 0:
avg_student_model.update_parameters(student_model)
with amp.autocast(enabled=args.amp):
with torch.no_grad():
s_logits_l = student_model(images_l)
s_loss_l_new = F.cross_entropy(s_logits_l.detach(), targets)
# dot_product = s_loss_l_new - s_loss_l_old
# test
dot_product = s_loss_l_old - s_loss_l_new
# moving_dot_product = moving_dot_product * 0.99 + dot_product * 0.01
# dot_product = dot_product - moving_dot_product
_, hard_pseudo_label = torch.max(t_logits_us.detach(), dim=-1)
t_loss_mpl = dot_product * F.cross_entropy(t_logits_us, hard_pseudo_label)
t_loss = t_loss_uda + t_loss_mpl
t_scaler.scale(t_loss).backward()
if args.grad_clip > 0:
t_scaler.unscale_(t_optimizer)
nn.utils.clip_grad_norm_(teacher_model.parameters(), args.grad_clip)
t_scaler.step(t_optimizer)
t_scaler.update()
t_scheduler.step()
teacher_model.zero_grad()
student_model.zero_grad()
if args.world_size > 1:
s_loss = reduce_tensor(s_loss.detach(), args.world_size)
t_loss = reduce_tensor(t_loss.detach(), args.world_size)
t_loss_l = reduce_tensor(t_loss_l.detach(), args.world_size)
t_loss_u = reduce_tensor(t_loss_u.detach(), args.world_size)
t_loss_mpl = reduce_tensor(t_loss_mpl.detach(), args.world_size)
mask = reduce_tensor(mask, args.world_size)