1. 程式人生 > >關於sklearn下class_weight引數的一點原始碼閱讀與測試

關於sklearn下class_weight引數的一點原始碼閱讀與測試

版權宣告:歡迎轉載,請註明原出處 https://blog.csdn.net/go_og/article/details/81281387

一直沒有很在意過sklearn的class_weight的這個引數的具體作用細節,只大致瞭解是是用於處理樣本不均衡。後來在簡書上閱讀svm鬆弛變數的一些推導的時候,看到樣本不均衡的帶來的問題時候,想更深層次的看一下class_weight的具體作用方式,

svm鬆弛變數的簡書連結:https://www.jianshu.com/p/8a499171baa9

該文中的樣本不均衡的描述:

“樣本偏斜是指資料集中正負類樣本數量不均,比如正類樣本有10000個,負類樣本只有100個,這就可能使得超平面被“推向”負類(因為負類數量少,分佈得不夠廣),影響結果的準確性。” 

隨後翻開sklearn LR的原始碼:

我們以分類作為說明重點

在輸入引數class_weight=‘balanced’的時候:

 
  1. # compute the class weights for the entire dataset y

  2. if class_weight == "balanced":

  3. class_weight = compute_class_weight(class_weight,

  4. np.arange(len(self.classes_)),

  5. y)

  6. class_weight = dict(enumerate(class_weight))

進一步閱讀 compute_class_weight這個函式:

 
  1. elif class_weight == 'balanced':

  2. # Find the weight of each class as present in y.

  3. le = LabelEncoder()

  4. y_ind = le.fit_transform(y)

  5. if not all(np.in1d(classes, le.classes_)):

  6. raise ValueError("classes should have valid labels that are in y")

  7.  
  8. recip_freq = len(y) / (len(le.classes_) *

  9. np.bincount(y_ind).astype(np.float64))

  10. weight = recip_freq[le.transform(classes)]

compute_class_weight這個函式的作用是對於輸入的樣本,平衡類別之間的權重,下面寫段測試程式碼測試這個函式:

 
  1. # coding:utf-8

  2.  
  3. from sklearn.utils.class_weight import compute_class_weight

  4.  
  5. class_weight = 'balanced'

  6. label = [0] * 9 + [1]*1 + [2, 2]

  7. print(label) # [0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 2, 2]

  8. classes=[0, 1, 2]

  9. weight = compute_class_weight(class_weight, classes, label)

  10. print(weight) #[ 0.44444444 4. 2. ]

  11. print(.44444444 * 9) # 3.99999996

  12. print(4 * 1) # 4

  13. print(2 * 2) # 4

如上圖所示,可以看到這個函式把樣本的平衡後的權重乘積為4,每個類別均如此。

關於class_weight與sample_weight在損失函式上的具體計算方式:

 
  1. sample_weight *= class_weight_[le.fit_transform(y_bin)] # sample_weight 與 class_weight相乘

  2.  
  3. # Logistic loss is the negative of the log of the logistic function.

  4. out = -np.sum(sample_weight * log_logistic(yz)) + .5 * alpha * np.dot(w, w)

上述可以看出對於每個樣本,計算的損失函式乘上對應的sample_weight來計算最終的損失。這樣計算而來的損失函式不會因為樣本不平衡而被“推向”樣本量偏少的類別中。

class_weight以及sample_weight並沒有進行不平衡資料的處理,比如,上下采樣。詳細參見SMOTE EasyEnsemble等。

--------------------- 本文來自 摸摸小松鼠寶寶 的CSDN 部落格 ,全文地址請點選:https://blog.csdn.net/go_og/article/details/81281387?utm_source=copy