計算與推斷思維 十三、預測
十三、預測
資料科學的一個重要方面,是發現數據可以告訴我們什麼未來的事情。氣候和汙染的資料說了幾十年內溫度的什麼事情?根據一個人的網際網路個人資訊,哪些網站可能會讓他感興趣?病人的病史如何用來判斷他或她對治療的反應?
為了回答這樣的問題,資料科學家已經開發出了預測的方法。在本章中,我們將研究一種最常用的方法,基於一個變數的值來預測另一個變數。
方法的基礎由弗朗西斯·高爾頓爵士(Sir Francis Galton)奠定。我們在 7.1 節看到,高爾頓研究了身體特徵是如何從一代傳到下一代的。他最著名的工作之一,是根據父母的高度預測子女的身高。我們已經研究了高爾頓為此收集的資料集。heights
# Galton's data on heights of parents and their adult children
galton = Table.read_table('galton.csv')
heights = Table().with_columns(
'MidParent', galton.column('midparentHeight'),
'Child', galton.column('childHeight')
)
heights
MidParent | Child |
---|---|
75.43 | 73.2 |
75.43 | 69.2 |
75.43 | 69 |
75.43 | 69 |
73.66 | 73.5 |
73.66 | 72.5 |
73.66 | 65.5 |
73.66 | 65.5 |
72.06 | 71 |
72.06 | 68 |
(省略了 924 行)
heights.scatter('MidParent')
收集資料的主要原因是能夠預測成年子女的身高,他們的父母與資料集中相似。 在注意到兩個變數之間的正相關之後,我們在第 7.1 節中做了這些預測。
我們的方法是,基於新人的雙親身高周圍的所有點來做預測。 為此,我們編寫了一個名為predict_child
def predict_child(mpht):
"""Return a prediction of the height of a child
whose parents have a midparent height of mpht.
The prediction is the average height of the children
whose midparent height is in the range mpht plus or minus 0.5 inches.
"""
close_points = heights.where('MidParent', are.between(mpht-0.5, mpht + 0.5))
return close_points.column('Child').mean()
我們將函式應用於Midparent
列,視覺化我們的結果。
# Apply predict_child to all the midparent heights
heights_with_predictions = heights.with_column(
'Prediction', heights.apply(predict_child, 'MidParent')
)
# Draw the original scatter plot along with the predicted values
heights_with_predictions.scatter('MidParent')
給定雙親身高的預測值,大致位於給定身高處的垂直條形的中心。這種預測方法稱為迴歸。 本章後面我們會看到這個術語的來源。 我們也會看到,我們是否可以避免將“接近”任意定義為“在半英寸之內”。 但是首先我們要開發一個可用於很多環境的方法,來決定一個變數作為另一個變數的預測值有多好。
相關性
在本節中,我們將開發一種度量,度量散點圖緊密聚集在一條直線上的程度。 形式上,這被稱為測量線性關聯。
hybrid
表包含了 1997 年到 2013 年在美國銷售的混合動力車的資料。資料來自佛羅里達大學 Larry Winner 教授的線上資料檔案。這些列為:
vehicle
:車的型號year
:出廠年份msrp
: 2013 年製造商的建議零售價(美元)acceleration
: 加速度(千米每小時每秒)mpg
: 燃油效率(英里每加侖)class
: 型號的類別
(省略了 143 行)
下圖是msrp
與acceleration
的散點圖。 這意味著msrp
繪製在縱軸上並且acceleration
在橫軸上。
hybrid.scatter('acceleration', 'msrp')
注意正相關。 散點圖傾斜向上,表明加速度較大的車輛通常成本更高;相反,價格更高的汽車通常具有更大的加速。
msrp
與mpg
的散點圖表明瞭負相關。 mpg
較高的混合動力車往往成本較低。 這似乎令人驚訝,直到你明白了,加速更快的汽車往往燃油效率更低,行駛里程更低。 之前的散點圖顯示,這些也是價格更高的車型。
hybrid.scatter('mpg', 'msrp')
除了負相關,價格與效率的散點圖顯示了兩個變數之間的非線性關係。 這些點似乎圍繞在一條曲線周圍,而不是一條直線。
但是,如果我們只將資料限制在 SUV 類別中,價格和效率之間仍然負相關的,但是這種關係似乎更為線性。 SUV 價格與加速度之間的關係也呈線性趨勢,但是斜率是正的。
suv = hybrid.where('class', 'SUV')
suv.scatter('mpg', 'msrp')
suv.scatter('acceleration', 'msrp')
你會注意到,即使不關注變數被測量的單位,我們也可以從散點圖的大體方向和形狀中得到有用的資訊。
事實上,我們可以將所有的變數繪製成標準單位,並且繪圖看起來是一樣的。 這給了我們一個方法,來比較兩個散點圖中的線性程度。
回想一下,在前面的章節中,我們定義了standard_units
函式來將數值陣列轉換為標準單位。
def standard_units(any_numbers):
"Convert any array of numbers to standard units."
return (any_numbers - np.mean(any_numbers))/np.std(any_numbers)
我們可以使用這個函式重新繪製 SUV 的兩個散點圖,所有變數都以標準單位測量。
Table().with_columns(
'mpg (standard units)', standard_units(suv.column('mpg')),
'msrp (standard units)', standard_units(suv.column('msrp'))
).scatter(0, 1)
plots.xlim(-3, 3)
plots.ylim(-3, 3);
Table().with_columns(
'acceleration (standard units)', standard_units(suv.column('acceleration')),
'msrp (standard units)', standard_units(suv.column('msrp'))
).scatter(0, 1)
plots.xlim(-3, 3)
plots.ylim(-3, 3);
我們在這些數字中看到的關聯與我們之前看到的一樣。 另外,由於現在兩張散點圖的刻度完全相同,我們可以看到,第二張圖中的線性關係比第一張圖中的線性關係更加模糊。
我們現在將定義一個度量,使用標準單位來量化我們看到的這種關聯。
相關係數
相關係數測量兩個變數之間線性關係的強度。 在圖形上,它測量散點圖聚集在一條直線上的程度。
相關係數這個術語不容易表述,所以它通常縮寫為相關性並用r
表示。
以下是一些關於r
的數學事實,我們將通過模擬觀察。
- 相關係數
r
是介於-1
和1
之間的數字。 r
度量了散點圖圍繞一條直線聚集的程度。- 如果散點圖是完美的向上傾斜的直線,
r = 1
,如果散點圖是完美的向下傾斜的直線,r = -1
。
函式r_scatter
接受r
值作為引數,模擬相關性非常接近r
的散點圖。 由於模擬中的隨機性,相關性不會完全等於r
。
呼叫r_scatter
幾次,以r
的不同值作為引數,並檢視散點圖如何變化。
當r = 1
時,散點圖是完全線性的,向上傾斜。 當r = -1
時,散點圖是完全線性的,向下傾斜。 當r = 0
時,散點圖是圍繞水平軸的不定形雲,並且變數據說是不相關的。
r_scatter(0.9)
r_scatter(0.25)
r_scatter(0)
r_scatter(-0.55)
計算r
目前為止,r
的公式還不清楚。 它擁有超出本課程範圍的數學基礎。 然而,你將會看到,這個計算很簡單,可以幫助我們理解r
的幾個屬性。
r
的公式:
r
是兩個變數的乘積的均值,這兩個變數都以標準單位來衡量。
以下是計算中的步驟。 我們將把這些步驟應用於x
和y
值的簡單表格。
x = np.arange(1, 7, 1)
y = make_array(2, 3, 1, 5, 2, 7)
t = Table().with_columns(
'x', x,
'y', y
)
t
x | y |
---|---|
1 | 2 |
2 | 3 |
3 | 1 |
4 | 5 |
5 | 2 |
6 | 7 |
根據散點圖,我們預計r
將是正值,但不等於 1。
t.scatter(0, 1, s=30, color='red')
第一步:將每個變數轉換為標準單位。
t_su = t.with_columns(
'x (standard units)', standard_units(x),
'y (standard units)', standard_units(y)
)
t_su
x | y | x (standard units) | y (standard units) |
---|---|---|---|
1 | 2 | -1.46385 | -0.648886 |
2 | 3 | -0.87831 | -0.162221 |
3 | 1 | -0.29277 | -1.13555 |
4 | 5 | 0.29277 | 0.811107 |
5 | 2 | 0.87831 | -0.648886 |
6 | 7 | 1.46385 | 1.78444 |
第二步:將每一對標準單位相乘
t_product = t_su.with_column('product of standard units', t_su.column(2) * t_su.column(3))
t_product
x | y | x (standard units) | y (standard units) | product of standard units |
---|---|---|---|---|
1 | 2 | -1.46385 | -0.648886 | 0.949871 |
2 | 3 | -0.87831 | -0.162221 | 0.142481 |
3 | 1 | -0.29277 | -1.13555 | 0.332455 |
4 | 5 | 0.29277 | 0.811107 | 0.237468 |
5 | 2 | 0.87831 | -0.648886 | -0.569923 |
6 | 7 | 1.46385 | 1.78444 | 2.61215 |
第三步:r
是第二步計算的乘積的均值。
# r is the average of the products of standard units
r = np.mean(t_product.column(4))
r
0.61741639718977093
正如我們的預期,r
是個不等於的正值。
r
的性質
計算結果表明:
r
是一個純數字。 它沒有單位。 這是因為r
基於標準單位。
r
不受任何軸上單位的影響。 這也是因為r
基於標準單位。
r
不受軸的交換的影響。 在代數上,這是因為標準單位的乘積不依賴於哪個變數被稱為x
和y
。 在幾何上,軸的切換關於y = x
直線翻轉了散點圖,但不會改變群聚度和關聯的符號。
t.scatter('y', 'x', s=30, color='red')
correlation
函式
我們將要重複計算相關性,所以定義一個函式會有幫助,這個函式通過執行上述所有步驟來計算它。 讓我們定義一個函式correlation
,它接受一個表格,和兩列的標籤。該函式返回r
,它是標準單位下這些列的值的乘積的平均值。
def correlation(t, x, y):
return np.mean(standard_units(t.column(x))*standard_units(t.column(y)))
讓我們在t
的x
和y
列上呼叫函式。 該函式返回x
和y
之間的相關性的相同答案,就像直接應用r
的公式一樣。
correlation(t, 'x', 'y')
0.61741639718977093
我們注意到,變數被指定的順序並不重要。
correlation(t, 'y', 'x')
0.61741639718977093
在suv
表的列上呼叫correlation
,可以使我們看到價格和效率之間的相關性,以及價格和加速度之間的相關性。
correlation(suv, 'mpg', 'msrp')
-0.6667143635709919
correlation(suv, 'acceleration', 'msrp')
0.48699799279959155
這些數值證實了我們的觀察:
價格和效率之間存在負相關關係,而價格和加速度之間存在正相關關係。
價格和加速度之間的線性關係(相關性約為 0.5),比價格和效率之間的線性關係稍弱(相關性約為 -0.67)。
相關性是一個簡單而強大的概念,但有時會被誤用。 在使用r
之前,重要的是要知道相關效能做和不能做什麼。
相關不是因果
相關只衡量關聯,並不意味著因果。 儘管學區內的孩子的體重與數學能力之間的相關性可能是正的,但這並不意味著做數學會使孩子更重,或者說增加體重會提高孩子的數學能力。 年齡是一個使人混淆的變數:平均來說,較大的孩子比較小的孩子更重,數學能力更好。
相關性度量線性關聯
相關性只測量一種關聯 - 線性關聯。 具有較強非線性關聯的變數可能具有非常低的相關性。 這裡有一個變數的例子,它具有完美的二次關聯y = x ^ 2
,但是相關性等於 0。
new_x = np.arange(-4, 4.1, 0.5)
nonlinear = Table().with_columns(
'x', new_x,
'y', new_x**2
)
nonlinear.scatter('x', 'y', s=30, color='r')
correlation(nonlinear, 'x', 'y')
0.0
相關性受到離群點影響
離群點可能對相關性有很大的影響。 下面是一個例子,其中通過增加一個離群點,r
等於 1 的散點圖變成r
等於 0 的圖。
line = Table().with_columns(
'x', make_array(1, 2, 3, 4),
'y', make_array(1, 2, 3, 4)
)
line.scatter('x', 'y', s=30, color='r')
correlation(line, 'x', 'y')
1.0
outlier = Table().with_columns(
'x', make_array(1, 2, 3, 4, 5),
'y', make_array(1, 2, 3, 4, 0)
)
outlier.scatter('x', 'y', s=30, color='r')
correlation(outlier, 'x', 'y')
0.0
生態相關性應謹慎解讀
基於彙總資料的相關性可能會產生誤導。 作為一個例子,這裡是 2014 年 SAT 批判性閱讀和數學成績的資料。50 個州和華盛頓特區各有一個點。Participation Rate
列包含參加考試的高中學生的百分比。 接下來的三列顯示了每個州的測試每個部分的平均得分,最後一列是測試總得分的平均值。
sat2014 = Table.read_table('sat2014.csv').sort('State')
sat2014
State | Participation Rate | Critical Reading | Math | Writing | Combined |
---|---|---|---|---|---|
Alabama | 6.7 | 547 | 538 | 532 | 1617 |
Alaska | 54.2 | 507 | 503 | 475 | 1485 |
Arizona | 36.4 | 522 | 525 | 500 | 1547 |
Arkansas | 4.2 | 573 | 571 | 554 | 1698 |
California | 60.3 | 498 | 510 | 496 | 1504 |
Colorado | 14.3 | 582 | 586 | 567 | 1735 |
Connecticut | 88.4 | 507 | 510 | 508 | 1525 |
Delaware | 100 | 456 | 459 | 444 | 1359 |
District of Columbia | 100 | 440 | 438 | 431 | 1309 |
Florida | 72.2 | 491 | 485 | 472 | 1448 |
(省略了 41 行)
數學得分與批判性閱讀得分的散點圖緊密聚集在一條直線上; 相關性接近 0.985。
sat2014.scatter('Critical Reading', 'Math')
correlation(sat2014, 'Critical Reading', 'Math')
0.98475584110674341
這是個非常高的相關性。但重要的是要注意,這並不能反映學生的數學和批判性閱讀得分之間的關係強度。
資料由每個州的平均分陣列成。但是各州不參加考試 - 而是學生。表中的資料通過將每個州的所有學生聚集為(這個州里面的兩個變數的均值處的)單個點而建立。但並不是所有州的學生都會在這個位置,因為學生的表現各不相同。如果你為每個學生繪製一個點,而不是每個州一個點,那麼在上圖中的每個點周圍都會有一圈雲狀的點。整體畫面會更模糊。學生的數學和批判性閱讀得分之間的相關性,將低於基於州均值計算的數值。
基於聚合和均值的相關性被稱為生態相關性,並且經常用於報告。正如我們剛剛所看到的,他們必須謹慎解讀。
嚴重還是開玩笑?
2012 年,在著名的《新英格蘭醫學雜誌》(New England Journal of Medicine)上發表的一篇論文,研究了一組國家巧克力消費與的諾貝爾獎之間的關係。《科學美國人》(Scientific American)嚴肅地做出迴應,而其他人更加輕鬆。 歡迎你自行決定!下面的圖表應該讓你有興趣去看看。
迴歸直線
相關係數r
並不只是測量散點圖中的點聚集在一條直線上的程度。 它也有助於確定點聚集的直線。 在這一節中,我們將追溯高爾頓和皮爾遜發現這條直線的路線。
高爾頓的父母及其成年子女身高的資料顯示出線性關係。 當我們基於雙親身高的子女身高的預測大致沿著直線時,就證實了線性。
galton = Table.read_table('galton.csv')
heights = Table().with_columns(
'MidParent', galton.column('midparentHeight'),
'Child', galton.column('childHeight')
)
def predict_child(mpht):
"""Return a prediction of the height of a child
whose parents have a midparent height of mpht.
The prediction is the average height of the children
whose midparent height is in the range mpht plus or minus 0.5 inches.
"""
close_points = heights.where('MidParent', are.between(mpht-0.5, mpht + 0.5))
return close_points.column('Child').mean()
heights_with_predictions = heights.with_column(
'Prediction', heights.apply(predict_child, 'MidParent')
)
heights_with_predictions.scatter('MidParent')
標準單位下的度量
讓我們看看,我們是否能找到一個方法來確定這條線。 首先,注意到線性關聯不依賴於度量單位 - 我們也可以用標準單位來衡量這兩個變數。
def standard_units(xyz):
"Convert any array of numbers to standard units."
return (xyz - np.mean(xyz))/np.std(xyz)
heights_SU = Table().with_columns(
'MidParent SU', standard_units(heights.column('MidParent')),
'Child SU', standard_units(heights.column('Child'))
)
heights_SU
MidParent SU | Child SU |
---|---|
3.45465 | 1.80416 |
3.45465 | 0.686005 |
3.45465 | 0.630097 |
3.45465 | 0.630097 |
2.47209 | 1.88802 |
2.47209 | 1.60848 |
2.47209 | -0.348285 |
2.47209 | -0.348285 |
1.58389 | 1.18917 |
1.58389 | 0.350559 |
(省略了 924 行)
在這個刻度上,我們可以像以前一樣精確地計算我們的預測。 但是首先我們必須弄清楚,如何將“接近”的點的舊定義轉換為新的刻度上的一個值。 我們曾經說過,如果雙親高度在 0.5 英寸之內,它們就是“接近”的。 由於標準單位以標準差為單位測量距離,所以我們必須計算出,0.5 英寸是多少個雙親身高的標準差。
雙親身高的標準差約為 1.8 英寸。 所以 0.5 英寸約為 0.28 個標準差。
sd_midparent = np.std(heights.column(0))
sd_midparent
1.8014050969207571
0.5/sd_midparent
0.27756111096536701
現在我們準備修改我們的預測函式,來預測標準單位。 所有改變的是,我們正在使用標準單位的值的表格,並定義如上所述的“接近”。
def predict_child_su(mpht_su):
"""Return a prediction of the height (in standard units) of a child
whose parents have a midparent height of mpht_su in standard units.
"""
close = 0.5/sd_midparent
close_points = heights_SU.where('MidParent SU', are.between(mpht_su-close, mpht_su + close))
return close_points.column('Child SU').mean()
heights_with_su_predictions = heights_SU.with_column(
'Prediction SU', heights_SU.apply(predict_child_su, 'MidParent SU')
)
heights_with_su_predictions.scatter('MidParent SU')
這個繪圖看起來就像在原始刻度上繪圖。 只改變了軸上的數字。 這證實了我們可以通過在標準單位下工作,來理解預測過程。
確定標準單位下的直線
高爾頓的散點圖形狀是個橄欖球 - 就是說,像橄欖球一樣大致橢圓形。不是所有的散點圖都是橄欖形的,甚至那些線性關聯的也不都是。但在這一節中,我們假裝我們是高爾頓,只能處理橄欖形的散點圖。在下一節中,我們將把我們的分析推廣到其他形狀的繪圖。
這裡是一個橄欖形散點圖,兩個變數以標準單位測量。 45 度線顯示為紅色。
但是 45 度線不是經過垂直條形的中心的線。你可以看到在下圖中,1.5 個標準單位的垂直線顯示為黑色。藍線附近的散點圖上的點的高度都大致在 -2 到 3 的範圍內。紅線太高,無法命中中心。
所以 45 度線不是“均值圖”。該線是下面顯示的綠線。
兩條線都經過原點(0,0)
。綠線穿過垂直條形的中心(至少大概),比紅色的 45 度線平坦。
45 度線的斜率為 1。所以綠色的“均值圖”直線的斜率是正值但小於 1。
這可能是什麼值呢?你猜對了 - 這是r
。
標準單位下的迴歸直線
綠色的“均值圖”線被稱為迴歸直線,我們將很快解釋原因。 但首先,讓我們模擬一些r
值不同的橄欖形散點圖,看看直線是如何變化的。 在每種情況中,繪製紅色 45 度線作比較。
執行模擬的函式為regression_line
,並以r
為引數。
regression_line(0.95)
regression_line(0.6)
當r
接近於 1 時,散點圖,45 度線和迴歸線都非常接近。 但是對於r
較低值來說,迴歸線顯然更平坦。
迴歸效應
就預測而言,這意味著,對於雙親身高為 1.5 個標準單位的家長來說,我們對女子身高的預測要稍低於 1.5 個標準單位。如果雙親高度是 2 個標準單位,我們對子女身高的預測,會比 2 個標準單位少一些。
換句話說,我們預測,子女會比父母更接近均值。
弗朗西斯·高爾頓爵士就不高興了。他一直希望,特別高的父母會有特別高的子女。然而,資料是清楚的,高爾頓意識到,高個子父母通常擁有並不是特別高的子女。高爾頓沮喪地將這種現象稱為“迴歸平庸”。
高爾頓還注意到,特別矮的父母通常擁有相對於他們這一代高一些的子女。一般來說,一個變數的平均值遠遠低於另一個變數的平均值。這被稱為迴歸效應。
迴歸直線的方程
在迴歸中,我們使用一個變數(我們稱x
)的值來預測另一個變數的值(我們稱之為y
)。 當變數x
和y
以標準單位測量時,基於x
預測y
的迴歸線斜率為r
並通過原點。 因此,迴歸線的方程可寫為:
在資料的原始單位下,就變成了:
原始單位的迴歸線的斜率和截距可以從上圖中匯出。
下面的三個函式計算相關性,斜率和截距。 它們都有三個引數:表的名稱,包含x
的列的標籤以及包含y
的列的標籤。
def correlation(t, label_x, label_y):
return np.mean(standard_units(t.column(label_x))*standard_units(t.column(label_y)))
def slope(t, label_x, label_y):
r = correlation(t, label_x, label_y)
return r*np.std(t.column(label_y))/np.std(t.column(label_x))
def intercept(t, label_x, label_y):
return np.mean(t.column(label_y)) - slope(t, label_x, label_y)*np.mean(t.column(label_x))
迴歸直線和高爾頓的資料
雙親身高和子女身高之間的相關性是 0.32:
galton_r = correlation(heights, 'MidParent', 'Child')
galton_r
0.32094989606395924
我們也可以找到迴歸直線的方程,來基於雙親身高預測子女身高:
galton_slope = slope(heights, 'MidParent', 'Child')
galton_intercept = intercept(heights, 'MidParent', 'Child')
galton_slope, galton_intercept
(0.63736089696947895, 22.636240549589751)
迴歸直線的方程是:
這也成為迴歸方程。迴歸方程的主要用途是根據x
預測y
。
例如,對於 70.48 英寸的雙親身高,迴歸直線預測,子女身高為 67.56 英寸。
galton_slope*70.48 + galton_intercept
67.557436567998622
我們最初的預測,通過計算雙親身高接近 70.48 的所有子女的平均身高來完成,這個預測非常接近:67.63 英寸,而回歸線的預測是 67.55 英寸。
heights_with_predictions.where('MidParent', are.equal_to(70.48)).show(3)
MidParent | Child | Prediction |
---|---|---|
70.48 | 74 | 67.6342 |
70.48 | 70 | 67.6342 |
70.48 | 68 | 67.6342 |
(省略了 5 行)
這裡是高爾頓的表格的所有行,我們的原始預測,以及子女身高的迴歸預測。
heights_with_predictions = heights_with_predictions.with_column(
'Regression Prediction', galton_slope*heights.column('MidParent') + galton_intercept
)
heights_with_predictions
MidParent | Child | Prediction | Regression Prediction |
---|---|---|---|
75.43 | 73.2 | 70.1 | 70.7124 |
75.43 | 69.2 | 70.1 | 70.7124 |
75.43 | 69 | 70.1 | 70.7124 |
75.43 | 69 | 70.1 | 70.7124 |
73.66 | 73.5 | 70.4158 | 69.5842 |
73.66 | 72.5 | 70.4158 | 69.5842 |
73.66 | 65.5 | 70.4158 | 69.5842 |
73.66 | 65.5 | 70.4158 | 69.5842 |
72.06 | 71 | 68.5025 | 68.5645 |
72.06 | 68 | 68.5025 | 68.5645 |
(省略了 924 行)
heights_with_predictions.scatter('MidParent')
灰色圓點顯示迴歸預測,全部在迴歸線上。 注意這條線與均值的金色圖非常接近。 對於這些資料,迴歸線很好地逼近垂直條形的中心。
擬合值
所有的預測值都在直線上,被稱為“擬合值”。 函式fit
使用表名和x
和y
的標籤,並返回一個擬合值陣列,散點圖中每個點一個。
def fit(table, x, y):
"""Return the height of the regression line at each x value."""
a = slope(table, x, y)
b = intercept(table, x, y)
return a * table.column(x) + b
下圖比上圖更輕易看到直線:
heights.with_column('Fitted', fit(heights, 'MidParent', 'Child')).scatter('MidParent')
另一個繪製直線的方式是在表方法scatter
中,使用選項fit_line=True
。
heights.scatter('MidParent', fit_line=True)
斜率的測量單位
斜率是一個比值,值得花點時間來研究它的測量單位。 我們的例子來自熟悉的醫院系統中產婦的資料集。 孕期體重與高度的散點圖看起來像是一個橄欖球,已經在一場比賽中使用了很多次,但足夠接近橄欖球,我們可以讓我們的擬合直線穿過它來證明。 在後面的章節中,我們將看到如何使這種證明更正式。
baby = Table.read_table('baby.csv')
baby.scatter('Maternal Height', 'Maternal Pregnancy Weight', fit_line=True)
slope(baby, 'Maternal Height', 'Maternal Pregnancy Weight')
3.5728462592750558
迴歸線的斜率是 3.57 磅每英寸。 這意味著,對於身高相差 1 英寸的兩名女性來說,我們對孕期體重的預測相差 3.57 磅。 對於身高相差 2 英寸的女性,我們預測的孕期體重相差2 * 3.57 ~= 7.14
磅。
請注意,散點圖中的連續垂直條形相距 1 英寸,因為高度已經舍入到最近的英寸。 另一種考慮斜率的方法是取兩個相連的條形(相隔 1 英寸),相當於兩組身高相差 1 英寸的女性。 3.57 磅每英寸的斜率意味著,較高組的平均孕期體重比較矮組多大約 3.57 磅。
示例
假設我們的目標是使用迴歸,基於巴塞特獵犬的體重來估計它的身高,所用的樣本與迴歸模型看起來一致。 假設觀察到的相關性r
為 0.5,並且這兩個變數的彙總統計量如下表所示:
average | SD |
---|---|
height | 14 inches |
weight | 50 pounds |
為了計算迴歸線的方程,我們需要斜率和截距。
迴歸線的方程允許我們,根據給定重量(磅)計算估計高度(英寸):
線的斜率衡量隨著重量的單位增長的估計高度的增長。 斜率是正值,重要的是要注意,這並不表示我們認為,如果體重增加巴塞特獵狗就會變得更高。 斜率反映了兩組狗的平均身高的差異,這兩組狗的體重相差 1 磅。 具體來說,考慮一組重量為w
磅,以及另一組重量為w + 1
磅的狗。 我們估計,第二組的均值高出 0.2 英寸。 對於樣本中的所有w
值都是如此。
一般來說,迴歸線的斜率可以解釋為隨著x
單位增長的y
平均增長。 請注意,如果斜率為負值,那麼對於x
的每單位增長,y
的平均值會減少。
尾註
即使我們沒有建立迴歸方程的數學基礎,我們可以看到,當散點圖是橄欖形的時候,它會給出相當好的預測。 這是一個令人驚訝的數學事實,無論散點圖的形狀如何,同一個方程給出所有直線中的“最好”的預測。 這是下一節的主題。
最小二乘法
我們已經回溯了高爾頓和皮爾森用於開發迴歸線方程的步驟,它穿過橄欖形的散點圖。但不是所有的散點圖都是橄欖形的,甚至不是線性的。每個散點圖都有一個“最優”直線嗎?如果是這樣,我們仍然可以使用上一節中開發的斜率和截距公式,還是需要新的公式?
為了解決這些問題,我們需要一個“最優”的合理定義。回想一下,這條線的目的是預測或估計y
的值,在給定x
值的情況下。估計通常不是完美的。每個值都由於誤差而偏離真正的值。“最優”直線的合理標準是,它在所有直線中總體誤差儘可能最小。
在本節中,我們將精確確定這個標準,看看我們能否確定標準下的最優直線。
我們的第一個例子是小說《小女人》資料集,每章都有一行。目標是根據句子數來估計字元數(即字母,空格標點符號等等)。回想一下,我們在本課程的第一堂課中試圖實現它。
little_women = Table.read_table('little_women.csv')
little_women = little_women.move_to_start('Periods')
little_women.show(3)
Periods | Characters |
---|---|
189 | 21759 |
188 | 22148 |
231 | 20558 |
(省略了 44 行)
little_women.scatter('Periods', 'Characters')
為了探索資料,我們將需要使用上一節定義的函式correlation
,slope
,intercept
和fit
。
correlation(little_women, 'Periods', 'Characters')
0.92295768958548163
散點圖明顯接近線性,相關性大於 0.92。
估計中的誤差
下圖顯示了我們在上一節中開發的散點圖和直線。 我們還不知道這是否是所有直線中最優的。 我們首先必須準確表達“最優”的意思。
lw_with_predictions = little_women.with_column('Linear Prediction', fit(little_women, 'Periods', 'Characters'))
lw_with_predictions.scatter('Periods')
對應於散點圖上的每個點,預測的誤差是計算為實際值減去預測值。 它是點與直線之間的垂直距離,如果點線上之下,則為負值。
actual = lw_with_predictions.column('Characters')
predicted = lw_with_predictions.column('Linear Prediction')
errors = actual - predicted
lw_with_predictions.with_column('Error', errors)
Periods | Characters | Linear Prediction | Error |
---|---|---|---|
189 | 21759 | 21183.6 | 575.403 |
188 | 22148 | 21096.6 | 1051.38 |
231 | 20558 | 24836.7 | -4278.67 |
195 | 25526 | 21705.5 | 3820.54 |
255 | 23395 | 26924.1 | -3529.13 |
140 | 14622 | 16921.7 | -2299.68 |
131 | 14431 | 16138.9 | -1707.88 |
214 | 22476 | 23358 | -882.043 |
337 | 33767 | 34056.3 | -289.317 |
185 | 18508 | 20835.7 | -2327.69 |
(省略了 37 行)
我們可以使用slope
和intercept
來計算擬合直線的斜率和截距。 下圖顯示了該直線(淺藍色)。 對應於四個點的誤差以紅色顯示。 這四個點沒什麼特別的。 他們只是為了展示的清晰而被選中。 函式lw_errors
以斜率和截距(按照該順序)作為引數,並繪製該圖形。
lw_reg_slope = slope(little_women, 'Periods', 'Characters')
lw_reg_intercept = intercept(little_women, 'Periods', 'Characters')
print('Slope of Regression Line: ', np.round(lw_reg_slope), 'characters per period')
print('Intercept of Regression Line:', np.round(lw_reg_intercept), 'characters')
lw_errors(lw_reg_slope, lw_reg_intercept)
Slope of Regression Line: 87.0 characters per period
Intercept of Regression Line: 4745.0 characters
如果我們用不同的線來建立我們的估計,誤差將會不同。 下面的圖表顯示瞭如果我們使用另一條線進行估算,誤差會有多大。 第二張圖顯示了通過使用完全愚蠢的線獲得了較大誤差。
lw_errors(50, 10000)
lw_errors(-100, 50000)
均方根誤差(RMSE)
我們現在需要的是誤差大小的一個總體衡量。 你會認識到建立它的方法 - 這正是我們開發標準差的方式。
如果你使用任意直線來計算你的估計值,那麼你的一些誤差可能是正的,而其他的則是負的。 為了避免誤差大小在測量時抵消,我們將採用誤差平方的均值而不是誤差的均值。
估計的均方誤差大概是誤差的平方有多大,但正如我們前面提到的,它的單位很難解釋。 取平方根產生均方根誤差(RMSE),與預測變數的單位相同,因此更容易理解。
使 RMSE 最小
到目前為止,我們的觀察可以總結如下。
- 要根據
x
估算y
,可以使用任何你想要的直線。 - 每個直線都有估計的均方根誤差。
- “更好”的直線有更小的誤差。
有沒有“最好”的直線? 也就是說,是否有一條線可以使所有行中的均方根誤差最小?
為了回答這個問題,我們首先定義一個函式lw_rmse
,通過《小女人》的散點圖來計算任意直線的均方根誤差。 函式將斜率和截距(按此順序)作為引數。
def lw_rmse(slope, intercept):
lw_errors(slope, intercept)
x = little_women.column('Periods')
y = little_women.column('Characters')
fitted = slope * x + intercept
mse = np.mean((y - fitted) ** 2)
print("Root mean squared error:", mse ** 0.5)
lw_rmse(50, 10000)
Root mean squared error: 4322.16783177
lw_rmse(-100, 50000)
Root mean squared