多光譜遙感分類:使用CNN1(一)
阿新 • • 發佈:2018-12-18
程式碼源於很久以前練手的一個Demo,時間長了許多魔改版的都不見了,目前只剩下此簡陋版本。讀者如有相關需求,可根據隻言片語斷章取義。由於程式碼混亂基礎,不再上傳GitHub。
所用資料為多光譜遙感影像(.tif,由arcgis匯出RGB彩色影象),摳圖所得點檔案(.shp)(由摳圖面檔案使用arcgis隨機生成點生成,至少有一個欄位,即標籤)。
工具篇
根據點shp檔案(樣本點集合),對柵格影象的3、2、1波段切圖,並儲存在相應標籤下的資料夾,注意shp、tif的投影座標一致
from osgeo import gdal
import numpy as np
import shapefile
import cv2
import os
size=64
bands=3
dataset = gdal.Open(r"E:\資料2\test_tif_peizhun_subset_proj_.tif")
rer=shapefile.Reader(r'E:\shps\test.shp')
def __createDir(path):
if not os.path.exists(path):
try:
os.makedirs(path)
except:
print("建立資料夾失敗")
exit( 1)
def __getACell(geo,pos):
try:
xoffset = int((pos[0] - geo[0]) / geo[1])
yoffset = int((pos[1] - geo[3]) / geo[5])
print("pixels: x= %d,y= %d" % (xoffset, yoffset))
output = []
for i in [3,2,1]:
band = dataset.GetRasterBand(i)
if (int (xoffset - size / 2) < 0 or int(yoffset - size / 2) < 0
or int(xoffset - size / 2) + size > dataset.RasterXSize
or int(yoffset - size / 2) + size > dataset.RasterYSize):
return None
t = band.ReadAsArray(int(xoffset - size / 2), int(yoffset - size / 2), size, size)
output.append(t)
img = np.moveaxis(np.array(output, dtype=np.uint8), 0, 2)
except:
return None
return img
def getShpDataForNum():
labels=[i[0] for i in rer.records()]
for i in set(labels):
__createDir(os.path.join("data/org/"+str(i)))
for i in range(rer.numRecords):#rer.numRecords
print("deal %d: " % (i+1))
sr=rer.shape(i)
img=__getACell(dataset.GetGeoTransform(), sr.points[0])
if(img is None):
print("the area of points %d is out range." %(i))
continue
label=labels[i]
cv2.imwrite("data/org/%s/%s.%d.jpg" % (label, label, i), img)
print("data/org/%s/%s.%d.jpg" % (label, label, i))
print("deal finish,to numpy array.")
getShpDataForNum()
如下,將上述所得檔案拆分為測試集和訓練集。
import os
import shutil
import random
def createDir(path):
if not os.path.exists(path):
try:
os.makedirs(path)
except:
print("建立資料夾失敗")
exit(1)
createDir("data/train/")
createDir("data/test/")
dir='data/org/'
for dir_item in os.listdir(dir):
createDir("data/train/" + dir_item)
createDir("data/test/"+dir_item)
org_data=os.listdir(dir+dir_item+"/")
random.shuffle(org_data)
num=int(len(org_data)*0.25)
print(dir + dir_item + " start.")
for d in org_data[:-num]:
shutil.copyfile(dir + dir_item + "/" + d, "data/train/" + dir_item + "/" + d)
for d in org_data[-num:]:
shutil.copyfile(dir+dir_item+"/"+d,"data/test/"+dir_item+"/"+d)
print(dir+dir_item+" finished")
以下顯示制定資料夾下的子資料夾中的檔案數目直方圖。
import os
import seaborn as sns
import matplotlib.pyplot as plt
def show(path,title):
d=os.listdir(path)
d_len=[len(os.listdir(os.path.join(path,i))) for i in d]
# print(d,d_len)
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用來正常顯示中文標籤
plt.rcParams['axes.unicode_minus'] = False # 用來正常顯示負號
sns.barplot(d,d_len,)
plt.xlabel("樣本型別")
plt.ylabel("數量")
plt.title(title)
for i in range(len(d_len)):
plt.text(i,d_len[i]+2,"%d" % d_len[i],ha="center",va="bottom")
plt.show()
show(r"data/1_train","訓練集源資料取樣集")
由於其他原因,資料更改。如下為使用shp樣本點對應的畫素座標所採圖集。此時分為train pos.txt和test pos.txt諸如此類。
from osgeo import gdal
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import os
import cv2,shutil
class Tiff:
def createDir(self, path):
if not os.path.exists(path):
try:
os.makedirs(path)
except:
print("建立資料夾失敗")
exit(1)
def __init__(self, pos_src,other_feather,contact_src,size=128,bands=[3,2,1],tif_src=r"D:/lishihang/jiangxia_simple/ZY3_GS_jiangxia1.tif"):
self.dataset = gdal.Open(tif_src) # tif資料
self.size = size # 取樣視窗大小
self.bands=bands
self.contact_pos_feather(pos_src, other_feather,contact_src)
self.fea =pd.read_csv(contact_src, header=None)
# shutil.rmtree("data/temp.txt")
def get_cell(self, pos_x, pos_y):
try:
output = []
for i in self.bands:
band = self.dataset.GetRasterBand(i)
t = band.ReadAsArray(int(pos_x - self.size / 2), int(pos_y - self.size / 2), self.size, self.size)
output.append(t)
img2 = np.moveaxis(np.array(output, dtype=np.uint8), 0, 2)
# print(img2.shape)
# self.showImg(img2)
except:
return None
return img2
def get_cells(self,target_src):
fea_len=len(self.fea)
self.createDir(target_src)
for label in set(self.fea.iloc[:,-2]):
self.createDir("%s/%s" % (target_src,label))
print("fea length: %d" % fea_len)
for i in range(fea_len):
temp=self.fea.iloc[i,:].values
img = self.get_cell(temp[1], temp[0])
if img is None:
continue
cv2.imwrite("%s/%s/%s.%d.jpg" % (target_src,temp[-2], temp[-2], i), img)
if(i%1000==0):
print("%d/%d hava finsh save." % (i,fea_len))
def contact_pos_feather(self,pos_src, other_feather,target):
if os.path.exists(target):
print("檔案已存在")
return
pos = pd.read_csv(pos_src, header=None, sep=' ')
feather = pd.read_csv(other_feather, header=None, sep='\t')
# fea = pd.concat([pos, feather], axis=1).sample(frac=1).reset_index(drop=True)
fea = pd.concat([pos, feather], axis=1)
print("pos Length=%d,feather Length=%d,fea Length=%d" % (len(pos), len(feather), len(fea)))
# print(type(fea))
del feather
del pos
fea = pd.DataFrame(fea)
fea.to_csv(target, index=None, header=None)
if __name__ == '__main__':
tiff=Tiff(r"D:/tr_sample_1.txt",r"D:/train1.txt",r"tr_1.txt")
# tiff=Tiff(r"D:/te_sample_1.txt",r"D:/test1.txt",r"te_1.txt")
# tiff.get_cells("data/1_test")