1. 程式人生 > 實用技巧 >golang呼叫tensorflow keras訓練的音訊分類模型

golang呼叫tensorflow keras訓練的音訊分類模型

1 實現場景分析

業務在外呼中經常會遇到接聽者因忙或者空號導致返回的回鈴音被語音識別引擎識別並傳遞給業務流程解析,而這種情況會在外呼後的業務統計中導致接通率的統計較低,為了解決該問題,打算在回鈴音進入語音識別引擎前進行識別,判斷為非接通的則直接丟棄不在接入流程處理。
經過對場景中的錄音音訊及語音識別的文字進行分析,發現大部分的誤識別回鈴音都是客戶忙或者是空號,與正常接通的音訊特徵區分很明顯,如下所示採用科大訊飛的語音識別引擎對失敗的回鈴音轉寫的結果

從轉寫結果統計也驗證了我們的分析。(針對回鈴音為視訊彩鈴等的暫時沒有統計到,這裡也不作為主要的失敗音訊分析)

2 模型訓練實現基於深度學習的聲音分類

實際實踐參考 keras實現聲音二分類 文章中有對音訊特徵mfcc的說明,流程分析也很詳細,可以參考。這裡主要貼下在驗證中使用的程式碼,模型訓練程式碼

import os
import keras
import librosa
import numpy as np
import matplotlib.pyplot as plt
from keras import Sequential
from keras.utils import to_categorical
from keras.layers import Dense
from sklearn.model_selection import train_test_split
import tensorflow as tf
from keras import backend as  K

DATA = 'data.npy'
TARGET = 'target.npy'


def load_label(label_path):
    """
    遍歷當前給定的目錄,便於後文進行標籤載入
    :param label_path:
    :return:
    """
    label = os.listdir(label_path)
    return label


# 提取 mfcc 引數
def wav2mfcc(path, max_pad_size=11):
    """
    備註:由於我們拿到的音訊檔案,持續時間都不盡相同,所以提取到的 mfcc 大小是不相同的。
    但是神經網路要求待處理的矩陣大小要相同,所以這裡我們用到了鋪平操作。我們 mfcc 係數預設提取 20 幀,對於每一幀來說,
    如果幀長小於 11,我們就用 0 填滿不滿足要求的幀;如果幀長大於 11,我們就只選取前 11 個引數
    :param path:    音訊檔案地址
    :param max_pad_size:    幀長,最大設定為11
    :return:
    """
    # 讀取音訊檔案,按照音訊本身的取樣率進行讀取
    y, sr = librosa.load(path=path, sr=None, mono=True)
    y = y[::3]  # 不需要太高的取樣率資料,這裡進行每三個點選用一個
    audio_mac = librosa.feature.mfcc(y=y, sr=16000)
    y_shape = audio_mac.shape[1]
    if y_shape < max_pad_size:
        """
        函式numpy.pad(array, pad_width, mode),其中 array 是我們需要填充的矩陣,pad_width是各個維度上首尾填充的個數。
        舉個例子,假定我們設定的 pad_width 是((0,0), (0,2)),而待處理的 mfcc 係數是 20 * 11 的矩陣。
        我們把 mfcc 係數看成 20 行 11 列的矩陣,進行 pad 操作,第一個(0,0)對行進行操作,
        表示每一行最前面和最後面增加的數個數為零,也就相當於總共增加了 0 列。第二個(0,2)對列操作,
        表示每一列最前面增加的數為 0 個,但最後面要增加兩個數,也就相當於總共增加了 2 行。
        mode 設定為 ‘constant’,表明填充的是常數,且預設為 0 
        """
        pad_size = max_pad_size - y_shape
        audio_mac = np.pad(audio_mac, ((0, 0), (0, pad_size)), mode='constant')
    else:
        audio_mac = audio_mac[:, :max_pad_size]
    return audio_mac


def save_data_to_array(label_path, max_pad_size=11):
    """
    儲存處理過的資料,方便下一次的使用
    :param label_path:
    :param max_pad_size:
    :return:
    """
    mfcc_vectors = []
    target = []
    labels = load_label(label_path=label_path)
    for i, label in enumerate(labels):
        path = label_path + '/' + label
        wavfiles = [path + '/' + file for file in os.listdir(path)]
        for wavfile in wavfiles:
            wav = wav2mfcc(wavfile, max_pad_size=max_pad_size)
            mfcc_vectors.append(wav)
            target.append(i)
    np.save(DATA, mfcc_vectors)
    np.save(TARGET, target)
    # return mfcc_vectors, target


