caffe Python API 之Model訓練
阿新 • • 發佈:2018-11-06
# 訓練設定 # 使用GPU caffe.set_device(gpu_id) # 若不設定,預設為0 caffe.set_mode_gpu() # 使用CPU caffe.set_mode_cpu() # 載入Solver,有兩種常用方法 # 1. 無論模型中Slover型別是什麼統一設定為SGD solver = caffe.SGDSolver('/home/xxx/data/solver.prototxt') # 2. 根據solver的prototxt中solver_type讀取,預設為SGD solver = caffe.get_solver('/home/xxx/data/solver.prototxt') # 訓練模型 # 1.1 前向傳播 solver.net.forward() # train net solver.test_nets[0].forward() # test net (there can be more than one) # 1.2 反向傳播,計算梯度 solver.net.backward() # 2. 進行一次前向傳播一次反向傳播並根據梯度更新引數 solver.step(1) # 3. 根據solver檔案中設定進行完整model訓練 solver.solve()
如果想在訓練過程中儲存模型引數,呼叫
solver.net.save('mymodel.caffemodel')