1. 程式人生 > 其它 >當Pytorch遇上tqdm

當Pytorch遇上tqdm

技術標籤:PythonPyTorchpython深度學習pytorch

訓練模型時,控制檯輸出模板:
【資料集載入類】

import torch
from torch.utils.data import Dataset

__all__ = ["TestDataset"]


class TestDataset(Dataset):

    def __init__(self):
        super().__init__()
        self.dataset = torch.rand(20, 10, 1, 200, 200)

    def __getitem__
(self, item): return self.dataset[item,], self.dataset[item, 10:] def __len__(self): return self.dataset.shape[0]

【訓練程式碼】

import time
from random import random
from datetime import datetime

from torch.utils.data import DataLoader
from tqdm import tqdm, trange

from TestDataset import
TestDataset max_epochs = 5 train_set = TestDataset() test_set = TestDataset() train_loader = DataLoader(train_set) test_loader = DataLoader(test_set) def training(epoch: int, test_frequency: int = 5): for epoch in range(epoch): with tqdm( iterable=train_loader, bar_format=
'{desc} {n_fmt:>4s}/{total_fmt:<4s} {percentage:3.0f}%|{bar}| {postfix}', ) as t: start_time = datetime.now() loss_list = [] for batch, data in enumerate(train_loader): t.set_description_str(f"\33[36m【Epoch {epoch + 1:04d}】") # 訓練程式碼 time.sleep(1) # 計算當前損失 loss = random() loss_list.append(loss) cur_time = datetime.now() delta_time = cur_time - start_time t.set_postfix_str(f"train_loss={sum(loss_list) / len(loss_list):.6f}, 執行時長:{delta_time}\33[0m") t.update() if (epoch + 1) % test_frequency == 0: with tqdm( iterable=test_loader, bar_format='{desc} {postfix}', ) as t: # 測試一下 time.sleep(2) test_loss = 3.1415926 t.set_description_str(f"\33[35m【測試集】") t.set_postfix_str(f"test_loss={test_loss:.6f}\33[0m") t.update() training(epoch=max_epochs, test_frequency=1)

【效果圖】
在這裡插入圖片描述