資料樣本不平衡時處理方法(Resampling strategies for imbalanced datasets)
文章目錄(Table of Contents)
簡介
這一部分講一下樣本平衡的一些做法。所有內容來自下面的連結。
下面這個參考資料很好,十分建議檢視:Resampling strategies for imbalanced datasets
為什麼要做樣本平衡
如果正負樣本差別很大,或是類別與類別之間相差很大,那麼模型就會偏向於預測最常出現的樣本。同時,這樣做最後可以獲得較高的準確率,但是這個準確率不能說明模型有多好。
In a dataset with highly unbalanced classes, if the classifier always "predicts" the most common class without performing any analysis of the features, it will still have a high accuracy rate, obviously illusory.
解決辦法
解決樣本不平衡的問題,有兩個大的方向是可以解決的。一個是under-sampling,另一個是over-sampling。(A widely adopted technique for dealing with highly unbalanced datasets is called resampling. It consists of removing samples from the majority class(under-sampling) and / oradding more examples from the minority class(over-sampling).)
Under-sampling
under-sampling我們可以理解為將較多的分類中的樣本中取一些出來,使得較多的分類的數量與較少分類的數量相同。(這裡取樣的方式會有很多)
Over-sampling
所謂over-sampling,我們可以理解為將少的一部分樣本進行重取樣,使其變多。(這裡重取樣的方式會有很多)
下面這張圖片概括了under-sampling和over-sampling兩者區別。
當然,使用上面兩種方式是會有代價的,如果使用undersampling,會出現丟失資訊的問題。如果使用oversampling的方式,會出現過擬合的問題。
Despite the advantage of balancing classes, these techniques also have their weaknesses (there is no free lunch). The simplest implementation of over-sampling is to duplicate random records from the minority class, which can cause overfitting. In under-sampling, the simplest technique involves removing random records from the majority class, which can cause loss of information.
簡單實驗
下面我們使用NSL-KDD資料集來做一下簡單的實驗。我們在這裡只實現簡單的over-sampling和under-sampling,關於一些別的取樣方式可以參考上面的連結,我在這裡再放一下。
- 十分好的參考資料:Resampling strategies for imbalanced datasets
- 簡單原理介紹 :imbalanced-learn
資料集準備
- importpandasaspd
- importnumpyasnp
- importmatplotlib.pyplotasplt
下面匯入資料集
- COL_NAMES=["duration","protocol_type","service","flag","src_bytes",
- "dst_bytes","land","wrong_fragment","urgent","hot","num_failed_logins",
- "logged_in","num_compromised","root_shell","su_attempted","num_root",
- "num_file_creations","num_shells","num_access_files","num_outbound_cmds",
- "is_host_login","is_guest_login","count","srv_count","serror_rate",
- "srv_serror_rate","rerror_rate","srv_rerror_rate","same_srv_rate",
- "diff_srv_rate","srv_diff_host_rate","dst_host_count","dst_host_srv_count",
- "dst_host_same_srv_rate","dst_host_diff_srv_rate","dst_host_same_src_port_rate",
- "dst_host_srv_diff_host_rate","dst_host_serror_rate","dst_host_srv_serror_rate",
- "dst_host_rerror_rate","dst_host_srv_rerror_rate","labels"]
- #匯入資料集
- Trainfilepath='./NSL-KDD/KDDTrain+.txt'
- dfDataTrain=pd.read_csv(Trainfilepath,names=COL_NAMES,index_col=False)
我們簡單檢視一下各類攻擊的分佈。
- target_count=dfDataTrain.labels.value_counts()
- target_count.plot(kind='barh',title='Count(target)');
在這裡,我們只對嘗試其中的四種攻擊,分別是back,neptune,smurf,teardrop。我們簡單看一下這四種攻擊的分佈。
- DataBack=dfDataTrain[dfDataTrain['labels']=='back']
- DataNeptune=dfDataTrain[dfDataTrain['labels']=='neptune']
- DataSmurf=dfDataTrain[dfDataTrain['labels']=='smurf']
- DataTeardrop=dfDataTrain[dfDataTrain['labels']=='teardrop']
- DataAll=pd.concat([DataBack,DataNeptune,DataSmurf,DataTeardrop],axis=0,ignore_index=True).sample(frac=1)#合併成為新的資料
- #檢視各類的分佈
- target_count=DataAll.labels.value_counts()
- target_count.plot(kind='barh',title='Count(target)');
Over-Sampling
我們使用簡單的過取樣,即重複取值,使其樣本個數增多。
- fromimblearn.over_samplingimportRandomOverSampler
- #實現簡單過取樣
- ros=RandomOverSampler()
- X=DataAll.iloc[:,:41].to_numpy()
- y=DataAll['labels'].to_numpy()
- X_ros,y_ros=ros.fit_sample(X,y)
- print(X_ros.shape[0]-X.shape[0],'newrandompickedpoints')
- #組成pandas的格式
- DataAll=pd.DataFrame(X_ros,columns=COL_NAMES[:-1])
- DataAll['labels']=y_ros
- #進行視覺化展示
- target_count=DataAll.labels.value_counts()
- target_count.plot(kind='barh',title='Count(target)');
簡單看一下最終的結果,可以看到每個類別的樣本現在都是40000+,相當於都和之前最多的樣本的個數是相同的。
Under-Sampling
下面簡單實現一下下采樣,也是直接去掉比較多的類中的資料。
- fromimblearn.under_samplingimportRandomUnderSampler
- rus=RandomUnderSampler(return_indices=True)
- X=DataAll.iloc[:,:41].to_numpy()
- y=DataAll['labels'].to_numpy()
- X_rus,y_rus,id_rus=rus.fit_sample(X,y)
- #組成pandas的格式
- DataAll=pd.DataFrame(X_rus,columns=COL_NAMES[:-1])
- DataAll['labels']=y_rus
- #進行繪圖
- target_count=DataAll.labels.value_counts()
- target_count.plot(kind='barh',title='Count(target)');
可以看到現在每個樣本的個數都是800+,這樣就完成了under-sampling.
這裡只是簡單的介紹關於上取樣和下采樣的方式,還有一些其他的取樣方式可以參考上面的連結。