1. 程式人生 > >【Keras初學】keras構建兩種特徵輸入,兩個輸出同時訓練

【Keras初學】keras構建兩種特徵輸入,兩個輸出同時訓練

基於Keras構建兩種不同資料的輸入和兩種輸出,進行訓練,結構圖如下:


Python程式碼如下:

from keras.layers import Input, Embedding, LSTM, Dense
from keras.models import Model
import numpy as np
import keras

#construct model
main_input = Input((100,), dtype='int32', name='main_input' )

x = Embedding(output_dim=512, input_dim=10000, input_length=100)(main_input)
lstm_out = LSTM(32)(x)
aux_output = Dense(1, activation='sigmoid', name='aux_output')(lstm_out)

aux_input = Input((5,), name='aux_input')
x = keras.layers.concatenate([lstm_out, aux_input])
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(64, activation='relu')(x)
main_output = Dense(1, activation='sigmoid', name='main_output')(x)

model = Model(inputs=[main_input, aux_input], outputs=[main_output, aux_output])
#model.compile(optimizer='rmsprop', loss='binary_crossentropy', loss_weights=[1., 0.2])
model.compile(optimizer='rmsprop', 
            loss={'main_output': 'binary_crossentropy', 'aux_output': 'binary_crossentropy'},
            loss_weights={'main_output': 1., 'aux_output': 0.3})

#train data
samples_len = 300
main_data = np.random.randint(1, 10000,size=(samples_len, 100), dtype='int32')
aux_data = np.random.randint(0,10,size=(samples_len,5), dtype='int32')
main_labels = np.random.randint(0,2,size=(samples_len,1), dtype='int32')

model.fit(x={'main_input': main_data, 'aux_input': aux_data},
            y={'main_output': main_labels, 'aux_output': main_labels},
            batch_size=32, epochs=10,verbose=1)

score = model.evaluate(x={'main_input': main_data, 'aux_input': aux_data},
            y={'main_output': main_labels, 'aux_output': main_labels},
            batch_size=10, verbose=1)
print(score)

由於是隨機的測試資料得到結果如下:

Using TensorFlow backend.
Epoch 1/10

 32/300 [==>...........................] - ETA: 1:40 - loss: 1.0214 - main_output_loss: 0.8143 - aux_output_loss: 0.6903
 64/300 [=====>........................] - ETA: 45s - loss: 0.9316 - main_output_loss: 0.7241 - aux_output_loss: 0.6917
 96/300 [========>.....................] - ETA: 26s - loss: 0.9477 - main_output_loss: 0.7403 - aux_output_loss: 0.6914
128/300 [===========>..................] - ETA: 16s - loss: 0.9376 - main_output_loss: 0.7299 - aux_output_loss: 0.6925
160/300 [===============>..............] - ETA: 11s - loss: 0.9423 - main_output_loss: 0.7346 - aux_output_loss: 0.6925
192/300 [==================>...........] - ETA: 7s - loss: 0.9327 - main_output_loss: 0.7249 - aux_output_loss: 0.6927
224/300 [=====================>........] - ETA: 4s - loss: 0.9457 - main_output_loss: 0.7378 - aux_output_loss: 0.6933
256/300 [========================>.....] - ETA: 2s - loss: 0.9622 - main_output_loss: 0.7544 - aux_output_loss: 0.6928
288/300 [===========================>..] - ETA: 0s - loss: 0.9529 - main_output_loss: 0.7450 - aux_output_loss: 0.6930
300/300 [==============================] - 14s 46ms/step - loss: 0.9525 - main_output_loss: 0.7445 - aux_output_loss: 0.6932
Epoch 2/10

 32/300 [==>...........................] - ETA: 1s - loss: 0.9614 - main_output_loss: 0.7661 - aux_output_loss: 0.6513
 64/300 [=====>........................] - ETA: 1s - loss: 0.9061 - main_output_loss: 0.7119 - aux_output_loss: 0.6476
 96/300 [========>.....................] - ETA: 1s - loss: 0.8951 - main_output_loss: 0.7019 - aux_output_loss: 0.6439
