1. 程式人生 > 其它 >pytorch執行官方教程程式碼出錯問題 An attempt has been made to start a new process before the current proce

pytorch執行官方教程程式碼出錯問題 An attempt has been made to start a new process before the current proce

RuntimeError: 
        An attempt has been made to start a new process before the
        current process has finished its bootstrapping phase.

        This probably means that you are not using fork to start your
        child processes and you have forgotten to use the proper idiom
        in the main module:

            if __name__ == '__main__':
                freeze_support()
                ...

        The "freeze_support()" line can be omitted if the program
        is not going to be frozen to produce an executable.

在執行一下官方教程程式碼的時候:

import torch
import torch.utils
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
)

trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2)


traintest = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
testloader = torch.utils.data.DataLoader(traintest,batch_size=4,shuffle=False,num_workers=2)
classes =('plane','car','bird','cat','deer','dog','frog','horse','ship','trunk')
dataiter = iter(trainloader)

def imshow(img):
    img=img/2+0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()

按照以上的方式執行,出現了一開始的錯誤,在網上查閱之後,發現有幾種說法:

原因:多程序需要在main函式中進行,在上面的程式碼中設定了num_workers,所以要放到main函式中進行執行

辦法一:將程式碼放到main函式中

import torch
import torch.utils
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np
def main():
    transform = transforms.Compose(
        [transforms.ToTensor(),
         transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
    )

    trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
    trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=2)


    traintest = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
    testloader = torch.utils.data.DataLoader(traintest,batch_size=4,shuffle=False,num_workers=2)
    classes =('plane','car','bird','cat','deer','dog','frog','horse','ship','trunk')
    dataiter = iter(trainloader)

def imshow(img):
    img=img/2+0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()

if __name__ == '__main__':
    main()

此時OK,不再報錯

解決方法二:將num_workers設定為0,不啟用多執行緒

import torch
import torch.utils
import torchvision
from torchvision import transforms
import matplotlib.pyplot as plt
import numpy as np

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5,0.5,0.5),(0.5,0.5,0.5))]
)

trainset = torchvision.datasets.CIFAR10(root='./data',train=True,download=True,transform=transform)
trainloader = torch.utils.data.DataLoader(trainset,batch_size=4,shuffle=True,num_workers=0)


traintest = torchvision.datasets.CIFAR10(root='./data',train=False,download=True,transform=transform)
testloader = torch.utils.data.DataLoader(traintest,batch_size=4,shuffle=False,num_workers=0)
classes =('plane','car','bird','cat','deer','dog','frog','horse','ship','trunk')
dataiter = iter(trainloader)

def imshow(img):
    img=img/2+0.5
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg,(1,2,0)))
    plt.show()

此時也可以成功執行

總結:對python多執行緒不瞭解