深度學習之PyTorch---- 一維線性迴歸
阿新 • • 發佈:2018-12-18
# 一維線性迴歸的程式碼實現 x_train = np.array([[3.3],[4.4],[5.5],[6.71],[6.93],[4.168], [9.779],[6.182],[7.59],[2.167],[7.042], [10.791],[5.313],[7.997],[3.1]],dtype=np.float32) y_train = np.array([[1.7],[2.76],[2.09],[3.19],[1.694],[1.573], [3.366],[2.596],[2.53],[1.221],[2.827], [3.465],[1.65],[2.904],[1.3]],dtype=np.float32) x_train = torch.from_numpy(x_train) y_train = torch.from_numpy(y_train) class LinearRegression(nn.Module): def __init__(self): super(LinearRegression,self).__init__() self.linear = nn.Linear(1,1) def forward(self,x): out = self.linear(x) return out model = LinearRegression() criterion = nn.MSELoss() optimizer = optim.SGD(model.parameters(),lr=1e-3) num_epochs = 2000 for epoch in range(num_epochs): inputs = Variable(x_train) target = Variable(y_train) # forward out = model(inputs) loss = criterion(out,target) # backward optimizer.zero_grad() loss.backward() optimizer.step() if (epoch+1) % 100 == 0: print('Epoch [{} / {}],loss {}'.format(epoch+1,num_epochs,loss.data[0])) model.eval() predict = model(Variable(x_train)) predict = predict.data.numpy() plt.plot(x_train.numpy() , y_train.numpy() ,'ro',label="Original data") plt.plot(x_train.numpy(),predict,'b',label="Fitting data") plt.legend()
輸出:
Epoch [100 / 2000],loss 0.23131796717643738 Epoch [200 / 2000],loss 0.22819313406944275 Epoch [300 / 2000],loss 0.22522489726543427 Epoch [400 / 2000],loss 0.22240526974201202 Epoch [500 / 2000],loss 0.2197268307209015 Epoch [600 / 2000],loss 0.2171824872493744 Epoch [700 / 2000],loss 0.21476560831069946 Epoch [800 / 2000],loss 0.2124696969985962 Epoch [900 / 2000],loss 0.21028876304626465 Epoch [1000 / 2000],loss 0.2082170695066452 Epoch [1100 / 2000],loss 0.20624907314777374 Epoch [1200 / 2000],loss 0.20437967777252197 Epoch [1300 / 2000],loss 0.20260383188724518 Epoch [1400 / 2000],loss 0.20091693103313446 Epoch [1500 / 2000],loss 0.19931448996067047 Epoch [1600 / 2000],loss 0.19779227674007416 Epoch [1700 / 2000],loss 0.19634634256362915 Epoch [1800 / 2000],loss 0.19497279822826385 Epoch [1900 / 2000],loss 0.19366800785064697 Epoch [2000 / 2000],loss 0.1924285888671875