1. 程式人生 > 其它 >手把手教你實現GAN半監督學習

手把手教你實現GAN半監督學習

作者:劉威威

編輯:田 旭

引言

本文主要介紹如何在tensorflow上僅使用200個帶標籤的mnist影象,實現在一萬張測試圖片上99%的測試精度,原理在於使用GAN做半監督學習。前文主要介紹一些原理部分,後文詳細介紹程式碼及其實現原理。前文介紹比較簡單,有基礎的同學請掠過直接看第二部分,文章末尾給出了程式碼GitHub連結。對GAN不瞭解的同學可以檢視微信公眾號:機器學習演算法全棧工程師 的GAN入門文章。

監督,無監督,半監督學習介紹

在正式介紹實現半監督學習之前,我在這裡首先介紹一下監督學習(supervised learning),半監督學習(semi-supervised learning)和無監督學習(unsupervised learning)的區別。監督學習是指在訓練集中包含訓練資料的標籤(label),比如類別標籤,位置標籤等等。最普遍使用標籤學習的是分類任務,對於分類任務,輸入給網路訓練樣本(samples)的一些特徵(feature)以及此樣本對應的標籤(label),通過神經網路擬合的方法,神經網路可以在特徵和標籤之間找到一個合適的對映關係(mapping),這樣當訓練完成後,輸入給網路沒有label的樣本,神經網路可以通過這一個對映關係猜出它屬於哪一類。典型機器學習的監督學習的例子是KNN和SVM。目前機器視覺領域的急速發展離不開監督學習。

而無監督學習的訓練事先沒有訓練標籤,直接輸入給演算法一些資料,演算法會努力學習資料的共同點,尋找樣本之間的規律性。無監督學習是很典型的學習,人的學習有時候就是基於無監督的,比如我並不懂音樂,但是我聽了上百首歌曲後,我可以根據我聽的結果將音樂分為搖滾樂(記為0類)、民謠(記為1類)、純音樂(記為2類)等等,事實上,我並不知道具體是哪一類,所以將它們記為0,1,2三類。典型的無監督學習方法是聚類演算法,比如k-means。

東方快車電影裡面大偵探有過一個臺詞,人們的話只有對與錯,沒有中間地帶,最後經過一系列事件後他找到了對與錯之間的betweeness。在監督學習和無監督學習之間,同樣存在著中間地帶-半監督學習。半監督學習簡單來說就是將無監督學習和監督學習相結合,一部分包含了監督學習一部分包含了無監督學習,比如給一個分類任務,此分類任務的訓練集中有精確標籤的資料非常少,但是包含了大量的沒有標註的資料,如果直接用監督學習的方法去做的話,效果不一定很好,有標註的訓練資料太少很容易導致過擬合,而且大量的無標註的資料都沒有充分的利用,最常見的例子是在醫學影象的分析檢測任務中,醫學影象本身就不容易獲得,要獲得精標註的影象就需要有經驗的醫生去一個一個標註,顯然他們並沒有那麼多的時間。這時候就是半監督學習的用武之地了,半監督學習很適合用在標籤資料少,訓練資料又比較多的情況。

常見的半監督學習方法主要有:

1.Self training

2.Generative model

3.S3VMs

4.Graph-Based AIgorithems

5.Multiview AIgorithems

接下來我會結合Improved Techniques for Training GANs這篇論文詳細介紹如何使用目前最火的生成模型GAN去實現半監督學習,也即是半監督學習的第二種方法,並給出詳細的程式碼解釋,對理論不是很熟悉的同學可以直接看程式碼。另外註明:我只復現了論文半監督學習的部分,之前也有人復現了此部分,但是我感覺他對原文有很大的曲解,他使用了所有的標籤去幫助生成,並不在分類上,不太符合半監督學習的本質,而且程式碼很複雜,感興趣的可以看這個連結https://github.com/gitlimlab/SSGAN-Tensorflow。

Improved Techniques for Training GANs

