1. 程式人生 > 其它 >行人重識別(12)——程式碼實踐之模型訓練(train_model(表徵學習).py)

行人重識別(12)——程式碼實踐之模型訓練(train_model(表徵學習).py)

技術標籤:行人重識別演算法計算機視覺行人重識別

!轉載請註明原文地址!——東方旅行者

更多行人重識別文章移步我的專欄:行人重識別專欄

本文目錄

表徵學習模型訓練(train_model(表徵學習).py)

一、train_model(表徵學習).py作用

本檔案是行人重識別系統的核心檔案,用於模型訓練。

二、train_model(表徵學習).py編寫思路

為了便於管理一些常用引數,將其形成變數單獨指定,這些引數有:

1.	輸入圖片寬度width
2.	輸入圖片高度height
3.	訓練批次大小train_batch_size
4.	測試批次大小test_batch_size
5.	學習率train_lr
6.	開始訓練的批次start_epoch
7.	結束訓練的批次end_epoch
8.	動態學習率變化步長dy_step_size
9.	動態學習率變化倍數dy_step_gamma
10.	是否測試evaluate
11.	模型儲存的地址best_model_path、final_model_path
12.	最大準確率(用於對判斷最優模型)max_acc

1.train()方法

訓練函式train()需要的引數有epoch(當前批次)、model(使用的模型)、criterion_class(損失函式)、optimizer(優化器型別,用於反向傳播優化網路引數)、scheduler(用於管理學習率)、data_loader(指定資料載入器,獲得網路的輸入資料)。
首先使用enumerate迭代資料載入器吐出資料,每一次突出的資料是當前批數、(圖片,行人ID,攝像機ID),在獲取輸入資料後,對優化器進行清零,防止上次計算結果對這次計算產生影響,使用模型進行運算,得到運算結果(表徵學習階段計算結果就是分類向量),使用損失函式根據計算結果與真實結果計算損失,然後使用損失的backward()方法對損失進行反向傳播優化計算,然後使用scheduler的step()方法更新學習率,在損失的backward()計算完畢後,一定不要忘記使用optimizer的step()方法更新引數,否則網路引數不會修改


然後使用torch.argmax按行求最大值,即計算分類結果,然後根據分類結果與真實結果計算本次訓練的準確率。若本次準確率高於最高準確率,則儲存此時模型至最優模型地址並列印相關資訊。

2.main()方法

主函式main()是訓練過程呼叫的函式,首先需要使用資料管理器載入資料集,便於之後資料載入器獲取索引列表。然後進行訓練資料處理器、測試資料處理器的宣告。
訓練資料處理器需要使用自定義的隨機裁剪、水平翻轉、將圖片轉為張量、歸一化。
測試資料處理器需要使用圖片尺度調整(統一尺寸)、將圖片轉為張量、歸一化。
然後宣告訓練集資料吞吐器、測試集資料吞吐器與查詢集資料吞吐器。在生成吞吐器需要指定使用的Dataset型別的資料集(自定義的ImageDataset,在dataset_loader.py中,同時在這裡還需要指定使用的資料處理器),一個批次的大小,是否重排序,是否丟棄最後無法組成一整個批次的資料。
接下來載入模型,設定損失函式、優化器(torch.optim.SGD需要指定待優化的引數,學習率,權重衰減)、動態學習率(lr_scheduler.ReduceLROnPlateau需要指定使用的優化器;mode可選擇‘min’或者‘max’,min表示當監控量停止下降的時候,學習率將減小,max表示當監控量停止上升的時候,學習率將減小。預設值為‘min’;factor學習率每次降低多少;patience容忍網路的效能不提升的次數,高於這個次數就降低學習率;min_lr學習率的下限)等。
設定模型為訓練模式,呼叫train函式進行模型訓練。最後將模型引數儲存到目標地址。

三、程式碼

import os,sys,time,datetime
import os.path as osp
import numpy as np
from IPython import embed

import torch
import torch.nn as nn
import transform as T
from torch.utils.data import DataLoader
from torch.optim import lr_scheduler
from model.ReIDNet import ReIDNet
from dataset_manager import Market1501#資料管理器
from dataset_loader import ImageDataset#資料載入器

"""
本檔案是行人重識別系統的核心檔案,用於表徵學習模型訓練。
"""
#設定輸入引數
width=64			#圖片寬度
height=128			#圖片高度
train_batch_size=32	#訓練批量
test_batch_size=32	#測試批量
train_lr=0.01		#學習率
start_epoch=0		#開始訓練的批次
end_epoch=1			#結束訓練的批次
dy_step_size=800	#動態學習率變化步長
dy_step_gamma=0.9	#動態學習率變化倍數
evaluate=False		#是否測試
max_acc=-1			#最大準確率
best_model_path='./model/param/net_params_best.pth'		#最優模型儲存地址
final_model_path='./model/param/net_params_final.pth'	#最終模型儲存地址