def get_train_test(split_ratio=.6, random_state=42):
    """
    使用sklearn 中的train_test_split,把資料集分為訓練集和驗證集。其中訓練集佔 6 成,測試集佔 4 成
    :param split_ratio:
    :param random_state:
    :return:
    """
    X = np.load(DATA)
    y = np.load(TARGET)
    assert X.shape[0] == y.shape[0]
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=(1 - split_ratio), random_state=random_state,
                                                        shuffle=True)
    return X_train, X_test, y_train, y_test


def main():
    x_train, x_test, y_train, y_test = get_train_test()
    # 變成二維矩陣且第二個維度大小為 220
    x_train = x_train.reshape(-1, 220)
    x_test = x_test.reshape(-1, 220)
    # 使用kears中的onehot編碼
    y_train_hot = to_categorical(y_train)
    y_test_hot = to_categorical(y_test)
    model = Sequential()
    model.add(Dense(64, activation='relu', input_shape=(220,), name="input_layer"))
    model.add(Dense(64, activation='relu', name="dropout_layer1"))
    model.add(Dense(64, activation='relu', name="dropout_layer2"))
    model.add(Dense(2, activation='softmax', name="output_layer"))

    # 模型訓練
    sess = tf.Session()
    K.set_session(sess)
    # 這步找到input_layer和output_layer的完整路徑,在golang中使用時需要用來定義輸入輸出node
    for n in sess.graph.as_graph_def().node:
        if 'input_layer' in n.name:
            print(n.name)
        if 'output_layer' in n.name:
            print(n.name)

    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=keras.optimizers.RMSprop(),
                  metrics=['accuracy'])
    history = model.fit(x_train, y_train_hot, batch_size=100, epochs=100, verbose=1,
                        validation_data=(x_test, y_test_hot))

    # 以下是關鍵程式碼
    # Use TF to save the graph model instead of Keras save model to load it in Golang
    builder = tf.saved_model.builder.SavedModelBuilder("cnnModel")
    # Tag the model, required for Go
    builder.add_meta_graph_and_variables(sess, ["myTag"])
    builder.save()

    model.save("classaud.h5")
    plot_history(history)

    sess.close()


def save():
    label_path = 'F:\\doc\\專案\\音訊分類\\audio-heji\\'
    save_data_to_array(label_path, max_pad_size=11)


def plot_history(history):
    plt.plot(history.history['acc'], label='train')
    plt.plot(history.history['val_acc'], label='validation')
    plt.legend()
    plt.show()


if __name__ == "__main__":
    save_data_to_array("F:\\doc\\專案\\音訊分類\\audio-heji\\", max_pad_size=11)
    # save_data_to_array("/home/audio/audio-heji/", max_pad_size=11)
    main()

音訊檔案重新命名程式碼

import os


def rename(pic_path):
    """
    這裡對兩個目錄下檔案進行排序,第一個fail-audio設定成10001這種,第二個success-audio目錄下設定成了90001這種。
    :param pic_path:
    :return:
    """
    piclist = os.listdir(pic_path)
    i = 1
    print("ok")
    for pic in piclist:
        if pic.endswith(".wav"):
            old_path = os.path.join(os.path.abspath(pic_path), pic)
            new_path = os.path.join(os.path.abspath(pic_path), str(
                90000 + (int(i))) + '.wav')
            os.renames(old_path, new_path)
            print("把原命名格式:" + old_path + u"轉換為新命名格式:" + new_path)
            i = i + 1


# 載入標籤
def load_label(label_path):
    label = os.listdir(label_path)
    return label

if __name__ == '__main__':
    rename("F:\\doc\\專案\\音訊分類\\audio-heji\\success-audio")

測試程式碼

import librosa
import numpy as np
import os
from keras.models import load_model


# 提取 mfcc 引數
def wav2mfcc(path, max_pad_size=11):
    y, sr = librosa.load(path=path, sr=None, mono=1)
    y = y[::3]  # 每三個點選用一個
    audio_mac = librosa.feature.mfcc(y=y, sr=16000)
    y_shape = audio_mac.shape[1]
    if y_shape < max_pad_size:
        pad_size = max_pad_size - y_shape
        audio_mac = np.pad(audio_mac, ((0, 0), (0, pad_size)), mode='constant')
    else:
        audio_mac = audio_mac[:, :max_pad_size]
    return audio_mac