GAN是無監督學習的代表,它可以不斷學習模擬資料的分佈進而生成和訓練資料相似分佈的樣本,在訓練過程不需要標籤,GAN在無監督學習領域,生成領域,半監督學習領域以及強化學習領域都有廣泛的應用。但是GAN存在很多的訓練不穩定等等的問題,作者good fellow在2016年放出了Improved Techniques for Training GANs,對GAN訓練不穩定的問題做了一些解釋和經驗上的解決方案,並給出了和半監督學習結合的方法。

從平衡點角度解釋GAN的不穩定性來說,GAN的納什均衡點是一個鞍點,並不是一個區域性最小值點,基於梯度的方法主要是尋找高維空間中的極小值點,因此使用梯度訓練的方法很難使GAN收斂到平衡點。為此,為了一部分緩解這個問題,goodfellow聯合提出了一些改進方案,

主要有:

Feature matching,

Minibatch discrimination

weight Historical averaging (相當於一個正則化的方式)

One-sided label smoothing

Virtual batch normalization

後來發現Feature matching在半監督學習上表現良好,mini-batch discrimination表現很差。

semi-supervised GAN

對於一個普通的分類器來說,假設對MNIST分類,一共有10類資料,分別是0-9,分類器模型以資料x作為輸入,輸出一個K=10維的向量,經過soft max後計算出分類概率最大的那個類別。在監督學習領域,往往是通過最小化類別標籤 y 和預測分佈

的交叉熵來實現最好的結果。

但是將GAN用在半監督學習領域的時候需要做一些改變,生成器不做改變,仍然負責從輸入噪聲資料中生成影象,判別器D不在是一個簡單的真假分類(二分類)器,假設輸入資料有K類,D就是K+1的分類器,多出的那一類是判別輸入是否是生成器G生成的影象。網路的流程圖見圖一。

圖一 網路的流程圖

網路結構確定了之後就是損失函式的設計部分,藉助GAN我們就可以從無標籤資料中學習,只要知道輸入資料是真實資料,那就可以通過最大化

來實現,上述式子可解釋為不管輸入的是哪一類真的圖片(不是生成器G生成的假圖片),只要最大化輸出它是真影象的概率就可以了,不需要具體分出是哪一類。由於GAN的生成器的參與,訓練資料中有一半都是生成的假資料。

下面給出判別器D的損失函式設計,D損失函式包括兩個部分,一個是監督學習損失,一個是半監督學習損失,具體公式如下:

其中

對於無監督學習來說,只需要輸出真假就可以了,不需要確定是哪一類,因此我們令

其中

表示判別是假影象的概率,那麼D(x)就代表了輸出是真影象的概率,那麼無監督學習的損失函式就可以表示為

這不就是GAN的損失函式嘛!好了,到這裡得出結論,在半監督學習中,判別器的分類要多分一類,多出的這一類表示的是生成器生成的假影象這一類,另外判別器的損失函式不僅包括了監督損失而且還有無監督的損失函式,在訓練過程中同時最小化這兩者。損失函式介紹完畢,接下來介紹程式碼實現部分。

程式碼實現及解讀

注:完整程式碼的GitHub連線在文章底部。這裡只擷取關鍵部分做介紹。

在程式碼中,我使用feature matching,one side label smoothing方式,並沒有使用論文中介紹的Historical averaging,而是隻對判別器D使用了簡單的l2正則化,防止過擬合,另外論文中介紹的Minibatch discrimination, Virtual batch normalization等等都沒有使用,主要是這兩者在半監督學習中表現不是很好,但是如果想獲得好的生成結果還是很有用的。

1網路結構

首先介紹網路結構部分,因為是在mnist資料集比較簡單,所以隨便搭了一個判別器和生成器,具體如下:

判別器的網路結構如下面程式碼所示:

