1. 程式人生 > 其它 >使用tensorboard繪製pytorch網路模型

使用tensorboard繪製pytorch網路模型

本文簡單介紹一下如何使用tensorboard繪製pytorch網路模型

版本
torch == 1.10.1
tensorboardX == 2.5

程式碼

import torch
import torch.nn as nn
from tensorboardX import SummaryWriter

# 定義超引數
batch_size = 64
learning_rate = 1e-2
num_epoches = 20


# 定義網路
class CNN(nn.Module):
    def __init__(self):
        super(CNN, self).__init__()
        self.layer1 = nn.Sequential(
            nn.Conv2d(1, 16, kernel_size=3),
            nn.BatchNorm2d(16),
            nn.ReLU(True))

        self.layer2 = nn.Sequential(
            nn.Conv2d(16, 32, kernel_size=3),
            nn.BatchNorm2d(32),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.layer3 = nn.Sequential(
            nn.Conv2d(32, 64, kernel_size=3),
            nn.BatchNorm2d(64),
            nn.ReLU(True))

        self.layer4 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3),
            nn.BatchNorm2d(128),
            nn.ReLU(True),
            nn.MaxPool2d(kernel_size=2, stride=2)
        )

        self.fc = nn.Sequential(
            nn.Linear(128 * 4 * 4, 1024),
            nn.ReLU(True),
            nn.Linear(1024, 128),
            nn.ReLU(True),
            nn.Linear(128, 10)
        )

    def forward(self, x):
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        x = self.layer4(x)
        x = x.view(x.size(0), -1)
        x = self.fc(x)
        return x


model = CNN()
dummy_input = torch.rand(20, 1, 28, 28)
writer = SummaryWriter('log')
with SummaryWriter(comment='LeNet') as w:
    w.add_graph(model, (dummy_input,))

在pychram中執行完程式碼後,可以看到在專案的根目錄下有一個名為runs的目錄,該目錄下有一個剛生成的目錄,例如 Apr04_16-33-37_LAPTOP-QGED210TLeNet

在pycharm的命令列中執行命令
tensorboard --logdir ./runs/Apr04_16-33-37_LAPTOP-QGED210TLeNet

然後pycharm會給出tensorboard的訪問路徑,例如http://localhost:6006/,接下來就可以用瀏覽器進行訪問了

參考: https://zhuanlan.zhihu.com/p/58961505