1. 程式人生 > 其它 >鳶尾花資料集----決策樹vs神經網路

鳶尾花資料集----決策樹vs神經網路

為方便理解兩種不同預測分類演算法 我們均呼叫 sklearn 裡 datasets 的鳶尾花資料集

決策樹:

  1 import numpy as np
  2 from sklearn import datasets
  3 from sklearn.model_selection import train_test_split
  4 import matplotlib as mpl
  5 import matplotlib.pyplot as plt
  6 from sklearn import tree
  7 from sklearn.pipeline import Pipeline
8 from sklearn.tree import DecisionTreeClassifier 9 from sklearn.preprocessing import StandardScaler 10 11 # 防止畫圖漢字亂碼 12 mpl.rcParams['font.sans-serif'] = [u'SimHei'] 13 mpl.rcParams['axes.unicode_minus'] = False 14 15 #資料準備 16 dataset = datasets.load_iris() # 此時 訓練資料(train)與標籤(target) 已經分離 為 字典 資料集
17 # 資料集 已經將標籤資料化(化為0-2標籤值) 無需再處理 18 19 data = dataset['data'] # 取出對應鍵 的值 值為array型別 20 target = dataset['target'] 21 # input = torch.FloatTensor(dataset['data']) 22 # y = torch.LongTensor(dataset['target']) 23 24 x = np.array(data) 25 y = np.array(target) 26 x = x[:, :2] # 此時的資料為 150行 4列 為方便畫圖 我們只取前兩個特徵
27 # 將資料集 7 / 3 分 28 x_train, x_test, y_train, y_test = train_test_split(x, y, test_size=0.3, random_state=1) 29 30 model = Pipeline([ 31 ('ss', StandardScaler()), 32 ('DTC', DecisionTreeClassifier(criterion='entropy', max_depth=3))]) 33 # clf = DecisionTreeClassifier(criterion='entropy', max_depth=3) 34 model = model.fit(x_train, y_train) 35 y_test_hat = model.predict(x_test) # 測試資料 y_test_hat 為預測值 36 # print(y_test) 45個預測樣本的真實標籤 37 # [0 1 1 0 2 1 2 0 0 2 1 0 2 1 1 0 1 1 0 0 1 1 1 0 2 1 0 0 1 2 1 2 1 2 2 0 1 0 1 2 2 0 2 2 1] 38 # print(y_test_hat) 45個預測樣本的預測標籤 39 # [0 1 2 0 2 2 2 0 0 2 1 0 2 2 1 0 1 1 0 0 1 0 2 0 2 1 0 0 1 2 1 2 1 2 1 0 1 0 2 2 2 0 1 2 2] 40 41 42 # 儲存 43 # dot -Tpng -o 1.png 1.dot 44 f = open('.\\iris_tree.dot', 'w') 45 tree.export_graphviz(model.get_params('DTC')['DTC'], out_file=f) 46 47 # 畫圖 48 N, M = 100, 100 # 橫縱各取樣多少個值 49 x1_min, x1_max = x[:, 0].min(), x[:, 0].max() # 第0列的範圍 50 x2_min, x2_max = x[:, 1].min(), x[:, 1].max() # 第1列的範圍 51 t1 = np.linspace(x1_min, x1_max, N) 52 t2 = np.linspace(x2_min, x2_max, M) 53 x1, x2 = np.meshgrid(t1, t2) # 生成 v 網格取樣點 54 x_show = np.stack((x1.flat, x2.flat), axis=1) # 測試點 55 56 # # 無意義,只是為了湊另外兩個維度 57 # # 開啟該註釋前,確保註釋掉x = x[:, :2] 58 # x3 = np.ones(x1.size) * np.average(x[:, 2]) 59 # x4 = np.ones(x1.size) * np.average(x[:, 3]) 60 # x_test = np.stack((x1.flat, x2.flat, x3, x4), axis=1) # 測試點 61 62 cm_light = mpl.colors.ListedColormap(['#A0FFA0', '#FFA0A0', '#A0A0FF']) 63 cm_dark = mpl.colors.ListedColormap(['g', 'r', 'b']) 64 y_show_hat = model.predict(x_show) # 預測值 預測的標籤值 65 66 y_show_hat = y_show_hat.reshape(x1.shape) # 使之與輸入的形狀相同 67 plt.figure(facecolor='w') 68 plt.pcolormesh(x1, x2, y_show_hat, cmap=cm_light) # 預測值的顯示 69 plt.scatter(x_test[:, 0], x_test[:, 1], c=y_test.ravel(), edgecolors='k', s=100, cmap=cm_dark, marker='o') # 測試資料 70 plt.scatter(x[:, 0], x[:, 1], c=y.ravel(), edgecolors='k', s=40, cmap=cm_dark) # 全部資料 71 plt.xlabel("花萼長度", fontsize=15) # 花萼長度、花萼寬度 72 plt.ylabel("花萼寬度", fontsize=15) 73 plt.xlim(x1_min, x1_max) 74 plt.ylim(x2_min, x2_max) 75 plt.grid(True) 76 plt.title(u'鳶尾花資料的決策樹分類', fontsize=17) 77 plt.show() 78 79 # 訓練集上的預測結果 80 y_test = y_test.reshape(-1) 81 82 result = (y_test_hat == y_test) # True則預測正確,False則預測錯誤 83 acc = np.mean(result) 84 print('準確度: %.2f%%' % (100 * acc)) 85 86 # 過擬合:錯誤率 87 depth = np.arange(1, 45) 88 err_list = [] 89 for d in depth: # 進行15 90 clf = DecisionTreeClassifier(criterion='entropy', max_depth=d) 91 clf = clf.fit(x_train, y_train) 92 y_test_hat = clf.predict(x_test) # 測試資料 93 result = (y_test_hat == y_test) # True則預測正確,False則預測錯誤 94 err = 1 - np.mean(result) 95 err_list.append(err) 96 print(d, ' 準確度: %.2f%%' % (100 * err)) 97 plt.figure(facecolor='w') 98 plt.plot(depth, err_list, 'ro-', lw=2) 99 plt.xlabel(u'決策樹深度', fontsize=15) 100 plt.ylabel(u'錯誤率', fontsize=15) 101 plt.title(u'決策樹深度與過擬合', fontsize=17) 102 plt.grid(True) 103 104 plt.show() 105 106 from sklearn import tree # 需要匯入的包 107 108 f = open('D:\\py_project\\iris_tree.dot', 'w') 109 110 tree.export_graphviz(model.get_params('DTC')['DTC'], out_file=f)