def discriminator(self, name, inputs, reuse):
        l = tf.shape(inputs)[0]
        inputs = tf.reshape(inputs, (l,self.img_size,self.img_size,self.dim))
        with tf.variable_scope(name,reuse=reuse):
            out = []
            output = conv2d('d_con1',inputs,5, 64, stride=2, padding='SAME') #14*14
            output1 = lrelu(self.bn('d_bn1',output))
            out.append(output1)
            # output1 = tf.contrib.keras.layers.GaussianNoise
            output = conv2d('d_con2', output1, 3, 64*2, stride=2, padding='SAME')#7*7
            output2 = lrelu(self.bn('d_bn2', output))
            out.append(output2)
            output = conv2d('d_con3', output2, 3, 64*4, stride=1, padding='VALID')#5*5
            output3 = lrelu(self.bn('d_bn3', output))
            out.append(output3)
            output = conv2d('d_con4', output3, 3, 64*4, stride=2, padding='VALID')#2*2
            output4 = lrelu(self.bn('d_bn4', output))
            out.append(output4)
            output = tf.reshape(output4, [l, 2*2*64*4])# 2*2*64*4
            output = fc('d_fc', output, self.num_class)
            # output = tf.nn.softmax(output)
            return output, out

其中conv2d()是卷積操作,引數依次是,層的名字,輸入tensor,卷積核大小,輸出通道數,步長,padding。判別器中每一層都加了歸一化層,這裡使用最簡單的歸一化,函式如下所示,另外每一層的啟用函式使用leakrelu。判別器D最終返回兩個值,第一個是計算的logits,另外一個是一個列表,列表的每一個元素代表判別器每一層的輸出,為接下來實現feature matching做準備。

def bn(self, name, input):
        val = tf.contrib.layers.batch_norm(input, decay=0.9,
                                           updates_collections=None,
                                           epsilon=1e-5,
                                           scale=True,
                                           is_training=True,
                                           scope=name)
        return val


def lrelu(x, leak=0.2):
    return tf.maximum(x, leak * x)

生成器結構如下面程式碼所示:其最後一層啟用函式使用tanh

def generator(self,name, noise, reuse):
        with tf.variable_scope(name,reuse=reuse):
            l = self.batch_size
            output = fc('g_dc', noise, 2*2*64)
            output = tf.reshape(output, [-1, 2, 2, 64])
            output = tf.nn.relu(self.bn('g_bn1',output))
            output = deconv2d('g_dcon1',output,5,outshape=[l, 4, 4, 64*4])
            output = tf.nn.relu(self.bn('g_bn2',output))

            output = deconv2d('g_dcon2', output, 5, outshape=[l, 8, 8, 64 * 2])
            output = tf.nn.relu(self.bn('g_bn3', output))

            output = deconv2d('g_dcon3', output, 5, outshape=[l, 16, 16,64 * 1])
            output = tf.nn.relu(self.bn('g_bn4', output))

            output = deconv2d('g_dcon4', output, 5, outshape=[l, 32, 32, self.dim])
            output = tf.image.resize_images(output, (28, 28))
            # output = tf.nn.relu(self.bn('g_bn4', output))
            return tf.nn.tanh(output)

網路結構是根據DCGAN的結構改的,所以網路簡要介紹到這裡。

2網路初始化

接下來介紹網路初始化方面:

首先在train.py裡建立一個Train的類,並做一些初始化

class Train(object):
    def __init__(self, sess, args):
        #sess=tf.Session()
        self.sess = sess
        self.img_size = 28   # the size of image
        self.trainable = True
        self.batch_size = 100  # must be even number
        self.lr = 0.0002
        self.mm = 0.5      # momentum term for adam
        self.z_dim = 128   # the dimension of noise z
        self.EPOCH = 50    # the number of max epoch
        self.LAMBDA = 0.1  # parameter of WGAN-GP
        self.model = args.model  # 'DCGAN' or 'WGAN'
        self.dim = 1       # RGB is different with gray pic
        self.num_class = 11
        self.load_model = args.load_model
        self.build_model()  # initializer

args是傳進來的引數,主要包括三個,一個是args.model,選擇DCGAN模式還是WGAN-GP模式,二者的不同主要在於損失函式不同和優化器的學習率不同,其他都一樣。第二個引數是args.trainable,訓練還是測試,訓練時為True,測試是False。Loadmodel表示是否選擇載入訓練好的權重。

import argparse
parser.add_argument('--model', type=str, default='DCGAN', help='DCGAN or WGAN-GP')
parser.add_argument('--trainable', type=bool, default=False,help='True for train and False for test')
parser.add_argument('--load_model', type=bool, default=True, help='True for load ckpt model and False for otherwise')
parser.add_argument('--label_num', type=int, default=2, help='the num of labled images we use, 2*100=200,batchsize:100')