def main():
    #資料集載入
    dataset=Market1501()
    
    #訓練資料處理器
    transform_train=T.Compose([
        T.Random2DTransform(height,width),#尺度統一,隨機裁剪
        T.RandomHorizontalFlip(),#水平翻轉
        T.ToTensor(),#圖片轉張量
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),#歸一化,引數固定
    ]
    )
    
    #測試資料處理器
    transform_test=T.Compose([
        T.Resize((height,width)),#尺度統一
        T.ToTensor(),#圖片轉張量
        T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),#歸一化,引數固定
    ]
    )
    
    #train資料集吞吐器
    train_data_loader=DataLoader(
        ImageDataset(dataset.train, transform=transform_train),#自定義的資料集,使用訓練資料處理器
        batch_size=train_batch_size,#一個批次的大小(一個批次有多少個圖片張量)
        drop_last=True,#丟棄最後無法稱為一整個批次的資料
    )
    print("train_data_loader inited")
    
    #query資料集吞吐器
    query_data_loader=DataLoader(
        ImageDataset(dataset.query, transform=transform_test),#自定義的資料集,使用測試資料處理器
        batch_size=test_batch_size,#一個批次的大小(一個批次有多少個圖片張量)
        shuffle=False,#不重排
        drop_last=True,#丟棄最後無法稱為一整個批次的資料
    )
    print("query_data_loader inited")
    
    #gallery資料集吞吐器
    gallery_data_loader=DataLoader(
        ImageDataset(dataset.gallery, transform=transform_test),#自定義的資料集,使用測試資料處理器
        batch_size=test_batch_size,#一個批次的大小(一個批次有多少個圖片張量)
        shuffle=False,#不重排
        drop_last=True,#丟棄最後無法稱為一整個批次的資料
    )
    print("gallery_data_loader inited\n")
    
    #載入模型
    model=ReIDNet(num_classes=751,loss={'softmax'})#指定分類的數量,與使用的損失函式以便決定模型輸出何種計算結果
    print("=>ReIDNet loaded")
    print("Model size: {:.5f}M\n".format(sum(p.numel() for p in model.parameters())/1000000.0))
    
    #損失函式
    criterion_class=nn.CrossEntropyLoss()
    
    """
    優化器
    引數1,待優化的引數
    引數2,學習率
    引數3,權重衰減
    """
    optimizer=torch.optim.SGD(model.parameters(),lr=train_lr,weight_decay=5e-04)
    
    """
    動態學習率
    引數1,指定使用的優化器
    引數2,mode,可選擇‘min’(min表示當監控量停止下降的時候,學習率將減小)或者‘max’(max表示當監控量停止上升的時候,學習率將減小)
    引數3,factor,代表學習率每次降低多少
    引數4,patience,容忍網路的效能不提升的次數,高於這個次數就降低學習率
    引數5,min_lr,學習率的下限
    """
    scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=dy_step_gamma, patience=10, min_lr=0.0001)
    
    #如果是測試
    if evaluate:
        test(model,query_data_loader,gallery_data_loader)
        return 0
    #如果是訓練
    print('————model start training————\n')
    bt=time.time()#訓練的開始時間
    for epoch in range(start_epoch,end_epoch):
        model.train(True)
        train(epoch,model,criterion_class,optimizer,scheduler,train_data_loader)
    et=time.time()#訓練的結束時間
    print('**模型訓練結束, 儲存最終引數到{}**\n'.format(final_model_path))
    torch.save(model.state_dict(), final_model_path)
    print('————訓練總用時{:.2f}小時————'.format((et-bt)/3600.0))

def train(epoch, model, criterion_class, optimizer, scheduler, data_loader):
    """
    訓練函式train
    引數1,epoch(當前批次)
    引數2,model(使用的模型)
    引數3,criterion_class(損失函式)
    引數4,optimizer(優化器型別,用於反向傳播優化網路引數)
    引數5,scheduler(用於管理學習率)
    引數6,data_loader(指定資料載入器,獲得網路的輸入資料)
    """
    global max_acc
    for batch_idx, (imgs, pids, cids) in enumerate(data_loader):
        optimizer.zero_grad()#優化器進行清零,防止上次計算結果對這次計算產生影響
        outputs=model(imgs)
        loss=criterion_class(outputs,pids)#使用損失函式根據計算結果與真實結果計算損失
        loss.backward()#根據損失進行反向傳播優化計算
        scheduler.step(loss)#更新學習率
        optimizer.step()#更新網路中指定的需要優化的引數
        pred = torch.argmax(outputs, 1)#按行求最大值,計算分類結果
        current_acc=100*(pred == pids).sum().float()/len(pids)
        if current_acc>max_acc:
            max_acc=current_acc
            print('**最高準確度更新為{}%,儲存此模型到{}**\n'.format(max_acc,best_model_path))
            torch.save(model.state_dict(), best_model_path)
        if batch_idx%10==0:
            print('————————————————————————————————')
            pred = torch.argmax(outputs, 1)
            print('Epoch: {}, Batch: {}, Loss: {}'.format(epoch + 1, batch_idx, loss.data))
            print('Current Accuracy: {:.2f}%'.format(100*(pred == pids).sum().float()/len(pids)))     
            print('————————————————————————————————\n')
    return 0

main()

四、訓練列印日誌

=> Market1501 loaded
------------------------------------------------------------------------
  subset: train  	| num_id:   751  	|  num_imgs:   12936  
  subset: query  	| num_id:   750  	|  num_imgs:    3368  
  subset: gallery 	| num_id:   751  	|  num_imgs:   19732  
------------------------------------------------------------------------
  total 			| num_id:  1501  	|  num_imgs:   16304  
------------------------------------------------------------------------
train_data_loader inited
query_data_loader inited
gallery_data_loader inited

=>ReIDNet loaded
Model size: 4.53483M

————model start training————

**最高準確度更新為0.0%,儲存此模型到./model/param/net_params_best.pth**

————————————————————————————————
Epoch: 1, Batch: 0, Loss: 6.598439693450928
Current Accuracy: 0.00%
————————————————————————————————

**最高準確度更新為21.875%,儲存此模型到./model/param/net_params_best.pth**

————————————————————————————————
Epoch: 1, Batch: 100, Loss: 6.623084545135498
Current Accuracy: 0.00%
————————————————————————————————
******************************中間省略若干******************************
******************************中間省略若干******************************
————————————————————————————————
————————————————————————————————
epoch: 50, batch: 400, loss: 5.373966217041016, lr:0.0009847709021836117
saving model params
model params saved
————————————————————————————————

saving model params
model params saved

————訓練總用時5.08小時————