1. 程式人生 > 其它 >FSL-GNN程式碼解讀

FSL-GNN程式碼解讀

FSL-GNN程式碼解讀

main.py(主函式)

1、載入資料集:

train_loader = generator.Generator(args.dataset_root, args, partition='train', dataset=args.dataset)

2、初始化或載入模型:

enc_nn = models.load_model('enc_nn', args, io)
metric_nn = models.load_model('metric_nn', args, io)

if enc_nn is None or metric_nn is None:
	enc_nn, metric_nn = models.create_models(args=args)
    softmax_module = models.SoftmaxModule()

models.create_models(args=args) : in models.py

def create_models(args):
    print (args.dataset)
    if 'omniglot' == args.dataset:
        enc_nn = EmbeddingOmniglot(args, 64)
    elif 'mini_imagenet' == args.dataset:
        enc_nn = EmbeddingImagenet(args, 128)
    else:
        raise NameError('Dataset ' + args.dataset + ' not knows')
    return enc_nn, MetricNN(args, emb_size=enc_nn.emb_size)

class EmbeddingOmniglot():				# 特徵提取
class EmbeddingImagenet():				# 略

class MetricNN(nn.Module):
	if self.metric_network == 'gnn_iclr_nl':……		# 正常的網路
	self.gnn_obj = gnn_iclr.GNN_nl()			# in gnn_iclr.py
	
	elif self.metric_network == 'gnn_iclr_active':……	# 主動學習
	self.gnn_obj = gnn_iclr.GNN_active()# in gnn_iclr.py
	
class SoftmaxModule():		# 線性分類

class GNN_nl(nn.Module) & class GNN_active(nn.Module) : in gnn_iclr.py

class GNN_nl(nn.Module):		# 圖網路主要部分
	class Wcompute(nn.Module)	# W鄰接矩陣計算
    class Gconv(nn.Module)		# 組圖
		def gmul(input)		# 更新圖節點特徵,W直接返回

3、訓練

# 權重衰減
weight_decay = 1e-6

# 優化器
opt_enc_nn = optim.Adam(enc_nn.parameters(), lr=args.lr, weight_decay=weight_decay)
opt_metric_nn = optim.Adam(metric_nn.parameters(), lr=args.lr, weight_decay=weight_decay)

# 梯度置零,也就是把loss關於weight的導數變成0
opt_enc_nn.zero_grad()
opt_metric_nn.zero_grad()

# 訓練
loss_d_metric = train_batch(
	model=[enc_nn, metric_nn, 
	softmax_module],
	data=[batch_x, label_x, batches_xi, labels_yi, oracles_yi, hidden_labels])

# 更新引數
opt_enc_nn.step()
opt_metric_nn.step()

# 自適應引數
adjust_learning_rate(optimizers=[opt_enc_nn, opt_metric_nn], lr=args.lr, iter=batch_idx)

# 顯示訓練中loss的更新
if batch_idx % args.log_interval == 0:
	display_str = 'Train Iter: {}'.format(batch_idx)
	display_str += '\tLoss_d_metric: {:.6f}'.format(total_loss/counter)
	io.cprint(display_str)

# 測試
def test_one_shot(args, model, test_samples=5000, partition='test') 定義於 test.py 中
val_acc_aux = test.test_one_shot	# 驗證集上測試
test_acc_aux = test.test_one_shot	# 測試集上測試
test.test_one_shot(					# 訓練集上測試
	args, 
	model=[enc_nn, metric_nn, softmax_module],
	test_samples=test_samples, 
	partition='train')				

# 測試完畢,將模型設定回訓練狀態
enc_nn.train()
metric_nn.train()

# 若在驗證集上的效果繼續變好,則更新
if val_acc_aux is not None and val_acc_aux >= val_acc:

# 儲存模型
torch.save(enc_nn, 'checkpoints/%s/models/enc_nn.t7' % args.exp_name)
torch.save(metric_nn, 'checkpoints/%s/models/metric_nn.t7' % args.exp_name)

# 全部訓練完畢後進行測試
test.test_one_shot