1. 程式人生 > 實用技巧 >在mini-batch訓練中使用tqdm來建立進度條

在mini-batch訓練中使用tqdm來建立進度條

tqdm是python中的一個用來供我們建立進度條的庫。在進行深度學習的研究時,我們可以使用這個庫為我們直觀地展示當前的訓練進度,下面來說說如何在mini-batch優化中使用這個庫。

我希望程式能夠在每一個epoch都使用進度條來顯示當前epoch的訓練情況,我使用的程式碼如下:

1 from tqdm import tqdm #從tqdm庫中匯入tadm類
2 
3 for epoch in range(epochs): #訓練輪次
4     with tqdm(total = batch_num, desc=f'Epoch {epoch+1}/{epochs}', unit='it') as pbar: #
建立一個進度條 5 for batch_idx in range(batch_num): #mini-batch訓練 6 ... 7 pbar.set_postfix({'batch_loss:'loss}) #在進度條後顯示當前batch的損失 8 pbar.update(1) #更當前進度,1表示完成了一個batch的訓練

所得到的的進度條如下圖所示:

“from tqdm import tqdm”就是從tqdm庫中匯入tqdm類。一開始我寫成了“import tqdm”,導致程式報錯,所以這一點要注意。

"with tqdm(total = batch_num, desc=f'Epoch {epoch+1}/{epochs}', unit='it') as pbar: #建立一個進度條"使用python的with結構來建立一個tqdm物件pbar。如果不使用with結構,就需要在一次epoch訓練的結尾呼叫tqdm物件的close()函式。這一語句中各個引數的意思是:

total:為一個epoch中batch的總數量,即迭代的總次數;

desc:放在進度條最前的一段描述,在此顯示的是當前的epoch及總共需要多少個epoch;

unit:迭代速度的單位,it是iteration的簡寫,這裡指的是以每秒完成多少次迭代作為速度度量並顯示在進度條上,見上圖中的“1.41it/s”。

“pbar.set_postfix({'batch_loss:'loss}) #在進度條後顯示當前batch的損失”,在進度條的最後顯示當前batch的損失,如圖中的“batch_loss=1.66e+5”。

pbar.update(1)用來更新進度,這裡的1指的是完成了一個batch的訓練,讓進度條加1。

上圖中的“03:21<01:21”分別顯示訓練當前batch已經花費的時間和還需要消耗的時間。