解決Tensorflow sess.run導致的記憶體溢位問題
阿新 • • 發佈:2020-02-06
下面是呼叫模型進行批量測試的程式碼(出現溢位),開始以為導致溢位的原因是資料讀入方式問題引起的,用了tf,PIL和cv等方式讀入圖片資料,發現越來越慢,記憶體佔用飆升,除錯時發現是sess.run這裡出了問題(隨著for迴圈進行速度越來越慢)。
# Creates graph from saved GraphDef create_graph(pb_path) # Init tf Session config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") for filename in os.listdir(image_dir): image_path = os.path.join(image_dir,filename) start = time.time() image_data = cv2.imread(image_path) image_data = cv2.resize(image_data,(w,h)) image_data_1 = image_data - IMG_MEAN input_image = np.expand_dims(image_data_1,0) raw_output_up = tf.image.resize_bilinear(output_tensor_name,size=[h,w],align_corners=True) raw_output_up = tf.argmax(raw_output_up,axis=3) predict_img = sess.run(raw_output_up,feed_dict={input_image_tensor: input_image}) # 1,height,width predict_img = np.squeeze(predict_img) # height, width voc_palette = visual.make_palette(3) masked_im = visual.vis_seg(image_data,predict_img,voc_palette) cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]),masked_im) print(time.time() - start) print(">>>>>>Done")
下面是解決溢位問題的程式碼(將部分程式碼放在for迴圈外)
# Creates graph from saved GraphDef create_graph(pb_path) # Init tf Session config = tf.ConfigProto() config.gpu_options.allow_growth = True sess = tf.Session(config=config) init = tf.global_variables_initializer() sess.run(init) input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") ############################################################################################################## raw_output_up = tf.image.resize_bilinear(output_tensor_name,align_corners=True) raw_output_up = tf.argmax(raw_output_up,axis=3) ############################################################################################################## for filename in os.listdir(image_dir): image_path = os.path.join(image_dir,0) predict_img = sess.run(raw_output_up,masked_im) print(time.time() - start) print(">>>>>>Done")
總結:
在迭代過程中,在sess.run的for迴圈中不要加入tensorflow一些op操作,會增加圖節點,否則隨著迭代的進行,tf的圖會越來越大,最終導致溢位;
建議不要使用tf.gfile.FastGFile(image_path,'rb').read()讀入資料(有可能會造成溢位),用opencv之類讀取。
以上這篇解決Tensoflow sess.run導致的記憶體溢位問題就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。