1. 程式人生 > 其它 >[論文理解] Meta Pseudo Labels

[論文理解] Meta Pseudo Labels

Intro

GOOGLE 21年的CVPR,提出了一種Teacher、Student都在訓練中進行優化的基於偽標籤的優化方法,最重要的是效能好,是目前引數量同等情況下在IMAGENET上精度最高的方法,TOP1 ACC高達90.2%。

文章的貢獻主要有:

  1. 提出一種形式化的蒸餾方法,該方法利用偽標籤同時更新Teacher網路和Student網路。
  2. 文章提出的方法具有超高的效能,同參數量情況下率先將IMAGENET準確率提升到90+。

Model

直入主題,文章提出的方法叫Meta Pseudo Labels,因而相較於傳統的Pseudo Labels方法多了Meta建模的過程。傳統的基於偽標籤的蒸餾方法是基於一個預訓練好的Teacher模型,利用Teacher模型提供的偽標籤作為Student模型的Target,進行訓練。而本文的方法可以通過Student模型在有標籤資料上的表現(Loss)來幫助Teacher模型優化。與其他半監督模型不太一樣的是,Teacher模型並非是通過EMA方式進行更新的,而是梯度方式。

Method

對於傳統的蒸餾方法,用偽標籤方式進行優化的過程可以描述為:

\[\theta_{S}^{\mathrm{PL}}=\underset{\theta_{S}}{\operatorname{argmin}} \underbrace{\mathbb{E}_{x_{u}}\left[\operatorname{CE}\left(T\left(x_{u} ; \theta_{T}\right), S\left(x_{u} ; \theta_{S}\right)\right)\right]}_{:=\mathcal{L}_{u}\left(\theta_{T}, \theta_{S}\right)} \]

其中\(T\)

表示Teacher模型,\(S\)表示Student模型,\(\theta\)表示該模型的引數,\(CE\)為交叉熵損失函式,\(\theta_S^{PL}\)為利用偽標籤方法得到的最優Student模型引數。該部分是在無標籤的資料上進行的,因為傳統的半監督學習方法就是有標籤部分損失加上無標籤部分的一致性損失,蒸餾過程對應無標籤部分的一致性損失。

本文的想法是,利用Student模型在有標籤資料上的表現,來更新Teacher模型,那麼這種“表現”數學化其實對應的就是Student模型在有標籤資料上的Loss,因此可以表示為:

\[\mathbb{E}_{x_{l}, y_{l}}\left[\operatorname{CE}\left(y_{l}, S\left(x_{l} ; \theta_{S}^{\mathrm{PL}}\right)\right)\right]:=\mathcal{L}_{l}\left(\theta_{S}^{\mathrm{PL}}\right) \]

上式的公式有一個引數\(\theta_{S}^{\mathrm{PL}}\)

,可以看到這個引數其實由上面第一個公式定義,因而可以看作是以\(\theta_T\)作為輸入變數,\(\theta_S\)作為優化引數的函式形式,因而可以寫為\(\theta_{S}^{\mathrm{PL}}(\theta_T)\),那麼上面第二個公式的損失可以定義為\(\mathcal{L}_{l}\left(\theta_{S}^{\mathrm{PL}}(\theta_T)\right)\)

這樣其實就已經完成了Meta模型的建模,即將一個模型的引數作為某一函式表達的輸入,另一模型的引數作為該函式表達的引數,經過對該引數表達的損失函式的優化,得到最優引數。

因此,要在這一過程中更新Teacher模型,則需要最小化第二個公式的損失:

\[\begin{aligned}\min _{\theta_{T}} & \mathcal{L}_{l}\left(\theta_{S}^{\mathrm{PL}}\left(\theta_{T}\right)\right), \\\text { where } & \theta_{S}^{\mathrm{PL}}\left(\theta_{T}\right)=\underset{\theta_{S}}{\operatorname{argmin}} \mathcal{L}_{u}\left(\theta_{T}, \theta_{S}\right)\end{aligned} \]

很顯然,上式的argmin函式沒法用梯度方式來優化,因為得等到\(\theta_S\)達到最優,才能進行下一步,顯然會導致訓練無法端到端。文章對該問題做了一個one-step的近似:

\[\theta_{S}^{\mathrm{PL}}\left(\theta_{T}\right) \approx \theta_{S}-\eta_{S} \cdot \nabla_{\theta_{S}} \mathcal{L}_{u}\left(\theta_{T}, \theta_{S}\right) \]

到這裡,上式的優化目標變成了:

\[\min _{\theta_{T}} \quad \mathcal{L}_{l}\left(\theta_{S}-\eta_{S} \cdot \nabla_{\theta_{S}} \mathcal{L}_{u}\left(\theta_{T}, \theta_{S}\right)\right) \]

