1. 程式人生 > 實用技巧 >高光譜分類

高光譜分類

HybridSN 高光譜分類

S. K. Roy, G. Krishna, S. R. Dubey, B. B. Chaudhuri HybridSN: Exploring 3-D–2-D CNN Feature Hierarchy for Hyperspectral Image Classification, IEEE GRSL 2020

這篇論文構建了一個 混合網路 解決高光譜影象分類問題,首先用 3D卷積,然後使用 2D卷積,程式碼相對簡單,下面是程式碼的解析。

資料集已經下載到 src/data 資料夾之中,一共有三類資料{Indian-pines, PaviaU, Salinas}

引入基本函式庫

import numpy as np
import matplotlib.pyplot as plt
import scipy.io as sio
from sklearn.decomposition import PCA
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, accuracy_score, classification_report, cohen_kappa_score
import spectral
import torch
import torchvision
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim

%matplotlib inline

1. 定義 HybridSN 類

模型的網路結構為如下圖所示:

根據網路結構構建網路如下:

三維卷積部分:

conv1:(1, 30, 25, 25), 8個 7x3x3 的卷積核 ==>(8, 24, 23, 23)

conv2:(8, 24, 23, 23), 16個 5x3x3 的卷積核 ==>(16, 20, 21, 21)

conv3:(16, 20, 21, 21),32個 3x3x3 的卷積核 ==>(32, 18, 19, 19)

接下來要進行二維卷積,因此把前面的 32*18 reshape 一下,得到 (576, 19, 19)

二維卷積:(576, 19, 19) 64個 3x3 的卷積核,得到 (64, 17, 17)

接下來是一個 flatten 操作,變為 18496 維的向量,

接下來依次為256,128節點的全連線層,都使用比例為0.4的 Dropout,

最後輸出為 16 個節點,是最終的分類類別數。

下面是 HybridSN 類的程式碼:

class HybridSN(nn.Module):
    def __init__(self, num_classes):
        super(HybridSN, self).__init__()
        # out = (width - kernel_size + 2*padding)/stride + 1
        # => padding = ( stride * (out-1) + kernel_size - width)

        # 3D卷積塊
        self.block_1_3D = nn.Sequential(
            nn.Conv3d(
                in_channels=1,
                out_channels=8,
                kernel_size=(7, 3, 3),
                stride=1,
                padding=0
            ),
            nn.ReLU(inplace=True),
            nn.Conv3d(
                in_channels=8,
                out_channels=16,
                kernel_size=(5, 3, 3),
                stride=1,
                padding=0
            ),
            nn.ReLU(inplace=True),
            nn.Conv3d(
                in_channels=16,
                out_channels=32,
                kernel_size=(3, 3, 3),
                stride=1,
                padding=0
            ),
            nn.ReLU(inplace=True)
        )

        # 2D卷積塊
        self.block_2_2D = nn.Sequential(
            nn.Conv2d(
                in_channels=576,
                out_channels=64,
                kernel_size=(3, 3)
            ),
            nn.ReLU(inplace=True)
        )

        # 全連線層
        self.classifier = nn.Sequential(
            nn.Linear(
                in_features=18496,
                out_features=256
            ),
            nn.Dropout(p=0.4),
            nn.Linear(
                in_features=256,
                out_features=128
            ),
            nn.Dropout(p=0.4),
            nn.Linear(
                in_features=128,
                out_features=num_classes
            )
        )

    def forward(self, x):
        y = self.block_1_3D(x)
        y = y.view(-1, y.shape[1] * y.shape[2], y.shape[3], y.shape[4])
        y = self.block_2_2D(y)
        #y = torch.flatten(y.detach())
        y = y.view(y.size(0), -1)
        y = self.classifier(y)
        # y = nn.LogSoftmax(y)
        return y

if __name__ == '__main__':
    # 隨機輸入,測試網路結構是否通
    x = torch.randn(1, 1, 30, 25, 25)
    net = HybridSN(num_classes=16)
    y = net(x)
    #print(y.shape)
    print(y)
