keras兩個API
阿新 • • 發佈:2019-01-01
1.累加API
from keras.models import Sequential
from keras.layers import Dense
model = Sequential()
model.add(Dense(2, input_dim=1))
model.add(Dense(1))
但是他有很多限制
For example, it is not straightforward to define models that may have multiple different input sources, produce multiple output destinations or models that re-use layers.
2.函式式API
舉個最簡單的例子
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
visible = Input(shape=(2,))
hidden = Dense(2)(visible)
model = Model(inputs=visible, outputs=hidden)
稍微複雜點的例子
# Convolutional Neural Network from keras.utils import plot_model from keras.models import Model from keras.layers import Input from keras.layers import Dense from keras.layers import Flatten from keras.layers.convolutional import Conv2D from keras.layers.pooling import MaxPooling2D visible = Input(shape=(64,64,1)) conv1 = Conv2D(32, kernel_size=4, activation='relu')(visible) pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) conv2 = Conv2D(16, kernel_size=4, activation='relu')(pool1) pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) flat = Flatten()(pool2) hidden1 = Dense(10, activation='relu')(flat) output = Dense(1, activation='sigmoid')(hidden1) model = Model(inputs=visible, outputs=output) # summarize layers print(model.summary()) # plot graph plot_model(model, to_file='convolutional_neural_network.png')
再複雜一下,體現這個API的優越性
# Shared Input Layer from keras.utils import plot_model from keras.models import Model from keras.layers import Input from keras.layers import Dense from keras.layers import Flatten from keras.layers.convolutional import Conv2D from keras.layers.pooling import MaxPooling2D from keras.layers.merge import concatenate # input layer visible = Input(shape=(64,64,1)) # first feature extractor conv1 = Conv2D(32, kernel_size=4, activation='relu')(visible) pool1 = MaxPooling2D(pool_size=(2, 2))(conv1) flat1 = Flatten()(pool1) # second feature extractor conv2 = Conv2D(16, kernel_size=8, activation='relu')(visible) pool2 = MaxPooling2D(pool_size=(2, 2))(conv2) flat2 = Flatten()(pool2) # merge feature extractors merge = concatenate([flat1, flat2]) # interpretation layer hidden1 = Dense(10, activation='relu')(merge) # prediction output output = Dense(1, activation='sigmoid')(hidden1) model = Model(inputs=visible, outputs=output) # summarize layers print(model.summary()) # plot graph plot_model(model, to_file='shared_input_layer.png')
這裡體現了分支和merge
當然還有下面這種結構
# Shared Feature Extraction Layer
from keras.utils import plot_model
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers.recurrent import LSTM
from keras.layers.merge import concatenate
# define input
visible = Input(shape=(100,1))
# feature extraction
extract1 = LSTM(10)(visible)
# first interpretation model
interp1 = Dense(10, activation='relu')(extract1)
# second interpretation model
interp11 = Dense(10, activation='relu')(extract1)
interp12 = Dense(20, activation='relu')(interp11)
interp13 = Dense(10, activation='relu')(interp12)
# merge interpretation
merge = concatenate([interp1, interp13])
# output
output = Dense(1, activation='sigmoid')(merge)
model = Model(inputs=visible, outputs=output)
# summarize layers
print(model.summary())
# plot graph
plot_model(model, to_file='shared_feature_extractor.png')
多個input一個ouput也可以
多個output也可以
# Multiple Outputs
from keras.utils import plot_model
from keras.models import Model
from keras.layers import Input
from keras.layers import Dense
from keras.layers.recurrent import LSTM
from keras.layers.wrappers import TimeDistributed
# input layer
visible = Input(shape=(100,1))
# feature extraction
extract = LSTM(10, return_sequences=True)(visible)
# classification output
class11 = LSTM(10)(extract)
class12 = Dense(10, activation='relu')(class11)
output1 = Dense(1, activation='sigmoid')(class12)
# sequence output
output2 = TimeDistributed(Dense(1, activation='linear'))(extract)
# output
model = Model(inputs=visible, outputs=[output1, output2])
# summarize layers
print(model.summary())
# plot graph
plot_model(model, to_file='multiple_outputs.png')
參考:https://machinelearningmastery.com/keras-functional-api-deep-learning/