1. 程式人生 > >機器學習工具代碼

機器學習工具代碼

input per where 位置 n) enter pri http dense

(持續整理)

數組閾值處理

"""
img 為圖像數組,同時也是numpy數組  
將img數據小於min的都設為min,同時將大於max的都設為max  
"""
img[np.where(img < min)] = min  
img[np.where(img > 250)] = max  

歸一化和中心化

min = np.min(img)
max = np.max(img)
center = (min + max) / 2
img = (img - center) /(max - min) * 2

最大聯通域

from skimage import measure


def max_connected_domain_3D(arr):
    # 取相同數字的最大連通域
    labels = measure.label(arr)  # <1.2s
    t = np.bincount(labels.flatten())[1:]  # <1.5s
    max_pixel = np.argmax(t) + 1  # 位置變了,去除了0
    labels[labels != max_pixel] = 0
    labels[labels == max_pixel] = 1
    return labels.astype(np.uint8)

# 測試  
arr = [[1, 1, 0, 3], [1, 0, 3, 3], [0, 1, 3, 3], [0, 0, 0, 0]]
arr = np.asarray(arr)
print(arr)
print(max_connected_domain_3D(arr))

\[ 1 1 0 3\1 0 3 3\0 1 3 3\0 0 0 0\\]
\[ \Downarrow \]
\[ 0 0 0 1\0 0 1 1\0 0 1 1\0 0 0 0 \]

arr = np.squeeze(arr) # 從數組的形狀中刪除單維度條目,即把shape中為1的維度去掉
y = np.transpose(y,(1,2,0))  # 將數組的軸交換 (0, 1, 2) => (1, 2, 0)
"""
出處為寫nrrd文件的時候,可以考慮nrrd的數組存儲形式與正常數組維度不一致
"""

繪制模型

from keras.utils import plot_model

plot_model(model, "RUnet.png", True)

demo

from keras import models
from keras import layers
from keras import regularizers
from keras.utils import plot_model


def get_model(x, y, z):
    model = models.Sequential()
    model.add(layers.Conv3D(16, (3, 3, 2), activation='relu', input_shape=(x, y, z, 1)))
    model.add(layers.BatchNormalization())
    model.add(layers.Conv3D(8, (3, 3, 2), activation='relu', kernel_regularizer=regularizers.l2(0.1)))
    model.add(layers.BatchNormalization())
    model.add(layers.Conv3D(8, (3, 3, 2), activation='relu', kernel_regularizer=regularizers.l2(0.1)))
    model.add(layers.BatchNormalization())
    model.add(layers.Conv3D(8, (3, 3, 1), activation='relu', kernel_regularizer=regularizers.l2(0.1)))
    model.add(layers.Dropout(rate=0.1))
    model.add(layers.BatchNormalization())
    model.add(layers.Flatten())
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(13, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(8, activation='relu'))
    model.add(layers.BatchNormalization())
    model.add(layers.Dense(8, activation='relu'))
    model.add(layers.Dense(2, activation='sigmoid'))
    model.summary()
    return model

if __name__ == '__main__':
    model = get_model(125, 125, 10)
    plot_model(model, r"C:\Users\fan\Desktop\model.png", True)
    

效果圖
技術分享圖片

註:需要安裝graphviz

數據混淆

def data_confusion(data, label):
    # 進行數據混淆
    permutation = np.random.permutation(label.shape[0])
    shuffled_data = data[permutation, :, :]
    shuffled_label = label[permutation]
    return shuffled_data, shuffled_label

機器學習工具代碼