Transfer Learning(遷移學習)實戰---寶可夢影象分類
阿新 • • 發佈:2021-02-09
專案介紹
資料集:總過有5類寶可夢,共1165張圖片
模型搭建框架:PyTorch.
1.基於torch.utils.data 構建自定義資料讀取器,具體細節見程式碼。
2.通過比較用torchvision 提供的pretrained的resnet18網路和自定義的殘差網路訓練在測試集的結果,來檢視遷移學習對於當前任務的效果。
具體程式碼
首先自定義資料讀取器
import torch
import os,glob
import random,csv
import visdom
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import time
class Pokemon(Dataset):
def __init__(self,root,resize,mode):
super(Pokemon,self).__init__()
self.root=root
self.resize=resize
self.name2label={}
for name in sorted(os.listdir(os.path.join(root))):
if not os.path.isdir(os.path.join(root,name)):
continue
self.name2label[name]=len(self.name2label.keys())
#print(self.name2label)
self.images,self.labels=self.load_csv('images.csv')
if mode=='train':
self.images=self.images[:int(0.6*len(self.images))]
self.labels=self.labels[:int(0.6*len (self.labels))]
elif mode=='validation':
self.images=self.images[int(0.6*len(self.images)):int(0.8*len(self.images))]
self.labels=self.labels[int(0.6*len(self.labels)):int(0.8*len(self.labels))]
else:
self.images=self.images[int(0.8*len(self.images)):int(len(self.images))]
self.labels=self.labels[int(0.8*len(self.labels)):int(len(self.labels))]
def load_csv(self,filename):
if not os.path.exists(os.path.join(self.root,filename)): #如果csv檔案不存在
images=[]
for name in self.name2label.keys():
images+=glob.glob(os.path.join(self.root,name,'*.png')) #返回png格式檔案的路徑組成的列表,下面同理
images+=glob.glob(os.path.join(self.root,name,'*.jpg'))
images+=glob.glob(os.path.join(self.root,name,'*.jpeg'))
print(len(images))
random.shuffle(images)#重新打亂順序
with open(os.path.join(self.root,filename),mode='w',newline='') as f:
writer=csv.writer(f)
for img in images:
name=img.split(os.sep)[-2]
label=self.name2label[name]
writer.writerow([img,label])
print('written into csv file:',filename)
f.close()
images,labels=[],[]
with open(os.path.join(self.root,filename)) as f: #讀取csv檔案
reader=csv.reader(f)
for row in reader:
img,label=row
label=int(label)
images.append(img)
labels.append(label)
f.close()
assert len(images)==len(labels)
return images,labels
def denormalize(self,x_hat):
mean=[0.485,0.456,0.406]
std=[0.229,0.224,0.225]
#一張x的維度:[c,h,w]
#因此mean,std也應該相應地擴充為3維
mean=torch.tensor(mean).unsqueeze(1).unsqueeze(1)
std=torch.tensor(std).unsqueeze(1).unsqueeze(1)
x=x_hat*std+mean
return x
def __len__(self):
return len(self.images)
def __getitem__(self,idx): #獲取具體的圖片
img,label=self.images[idx],self.labels[idx]
tf=transforms.Compose([
lambda x:Image.open(x).convert('RGB'), #path轉為image data
transforms.Resize((int(self.resize*1.25),int(self.resize*1.25))), #統一尺寸
transforms.RandomRotation(15), #旋轉過大會導致網路不收斂的情況
transforms.CenterCrop(self.resize), #按原尺寸進行裁剪
transforms.ToTensor(),
transforms.Normalize(mean=[0.485,0.456,0.406],std=[0.229,0.224,0.225])]) #follow imagenet的資料分佈
img=tf(img)
label=torch.tensor(label)
return img,label
然後通過繼承nn.Module來自定義殘差網路
import torch
from torch import nn
from torch.nn import functional as F
class ResBlock(nn.Module): #定義殘差網路塊
def __init__(self,ch_in,ch_out,stride=1):
super(ResBlock,self).__init__()
self.conv1=nn.Conv2d(ch_in,ch_out,kernel_size=3,stride=stride,padding=1)
self.bn1=nn.BatchNorm2d(ch_out)
self.conv2=nn.Conv2d(ch_out,ch_out,kernel_size=3,stride=1,padding=1) #no size change in this conv layer
self.bn2=nn.BatchNorm2d(ch_out)
#resnet requires the same shape of input and output of the network block
self.extra=nn.Sequential()
if ch_out!=ch_in:
self.extra=nn.Sequential(nn.Conv2d(ch_in,ch_out,kernel_size=1,stride=stride),
nn.BatchNorm2d(ch_out))
def forward(self,x):
out=F.relu(self.bn1(self.conv1(x)))
out=self.bn2(self.conv2(out))
#the kernel_size=3 and padding=1 make sure the size of feature map = orginal image size =[h,w]
#x:[batch_size,ch_in,h,w] => out:[batch_size,ch_out,h,w]
#resnet operation
out=self.extra(x)+out
out=F.relu(out)
return out
class ResNet18(nn.Module):
def __init__(self,num_class):
super(ResNet18,self).__init__()
self.conv1=nn.Sequential(nn.Conv2d(3,16,kernel_size=3,stride=3,padding=0),nn.BatchNorm2d(16))
self.blk1=ResBlock(16,32,stride=3)
self.blk2=ResBlock(32,64,stride=3)
self.blk3=ResBlock(64,128,stride=2)
self.blk4=ResBlock(128,256,stride=2)
self.outlayer=nn.Linear(256*3*3,num_class)
def forward(self,x):
x=F.relu(self.conv1(x))
x=self.blk1(x)
x=self.blk2(x)
x=self.blk3(x)
x=self.blk4(x)
#flatten
x=x.view(x.size(0),-1)
x=self.outlayer(x)
return x
訓練時,通過在main()中設定Transfer_learning 的引數值來分別訓練自定義殘差網路和基於預訓練的resnet18。
import torch
from torch import optim,nn
import visdom
import torchvision
from torch.utils.data import DataLoader
from pokemon import Pokemon
from resnet import ResNet18
from torchvision.models import resnet18
from utils import Flatten
#parameters
batch_size=32
lr=1e-3
epochs=10
device=torch.device('cpu')
torch.manual_seed(1)
#data
train_set=Pokemon('./pokemon',224,mode='train')
validation_set=Pokemon('./pokemon',224,mode='validation')
test_set=Pokemon('./pokemon',224,mode='test')
train_loader=DataLoader(train_set,batch_size=batch_size,shuffle=True)
val_loader=DataLoader(validation_set,batch_size=32)
test_loader=DataLoader(test_set,batch_size=32)
def evalute(model,loader):
correct=0
size=len(loader.dataset)
for x,y in loader:
x,y=x.to(device),y.to(device)
with torch.no_grad(): #這裡使用torch.no_grad()是因為下面的計算過程只是用於驗證集檢視模型效果,
#我們並不想要用這裡的計算結果用於後續的網路引數優化。
output=model(x)
pred=output.argmax(dim=1)
correct+=torch.eq(pred,y).sum().float().item() #item() 將一個張量轉化為一個元素值
return correct/size
viz=visdom.Visdom() #視覺化物件
def main(transfer_learning=True):
if transfer_learning: #基於預訓練的resnet18模型
trained_model=resnet18(pretrained=True)
model=nn.Sequential(*list(trained_model.children())[:-1], #經過這部分網路後的輸出size=[b,512,1,1]
Flatten(),
nn.Linear(512,5)).to(device)
print('Transfer Learning Model Loaded!')
else:
model=ResNet18(5).to(device)
optimizer=optim.Adam(model.parameters(),lr=lr)
metric=nn.CrossEntropyLoss()
best_acc,best_epoch=0,0
global_step=0
viz.line([0],[-1],win='loss',opts=dict(title='loss'))
viz.line([0],[-1],win='val_acc',opts=dict(title='val_acc'))
for epoch in range(epochs):
for step,(x,y) in enumerate(train_loader):
x,y=x.to(device),y.to(device)
output=model(x)
loss=metric(output,y)
optimizer.zero_grad() #清空累積的梯度值
loss.backward() #計算本輪的梯度
optimizer.step() #利用本輪計算得到的梯度更新引數
viz.line([loss.item()],[global_step],win='loss',update='append')
global_step+=1
if epoch %1==0:
val_acc=evalute(model,val_loader)
if val_acc>best_acc:
best_epoch=epoch
best_acc=val_acc
torch.save(model.state_dict(),'bestmodel.para') #檔案字尾隨意
viz.line([val_acc],[global_step],win='val_acc',update='append')
print('best acc:',best_acc,'best epoch:',best_epoch)
model.load_state_dict(torch.load('bestmodel.para'))
print('loaded from ckpt!')
test_acc=evalute(model,test_loader)
print('test acc:',test_acc)
if __name__=='__main__':
main(transfer_learning=False)
因為在PyTorch中沒有實現Flatten(), 展平操作,因此需要自己定義。
import torch
from torch import nn
class Flatten(nn.Module):
def __init__(self):
super(Flatten,self).__init__()
def forward(self,x):
shape=torch.prod(torch.tensor(x.shape[1:])).item()
return x.view(-1,shape)
最後是訓練結果
首先是自定義的網路的訓練結果,可以看到在訓練集上的表現為0.9的準確率,有點意料之外?
而且仔細觀察還能發現,在測試集上的表現居然還比驗證集上的表現好2個百分點,意料之外的沒有overfitting。
然後是基於pre-trained的resnet18的模型效果
意料之中的有所提高,但是提升效果沒有想象中那麼高,在測試集上的準確率為0.93。遷移學習在這個task上的作用並不明顯,當然部分原因是自定義網路的效果有點超出預期。。。