tensorflow2.0——鳶尾花資料集的一元分類
阿新 • • 發佈:2020-08-04
import tensorflow as tf import numpy as np import pandas as pd import matplotlib.pylab as plt import matplotlib as mpl # 讀取資料集 TRIN_URL = 'http://download.tensorflow.org/data/iris_training.csv' # 資料集下載網址 df_iris = pd.read_csv('./鳶尾花資料集/iris.csv', header=0) # 讀取本地csv檔案資料 # print(df_iris) # 最後一列是分類編號# 處理資料集 iris = np.array(df_iris) # 將資料轉換成numpy陣列 iris2 = iris[iris[:, -1] < 2] # 只取分類編號小於2的兩類資料 train_x = iris2[:, 0:2] # 只取特徵的前兩列 train_x = train_x - np.mean(train_x, axis=0) #需要將樣本的均值變為0 train_1 = np.ones(train_x.shape[0]).reshape(-1, 1) # 生成一個與train_x一樣行數的全1矩陣 # print('train_1:',train_1) train_x = tf.concat((train_x, train_1), axis=1) # 將train_x擴充一列全1 train_x = tf.cast(train_x,tf.float32) train_y = iris2[:, -1] #標籤 train_y = train_y.reshape(-1,1) # print(train_x, train_x.shape) # print('axis = 0:',np.mean(train_x,axis=0)) # axis = 0 求一列的平均值 print(train_y, train_y.shape) # 設定超參 iter = 2000 learn_rate = 0.1 loss_list = [] acc_list = [] # 初始化訓練引數 w = tf.Variable(np.random.randn(3,1),dtype=tf.float32) # w = tf.Variable(np.array([1.,1.,1.]).reshape(-1,1),dtype=tf.float32) for i in range(iter): with tf.GradientTape() as tape: y_p = 1 / (1 + tf.exp(-(tf.matmul(train_x, w)))) loss = tf.reduce_mean(-(train_y * tf.math.log(y_p) + (1 - train_y) * tf.math.log(1 - y_p))) dloss_dw = tape.gradient(loss, w) w.assign_sub(learn_rate * dloss_dw) loss_list.append(loss) acc = tf.reduce_mean(tf.cast(tf.equal(tf.round(y_p),train_y),dtype = tf.float32)) acc_list.append(acc) if i % 100 == 0: print('第{}次, loss:{},acc:{}'.format(i,loss,acc)) # print('y_p:{}\ntrain_y:{}'.format(y_p, train_y)) print() # 預測直線的橫縱座標處理 x1 = train_x[:,0] print(w[0] * 55) x2 = -(w[0] * x1 + w[2])/ w[1] # print('x2:',x2) # 畫圖 plt.rcParams["font.family"] = 'SimHei' # 將字型改為中文 plt.rcParams['axes.unicode_minus'] = False # 設定了中文字型預設後,座標的"-"號無法顯示,設定這個引數就可以避免 plt.subplot(221) cm_pt = mpl.colors.ListedColormap(['red', 'green']) plt.scatter(x=train_x[:, 0], y=train_x[:, 1], c=train_y, cmap=cm_pt) plt.plot(x1,x2,label = '預測線') plt.legend() plt.subplot(223) plt.title('損失值') plt.plot(loss_list) plt.subplot(224) plt.title('準確率') plt.plot(acc_list) plt.show()