loss.data[0] 報錯invalid index of a 0-dim tensor. Use `tensor.item()` in Python or `tensor.item<T>()`
阿新 • • 發佈:2021-01-22
剛開始學pytorch
原始碼
for e in range(1000):
out = logistic_regression(Variable(x))
loss = criterion(out, Variable(y))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (e + 1) % 20 == 0:
print('epoch: {}, loss: {}'.format(e + 1, loss.data[0]))
報錯
修改後
把loss.data[0]後邊的[0]刪除就可以運行了
for e in range(1000):
out = logistic_regression(Variable(x))
loss = criterion(out, Variable(y))
optimizer.zero_grad()
loss.backward()
optimizer.step()
if (e + 1) % 20 == 0:
print('epoch: {}, loss: {}'.format(e + 1, loss.data))