使用numpy跟sympy實作Linear regression « Terrence的宅宅幻想
這一陣子在上機器學習的課,對於線性回歸(linear regression)的演算法一直不是很理解
這幾天嘗試用numpy去實作才總算真的能夠稍微理解,今天這裡就用簡單線性回歸(Simple linear regression)為例來做個筆記
線性回歸(linear regression)
個人的理解就是想要找出一群資料(Data Set)的線性關係(回歸方程式),以對未來未知的新進資料進行推估或預測.用在機器學習領域則可以做簡單的二元分類.做法是把資料丟到回歸方程式看出來的結果是否大於某個門檻值(threshold)
用簡單線性回歸作說明,假設有一組(x, y)資料代表散落在平面上的一堆點
拉回正題,簡單線性回歸因為只有一個feature(或稱自變數),所以其回歸方程式可以再二維平面上簡單表示成
這回歸方程式可以表示x與y之間的線性關係的一條線,可以用來推估新的x進來的時候,他的y可能是多少
現在問題來了,這條線(方程式)怎麼來的
網路上還蠻容易找到已經套好的“公式解”直接求a跟b的值,但是我想更深入探討求a跟b過程的精神
參考上圖每筆資料的應變數y跟用回歸方程式y=ax+b求出來的y之間都存在著一個距離一般稱之為殘差(residual)
對於我們來說希望這條線到所有點的距離總和也就是殘差的和要最小,這樣這條線就最能代表這群資料的線性關係
對於這種求差距最小的方法常用的是最小平方法(least squares method)
最小平方法(least squares method)
資料(x, y)跟回歸方程式求出來的點(x, y^=ax+b)的差距(y - y^)可以用以下公式表示
而因為是“差”所以有可能有正有負,因此取平方保證差距是個正值以確保之後加總
在此之前我一直有個疑問,為什麼不用絕對值就好呢?
查了一些資料跟複習了大一微積分之後才多少有點領略
使用平方(squares)取正值有兩個原因
針對第一點,統計學上在某些情況下差距只取絕對值的時候會難以看出資料間的離散程度,但是透過平方除了可以取正值之外最重要的一個因素就是“放大差異
而最後我們可以得到所有點的殘差和(sum residual)的公式如下
對於回歸方程式的系數a,b我們的最終目標是讓上面這條公式算出來的差距總和“最小”,這樣就確保y=ax+b這條線到所有點的距離總和最小
而要怎麼讓上面這個公式有最小值呢?以前的時候我百思不得其解,回頭翻大一微積分的時候終於比較能理解做法跟使用微積分的原因,答案是用微積分的“梯度(Gradient)”
在談梯度之前先聊聊為什麼用微積分,在做線性回歸(linear regression)的時候由於求的解是一條直線,如果單純用絕對值去算殘差和的話他的表示式仍是直線
直線沒辦法知道極限值(Limit)在哪裡,要求最小值(也可以說是極限值)比較難以有直觀的解.但方程式經過平方(squares)後就不一樣了.它會變成一條曲線.
有曲線之後就會有山谷出現,有山谷自然就會有最低點或最高點(極限值),透過微分可以簡單替我們找出一條曲線什麼時候會有最低點
舉個簡單例子,參考以下公式及其代表曲線
上面這條曲線的最低點(極限值),透過微分方程式=0的時候可以找到極限值這個性質去找最低點
微分的解是x=1的時候我們可以找到該曲線的極限值
拉回正題,回頭看我們的殘差和公式
上面這條公式要如何求最佳解呢,套用梯度的定義,我們分別針對a跟b作微分可以得到梯度向量
我們將得兩組微分方程式可以對a,b解聯立方程式求得a跟b的最佳解,這就是最小平方法了
舉個實際的例子,假設有三個點(1, 2), (2, 4), (3, 3),求代表這三個點的直線回歸方程式
上面三筆資料帶入殘差和的公式可以得到
展開後可以得到
現在的目標就是要求上面這條公式帶入什麼樣的a跟b會有最小值,現在分別對a跟b微分可以到兩條方程式
對a微分
對b微分
只要將上面兩式子分別等於0就可以用聯立方程式求出a,b
最後得到a=0.5跟b=2這條線就是我們想找的最佳解
使用numpy跟sympy實作
用numpy可以方便我們作數值運算,用sympy則可以替我們處理方程式的計算
首先產生測試資料,用y=x這條方程式為基準去隨機產生一些點,上下震盪幅度約為2
上面的結果大概會長的像下面一樣
簡單解釋一下linspace這個函式會在一個區間產生連續的一群資料,除了可以當產生資料的工具之外
用在畫圖上也可以用來畫方程式線,vectorize是把一個把python函式包裝成一個可以對numpy的資料進行運算的新function
現在開始進行linear regression的實作、如下
用sympy的symbols可以替我們產生可以當作方程式未知數的“符號”並且可以進行數學幾何運算
之後再用solve就可以很簡單的算出方程式=0的解
之後再用求到的a,b畫出回歸方程式的線
其結果會長得像如下
藍色的線是我們用來產生資料的原始方程式,綠色的線則是經由回歸算出來的回歸方程式
由於資料點不多,回歸方程式跟產生資料的方程式差的有點多
這就牽扯到機器學習跟統計的問題了,資料要夠多才能誤差比較小預測比較精準
下面一個資料比較多的結果,可以看出來回歸方程式已經非常接近原始產生資料的方程式了
這次只用簡單線性回歸作介紹,但上面這個方法可以在套用到多個feature,不管有多少個自變數方法都一樣
希望下次有機會再用多變數進行分享