1. 程式人生 > >遙感分類的一種取樣方法

遙感分類的一種取樣方法

如深度學習,輸入要求為一小鄰域(下文稱鄰域塊)代表中心畫素型別。現有柵格影象,以及摳圖面檔案(.shp)。以下主要集中與arcgis操作。閱讀本文前,建議閱讀多光譜遙感分類:使用CNN1(一)

一種方法是使用隨機點,但是就本任務目標其弊端明顯(鄰域塊重疊相關,可通過設定隨機點間隔解決,但會使樣本大大減少)。具體參考
在這裡插入圖片描述

本文將描述的方法基於漁網。通過建立漁網(設定像元間隔)->疊加分析.相交。可以得到近可能多的點。
在這裡插入圖片描述

本文相關程式碼如下,讀者有必要自行取捨。(包括將點shp檔案匯出座標檔案topos,根據座標檔案取樣到各個資料夾)。

"""
@file: dataCreate.py
@time: 2018/10/15
"""
import os import shapefile import gdal import pandas as pd import matplotlib.pyplot as plt from pyecharts import Bar import numpy as np import cv2 import shutil import sys dataset = gdal.Open(r"E:\Experiment\Mine\dataformat.tif") rer=shapefile.Reader(r"E:\Experiment\Mine\相交.shp") size=64 def topos
(): res=[] geo=dataset.GetGeoTransform() for i in range(rer.numRecords):#rer.numRecords pos=rer.shape(i).points[0] label=rer.record(i)[1] xoffset = int((pos[0] - geo[0]) / geo[1]) yoffset = int((pos[1] - geo[3]) / geo[5]) res.append([xoffset,yoffset,label]
) res=pd.DataFrame(res) print(res.head()) res.to_csv("../output/pos.csv",header=None,index=None) def get_cell(pos_x, pos_y): try: output = [] for i in [1,2,3]: band = dataset.GetRasterBand(i) if (int(pos_x - size / 2) < 0 or int(pos_y - size / 2) < 0 or int(pos_x - size / 2) + size > dataset.RasterXSize or int(pos_y - size / 2) + size > dataset.RasterYSize): return None t = band.ReadAsArray(int(pos_x - size / 2), int(pos_y - size / 2), size, size) output.append(t) img = np.moveaxis(np.array(output, dtype=np.uint8), 0, 2) except: return None return img def createdoc(dic): d="../output/img" if os.path.exists(d): shutil.rmtree(d) os.makedirs(d) for i in dic.values(): os.makedirs(os.path.join(d,i)) def toImg(): data=pd.read_csv("../output/pos.csv",header=None) labeldic=num2label() createdoc(labeldic) for line,row in data.iterrows(): img=get_cell(row[0],row[1]) if img is None: continue label=labeldic.get(row[2]) # cv2.imwrite("../output/img/%s/%d.png" % (label,line),img) cv2.imencode('.png', img)[1].tofile("../output/img/%s/%d.png" % (label,line)) def num2label(): eo=pd.read_csv(r"E:\Experiment\Mine\Export_Output.txt",index_col=0) dic=dict() for _,row in eo.iterrows(): dic[row[1]]=row[2] return dic if __name__ == '__main__': # toImg() # info() pass

取樣結果如下:
在這裡插入圖片描述

[[‘城鄉居民建設用地_紅白頂’ ‘排土場’ ‘未利用土地_裸土地’ ‘水體’ ‘排土場’ ‘排土場’]
[‘城鄉居民建設用地_灰白頂’ ‘採場’ ‘林地_灰’ ‘採場’ ‘採場’ ‘排土場’]
[‘耕地_旱地_綠色’ ‘水體’ ‘耕地_旱地_綠色’ ‘林地_紅’ ‘排土場’ ‘選礦場’]
[‘選礦場’ ‘耕地_旱地_灰色’ ‘水體’ ‘選礦場’ ‘耕地_旱地_灰色’ ‘林地_紅’]
[‘選礦場’ ‘耕地_旱地_灰色’ ‘採場’ ‘採場’ ‘選礦場’ ‘林地_紅’]
[‘選礦場’ ‘採場’ ‘採場’ ‘選礦場’ ‘林地_黑’ ‘排土場’]]

其視覺化程式碼如下:

"""
@file: tongji.py
@time: 2018/10/16
"""
import os
import sys
import re
import numpy as np

import torch
import torchvision
from pyecharts import Bar
import random
import pandas as pd
import matplotlib.pyplot as plt
from torchvision import datasets

from dataDeal.datacreate import num2label
from PIL import Image

plt.rcParams['font.sans-serif']=['SimHei'] #用來正常顯示中文標籤
plt.rcParams['axes.unicode_minus']=False #用來正常顯示負號

def show():
    """
    顯示資料夾子資料夾下圖片統計表
    :return:
    """
    path=r"../output/img"
    d=os.listdir(path)
    d_len=[len(os.listdir(os.path.join(path,i))) for i in d]

    line = Bar(path)
    line.add("圖片數量", d, d_len, mark_point=["average", "max", "min"],xaxis_rotate=50)
    line.render(path="../output/圖片數量.html")

def info():
    path="../output/pos.csv"
    res=pd.read_csv(path,header=None)

    xy=res.iloc[:, 2].value_counts()
    label=[num2label().get(i) for i in xy.index.values]

    # print(xy.index.values)
    line = Bar(path)
    line.add("點數量", label,xy.values, is_smooth=True, mark_line=["max", "average","min"],xaxis_rotate=50)
    line.render("../output/點數量.html")

def imageshow():

    # 所有路徑、標籤
    path = r"../output/img"

    # imgall=[]
    # labelall=[]
    # for root, dirs, files in os.walk(path):  # 目錄
    #     for f in files:
    #         p = os.path.join(root, f)
    #         label=re.split("/|\\\\",p)[-2]
    #         imgall.append(p)
    #         labelall.append(label)
    # print(len(imgall),len(labelall))
    #
    # row=4
    # col=4
    # pos=random.sample(range(len(labelall)),row*col)
    # # print(pos)
    # plt.figure(figsize=(10,8))
    # for i,value in enumerate(pos):
    #     plt.subplot(row,col,i+1)
    #     plt.imshow(Image.open(imgall[value]))
    #     plt.title(labelall[value])
    #     plt.xticks([])
    #     plt.yticks([])
    #
    # plt.show()

    image_datasets =datasets.ImageFolder(os.path.join(path),transform=torchvision.transforms.ToTensor())
    dataloaders =torch.utils.data.DataLoader(image_datasets, batch_size=36,shuffle=True)

    inputs, classes = next(iter(dataloaders))
    labels=[image_datasets.classes[i] for i in classes.numpy()]
    print(np.array(labels).reshape(-1,6))
    out = torchvision.utils.make_grid(inputs,6,0)
    inp = out.numpy().transpose((1, 2, 0))
    plt.imshow(inp)
    plt.show()

if __name__ == '__main__':

    # show()
    imageshow()

    pass