pytorch入門 | 簡單的線性模型
阿新 • • 發佈:2022-05-30
1.程式碼
import torch x_data = [1.0, 2.0, 3.0] y_data = [2.0, 4.0, 6.0] w = torch.tensor([1.0]) # w的初值為1.0 w.requires_grad = True # 表示需要計算梯度,(只有手動設定需要計算梯度的時候才會去計算) def forward(x): return x*w # w是一個Tensor (所以這個乘法運算子已經被過載了,x已經被自動型別轉換為了tensor型別) def loss(x, y): y_pred = forward(x) return (y_pred - y)**2 print("predict (before training)", 4, forward(4).item()) for epoch in range(100): for x, y in zip(x_data, y_data): l =loss(x,y) # 前饋過程 (l是一個張量,tensor主要是在建立計算圖 forward, compute the loss) l.backward() # 自動地把計算鏈路(計算圖)上所有需要計算梯度的地方上的梯度都求出來(backward,compute grad for Tensor whose requires_grad set to True) # 隨即計算圖被釋放,下次再進行l =loss(x,y計算的時候會建立新的計算圖 #(注:有時候構建的神經網路每一次執行的時候計算圖可能是不一樣的,每進行一次反向傳播之後將計算圖釋放,準備下一次的圖,這是一種靈活的方式) print('\tgrad:', x, y, w.grad.item())# 這裡的item()作用是直接把梯度裡面的數值直接拿出來,變成python中的標量,也是為了防止產生計算圖 w.data = w.data - 0.01 * w.grad.data # 這裡不能直接使用w.grad,這裡的w.grad也是一個tensor(權重更新時,需要用到標量);如果直接使用w.grad 則是在建立計算圖; # 取標量w.grad.data 來計算時是不會再建立計算圖的 # 這裡這是對權重進行純數值的修改 w.grad.data.zero_() # after update, remember set the grad to zero # 每一次運算後都需要將上一次的梯度記錄清空(釋放) print('progress:', epoch, l.item()) # 訓練的;輪數以及該輪最後訓練時候的loss值 # 取出loss使用l.item,不要直接使用l(l是tensor型別,會構建計算圖) print("predict (after training)", 4, forward(4).item())
predict (before training) 4 4.0 grad: 1.0 2.0 -2.0 grad: 2.0 4.0 -7.840000152587891 grad: 3.0 6.0 -16.228801727294922 progress: 0 7.315943717956543 grad: 1.0 2.0 -1.478623867034912 grad: 2.0 4.0 -5.796205520629883 grad: 3.0 6.0 -11.998146057128906 progress: 1 3.9987640380859375 grad: 1.0 2.0 -1.0931644439697266 grad: 2.0 4.0 -4.285204887390137 grad: 3.0 6.0 -8.870372772216797 progress: 2 2.1856532096862793 grad: 1.0 2.0 -0.8081896305084229 grad: 2.0 4.0 -3.1681032180786133 grad: 3.0 6.0 -6.557973861694336 progress: 3 1.1946394443511963 grad: 1.0 2.0 -0.5975041389465332 grad: 2.0 4.0 -2.3422164916992188 grad: 3.0 6.0 -4.848389625549316 progress: 4 0.6529689431190491 grad: 1.0 2.0 -0.4417421817779541 grad: 2.0 4.0 -1.7316293716430664 grad: 3.0 6.0 -3.58447265625 progress: 5 0.35690122842788696 grad: 1.0 2.0 -0.3265852928161621 grad: 2.0 4.0 -1.2802143096923828 grad: 3.0 6.0 -2.650045394897461 progress: 6 0.195076122879982 grad: 1.0 2.0 -0.24144840240478516 grad: 2.0 4.0 -0.9464778900146484 grad: 3.0 6.0 -1.9592113494873047 progress: 7 0.10662525147199631 grad: 1.0 2.0 -0.17850565910339355 grad: 2.0 4.0 -0.699742317199707 grad: 3.0 6.0 -1.4484672546386719 progress: 8 0.0582793727517128 grad: 1.0 2.0 -0.1319713592529297 grad: 2.0 4.0 -0.5173273086547852 grad: 3.0 6.0 -1.070866584777832 progress: 9 0.03185431286692619 grad: 1.0 2.0 -0.09756779670715332 grad: 2.0 4.0 -0.3824653625488281 grad: 3.0 6.0 -0.7917022705078125 progress: 10 0.017410902306437492 grad: 1.0 2.0 -0.07213282585144043 grad: 2.0 4.0 -0.2827606201171875 grad: 3.0 6.0 -0.5853137969970703 progress: 11 0.009516451507806778 grad: 1.0 2.0 -0.053328514099121094 grad: 2.0 4.0 -0.2090473175048828 grad: 3.0 6.0 -0.43272972106933594 progress: 12 0.005201528314501047 grad: 1.0 2.0 -0.039426326751708984 grad: 2.0 4.0 -0.15455150604248047 grad: 3.0 6.0 -0.3199195861816406 progress: 13 0.0028430151287466288 grad: 1.0 2.0 -0.029148340225219727 grad: 2.0 4.0 -0.11426162719726562 grad: 3.0 6.0 -0.23652076721191406 progress: 14 0.0015539465239271522 grad: 1.0 2.0 -0.021549701690673828 grad: 2.0 4.0 -0.08447456359863281 grad: 3.0 6.0 -0.17486286163330078 progress: 15 0.0008493617060594261 grad: 1.0 2.0 -0.01593184471130371 grad: 2.0 4.0 -0.062453269958496094 grad: 3.0 6.0 -0.12927818298339844 progress: 16 0.00046424579340964556 grad: 1.0 2.0 -0.011778593063354492 grad: 2.0 4.0 -0.046172142028808594 grad: 3.0 6.0 -0.09557533264160156 progress: 17 0.0002537401160225272 grad: 1.0 2.0 -0.00870823860168457 grad: 2.0 4.0 -0.03413581848144531 grad: 3.0 6.0 -0.07066154479980469 progress: 18 0.00013869594840798527 grad: 1.0 2.0 -0.006437778472900391 grad: 2.0 4.0 -0.025236129760742188 grad: 3.0 6.0 -0.052239418029785156 progress: 19 7.580435340059921e-05 grad: 1.0 2.0 -0.004759550094604492 grad: 2.0 4.0 -0.018657684326171875 grad: 3.0 6.0 -0.038620948791503906 progress: 20 4.143271507928148e-05 grad: 1.0 2.0 -0.003518819808959961 grad: 2.0 4.0 -0.0137939453125 grad: 3.0 6.0 -0.028553009033203125 progress: 21 2.264650902361609e-05 grad: 1.0 2.0 -0.00260162353515625 grad: 2.0 4.0 -0.010198593139648438 grad: 3.0 6.0 -0.021108627319335938 progress: 22 1.2377059647405986e-05 grad: 1.0 2.0 -0.0019233226776123047 grad: 2.0 4.0 -0.0075397491455078125 grad: 3.0 6.0 -0.0156097412109375 progress: 23 6.768445018678904e-06 grad: 1.0 2.0 -0.0014221668243408203 grad: 2.0 4.0 -0.0055751800537109375 grad: 3.0 6.0 -0.011541366577148438 progress: 24 3.7000872907810844e-06 grad: 1.0 2.0 -0.0010514259338378906 grad: 2.0 4.0 -0.0041217803955078125 grad: 3.0 6.0 -0.008531570434570312 progress: 25 2.021880391112063e-06 grad: 1.0 2.0 -0.0007772445678710938 grad: 2.0 4.0 -0.0030469894409179688 grad: 3.0 6.0 -0.006305694580078125 progress: 26 1.1044940038118511e-06 grad: 1.0 2.0 -0.0005745887756347656 grad: 2.0 4.0 -0.0022525787353515625 grad: 3.0 6.0 -0.0046634674072265625 progress: 27 6.041091182851233e-07 grad: 1.0 2.0 -0.0004248619079589844 grad: 2.0 4.0 -0.0016651153564453125 grad: 3.0 6.0 -0.003444671630859375 progress: 28 3.296045179013163e-07 grad: 1.0 2.0 -0.0003139972686767578 grad: 2.0 4.0 -0.0012311935424804688 grad: 3.0 6.0 -0.0025491714477539062 progress: 29 1.805076408345485e-07 grad: 1.0 2.0 -0.00023221969604492188 grad: 2.0 4.0 -0.0009107589721679688 grad: 3.0 6.0 -0.0018854141235351562 progress: 30 9.874406714516226e-08 grad: 1.0 2.0 -0.00017189979553222656 grad: 2.0 4.0 -0.0006742477416992188 grad: 3.0 6.0 -0.00139617919921875 progress: 31 5.4147676564753056e-08 grad: 1.0 2.0 -0.0001270771026611328 grad: 2.0 4.0 -0.0004978179931640625 grad: 3.0 6.0 -0.00102996826171875 progress: 32 2.9467628337442875e-08 grad: 1.0 2.0 -9.393692016601562e-05 grad: 2.0 4.0 -0.0003681182861328125 grad: 3.0 6.0 -0.0007610321044921875 progress: 33 1.6088051779661328e-08 grad: 1.0 2.0 -6.937980651855469e-05 grad: 2.0 4.0 -0.00027179718017578125 grad: 3.0 6.0 -0.000560760498046875 progress: 34 8.734787115827203e-09 grad: 1.0 2.0 -5.125999450683594e-05 grad: 2.0 4.0 -0.00020122528076171875 grad: 3.0 6.0 -0.0004177093505859375 progress: 35 4.8466972657479346e-09 grad: 1.0 2.0 -3.790855407714844e-05 grad: 2.0 4.0 -0.000148773193359375 grad: 3.0 6.0 -0.000308990478515625 progress: 36 2.6520865503698587e-09 grad: 1.0 2.0 -2.8133392333984375e-05 grad: 2.0 4.0 -0.000110626220703125 grad: 3.0 6.0 -0.0002288818359375 progress: 37 1.4551915228366852e-09 grad: 1.0 2.0 -2.09808349609375e-05 grad: 2.0 4.0 -8.20159912109375e-05 grad: 3.0 6.0 -0.00016880035400390625 progress: 38 7.914877642178908e-10 grad: 1.0 2.0 -1.5497207641601562e-05 grad: 2.0 4.0 -6.103515625e-05 grad: 3.0 6.0 -0.000125885009765625 progress: 39 4.4019543565809727e-10 grad: 1.0 2.0 -1.1444091796875e-05 grad: 2.0 4.0 -4.482269287109375e-05 grad: 3.0 6.0 -9.1552734375e-05 progress: 40 2.3283064365386963e-10 grad: 1.0 2.0 -8.344650268554688e-06 grad: 2.0 4.0 -3.24249267578125e-05 grad: 3.0 6.0 -6.580352783203125e-05 progress: 41 1.2028067430946976e-10 grad: 1.0 2.0 -5.9604644775390625e-06 grad: 2.0 4.0 -2.288818359375e-05 grad: 3.0 6.0 -4.57763671875e-05 progress: 42 5.820766091346741e-11 grad: 1.0 2.0 -4.291534423828125e-06 grad: 2.0 4.0 -1.71661376953125e-05 grad: 3.0 6.0 -3.719329833984375e-05 progress: 43 3.842615114990622e-11 grad: 1.0 2.0 -3.337860107421875e-06 grad: 2.0 4.0 -1.33514404296875e-05 grad: 3.0 6.0 -2.86102294921875e-05 progress: 44 2.2737367544323206e-11 grad: 1.0 2.0 -2.6226043701171875e-06 grad: 2.0 4.0 -1.049041748046875e-05 grad: 3.0 6.0 -2.288818359375e-05 progress: 45 1.4551915228366852e-11 grad: 1.0 2.0 -1.9073486328125e-06 grad: 2.0 4.0 -7.62939453125e-06 grad: 3.0 6.0 -1.430511474609375e-05 progress: 46 5.6843418860808015e-12 grad: 1.0 2.0 -1.430511474609375e-06 grad: 2.0 4.0 -5.7220458984375e-06 grad: 3.0 6.0 -1.1444091796875e-05 progress: 47 3.637978807091713e-12 grad: 1.0 2.0 -1.1920928955078125e-06 grad: 2.0 4.0 -4.76837158203125e-06 grad: 3.0 6.0 -1.1444091796875e-05 progress: 48 3.637978807091713e-12 grad: 1.0 2.0 -9.5367431640625e-07 grad: 2.0 4.0 -3.814697265625e-06 grad: 3.0 6.0 -8.58306884765625e-06 progress: 49 2.0463630789890885e-12 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 50 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 51 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 52 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 53 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 54 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 55 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 56 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 57 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 58 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 59 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 60 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 61 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 62 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 63 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 64 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 65 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 66 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 67 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 68 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 69 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 70 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 71 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 72 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 73 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 74 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 75 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 76 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 77 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 78 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 79 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 80 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 81 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 82 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 83 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 84 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 85 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 86 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 87 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 88 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 89 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 90 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 91 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 92 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 93 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 94 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 95 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 96 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 97 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 98 9.094947017729282e-13 grad: 1.0 2.0 -7.152557373046875e-07 grad: 2.0 4.0 -2.86102294921875e-06 grad: 3.0 6.0 -5.7220458984375e-06 progress: 99 9.094947017729282e-13 predict (after training) 4 7.999998569488525
2.說明
- w是Tensor(張量型別),Tensor中包含data和grad。
grad初始為None,呼叫l.backward()方法後w.grad為Tensor,故更新w.data時需使用w.grad.data。
-
“如果w需要計算梯度,那構建的計算圖中,跟w相關的tensor都預設需要計算梯度。”
w是Tensor, forward函式的返回值也是Tensor,loss函式的返回值也是Tensor
例如: -
w = torch.Tensor([1.0])和w = torch.tensor([1.0])都可以
-
本演算法中反向傳播主要體現在,l.backward()。呼叫該方法後w.grad由None更新為Tensor型別,且w.grad.data的值用於後續w.data的更新。
l.backward()會把計算圖中所有需要梯度(grad)的地方都會求出來,然後把梯度都存在對應的變數中,隨即計算圖被釋放。
- tensor型別與數值型別做加法計算的時候,要使用item(),防止產生計算圖而導致佔用記憶體