OK,那麼講道理如果能對該式求\(\theta_T\)的梯度,就可以利用梯度下降方法來端到端優化了,定義上式為\(R\),那麼具體求解過程為:

\[\underbrace{\frac{\partial R}{\partial \theta_{T}}}_{1 \times|T|}=\frac{\partial}{\partial \theta_{T}} \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[\theta_{S}-\eta_{S} \nabla_{\theta_{S}} \operatorname{CE}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)\right]\right)\right) \]

為簡化表示,定義:

\[\underbrace{\bar{\theta}_{S}^{\prime}}_{|S| \times 1}=\mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[\theta_{S}-\eta_{S} \nabla_{\theta_{S}} \mathbf{C E}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)\right] \]

則上式可以表示為:

\[\begin{aligned}\underbrace{\frac{\partial R}{\partial \theta_{T}}}_{1 \times|T|} &=\frac{\partial}{\partial \theta_{T}} \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[\theta_{S}-\eta_{S} \nabla_{\theta_{S}} \mathbf{C E}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)\right]\right)\right) \\&=\frac{\partial}{\partial \theta_{T}} \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \bar{\theta}_{S}^{\prime}\right)\right) \\&=\underbrace{\left.\frac{\partial \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \bar{\theta}_{S}^{\prime}\right)\right)}{\partial \theta_{S}}\right|_{\left.\theta_{S}=\bar{\theta}_{S}^{\prime}\right)}}_{|\times| S \mid} \cdot \underbrace{\frac{\partial \bar{\theta}_{S}^{\prime}}{\partial \theta_{T}}}_{|S| \times|T|}\end{aligned} \]

上式的左邊其實是很容易利用梯度下降求解的,因為可以利用Student模型在有標籤資料集上更新前後引數相減得到梯度:

由於

\[\theta_{S}^*=\theta_{S}-\eta_{S} \nabla_{\theta_{S}} \mathbf{C E}\left(y_{l}, S\left(x_{l} ; \theta_{S}\right)\right) \]

因此很容易利用更新前後\(\theta_S\)進行相減得到其梯度:

\[\eta_{S} \nabla_{\theta_{S}} \mathbf{C E}\left(y_{l}, S\left(x_{l} ; \theta_{S}\right)\right) =\theta_{S}-\theta_{S}^* \]

所以現在需要聚焦到前面式子的右側項\(\frac{\partial \bar{\theta}_{S}^{\prime}}{\partial \theta_{T}}\)

將這該式展開:

\[\begin{aligned}\underbrace{\frac{\partial \bar{\theta}_{S}^{\prime}}{\partial \theta_{T}}}_{|S| \times|T|} &=\frac{\partial}{\partial \theta_{T}} \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[\theta_{S}-\eta_{S} \nabla_{\theta_{S}} \operatorname{CE}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)\right] \\&=\frac{\partial}{\partial \theta_{T}} \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[\theta_{S}-\eta_{S} \cdot\left(\left.\frac{\partial \operatorname{CE}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)}{\partial \theta_{S}}\right|_{\theta_{S}=\theta_{S}}\right)^{\top}\right]\end{aligned} \]

為了簡化表示,我們再次定義:

\[\underbrace{g_{S}\left(\widehat{y}_{u}\right)}_{|S| \times|1|}=\left(\left.\frac{\partial \operatorname{CE}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)}{\partial \theta_{S}}\right|_{\theta_{S}=\theta_{S}}\right)^{\top} \]

那麼上式就變成了:

\[\underbrace{\frac{\partial \bar{\theta}_{S}^{\prime}}{\partial \theta_{T}}}_{|S| \times|T|}=-\eta_{S} \cdot \frac{\partial}{\partial \theta_{T}} \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}[\underbrace{g_{S}\left(\widehat{y}_{u}\right)}_{|S| \times 1}] \]

這裡的\(g_{S}\left(\widehat{y}_{u}\right)\)並不依賴\(\theta_T\)的,只是\(\widehat{y}_{u}\)需要利用偽標籤演算法依賴Teacher模型的引數罷了,舉一個例子:

\[\frac{\partial }{\partial \theta} \mathbb{E}_{x \sim p(x;\theta)}f(x) \]

這個式子要求梯度可以這麼做:

\[\frac{\partial }{\partial \theta} \mathbb{E}_{x \sim p(x;\theta)}f(x) \\ = \frac{\partial}{\partial \theta} \int p(x;\theta)f(x)dx \\ = \int \frac{\partial}{\partial \theta}p(x;\theta)f(x)dx \\ = \int p(x;\theta) \nabla_\theta log(p(x;\theta))f(x)dx \\ =\mathbb{E}_{x \sim p(x;\theta)} f(x)\nabla_\theta log(p(x;\theta)) \]