tensor([[ 0.0472, -0.0176, -0.0186, -0.0334,  0.0459, -0.0728, -0.0416, -0.0740,
          0.0646,  0.0419, -0.0739,  0.0102, -0.0338, -0.0616, -0.0066, -0.0077]],
       grad_fn=<AddmmBackward>)

2. 建立資料集

首先對高光譜資料實施PCA降維;然後建立 keras 方便處理的資料格式;然後隨機抽取 10% 資料做為訓練集,剩餘的做為測試集。

首先定義基本函式:

# 對高光譜資料 X 應用 PCA 變換
def applyPCA(X, numComponents):
    newX = np.reshape(X, (-1, X.shape[2]))
    pca = PCA(n_components=numComponents, whiten=True)
    newX = pca.fit_transform(newX)
    newX = np.reshape(newX, (X.shape[0], X.shape[1], numComponents))
    return newX

# 對單個畫素周圍提取 patch 時,邊緣畫素就無法取了,因此,給這部分畫素進行 padding 操作
def padWithZeros(X, margin=2):
    newX = np.zeros((X.shape[0] + 2 * margin, X.shape[1] + 2* margin, X.shape[2]))
    x_offset = margin
    y_offset = margin
    newX[x_offset:X.shape[0] + x_offset, y_offset:X.shape[1] + y_offset, :] = X
    return newX

# 在每個畫素周圍提取 patch ,然後建立成符合 keras 處理的格式
def createImageCubes(X, y, windowSize=5, removeZeroLabels = True):
    # 給 X 做 padding
    margin = int((windowSize - 1) / 2)
    zeroPaddedX = padWithZeros(X, margin=margin)
    # split patches
    patchesData = np.zeros((X.shape[0] * X.shape[1], windowSize, windowSize, X.shape[2]))
    patchesLabels = np.zeros((X.shape[0] * X.shape[1]))
    patchIndex = 0
    for r in range(margin, zeroPaddedX.shape[0] - margin):
        for c in range(margin, zeroPaddedX.shape[1] - margin):
            patch = zeroPaddedX[r - margin:r + margin + 1, c - margin:c + margin + 1]
            patchesData[patchIndex, :, :, :] = patch
            patchesLabels[patchIndex] = y[r-margin, c-margin]
            patchIndex = patchIndex + 1
    if removeZeroLabels:
        patchesData = patchesData[patchesLabels>0,:,:,:]
        patchesLabels = patchesLabels[patchesLabels>0]
        patchesLabels -= 1
    return patchesData, patchesLabels

def splitTrainTestSet(X, y, testRatio, randomState=345):
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=testRatio, random_state=randomState, stratify=y)
    return X_train, X_test, y_train, y_test

下面讀取並建立資料集:

# 地物類別
class_num = 16
X = sio.loadmat('./src/data/Indian_pines_corrected.mat')['indian_pines_corrected']
y = sio.loadmat('./src/data/Indian_pines_gt.mat')['indian_pines_gt']

# 用於測試樣本的比例
test_ratio = 0.90
# 每個畫素周圍提取 patch 的尺寸
patch_size = 25
# 使用 PCA 降維,得到主成分的數量
pca_components = 30

print('Hyperspectral data shape: ', X.shape)
print('Label shape: ', y.shape)

print('\n... ... PCA tranformation ... ...')
X_pca = applyPCA(X, numComponents=pca_components)
print('Data shape after PCA: ', X_pca.shape)

print('\n... ... create data cubes ... ...')
X_pca, y = createImageCubes(X_pca, y, windowSize=patch_size)
print('Data cube X shape: ', X_pca.shape)
print('Data cube y shape: ', y.shape)

print('\n... ... create train & test data ... ...')
Xtrain, Xtest, ytrain, ytest = splitTrainTestSet(X_pca, y, test_ratio)
print('Xtrain shape: ', Xtrain.shape)
print('Xtest  shape: ', Xtest.shape)