3Build_model函式

Build_model函式裡面主要包括了網路訓練前的準備工作,主要包括損失函式的設計和優化器的設計。以下程式碼連在一起正好是build_model函式的全部內容,下文將詳細做出介紹,尤其是損失函式部分。

def build_model(self):
        # build  placeholders
        self.x = tf.placeholder(tf.float32, shape=[self.batch_size, self.img_size*self.img_size*self.dim], name='real_img')
        self.z = tf.placeholder(tf.float32, shape=[self.batch_size, self.z_dim], name='noise')
        self.label = tf.placeholder(tf.float32, shape=[self.batch_size, self.num_class-1], name='label')
        self.flag = tf.placeholder(tf.float32, shape=[], name='flag')
        self.flag2 = tf.placeholder(tf.float32, shape=[], name='flag2')
        # define the network
        self.G_img = self.generator('gen', self.z, reuse=False)
        ximg = tf.reshape(self.x, (self.batch_size, self.img_size, self.img_size, self.dim))
        d_in = tf.concat([ximg, self.G_img], axis=0)

        self.D_logits_, self.D_out_ = self.discriminator('dis', d_in, reuse=False)

        self.D_logits, self.D_logits_f = tf.split(self.D_logits_, [self.batch_size, self.batch_size], axis=0)

        d_regular = tf.add_n(tf.get_collection('regularizer', 'dis'), 'loss')
       #caculate the supervised loss
        batch_gl = tf.zeros_like(self.label, dtype=tf.float32)
        batchl_ = tf.concat([self.label, tf.zeros([self.batch_size, 1])], axis=1)
        batch_gl = tf.concat([batch_gl, tf.ones([self.batch_size, 1])], axis=1)
        batchl = tf.concat([batchl_, batch_gl], axis=0)*0.9  # one side label smoothing
         s_l = tf.losses.softmax_cross_entropy(onehot_labels=batchl, logits=self.D_logits_, label_smoothing=None)
        s_logits_ = tf.nn.softmax(self.D_logits_)
        un_s = tf.reduce_sum(s_logits_[:self.batch_size, -1])/(tf.reduce_sum(s_logits_[:self.batch_size,:])) 
                + tf.reduce_sum(s_logits_[self.batch_size:,:-1])/tf.reduce_sum(s_logits_[self.batch_size:,:])
        f_match = tf.constant(0., dtype=tf.float32)
        for i in range(4):
            d_layer, d_glayer = tf.split(self.D_out_[i], [self.batch_size, self.batch_size], axis=0)
            f_match += tf.reduce_mean(tf.multiply(tf.subtract(d_layer, d_glayer),tf.subtract(d_layer, d_glayer)))
        self.d_loss_real = -tf.log(tf.reduce_sum(s_logits_[:self.batch_size, :-1])/tf.reduce_sum(s_logits_[:self.batch_size, :]))
            self.d_loss_fake = -tf.log(tf.reduce_sum(s_logits_[self.batch_size:, -1])/tf.reduce_sum(s_logits_[self.batch_size:, :]))
            self.g_loss = self.d_loss_fake + f_match*0.01*self.flag2
            self.d_l_1, self.d_l_2, self.d_l_3 = self.d_loss_fake + self.d_loss_real, self.flag*s_l, (1-self.flag)*un_s
            self.d_loss = self.d_l_1 + self.d_l_2 + self.d_l_3

首先,建立了五個placeholder,flag表示兩個標誌位,只有0-1兩種情況,注意到我num_class是11,也就是做11分類,但是lable的placeholder中shape是(batchsize,10),因為傳進去訓練之前會將label擴充套件到[batchsize, 11]。為了方便,我將生成器的生成結果和真實資料X級聯在一起作為判別器的輸入,輸出再把他它們結果split分開。

d_regular 表示正則化,這裡我將判別器中所有的weights做了l2正則。

監督學習的損失函式使用常見的交叉熵損失函式,對生成器生成的影象的label的one_hot型為:

[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]