那麼同理呀,上式可以寫成:

\[\frac{\partial }{\partial \theta_T} \mathbb{E}_{\hat{y}_u \sim T(x_u;\theta_T)}[g_s(\hat{y}_u)] \\ = \frac{\partial}{\partial \theta_T} \sum_{\hat{y}_u} p(\hat{y}_u|x_u;\theta_T)g_s(\hat{y}_u) \\ = \sum_{\hat{y}_u} \frac{\partial}{\partial \theta_T}p(\hat{y}_u|x_u;\theta_T)g_s(\hat{y}_u) \\ = \sum_{\hat{y}_u} p(\hat{y}_u|x_u;\theta_T) \frac{\partial}{\partial \theta_T} log(p(\hat{y}_u|x_u;\theta_T)g_s(\hat{y}_u) \\ =\mathbb{E}_{\hat{y}_u \sim T(x_u;\theta_T)} [g_s(\hat{y}_u) \frac{\partial}{\partial \theta_T} log(p(\hat{y}_u|x_u;\theta_T)] \]

因此有:

\[\begin{aligned}\underbrace{\frac{\partial \bar{\theta}_{S}^{(t+1)}}{\partial \theta_{T}}}_{|S| \times|T|} &=-\eta_{S} \cdot \frac{\partial}{\partial \theta_{T}} \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}\left[g_{S}\left(\widehat{y}_{u}\right)\right] \\&=-\eta_{S} \cdot \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}[\underbrace{g_{S}\left(\widehat{y}_{u}\right)}_{|S| \times 1} \underbrace{\cdot \underbrace{\frac{\partial \log P\left(\widehat{y}_{u} \mid x_{u} ; \theta_{T}\right)}{\partial \theta_{T}}}_{1 \times|T|}]}\\&=\eta_{S} \cdot \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}[\underbrace{g_{S}\left(\widehat{y}_{u}\right)}_{|S| \times 1} \cdot \underbrace{\frac{\partial \operatorname{CE}\left(\widehat{y}_{u}, T\left(x_{u} ; \theta_{T}\right)\right)}{\partial \theta_{T}}}_{1 \times|T|}]\end{aligned} \]

到這一步就可以利用交叉熵損失項來計算該部分梯度了。

到這裡,再整理一下上面提到的左項和右項:

\[\begin{aligned}\underbrace{\frac{\partial R}{\partial \theta_{T}}}_{1 \times|T|} &=\underbrace{\left.\frac{\partial \mathbf{C E}\left(y_{l}, S\left(x_{l} ; \bar{\theta}_{S}^{\prime}\right)\right)}{\partial \theta_{S}}\right|_{\theta_{S}=\bar{\theta}_{S}^{\prime}}}_{1 \times|S|} \underbrace{\frac{\partial \bar{\theta}_{S}^{\prime}}{\partial \theta_{T}}}_{|S| \times|T|} \\&=\eta_{S} \cdot \underbrace{\left.\frac{\partial \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \bar{\theta}_{S}^{\prime}\right)\right)}{\partial \theta_{S}}\right|_{\theta_{S}=\bar{\theta}_{S}^{\prime}}}_{1 \times|S|} \cdot \mathbb{E}_{\widehat{y}_{u} \sim T\left(x_{u} ; \theta_{T}\right)}[\underbrace{g_{S}\left(\widehat{y}_{u}\right)}_{|S| \times 1} \cdot \underbrace{\frac{\partial \operatorname{CE}\left(\widehat{y}_{u}, T\left(x_{u} ; \theta_{T}\right)\right)}{\partial \theta_{T}}}_{1 \times|T|}]\end{aligned} \]

上式均值項需要進行取樣才能計算(過程就是對batch內樣本計算),以batch內一個樣本為例,其梯度為:

\[\begin{aligned}\nabla_{\theta_{T}} \mathcal{L}_{l} &=\eta_{S} \cdot \underbrace{\frac{\partial \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \theta_{S}^{\prime}\right)\right)}{\partial \theta_{S}}}_{1 \times|S|} \cdot \underbrace{\left(\left.\frac{\partial \mathbf{C E}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)}{\partial \theta_{S}}\right|_{\theta_{S}=\theta_{S}}\right)^{\top}}_{|S| \times 1} \cdot \underbrace{\frac{\partial \operatorname{CE}\left(\widehat{y}_{u}, T\left(x_{u} ; \theta_{T}\right)\right)}{\partial \theta_{T}}}_{1 \times|T|} \\&=\underbrace{\eta_{S} \cdot\left(\left(\nabla_{\theta_{S}^{\prime}} \operatorname{CE}\left(y_{l}, S\left(x_{l} ; \theta_{S}^{\prime}\right)\right)^{\top} \cdot \nabla_{\theta_{S}} \operatorname{CE}\left(\widehat{y}_{u}, S\left(x_{u} ; \theta_{S}\right)\right)\right)\right.}_{\text {A scalar }:=h} \cdot \nabla_{\theta_{T}} \mathbf{C E}\left(\widehat{y}_{u}, T\left(x_{u} ; \theta_{T}\right)\right)\end{aligned} \]

