1. 程式人生 > 程式設計 >pytorch實現對輸入超過三通道的資料進行訓練

pytorch實現對輸入超過三通道的資料進行訓練

案例背景:視訊識別

假設每次輸入是8s的灰度視訊,視訊幀率為25fps,則視訊由200幀影象序列構成.每幀是一副單通道的灰度影象,通過pythonb裡面的np.stack(深度拼接)可將200幀拼接成200通道的深度資料.進而送到網路裡面去訓練.

如果輸入影象200通道覺得多,可以對視訊進行抽幀,針對具體場景可以隨機抽幀或等間隔抽幀.比如這裡等間隔抽取40幀.則最後輸入視訊相當於輸入一個40通道的影象資料了.

pytorch對超過三通道資料的載入:

讀取視訊每一幀,轉為array格式,然後依次將每一幀進行深度拼接,最後得到一個40通道的array格式的深度資料,儲存到pickle裡.

對每個視訊都進行上述操作,儲存到pickle裡.

我這裡將火的視訊深度資料儲存在一個.pkl檔案中,一共2504個火的視訊,即2504個火的深度資料.

將非火的視訊深度資料儲存在一個.pkl檔案中,一共3985個非火的視訊,即3985個非火的深度資料.

資料載入

import torch 
from torch.utils import data
import os
from PIL import Image
import numpy as np
import pickle
 
class Fire_Unfire(data.Dataset):
  def __init__(self,fire_path,unfire_path):
    self.pickle_fire = open(fire_path,'rb')
    self.pickle_unfire = open(unfire_path,'rb')
    
  def __getitem__(self,index):
    if index <2504:
      fire = pickle.load(self.pickle_fire)#高*寬*通道
      fire = fire.transpose(2,1)#通道*高*寬
      data = torch.from_numpy(fire)
      label = 1
      return data,label
    elif index>=2504 and index<6489:
      unfire = pickle.load(self.pickle_unfire)
      unfire = unfire.transpose(2,1)
      data = torch.from_numpy(unfire)
      label = 0
      return data,label
    
  def __len__(self):
    return 6489
root_path = './datasets/train'
dataset = Fire_Unfire(root_path +'/fire_train.pkl',root_path +'/unfire_train.pkl')
 
#轉換成pytorch網路輸入的格式(批量大小,通道數,高,寬)
from torch.utils.data import DataLoader
fire_dataloader = DataLoader(dataset,batch_size=4,shuffle=True,drop_last = True)

模型訓練

import torch
from torch.utils import data
from nets.mobilenet import mobilenet
from config.config import default_config
from torch.autograd import Variable as V
import numpy as np
import sys
import time
 
opt = default_config()
def train():
  #模型定義
  model = mobilenet().cuda()
  if opt.pretrain_model:
    model.load_state_dict(torch.load(opt.pretrain_model))
  
  #損失函式
  criterion = torch.nn.CrossEntropyLoss().cuda()
  
  #學習率
  lr = opt.lr
  
  #優化器
  optimizer = torch.optim.SGD(model.parameters(),lr = lr,weight_decay=opt.weight_decay)
  
  
  pre_loss = 0.0
  #訓練
  for epoch in range(opt.max_epoch):
     #訓練資料
    train_data = Fire_Unfire(opt.root_path +'/fire_train.pkl',opt.root_path +'/unfire_train.pkl')
    train_dataloader = data.DataLoader(train_data,batch_size=opt.batch_size,drop_last = True)
    loss_sum = 0.0
    for i,(datas,labels) in enumerate(train_dataloader):
      #print(i,datas.size(),labels)
      #梯度清零
      optimizer.zero_grad()
      #輸入
      input = V(datas.cuda()).float()
      #目標
      target = V(labels.cuda()).long()
      #輸出
      score = model(input).cuda()
      #損失
      loss = criterion(score,target)
      loss_sum += loss
      #反向傳播
      loss.backward()
      #梯度更新
      optimizer.step()      
    print('{}{}{}{}{}'.format('epoch:',epoch,','loss:',loss))
    torch.save(model.state_dict(),'models/mobilenet_%d.pth'%(epoch+370))

RuntimeError: Expected object of scalar type Long but got scalar type Float for argument #2 'target'

解決方案:target = target.long()

以上這篇pytorch實現對輸入超過三通道的資料進行訓練就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。