1. 程式人生 > >BN - Batch normalization

BN - Batch normalization

BN大法無敵,這部落格從效果上分析為何如此牛逼,原理請移步其它blog,這裡就不再贅述了

論文 : Ioffe, S., & Szegedy, C. (2015). Batch normalization: Accelerating deep network training by reducing internal covariate shift. arXiv preprint arXiv:1502.03167.

 

背景

學習mxnet時,使用AlexNet去分類fashion-mnist資料集

對比,加BN前後的效果

新增BN前的網路設計如下

'''
net.add(nn.Conv2D(96, kernel_size=11, strides=4, activation='relu'),
        nn.MaxPool2D(pool_size=3, strides=2),
        nn.Conv2D(256, kernel_size=5, padding=2, activation='relu'),
        nn.MaxPool2D(pool_size=3, strides=2),
        nn.Conv2D(384, kernel_size=3, padding=1, activation='relu'),
        nn.Conv2D(384, kernel_size=3, padding=1, activation='relu'),
        nn.Conv2D(256, kernel_size=3, padding=1, activation='relu'),
        nn.MaxPool2D(pool_size=3, strides=2),
        nn.Dense(4096, activation="relu"), nn.Dropout(0.2),
        nn.Dense(4096, activation="relu"), nn.Dropout(0.5),
        nn.Dense(10))
'''

訓練的loss 和準確度變化如下

'''
epoch 1, loss 1.3111, train acc 0.508, test acc 0.762, time 41.8 sec
epoch 2, loss 0.6442, train acc 0.760, test acc 0.816, time 39.1 sec
epoch 3, loss 0.5270, train acc 0.804, test acc 0.827, time 39.1 sec
epoch 4, loss 0.4626, train acc 0.830, test acc 0.861, time 39.4 sec
epoch 5, loss 0.4231, train acc 0.846, test acc 0.867, time 39.1 sec
epoch 6, loss 0.3947, train acc 0.857, test acc 0.873, time 39.9 sec
epoch 7, loss 0.3721, train acc 0.865, test acc 0.879, time 58.0 sec
epoch 8, loss 0.3548, train acc 0.871, test acc 0.883, time 39.1 sec
epoch 9, loss 0.3379, train acc 0.877, test acc 0.882, time 39.1 sec
epoch 10, loss 0.3271, train acc 0.881, test acc 0.890, time 39.0 sec
epoch 11, loss 0.3173, train acc 0.883, test acc 0.893, time 39.1 sec
epoch 12, loss 0.3069, train acc 0.887, test acc 0.893, time 39.5 sec
epoch 13, loss 0.2972, train acc 0.892, test acc 0.898, time 39.2 sec
epoch 14, loss 0.2891, train acc 0.894, test acc 0.903, time 39.0 sec
epoch 15, loss 0.2817, train acc 0.897, test acc 0.904, time 39.1 sec
epoch 16, loss 0.2762, train acc 0.898, test acc 0.903, time 39.1 sec
epoch 17, loss 0.2703, train acc 0.901, test acc 0.909, time 39.1 sec
epoch 18, loss 0.2634, train acc 0.904, test acc 0.907, time 39.1 sec
epoch 19, loss 0.2551, train acc 0.906, test acc 0.910, time 39.1 sec
epoch 20, loss 0.2488, train acc 0.908, test acc 0.909, time 39.5 sec
epoch 21, loss 0.2432, train acc 0.910, test acc 0.910, time 39.1 sec
epoch 22, loss 0.2391, train acc 0.912, test acc 0.912, time 39.3 sec
epoch 23, loss 0.2321, train acc 0.914, test acc 0.914, time 39.3 sec
epoch 24, loss 0.2274, train acc 0.916, test acc 0.912, time 39.1 sec
epoch 25, loss 0.2204, train acc 0.918, test acc 0.913, time 39.7 sec
epoch 26, loss 0.2164, train acc 0.920, test acc 0.917, time 39.1 sec
epoch 27, loss 0.2128, train acc 0.921, test acc 0.917, time 39.1 sec
epoch 28, loss 0.2085, train acc 0.923, test acc 0.918, time 39.2 sec
epoch 29, loss 0.2016, train acc 0.925, test acc 0.920, time 39.3 sec
epoch 30, loss 0.1969, train acc 0.927, test acc 0.917, time 39.2 sec
epoch 31, loss 0.1932, train acc 0.928, test acc 0.921, time 39.2 sec
epoch 32, loss 0.1891, train acc 0.930, test acc 0.916, time 39.1 sec
epoch 33, loss 0.1865, train acc 0.930, test acc 0.919, time 39.6 sec
epoch 34, loss 0.1801, train acc 0.932, test acc 0.917, time 39.3 sec
epoch 35, loss 0.1745, train acc 0.934, test acc 0.918, time 39.7 sec
epoch 36, loss 0.1709, train acc 0.936, test acc 0.919, time 39.3 sec
epoch 37, loss 0.1673, train acc 0.938, test acc 0.922, time 39.3 sec
epoch 38, loss 0.1639, train acc 0.939, test acc 0.922, time 39.6 sec
epoch 39, loss 0.1587, train acc 0.940, test acc 0.922, time 39.8 sec
epoch 40, loss 0.1554, train acc 0.941, test acc 0.921, time 39.5 sec
'''