128/300 [===========>..................] - ETA: 1s - loss: 0.8800 - main_output_loss: 0.6887 - aux_output_loss: 0.6376
160/300 [===============>..............] - ETA: 0s - loss: 0.8610 - main_output_loss: 0.6718 - aux_output_loss: 0.6308
192/300 [==================>...........] - ETA: 0s - loss: 0.8584 - main_output_loss: 0.6709 - aux_output_loss: 0.6250
224/300 [=====================>........] - ETA: 0s - loss: 0.8557 - main_output_loss: 0.6702 - aux_output_loss: 0.6186
256/300 [========================>.....] - ETA: 0s - loss: 0.8485 - main_output_loss: 0.6654 - aux_output_loss: 0.6102
288/300 [===========================>..] - ETA: 0s - loss: 0.8476 - main_output_loss: 0.6675 - aux_output_loss: 0.6001
300/300 [==============================] - 2s 6ms/step - loss: 0.8350 - main_output_loss: 0.6580 - aux_output_loss: 0.5900
Epoch 3/10

 32/300 [==>...........................] - ETA: 1s - loss: 0.7031 - main_output_loss: 0.5727 - aux_output_loss: 0.4347
 64/300 [=====>........................] - ETA: 1s - loss: 0.5957 - main_output_loss: 0.4866 - aux_output_loss: 0.3637
 96/300 [========>.....................] - ETA: 1s - loss: 0.5402 - main_output_loss: 0.4414 - aux_output_loss: 0.3293
128/300 [===========>..................] - ETA: 1s - loss: 0.5114 - main_output_loss: 0.4171 - aux_output_loss: 0.3146
160/300 [===============>..............] - ETA: 0s - loss: 0.4818 - main_output_loss: 0.3899 - aux_output_loss: 0.3062
192/300 [==================>...........] - ETA: 0s - loss: 0.4568 - main_output_loss: 0.3650 - aux_output_loss: 0.3059
224/300 [=====================>........] - ETA: 0s - loss: 0.4333 - main_output_loss: 0.3431 - aux_output_loss: 0.3009
256/300 [========================>.....] - ETA: 0s - loss: 0.4143 - main_output_loss: 0.3254 - aux_output_loss: 0.2965
288/300 [===========================>..] - ETA: 0s - loss: 0.3926 - main_output_loss: 0.3060 - aux_output_loss: 0.2888
300/300 [==============================] - 2s 6ms/step - loss: 0.3831 - main_output_loss: 0.2979 - aux_output_loss: 0.2840
Epoch 4/10

 32/300 [==>...........................] - ETA: 1s - loss: 0.1192 - main_output_loss: 0.0711 - aux_output_loss: 0.1603
 64/300 [=====>........................] - ETA: 1s - loss: 0.1319 - main_output_loss: 0.0819 - aux_output_loss: 0.1668
 96/300 [========>.....................] - ETA: 1s - loss: 0.1239 - main_output_loss: 0.0735 - aux_output_loss: 0.1681
128/300 [===========>..................] - ETA: 0s - loss: 0.1151 - main_output_loss: 0.0650 - aux_output_loss: 0.1669
160/300 [===============>..............] - ETA: 0s - loss: 0.1078 - main_output_loss: 0.0593 - aux_output_loss: 0.1617
192/300 [==================>...........] - ETA: 0s - loss: 0.1043 - main_output_loss: 0.0549 - aux_output_loss: 0.1647
224/300 [=====================>........] - ETA: 0s - loss: 0.0992 - main_output_loss: 0.0513 - aux_output_loss: 0.1595
256/300 [========================>.....] - ETA: 0s - loss: 0.0966 - main_output_loss: 0.0492 - aux_output_loss: 0.1580
288/300 [===========================>..] - ETA: 0s - loss: 0.0929 - main_output_loss: 0.0466 - aux_output_loss: 0.1544
300/300 [==============================] - 2s 6ms/step - loss: 0.0920 - main_output_loss: 0.0456 - aux_output_loss: 0.1545
Epoch 5/10

 32/300 [==>...........................] - ETA: 1s - loss: 0.0591 - main_output_loss: 0.0228 - aux_output_loss: 0.1210
 64/300 [=====>........................] - ETA: 1s - loss: 0.0583 - main_output_loss: 0.0184 - aux_output_loss: 0.1332
 96/300 [========>.....................] - ETA: 1s - loss: 0.0525 - main_output_loss: 0.0164 - aux_output_loss: 0.1203
