迭代閾值影象分割.py
阿新 • • 發佈:2022-12-08
""" Created on 2020/12/29 16:00. @Author: yubaby@anne @Email: [email protected] """ import time import gdal from gdalconst import * import numpy as np def tif_read(tifpath): image = gdal.Open(tifpath, GA_ReadOnly) im_clos = image.RasterXSize im_rows = image.RasterYSize band_array = image.ReadAsArray(0, 0, im_clos, im_rows).astype(np.float) del image return band_array def tif_write(im_data, path): if 'int8' in im_data.dtype.name: datatype = gdal.GDT_Byte elif 'int16' in im_data.dtype.name: datatype = gdal.GDT_UInt16 else: datatype = gdal.GDT_Float32 if len(im_data.shape) == 3: im_bands, im_height, im_width = im_data.shape elif len(im_data.shape) == 2: im_data = np.array([im_data]) im_bands, im_height, im_width = im_data.shape driver = gdal.GetDriverByName("GTiff") dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype) for i in range(im_bands): dataset.GetRasterBand(i + 1).WriteArray(im_data[i]) del dataset def iteration_threshold(F): """ (1)設定初始閾值t1。 當目標與背景的面積相當時,可以將初始閾值t1置為整幅影象的平均灰度; 當目標與背景的面積相差較大時,更好的選擇是將初始閾值t1置為最大灰度值與最小灰度值的中間值。 """ t1 = np.mean(F) ''' (2)根據t1將影象F分割為F1和F2兩部分, 其中F1包含所有灰度值小於t1的畫素,F2包含所有大於t1的畫素,分別求出F1和F2的平均灰度值μ1和μ2。 ''' # F = F.flatten() F1 = F[F <= t1] F2 = F[F > t1] u1 = np.mean(F1) u2 = np.mean(F2) ''' (3)計算新的閾值t2=(μ1+μ2)/2。 ''' t2 = (u1 + u2) / 2 ''' (4)指定常數t0(很小的正數), 如果|t2-t1|<=t0,即迭代過程中前後兩次閾值很接近時(或者說μ1和μ2不再變化),終止迭代; 否則令t1=t2,重複步驟(2)、(3)、(4)。 設定常數t0的目的是為了加快迭代速度,如果不關心迭代速度,則可以設定為0。 ''' t0 = 0 count = 0 while abs(t2 - t1) != t0: # PS:t0=0時,|t2-t1|不可能<0,所以,條件變為|t2-t1|=0,即t1=t2時停止迭代 t1 = t2 F1 = F[F <= t1] F2 = F[F > t1] u1 = np.mean(F1) u2 = np.mean(F2) t2 = (u1 + u2) / 2 count += 1 print('迭代次數:', count) return t2 def segmentation(F, T): F[F >= T] = 1 F[F < T] = 0 return F if __name__ == '__main__': start = time.time() path_input = 'J:\BaiduNetdiskDownload\GF3\GF3_KAS_FSI_013781_E121.4_N37.6_20190323_L1A_HHHV_L10003900117\zone\\tif\\' path_output = path_input + 'output\\' list_tif = ['lee_2019_C2_HH', # 可用資料:基本能目視看出目標地物 'lee_2019_C2_HV', 'lee_2019_DB_HH', 'lee_2019_DB_HV', 'mean_2019_C2_HH', 'mean_2019_C2_HV'] for tif in list_tif: img = tif_read(path_input + tif + '.tif') thr = iteration_threshold(img) result = segmentation(img, thr) tif_write(result, path_output + tif + '-' + str(thr) + '.tif') print('It takes', time.time() - start, "seconds.")