語義分割,去除邊緣線程式碼
阿新 • • 發佈:2018-11-06
import tensorflow as tf import scipy.misc as msc ''' 對於語義分割的邊緣線,白色的為255,進行去除 ''' def remove_ignore_label(gt ,output=None ,pred=None): ''' 去除label為255的值,進行交叉熵的計算 gt: not one-hot output: a distriution of all labels, and is scaled to macth the size of gt NOTE the result is a flatted tensor and all label which is bigger that or equal to self.category_num is void label ''' gt = tf.reshape(gt ,shape=[-1]) # (180000,) 把矩陣 轉化為向量 indices = tf.squeeze(tf.where(tf.less(gt, 21)) ,axis=1) #除去邊緣線 判斷是否小於 255 #tf.less(gt, 21) 找到所以小於21的label,相當於除去邊緣線, 某位置< 21 返回True, 否則返回False #tf.where(tf.less(gt, 21)) 返回這個位置的index,在為True的位置 #tf.squeeze 壓縮為1的維度 gt = tf.gather(gt ,indices) # 根據indices 取出這個位置的值 if output is not None: output = tf.reshape(output, shape=[-1, 21]) #轉化為21維度的特徵,每個特徵,相當於一張圖片 output = tf.gather(output ,indices) # output 輸出也是 [小於21的索引值(相當與一張圖片除為255的所以值) , 21] return gt ,output elif pred is not None: pred = tf.reshape(pred, shape=[-1]) pred = tf.gather(pred, indices) return gt ,pred label = tf.truncated_normal(shape=(3,281,500),stddev=0.1) # 輸入圖片的label (b, w, h) output = tf.truncated_normal(shape=(3,281,500,21),stddev=0.1) # 網路輸出圖片的概率 (b, w, h, c) #21代表類別 label,output = remove_ignore_label(label,output) label = tf.cast(label, tf.int32) loss = tf.reduce_mean(tf.nn.sparse_softmax_cross_entropy_with_logits(labels=label, logits=output)) #這裡進行計算交叉熵 with tf.Session() as sess: print(sess.run(loss))