1. 程式人生 > >[ pytorch ] ——基本使用:(5) 模型並聯

[ pytorch ] ——基本使用:(5) 模型並聯

模型並聯要注意的是:即使並聯的模組是一樣的,也要用不同的變數來定義,不然model.parameters裡面只會出現一次該模組,而不是並聯的全部模組。

class MyModel(nn.Module):     # Resnet50 + Encoder_Decoder
    def __init__(self,class_num=2):   # input the dim of output fea-map of Resnet: 7*7*2048
        super(MyModel, self).__init__()
        
        ### 分別用不同 變數名 定義每個並聯模組 ###
        for iii in range(3):
            locals()["block" + str(iii + 1)] = []  # Variable name: block1、block2、block3
            locals()["block" + str(iii + 1)] += [torch.nn.Conv2d(in_channels=3, out_channels=10, kernel_size=3)]
            locals()["block" + str(iii + 1)] += [torch.nn.AdaptiveAvgPool2d((1,1))]
            locals()["block" + str(iii + 1)] = torch.nn.Sequential(*(locals()["block" + str(iii+1)]))
        ### 分別用不同 變數名 定義每個並聯模組 ###

        self.Block1 = locals()["block" + str(1)]
        self.Block2 = locals()["block" + str(2)]
        self.Block3 = locals()["block" + str(3)]

        self.classifier = ClassBlock(10, class_num=class_num)

    def forward(self, input):   # input is 7*7*2048!

        x = input
        
        ### 並聯forward ###
        x1 = self.Block1(x)
        x2 = self.Block2(x)
        x3 = self.Block3(x)
        x1 = torch.squeeze(x1)
        x2 = torch.squeeze(x2)
        x3 = torch.squeeze(x3)
        ### 並聯forward ###
        
        feature = torch.add(x1, x2)  # feature fusion
        feature = torch.add(feature, x3)

        feature = self.classifier(feature)

        return feature

################
#    main()
# --------------

net = MyModel()

print(list(net.named_parameters()))  # 可以查到net.Block1、 net.Block2、 net.Block3