1. 程式人生 > 其它 >C++線性迴歸(SGD、BGD、MBGD)

C++線性迴歸(SGD、BGD、MBGD)

技術標籤:機器學習機器學習大資料

程式碼

#include<iostream>
#include<random>

using namespace std;

double w = 0;//y=w*x+b,初始w
double b = 0;//y=w*x+b,初始b
double true_w = 10.0;//真實w
double true_b = 10.0;//真實b
double alpha = 0.3;//步長
int samples = 1000;//樣本數

double* x = new double[samples];//樣本特徵
double* y = new double[samples]
;//樣本標籤 /* * 函式:生成資料 * 輸入:無 * 輸出:無 */ void generator() { default_random_engine gen;//隨機數 normal_distribution<double>normal(0.01, 0.01);//正態分佈 for (int i = 0; i < samples; i++) { x[i] = rand() * 1.0 / RAND_MAX;//x取值為[0,1] y[i] = true_w * x[i] + true_b + normal(gen);//根據x計算y } } /* * 函式:隨機梯度下降 * 輸入:特徵值x,標籤y * 輸出:w、b的預測值 */
void SGD(double* x, double* y) { for (int i = 0; i < samples; i++) { double TempW = w - 2 * alpha * x[i] * (w * x[i] + b - y[i]);//求導 double TempB = b - 2 * alpha * (w * x[i] + b - y[i]);//求導 w = TempW;//更新 b = TempB;//更新 } cout << "w=" << w << " b=" <<
b << endl; } /* * 函式:批量梯度下降 * 輸入:特徵值x,標籤y,停止求解的閾值dert * 輸出:w、b的預測值 */ void BGD(double* x, double* y, double dert) { double TempW = 0; double tempB = 0; while (1) { double RcdW = w; double RcdB = b; for (int i = 0; i < samples; i++) { TempW += 2 * x[i] * (w * x[i] + b - y[i]);//導數累加 tempB += 2 * (w * x[i] + b - y[i]);//導數累加 } w -= alpha * TempW / (2.0 * samples);//更新 b -= alpha * tempB / (2.0 * samples);//更新 if (abs(RcdW - w) < dert && abs(RcdB - b) < dert)break; } cout << "w=" << w << " b=" << b << endl; } /* * 函式:批量梯度下降 * 輸入:特徵值x,標籤y,每批數量 * 輸出:w、b的預測值 */ void MBGD(double* x, double* y, int batch) { for (int i = 0; i < samples / batch; i++) { double TempW = 0; double TempB = 0; for (int j = 0; j < batch; j++) { int k = i * batch + j; TempW += 2 * x[k] * (w * x[k] + b - y[k]);//導數累加 TempB += 2 * (w * x[k] + b - y[k]);//導數累加 } w -= alpha * TempW / (2.0 * samples);//更新 b -= alpha * TempB / (2.0 * samples);//更新 } cout << "w=" << w << " b=" << b << endl; } int main(){ generator(); SGD(x, y); generator(); BGD(x, y, 0.001); generator(); MBGD(x, y, 10); system("pause"); }

結果

最後的求解結果還行,記錄一下。
擬合結果