1. 程式人生 > 其它 >Pytorch深入學習階段二(三)

Pytorch深入學習階段二(三)

Pytorch學習階段二(三)

一、真實的torch.nn

轉化資料型別:

x_train, y_train, x_valid, y_valid = map(
    torch.tensor, (x_train, y_train, x_valid, y_valid)
)

torch.nn

  • module:建立可呼叫物件,包含權重等狀態,並且可以更新權重
  • Parameter:即需要被訓練的權重,設定requires_grad來設定更新
  • functional:一個包含啟用函式,損失函式等的模型

torch.optim:包含SGD等許多優化器,在後向傳播的過程中更新權重

Dataset__len__

__getitem__重寫後為神經網路載入資料

DataLoader:返回一個迭代器,可用於迭代資料

二、TensorBoard使用

初始化:

from torch.utils.tensorboard import SummaryWriter

# default `log_dir` is "runs" - we'll be more specific here
writer = SummaryWriter('runs/fashion_mnist_experiment_1')

新增圖片:

# write to tensorboard
writer.add_image('four_fashion_mnist_images', img_grid)

run:

tensorboard --logdir=runs --port=8080

在中控臺點選: http://localhost:8080或者瀏覽器瀏覽此網頁

新增視覺化網路:

writer.add_graph(net, images)

新增圖表:

# ...log the running loss
            writer.add_scalar('training loss',
                            running_loss / 1000,
                            epoch * len(trainloader) + i)