fine-tuning of VGG
阿新 • • 發佈:2020-11-03
一、 fine-tuning
由於資料集的限制,我們可以使用預訓練的模型,來重新fine-tuning(微調)。
使用卷積網路作為特徵提取器,凍結卷積操作層,這是因為卷積層提取的特徵對於許多工都有用處,使用新的資料集訓練新定義的全連線層。
何時以及如何Fine-tune
決定如何使用遷移學習的因素有很多,這是最重要的只有兩個:新資料集的大小、以及新資料和原資料集的相似程度。有一點一定記住:網路前幾層學到的是通用特徵,後面幾層學到的是與類別相關的特徵。這裡有使用的四個場景:
1、新資料集比較小且和原資料集相似。因為新資料集比較小,如果fine-tune可能會過擬合;又因為新舊資料集類似,我們期望他們高層特徵類似,可以使用預訓練網路當做特徵提取器,用提取的特徵訓練線性分類器。
2、新資料集大且和原資料集相似。因為新資料集足夠大,可以fine-tune整個網路。
3、新資料集小且和原資料集不相似。新資料集小,最好不要fine-tune,和原資料集不類似,最好也不使用高層特徵。這時可是使用前面層的特徵來訓練SVM分類器。
4、新資料集大且和原資料集不相似。因為新資料集足夠大,可以重新訓練。但是實踐中fine-tune預訓練模型還是有益的。新資料集足夠大,可以fine-tine整個網路。
我們這次的作業屬於資料集與原來相似而且資料集很小的情況,可以使用預訓練網路當做特徵提取器,用提取的特徵訓練線性分類器。
二、程式碼重點
資料處理
normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) vgg_format = transforms.Compose([ transforms.CenterCrop(224), #.中心裁剪:transforms.CenterCrop #class torchvision.transforms.CenterCrop(size) #功能:依據給定的size從中心裁剪 #引數: #size- (sequence or int),若為sequence,則為(h,w),若為int,則(size,size) transforms.ToTensor(), normalize, ]) data_dir = './dogscats' dsets = {x: datasets.ImageFolder(os.path.join(data_dir, x), vgg_format) for x in ['train', 'valid']} dset_sizes = {x: len(dsets[x]) for x in ['train', 'valid']} dset_classes = dsets['train'].classes
修改全連線層,凍結卷積層的引數
for param in model_vgg_new.parameters():
param.requires_grad = False #訓練時不更改引數
model_vgg_new.classifier._modules['6'] = nn.Linear(4096, 2) #全連線輸出兩類貓或者狗
model_vgg_new.classifier._modules['7'] = torch.nn.LogSoftmax(dim = 1) # 資料處理
建立損失函式和優化器,訓練模型
criterion = nn.NLLLoss() #設定損失函式 lr = 0.001 # 學習率 optimizer_vgg = torch.optim.SGD(model_vgg_new.classifier[6].parameters(),lr = lr) # 隨機梯度下降 #訓練模型 (模板 建議直接背誦) def train_model(model,dataloader,size,epochs=1,optimizer=None): model.train() for epoch in range(epochs): running_loss = 0.0 running_corrects = 0 count = 0 for inputs,classes in dataloader: inputs = inputs.to(device) classes = classes.to(device) outputs = model(inputs) loss = criterion(outputs,classes) optimizer = optimizer optimizer.zero_grad() loss.backward() optimizer.step() _,preds = torch.max(outputs.data,1) # statistics running_loss += loss.data.item() running_corrects += torch.sum(preds == classes.data) count += len(inputs) print('Training: No. ', count, ' process ... total: ', size) epoch_loss = running_loss / size epoch_acc = running_corrects.data.item() / size print('Loss: {:.4f} Acc: {:.4f}'.format( epoch_loss, epoch_acc))
三、程式碼優化
1.資料處理
vgg_format = transforms.Compose([
#transforms.CenterCrop(224),
transforms.Resize((224,224)),
#這裡選擇縮放而不是中心裁剪,因為簡單地選擇重心裁剪會讓影象的一些特徵直接丟失,嚴重的情況下直接無法捕捉到物體(cat or dog),這樣的情況下卷積也沒有什麼作用了
transforms.ToTensor(),
normalize,
])
下面的圖片就可以看出使用縮放而不是中心裁剪的原因:原本圖片的