def load_label(label_path):
    """
    遍歷當前給定的目錄,便於後文進行標籤載入
    :param label_path:
    :return:
    """
    label = os.listdir(label_path)
    return label


if __name__ == '__main__':
    # 載入模型
    model = load_model('classaud.h5')  # 載入訓練模型
    wavs = [wav2mfcc("F:\\doc\\專案\\音訊分類\\test\\" + file, 11) for file in os.listdir("F:\\doc\\專案\\音訊分類\\test")]
    X = np.array(wavs)
    X = X.reshape(-1, 220)
    print(X.shape)

    for j in range(X.shape[0]):
        print(j)
        print(X[j:j+1])
        result = model.predict(X[j:j+1])[0]  #
        print("識別結果", result)
        #  因為在訓練的時候,標籤集的名字 為:  0:fail-audio   1:success-audio
        name = ["fail-audio", "success-audio"]  # 建立一個跟訓練時一樣的標籤集
        ind = 0  # 結果中最大的一個數
        for i in range(len(result)):
            if result[i] > result[ind]:
                ind = 1
        print("識別的語音結果是:", name[ind])

在模型訓練中,為了後續golang中載入訓練好的模型,增加了部分程式碼,主要是如下

# 模型訓練
    sess = tf.Session()
    K.set_session(sess)
    # 這步找到input_layer和output_layer的完整路徑,在golang中使用時需要用來定義輸入輸出node
    for n in sess.graph.as_graph_def().node:
        if 'input_layer' in n.name:
            print(n.name)
        if 'output_layer' in n.name:
            print(n.name)

    model.compile(loss=keras.losses.categorical_crossentropy,
                  optimizer=keras.optimizers.RMSprop(),
                  metrics=['accuracy'])
    history = model.fit(x_train, y_train_hot, batch_size=100, epochs=100, verbose=1,
                        validation_data=(x_test, y_test_hot))

    # 以下是關鍵程式碼
    # Use TF to save the graph model instead of Keras save model to load it in Golang
    builder = tf.saved_model.builder.SavedModelBuilder("cnnModel")
    # Tag the model, required for Go
    builder.add_meta_graph_and_variables(sess, ["myTag"])
    builder.save()

模型訓練完成後會生成響應的模型,其中cnnModel資料夾包含pd模型及variables資料夾,為後續golang排程使用的模型,h5模型為這裡測試使用的模型,結果如下圖

另外在模型訓練時,我們打印出了每層神經網路的名字,這塊需要關注,因為在後續的golang環境中因為載入節點名錯誤導致的問題,列印如下

WARNING:tensorflow:From E:\pycharm\nlu-algorithm-package\venv\lib\site-packages\tensorflow\python\framework\op_def_library.py:263: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.
Instructions for updating:
Colocations handled automatically by placer.
2020-08-27 19:06:19.437965: I tensorflow/core/platform/cpu_feature_guard.cc:141] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
input_layer_input
input_layer/random_uniform/shape
input_layer/random_uniform/min
input_layer/random_uniform/max
input_layer/random_uniform/RandomUniform
input_layer/random_uniform/sub
input_layer/random_uniform/mul
input_layer/random_uniform
input_layer/kernel
input_layer/kernel/Assign
input_layer/kernel/read
input_layer/Const
input_layer/bias
input_layer/bias/Assign
input_layer/bias/read
input_layer/MatMul
input_layer/BiasAdd
input_layer/Relu
output_layer/random_uniform/shape
output_layer/random_uniform/min
output_layer/random_uniform/max
output_layer/random_uniform/RandomUniform
output_layer/random_uniform/sub
output_layer/random_uniform/mul
output_layer/random_uniform
output_layer/kernel
output_layer/kernel/Assign
output_layer/kernel/read
output_layer/Const
output_layer/bias
output_layer/bias/Assign
output_layer/bias/read
output_layer/MatMul
output_layer/BiasAdd
output_layer/Softmax

測試中的一條結果記錄為

