1. 程式人生 > 實用技巧 >Pytorch RuntimeError: expected type torch.cuda.FloatTensor but got torch.FloatTensor

Pytorch RuntimeError: expected type torch.cuda.FloatTensor but got torch.FloatTensor

Pytorch RuntimeError: expected type torch.cuda.FloatTensor but got torch.FloatTensor

在使用Pytorch框架訓練模型時,丟擲RuntimeError: expected type torch.cuda.FloatTensor but got torch.FloatTensor。

產生原因及分析

待訓練網路在GPU中運算,但有部分資料未進入GPU。
基於Pytorch框架使用GPU進行訓練時,輸入資料(樣本、標記)、網路結構均會在GPU中進行計算,例如:

……
device = 'cuda:0'
model = models.resnet18(pretrained=True).to(device)
……
inputs = inputs.to(device)
labels = labels.to(device)
……

但如果在網路計算過程中,有新加入Tensor但沒有明確指定其在GPU中運算(預設是在CPU中),則會丟擲上述異常。

解決方法1

解決辦法(單GPU)

通過上述分析可知,在網路中引入新的Tensor時,顯式指定其執行裝置為GPU,可解決上述問題,例如:

……
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype).to(device)
……

在單GPU時,該方法能夠解決上述問題;但在多GPU情況下,仍會丟擲不在同一個GPU上計算的異常。

解決辦法(多GPU)

對於多GPU,網路中新增Tensor一般會與輸入(inputs)進行計算;因此獲取inputs所在的GPU裝置,將新Tensor的計算裝置設定為與之相同,問題得解,例如:

……
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype).to(inputs.device)
……

或者

……
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
……

解決辦法2

常見錯誤 RuntimeError: expected type torch.FloatTensor but got torch.cuda.FloatTensor

https://www.jianshu.com/p/0be7a375bdbe

https://blog.csdn.net/qq_38410428/article/details/82973895

計算中有的引數為cuda型,有的引數卻是cpu型,就會遇到這樣的錯誤。

解決辦法:

該加.cuda()的加上,不該用.cpu()的地方去掉它。