神經網路:

 1 import numpy as np
 2 from collections import Counter
 3 from sklearn import datasets
 4 import torch.nn.functional as Fun
 5 from torch.autograd import Variable
 6 import matplotlib.pyplot as plt
 7 import torch
 8 
 9 dataset = datasets.load_iris()
10 dataut=dataset['data']
11 priciple=dataset['target']
12 
13 input=torch.FloatTensor(dataset['data'])
14 label=torch.LongTensor(dataset['target'])
15 
16 #定義BP神經網路
17 class Net(torch.nn.Module):
18     def __init__(self, n_feature, n_hidden, n_output):
19         super(Net, self).__init__()
20         self.hidden = torch.nn.Linear(n_feature, n_hidden)   # hidden layer
21         self.out = torch.nn.Linear(n_hidden, n_output)   # output layer
22 
23     def forward(self, x):
24         x = Fun.relu(self.hidden(x))      # activation function for hidden layer we choose sigmoid
25         x = self.out(x)
26         return x
27 
28 net = Net(n_feature=4, n_hidden=20, n_output=3)
29 optimizer = torch.optim.SGD(net.parameters(), lr=0.02) #SGD: 隨機梯度下降
30 loss_func = torch.nn.CrossEntropyLoss() #針對分類問題的損失函式!
31 
32 #訓練資料
33 for t in range(500):
34     out = net(input)                 # input x and predict based on x
35     loss = loss_func(out, label)     # 輸出與label對比
36     optimizer.zero_grad()   # clear gradients for next train
37     loss.backward()         # backpropagation, compute gradients
38     optimizer.step()        # apply gradients
39 
40 out = net(input) #out是一個計算矩陣,可以用Fun.softmax(out)轉化為概率矩陣
41 prediction = torch.max(out, 1)[1] # 1返回index  0返回原值
42 pred_y = prediction.data.numpy()
43 target_y = label.data.numpy()
44 accuracy = float((pred_y == target_y).astype(int).sum()) / float(target_y.size)
45 print("鶯尾花預測準確率",accuracy)

鳶尾花資料集:

共150個分為 三種類別  setosa,versicolor,virginnica
花萼長度、花萼寬度,花瓣長度,花瓣寬度,種類

