201208-PyTorch求解矩陣正交基
阿新 • • 發佈:2020-12-10
技術標籤:Pytorch
- 參考 MATLAB Link
orth is obtained from U in the singular value decomposition, [U,S] = svd(A,‘econ’). If r = rank(A), the first r columns of U form an orthonormal basis for the range of A.
-
步驟:
- svd分解
- 按秩索引
-
程式碼
from scipy import linalg as LA import torch # Step 1: 建立相關矩陣 print('Step 1: 建立相關矩陣') A = np.array([[1.,0,1],[0,1,0],[1,0,1]]) B = torch.Tensor(A) print(B) print('\n') # Step 2: 求矩陣的秩 print('Step 2: 求矩陣的秩') r = torch.matrix_rank(B) print(r) print('\n') # Step 3: 矩陣的svd分解 print('Step 3: 矩陣的svd分解及正交基索引') u,s,v = torch.svd(B) torchResult = u[:,:r] print('u:\n',u) print('s:\n',s) print('v:\n',v) print('\n') # Step 4: 對比scipy直接計算與pytorch間接計算結果 print('Step 4: 對比scipy直接計算與pytorch間接計算結果') scipyResult = LA.orth(B) print('scipy.LA.orth():\n',scipyResult) scipyResult.dtype print('scipyResult - torchResult\n',torch.dist(torch.Tensor(scipyResult), torchResult,2))
- 結果
Step 1: 建立相關矩陣 tensor([[1., 0., 1.], [0., 1., 0.], [1., 0., 1.]]) Step 2: 求矩陣的秩 tensor(2) Step 3: 矩陣的svd分解及正交基索引 u: tensor([[-0.7071, 0.0000, 0.7071], [ 0.0000, -1.0000, 0.0000], [-0.7071, 0.0000, -0.7071]]) s: tensor([2.0000e+00, 1.0000e+00, 1.3491e-08]) v: tensor([[-7.0711e-01, -0.0000e+00, 7.0711e-01], [ 0.0000e+00, -1.0000e+00, 0.0000e+00], [-7.0711e-01, -1.1921e-07, -7.0711e-01]]) Step 4: 對比scipy直接計算與pytorch間接計算結果 scipy.LA.orth(): [[-0.7071068 0. ] [ 0. -1. ] [-0.70710677 0. ]] scipyResult - torchResult tensor(0.)
- 函式
def torchOrth(A):
r = torch.matrix_rank(A)
u,s,v = torch.svd(A)
return u[:,:r]
A = np.array([[1.,0,1],[0,1,0],[1,0,1]])
B = torch.Tensor(A)
torchOrth(B)
- 實踐
當引數過大的時候,仍舊存在一定的誤差, 此時可設定64位精度
torch.set_default_dtype(torch.float64)
def torchOrth(A): r = torch.matrix_rank(A) u,s,v = torch.svd(A) return u[:,:r] nFea = 10 nWin = 10 nOrt = 500 randomOth = random.randn(nWin * nFea + 1, nOrt) if nFea * nWin >= nOrt: wenh1 = LA.orth(2 * randomOth)-1 else: wenh1 = LA.orth(2 * randomOth.T-1).T randomOth = torch.Tensor(randomOth) if nFea * nWin >= nOrt: wenh2 = torchOrth(2 * randomOth)-1 else: wenh2 = torchOrth(2 * randomOth.T-1).T torch.dist(torch.Tensor(wenh1),wenh2,2)