可以看到左端項其實利用矩陣乘法已經是一個scalar了,右端項為一vector。

到這裡,理論推導部分就結束了。

Algorithm

演算法部分相當簡單,和UDA損失一起使用,那麼基本分為兩個過程,首先利用Teacher模型提供偽標籤,優化更新Student模型,然後利用上面計算的公式,求出scalar h,代入求得梯度項,更新Teacher模型;Teacher和Student模型交替進行優化。

Experiments

實驗上,一般先在CIFAR10等這樣的小資料集上進行一輪比較:

基本是最優的。

在IMAGENET上進行全監督實驗時,其實也是按半監督的方式來做的,只是將IMAGENET的全部樣本當作有標籤的樣本,然後再用了1.3億張JFT資料集當無標籤樣本訓練的。

其實驗結果:

300+M的模型就已經達到90%的acc了。。

Code

Ref: https://github.com/kekmodel/MPL-pytorch

batch_size = images_l.shape[0]
t_images = torch.cat((images_l, images_uw, images_us))
t_logits = teacher_model(t_images)
t_logits_l = t_logits[:batch_size]
t_logits_uw, t_logits_us = t_logits[batch_size:].chunk(2)
del t_logits

t_loss_l = criterion(t_logits_l, targets)

soft_pseudo_label = torch.softmax(t_logits_uw.detach()/args.temperature, dim=-1)
max_probs, hard_pseudo_label = torch.max(soft_pseudo_label, dim=-1)
mask = max_probs.ge(args.threshold).float()
t_loss_u = torch.mean(
    -(soft_pseudo_label * torch.log_softmax(t_logits_us, dim=-1)).sum(dim=-1) * mask
)
weight_u = args.lambda_u * min(1., (step+1) / args.uda_steps)
t_loss_uda = t_loss_l + weight_u * t_loss_u

s_images = torch.cat((images_l, images_us))
s_logits = student_model(s_images)
s_logits_l = s_logits[:batch_size]
s_logits_us = s_logits[batch_size:]
del s_logits

s_loss_l_old = F.cross_entropy(s_logits_l.detach(), targets)
s_loss = criterion(s_logits_us, hard_pseudo_label)

s_scaler.scale(s_loss).backward()
if args.grad_clip > 0:
    s_scaler.unscale_(s_optimizer)
    nn.utils.clip_grad_norm_(student_model.parameters(), args.grad_clip)
s_scaler.step(s_optimizer)
s_scaler.update()
s_scheduler.step()
if args.ema > 0:
    avg_student_model.update_parameters(student_model)

with amp.autocast(enabled=args.amp):
    with torch.no_grad():
        s_logits_l = student_model(images_l)
    s_loss_l_new = F.cross_entropy(s_logits_l.detach(), targets)
    # dot_product = s_loss_l_new - s_loss_l_old
    # test
    dot_product = s_loss_l_old - s_loss_l_new
    # moving_dot_product = moving_dot_product * 0.99 + dot_product * 0.01
    # dot_product = dot_product - moving_dot_product
    _, hard_pseudo_label = torch.max(t_logits_us.detach(), dim=-1)
    t_loss_mpl = dot_product * F.cross_entropy(t_logits_us, hard_pseudo_label)
    t_loss = t_loss_uda + t_loss_mpl

t_scaler.scale(t_loss).backward()
if args.grad_clip > 0:
    t_scaler.unscale_(t_optimizer)
    nn.utils.clip_grad_norm_(teacher_model.parameters(), args.grad_clip)
t_scaler.step(t_optimizer)
t_scaler.update()
t_scheduler.step()

teacher_model.zero_grad()
student_model.zero_grad()

if args.world_size > 1:
    s_loss = reduce_tensor(s_loss.detach(), args.world_size)
    t_loss = reduce_tensor(t_loss.detach(), args.world_size)
    t_loss_l = reduce_tensor(t_loss_l.detach(), args.world_size)
    t_loss_u = reduce_tensor(t_loss_u.detach(), args.world_size)
    t_loss_mpl = reduce_tensor(t_loss_mpl.detach(), args.world_size)
    mask = reduce_tensor(mask, args.world_size)