Python構建SVM分類器(線性)
阿新 • • 發佈:2019-01-03
1.SVM建立線性分類器
SVM用來構建分類器和迴歸器的監督學習模型,SVM通過對數學方程組的求解,可以找出兩組資料之間的最佳分割邊界。
2.準備工作
我們首先對資料進行視覺化,使用的檔案來自學習書籍配套管網。
首先增加以下程式碼:
import numpy as np
import matplotlib.pyplot as plt
import utilities
# Load input data
input_file = 'data_multivar.txt'
X, y = utilities.load_data(input_file)
剛剛匯入了需要的程式包,確定了檔案的名稱,接下來看load_data()方法:
#載入輸入檔案中的多變數資料 def load_data(input_file): X = [] y = [] with open(input_file, 'r') as f: for line in f.readlines(): data = [float(x) for x in line.split(',')] X.append(data[:-1]) y.append(data[-1]) X = np.array(X) y = np.array(y) return X, y
需要將資料分成類,如下所示:
class_0 = np.array([X[i] for i in range(len(X)) if y[i]==0])
class_1 = np.array([X[i] for i in range(len(X)) if y[i]==1])
資料分類後,我們將它們畫出來:
可以發現數據有兩個型別組成,我們的目標是要建立一個可以將實心方塊和空心方塊分開的模型、plt.figure() plt.scatter(class_0[:,0], class_0[:,1], facecolors='black', edgecolors='black', marker='s') plt.scatter(class_1[:,0], class_1[:,1], facecolors='None', edgecolors='black', marker='s') plt.title('Input data') plt.show()
3.實現步驟
(1)將資料集分割為訓練資料集和測試資料集,加入以下程式碼:
from sklearn import cross_validation
from sklearn.svm import SVC
X_train, X_test, y_train, y_test = cross_validation.train_test_split(X, y, test_size=0.25, random_state=5)
(2)用線性核函式初始化一個SVM物件,並訓練線性SVM分類器,加入以下函式:
params = {'kernel': 'linear'}
classifier = SVC(**params)
classifier.fit(X_train, y_train)
utilities.plot_classifier(classifier, X_train, y_train, 'Training dataset')
可以看到以下的圖形
plot_classifier函式為之前構造過的畫圖函式。
接下來看分類器對測試集的執行,增加以下程式碼:
y_test_pred = classifier.predict(X_test)
utilities.plot_classifier(classifier, X_test, y_test, 'Test dataset')
接下來計算訓練集的準確性:
from sklearn.metrics import classification_report
target_names = ['Class-' + str(int(i)) for i in set(y)]
print("\n" + "#"*30)
print("\nClassifier performance on training dataset\n")
print(classification_report(y_train,classifier.predict(X_train),
target_names = target_names))
print("#"*30+"\n")
最後檢視分類器為測試集生成的分類報告:print("#"*30)
print("\nClassification report on test dataset\n")
print(classification_report(y_test,y_test_pred,
target_names=target_names))
print("#"*30+"\n")
從前面的資料視覺化圖形中我們可以看到實心方塊完全被空心方塊保衛,也就是兩種資料不是線性可分,無法畫出一條分離兩種型別資料的直線,需要利用非線性分類器。