25
[[-1.66098328e+02 -2.12202209e+02 -4.31954193e+02 -4.47424835e+02
  -3.33869904e+02 -3.58775604e+02 -5.43935608e+02 -6.60088867e+02
  -2.30052383e+02 -6.87607422e+01 -1.21050777e+01  4.23143768e+01
   4.26393929e+01  4.03902130e+01  3.61016235e+01  3.74344673e+01
   3.76606216e+01  2.93862953e+01  1.91747296e+00 -9.67822647e+00
  -1.10491867e+01 -1.35342636e+01 -6.71506195e+01 -6.73550262e+01
  -6.63572083e+01 -4.63606071e+01 -4.91609383e+01 -5.08716202e+01
  -4.48681793e+01 -1.87636101e+00 -4.61852417e+01 -4.17765961e+01
  -2.01916142e+01 -9.79832268e+00 -9.91659927e+00 -1.03536911e+01
  -3.72100983e+01 -3.87253189e+01 -3.71894073e+01 -2.30876770e+01
   7.12325764e+00 -1.86448917e+01 -1.23296547e+01 -1.36949463e+01
   2.36775169e+01  2.40340843e+01  2.65298653e+01 -1.50987358e+01
  -1.75523911e+01 -1.85126324e+01 -1.25076523e+01  7.13483989e-01
   1.25870705e+01  9.64732361e+00  1.51968069e+01  1.57215271e+01
   1.60245552e+01  1.79179173e+01  8.83283997e+00  6.01854324e+00
   5.03554726e+00  7.35849476e+00  7.32028627e+00  2.61785851e+01
   2.36798325e+01  2.17165947e+01 -8.92044258e+00 -8.98554516e+00
  -9.36619282e+00  8.72706318e+00  7.63292933e+00  8.11320019e+00
   1.23650103e+01  4.05273724e+00 -2.74963975e+00  6.03793526e+00
   3.32019657e-01 -1.07288876e+01 -1.08059797e+01 -1.12256699e+01
   3.14207101e+00  2.12598038e+00  3.41194320e+00  1.52846613e+01
   3.58989859e+00 -7.22187281e+00 -2.92357159e+00 -1.40336580e+01
   1.25538235e+01  1.30331860e+01  1.53311596e+01  4.33768129e+00
   3.44467473e+00  3.40012264e+00  8.20392323e+00  4.26902437e+00
   2.08825417e+01  1.70654850e+01  6.18888092e+00  1.60812531e+01
   1.66913948e+01  1.89202042e+01 -3.88430738e+00 -3.84901094e+00
  -2.63745117e+00 -1.11108208e+00  6.36666417e-01 -2.25304246e+00
  -8.85197830e+00 -1.91202374e+01 -1.46577859e+00 -1.23925185e+00
   7.83551395e-01  2.47491241e+00  1.99022865e+00  1.92004573e+00
  -4.81319517e-01  5.03325987e+00 -3.67527580e+00  1.00166121e+01
   1.32022190e+01  1.41424477e+00  1.61442292e+00  1.70821917e+00
   1.69235497e+01  1.71120987e+01  1.50358353e+01  1.89076972e+00
   6.39788628e+00 -1.38607597e+00  1.02029294e-01 -1.02280178e+01
   4.86864090e+00  5.27380610e+00  7.52170134e+00 -5.28688669e+00
  -4.54772520e+00 -1.99729645e+00  9.84556198e+00  6.01770210e+00
   6.42514515e+00  5.05019569e+00 -3.58427215e+00  2.53060675e+00
   2.77301073e+00  3.36337566e+00 -3.74559736e+00 -3.88710737e+00
  -2.85192370e+00  8.00442696e+00  5.90060949e-01  4.97644138e+00
   8.34950066e+00  6.98132086e+00 -2.52411366e+00 -2.62051678e+00
  -2.87881041e+00  7.85154676e+00  7.58860874e+00  7.21512508e+00
   6.53992605e+00 -2.57507980e-01 -4.69269657e+00 -3.40787172e+00
   1.32537198e+00 -4.12547922e+00 -4.38115311e+00 -5.22326088e+00
   2.03997135e+00  2.89152718e+00  4.46722126e+00  6.32854557e+00
   5.27089882e+00  8.66948891e+00  5.42871141e+00  1.18013754e+01
  -1.06842923e+00 -1.21782100e+00 -1.62583649e+00  4.68027020e+00
   4.82862568e+00  5.17801666e+00  5.02924442e+00  3.57280898e+00
  -1.60658951e+01 -3.89933228e+00  5.28476810e+00  7.85575271e-01
   7.23506689e-01  9.12119806e-01  1.29786149e-01 -1.08789623e+00
  -2.01344442e+00 -3.12455368e+00  4.22501802e+00 -2.05132604e-01
   4.64199352e+00  1.28417645e+01 -2.14332151e+00 -2.28186941e+00
  -2.92554092e+00 -7.33241290e-02 -4.54517424e-01 -9.12135363e-01
  -1.92673039e+00  1.32999837e+00 -5.95955181e+00 -1.38899193e+01
  -1.53170991e+00 -3.57915735e+00 -3.69184279e+00 -3.76903296e+00
  -3.22209269e-01 -6.59340382e-01  5.86225927e-01  1.03645115e+01
   2.81656504e+00 -1.55326450e+00 -1.87907255e+00 -2.12706447e+00]]