128/300 [===========>..................] - ETA: 1s - loss: 0.0519 - main_output_loss: 0.0156 - aux_output_loss: 0.1213
160/300 [===============>..............] - ETA: 0s - loss: 0.0502 - main_output_loss: 0.0141 - aux_output_loss: 0.1202
192/300 [==================>...........] - ETA: 0s - loss: 0.0486 - main_output_loss: 0.0132 - aux_output_loss: 0.1180
224/300 [=====================>........] - ETA: 0s - loss: 0.0477 - main_output_loss: 0.0127 - aux_output_loss: 0.1166
256/300 [========================>.....] - ETA: 0s - loss: 0.0462 - main_output_loss: 0.0117 - aux_output_loss: 0.1148
288/300 [===========================>..] - ETA: 0s - loss: 0.0454 - main_output_loss: 0.0111 - aux_output_loss: 0.1141
300/300 [==============================] - 2s 6ms/step - loss: 0.0453 - main_output_loss: 0.0109 - aux_output_loss: 0.1147
Epoch 6/10

 32/300 [==>...........................] - ETA: 1s - loss: 0.0330 - main_output_loss: 0.0043 - aux_output_loss: 0.0954
 64/300 [=====>........................] - ETA: 1s - loss: 0.0368 - main_output_loss: 0.0047 - aux_output_loss: 0.1071
 96/300 [========>.....................] - ETA: 1s - loss: 0.0350 - main_output_loss: 0.0048 - aux_output_loss: 0.1008
128/300 [===========>..................] - ETA: 0s - loss: 0.0350 - main_output_loss: 0.0047 - aux_output_loss: 0.1010
160/300 [===============>..............] - ETA: 0s - loss: 0.0324 - main_output_loss: 0.0045 - aux_output_loss: 0.0930
192/300 [==================>...........] - ETA: 0s - loss: 0.0320 - main_output_loss: 0.0044 - aux_output_loss: 0.0921
224/300 [=====================>........] - ETA: 0s - loss: 0.0305 - main_output_loss: 0.0042 - aux_output_loss: 0.0877
256/300 [========================>.....] - ETA: 0s - loss: 0.0303 - main_output_loss: 0.0040 - aux_output_loss: 0.0876
288/300 [===========================>..] - ETA: 0s - loss: 0.0297 - main_output_loss: 0.0037 - aux_output_loss: 0.0867
300/300 [==============================] - 2s 6ms/step - loss: 0.0295 - main_output_loss: 0.0037 - aux_output_loss: 0.0860
Epoch 7/10

 32/300 [==>...........................] - ETA: 1s - loss: 0.0237 - main_output_loss: 0.0022 - aux_output_loss: 0.0718
 64/300 [=====>........................] - ETA: 1s - loss: 0.0226 - main_output_loss: 0.0020 - aux_output_loss: 0.0687
 96/300 [========>.....................] - ETA: 1s - loss: 0.0223 - main_output_loss: 0.0020 - aux_output_loss: 0.0676
128/300 [===========>..................] - ETA: 0s - loss: 0.0221 - main_output_loss: 0.0017 - aux_output_loss: 0.0677
160/300 [===============>..............] - ETA: 0s - loss: 0.0220 - main_output_loss: 0.0017 - aux_output_loss: 0.0677
192/300 [==================>...........] - ETA: 0s - loss: 0.0216 - main_output_loss: 0.0016 - aux_output_loss: 0.0666
224/300 [=====================>........] - ETA: 0s - loss: 0.0216 - main_output_loss: 0.0016 - aux_output_loss: 0.0667
256/300 [========================>.....] - ETA: 0s - loss: 0.0211 - main_output_loss: 0.0015 - aux_output_loss: 0.0651
288/300 [===========================>..] - ETA: 0s - loss: 0.0208 - main_output_loss: 0.0015 - aux_output_loss: 0.0643
300/300 [==============================] - 2s 6ms/step - loss: 0.0207 - main_output_loss: 0.0015 - aux_output_loss: 0.0640
Epoch 8/10

 32/300 [==>...........................] - ETA: 1s - loss: 0.0160 - main_output_loss: 9.5651e-04 - aux_output_loss: 0.0501
 64/300 [=====>........................] - ETA: 1s - loss: 0.0167 - main_output_loss: 8.6357e-04 - aux_output_loss: 0.0528
 96/300 [========>.....................] - ETA: 1s - loss: 0.0173 - main_output_loss: 8.7130e-04 - aux_output_loss: 0.0549