# 改變 Xtrain, Ytrain 的形狀,以符合 keras 的要求
Xtrain = Xtrain.reshape(-1, patch_size, patch_size, pca_components, 1)
Xtest  = Xtest.reshape(-1, patch_size, patch_size, pca_components, 1)
print('before transpose: Xtrain shape: ', Xtrain.shape)
print('before transpose: Xtest  shape: ', Xtest.shape)

# 為了適應 pytorch 結構,資料要做 transpose
Xtrain = Xtrain.transpose(0, 4, 3, 1, 2)
Xtest  = Xtest.transpose(0, 4, 3, 1, 2)
print('after transpose: Xtrain shape: ', Xtrain.shape)
print('after transpose: Xtest  shape: ', Xtest.shape)


""" Training dataset"""
class TrainDS(torch.utils.data.Dataset):
    def __init__(self):
        self.len = Xtrain.shape[0]
        self.x_data = torch.FloatTensor(Xtrain)
        self.y_data = torch.LongTensor(ytrain)
    def __getitem__(self, index):
        # 根據索引返回資料和對應的標籤
        return self.x_data[index], self.y_data[index]
    def __len__(self):
        # 返回檔案資料的數目
        return self.len

""" Testing dataset"""
class TestDS(torch.utils.data.Dataset):
    def __init__(self):
        self.len = Xtest.shape[0]
        self.x_data = torch.FloatTensor(Xtest)
        self.y_data = torch.LongTensor(ytest)
    def __getitem__(self, index):
        # 根據索引返回資料和對應的標籤
        return self.x_data[index], self.y_data[index]
    def __len__(self):
        # 返回檔案資料的數目
        return self.len

# 建立 trainloader 和 testloader
trainset = TrainDS()
testset  = TestDS()
train_loader = torch.utils.data.DataLoader(dataset=trainset, batch_size=128, shuffle=True, num_workers=0)
test_loader  = torch.utils.data.DataLoader(dataset=testset,  batch_size=128, shuffle=False, num_workers=0)
Hyperspectral data shape:  (145, 145, 200)
Label shape:  (145, 145)

... ... PCA tranformation ... ...
Data shape after PCA:  (145, 145, 30)

... ... create data cubes ... ...
Data cube X shape:  (10249, 25, 25, 30)
Data cube y shape:  (10249,)

... ... create train & test data ... ...
Xtrain shape:  (1024, 25, 25, 30)
Xtest  shape:  (9225, 25, 25, 30)
before transpose: Xtrain shape:  (1024, 25, 25, 30, 1)
before transpose: Xtest  shape:  (9225, 25, 25, 30, 1)
after transpose: Xtrain shape:  (1024, 1, 30, 25, 25)
after transpose: Xtest  shape:  (9225, 1, 30, 25, 25)

3. 開始訓練

# 使用GPU訓練,可以在選單 "程式碼執行工具" -> "更改執行時型別" 裡進行設定
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

# 網路放到GPU上
net = HybridSN(num_classes=16).to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(net.parameters(), lr=0.001)

# 開始訓練
total_loss = 0
net.train()
for epoch in range(100):
    for i, (inputs, labels) in enumerate(train_loader):
        inputs = inputs.to(device)
        labels = labels.to(device)
        # 優化器梯度歸零
        optimizer.zero_grad()
        # 正向傳播 + 反向傳播 + 優化
        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()
    print('[Epoch: %d]   [loss avg: %.4f]   [current loss: %.4f]' %(epoch + 1, total_loss/(epoch+1), loss.item()))