識別結果 [1.4359492e-26 1.0000000e+00]
識別的語音結果是: success-audio

後續的golang呼叫tensorflow模型中我們以此結果作為測試。
備註:此處使用的tensorflow為1.13.1版本、kears為2.2.4、python3.6.0,音訊特徵抽取使用的librosa庫,版本0.8.0。

其他方式實現的音訊分類,在實踐中也參考了 Python Project – Music Genre Classification 該文章時使用的K近鄰方式實現的,在實踐中也做了此方式,準確率在70%左右,也可以參考。

3 golang呼叫tensorflow/keras訓練的模型

安裝 go 版 TensorFlow的問題記錄 我們已經將相關環境設定好,此處只需要完成相關的go程式碼即可上線測試。程式碼驗證參考了 golang呼叫tensorflow/keras訓練的模型 此處先給出基於此文章實現的驗證程式碼

package main
 
import (
        "fmt"
        tf "github.com/tensorflow/tensorflow/tensorflow/go"
)
 
func main() {
        // 特徵長度
        const MAXLEN int = 220
        // 將文字轉換為id序列,為了實驗方便直接使用轉換好的ID序列即可,此處是使用上文中測試中打印出來的音訊特徵
        input_data := [1][MAXLEN]float32{{-1.66098328e+02,-2.12202209e+02,-4.31954193e+02,-4.47424835e+02,-3.33869904e+02,-3.58775604e+02,-5.43935608e+02,-6.60088867e+02,-2.30052383e+02,-6.87607422e+01,-1.21050777e+01,4.23143768e+01,4.26393929e+01,4.03902130e+01,3.61016235e+01,3.74344673e+01,3.76606216e+01,2.93862953e+01,1.91747296e+00,-9.67822647e+00,-1.10491867e+01,-1.35342636e+01,-6.71506195e+01,-6.73550262e+01,-6.63572083e+01,-4.63606071e+01,-4.91609383e+01,-5.08716202e+01,-4.48681793e+01,-1.87636101e+00,-4.61852417e+01,-4.17765961e+01,-2.01916142e+01,-9.79832268e+00,-9.91659927e+00,-1.03536911e+01,-3.72100983e+01,-3.87253189e+01,-3.71894073e+01,-2.30876770e+01,7.12325764e+00,-1.86448917e+01,-1.23296547e+01,-1.36949463e+01,2.36775169e+01,2.40340843e+01,2.65298653e+01,-1.50987358e+01,-1.75523911e+01,-1.85126324e+01,-1.25076523e+01,7.13483989e-01,1.25870705e+01,9.64732361e+00,1.51968069e+01,1.57215271e+01,1.60245552e+01,1.79179173e+01,8.83283997e+00,6.01854324e+00,5.03554726e+00,7.35849476e+00,7.32028627e+00,2.61785851e+01,2.36798325e+01,2.17165947e+01,-8.92044258e+00,-8.98554516e+00,-9.36619282e+00,8.72706318e+00,7.63292933e+00,8.11320019e+00,1.23650103e+01,4.05273724e+00,-2.74963975e+00,6.03793526e+00,3.32019657e-01,-1.07288876e+01,-1.08059797e+01,-1.12256699e+01,3.14207101e+00,2.12598038e+00,3.41194320e+00,1.52846613e+01,3.58989859e+00,-7.22187281e+00,-2.92357159e+00,-1.40336580e+01,1.25538235e+01,1.30331860e+01,1.53311596e+01,4.33768129e+00,3.44467473e+00,3.40012264e+00,8.20392323e+00,4.26902437e+00,2.08825417e+01,1.70654850e+01,6.18888092e+00,1.60812531e+01,1.66913948e+01,1.89202042e+01,-3.88430738e+00,-3.84901094e+00,-2.63745117e+00,-1.11108208e+00,6.36666417e-01,-2.25304246e+00,-8.85197830e+00,-1.91202374e+01,-1.46577859e+00,-1.23925185e+00,7.83551395e-01,2.47491241e+00,1.99022865e+00,1.92004573e+00,-4.81319517e-01,5.03325987e+00,-3.67527580e+00,1.00166121e+01,1.32022190e+01,1.41424477e+00,1.61442292e+00,1.70821917e+00,1.69235497e+01,1.71120987e+01,1.50358353e+01,1.89076972e+00,6.39788628e+00,-1.38607597e+00,1.02029294e-01,-1.02280178e+01,4.86864090e+00,5.27380610e+00,7.52170134e+00,-5.28688669e+00,-4.54772520e+00,-1.99729645e+00,9.84556198e+00,6.01770210e+00,6.42514515e+00,5.05019569e+00,-3.58427215e+00,2.53060675e+00,2.77301073e+00,3.36337566e+00,-3.74559736e+00,-3.88710737e+00,-2.85192370e+00,8.00442696e+00,5.90060949e-01,4.97644138e+00,8.34950066e+00,6.98132086e+00,-2.52411366e+00,-2.62051678e+00,-2.87881041e+00,7.85154676e+00,7.58860874e+00,7.21512508e+00,6.53992605e+00,-2.57507980e-01,-4.69269657e+00,-3.40787172e+00,1.32537198e+00,-4.12547922e+00,-4.38115311e+00,-5.22326088e+00,2.03997135e+00,2.89152718e+00,4.46722126e+00,6.32854557e+00,5.27089882e+00,8.66948891e+00,5.42871141e+00,1.18013754e+01,-1.06842923e+00,-1.21782100e+00,-1.62583649e+00,4.68027020e+00,4.82862568e+00,5.17801666e+00,5.02924442e+00,3.57280898e+00,-1.60658951e+01,-3.89933228e+00,5.28476810e+00,7.85575271e-01,7.23506689e-01,9.12119806e-01,1.29786149e-01,-1.08789623e+00,-2.01344442e+00,-3.12455368e+00,4.22501802e+00,-2.05132604e-01,4.64199352e+00,1.28417645e+01,-2.14332151e+00,-2.28186941e+00,-2.92554092e+00,-7.33241290e-02,-4.54517424e-01,-9.12135363e-01,-1.92673039e+00,1.32999837e+00,-5.95955181e+00,-1.38899193e+01,-1.53170991e+00,-3.57915735e+00,-3.69184279e+00,-3.76903296e+00,-3.22209269e-01,-6.59340382e-01,5.86225927e-01,1.03645115e+01,2.81656504e+00,-1.55326450e+00,-1.87907255e+00,-2.12706447e+00}}
        tensor, err := tf.NewTensor(input_data)
        if err != nil {
                fmt.Printf("Error NewTensor: err: %s", err.Error())
                return
        }
        //讀取模型
        model, err := tf.LoadSavedModel("cnnModel", []string{"myTag"}, nil)
        if err != nil {
                fmt.Printf("Error loading Saved Model: %s\n", err.Error())
                return
        }
        // 識別
        result, err := model.Session.Run(
                map[tf.Output]*tf.Tensor{
                        // python版tensorflow/keras中定義的輸入層input_layer
                        model.Graph.Operation("input_layer").Output(0): tensor,
                },
                []tf.Output{
                        // python版tensorflow/keras中定義的輸出層output_layer
                        model.Graph.Operation("output_layer/Softmax").Output(0),
                },
                nil,
        )
 
        if err != nil {
                fmt.Printf("Error running the session with input, err: %s  ", err.Error())
                return
        }
        // 輸出結果,interface{}格式
        fmt.Printf("Result value: %v", result[0].Value())
}

