caffe學習(8)Solver 配置詳解
Solver是求解學習模型的核心配置檔案,網路確定後,solver就決定了學習的效果。本文結合caffe.proto和網上資料,對solver配置進行學習。
Solver在caffe中的定義
通常的solver檔案與net檔案相互關聯,同樣的net我們往往使用不同的solver嘗試得到最好的效果,其執行程式碼為:
caffe train --solver=*_slover.prototxt
關於solver的一切,都在caffe.proto檔案中message SolverParameter 這一部分。
網路檔案源
// Proto filename for the train net, possibly combined with one or more
// test nets.
optional string net = 24;
// Inline train net param, possibly combined with one or more test nets.
optional NetParameter net_param = 25;
optional string train_net = 1; // Proto filename for the train net.
repeated string test_net = 2; // Proto filenames for the test nets.
optional NetParameter train_net_param = 21 ; // Inline train net params.
repeated NetParameter test_net_param = 22; // Inline test net params.
這是最開始的部分,需要說明net檔案的位置。在這四個train_net_param, train_net, net_param, net欄位中至少需要出現一個,當出現多個時,就會按著(1) test_net_param, (2) test_net, (3) net_param/net 的順序依次求解。必須為每個test_net指定一個test_iter。還可以為每個test_net指定test_level和/或test_stage。注意的是:檔案的路徑要從caffe的根目錄開始,其它的所有配置都是這樣。
可以看到這幾行的標籤序號並不是順序的,也說明caffe在不斷地修改,下一個可用的序號是41。
網路狀態
// The states for the train/test nets. Must be unspecified or
// specified once per net.
//
// By default, all states will have solver = true;
// train_state will have phase = TRAIN,
// and all test_state's will have phase = TEST.
// Other defaults are set according to the NetState defaults.
optional NetState train_state = 26;
repeated NetState test_state = 27;
網路狀態必須是未指定的或者只能在一個網路中指定一次。
關於NetState,其定義為:
message NetState {
optional Phase phase = 1 [default = TEST];
optional int32 level = 2 [default = 0];
repeated string stage = 3;
}
enum Phase {
TRAIN = 0;
TEST = 1;
}
迭代器
// The number of iterations for each test net.
repeated int32 test_iter = 3;
首先是test_iter
,這需要與test layer中的batch_size結合起來理解。mnist資料中測試樣本總數為10000,一次性執行全部資料效率很低,因此我們將測試資料分成幾個批次來執行,每個批次的數量就是batch_size。假設我們設定batch_size為100,則需要迭代100次才能將10000個數據全部執行完。因此test_iter設定為100。執行完一次全部資料,稱之為一個epoch。
// The number of iterations between two testing phases.
optional int32 test_interval = 4 [default = 0];
optional bool test_compute_loss = 19 [default = false];
// If true, run an initial test pass before the first iteration,
// ensuring memory availability and printing the starting value of the loss.
optional bool test_initialization = 32 [default = true];
test_interval
是指測試間隔,每訓練test_interval次,進行一次測試。同時test_compute_loss
可以選擇是否計算loss。test_initialization
是指在第一次迭代前,計算初始的loss以確保記憶體可用。
optional float base_lr = 5; // The base learning rate
// the number of iterations between displaying info. If display = 0, no info
// will be displayed.
optional int32 display = 6;
// Display the loss averaged over the last average_loss iterations
optional int32 average_loss = 33 [default = 1];
optional int32 max_iter = 7; // the maximum number of iterations
// accumulate gradients over `iter_size` x `batch_size` instances
optional int32 iter_size = 36 [default = 1];
base_lr
指基礎的學習率;display
是資訊顯示間隔,迭代一定次數顯示一次資訊。average_loss
用於顯示在上次average_loss迭代中的平均損失。max_iter
是最大迭代次數,需要合適設定達到精度、震盪的平衡。iter_size
是迭代器大小,梯度的計算是通過iter_size
x batch_size
決定的。
學習策略
optional string lr_policy = 8;
optional float gamma = 9; // The parameter to compute the learning rate.
optional float power = 10; // The parameter to compute the learning rate.
optional float momentum = 11; // The momentum value.
optional float weight_decay = 12; // The weight decay.
// regularization types supported: L1 and L2
// controlled by weight_decay
optional string regularization_type = 29 [default = "L2"];
// the stepsize for learning rate policy "step"
optional int32 stepsize = 13;
// the stepsize for learning rate policy "multistep"
repeated int32 stepvalue = 34;
只要是梯度下降法來求解優化,都會有一個學習率,也叫步長。base_lr用於設定基礎學習率,在迭代的過程中,可以對基礎學習率進行調整。怎麼樣進行調整,就是調整的策略,由lr_policy來設定。caffe提供了多種policy:
- fixed: 總是返回base_lr(學習率不變)
- step: 返回 base_lr * gamma ^ (floor(iter / step))
還需要設定stepsize引數以確定step,iter表示當前迭代次數。 - exp: 返回base_lr * gamma ^ iter, iter為當前迭代次數
- inv: 如果設定為inv,還需要設定一個power, 返回base_lr * (1 + gamma * iter) ^ (- power)
- multistep: 如果設定為multistep,則還需要設定一個stepvalue。這個引數和step很相似,step是均勻等間隔變化,而multistep則是根據stepvalue值變化。
- poly: 學習率進行多項式誤差, 返回 base_lr (1 - iter/max_iter) ^ (power)
- sigmoid: 學習率進行sigmod衰減,返回 base_lr ( 1/(1 + exp(-gamma * (iter - stepsize))))。
multistep示例:
base_lr: 0.01
momentum: 0.9
weight_decay: 0.0005
# The learning rate policy
lr_policy: "multistep"
gamma: 0.9
stepvalue: 5000
stepvalue: 7000
stepvalue: 8000
stepvalue: 9000
stepvalue: 9500
之後有momentum
,上次梯度更新的權重;weight_decay
權重衰減,防止過擬合;regularization_type
正則化方式。
clip_gradients
optional float clip_gradients = 35 [default = -1];
引數梯度的實際L2範數較大時,將clip_gradients設定為> = 0,以將引數梯度剪下到該L2範數。具體作用還不是很理解。
snapshot快照
optional int32 snapshot = 14 [default = 0]; // The snapshot interval
optional string snapshot_prefix = 15; // The prefix for the snapshot.
// whether to snapshot diff in the results or not. Snapshotting diff will help
// debugging but the final protocol buffer size will be much larger.
optional bool snapshot_diff = 16 [default = false];
enum SnapshotFormat {
HDF5 = 0;
BINARYPROTO = 1;
}
optional SnapshotFormat snapshot_format = 37 [default = BINARYPROTO];
快照可以將訓練出來的model和solver狀態進行儲存,snapshot
用於設定訓練多少次後進行儲存,預設為0,不儲存。snapshot_prefix
設定儲存路徑。還可以設定snapshot_diff
,是否儲存梯度值,儲存有利於除錯,但需要較大空間儲存,預設為false,不儲存。也可以設定snapshot_format
,儲存的型別。有兩種選擇:HDF5 和BINARYPROTO ,預設為BINARYPROTO。
執行模式
enum SolverMode {
CPU = 0;
GPU = 1;
}
optional SolverMode solver_mode = 17 [default = GPU];
// the device_id will that be used in GPU mode. Use device_id = 0 in default.
optional int32 device_id = 18 [default = 0];
// If non-negative, the seed with which the Solver will initialize the Caffe
// random number generator -- useful for reproducible results. Otherwise,
// (and by default) initialize using a seed derived from the system clock.
optional int64 random_seed = 20 [default = -1];
設定CPU或GPU模式,在GPU下還可以指定使用哪一塊GPU執行。random_seed
用於初始生成隨機數種子。
Solver型別
// type of the solver
optional string type = 40 [default = "SGD"];
// numerical stability for RMSProp, AdaGrad and AdaDelta and Adam
optional float delta = 31 [default = 1e-8];
// parameters for the Adam solver
optional float momentum2 = 39 [default = 0.999];
// RMSProp decay value
// MeanSquare(t) = rms_decay*MeanSquare(t-1) + (1-rms_decay)*SquareGradient(t)
optional float rms_decay = 38;
type
是solver的型別,目前有SGD、NESTEROV、ADAGRAD、RMSPROP、ADADELTA、ADAM = 5這六類。之後的一些是這些型別的特有引數,根據需要設定。
雜項
// If true, print information about the state of the net that may help with
// debugging learning problems.
optional bool debug_info = 23 [default = false];
// If false, don't save a snapshot after training finishes.
optional bool snapshot_after_train = 28 [default = true];
debug_info
用於輸出除錯資訊。snapshot_after_train
用於訓練後是否輸出快照。