1. 程式人生 > >決策邊界視覺化

決策邊界視覺化

視覺化邊界

python程式碼實現


import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn import datasets
from sklearn import svm

# 載入分類資料
iris = datasets.load_iris()
# 在這裡只討論兩個特徵的情況, 因為多於兩個特徵是無法進行視覺化的
X = iris.data[:, 0:2]
y = iris.target

# 使用SVM分類器
clf = svm.LinearSVC().fit(X, y)
# 接下來進行視覺化, 要想進行視覺化, 我們核心就是要呼叫plt.contour函式畫圖, 但是它要求傳入三個矩陣, 而我們的x1和x2為向量, 預測的值也為向量, 所有我們需要將x1和x2轉換為矩陣

# 獲取邊界範圍, 為了產生資料
x1_min, x1_max = np.min(X[:, 0]) - 1, np.max(X[:, 0]) + 1
x2_min, x2_max = np.min(X[:, 1]) - 1, np.max(X[:, 1]) + 1

# 生成新的資料, 並呼叫meshgrid網格搜尋函式幫助我們生成矩陣
xx1, xx2 = np.meshgrid(np.arange(x1_min, x1_max, 0.02), np.arange(x2_min, x2_max, 0.02))
# 有了新的資料, 我們需要將這些資料輸入到分類器獲取到結果, 但是因為輸入的是矩陣, 我們需要給你將其轉換為符合條件的資料
Z = clf.predict(np.c_[xx1.ravel(), xx2.ravel()])
# 這個時候得到的是Z還是一個向量, 將這個向量轉為矩陣即可
Z = Z.reshape(xx1.shape)
plt.figure()
# 為什麼需要輸入矩陣, 因為等高線函式其實是3D函式, 3D座標是三個平面, 平面對應矩陣
plt.contour(xx1, xx2, Z, cmap=plt.cm.RdYlBu)
plt.scatter(X[:, 0], X[:, 1], c=y, cmap=plt.cm.RdYlBu)
plt.show()

MATLAB 實現


%% Initialize data
clear, clc, close all;
load('data.mat');

y(y == 0) = -1;
%% Training a SVM(Support Vector Machine) Classifier
svm = fitcsvm(X, y, 'KernelFunction', 'linear');
y_pred = predict(svm, X);
mean(double(y_pred == y))
jhplotdata(X, y);
hold on;
x1_min = min(X(:, 1)) - 1;
x1_max = max(X(:, 1)) + 1;
x2_min = min(X(:, 2)) - 1;
x2_max = max(X(:, 2)) + 1;
[XX, YY] = meshgrid(x1_min:0.02:x1_max, x2_min:0.02:x2_max);
Z = predict(svm, [XX(:) YY(:)]);
Z = reshape(Z, size(XX));
contour(XX, YY, Z);