pytorch訓練ImageNet筆記(二)
1.import torch.backends.cudnn as cudnn的作用:
cudnn.benchmark = true
總的來說,大部分情況下,設定這個 flag 可以讓內建的 cuDNN 的 auto-tuner 自動尋找最適合當前配置的高效演算法,來達到優化執行效率的問題。
一般來講,應該遵循以下準則:
(1)如果網路的輸入資料維度或型別上變化不大,設定 torch.backends.cudnn.benchmark = true 可以增加執行效率;
(2)如果網路的輸入資料在每次 iteration 都變化的話,會導致 cnDNN 每次都會去尋找一遍最優配置,這樣反而會降低
2.在訓練過程中,可設定set_epoch()使得在每個itertion開始的時候打亂資料的分佈
3.shuffle:bool,可選。為True時表示每個epoch都對資料進行洗牌
4.sampler:Sampler,可選。從資料集中取樣樣本的方法。
5.collate_fn (callable, optional): 將一個list的sample組成一個mini-batch的函式
6.pin_memory (bool, optional): 如果設定為True,那麼data loader將會在返回它們之前,將tensors拷貝到CUDA中的固定內(CUDA pinned memory)中.
7.drop_last (bool, optional): 如果設定為True:這個是對最後的未完成的batch來說的,比如你的batch_size設定為64,而一個epoch只有100個樣本,那麼訓練的時候後面的36個就被扔掉了… 如果為False(預設),那麼會繼續正常執行,只是最後的batch_size會小一點。
8.target = target.cuda(async=True) # 這是一種用來包裹張量並記錄應用的操作
9.args.resume這個引數主要是用來設定是否從斷點處繼續訓練,比如原來訓練模型訓到一半停止了,希望繼續從儲存的最新epoch開始訓練,因此args.resume要麼是預設的None,要麼就是你儲存的模型檔案(.pth)的路徑。其中checkpoint = torch.load(args.resume)是用來匯入已訓練好的模型。model.load_state_dict(checkpoint[‘state_dict’])是完成匯入模型的引數初始化model這個網路的過程,load_state_dict是torch.nn.Module類中重要的方法之一.
10.按照迭代的次數n輸出前n次的top1與top5平均值:
top1.update(prec1[0], input.size(0)) #更新
top5.update(prec5[0], input.size(0)) #更新
class AverageMeter(object):
# Computes and stores the average and current value
"""
batch_time = AverageMeter()
即 self = batch_time
則 batch_time 具有__init__,reset,update三個屬性,
直接使用batch_time.update()呼叫
功能為:batch_time.update(time.time() - end)
僅一個引數,則直接儲存引數值
對應定義:def update(self, val, n=1)
losses.update(loss.data[0], input.size(0))
top1.update(prec1[0], input.size(0))
top5.update(prec5[0], input.size(0))
這些有兩個引數則求引數val的均值,儲存在avg中##不確定##
"""
def __init__(self):
self.reset() # __init__():reset parameters
def reset(self):
self.val = 0
self.avg = 0
self.sum = 0
self.count = 0
def update(self, val, n=1):
self.val = val
self.sum += val * n
self.count += n
self.avg = self.sum / self.count
reset()代表初始化,在訓練過程中,每次iterion都會進行top1與top5的更新,每次iterion將會載入batch_size個數據,通過pre1.update()計算每次的平均值,val代表本次iterion的預測值,self.sum代表總的概率(256*30%+256*40%+50*70%),self.count代表載入過的資料量(每次疊加),self.avg代表平均值(通過總概率除總資料量求得)。每個epoch單獨計算平均概率。