1. 程式人生 > >過取樣(處理資料不平衡問題)

過取樣(處理資料不平衡問題)

import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.model_selection import GridSearchCV
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import classification_report
from imblearn.over_sampling import SMOTE

def load_and_analyse_data():
    data = pd.read_csv('./data/creditcard.csv')
    # ----------------------預處理---------------------------------------------

    # ----------------------標準化Amount列---------
    data['normAmout'] = StandardScaler().fit_transform(data['Amount'].values.reshape(-1, 1))
    data = data.drop(['Time', 'Amount'], axis=1)
    # ----------------------------------------------

    X = data.ix[:, data.columns != 'Class']
    y = data.ix[:, data.columns == 'Class']
    X_train,X_test,y_train,y_test = train_test_split(X,y,test_size=0.3,random_state=0)
    # ----------------------取樣-------------------
    sample_solver = SMOTE(random_state=0)
    X_sample ,y_sample = sample_solver.fit_sample(X_train,y_train)#從原始的訓練集採出樣本,用來訓練模型
    return np.array(X_test),np.array(y_test).reshape(len(y_test)),np.array(X_sample),np.array(y_sample).reshape(len(y_sample))

if __name__ == '__main__':
    X_test, y_test, X_sample, y_sample  = load_and_analyse_data()
    X_train,X_dev,y_train,y_dev = train_test_split(X_sample,y_sample,test_size=0.3,random_state=1)

    print("X_train:{}  X_dev:{}  X_test:{}".format(len(y_train), len(y_dev), len(y_test)))
    model = LogisticRegression()
    parameters = {'C':[0.001,0.003,0.01,0.03,0.1,0.3,1,3,10]}
    gs  = GridSearchCV(model,parameters,verbose=5,cv=5)
    gs.fit(X_train,y_train)#訓練模型,訓練集為取樣後的資料
    print('最佳模型:',gs.best_params_,gs.best_score_)
    print('在取樣資料上的效能表現:')
    print(gs.score(X_dev,y_dev))
    y_dev_pre = gs.predict(X_dev)
    print(classification_report(y_dev,y_dev_pre))
    print('在原始資料上的效能表現:')
    print(gs.score(X_test,y_test))
    y_pre = gs.predict(X_test)
    print(classification_report(y_test,y_pre))

 

資料:

連結: https://pan.baidu.com/s/1OlZ-nkS4sbjSgoaetqqOGg 提取碼: ggr8

什麼是過取樣:

目的:處理資料不平衡問題。

方法:當資料不平衡的時,比如樣本標籤1有10000個數據,樣本標籤0有100個數據,這時如果採用下采樣會浪費很多樣本,

所以引入過取樣,過取樣是根據樣本標籤少的樣本的規律去生成更多該標籤樣本,這樣使得資料趨向於平衡。

典型的過取樣方式是SMOTE等

關於SMOTE具體演算法:

https://blog.csdn.net/jiede1/article/details/70215477

1、對於少數類中每一個樣本x,以歐氏距離為標準計算它到少數類樣本集Smin中所有樣本的距離,得到其k近鄰。
2、根據樣本不平衡比例設定一個取樣比例以確定取樣倍率N,對於每一個少數類樣本x,從其k近鄰中隨機選擇若干個樣本,假設選擇的近鄰為xn。
3、對於每一個隨機選出的近鄰xn,分別與原樣本按照如下的公式構建新的樣本 。

        

 

效果對比: