1. 程式人生 > >pytorch搭建卷積神經網路(alexnet、vgg16、resnet50)以及訓練

pytorch搭建卷積神經網路(alexnet、vgg16、resnet50)以及訓練

文末有程式碼和資料集連結!!!!

(注:文章中所有path指檔案的路徑)

因畢業設計需要,接觸卷積神經網路。由於pytorch方便使用,所以最後使用pytorch來完成卷積神經網路訓練。

接觸到的網路有Alexnet、vgg16、resnet50,畢業答辯完後,一直在訓練Alexnet。

1.卷積神經網路搭建

  pytorch中有torchvision.models,裡面有許多已搭建好的模型。如果採用預訓練模型,只需要修改最後分類的類別。

雖然這樣但是我還是inception v3模型修改上失敗。

alexnet和vgg16修改的是全連線層的最後一層。

model.classifier = nn.Sequential(nn.Linear(25088, 4096),      #vgg16
                                 nn.ReLU(),
                                 nn.Dropout(p=0.5),
                                 nn.Linear(4096, 4096),
                                 nn.ReLU(),
                                 nn.Dropout(p=0.5),
                                 nn.Linear(4096, 2))
alexnet_model.classifier = nn.Sequential(
            nn.Dropout(),
            nn.Linear(256 * 6 * 6, 4096),
            nn.ReLU(inplace=True),
            nn.Dropout(),
            nn.Linear(4096, 4096),
            nn.ReLU(inplace=True),
            nn.Linear(4096, 2),
        )

resnet50只需要修改最後的fc層。

model.fc = nn.Linear(2048, 2)

簡單的修改,就可以完成。

如果採用要採用預訓練模型的話,還需要對修改處引數的進行修改。(vgg16和alexnet需要,resnet50不需要,原因我認為是修改的地方不同)

 for index, parma in enumerate(model.classifier.parameters()):
     if index == 6:
        parma.requires_grad = True

2.訓練

這張圖是我所認為的神經網路訓練的七步吧。

(1) 模型的建立上文已介紹。

(2) 資料集的建立:在PyTorch中對於資料集的檔案格式有一定的要求。如圖4-10所示,在目錄下分別建cat和dog資料夾,這就相當於做標籤

(3)對資料集進行預處理:這裡採用的是資料增強變化的方法,包括對圖片大小進行壓縮和輸入畫素統一,都為224224,還有影象翻轉以及歸一化。

data_transform = transforms.Compose([
    transforms.Scale((224,224), 2),                           #對影象大小統一
    transforms.RandomHorizontalFlip(),                        #影象翻轉
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[    #影象歸一化
                             0.229, 0.224, 0.225])
         ])

(4)資料集的載入,載入方式有三種:1.如果採用pytorch模組自帶的資料集就可以使用torchvision.datasets.       來新增資料集。2.和我下面程式碼一樣,使用torchvision.datasets.ImageFolder,不過資料夾要按照(2)中固定格式來建立資料集。3.參照pytorch中的原始碼自己寫一個相對應的函式。

train_dataset = torchvision.datasets.ImageFolder(root='/path/data/train/',transform=data_transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size = batch_size, shuffle=True, num_workers=0)

val_dataset = torchvision.datasets.ImageFolder(root='/path/data/val/', transform=data_transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size = batch_size, shuffle=True, num_workers=0)

(5)  模型的訓練

    for epoch in range(num_epochs):
        batch_size_start = time.time()
        running_loss = 0.0
        for i, (inputs, labels) in enumerate(train_loader):
            if epoch >= 5:
                optimizer = torch.optim.SGD(model.classifier.parameters(), lr=lr2)
                print("lr", lr2)
            else:
                optimizer = torch.optim.SGD(model.classifier.parameters(), lr=lr1)
                print("lr", lr1)
            inputs = Variable(inputs)
            labels = Variable(labels)
            optimizer.zero_grad()
            outputs = model(inputs)
            criterion = nn.CrossEntropyLoss()
            loss = criterion(outputs, labels)        #交叉熵
            loss.backward()
            optimizer.step()                          #更新權重
            running_loss += loss.data[0]

        print('Epoch [%d/%d], Loss: %.4f,need time %.4f'
                  % (epoch + 1, num_epochs, running_loss / (4000 / batch_size), time.time() - batch_size_start))

(6)驗證集的驗證  ,程式碼中有模型的儲存

        correct = 0
        total = 0
        model.eval()
        for (images, labels) in val_loader:
            batch_size_start = time.time()
            images = Variable(images)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum()

        # print("正確的數量:", correct)
        print(" Val BatchSize cost time :%.4f s" % (time.time() - batch_size_start))
        print('Test Accuracy of the model on the %d Val images: %.4f' % (total, float(correct) / total))
        if (float(correct) / total) >= 0.99:
            print('the Accuracy>=0.98 the num_epochs:%d'% epoch)
            break
        x_epoch.append(epoch)
        Acc = round((float(correct) / total), 3)
        y_acc.append(Acc)

        picName = os.path.join(codeDirRoot, "log", "pic",
                               "alexnet%s.png" % experimentSuffix)
        line_chart(x_epoch, y_acc, picName)

        # if (epoch + 1) % adjustLREpoch == 0:
        #     adjust_learning_rate(optimizer, LRModulus)

        if (epoch+1) % saveModelEpoch != 0:
            continue
        saveModelName = os.path.join(codeDirRoot, "model", "alexnet%s_model.pkl"%experimentSuffix + "_" + str(epoch))
        torch.save(model.state_dict(), saveModelName)

(7) 測試集的測試,程式碼中包含模型的載入。

model.load_state_dict(torch.load(
    "/path/cnn/model/vgg16/39_vgg16_model.pkl",map_location=lambda storage, loc: storage))
model.eval()
correct = 0
total = 0
for images, labels in test_loader:
    images = Variable(images)
    outputs = model(images)
    _, predicted = torch.max(outputs.data, 1)
    total += labels.size(0)
    correct += (predicted == labels).sum()
print("正確的數量%d,所有圖片數量%d:" % (correct, total))
print('val accuracy of the %d val images:%.4f' % (total, float(correct) / total))

這是完整的過程。在這個過程中加入了,警告忽略、日誌儲存、圖形化資料。程式碼如下。

import warnings
warnings.filterwarnings("ignore")
class Logger(object):

    def __init__(self, filename="Default.log"):
        self.terminal = sys.stdout
        self.log = open(filename, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass
sys.stdout = Logger("/path/cnn/log/resnet50_image_show.txt")
# 畫折線圖形並儲存
def line_chart(x_epoch, y_acc, picName):
    plt.figure()#建立繪圖物件
    plt.plot(x_epoch, y_acc, "b--", linewidth=1)   #在當前繪圖物件繪圖(X軸,Y軸,藍色虛線,線寬度)
    plt.ylim(0.00, 1.00)
    plt.xlabel("epoch")            #X軸標籤
    plt.ylabel("accuracy")               #Y軸標籤
    plt.title("alexnet-Line _chart")          #圖示題
    # plt.savefig(os.path.join(codeDirRoot, "log", "pic", "resnet50%s.png"%experimentSuffix))  # 儲存圖
    plt.savefig(picName)  # 儲存圖

我在老師要求下,做了最後的識別結果輸出。下面是完整的程式碼。

warnings.filterwarnings("ignore")     #忽略警告
class Logger(object):                                #儲存日誌函式
    def __init__(self, filename="Default.log"):
        self.terminal = sys.stdout
        self.log = open(filename, "a")

    def write(self, message):
        self.terminal.write(message)
        self.log.write(message)

    def flush(self):
        pass
sys.stdout = Logger("path/cnn/log/alexnet_image_show.txt")

#顯示圖片函式
def imshow(inp, title=None):
    """Imshow for Tensor."""
    inp = inp.numpy().transpose((1, 2, 0))
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    inp = std * inp + mean
    inp = np.clip(inp, 0, 1)
    plt.imshow(inp)
    if title is not None:
        plt.title(title)
    plt.pause(0.001)
# 模型搭建
model = models.alexnet(pretrained=False)
model.classifier = nn.Sequential(nn.Linear(9216, 4096),
                                 nn.ReLU(),
                                 nn.Dropout(p=0.5),
                                 nn.Linear(4096, 4096),
                                 nn.ReLU(),
                                 nn.Dropout(p=0.5),
                                 nn.Linear(4096, 2))
print("model", model)
#載入預訓練模型
model.load_state_dict(torch.load("/path/cnn/model/alexnet_model.pkl", map_location=lambda storage, loc: storage))
#資料預處理
data_transform = transforms.Compose([
    transforms.Scale((224, 224), 2),
    transforms.RandomHorizontalFlip(),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
    ])
#建立資料集
test_dataset = torchvision.datasets.ImageFolder("/path/data/show", data_transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=True)
#分類的類別
class_names = test_dataset.classes
# 顯示一些圖片預測函式
def visualize_model(model, num_images):
    model.eval()
    images_so_far = 0

    for i, data in enumerate(test_loader):
        inputs, labels = data
        inputs, labels = Variable(inputs), Variable(labels)
        outputs = model(inputs)
        _, predicted = torch.max(outputs.data, 1)

        for j in range(inputs.size()[0]):
            images_so_far += 1
            ax = plt.subplot(num_images//2, 2, images_so_far)
            ax.axis('off')
            ax.set_title('predicted: {}'.format(class_names[predicted[j]]))
            imshow(inputs.cpu().data[j])
            if images_so_far == num_images:
                return
visualize_model(model, 10)       顯示十張圖片

# plt.ioff()     #“關閉互動模式”。
plt.savefig("/path/cnn/log/pic/alexnet.png")  # 儲存圖
plt.show()

這就是整個過程。

百度網盤連結:連結: https://pan.baidu.com/s/1d8qVvi_1jBMH2Ayd5fDi3A 提取碼: 6dkr