1. 程式人生 > >語義分割,去除邊緣線程式碼

語義分割,去除邊緣線程式碼



import tensorflow as tf
import scipy.misc as msc



'''
    對於語義分割的邊緣線,白色的為255,進行去除
'''




def remove_ignore_label(gt ,output=None ,pred=None):
    '''   去除label為255的值,進行交叉熵的計算
    gt: not one-hot
    output: a distriution of all labels, and is scaled to macth the size of gt
    NOTE the result is a flatted tensor
    and all label which is bigger that or equal to self.category_num is void label
    '''
    gt = tf.reshape(gt ,shape=[-1])  # (180000,)  把矩陣  轉化為向量

    indices = tf.squeeze(tf.where(tf.less(gt, 21)) ,axis=1) #除去邊緣線  判斷是否小於 255
    #tf.less(gt, 21)  找到所以小於21的label,相當於除去邊緣線,  某位置< 21  返回True, 否則返回False
    #tf.where(tf.less(gt, 21))  返回這個位置的index,在為True的位置
    #tf.squeeze 壓縮為1的維度

    gt = tf.gather(gt ,indices)
    #  根據indices  取出這個位置的值


    if output is not None:
        output = tf.reshape(output, shape=[-1, 21])   #轉化為21維度的特徵,每個特徵,相當於一張圖片
        output = tf.gather(output ,indices) # output 輸出也是  [小於21的索引值(相當與一張圖片除為255的所以值) , 21]
        return gt ,output
    elif pred is not None:
        pred = tf.reshape(pred, shape=[-1])
        pred = tf.gather(pred, indices)
        return gt ,pred



label = tf.truncated_normal(shape=(3,281,500),stddev=0.1)  # 輸入圖片的label (b, w, h)

output = tf.truncated_normal(shape=(3,281,500,21),stddev=0.1) # 網路輸出圖片的概率 (b, w, h, c)  #21代表類別



label,output = remove_ignore_label(label,output)
label = tf.cast(label, tf.int32)
loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=output))

#這裡進行計算交叉熵

with tf.Session() as sess:
    print(sess.run(loss))