RNN入門(4)利用LSTM實現整數加法運算
本文將介紹LSTM模型在實現整數加法方面的應用。
我們以0-255之間的整數加法為例,生成的結果在0到510之間。為了能利用深度學習模型模擬整數的加法運算,我們需要將輸入的兩個加數和輸出的結果用二進位制表示,這樣就能得到向量,如加數在0-255內,可以用8位0-1向量來表示,前面的空位用0填充;結果在0-510內,可以用9位0-1向量來表示,前面的空位用0填充。因為兩個加數均在0-255內變化,所以共有256*256=65536個輸入向量以及65536個輸出向量,輸入向量為兩個加數的二進位制向量的拼接結果,因而是個16為的輸入向量。用以下的Python程式碼可以模擬以上過程:
import numpy as np
# 最多8位二進位制
BINARY_DIM = 8
# 將整數表示成為binary_dim位的二進位制數,高位用0補齊
def int_2_binary(number, binary_dim):
binary_list = list(map(lambda x: int(x), bin(number)[2:]))
number_dim = len(binary_list)
result_list = [0]*(binary_dim-number_dim)+binary_list
return result_list
# 將一個二進位制陣列轉為整數
def binary2int(binary_array):
out = 0
for index, x in enumerate(reversed(binary_array)):
out += x * pow(2, index)
return out
# 將[0,2**BINARY_DIM)所有數表示成二進位制
binary = np.array([int_2_binary(x, BINARY_DIM) for x in range(2**BINARY_DIM)])
# print(binary)
# 樣本的輸入向量和輸出向量
dataX = []
dataY = []
for i in range(binary.shape[0]):
for j in range(binary.shape[0]):
dataX.append(np.append(binary[i], binary[j]))
dataY.append(int_2_binary(i+j, BINARY_DIM+1))
# print(dataX)
# print(dataY)
# 重新特徵X和目標變數Y陣列,適應LSTM模型的輸入和輸出
X = np.reshape(dataX, (len(dataX), 2*BINARY_DIM, 1))
# print(X.shape)
Y = np.array(dataY)
# print(dataY.shape)
在以上程式碼中,得到的dataX和dataY以滿足要求,但為了能讓LSTM模型處理,需要改變這兩個資料集的形狀。
我們採用LSTM模型來訓練上述資料,LSTM模型的結構很簡單,就是簡單的一層LSTM層,然後加上Dropout層,最後是全連線層,啟用函式採用sigmoid函式,採用的損失函式為平均平方誤差。整個結構的示意圖如下:
模型訓練的程式碼如下:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM
from keras import losses
from keras.utils import plot_model
# 定義LSTM模型
model = Sequential()
model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2])))
model.add(Dropout(0.2))
model.add(Dense(Y.shape[1], activation='sigmoid'))
model.compile(loss=losses.mean_squared_error, optimizer='adam')
# print(model.summary())
# plot model
plot_model(model, to_file=r'./model.png', show_shapes=True)
# train model
epochs = 100
model.fit(X, Y, epochs=epochs, batch_size=128)
# save model
mp = r'./LSTM_Operation.h5'
model.save(mp)
該LSTM模型每批訓練128個樣本,共訓練100次,採用Adam優化器減少損失值。
對這個模型進行訓練,訓練100次,損失值為0.0045。接下來我們就要用這個訓練好的模型來預測。我們預測的方法為,雖然挑兩個在0-255內的加數,轉化為二進位制向量作為輸入向量,然後由LSTM模型輸出結果,將該結果取整作為輸出向量中的元素,最後將這個輸出向量轉化為整數,就是預測的兩個加數的和。模型預測的程式碼如下:
# use LSTM model to predict
for _ in range(100):
start = np.random.randint(0, len(dataX)-1)
# print(dataX[start])
number1 = dataX[start][0:BINARY_DIM]
number2 = dataX[start][BINARY_DIM:]
print('='*30)
print('%s: %s'%(number1, binary2int(number1)))
print('%s: %s'%(number2, binary2int(number2)))
sample = np.reshape(X[start], (1, 2*BINARY_DIM, 1))
predict = np.round(model.predict(sample), 0).astype(np.int32)[0]
print('%s: %s'%(predict, binary2int(predict)))
預測的100組樣本的輸出結果如下:
==============================
[1 0 0 1 1 1 0 1]: 157
[0 1 1 1 0 0 0 1]: 113
[1 0 0 0 0 1 1 1 0]: 270
==============================
[1 1 1 0 1 0 1 0]: 234
[0 1 0 0 1 1 0 0]: 76
[1 0 0 1 1 0 1 1 0]: 310
==============================
[1 1 0 0 0 1 0 0]: 196
[1 1 0 1 1 0 1 1]: 219
[1 1 0 0 1 1 1 1 1]: 415
==============================
[0 0 1 1 1 0 1 0]: 58
[0 0 1 0 0 0 1 1]: 35
[0 0 1 0 1 1 1 0 1]: 93
==============================
[1 0 0 0 0 0 0 0]: 128
[0 1 1 1 1 0 0 1]: 121
[0 1 1 1 1 1 0 0 1]: 249
==============================
[1 1 1 1 0 1 1 0]: 246
[1 1 0 1 0 1 0 1]: 213
[1 1 1 0 0 1 0 1 1]: 459
==============================
[1 1 1 0 0 1 1 0]: 230
[1 0 0 0 0 0 0 0]: 128
[1 0 1 1 0 0 1 1 0]: 358
==============================
[1 0 1 0 0 0 1 1]: 163
[0 1 1 0 0 1 0 1]: 101
[1 0 0 0 0 1 0 0 0]: 264
==============================
[1 0 1 0 0 1 1 0]: 166
[0 1 0 1 0 0 0 0]: 80
[0 1 1 1 1 0 1 1 0]: 246
==============================
[0 0 0 0 1 0 1 1]: 11
[0 1 0 0 0 1 0 1]: 69
[0 0 1 0 1 0 0 0 0]: 80
==============================
[1 1 1 1 0 1 1 1]: 247
[0 1 1 1 0 0 0 0]: 112
[1 0 1 1 0 0 1 1 1]: 359
==============================
[1 0 1 0 1 0 0 1]: 169
[1 1 0 0 0 0 0 0]: 192
[1 0 1 1 0 1 0 0 1]: 361
==============================
[1 0 1 1 0 0 0 1]: 177
[1 0 0 0 1 0 1 1]: 139
[1 0 0 1 1 1 1 0 0]: 316
==============================
[0 1 0 0 0 1 1 0]: 70
[0 0 1 0 1 1 1 0]: 46
[0 0 1 1 1 0 1 0 0]: 116
==============================
[1 0 0 1 1 0 1 1]: 155
[1 1 0 0 0 0 0 1]: 193
[1 0 1 0 1 1 1 0 0]: 348
==============================
[1 0 1 1 0 0 1 0]: 178
[1 0 0 0 1 1 1 1]: 143
[1 0 1 0 0 0 0 0 1]: 321
==============================
[0 1 0 1 1 1 1 1]: 95
[1 1 1 0 0 1 0 0]: 228
[1 0 1 0 0 0 0 1 1]: 323
==============================
[1 0 0 1 1 1 1 0]: 158
[0 0 0 1 1 0 0 1]: 25
[0 1 0 1 1 0 1 1 1]: 183
==============================
[1 1 1 0 1 0 1 1]: 235
[1 1 0 0 0 0 0 1]: 193
[1 1 0 1 0 1 1 0 0]: 428
==============================
[0 1 0 1 1 1 0 1]: 93
[0 1 1 1 0 1 1 0]: 118
[0 1 1 0 1 0 0 1 1]: 211
==============================
[1 1 1 1 1 1 1 1]: 255
[1 1 1 1 1 1 1 0]: 254
[1 1 1 1 1 1 1 0 1]: 509
==============================
[0 1 0 1 1 0 0 1]: 89
[0 1 0 1 1 1 1 0]: 94
[0 1 0 1 1 0 1 1 1]: 183
==============================
[0 1 1 1 0 0 0 0]: 112
[0 0 1 1 0 1 0 0]: 52
[0 1 0 1 0 0 1 0 0]: 164
==============================
[1 0 0 0 0 0 0 0]: 128
[1 1 0 1 1 0 1 0]: 218
[1 0 1 0 1 1 0 1 0]: 346
==============================
[0 0 1 1 0 1 0 1]: 53
[1 0 1 1 1 1 1 0]: 190
[0 1 1 1 1 0 0 1 1]: 243
==============================
[0 1 1 1 1 0 0 0]: 120
[1 1 0 1 0 1 0 1]: 213
[1 0 1 0 0 1 1 0 1]: 333
==============================
[0 1 1 1 1 0 1 1]: 123
[1 1 1 0 1 1 0 1]: 237
[1 0 1 1 0 1 0 0 0]: 360
==============================
[1 0 0 1 1 0 1 0]: 154
[0 1 1 0 1 0 0 1]: 105
[1 0 0 0 0 0 0 1 1]: 259
==============================
[0 0 0 1 1 0 0 1]: 25
[0 1 0 1 1 0 1 0]: 90
[0 0 1 1 1 0 0 1 1]: 115
==============================
[1 1 1 1 0 0 0 1]: 241
[0 0 0 1 1 1 1 1]: 31
[1 0 0 0 1 0 0 0 0]: 272
==============================
[0 1 0 0 0 1 1 0]: 70
[1 1 1 0 1 0 0 1]: 233
[1 0 0 1 0 1 1 1 1]: 303
==============================
[1 0 1 0 1 1 0 1]: 173
[0 1 1 1 0 1 0 0]: 116
[1 0 0 1 0 0 0 0 1]: 289
==============================
[0 1 0 0 1 0 0 0]: 72
[1 1 1 1 1 0 1 0]: 250
[1 0 1 0 0 0 0 1 0]: 322
==============================
[1 1 1 1 0 0 0 0]: 240
[0 1 0 0 0 0 1 0]: 66
[1 0 0 1 1 0 0 1 0]: 306
==============================
[0 1 0 0 0 1 1 1]: 71
[1 0 0 1 0 1 1 0]: 150
[0 1 1 0 1 1 1 0 1]: 221
==============================
[0 1 1 0 1 1 0 1]: 109
[0 0 1 0 0 1 0 1]: 37
[0 1 0 0 1 0 0 1 0]: 146
==============================
[1 1 0 0 0 0 0 0]: 192
[1 1 1 0 0 0 0 1]: 225
[1 1 0 1 0 0 0 0 1]: 417
==============================
[1 0 0 0 0 0 1 1]: 131
[1 1 0 1 1 1 1 0]: 222
[1 0 1 1 0 0 0 0 1]: 353
==============================
[0 0 0 0 0 1 0 0]: 4
[1 1 1 0 0 0 1 0]: 226
[0 1 1 1 0 0 1 1 0]: 230
==============================
[1 1 1 0 1 1 1 1]: 239
[1 1 0 1 1 0 1 1]: 219
[1 1 1 0 0 1 0 1 0]: 458
==============================
[0 0 1 1 0 1 0 1]: 53
[1 1 1 1 0 0 1 0]: 242
[1 0 0 1 0 0 1 1 1]: 295
==============================
[1 0 0 1 0 0 0 1]: 145
[0 1 0 0 0 1 0 0]: 68
[0 1 1 0 1 0 1 0 1]: 213
==============================
[0 0 1 1 0 0 0 0]: 48
[1 0 1 1 0 1 1 1]: 183
[0 1 1 1 0 0 1 1 1]: 231
==============================
[0 1 1 0 0 1 1 1]: 103
[0 0 0 1 1 1 1 0]: 30
[0 1 0 0 0 0 1 0 1]: 133
==============================
[0 1 0 1 1 1 0 1]: 93
[1 1 0 1 0 0 1 0]: 210
[1 0 0 1 0 1 1 1 1]: 303
==============================
[1 0 0 0 1 0 1 0]: 138
[0 1 1 1 1 0 0 1]: 121
[1 0 0 0 0 0 0 1 1]: 259
==============================
[0 0 0 0 0 0 1 1]: 3
[0 0 1 1 0 0 0 1]: 49
[0 0 0 1 1 0 1 0 0]: 52
==============================
[1 0 0 0 0 0 1 0]: 130
[0 0 0 1 0 0 0 0]: 16
[0 1 0 0 1 0 0 1 0]: 146
==============================
[0 0 0 1 0 0 0 0]: 16
[1 0 0 1 0 0 1 0]: 146
[0 1 0 1 0 0 0 1 0]: 162
==============================
[0 1 0 1 0 1 0 0]: 84
[0 0 0 0 1 1 0 0]: 12
[0 0 1 1 0 0 0 0 0]: 96
==============================
[1 0 1 0 1 0 1 1]: 171
[1 1 0 1 1 0 1 1]: 219
[1 1 0 0 0 0 1 1 0]: 390
==============================
[1 1 1 1 1 1 1 0]: 254
[0 1 1 0 1 0 1 0]: 106
[1 0 1 1 0 1 0 0 0]: 360
==============================
[1 0 0 0 0 0 1 0]: 130
[0 0 0 0 1 1 1 0]: 14
[0 1 0 0 1 0 0 0 0]: 144
==============================
[1 0 1 0 0 1 0 1]: 165
[0 0 1 1 1 0 1 1]: 59
[0 1 1 1 0 0 0 0 0]: 224
==============================
[0 0 1 1 1 0 1 0]: 58
[1 1 1 1 0 0 1 0]: 242
[1 0 0 1 0 1 1 0 0]: 300
==============================
[0 1 0 0 1 1 0 1]: 77
[0 0 0 1 1 1 1 1]: 31
[0 0 1 1 0 1 1 0 0]: 108
==============================
[1 0 0 1 1 0 1 0]: 154
[0 1 0 1 0 1 0 1]: 85
[0 1 1 1 0 1 1 1 1]: 239
==============================
[0 1 1 0 1 1 0 1]: 109
[0 1 1 0 1 0 0 1]: 105
[0 1 1 0 1 0 1 1 0]: 214
==============================
[0 1 1 1 1 1 1 1]: 127
[0 1 1 1 0 0 1 0]: 114
[0 1 1 1 1 0 0 0 1]: 241
==============================
[0 1 1 0 0 1 0 1]: 101
[0 1 0 1 0 0 0 0]: 80
[0 1 0 1 1 0 1 0 1]: 181
==============================
[0 1 1 0 1 1 1 0]: 110
[0 1 0 1 0 1 1 0]: 86
[0 1 1 0 0 0 1 0 0]: 196
==============================
[0 0 0 1 0 0 1 1]: 19
[1 0 0 1 0 0 0 0]: 144
[0 1 0 1 0 0 0 1 1]: 163
==============================
[1 1 1 1 0 1 0 0]: 244
[1 1 0 1 0 0 1 1]: 211
[1 1 1 0 0 0 1 1 1]: 455
==============================
[0 0 0 0 1 1 1 0]: 14
[1 0 1 1 0 0 1 0]: 178
[0 1 1 0 0 0 0 0 0]: 192
==============================
[0 1 1 0 0 0 0 0]: 96
[1 0 0 1 1 1 0 0]: 156
[0 1 1 1 1 1 1 0 0]: 252
==============================
[0 0 1 1 0 1 0 0]: 52
[0 1 1 1 1 1 0 1]: 125
[0 1 0 1 1 0 0 0 1]: 177
==============================
[0 0 0 0 1 1 0 0]: 12
[0 1 0 1 1 1 0 1]: 93
[0 0 1 1 0 1 0 0 1]: 105
==============================
[0 1 1 0 0 1 0 1]: 101
[1 1 0 1 0 1 0 0]: 212
[1 0 0 1 1 1 0 0 1]: 313
==============================
[1 1 0 0 0 0 0 1]: 193
[1 1 0 0 1 1 0 1]: 205
[1 1 0 0 0 1 1 1 0]: 398
==============================
[0 1 1 1 0 0 1 0]: 114
[0 0 0 0 0 0 0 0]: 0
[0 0 1 1 1 0 0 1 0]: 114
==============================
[1 0 0 0 1 1 1 0]: 142
[1 0 1 1 1 1 0 1]: 189
[1 0 1 0 0 1 0 1 1]: 331
==============================
[1 0 1 1 0 1 1 1]: 183
[0 1 0 1 0 1 1 0]: 86
[1 0 0 0 0 1 1 0 1]: 269
==============================
[1 0 1 0 0 0 1 1]: 163
[1 1 1 0 0 1 0 1]: 229
[1 1 0 0 0 1 0 0 0]: 392
==============================
[0 0 1 1 0 0 0 1]: 49
[1 1 1 0 0 1 1 1]: 231
[1 0 0 0 1 1 0 0 0]: 280
==============================
[1 0 0 0 1 1 1 1]: 143
[1 0 1 0 1 0 0 0]: 168
[1 0 0 1 1 0 1 1 1]: 311
==============================
[0 1 0 0 0 0 0 0]: 64
[0 0 0 0 0 1 0 1]: 5
[0 0 1 0 0 0 1 0 1]: 69
==============================
[1 1 1 1 1 0 1 1]: 251
[1 0 1 1 1 0 0 1]: 185
[1 1 0 1 1 0 1 0 0]: 436
==============================
[1 1 1 0 1 1 1 0]: 238
[1 1 0 0 0 0 1 0]: 194
[1 1 0 1 1 0 0 0 0]: 432
==============================
[0 0 1 1 1 1 0 0]: 60
[0 0 0 1 0 1 1 1]: 23
[0 0 1 0 1 0 0 1 1]: 83
==============================
[0 1 1 1 0 1 0 0]: 116
[1 1 1 1 1 1 0 0]: 252
[1 0 1 1 1 0 0 0 0]: 368
==============================
[1 1 0 1 0 1 1 0]: 214
[1 1 1 1 0 1 0 0]: 244
[1 1 1 0 0 1 0 1 0]: 458
==============================
[1 1 1 1 1 1 1 0]: 254
[1 1 0 1 0 0 0 1]: 209
[1 1 1 0 0 1 1 1 1]: 463
==============================
[0 0 0 0 0 0 1 0]: 2
[0 0 0 0 1 1 0 1]: 13
[0 0 0 0 0 1 1 1 1]: 15
==============================
[0 1 1 0 0 1 1 1]: 103
[1 0 1 1 1 1 1 0]: 190
[1 0 0 1 0 0 1 0 1]: 293
==============================
[1 1 1 1 0 1 1 0]: 246
[0 1 0 1 0 0 1 0]: 82
[1 0 1 0 0 1 0 0 0]: 328
==============================
[0 1 1 1 0 0 1 1]: 115
[0 0 1 1 1 0 1 1]: 59
[0 1 0 1 0 1 1 1 0]: 174
==============================
[0 1 0 1 1 0 0 1]: 89
[0 1 1 0 1 0 1 1]: 107
[0 1 1 0 0 0 1 0 0]: 196
==============================
[0 1 0 0 0 1 0 0]: 68
[0 0 1 1 1 0 0 0]: 56
[0 0 1 1 1 1 1 0 0]: 124
==============================
[1 1 0 0 1 0 0 0]: 200
[1 0 1 0 0 0 1 0]: 162
[1 0 1 1 0 1 0 1 0]: 362
==============================
[1 1 1 1 0 0 1 1]: 243
[0 1 1 0 0 0 1 1]: 99
[1 0 1 0 1 0 1 1 0]: 342
==============================
[0 0 1 0 1 0 0 1]: 41
[0 1 0 0 1 0 0 1]: 73
[0 0 1 1 1 0 0 1 0]: 114
==============================
[0 0 0 1 1 1 0 1]: 29
[1 0 1 0 1 1 1 0]: 174
[0 1 1 0 0 1 0 1 1]: 203
==============================
[0 0 0 0 1 1 1 1]: 15
[0 0 1 1 1 1 0 1]: 61
[0 0 1 0 0 1 1 0 0]: 76
==============================
[1 1 1 1 1 0 1 1]: 251
[1 1 0 1 0 0 0 0]: 208
[1 1 1 0 0 1 0 1 1]: 459
==============================
[1 1 1 0 1 0 0 0]: 232
[0 1 1 0 0 0 1 0]: 98
[1 0 1 0 0 1 0 1 0]: 330
==============================
[1 0 1 1 0 1 0 0]: 180
[0 1 0 1 0 1 1 1]: 87
[1 0 0 0 0 1 0 1 1]: 267
==============================
[1 0 0 0 0 1 1 0]: 134
[1 0 0 1 0 1 0 1]: 149
[1 0 0 0 1 1 0 1 1]: 283
==============================
[1 0 1 0 1 1 0 1]: 173
[0 1 1 1 1 1 0 0]: 124
[1 0 0 1 0 1 0 0 1]: 297
==============================
[0 1 0 0 1 0 0 0]: 72
[0 1 1 0 0 0 1 1]: 99
[0 1 0 1 0 1 0 1 1]: 171
==============================
[1 1 0 1 0 1 0 1]: 213
[0 0 0 1 1 1 1 0]: 30
[0 1 1 1 1 0 0 1 1]: 243
可以看到,這個簡單的LSTM模型的預測的結果全部正確。因此,這就可以用來模擬0-255內的整數的加法運算,是不是很神奇呢?
如果需要想將加數的範圍擴大,只需要改變程式碼中的BINARY_DIM變數即可。但是,加數的範圍越大,樣本就越大,如2**10=1024內的加法,就會有1024*1024=1048576個樣本,這樣大的樣本量的無疑需要更多的訓練時間。
本文到此結束,感謝閱讀!如果不當之處,請速聯絡筆者,歡迎大家交流!祝您好運~
注意:本人現已開通微信公眾號: Python爬蟲與演算法(微訊號為:easy_web_scrape), 歡迎大家關注哦~~
完整的Python程式碼如下:
import numpy as np
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.layers import LSTM
from keras import losses
from keras.utils import plot_model
# 最多8位二進位制
BINARY_DIM = 8
# 將整數表示成為binary_dim位的二進位制數,高位用0補齊
def int_2_binary(number, binary_dim):
binary_list = list(map(lambda x: int(x), bin(number)[2:]))
number_dim = len(binary_list)
result_list = [0]*(binary_dim-number_dim)+binary_list
return result_list
# 將一個二進位制陣列轉為整數
def binary2int(binary_array):
out = 0
for index, x in enumerate(reversed(binary_array)):
out += x * pow(2, index)
return out
# 將[0,2**BINARY_DIM)所有數表示成二進位制
binary = np.array([int_2_binary(x, BINARY_DIM) for x in range(2**BINARY_DIM)])
# print(binary)
# 樣本的輸入向量和輸出向量
dataX = []
dataY = []
for i in range(binary.shape[0]):
for j in range(binary.shape[0]):
dataX.append(np.append(binary[i], binary[j]))
dataY.append(int_2_binary(i+j, BINARY_DIM+1))
# print(dataX)
# print(dataY)
# 重新特徵X和目標變數Y陣列,適應LSTM模型的輸入和輸出
X = np.reshape(dataX, (len(dataX), 2*BINARY_DIM, 1))
# print(X.shape)
Y = np.array(dataY)
# print(dataY.shape)
# 定義LSTM模型
model = Sequential()
model.add(LSTM(256, input_shape=(X.shape[1], X.shape[2])))
model.add(Dropout(0.2))
model.add(Dense(Y.shape[1], activation='sigmoid'))
model.compile(loss=losses.mean_squared_error, optimizer='adam')
# print(model.summary())
# plot model
plot_model(model, to_file=r'./model.png', show_shapes=True)
# train model
epochs = 100
model.fit(X, Y, epochs=epochs, batch_size=128)
# save model
mp = r'./LSTM_Operation.h5'
model.save(mp)
# use LSTM model to predict
for _ in range(100):
start = np.random.randint(0, len(dataX)-1)
# print(dataX[start])
number1 = dataX[start][0:BINARY_DIM]
number2 = dataX[start][BINARY_DIM:]
print('='*30)
print('%s: %s'%(number1, binary2int(number1)))
print('%s: %s'%(number2, binary2int(number2)))
sample = np.reshape(X[start], (1, 2*BINARY_DIM, 1))
predict = np.round(model.predict(sample), 0).astype(np.int32)[0]
print('%s: %s'%(predict, binary2int(predict)))