pytorch訓練main函式模板
阿新 • • 發佈:2021-01-03
1 # -*- encoding: utf-8 -*- 2 """ 3 @File : main.py 4 @Time : 2020/11/14 5 @Author : Ding 6 @Description: main function 7 """ 8 9 import os 10 from ConvLSTM.encoder import Encoder 11 from ConvLSTM.decoder import Decoder 12 from ConvLSTM.model import ED 13 from ConvLSTM.net_params importconvgru_encoder_params, convgru_decoder_params 14 import torch 15 from torch import nn 16 from torch.optim import lr_scheduler 17 import torch.optim as optim 18 from ConvLSTM.earlystopping import EarlyStopping 19 from tqdm import tqdm 20 import numpy as np 21 import time 22 from dataload importdataload 23 from dataload.dataload import DataLoad 24 from ConvLSTM import config 25 26 config = config.get_config() 27 TIMESTAMP = time.strftime('%Y-%m-%d', time.localtime(time.time())) 28 # TIMESTAMP = "2020-12-29" 29 30 random_seed = 1996 31 np.random.seed(random_seed) 32 torch.manual_seed(random_seed) #為CPU設定種子用於生成隨機數,以使得結果是確定的 33 if torch.cuda.device_count() > 1: 34 torch.cuda.manual_seed_all(random_seed) 35 else: 36 torch.cuda.manual_seed(random_seed) # torch.cuda.manual_seed_all()為所有的GPU設定隨機數種子。 37 torch.backends.cudnn.deterministic = True # 保證每次執行網路的時候相同輸入的輸出是固定的 38 torch.backends.cudnn.benchmark = False 39 40 save_dir = '/data/code/save_model/' + TIMESTAMP # 儲存模型的地址 41 42 ''' 43 data loading 44 ''' 45 dataload.load_csvs(config['data_root']) 46 trainFolder = DataLoad('train') 47 validFolder = DataLoad('val') 48 # test_loader = DataLoad('test') 49 trainLoader = torch.utils.data.DataLoader(trainFolder, 50 batch_size=config['batchsz'], 51 shuffle=True) 52 validLoader = torch.utils.data.DataLoader(validFolder, 53 batch_size=config['batchsz'], 54 shuffle=True) 55 # testLoader = torch.utils.data.DataLoader(test_loader, 56 # batch_size=config['batchsz'], 57 # shuffle=True) 58 59 encoder_params = convgru_encoder_params 60 decoder_params = convgru_decoder_params 61 62 63 def train(): 64 ''' 65 main function to run the training 66 ''' 67 # encoder 68 encoder_rain = Encoder(convgru_encoder_params[0], 69 convgru_encoder_params[1]).to(config['device']) 70 encoder_wl = Encoder(convgru_encoder_params[0], 71 convgru_encoder_params[1]).to(config['device']) 72 # decoder 73 decoder = Decoder(convgru_decoder_params[0], 74 convgru_decoder_params[1]).to(config['device']) 75 net = ED(encoder_rain=encoder_rain, encoder_wl=encoder_wl, decoder=decoder).to(config['device']) 76 77 # initialize the early_stopping object 78 early_stopping = EarlyStopping(patience=20, verbose=True) 79 80 if torch.cuda.device_count() > 1: 81 net = nn.DataParallel(net) 82 net.to(config['device']) 83 84 if os.path.exists(os.path.join(save_dir, 'checkpoint.pth.tar')): 85 # load existing model 86 print('==> loading existing model') 87 model_info = torch.load(r'/data/code/save_model/2020-12-31/checkpoint.pth.tar') 88 net.load_state_dict(model_info['state_dict']) 89 optimizer = torch.optim.Adam(net.parameters()) 90 optimizer.load_state_dict(model_info['optimizer']) 91 cur_epoch = model_info['epoch'] + 1 92 else: 93 if not os.path.isdir(save_dir): 94 os.makedirs(save_dir) 95 cur_epoch = 0 96 optimizer = optim.Adam(net.parameters(), lr=config['lr']) 97 lossfunction = nn.MSELoss().cuda() 98 pla_lr_scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, 99 factor=0.5, 100 patience=4, 101 verbose=True) 102 # to track the training loss as the model trains 103 train_losses = [] 104 # to track the validation loss as the model trains 105 valid_losses = [] 106 # to track the average training loss per epoch as the model trains 107 avg_train_losses = [] 108 # to track the average validation loss per epoch as the model trains 109 avg_valid_losses = [] 110 # mini_val_loss = np.inf 111 for epoch in range(cur_epoch, config['epochs'] + 1): 112 ################### 113 # train the model # 114 ################### 115 t = tqdm(trainLoader, leave=False, total=len(trainLoader)) 116 for i, (inputVar, targetVar) in enumerate(t): 117 inputs = inputVar # B,S,C,H,W 118 label = targetVar.to(config['device']) # B,S,C,H,W 119 optimizer.zero_grad() 120 net.train() 121 pred = net(inputs) # B,S,C,H,W 122 loss = lossfunction(pred, label) 123 loss_aver = loss.item() / config['batchsz'] 124 train_losses.append(loss_aver) 125 loss.backward() 126 # 梯度裁剪 127 torch.nn.utils.clip_grad_value_(net.parameters(), clip_value=10.0) 128 optimizer.step() 129 t.set_postfix({ # 進度條顯示 130 'trainloss': '{:.6f}'.format(loss_aver), 131 'epoch': '{:02d}'.format(epoch) 132 }) 133 134 ###################### 135 # validate the model # 136 ###################### 137 with torch.no_grad(): 138 net.eval() 139 t = tqdm(validLoader, leave=False, total=len(validLoader)) 140 for i, (inputVar, targetVar) in enumerate(t): 141 if i >= 3000: 142 break 143 inputs = inputVar 144 label = targetVar.to(config['device']) 145 pred = net(inputs) 146 loss = lossfunction(pred, label) 147 loss_aver = loss.item() / config['batchsz'] 148 # record validation loss 149 valid_losses.append(loss_aver) 150 # print ("validloss: {:.6f}, epoch : {:02d}".format(loss_aver,epoch),end = '\r', flush=True) 151 t.set_postfix({ 152 'validloss': '{:.6f}'.format(loss_aver), 153 'epoch': '{:02d}'.format(epoch) 154 }) 155 156 torch.cuda.empty_cache() 157 # print training/validation statistics 158 # calculate average loss over an epoch 159 train_loss = np.average(train_losses) 160 valid_loss = np.average(valid_losses) 161 avg_train_losses.append(train_loss) 162 avg_valid_losses.append(valid_loss) 163 164 # epoch_len = len(str(config['epochs'])) 165 # print_msg = (f'[{epoch:>{epoch_len}}/{args.epochs:>{epoch_len}}] ' + 166 # f'train_loss: {train_loss:.6f} ' + 167 # f'valid_loss: {valid_loss:.6f}') 168 # 169 # print(print_msg) 170 # clear lists to track next epoch 171 train_losses = [] 172 valid_losses = [] 173 pla_lr_scheduler.step(valid_loss) # lr_scheduler 174 model_dict = { 175 'epoch': epoch, 176 'state_dict': net.state_dict(), 177 'optimizer': optimizer.state_dict() 178 } 179 early_stopping(valid_loss.item(), model_dict, epoch, save_dir) 180 if early_stopping.early_stop: 181 print("Early stopping") 182 break 183 184 with open("avg_train_losses.txt", 'wt') as f: 185 for i in avg_train_losses: 186 print(i, file=f) 187 188 with open("avg_valid_losses.txt", 'wt') as f: 189 for i in avg_valid_losses: 190 print(i, file=f) 191 192 # 看情況使用,載入模型 193 def load_checkpoint(model, checkpoint, optimizer, loadOptimizer): 194 if checkpoint != 'No': 195 print("loading checkpoint...") 196 model_dict = model.state_dict() 197 modelCheckpoint = torch.load(checkpoint) 198 pretrained_dict = modelCheckpoint['state_dict'] 199 # 過濾操作 200 new_dict = {k: v for k, v in pretrained_dict.items() if k in model_dict.keys()} 201 model_dict.update(new_dict) 202 # 打印出來,更新了多少的引數 203 print('Total : {}, update: {}'.format(len(pretrained_dict), len(new_dict))) 204 model.load_state_dict(model_dict) 205 print("loaded finished!") 206 # 如果不需要更新優化器那麼設定為false 207 if loadOptimizer == True: 208 optimizer.load_state_dict(modelCheckpoint['optimizer']) 209 print('loaded! optimizer') 210 else: 211 print('not loaded optimizer') 212 else: 213 print('No checkpoint is included') 214 return model, optimizer 215 216 217 if __name__ == "__main__": 218 train()