FSL-GNN程式碼解讀
阿新 • • 發佈:2021-10-06
FSL-GNN程式碼解讀
main.py(主函式)
1、載入資料集:
train_loader = generator.Generator(args.dataset_root, args, partition='train', dataset=args.dataset)
2、初始化或載入模型:
enc_nn = models.load_model('enc_nn', args, io) metric_nn = models.load_model('metric_nn', args, io) if enc_nn is None or metric_nn is None: enc_nn, metric_nn = models.create_models(args=args) softmax_module = models.SoftmaxModule()
models.create_models(args=args)
: in models.py
def create_models(args): print (args.dataset) if 'omniglot' == args.dataset: enc_nn = EmbeddingOmniglot(args, 64) elif 'mini_imagenet' == args.dataset: enc_nn = EmbeddingImagenet(args, 128) else: raise NameError('Dataset ' + args.dataset + ' not knows') return enc_nn, MetricNN(args, emb_size=enc_nn.emb_size) class EmbeddingOmniglot(): # 特徵提取 class EmbeddingImagenet(): # 略 class MetricNN(nn.Module): if self.metric_network == 'gnn_iclr_nl':…… # 正常的網路 self.gnn_obj = gnn_iclr.GNN_nl() # in gnn_iclr.py elif self.metric_network == 'gnn_iclr_active':…… # 主動學習 self.gnn_obj = gnn_iclr.GNN_active()# in gnn_iclr.py class SoftmaxModule(): # 線性分類
class GNN_nl(nn.Module) & class GNN_active(nn.Module)
: in gnn_iclr.py
class GNN_nl(nn.Module): # 圖網路主要部分
class Wcompute(nn.Module) # W鄰接矩陣計算
class Gconv(nn.Module) # 組圖
def gmul(input) # 更新圖節點特徵,W直接返回
3、訓練
# 權重衰減 weight_decay = 1e-6 # 優化器 opt_enc_nn = optim.Adam(enc_nn.parameters(), lr=args.lr, weight_decay=weight_decay) opt_metric_nn = optim.Adam(metric_nn.parameters(), lr=args.lr, weight_decay=weight_decay) # 梯度置零,也就是把loss關於weight的導數變成0 opt_enc_nn.zero_grad() opt_metric_nn.zero_grad() # 訓練 loss_d_metric = train_batch( model=[enc_nn, metric_nn, softmax_module], data=[batch_x, label_x, batches_xi, labels_yi, oracles_yi, hidden_labels]) # 更新引數 opt_enc_nn.step() opt_metric_nn.step() # 自適應引數 adjust_learning_rate(optimizers=[opt_enc_nn, opt_metric_nn], lr=args.lr, iter=batch_idx) # 顯示訓練中loss的更新 if batch_idx % args.log_interval == 0: display_str = 'Train Iter: {}'.format(batch_idx) display_str += '\tLoss_d_metric: {:.6f}'.format(total_loss/counter) io.cprint(display_str) # 測試 def test_one_shot(args, model, test_samples=5000, partition='test') 定義於 test.py 中 val_acc_aux = test.test_one_shot # 驗證集上測試 test_acc_aux = test.test_one_shot # 測試集上測試 test.test_one_shot( # 訓練集上測試 args, model=[enc_nn, metric_nn, softmax_module], test_samples=test_samples, partition='train') # 測試完畢,將模型設定回訓練狀態 enc_nn.train() metric_nn.train() # 若在驗證集上的效果繼續變好,則更新 if val_acc_aux is not None and val_acc_aux >= val_acc: # 儲存模型 torch.save(enc_nn, 'checkpoints/%s/models/enc_nn.t7' % args.exp_name) torch.save(metric_nn, 'checkpoints/%s/models/metric_nn.t7' % args.exp_name) # 全部訓練完畢後進行測試 test.test_one_shot