print('Finished Training')
[Epoch: 1]   [loss avg: 20.7095]   [current loss: 2.3290]
[Epoch: 2]   [loss avg: 20.1627]   [current loss: 2.5282]
[Epoch: 3]   [loss avg: 19.4209]   [current loss: 1.9093]
[Epoch: 4]   [loss avg: 18.4654]   [current loss: 1.7938]
[Epoch: 5]   [loss avg: 17.0903]   [current loss: 1.2731]
[Epoch: 6]   [loss avg: 15.6037]   [current loss: 1.0079]
[Epoch: 7]   [loss avg: 14.1798]   [current loss: 0.4204]
[Epoch: 8]   [loss avg: 12.8882]   [current loss: 0.3953]
[Epoch: 9]   [loss avg: 11.7441]   [current loss: 0.2800]
[Epoch: 10]   [loss avg: 10.7405]   [current loss: 0.1938]
[Epoch: 11]   [loss avg: 9.8935]   [current loss: 0.0548]
[Epoch: 12]   [loss avg: 9.1668]   [current loss: 0.1489]
[Epoch: 13]   [loss avg: 8.5507]   [current loss: 0.1108]
[Epoch: 14]   [loss avg: 8.0388]   [current loss: 0.1399]
[Epoch: 15]   [loss avg: 7.5784]   [current loss: 0.1596]
[Epoch: 16]   [loss avg: 7.1522]   [current loss: 0.0643]
[Epoch: 17]   [loss avg: 6.7597]   [current loss: 0.0592]
[Epoch: 18]   [loss avg: 6.3997]   [current loss: 0.0298]
[Epoch: 19]   [loss avg: 6.0752]   [current loss: 0.0166]
[Epoch: 20]   [loss avg: 5.7817]   [current loss: 0.0281]
[Epoch: 21]   [loss avg: 5.5139]   [current loss: 0.0208]
[Epoch: 22]   [loss avg: 5.2746]   [current loss: 0.0992]
[Epoch: 23]   [loss avg: 5.0558]   [current loss: 0.0131]
[Epoch: 24]   [loss avg: 4.8551]   [current loss: 0.0555]
[Epoch: 25]   [loss avg: 4.6687]   [current loss: 0.0217]
[Epoch: 26]   [loss avg: 4.4965]   [current loss: 0.0635]
[Epoch: 27]   [loss avg: 4.3345]   [current loss: 0.0035]
[Epoch: 28]   [loss avg: 4.1922]   [current loss: 0.0182]
[Epoch: 29]   [loss avg: 4.0607]   [current loss: 0.0082]
[Epoch: 30]   [loss avg: 3.9388]   [current loss: 0.0260]
[Epoch: 31]   [loss avg: 3.8228]   [current loss: 0.0065]
[Epoch: 32]   [loss avg: 3.7112]   [current loss: 0.0450]
[Epoch: 33]   [loss avg: 3.6075]   [current loss: 0.0144]
[Epoch: 34]   [loss avg: 3.5066]   [current loss: 0.0199]
[Epoch: 35]   [loss avg: 3.4119]   [current loss: 0.0080]
[Epoch: 36]   [loss avg: 3.3190]   [current loss: 0.0087]
[Epoch: 37]   [loss avg: 3.2320]   [current loss: 0.0324]
[Epoch: 38]   [loss avg: 3.1489]   [current loss: 0.0355]
[Epoch: 39]   [loss avg: 3.0693]   [current loss: 0.0288]
[Epoch: 40]   [loss avg: 2.9933]   [current loss: 0.0027]
[Epoch: 41]   [loss avg: 2.9223]   [current loss: 0.0162]
[Epoch: 42]   [loss avg: 2.8553]   [current loss: 0.0024]
[Epoch: 43]   [loss avg: 2.7906]   [current loss: 0.0026]
[Epoch: 44]   [loss avg: 2.7351]   [current loss: 0.0393]
[Epoch: 45]   [loss avg: 2.6765]   [current loss: 0.0425]
[Epoch: 46]   [loss avg: 2.6233]   [current loss: 0.0027]
[Epoch: 47]   [loss avg: 2.5706]   [current loss: 0.0543]
[Epoch: 48]   [loss avg: 2.5202]   [current loss: 0.0059]
[Epoch: 49]   [loss avg: 2.4715]   [current loss: 0.0086]
[Epoch: 50]   [loss avg: 2.4285]   [current loss: 0.0198]
[Epoch: 51]   [loss avg: 2.3821]   [current loss: 0.0280]
[Epoch: 52]   [loss avg: 2.3387]   [current loss: 0.0229]
[Epoch: 53]   [loss avg: 2.2958]   [current loss: 0.0064]
[Epoch: 54]   [loss avg: 2.2554]   [current loss: 0.0517]
[Epoch: 55]   [loss avg: 2.2168]   [current loss: 0.0210]
[Epoch: 56]   [loss avg: 2.1824]   [current loss: 0.0684]
[Epoch: 57]   [loss avg: 2.1498]   [current loss: 0.0913]
[Epoch: 58]   [loss avg: 2.1182]   [current loss: 0.0024]
[Epoch: 59]   [loss avg: 2.0845]   [current loss: 0.0261]
[Epoch: 60]   [loss avg: 2.0552]   [current loss: 0.0577]
[Epoch: 61]   [loss avg: 2.0244]   [current loss: 0.0200]
[Epoch: 62]   [loss avg: 1.9957]   [current loss: 0.0715]
[Epoch: 63]   [loss avg: 1.9669]   [current loss: 0.0226]
[Epoch: 64]   [loss avg: 1.9377]   [current loss: 0.0270]
[Epoch: 65]   [loss avg: 1.9097]   [current loss: 0.0007]
[Epoch: 66]   [loss avg: 1.8865]   [current loss: 0.0040]
[Epoch: 67]   [loss avg: 1.8608]   [current loss: 0.0250]
[Epoch: 68]   [loss avg: 1.8350]   [current loss: 0.0119]
[Epoch: 69]   [loss avg: 1.8103]   [current loss: 0.0008]
[Epoch: 70]   [loss avg: 1.7860]   [current loss: 0.0014]
[Epoch: 71]   [loss avg: 1.7625]   [current loss: 0.0011]
[Epoch: 72]   [loss avg: 1.7396]   [current loss: 0.0349]
[Epoch: 73]   [loss avg: 1.7170]   [current loss: 0.0176]
[Epoch: 74]   [loss avg: 1.6940]   [current loss: 0.0005]
[Epoch: 75]   [loss avg: 1.6723]   [current loss: 0.0199]
[Epoch: 76]   [loss avg: 1.6511]   [current loss: 0.0015]
[Epoch: 77]   [loss avg: 1.6326]   [current loss: 0.0620]
[Epoch: 78]   [loss avg: 1.6134]   [current loss: 0.0016]
[Epoch: 79]   [loss avg: 1.5959]   [current loss: 0.0129]
[Epoch: 80]   [loss avg: 1.5779]   [current loss: 0.0096]
[Epoch: 81]   [loss avg: 1.5606]   [current loss: 0.0300]
[Epoch: 82]   [loss avg: 1.5433]   [current loss: 0.0016]
[Epoch: 83]   [loss avg: 1.5268]   [current loss: 0.0050]
[Epoch: 84]   [loss avg: 1.5100]   [current loss: 0.0503]
[Epoch: 85]   [loss avg: 1.4926]   [current loss: 0.0007]
[Epoch: 86]   [loss avg: 1.4757]   [current loss: 0.0013]
[Epoch: 87]   [loss avg: 1.4598]   [current loss: 0.0007]
[Epoch: 88]   [loss avg: 1.4434]   [current loss: 0.0005]
[Epoch: 89]   [loss avg: 1.4272]   [current loss: 0.0001]
[Epoch: 90]   [loss avg: 1.4115]   [current loss: 0.0028]
[Epoch: 91]   [loss avg: 1.3962]   [current loss: 0.0008]
[Epoch: 92]   [loss avg: 1.3811]   [current loss: 0.0000]
[Epoch: 93]   [loss avg: 1.3664]   [current loss: 0.0002]
[Epoch: 94]   [loss avg: 1.3519]   [current loss: 0.0008]
[Epoch: 95]   [loss avg: 1.3380]   [current loss: 0.0001]
[Epoch: 96]   [loss avg: 1.3242]   [current loss: 0.0002]
[Epoch: 97]   [loss avg: 1.3107]   [current loss: 0.0007]
[Epoch: 98]   [loss avg: 1.2974]   [current loss: 0.0004]
[Epoch: 99]   [loss avg: 1.2844]   [current loss: 0.0001]
[Epoch: 100]   [loss avg: 1.2722]   [current loss: 0.0001]
Finished Training

