sklearn通過OneVsRestClassifier實現svm.SVC的多分類
阿新 • • 發佈:2019-01-05
這個repo
用來記錄一些python技巧、書籍、學習連結等,歡迎star
svm.SVC 支援向量機分類是一個很有效的分類方式,但是其只對2分類有效(sklearn中並不是,針對多分類其使用了1vs多,decision_function_shape : 'ovo', 'ovr', default='ovr'
, 這裡假裝只對2分類有效,用來進行下面的內容,  ̄□ ̄||),不過,可以將多分類經過多次2分類最終實現多分類,而sklearn中的multiclass包就可以實現這種方式,減少我們重複造輪子。
import numpy as np
from sklearn.datasets import load_digits
from sklearn.preprocessing import label_binarize
from sklearn.multiclass import OneVsRestClassifier
from sklearn.model_selection import train_test_split
digits = load_digits()
x, y = digits.data, digits.target
y = label_binarize(y, classes=list(range(10 )))
x_train, x_test, y_train, y_test = train_test_split(x, y)
model = OneVsRestClassifier(svm.SVC(kernel='linear'))
clf = model.fit(x_train, y_train)
In [236]: clf.score(x_train, y_train)
Out[236]: 0.97475872308834444
In [237]: clf.score(x_test, y_test)
Out[237]: 0.85999999999999999
In [242]: np.argmax(y_test, axis=1 )
Out[242]: array([0, 0, 2, ..., 5, 6, 7], dtype=int64)
In [243]: np.argmax(clf.decision_function(x_test), axis=1)
Out[243]: array([0, 0, 2, ..., 5, 6, 7], dtype=int64)