基於決策樹模型對 IRIS 資料集分類
阿新 • • 發佈:2018-11-25
基於決策樹模型對 IRIS 資料集分類
文章目錄
1 python 實現
載入資料集
IRIS 資料集在 sklearn 模組中已經提供。
# -*- coding: utf-8 -*-
from matplotlib import pyplot as plt
import numpy as np
from sklearn import tree
from sklearn.datasets import load_iris
if __name__ == '__main__':
print('\n\n\n\n\n\n\n\n\n\n')
# show data info
data = load_iris() # 載入 IRIS 資料集
print('keys: \n', data.keys()) # ['data', 'target', 'target_names', 'DESCR', 'feature_names']
feature_names = data. get('feature_names')
print('feature names: \n', data.get('feature_names')) # 檢視屬性名稱
print('target names: \n', data.get('target_names')) # 檢視 label 名稱
x = data.get('data') # 獲取樣本矩陣
y = data.get('target') # 獲取與樣本對應的 label 向量
print(x.shape, y.shape) # 檢視樣本資料
print(data.get('DESCR' ))
視覺化資料集
# visualize the data
f = []
f.append(y==0) # 類別為第一類的樣本的邏輯索引
f.append(y==1) # 類別為第二類的樣本的邏輯索引
f.append(y==2) # 類別為第三類的樣本的邏輯索引
color = ['red','blue','green']
fig, axes = plt.subplots(4,4) # 繪製四個屬性兩輛之間的散點圖
for i, ax in enumerate(axes.flat):
row = i // 4
col = i % 4
if row == col:
ax.text(.1,.5, feature_names[row])
ax.set_xticks([])
ax.set_yticks([])
continue
for k in range(3):
ax.scatter(x[f[k],row], x[f[k],col], c=color[k], s=3)
fig.subplots_adjust(hspace=0.3, wspace=0.3) # 設定間距
plt.show()
分類和預測
# 隨機劃分訓練集和測試集
num = x.shape[0] # 樣本總數
ratio = 7/3 # 劃分比例,訓練集數目:測試集數目
num_test = int(num/(1+ratio)) # 測試集樣本數目
num_train = num - num_test # 訓練集樣本數目
index = np.arange(num) # 產生樣本標號
np.random.shuffle(index) # 洗牌
x_test = x[index[:num_test],:] # 取出洗牌後前 num_test 作為測試集
y_test = y[index[:num_test]]
x_train = x[index[num_test:],:] # 剩餘作為訓練集
y_train = y[index[num_test:]]
# 構建決策樹
clf = tree.DecisionTreeClassifier() # 建立決策樹物件
clf.fit(x_train, y_train) # 決策樹擬合
# 預測
y_test_pre = clf.predict(x_test) # 利用擬合的決策樹進行預測
print('the predict values are', y_test_pre) # 顯示結果
計算準確率
# 計算分類準確率
acc = sum(y_test_pre==y_test)/num_test
print('the accuracy is', acc) # 顯示預測準確率
由於資料集的劃分是隨機的每次得到的準確率都不一樣,一般位於91%-97%之間。
2 基於MATLAB 實現
Matlab 對資料的視覺化
實現的實現過程與 python 的流程是一樣的,只是兩種程式語言的語法上的差異。
clc
clear all
close all;
load fisheriris; % 載入資料集
% 資料視覺化
x = meas;
y = species;
class = unique(y);
attr = {'sepal length', 'sepal width', 'petal length', 'petal width'};
ind1 = ismember(y, class{1});
ind2 = ismember(y, class{2});
ind3 = ismember(y, class{3});
s=10;
for i=1:4
for j=1:4
subplot(4,4,4*(i-1)+j);
if i==j
set(gca, 'xtick', [], 'ytick', []);
text(.2, .5, attr{i});
set(gca, 'box', 'on');
continue;
end
scatter(x(ind1,i), x(ind1,j), s, 'r', 'MarkerFaceColor', 'r');
hold on
scatter(x(ind2,i), x(ind2,j), s, 'b', 'MarkerFaceColor', 'b');
hold on
scatter(x(ind3,i), x(ind3,j), s, 'g', 'MarkerFaceColor', 'g');
set(gca, 'box', 'on');
end
end
% 隨機劃分訓練集和測試集
ratio = 7/3;
num = length(x);
num_test = round(num/(1+ratio));
num_train = num - num_test;
index = randperm(num);
x_train = x(index(1:num_train),:);
y_train = y(index(1:num_train));
x_test = x(index(num_train+1:end),:);
y_test = y(index(num_train+1:end));
% 構建決策樹並預測結果
tree = fitctree(x_train, y_train);
y_test_p = predict(tree, x_test);
% 計算預測準確率
acc = sum(strcmp(y_test,y_test_p))/num_test;
disp(['The accuracy is ', num2str(acc)]);