新增BN後的網路結構如下:

'''
net.add(nn.Conv2D(96, kernel_size=11, strides=4),
        nn.BatchNorm(),
        nn.Activation('relu'),
        nn.MaxPool2D(pool_size=3, strides=2),
        nn.Conv2D(256, kernel_size=5, padding=2),
        nn.BatchNorm(),
        nn.Activation('relu'),
        nn.MaxPool2D(pool_size=3, strides=2),
        nn.Conv2D(384, kernel_size=3, padding=1),
        nn.BatchNorm(),
        nn.Activation('relu'),
        nn.Conv2D(384, kernel_size=3, padding=1),
        nn.BatchNorm(),
        nn.Activation('relu'),
        nn.Conv2D(256, kernel_size=3, padding=1),
        nn.BatchNorm(),
        nn.Activation('relu'),
        nn.MaxPool2D(pool_size=3, strides=2),
        nn.Dense(4096),nn.BatchNorm(),nn.Activation('relu'),nn.Dropout(0.2),
        nn.Dense(2048),nn.BatchNorm(),nn.Activation('relu'),nn.Dropout(0.5),
        nn.Dense(10))
'''

訓練的loss 和準確度變化如下

'''
epoch 1, loss 0.5627, train acc 0.802, test acc 0.867, time 46.9 sec
epoch 2, loss 0.3653, train acc 0.869, test acc 0.898, time 52.3 sec
epoch 3, loss 0.3055, train acc 0.888, test acc 0.907, time 47.6 sec
epoch 4, loss 0.2716, train acc 0.902, test acc 0.914, time 45.8 sec
epoch 5, loss 0.2446, train acc 0.911, test acc 0.919, time 45.7 sec
epoch 6, loss 0.2234, train acc 0.919, test acc 0.912, time 45.6 sec
epoch 7, loss 0.2066, train acc 0.924, test acc 0.907, time 45.6 sec
epoch 8, loss 0.1908, train acc 0.930, test acc 0.923, time 45.9 sec
epoch 9, loss 0.1766, train acc 0.935, test acc 0.929, time 45.6 sec
epoch 10, loss 0.1623, train acc 0.941, test acc 0.930, time 45.5 sec
epoch 11, loss 0.1511, train acc 0.944, test acc 0.925, time 45.7 sec
epoch 12, loss 0.1388, train acc 0.949, test acc 0.929, time 46.2 sec
epoch 13, loss 0.1262, train acc 0.954, test acc 0.910, time 45.3 sec
epoch 14, loss 0.1203, train acc 0.955, test acc 0.930, time 45.7 sec
epoch 15, loss 0.1072, train acc 0.961, test acc 0.920, time 45.4 sec
epoch 16, loss 0.0997, train acc 0.963, test acc 0.932, time 45.6 sec
epoch 17, loss 0.0938, train acc 0.966, test acc 0.928, time 45.9 sec
epoch 18, loss 0.0833, train acc 0.970, test acc 0.930, time 45.7 sec
epoch 19, loss 0.0767, train acc 0.972, test acc 0.926, time 45.3 sec
epoch 20, loss 0.0707, train acc 0.975, test acc 0.935, time 45.5 sec
epoch 21, loss 0.0624, train acc 0.978, test acc 0.936, time 45.8 sec
epoch 22, loss 0.0591, train acc 0.979, test acc 0.934, time 45.5 sec
epoch 23, loss 0.0526, train acc 0.982, test acc 0.932, time 45.2 sec
epoch 24, loss 0.0481, train acc 0.983, test acc 0.930, time 45.6 sec
epoch 25, loss 0.0440, train acc 0.985, test acc 0.938, time 45.7 sec
epoch 26, loss 0.0393, train acc 0.987, test acc 0.933, time 46.0 sec
epoch 27, loss 0.0359, train acc 0.988, test acc 0.935, time 45.5 sec
epoch 28, loss 0.0341, train acc 0.988, test acc 0.936, time 45.9 sec
epoch 29, loss 0.0301, train acc 0.990, test acc 0.937, time 51.6 sec
epoch 30, loss 0.0266, train acc 0.992, test acc 0.937, time 63.5 sec
epoch 31, loss 0.0236, train acc 0.993, test acc 0.921, time 44.8 sec
epoch 32, loss 0.0201, train acc 0.994, test acc 0.938, time 45.2 sec
epoch 33, loss 0.0187, train acc 0.994, test acc 0.939, time 45.3 sec
epoch 34, loss 0.0169, train acc 0.995, test acc 0.939, time 44.4 sec
epoch 35, loss 0.0142, train acc 0.996, test acc 0.938, time 45.0 sec
epoch 36, loss 0.0126, train acc 0.997, test acc 0.937, time 61.0 sec
epoch 37, loss 0.0131, train acc 0.996, test acc 0.936, time 44.9 sec
epoch 38, loss 0.0127, train acc 0.996, test acc 0.937, time 45.0 sec
epoch 39, loss 0.0116, train acc 0.997, test acc 0.939, time 61.0 sec
epoch 40, loss 0.0105, train acc 0.997, test acc 0.937, time 45.2 sec

'''

對比兩種網路和訓練結果可以看出

加了BN之後網路的收斂速度加快了許多