4. 模型測試

count = 0
# 模型測試
net.eval()
for inputs, _ in test_loader:
    inputs = inputs.to(device)
    outputs = net(inputs)
    outputs = np.argmax(outputs.detach().cpu().numpy(), axis=1)
    if count == 0:
        y_pred_test =  outputs
        count = 1
    else:
        y_pred_test = np.concatenate( (y_pred_test, outputs) )

# 生成分類報告
classification = classification_report(ytest, y_pred_test, digits=4)
print(classification)
              precision    recall  f1-score   support

         0.0     1.0000    0.9512    0.9750        41
         1.0     0.9886    0.9463    0.9670      1285
         2.0     0.9750    0.9933    0.9841       747
         3.0     0.9742    0.8873    0.9287       213
         4.0     0.9448    0.9839    0.9640       435
         5.0     0.9671    0.9833    0.9751       657
         6.0     1.0000    0.8800    0.9362        25
         7.0     0.9389    1.0000    0.9685       430
         8.0     1.0000    0.8889    0.9412        18
         9.0     0.9887    0.9989    0.9937       875
        10.0     0.9727    0.9824    0.9775      2210
        11.0     0.9981    0.9663    0.9819       534
        12.0     1.0000    0.9459    0.9722       185
        13.0     0.9965    1.0000    0.9982      1139
        14.0     0.9942    0.9827    0.9884       347
        15.0     0.9222    0.9881    0.9540        84

    accuracy                         0.9785      9225
   macro avg     0.9788    0.9612    0.9691      9225
