1. 程式人生 > 實用技巧 >fine-tuning of VGG

fine-tuning of VGG

一、 fine-tuning

由於資料集的限制,我們可以使用預訓練的模型,來重新fine-tuning(微調)。

使用卷積網路作為特徵提取器,凍結卷積操作層,這是因為卷積層提取的特徵對於許多工都有用處,使用新的資料集訓練新定義的全連線層。

何時以及如何Fine-tune

決定如何使用遷移學習的因素有很多,這是最重要的只有兩個:新資料集的大小、以及新資料和原資料集的相似程度。有一點一定記住:網路前幾層學到的是通用特徵,後面幾層學到的是與類別相關的特徵。這裡有使用的四個場景:

1、新資料集比較小且和原資料集相似。因為新資料集比較小,如果fine-tune可能會過擬合;又因為新舊資料集類似,我們期望他們高層特徵類似,可以使用預訓練網路當做特徵提取器,用提取的特徵訓練線性分類器。

2、新資料集大且和原資料集相似。因為新資料集足夠大,可以fine-tune整個網路。

3、新資料集小且和原資料集不相似。新資料集小,最好不要fine-tune,和原資料集不類似,最好也不使用高層特徵。這時可是使用前面層的特徵來訓練SVM分類器。

4、新資料集大且和原資料集不相似。因為新資料集足夠大,可以重新訓練。但是實踐中fine-tune預訓練模型還是有益的。新資料集足夠大,可以fine-tine整個網路。

我們這次的作業屬於資料集與原來相似而且資料集很小的情況,可以使用預訓練網路當做特徵提取器,用提取的特徵訓練線性分類器。

二、程式碼重點

資料處理

normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])

vgg_format = transforms.Compose([
                transforms.CenterCrop(224), 
    			#.中心裁剪:transforms.CenterCrop
				#class torchvision.transforms.CenterCrop(size)
				#功能:依據給定的size從中心裁剪
				#引數:
				#size- (sequence or int),若為sequence,則為(h,w),若為int,則(size,size)
                transforms.ToTensor(),
                normalize,
            ])

data_dir = './dogscats'

dsets = {x: datasets.ImageFolder(os.path.join(data_dir, x), vgg_format)
         for x in ['train', 'valid']}

dset_sizes = {x: len(dsets[x]) for x in ['train', 'valid']}
dset_classes = dsets['train'].classes

修改全連線層,凍結卷積層的引數

for param in model_vgg_new.parameters():
    param.requires_grad = False #訓練時不更改引數
model_vgg_new.classifier._modules['6'] = nn.Linear(4096, 2) #全連線輸出兩類貓或者狗
model_vgg_new.classifier._modules['7'] = torch.nn.LogSoftmax(dim = 1) # 資料處理

建立損失函式和優化器,訓練模型

criterion = nn.NLLLoss() #設定損失函式

lr = 0.001 # 學習率

optimizer_vgg = torch.optim.SGD(model_vgg_new.classifier[6].parameters(),lr = lr) # 隨機梯度下降

#訓練模型 (模板 建議直接背誦)
def train_model(model,dataloader,size,epochs=1,optimizer=None):
    model.train()
    
    for epoch in range(epochs):
        running_loss = 0.0
        running_corrects = 0
        count = 0
        for inputs,classes in dataloader:
            inputs = inputs.to(device)
            classes = classes.to(device)
            outputs = model(inputs)
            loss = criterion(outputs,classes)           
            optimizer = optimizer
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            _,preds = torch.max(outputs.data,1)
            # statistics
            running_loss += loss.data.item()
            running_corrects += torch.sum(preds == classes.data)
            count += len(inputs)
            print('Training: No. ', count, ' process ... total: ', size)
        epoch_loss = running_loss / size
        epoch_acc = running_corrects.data.item() / size
        print('Loss: {:.4f} Acc: {:.4f}'.format(
                     epoch_loss, epoch_acc))

三、程式碼優化

1.資料處理

vgg_format = transforms.Compose([
                #transforms.CenterCrop(224), 
    			transforms.Resize((224,224)),
    #這裡選擇縮放而不是中心裁剪,因為簡單地選擇重心裁剪會讓影象的一些特徵直接丟失,嚴重的情況下直接無法捕捉到物體(cat or dog),這樣的情況下卷積也沒有什麼作用了
                transforms.ToTensor(),
                normalize,
            ])

下面的圖片就可以看出使用縮放而不是中心裁剪的原因:原本圖片的