128/300 [===========>..................] - ETA: 0s - loss: 0.0159 - main_output_loss: 8.0989e-04 - aux_output_loss: 0.0502
160/300 [===============>..............] - ETA: 0s - loss: 0.0153 - main_output_loss: 7.4170e-04 - aux_output_loss: 0.0485
192/300 [==================>...........] - ETA: 0s - loss: 0.0150 - main_output_loss: 7.3019e-04 - aux_output_loss: 0.0476
224/300 [=====================>........] - ETA: 0s - loss: 0.0154 - main_output_loss: 7.0571e-04 - aux_output_loss: 0.0490
256/300 [========================>.....] - ETA: 0s - loss: 0.0154 - main_output_loss: 6.6068e-04 - aux_output_loss: 0.0491
288/300 [===========================>..] - ETA: 0s - loss: 0.0150 - main_output_loss: 6.4659e-04 - aux_output_loss: 0.0477
300/300 [==============================] - 2s 6ms/step - loss: 0.0148 - main_output_loss: 6.4725e-04 - aux_output_loss: 0.0471
Epoch 9/10

 32/300 [==>...........................] - ETA: 1s - loss: 0.0121 - main_output_loss: 2.1562e-04 - aux_output_loss: 0.0395
 64/300 [=====>........................] - ETA: 1s - loss: 0.0122 - main_output_loss: 3.7267e-04 - aux_output_loss: 0.0395
 96/300 [========>.....................] - ETA: 1s - loss: 0.0113 - main_output_loss: 3.3215e-04 - aux_output_loss: 0.0365
128/300 [===========>..................] - ETA: 1s - loss: 0.0113 - main_output_loss: 3.5326e-04 - aux_output_loss: 0.0365
160/300 [===============>..............] - ETA: 0s - loss: 0.0112 - main_output_loss: 3.2605e-04 - aux_output_loss: 0.0364
192/300 [==================>...........] - ETA: 0s - loss: 0.0110 - main_output_loss: 3.1679e-04 - aux_output_loss: 0.0357
224/300 [=====================>........] - ETA: 0s - loss: 0.0110 - main_output_loss: 3.0916e-04 - aux_output_loss: 0.0357
256/300 [========================>.....] - ETA: 0s - loss: 0.0108 - main_output_loss: 2.9004e-04 - aux_output_loss: 0.0350
288/300 [===========================>..] - ETA: 0s - loss: 0.0107 - main_output_loss: 2.7632e-04 - aux_output_loss: 0.0346
300/300 [==============================] - 2s 6ms/step - loss: 0.0106 - main_output_loss: 2.7436e-04 - aux_output_loss: 0.0344
Epoch 10/10

 32/300 [==>...........................] - ETA: 1s - loss: 0.0080 - main_output_loss: 1.8609e-04 - aux_output_loss: 0.0262
 64/300 [=====>........................] - ETA: 1s - loss: 0.0077 - main_output_loss: 1.5369e-04 - aux_output_loss: 0.0250
 96/300 [========>.....................] - ETA: 1s - loss: 0.1011 - main_output_loss: 0.0819 - aux_output_loss: 0.0641   
128/300 [===========>..................] - ETA: 1s - loss: 0.0785 - main_output_loss: 0.0619 - aux_output_loss: 0.0552
160/300 [===============>..............] - ETA: 0s - loss: 0.0645 - main_output_loss: 0.0496 - aux_output_loss: 0.0497
192/300 [==================>...........] - ETA: 0s - loss: 0.0553 - main_output_loss: 0.0414 - aux_output_loss: 0.0461
224/300 [=====================>........] - ETA: 0s - loss: 0.0487 - main_output_loss: 0.0356 - aux_output_loss: 0.0436
256/300 [========================>.....] - ETA: 0s - loss: 0.0436 - main_output_loss: 0.0312 - aux_output_loss: 0.0412
288/300 [===========================>..] - ETA: 0s - loss: 0.0396 - main_output_loss: 0.0278 - aux_output_loss: 0.0394
300/300 [==============================] - 2s 6ms/step - loss: 0.0384 - main_output_loss: 0.0267 - aux_output_loss: 0.0387

 10/300 [>.............................] - ETA: 1s
 40/300 [===>..........................] - ETA: 0s
 70/300 [======>.......................] - ETA: 0s
110/300 [==========>...................] - ETA: 0s
140/300 [=============>................] - ETA: 0s
180/300 [=================>............] - ETA: 0s
220/300 [=====================>........] - ETA: 0s
260/300 [=========================>....] - ETA: 0s
300/300 [==============================] - 1s 2ms/step
[0.0073259123601019382, 0.00040480913982416191, 0.023070343149205048]