1. 程式人生 > 其它 >201208-PyTorch求解矩陣正交基

201208-PyTorch求解矩陣正交基

技術標籤:Pytorch

orth is obtained from U in the singular value decomposition, [U,S] = svd(A,‘econ’). If r = rank(A), the first r columns of U form an orthonormal basis for the range of A.

  • 步驟:

    • svd分解
    • 按秩索引
  • 程式碼

from scipy import linalg as LA
import torch


# Step 1: 建立相關矩陣
print('Step 1: 建立相關矩陣')
A = np.array([[1.,0,1],[0,1,0],[1,0,1]])
B = torch.Tensor(A)
print(B)
print('\n')

# Step 2: 求矩陣的秩
print('Step 2: 求矩陣的秩')
r = torch.matrix_rank(B)
print(r)
print('\n')

# Step 3: 矩陣的svd分解
print('Step 3: 矩陣的svd分解及正交基索引')
u,s,v = torch.svd(B)
torchResult = u[:,:r]
print('u:\n',u)
print('s:\n',s)
print('v:\n',v)
print('\n')

# Step 4: 對比scipy直接計算與pytorch間接計算結果
print('Step 4: 對比scipy直接計算與pytorch間接計算結果')
scipyResult = LA.orth(B)
print('scipy.LA.orth():\n',scipyResult)
scipyResult.dtype

print('scipyResult - torchResult\n',torch.dist(torch.Tensor(scipyResult), torchResult,2))
  • 結果
Step 1: 建立相關矩陣
tensor([[1., 0., 1.],
        [0., 1., 0.],
        [1., 0., 1.]])


Step 2: 求矩陣的秩
tensor(2)


Step 3: 矩陣的svd分解及正交基索引
u:
 tensor([[-0.7071,  0.0000,  0.7071],
        [ 0.0000, -1.0000,  0.0000],
        [-0.7071,  0.0000, -0.7071]])
s:
 tensor([2.0000e+00, 1.0000e+00, 1.3491e-08])
v:
 tensor([[-7.0711e-01, -0.0000e+00,  7.0711e-01],
        [ 0.0000e+00, -1.0000e+00,  0.0000e+00],
        [-7.0711e-01, -1.1921e-07, -7.0711e-01]])


Step 4: 對比scipy直接計算與pytorch間接計算結果
scipy.LA.orth():
 [[-0.7071068   0.        ]
 [ 0.         -1.        ]
 [-0.70710677  0.        ]]
scipyResult - torchResult
 tensor(0.)
  • 函式
def torchOrth(A):
    r = torch.matrix_rank(A)
    u,s,v = torch.svd(A)
    return u[:,:r]

A = np.array([[1.,0,1],[0,1,0],[1,0,1]])
B = torch.Tensor(A)
torchOrth(B)
  • 實踐

當引數過大的時候,仍舊存在一定的誤差, 此時可設定64位精度
torch.set_default_dtype(torch.float64)

def torchOrth(A):
    r = torch.matrix_rank(A)
    u,s,v = torch.svd(A)
    return u[:,:r]

nFea = 10
nWin = 10
nOrt = 500

randomOth = random.randn(nWin * nFea + 1, nOrt)

if nFea * nWin >= nOrt:
    wenh1 = LA.orth(2 * randomOth)-1
else:
    wenh1 = LA.orth(2 * randomOth.T-1).T

randomOth = torch.Tensor(randomOth)
if nFea * nWin >= nOrt:
    wenh2 = torchOrth(2 * randomOth)-1
else:
    wenh2 = torchOrth(2 * randomOth.T-1).T

torch.dist(torch.Tensor(wenh1),wenh2,2)