關於pytorch語義分割二分類問題的兩種做法
形式1:輸出為單通道
分析
即網路的輸出 output
為 [batch_size, 1, height, width] 形狀。其中 batch_szie
為批量大小,1
表示輸出一個通道,height
和 width
與輸入影象的高和寬保持一致。
在訓練時,輸出通道數是 1,網路得到的 output
包含的數值是任意的數。給定的 target
,是一個單通道標籤圖,數值只有 0 和 1 這兩種。為了讓網路輸出 output
不斷逼近這個標籤,首先會讓 output
經過一個sigmoid 函式,使其數值歸一化到[0, 1],得到 output1
,然後讓這個 output1
與 target
進行交叉熵計算,得到損失值,反向傳播更新網路權重。最終,網路經過學習,會使得 output1
target
。
訓練結束後,網路已經具備讓輸出的 output
經過轉換從而逼近 target
的能力。首先將輸出的 output
通過sigmoid 函式,然後取一個閾值(一般設定為0.5),大於閾值則取1反之則取0,從而得到預測圖 predict
。後續則是一些評估相關的計算。
程式碼實現
在這個過程中,訓練的損失函式為二進位制交叉熵損失函式,然後根據輸出是否用到了sigmoid有兩種可選的pytorch實現方式:
ouput = net(input) # net的最後一層沒有使用sigmoid loss_func1 = torch.nn.BCEWithLogitsLoss() loss = loss_func1(ouput, target)
當網路最後一層沒有使用sigmoid時,需要使用 torch.nn.BCEWithLogitsLoss()
,顧名思義,在這個函式中,拿到output首先會做一個sigmoid操作,再進行二進位制交叉熵計算。上面的操作等價於
ouput = net(input) # net的最後一層沒有使用sigmoid
output = F.sigmoid(ouput)
loss_func1 = torch.nn.BCEWithLoss()
loss = loss_func1(ouput, target)
當然,你也可以在網路最後一層加上sigmoid操作。從而省去第二行的程式碼(在預測時也可以省去)。
在預測試時,可用下面的程式碼實現預測圖的生成
ouput = net(input) # net的最後一層沒有使用sigmoid
output = F.sigmoid(ouput)
predict = torch.where(output>0.5,torch.ones_like(output),torch.zeros_like(output))
...
即大於0.5的記為1,小於0.5記為0。
形式2:輸出為多通道
分析
即網路的輸出 output
為 [batch_size, num_class, height, width] 形狀。其中 batch_szie
為批量大小,num_class
表示輸出的通道數與分類數量一致,height
和 width
與輸入影象的高和寬保持一致。
在訓練時,輸出通道數是 num_class
(這裡取2),網路得到的 output
包含的數值是任意的數。給定的 target
,是一個單通道標籤圖,數值只有 0 和 1 這兩種。為了讓網路輸出 output
不斷逼近這個標籤,首先會讓 output
經過一個 softmax 函式,使其數值歸一化到[0, 1],得到 output1
,在各通道中,這個數值加起來會等於1。對於target
他是一個單通道圖,首先使用onehot
編碼,轉換成 num_class
個通道的影象,每個通道中的取值是根據單通道中的取值計算出來的,例如單通道中的第一個畫素取值為1(0<= 1 <=num_class-1,這裡num_class=2),那麼onehot
編碼後,在第一個畫素的位置上,兩個通道的取值分別為0,1。也就是說畫素的取值決定了對應序號的通道取1,其他的通道取0,這個非常關鍵。上面的操作執行完後得到target1
,讓這個 output1
與 target1
進行交叉熵計算,得到損失值,反向傳播更新網路權重。最終,網路經過學習,會使得 output1
逼近target1
(在各通道層面上)。
訓練結束後,網路已經具備讓輸出的 output
經過轉換從而逼近 target
的能力。計算 output
中各通道每一個畫素位置上,取值最大的那個對應的通道序號,從而得到預測圖 predict
。後續則是一些評估相關的計算。
在這個過程中,則可以使用交叉熵損失函式:
ouput = net(input) # net的最後一層沒有使用sigmoid
loss_func = torch.nn.CrossEntropyLoss()
loss = loss_func(ouput, target)
根據前面的分析,我們知道,正常的ouput
是 [batch_size, num_class, height, width]形狀的,而target
是[batch_size, height, width]形狀的,需要按照上面的分析進行轉換才可以計算交叉熵,而在pytorch中,我們不需要進一步做這個處理,直接使用就可以了。
在預測試時,使用下面的程式碼實現預測圖的生成
ouput = net(input) # net的最後一層沒有使用sigmoid
predict = output.argmax(dim=1)
...
即得到輸出後,在通道方向上找出最大值所在的索引號。
小結
總的來說,我覺得第二種方式更值得推廣,一方面不用考慮閾值的選取問題;另一方面,該方法同樣適用於多類別的語義分割任務,通用性更強。
參考資料
[1]https://blog.csdn.net/longshaonihaoa/article/details/105253553