codemosi的專欄,點選我可以跳到目錄一欄
本文以spark 1.0.0版本MLlib演算法為準進行分析
一、程式碼結構
邏輯迴歸程式碼主要包含三個部分
1、classfication:邏輯迴歸分類器
2、optimization:優化方法,包含了隨機梯度、LBFGS兩種演算法
3、evaluation:演算法效果評估計算
二、邏輯迴歸分類器
1、LogisticRegressionModel類
(1) 根據訓練資料集得到的weights來預測新的資料點的分類
(2)預測新資料分類
採用 這個公式來進行預測。
其中w為權重向量weightMatrix,X表示預測資料dataMatrix,a
threshold變數用來控制分類的閾值,預設值為0.5。表示如果預測值<threshold則為分類0.0,否則為1.0
如果threshold設定為空,這會輸出實際值
2、LogisticRegressionWithSGD類
此類主要接收外部資料集、演算法引數等輸入進行訓練得到一個邏輯迴歸模型LogisticRegressionModel
接收的輸入引數包括:
input:輸入資料集合,分類標籤lable只能是1.0和0.0兩種,feature為double型別
numIterations:迭代次數,預設為100
stepSize:迭代步伐大小,預設為1.0
miniBatchFraction:每次迭代參與計算的樣本比例,預設為1.0
initialWeights:weight向量初始值,預設為0向量
regParam:regularization正則化控制引數,預設值為0.0
在LogisticRegressionWithSGD中可以看出它使用了GradientDescent(梯度下降)來優化weight引數的
3、GeneralizedLinearModel類
LogisticRegressionWithSGD中的run方法會呼叫GeneralizedLinearModel中的run方法來訓練訓練資料
在run方法中最關鍵的就是optimize方法,正是通過它來求得weightMatrix的最優解
三、優化方法
邏輯迴歸採用了梯度下降演算法來尋找weight的最優解
邏輯迴歸cost function
其中:
對J(Θ)求導數後得到梯度為:
1、GradientDescent類
負責梯度下降演算法的執行,分為Gradient梯度計算與weight update兩個步驟來計算
2、Gradient類
負責演算法梯度計算,包含了LogisticGradient、LeastSquaresGradient、HingeGradient三種梯度計算實現,本文主要介紹LogisticGradient的實現:
其中data為公式中的x,label為公式中的y,weights為公式中的Θ
gradient就是對J(Θ)求導的計算結果, loss為J(Θ)的計算結果
3、Updater類
負責weight的迭代更新計算,包含了SimpleUpdater、L1Updater、SquaredL2Updater三種更新策略
(1)SimpleUpdater
沒有使用regularization,weights更新規則為:
其中:iter表示這是執行的第幾次迭代
(2)L1Updater
使用了L1 regularization(R(w) = ||w||),利用soft-thresholding方法求解,weight更新規則為:
signum是符號函式,它的取值如下:
(3)SquaredL2Updater
使用了L2 regularization(R(w) = 1/2 ||w||^2),weights更新規則為:
注意:Mllib中的邏輯迴歸演算法預設使用的SimpleUpdater
四、演算法效果評估
BinaryClassificationMetrics類中包含了多種演算法演算法效果評估計算方法:
相關 | 不相關 | |
檢索到 | true positives (tp) | false positives(fp) |
未檢索到 |
false negatives(fn) | true negatives (tn) |
1、ROC(receiver operating characteristic接收者操作特徵)
調整分類器threshold取值,以FPR為橫座標,TPR為縱座標做ROC曲線
Area Under roc Curve(AUC):處於ROC curve下方的那部分面積的大小
通常,AUC的值介於0.5到1.0之間,較大的AUC代表了較好的效能
2、precision-recall(準確率-召回率)
準確率和召回率是互相影響的,理想情況下肯定是做到兩者都高,
但是一般情況下準確率高、召回率就低,召回率低、準確率高,
當然如果兩者都低,那是什麼地方出問題了
3、F-Measure
在precision與recall都要求高的情況下,可以用F來衡量