torch or numpy
阿新 • • 發佈:2017-11-07
docs 展示 他能 乘法 錯誤 method correct multi 重點
黃色:重點
粉色:不懂
- Torch 自稱為神經網絡界的 Numpy, 因為他能將 torch 產生的 tensor 放在 GPU 中加速運算 (前提是你有合適的 GPU), 就像 Numpy 會把 array 放在 CPU 中加速運算.
import torch import numpy as np np_data = np.arange(6).reshape((2, 3)) torch_data = torch.from_numpy(np_data) tensor2array = torch_data.numpy() print( ‘\nnumpy array:‘, np_data, #
Torch 中的數學運算
其實 torch 中 tensor 的運算和 numpy array 的如出一轍, 我們就以對比的形式來看. 如果想了解 torch 中其它更多有用的運算符, API就是你要去的地方.
# abs 絕對值計算 data = [-1, -2, 1, 2] tensor
除了簡單的計算, 矩陣運算才是神經網絡中最重要的部分. 所以我們展示下矩陣的乘法. 註意一下包含了一個 numpy 中可行, 但是 torch 中不可行的方式.
# matrix multiplication 矩陣點乘 data = [[1,2], [3,4]] tensor = torch.FloatTensor(data) # 轉換成32位浮點 tensor # correct method print( ‘\nmatrix multiplication (matmul)‘, ‘\nnumpy: ‘, np.matmul(data, data), # [[7, 10], [15, 22]] ‘\ntorch: ‘, torch.mm(tensor, tensor) # [[7, 10], [15, 22]] ) # !!!! 下面是錯誤的方法 !!!! data = np.array(data) print( ‘\nmatrix multiplication (dot)‘, ‘\nnumpy: ‘, data.dot(data), # [[7, 10], [15, 22]] 在numpy 中可行 ‘\ntorch: ‘, tensor.dot(tensor) # torch 會轉換成 [1,2,3,4].dot([1,2,3,4) = 30.0 )
torch or numpy