weighted avg     0.9789    0.9785    0.9785      9225

5. 備用函式

下面是用於計算各個類準確率,顯示結果的備用函式,以供參考

from operator import truediv

def AA_andEachClassAccuracy(confusion_matrix):
    counter = confusion_matrix.shape[0]
    list_diag = np.diag(confusion_matrix)
    list_raw_sum = np.sum(confusion_matrix, axis=1)
    each_acc = np.nan_to_num(truediv(list_diag, list_raw_sum))
    average_acc = np.mean(each_acc)
    return each_acc, average_acc


def reports (test_loader, y_test, name):
    count = 0
    # 模型測試
    for inputs, _ in test_loader:
        inputs = inputs.to(device)
        outputs = net(inputs)
        outputs = np.argmax(outputs.detach().cpu().numpy(), axis=1)
        if count == 0:
            y_pred =  outputs
            count = 1
        else:
            y_pred = np.concatenate( (y_pred, outputs) )

    if name == 'IP':
        target_names = ['Alfalfa', 'Corn-notill', 'Corn-mintill', 'Corn'
                        ,'Grass-pasture', 'Grass-trees', 'Grass-pasture-mowed',
                        'Hay-windrowed', 'Oats', 'Soybean-notill', 'Soybean-mintill',
                        'Soybean-clean', 'Wheat', 'Woods', 'Buildings-Grass-Trees-Drives',
                        'Stone-Steel-Towers']
    elif name == 'SA':
        target_names = ['Brocoli_green_weeds_1','Brocoli_green_weeds_2','Fallow','Fallow_rough_plow','Fallow_smooth',
                        'Stubble','Celery','Grapes_untrained','Soil_vinyard_develop','Corn_senesced_green_weeds',
                        'Lettuce_romaine_4wk','Lettuce_romaine_5wk','Lettuce_romaine_6wk','Lettuce_romaine_7wk',
                        'Vinyard_untrained','Vinyard_vertical_trellis']
    elif name == 'PU':
        target_names = ['Asphalt','Meadows','Gravel','Trees', 'Painted metal sheets','Bare Soil','Bitumen',
                        'Self-Blocking Bricks','Shadows']

    classification = classification_report(y_test, y_pred, target_names=target_names)
    oa = accuracy_score(y_test, y_pred)
    confusion = confusion_matrix(y_test, y_pred)
    each_acc, aa = AA_andEachClassAccuracy(confusion)
    kappa = cohen_kappa_score(y_test, y_pred)

    return classification, confusion, oa*100, each_acc*100, aa*100, kappa*100

