Pytorch深入學習階段二(三)
阿新 • • 發佈:2022-04-07
Pytorch學習階段二(三)
一、真實的torch.nn
轉化資料型別:
x_train, y_train, x_valid, y_valid = map(
torch.tensor, (x_train, y_train, x_valid, y_valid)
)
torch.nn
- module:建立可呼叫物件,包含權重等狀態,並且可以更新權重
- Parameter:即需要被訓練的權重,設定
requires_grad
來設定更新 - functional:一個包含啟用函式,損失函式等的模型
torch.optim
:包含SGD等許多優化器,在後向傳播的過程中更新權重
Dataset
:__len__
__getitem__
重寫後為神經網路載入資料
DataLoader
:返回一個迭代器,可用於迭代資料
二、TensorBoard使用
初始化:
from torch.utils.tensorboard import SummaryWriter
# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('runs/fashion_mnist_experiment_1')
新增圖片:
# write to tensorboard writer.add_image('four_fashion_mnist_images', img_grid)
run:
tensorboard --logdir=runs --port=8080
在中控臺點選: http://localhost:8080
或者瀏覽器瀏覽此網頁
新增視覺化網路:
writer.add_graph(net, images)
新增圖表:
# ...log the running loss writer.add_scalar('training loss', running_loss / 1000, epoch * len(trainloader) + i)