將原始的label擴充套件到(batchsize,11)後再和生成器生成的假資料的label再第一維度concat到一起得到batchl,另外乘以0.9,做單邊標籤平滑(one side smoothing),由此計算得到監督學習的損失函式值s_l,。

生成器G的損失函式

生成器G的損失函式包括兩部分,一個是來自GAN訓練的部分,另外一個是feature matching , 論文中提到的feature matching意思是特徵匹配,主要思想是希望生成器生成的假資料輸入到判別器,經過判別器每一層計算的結果和將真實資料X輸入到判別器,判別器每一層的結果儘可能的相似,公式如下:

其中f(x)是D的每一層的輸出。Feature matching 是指導G進行訓練,所以我將他放在了G的損失函式裡。

分類器D的損失函式

相比較G的損失函式,D的損失函式就比較麻煩了。

接下來介紹無監督學習的損失函式實現:

在前面介紹的無監督學習的損失函式中,有一部分和GAN的損失函式很相似,所以再程式碼中我們使用了無監督學習的時候沒有標籤的指導,此時判別器或者稱為分類器D無法正確對輸入進行分類,此時只要求D能夠區分真假就可以了,由此我們得到了無監督學習的損失un_s,直觀上也很好理解,假設輸入給判別器D真影象,它結果經過soft max後輸出類似下面表格的形式

其中前十個黃色區域表示對0-9的分類概率,最後一個灰色的表示對假影象的分類概率,由於無監督學習中判別器D並不知道具體是哪一類資料,所以乾脆D的損失函式最小化輸出假影象的概率就可以了,當輸入為生成器生成的假影象時,只要最小化D輸出為真影象的概率,由此我們得到了un_s.。但是此時有一個問題,即是有監督學習的時候不就沒有用了嗎,因為這時候應該使用s_l.為了解決這個問題,我使用了一個標誌位flag作為控制他們之間的使用,具體程式碼:

flag*s_l + ( 1 – flag)*un_s

有標籤的時候flag是1,表示使用s_l,無監督的時候flag是0,表示使用無監督損失函式。此時已經完成了判別器D損失函式的一部分設計,剩下的一部分和GAN中的D的損失一樣,在程式碼中我給出了兩種損失函式,一個是原始GAN的交叉熵損失函式,和DCGAN使用的一樣,另外一個是improved wgan論文中使用的損失函式,但是在做了對比之後,我強烈建議使用DCGAN來做,improved wgan的損失函式雖然在生成結果的優化上有很大幫助,但是並不適合半監督學習中。

訓練部分

接下來就是訓練部分:

此時可能有一個疑問,我們是如何實現只使用200帶標籤的資料訓練的,答案就在flag這個標誌位裡,在訓練部分程式碼中,當迭代次數小於2的時候,flag=1, 此時表示使用s_l作為損失函式的一部分,當flag=0的時候,un_s起作用而s_l並沒有起作用,這時,即使我們feed了正確的標籤資料,但是s_l不起作用,就相當於沒有使用標籤。

for idx in range(iters):
       start_t = time.time()
       flag = 1 if idx<args.label_num else 0 # set we use 500 train data with label.

flag2的作用本來是使用他控制feature matching是否工作的,這裡暫時設定為1。

(訓練部分詳細程式碼請移步文章下面github連結檢視)

測試

def test(self):
        count = 0.
        print 'testing................'
        for i in range(10000//self.batch_size):
            testx, textl = mnist.test.next_batch(self.batch_size)
            prediction = self.sess.run(self.prediction, feed_dict={self.x:testx, self.label:textl})
            count += np.sum(prediction)
        return count/10000.

測試部分程式碼如上圖所示,沒訓練完成一個epoch,就測試依次,測試的時候,使用了一個temp儲存測試的最大精度,當測試結果比前幾次都要好是,temp會更新到最好的測試精度,並儲存模型,否則不儲存模型,這樣做的好處在於我儲存的模型測試精度一定是最好的。

測試精度結果變化圖

本文實現及程式碼

使用GAN實現半監督學習程式碼https://github.com/LDOUBLEV/semi-supervised-GAN

如果感覺有用的話,歡迎star,fork。

參考文獻

https://arxiv.org/abs/1606.03498