檢測結果寫在檔案裡:

classification, confusion, oa, each_acc, aa, kappa = reports(test_loader, ytest, 'IP')
classification = str(classification)
confusion = str(confusion)
file_name = "classification_report.txt"

with open(file_name, 'w') as x_file:
    x_file.write('\n')
    x_file.write('{} Kappa accuracy (%)'.format(kappa))
    x_file.write('\n')
    x_file.write('{} Overall accuracy (%)'.format(oa))
    x_file.write('\n')
    x_file.write('{} Average accuracy (%)'.format(aa))
    x_file.write('\n')
    x_file.write('\n')
    x_file.write('{}'.format(classification))
    x_file.write('\n')
    x_file.write('{}'.format(confusion))

下面程式碼用於顯示分類結果:

# load the original image
X = sio.loadmat('./src/data/Indian_pines_corrected.mat')['indian_pines_corrected']
y = sio.loadmat('./src/data/Indian_pines_gt.mat')['indian_pines_gt']

height = y.shape[0]
width = y.shape[1]

X = applyPCA(X, numComponents= pca_components)
X = padWithZeros(X, patch_size//2)

# 逐畫素預測類別
outputs = np.zeros((height,width))
for i in range(height):
    for j in range(width):
        if int(y[i,j]) == 0:
            continue
        else :
            image_patch = X[i:i+patch_size, j:j+patch_size, :]
            image_patch = image_patch.reshape(1,image_patch.shape[0],image_patch.shape[1], image_patch.shape[2], 1)
            X_test_image = torch.FloatTensor(image_patch.transpose(0, 4, 3, 1, 2)).to(device)
            prediction = net(X_test_image)
            prediction = np.argmax(prediction.detach().cpu().numpy(), axis=1)
            outputs[i][j] = prediction+1
    if i % 20 == 0:
        print('... ... row ', i, ' handling ... ...')
... ... row  0  handling ... ...
... ... row  20  handling ... ...
... ... row  40  handling ... ...
... ... row  60  handling ... ...
... ... row  80  handling ... ...
... ... row  100  handling ... ...
... ... row  120  handling ... ...
... ... row  140  handling ... ...
predict_image = spectral.imshow(classes = outputs.astype(int),figsize =(5,5))
D:\Anaconda3\envs\PyTorch\lib\site-packages\spectral\graphics\spypylab.py:27: MatplotlibDeprecationWarning: 
The keymap.all_axes rcparam was deprecated in Matplotlib 3.3 and will be removed two minor releases later.
  mpl.rcParams['keymap.all_axes'] = ''
D:\Anaconda3\envs\PyTorch\lib\site-packages\spectral\graphics\spypylab.py:905: MatplotlibDeprecationWarning: Passing parameters norm and vmin/vmax simultaneously is deprecated since 3.3 and will become an error two minor releases later. Please pass vmin/vmax directly to the norm when creating it.
  self.class_axes = plt.imshow(self.class_rgb, **kwargs)