5.1,3.5,1.4,0.2,Iris-setosa
4.9,3.0,1.4,0.2,Iris-setosa
4.7,3.2,1.3,0.2,Iris-setosa
4.6,3.1,1.5,0.2,Iris-setosa
5.0,3.6,1.4,0.2,Iris-setosa
5.4,3.9,1.7,0.4,Iris-setosa
4.6,3.4,1.4,0.3,Iris-setosa
5.0,3.4,1.5,0.2,Iris-setosa
4.4,2.9,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.4,3.7,1.5,0.2,Iris-setosa
4.8,3.4,1.6,0.2,Iris-setosa
4.8,3.0,1.4,0.1,Iris-setosa
4.3,3.0,1.1,0.1,Iris-setosa
5.8,4.0,1.2,0.2,Iris-setosa
5.7,4.4,1.5,0.4,Iris-setosa
5.4,3.9,1.3,0.4,Iris-setosa
5.1,3.5,1.4,0.3,Iris-setosa
5.7,3.8,1.7,0.3,Iris-setosa
5.1,3.8,1.5,0.3,Iris-setosa
5.4,3.4,1.7,0.2,Iris-setosa
5.1,3.7,1.5,0.4,Iris-setosa
4.6,3.6,1.0,0.2,Iris-setosa
5.1,3.3,1.7,0.5,Iris-setosa
4.8,3.4,1.9,0.2,Iris-setosa
5.0,3.0,1.6,0.2,Iris-setosa
5.0,3.4,1.6,0.4,Iris-setosa
5.2,3.5,1.5,0.2,Iris-setosa
5.2,3.4,1.4,0.2,Iris-setosa
4.7,3.2,1.6,0.2,Iris-setosa
4.8,3.1,1.6,0.2,Iris-setosa
5.4,3.4,1.5,0.4,Iris-setosa
5.2,4.1,1.5,0.1,Iris-setosa
5.5,4.2,1.4,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
5.0,3.2,1.2,0.2,Iris-setosa
5.5,3.5,1.3,0.2,Iris-setosa
4.9,3.1,1.5,0.1,Iris-setosa
4.4,3.0,1.3,0.2,Iris-setosa
5.1,3.4,1.5,0.2,Iris-setosa
5.0,3.5,1.3,0.3,Iris-setosa
4.5,2.3,1.3,0.3,Iris-setosa
4.4,3.2,1.3,0.2,Iris-setosa
5.0,3.5,1.6,0.6,Iris-setosa
5.1,3.8,1.9,0.4,Iris-setosa
4.8,3.0,1.4,0.3,Iris-setosa
5.1,3.8,1.6,0.2,Iris-setosa
4.6,3.2,1.4,0.2,Iris-setosa
5.3,3.7,1.5,0.2,Iris-setosa
5.0,3.3,1.4,0.2,Iris-setosa
7.0,3.2,4.7,1.4,Iris-versicolor
6.4,3.2,4.5,1.5,Iris-versicolor
6.9,3.1,4.9,1.5,Iris-versicolor
5.5,2.3,4.0,1.3,Iris-versicolor
6.5,2.8,4.6,1.5,Iris-versicolor
5.7,2.8,4.5,1.3,Iris-versicolor
6.3,3.3,4.7,1.6,Iris-versicolor
4.9,2.4,3.3,1.0,Iris-versicolor
6.6,2.9,4.6,1.3,Iris-versicolor
5.2,2.7,3.9,1.4,Iris-versicolor
5.0,2.0,3.5,1.0,Iris-versicolor
5.9,3.0,4.2,1.5,Iris-versicolor
6.0,2.2,4.0,1.0,Iris-versicolor
6.1,2.9,4.7,1.4,Iris-versicolor
5.6,2.9,3.6,1.3,Iris-versicolor
6.7,3.1,4.4,1.4,Iris-versicolor
5.6,3.0,4.5,1.5,Iris-versicolor
5.8,2.7,4.1,1.0,Iris-versicolor
6.2,2.2,4.5,1.5,Iris-versicolor
5.6,2.5,3.9,1.1,Iris-versicolor
5.9,3.2,4.8,1.8,Iris-versicolor
6.1,2.8,4.0,1.3,Iris-versicolor
6.3,2.5,4.9,1.5,Iris-versicolor
6.1,2.8,4.7,1.2,Iris-versicolor
6.4,2.9,4.3,1.3,Iris-versicolor
6.6,3.0,4.4,1.4,Iris-versicolor
6.8,2.8,4.8,1.4,Iris-versicolor
6.7,3.0,5.0,1.7,Iris-versicolor
6.0,2.9,4.5,1.5,Iris-versicolor
5.7,2.6,3.5,1.0,Iris-versicolor
5.5,2.4,3.8,1.1,Iris-versicolor
5.5,2.4,3.7,1.0,Iris-versicolor
5.8,2.7,3.9,1.2,Iris-versicolor
6.0,2.7,5.1,1.6,Iris-versicolor
5.4,3.0,4.5,1.5,Iris-versicolor
6.0,3.4,4.5,1.6,Iris-versicolor
6.7,3.1,4.7,1.5,Iris-versicolor
6.3,2.3,4.4,1.3,Iris-versicolor
5.6,3.0,4.1,1.3,Iris-versicolor
5.5,2.5,4.0,1.3,Iris-versicolor
5.5,2.6,4.4,1.2,Iris-versicolor
6.1,3.0,4.6,1.4,Iris-versicolor
5.8,2.6,4.0,1.2,Iris-versicolor
5.0,2.3,3.3,1.0,Iris-versicolor
5.6,2.7,4.2,1.3,Iris-versicolor
5.7,3.0,4.2,1.2,Iris-versicolor
5.7,2.9,4.2,1.3,Iris-versicolor
6.2,2.9,4.3,1.3,Iris-versicolor
5.1,2.5,3.0,1.1,Iris-versicolor
5.7,2.8,4.1,1.3,Iris-versicolor
6.3,3.3,6.0,2.5,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
7.1,3.0,5.9,2.1,Iris-virginica
6.3,2.9,5.6,1.8,Iris-virginica
6.5,3.0,5.8,2.2,Iris-virginica
7.6,3.0,6.6,2.1,Iris-virginica
4.9,2.5,4.5,1.7,Iris-virginica
7.3,2.9,6.3,1.8,Iris-virginica
6.7,2.5,5.8,1.8,Iris-virginica
7.2,3.6,6.1,2.5,Iris-virginica
6.5,3.2,5.1,2.0,Iris-virginica
6.4,2.7,5.3,1.9,Iris-virginica
6.8,3.0,5.5,2.1,Iris-virginica
5.7,2.5,5.0,2.0,Iris-virginica
5.8,2.8,5.1,2.4,Iris-virginica
6.4,3.2,5.3,2.3,Iris-virginica
6.5,3.0,5.5,1.8,Iris-virginica
7.7,3.8,6.7,2.2,Iris-virginica
7.7,2.6,6.9,2.3,Iris-virginica
6.0,2.2,5.0,1.5,Iris-virginica
6.9,3.2,5.7,2.3,Iris-virginica
5.6,2.8,4.9,2.0,Iris-virginica
7.7,2.8,6.7,2.0,Iris-virginica
6.3,2.7,4.9,1.8,Iris-virginica
6.7,3.3,5.7,2.1,Iris-virginica
7.2,3.2,6.0,1.8,Iris-virginica
6.2,2.8,4.8,1.8,Iris-virginica
6.1,3.0,4.9,1.8,Iris-virginica
6.4,2.8,5.6,2.1,Iris-virginica
7.2,3.0,5.8,1.6,Iris-virginica
7.4,2.8,6.1,1.9,Iris-virginica
7.9,3.8,6.4,2.0,Iris-virginica
6.4,2.8,5.6,2.2,Iris-virginica
6.3,2.8,5.1,1.5,Iris-virginica
6.1,2.6,5.6,1.4,Iris-virginica
7.7,3.0,6.1,2.3,Iris-virginica
6.3,3.4,5.6,2.4,Iris-virginica
6.4,3.1,5.5,1.8,Iris-virginica
6.0,3.0,4.8,1.8,Iris-virginica
6.9,3.1,5.4,2.1,Iris-virginica
6.7,3.1,5.6,2.4,Iris-virginica
6.9,3.1,5.1,2.3,Iris-virginica
5.8,2.7,5.1,1.9,Iris-virginica
6.8,3.2,5.9,2.3,Iris-virginica
6.7,3.3,5.7,2.5,Iris-virginica
6.7,3.0,5.2,2.3,Iris-virginica
6.3,2.5,5.0,1.9,Iris-virginica
6.5,3.0,5.2,2.0,Iris-virginica
6.2,3.4,5.4,2.3,Iris-virginica
5.9,3.0,5.1,1.8,Iris-virginica