行人重識別(12)——程式碼實踐之模型訓練(train_model(表徵學習).py)
!轉載請註明原文地址!——東方旅行者
更多行人重識別文章移步我的專欄:行人重識別專欄
本文目錄
表徵學習模型訓練(train_model(表徵學習).py)
一、train_model(表徵學習).py作用
本檔案是行人重識別系統的核心檔案,用於模型訓練。
二、train_model(表徵學習).py編寫思路
為了便於管理一些常用引數,將其形成變數單獨指定,這些引數有:
1. 輸入圖片寬度width 2. 輸入圖片高度height 3. 訓練批次大小train_batch_size 4. 測試批次大小test_batch_size 5. 學習率train_lr 6. 開始訓練的批次start_epoch 7. 結束訓練的批次end_epoch 8. 動態學習率變化步長dy_step_size 9. 動態學習率變化倍數dy_step_gamma 10. 是否測試evaluate 11. 模型儲存的地址best_model_path、final_model_path 12. 最大準確率(用於對判斷最優模型)max_acc
1.train()方法
訓練函式train()需要的引數有epoch(當前批次)、model(使用的模型)、criterion_class(損失函式)、optimizer(優化器型別,用於反向傳播優化網路引數)、scheduler(用於管理學習率)、data_loader(指定資料載入器,獲得網路的輸入資料)。
首先使用enumerate迭代資料載入器吐出資料,每一次突出的資料是當前批數、(圖片,行人ID,攝像機ID),在獲取輸入資料後,對優化器進行清零,防止上次計算結果對這次計算產生影響,使用模型進行運算,得到運算結果(表徵學習階段計算結果就是分類向量),使用損失函式根據計算結果與真實結果計算損失,然後使用損失的backward()方法對損失進行反向傳播優化計算,然後使用scheduler的step()方法更新學習率,在損失的backward()計算完畢後,一定不要忘記使用optimizer的step()方法更新引數,否則網路引數不會修改
然後使用torch.argmax按行求最大值,即計算分類結果,然後根據分類結果與真實結果計算本次訓練的準確率。若本次準確率高於最高準確率,則儲存此時模型至最優模型地址並列印相關資訊。
2.main()方法
主函式main()是訓練過程呼叫的函式,首先需要使用資料管理器載入資料集,便於之後資料載入器獲取索引列表。然後進行訓練資料處理器、測試資料處理器的宣告。
訓練資料處理器需要使用自定義的隨機裁剪、水平翻轉、將圖片轉為張量、歸一化。
測試資料處理器需要使用圖片尺度調整(統一尺寸)、將圖片轉為張量、歸一化。
然後宣告訓練集資料吞吐器、測試集資料吞吐器與查詢集資料吞吐器。在生成吞吐器需要指定使用的Dataset型別的資料集(自定義的ImageDataset,在dataset_loader.py中,同時在這裡還需要指定使用的資料處理器),一個批次的大小,是否重排序,是否丟棄最後無法組成一整個批次的資料。
接下來載入模型,設定損失函式、優化器(torch.optim.SGD需要指定待優化的引數,學習率,權重衰減)、動態學習率(lr_scheduler.ReduceLROnPlateau需要指定使用的優化器;mode可選擇‘min’或者‘max’,min表示當監控量停止下降的時候,學習率將減小,max表示當監控量停止上升的時候,學習率將減小。預設值為‘min’;factor學習率每次降低多少;patience容忍網路的效能不提升的次數,高於這個次數就降低學習率;min_lr學習率的下限)等。
設定模型為訓練模式,呼叫train函式進行模型訓練。最後將模型引數儲存到目標地址。
三、程式碼
import os,sys,time,datetime
import os.path as osp
import numpy as np
from IPython import embed
import torch
import torch.nn as nn
import transform as T
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
from model.ReIDNet import ReIDNet
from dataset_manager import Market1501#資料管理器
from dataset_loader import ImageDataset#資料載入器
"""
本檔案是行人重識別系統的核心檔案,用於表徵學習模型訓練。
"""
#設定輸入引數
width=64 #圖片寬度
height=128 #圖片高度
train_batch_size=32 #訓練批量
test_batch_size=32 #測試批量
train_lr=0.01 #學習率
start_epoch=0 #開始訓練的批次
end_epoch=1 #結束訓練的批次
dy_step_size=800 #動態學習率變化步長
dy_step_gamma=0.9 #動態學習率變化倍數
evaluate=False #是否測試
max_acc=-1 #最大準確率
best_model_path='./model/param/net_params_best.pth' #最優模型儲存地址
final_model_path='./model/param/net_params_final.pth' #最終模型儲存地址
def main():
#資料集載入
dataset=Market1501()
#訓練資料處理器
transform_train=T.Compose([
T.Random2DTransform(height,width),#尺度統一,隨機裁剪
T.RandomHorizontalFlip(),#水平翻轉
T.ToTensor(),#圖片轉張量
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),#歸一化,引數固定
]
)
#測試資料處理器
transform_test=T.Compose([
T.Resize((height,width)),#尺度統一
T.ToTensor(),#圖片轉張量
T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),#歸一化,引數固定
]
)
#train資料集吞吐器
train_data_loader=DataLoader(
ImageDataset(dataset.train, transform=transform_train),#自定義的資料集,使用訓練資料處理器
batch_size=train_batch_size,#一個批次的大小(一個批次有多少個圖片張量)
drop_last=True,#丟棄最後無法稱為一整個批次的資料
)
print("train_data_loader inited")
#query資料集吞吐器
query_data_loader=DataLoader(
ImageDataset(dataset.query, transform=transform_test),#自定義的資料集,使用測試資料處理器
batch_size=test_batch_size,#一個批次的大小(一個批次有多少個圖片張量)
shuffle=False,#不重排
drop_last=True,#丟棄最後無法稱為一整個批次的資料
)
print("query_data_loader inited")
#gallery資料集吞吐器
gallery_data_loader=DataLoader(
ImageDataset(dataset.gallery, transform=transform_test),#自定義的資料集,使用測試資料處理器
batch_size=test_batch_size,#一個批次的大小(一個批次有多少個圖片張量)
shuffle=False,#不重排
drop_last=True,#丟棄最後無法稱為一整個批次的資料
)
print("gallery_data_loader inited\n")
#載入模型
model=ReIDNet(num_classes=751,loss={'softmax'})#指定分類的數量,與使用的損失函式以便決定模型輸出何種計算結果
print("=>ReIDNet loaded")
print("Model size: {:.5f}M\n".format(sum(p.numel() for p in model.parameters())/1000000.0))
#損失函式
criterion_class=nn.CrossEntropyLoss()
"""
優化器
引數1,待優化的引數
引數2,學習率
引數3,權重衰減
"""
optimizer=torch.optim.SGD(model.parameters(),lr=train_lr,weight_decay=5e-04)
"""
動態學習率
引數1,指定使用的優化器
引數2,mode,可選擇‘min’(min表示當監控量停止下降的時候,學習率將減小)或者‘max’(max表示當監控量停止上升的時候,學習率將減小)
引數3,factor,代表學習率每次降低多少
引數4,patience,容忍網路的效能不提升的次數,高於這個次數就降低學習率
引數5,min_lr,學習率的下限
"""
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=dy_step_gamma, patience=10, min_lr=0.0001)
#如果是測試
if evaluate:
test(model,query_data_loader,gallery_data_loader)
return 0
#如果是訓練
print('————model start training————\n')
bt=time.time()#訓練的開始時間
for epoch in range(start_epoch,end_epoch):
model.train(True)
train(epoch,model,criterion_class,optimizer,scheduler,train_data_loader)
et=time.time()#訓練的結束時間
print('**模型訓練結束, 儲存最終引數到{}**\n'.format(final_model_path))
torch.save(model.state_dict(), final_model_path)
print('————訓練總用時{:.2f}小時————'.format((et-bt)/3600.0))
def train(epoch, model, criterion_class, optimizer, scheduler, data_loader):
"""
訓練函式train
引數1,epoch(當前批次)
引數2,model(使用的模型)
引數3,criterion_class(損失函式)
引數4,optimizer(優化器型別,用於反向傳播優化網路引數)
引數5,scheduler(用於管理學習率)
引數6,data_loader(指定資料載入器,獲得網路的輸入資料)
"""
global max_acc
for batch_idx, (imgs, pids, cids) in enumerate(data_loader):
optimizer.zero_grad()#優化器進行清零,防止上次計算結果對這次計算產生影響
outputs=model(imgs)
loss=criterion_class(outputs,pids)#使用損失函式根據計算結果與真實結果計算損失
loss.backward()#根據損失進行反向傳播優化計算
scheduler.step(loss)#更新學習率
optimizer.step()#更新網路中指定的需要優化的引數
pred = torch.argmax(outputs, 1)#按行求最大值,計算分類結果
current_acc=100*(pred == pids).sum().float()/len(pids)
if current_acc>max_acc:
max_acc=current_acc
print('**最高準確度更新為{}%,儲存此模型到{}**\n'.format(max_acc,best_model_path))
torch.save(model.state_dict(), best_model_path)
if batch_idx%10==0:
print('————————————————————————————————')
pred = torch.argmax(outputs, 1)
print('Epoch: {}, Batch: {}, Loss: {}'.format(epoch + 1, batch_idx, loss.data))
print('Current Accuracy: {:.2f}%'.format(100*(pred == pids).sum().float()/len(pids)))
print('————————————————————————————————\n')
return 0
main()
四、訓練列印日誌
=> Market1501 loaded
------------------------------------------------------------------------
subset: train | num_id: 751 | num_imgs: 12936
subset: query | num_id: 750 | num_imgs: 3368
subset: gallery | num_id: 751 | num_imgs: 19732
------------------------------------------------------------------------
total | num_id: 1501 | num_imgs: 16304
------------------------------------------------------------------------
train_data_loader inited
query_data_loader inited
gallery_data_loader inited
=>ReIDNet loaded
Model size: 4.53483M
————model start training————
**最高準確度更新為0.0%,儲存此模型到./model/param/net_params_best.pth**
————————————————————————————————
Epoch: 1, Batch: 0, Loss: 6.598439693450928
Current Accuracy: 0.00%
————————————————————————————————
**最高準確度更新為21.875%,儲存此模型到./model/param/net_params_best.pth**
————————————————————————————————
Epoch: 1, Batch: 100, Loss: 6.623084545135498
Current Accuracy: 0.00%
————————————————————————————————
******************************中間省略若干******************************
******************************中間省略若干******************************
————————————————————————————————
————————————————————————————————
epoch: 50, batch: 400, loss: 5.373966217041016, lr:0.0009847709021836117
saving model params
model params saved
————————————————————————————————
saving model params
model params saved
————訓練總用時5.08小時————