Pytorch RuntimeError: expected type torch.cuda.FloatTensor but got torch.FloatTensor
阿新 • • 發佈:2020-11-06
Pytorch RuntimeError: expected type torch.cuda.FloatTensor but got torch.FloatTensor
在使用Pytorch框架訓練模型時,丟擲RuntimeError: expected type torch.cuda.FloatTensor but got torch.FloatTensor。
產生原因及分析
待訓練網路在GPU中運算,但有部分資料未進入GPU。
基於Pytorch框架使用GPU進行訓練時,輸入資料(樣本、標記)、網路結構均會在GPU中進行計算,例如:
…… device = 'cuda:0' model = models.resnet18(pretrained=True).to(device) …… inputs = inputs.to(device) labels = labels.to(device) ……
但如果在網路計算過程中,有新加入Tensor但沒有明確指定其在GPU中運算(預設是在CPU中),則會丟擲上述異常。
解決方法1
解決辦法(單GPU)
通過上述分析可知,在網路中引入新的Tensor時,顯式指定其執行裝置為GPU,可解決上述問題,例如:
……
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype).to(device)
……
在單GPU時,該方法能夠解決上述問題;但在多GPU情況下,仍會丟擲不在同一個GPU上計算的異常。
解決辦法(多GPU)
對於多GPU,網路中新增Tensor一般會與輸入(inputs)進行計算;因此獲取inputs所在的GPU裝置,將新Tensor的計算裝置設定為與之相同,問題得解,例如:
……
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype).to(inputs.device)
……
或者
……
random_tensor += torch.rand([batch_size, 1, 1, 1], dtype=inputs.dtype, device=inputs.device)
……
常見錯誤 RuntimeError: expected type torch.FloatTensor but got torch.cuda.FloatTensor
https://www.jianshu.com/p/0be7a375bdbe
https://blog.csdn.net/qq_38410428/article/details/82973895
計算中有的引數為cuda型,有的引數卻是cpu型,就會遇到這樣的錯誤。
解決辦法:
該加.cuda()的加上,不該用.cpu()的地方去掉它。