1. 程式人生 > 程式設計 >解決Tensorflow sess.run導致的記憶體溢位問題

解決Tensorflow sess.run導致的記憶體溢位問題

下面是呼叫模型進行批量測試的程式碼(出現溢位),開始以為導致溢位的原因是資料讀入方式問題引起的,用了tf,PIL和cv等方式讀入圖片資料,發現越來越慢,記憶體佔用飆升,除錯時發現是sess.run這裡出了問題(隨著for迴圈進行速度越來越慢)。

  # Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
 
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir,filename)
 
    start = time.time()
    image_data = cv2.imread(image_path)
    image_data = cv2.resize(image_data,(w,h))
    image_data_1 = image_data - IMG_MEAN
    input_image = np.expand_dims(image_data_1,0)
 
    raw_output_up = tf.image.resize_bilinear(output_tensor_name,size=[h,w],align_corners=True) 
    raw_output_up = tf.argmax(raw_output_up,axis=3)
    
 
    predict_img = sess.run(raw_output_up,feed_dict={input_image_tensor: input_image})    # 1,height,width
    predict_img = np.squeeze(predict_img)   # height, width 
 
    voc_palette = visual.make_palette(3)
    masked_im = visual.vis_seg(image_data,predict_img,voc_palette)
    cv2.imwrite("%s_pred.png" % (save_dir + filename.split(".")[0]),masked_im)
 
 
    print(time.time() - start)
 
  print(">>>>>>Done")

下面是解決溢位問題的程式碼(將部分程式碼放在for迴圈外

  # Creates graph from saved GraphDef
  create_graph(pb_path)
 
  # Init tf Session
  config = tf.ConfigProto()
  config.gpu_options.allow_growth = True
  sess = tf.Session(config=config)
  init = tf.global_variables_initializer()
  sess.run(init)
 
  input_image_tensor = sess.graph.get_tensor_by_name("create_inputs/batch:0") 
  output_tensor_name = sess.graph.get_tensor_by_name("conv6/out_1:0") 
  
##############################################################################################################
  raw_output_up = tf.image.resize_bilinear(output_tensor_name,align_corners=True) 
  raw_output_up = tf.argmax(raw_output_up,axis=3)
##############################################################################################################
 
  for filename in os.listdir(image_dir):
    image_path = os.path.join(image_dir,0)
    
    predict_img = sess.run(raw_output_up,masked_im)
    print(time.time() - start)
 
  print(">>>>>>Done")

總結:

在迭代過程中,在sess.run的for迴圈中不要加入tensorflow一些op操作,會增加圖節點,否則隨著迭代的進行,tf的圖會越來越大,最終導致溢位;

建議不要使用tf.gfile.FastGFile(image_path,'rb').read()讀入資料(有可能會造成溢位),用opencv之類讀取。

以上這篇解決Tensoflow sess.run導致的記憶體溢位問題就是小編分享給大家的全部內容了,希望能給大家一個參考,也希望大家多多支援我們。