faster RCNN的Python的畫出來loss曲線圖
阿新 • • 發佈:2019-01-06
由於要寫論文需要畫loss曲線,查詢網上的loss曲線視覺化的方法發現大多數是基於Imagenat的一些方法,在運用到Faster-Rcnn上時沒法用,本人不怎麼會編寫程式碼,所以想到能否用python直接寫一個程式碼,讀取txt然後畫出來,參考大神們的部落格,然後總和總算一下午時間,搞出來了,大牛們不要見笑。
首先,在訓練Faster-Rcnn時會自己生成log檔案,大概在/py-faster-rcnn/experiments/logs檔案下,把他直接拿出來,放在任意位置即可,因為是txt格式,可以直接用,如果嫌麻煩重新命名1.txt.接下來就是編寫程式了
一下為log基本的格式
I0530 08:54:19.183091 10143 solver.cpp:229] Iteration 22000, loss = 0.173712
I0530 08:54:19.183137 10143 solver.cpp:245] Train net output #0: rpn_cls_loss = 0.101713 (* 1 = 0.101713 loss)
I0530 08:54:19.183145 10143 solver.cpp:245] Train net output #1: rpn_loss_bbox = 0.071999 (* 1 = 0.071999 loss)
I0530 08:54:19.183148 10143 sgd_solver.cpp:106] Iteration 22000, lr = 0.001
通過發現,我們只需獲得 Iteration 和loss就行
#!/usr/bin/env python import os import sys import numpy as np import matplotlib.pyplot as plt import math import re import pylab from pylab import figure, show, legend from mpl_toolkits.axes_grid1 import host_subplot # read the log file fp = open('2.txt', 'r') train_iterations = [] train_loss = [] test_iterations = [] #test_accuracy = [] for ln in fp: # get train_iterations and train_loss if '] Iteration ' in ln and 'loss = ' in ln: arr = re.findall(r'ion \b\d+\b,',ln) train_iterations.append(int(arr[0].strip(',')[4:])) train_loss.append(float(ln.strip().split(' = ')[-1])) fp.close() host = host_subplot(111) plt.subplots_adjust(right=0.8) # ajust the right boundary of the plot window #par1 = host.twinx() # set labels host.set_xlabel("iterations") host.set_ylabel("RPN loss") #par1.set_ylabel("validation accuracy") # plot curves p1, = host.plot(train_iterations, train_loss, label="train RPN loss") #p2, = par1.plot(test_iterations, test_accuracy, label="validation accuracy") # set location of the legend, # 1->rightup corner, 2->leftup corner, 3->leftdown corner # 4->rightdown corner, 5->rightmid ... host.legend(loc=1) # set label color host.axis["left"].label.set_color(p1.get_color()) #par1.axis["right"].label.set_color(p2.get_color()) # set the range of x axis of host and y axis of par1 host.set_xlim([-1500,60000]) host.set_ylim([0., 1.6]) plt.draw() plt.show()
繪圖結果