1. 程式人生 > 其它 >迭代閾值影象分割.py

迭代閾值影象分割.py

"""
Created on 2020/12/29 16:00.

@Author: yubaby@anne
@Email: [email protected]
"""


import time
import gdal
from gdalconst import *
import numpy as np


def tif_read(tifpath):
    image = gdal.Open(tifpath, GA_ReadOnly)
    im_clos = image.RasterXSize
    im_rows = image.RasterYSize
    band_array = image.ReadAsArray(0, 0, im_clos, im_rows).astype(np.float)
    del image
    return band_array


def tif_write(im_data, path):
    if 'int8' in im_data.dtype.name:
        datatype = gdal.GDT_Byte
    elif 'int16' in im_data.dtype.name:
        datatype = gdal.GDT_UInt16
    else:
        datatype = gdal.GDT_Float32
    if len(im_data.shape) == 3:
        im_bands, im_height, im_width = im_data.shape
    elif len(im_data.shape) == 2:
        im_data = np.array([im_data])
        im_bands, im_height, im_width = im_data.shape
    driver = gdal.GetDriverByName("GTiff")
    dataset = driver.Create(path, int(im_width), int(im_height), int(im_bands), datatype)
    for i in range(im_bands):
        dataset.GetRasterBand(i + 1).WriteArray(im_data[i])
    del dataset


def iteration_threshold(F):
    """
    (1)設定初始閾值t1。
    當目標與背景的面積相當時,可以將初始閾值t1置為整幅影象的平均灰度;
    當目標與背景的面積相差較大時,更好的選擇是將初始閾值t1置為最大灰度值與最小灰度值的中間值。
    """
    t1 = np.mean(F)

    '''
    (2)根據t1將影象F分割為F1和F2兩部分,
    其中F1包含所有灰度值小於t1的畫素,F2包含所有大於t1的畫素,分別求出F1和F2的平均灰度值μ1和μ2。
    '''
    # F = F.flatten()
    F1 = F[F <= t1]
    F2 = F[F > t1]
    u1 = np.mean(F1)
    u2 = np.mean(F2)

    '''
    (3)計算新的閾值t2=(μ1+μ2)/2。
    '''
    t2 = (u1 + u2) / 2

    '''
    (4)指定常數t0(很小的正數),
    如果|t2-t1|<=t0,即迭代過程中前後兩次閾值很接近時(或者說μ1和μ2不再變化),終止迭代;
    否則令t1=t2,重複步驟(2)、(3)、(4)。
    設定常數t0的目的是為了加快迭代速度,如果不關心迭代速度,則可以設定為0。
    '''
    t0 = 0
    count = 0
    while abs(t2 - t1) != t0:  # PS:t0=0時,|t2-t1|不可能<0,所以,條件變為|t2-t1|=0,即t1=t2時停止迭代
        t1 = t2
        F1 = F[F <= t1]
        F2 = F[F > t1]
        u1 = np.mean(F1)
        u2 = np.mean(F2)
        t2 = (u1 + u2) / 2
        count += 1
        print('迭代次數:', count)

    return t2


def segmentation(F, T):
    F[F >= T] = 1
    F[F < T] = 0
    return F


if __name__ == '__main__':
    start = time.time()

    path_input = 'J:\BaiduNetdiskDownload\GF3\GF3_KAS_FSI_013781_E121.4_N37.6_20190323_L1A_HHHV_L10003900117\zone\\tif\\'
    path_output = path_input + 'output\\'

    list_tif = ['lee_2019_C2_HH',  # 可用資料:基本能目視看出目標地物
                'lee_2019_C2_HV',
                'lee_2019_DB_HH',
                'lee_2019_DB_HV',
                'mean_2019_C2_HH',
                'mean_2019_C2_HV']

    for tif in list_tif:
        img = tif_read(path_input + tif + '.tif')
        thr = iteration_threshold(img)
        result = segmentation(img, thr)
        tif_write(result, path_output + tif + '-' + str(thr) + '.tif')

    print('It takes', time.time() - start, "seconds.")