1. 程式人生 > 其它 >基於tensorflow2 的通過自定義class的方式搭建網路,鳶尾花分類

基於tensorflow2 的通過自定義class的方式搭建網路,鳶尾花分類

# 1)import【引入相關模組】
import tensorflow as tf
from tensorflow.keras.layers import Dense
from tensorflow.keras import Model
from sklearn import datasets
import numpy as np
import matplotlib.pyplot as plt

# 2)train,test 【告知喂入網路的訓練集測試集以及相應的標籤】
x_train = datasets.load_iris().data
y_train = datasets.load_iris().target

np.random.seed(
116) np.random.shuffle(x_train) np.random.seed(116) np.random.shuffle(y_train) tf.random.set_seed(116) # 3) class MyModel() class IrisModel(Model): def __init__(self): super(IrisModel, self).__init__() self.d1 = Dense(3, activation='sigmoid', kernel_regularizer=tf.keras.regularizers.l2())
def call(self, x): y = self.d1(x) return y model = IrisModel() # 4)model.compile 【告知訓練時選擇哪種優化器,選擇哪個損失函式,選擇哪種評測指標】 model.compile(optimizer=tf.keras.optimizers.SGD(lr=0.1), loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False), metrics=['sparse_categorical_accuracy
']) # 5)model.fit 【在fit()中執行訓練過程,告知訓練集合測試集的輸入特徵和標籤,告知batch大小,告知要迭代多少次資料集】 history = model.fit(x_train, y_train, verbose=0, batch_size=32, epochs=500, validation_split=0.2, validation_freq=20) # 6)model.summary 【列印網路結構和引數統計】 model.summary() # 畫訓練集和測試集的loss曲線 train_loss = np.array(history.history['loss']) sparse_categorical_accuracy = np.array(history.history['sparse_categorical_accuracy']) epoch = np.array(history.epoch) train_line, = plt.plot(epoch, train_loss) sparse_categorical_accuracy_line, = plt.plot(epoch, sparse_categorical_accuracy) plt.legend(handles=[train_line, sparse_categorical_accuracy_line], labels=['train_loss', 'sparse_categorical_accuracy']) plt.xlabel('epoch') plt.ylabel('loss') plt.show()