SVM 解決類別不平衡問題(scikit_learn)
阿新 • • 發佈:2019-01-07
在支援向量機中, 是負責懲罰錯誤分類資料的超引數。
解決資料類別不平衡的一個方法就是使用基於類別增加權重的值
其中,是誤分類的懲罰項,是與類別 的出現頻率成反比的權重引數, 就是類別 對應的 加權值
主要思路就是增大誤分類 少數類別 帶來的影響,保證 少數類別 的分類正確性,避免被多數類別掩蓋
在scikit-learn 中,使用 svc 方法時,可以通過設定引數
class_weight=’balanced’
實現上述加權功能
引數‘balanced’ 會自動按照以下公式計算權值:
其中, 為類別 對應權值, 為資料總數,為類別數量,即資料有 個種類,是類別 的資料個數
0.匯入庫
# Load libraries
from sklearn.svm import SVC
from sklearn import datasets
from sklearn.preprocessing import StandardScaler
import numpy as np
1、載入Iris Flower 資料集
#只加載兩個類別的資料,兩類,各50個
iris = datasets.load_iris()
X = iris.data[:100 ,:]
y = iris.target[:100]
2.不均衡化資料集
# 刪掉前四十個資料,資料總數變為60個
X = X[40:,:]
y = y[40:]
# 類別為0的類別不變,類別不為0的全部變為1
y = np.where((y == 0), 0, 1)
y
array([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
可以看到,有60個數據,10個為類別0,50個為類別1
3.特徵標準化
# Standarize features
scaler = StandardScaler()
X_std = scaler.fit_transform(X)
4.使用加權類別訓練SVM分類器
# Create support vector classifier
svc = SVC(kernel='linear', class_weight='balanced', C=1.0, random_state=0)
# Train classifier
model = svc.fit(X_std, y)
翻譯自Chris Albon 部落格
原文地址