將訓練好的模型及驗證程式上傳到虛擬機器的gopath目錄

[root@localhost gopath]# cd /home/gopath/
[root@localhost gopath]# 
[root@localhost gopath]# ll
總用量 8
-rw-r--r--. 1 root root 5072 8月  27 21:21 cnn.go
drwxr-xr-x. 3 root root   45 8月  27 19:27 cnnModel
drwxr-xr-x. 3 root root   25 8月  27 15:40 pkg
drwxr-xr-x. 5 root root   65 8月  27 15:40 src

執行程式後報錯,如下所示

[root@localhost gopath]# go run cnn.go 
2020-08-27 21:22:55.306065: I tensorflow/cc/saved_model/reader.cc:31] Reading SavedModel from: cnnModel
2020-08-27 21:22:55.312739: I tensorflow/cc/saved_model/reader.cc:54] Reading meta graph with tags { myTag }
2020-08-27 21:22:55.320111: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
2020-08-27 21:22:55.335002: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 1799995000 Hz
2020-08-27 21:22:55.335970: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x1ba6ea0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-08-27 21:22:55.336091: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2020-08-27 21:22:55.364460: I tensorflow/cc/saved_model/loader.cc:202] Restoring SavedModel bundle.
2020-08-27 21:22:55.449190: I tensorflow/cc/saved_model/loader.cc:311] SavedModel load for tags { myTag }; Status: success. Took 143138 microseconds.
模型read成功panic: nil-Operation. If the Output was created with a Scope object, see Scope.Err() for details.

goroutine 1 [running]:
github.com/tensorflow/tensorflow/tensorflow/go.Output.c(...)
	/home/gopath/src/github.com/tensorflow/tensorflow/tensorflow/go/operation.go:130
github.com/tensorflow/tensorflow/tensorflow/go.newCRunArgs(0xc000097e50, 0xc000097e20, 0x1, 0x1, 0x0, 0x0, 0x0, 0xc000097668)
	/home/gopath/src/github.com/tensorflow/tensorflow/tensorflow/go/session.go:369 +0x594
github.com/tensorflow/tensorflow/tensorflow/go.(*Session).Run(0xc00000e0c0, 0xc000097e50, 0xc000097e20, 0x1, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, ...)
	/home/gopath/src/github.com/tensorflow/tensorflow/tensorflow/go/session.go:143 +0x1e2
main.main()
	/home/gopath/cnn.go:27 +0x1199
exit status 2
[root@localhost gopath]#

錯誤顯示找不到輸出節點,可是我們在程式碼中已經設定了相關的輸出節點資訊,如下所示

map[tf.Output]*tf.Tensor{
                        // python版tensorflow/keras中定義的輸入層input_layer
                        model.Graph.Operation("input_layer").Output(0): tensor,
                },
                []tf.Output{
                        // python版tensorflow/keras中定義的輸出層output_layer
                        model.Graph.Operation("output_layer/Softmax").Output(0),
                },

那找不到相關的節點是否是節點名字繫結問題,我們回看在模型訓練時列印的各層節點名稱,發現是以“input_layer_input”開始“output_layer/Softmax”結束,檢視我們程式碼中是按照訓練中設定的name進行標記的,故將輸入的開始節點從“input_layer”修改為“input_layer_input”,如下

map[tf.Output]*tf.Tensor{
                        // python版tensorflow/keras中定義的輸入層input_layer
                        model.Graph.Operation("input_layer_input").Output(0): tensor,
                },
                []tf.Output{
                        // python版tensorflow/keras中定義的輸出層output_layer
                        model.Graph.Operation("output_layer/Softmax").Output(0),
                },

重新執行程式,結果如下

[root@localhost gopath]# go run cnn.go 
2020-08-27 21:28:33.161277: I tensorflow/cc/saved_model/reader.cc:31] Reading SavedModel from: cnnModel
2020-08-27 21:28:33.166881: I tensorflow/cc/saved_model/reader.cc:54] Reading meta graph with tags { myTag }
2020-08-27 21:28:33.173316: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2
2020-08-27 21:28:33.186007: I tensorflow/core/platform/profile_utils/cpu_utils.cc:94] CPU Frequency: 1799995000 Hz
2020-08-27 21:28:33.187068: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x13fbea0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-08-27 21:28:33.187126: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
2020-08-27 21:28:33.210363: I tensorflow/cc/saved_model/loader.cc:202] Restoring SavedModel bundle.
2020-08-27 21:28:33.294067: I tensorflow/cc/saved_model/loader.cc:311] SavedModel load for tags { myTag }; Status: success. Took 132796 microseconds.
模型read成功模型識別成功Result value: [[1.4359492e-26 1]]

對比上文給出的測試中的識別結果是保持一致的,故這裡golang排程keras訓練的模型成功,後續就是實際場景如何進行使用的問題規劃。