神經演算法與決策樹分析bankloan資料
阿新 • • 發佈:2022-03-27
神經網路分析程式碼如下
import pandas as pd filename = 'C:/Users/透心涼i/Desktop/data/data/bankloan.xls' data_tr = pd.read_excel(filename) #print(data_tr) # 匯入資料 #讀取資料 x_tr = data_tr.iloc[:,:8] y_tr = data_tr.iloc[:,8] #print(x_tr) #print(y_tr) from tensorflow.keras.models import Sequential from tensorflow.keras.layers importDense, Activation model = Sequential() # 建立模型 model.add(Dense(input_dim = 8, units = 16)) model.add(Activation('relu')) # 用relu函式作為啟用函式,能夠大幅提供準確度 model.add(Dense(input_dim = 16, units = 32)) model.add(Activation('sigmoid')) model.add(Dense(input_dim = 32, units = 1)) model.add(Activation('sigmoid')) # 由於是0-1輸出,用sigmoid函式作為啟用函式 model.compile(loss = 'binary_crossentropy', optimizer = 'adam') # 編譯模型。由於我們做的是二元分類,所以我們指定損失函式為binary_crossentropy,以及模式為binary # 另外常見的損失函式還有mean_squared_error、categorical_crossentropy等,請閱讀幫助檔案。 # 求解方法我們指定用adam,還有sgd、rmsprop等可選 model.fit(x_tr, y_tr, epochs = 1000, batch_size = 10) #訓練模型,學習一千次 yp = model.predict(x_tr).reshape(len(y_tr)) # 分類預測 score = model.evaluate(x_tr, y_tr, batch_size=256) #分類預測損失值 print("分類預測損失值") print(score)
結果如圖:
決策樹分析程式碼如下:
import pandas as pd # 引數初始化 import pandas as pd import os os.chdir('C:/Users/透心涼i') data = pd.read_excel('bankloan.xls') x = data.iloc[:,:8].astype(int) y = data.iloc[:,8].astype(int) from sklearn.tree import DecisionTreeClassifier as DTC dtc = DTC(criterion='entropy') # 建立決策樹模型,基於資訊熵 dtc.fit(x, y) # 訓練模型 # 匯入相關函式,視覺化決策樹。 # 匯出的結果是一個dot檔案,需要安裝Graphviz才能將它轉換為pdf或png等格式。 from sklearn.tree import export_graphviz x = pd.DataFrame(x) with open(r"C:/Users/透心涼i/Desktop/data/tree3.dot", 'w',encoding="utf-8") as f: export_graphviz(dtc, feature_names = x.columns, out_file = f) f.close() from IPython.display import Image from sklearn import tree import pydotplus import os os.environ["PATH"] += os.pathsep + 'C:/Program Files/Graphviz/bin/' dot_data = tree.export_graphviz(dtc, out_file=None, #regr_1 是對應分類器 feature_names=data.columns[:8], #對應特徵的名字 class_names=data.columns[8], #對應類別的名字 filled=True, rounded=True, special_characters=True) dot_data = dot_data.replace('helvetica', 'MicrosoftYaHei') graph = pydotplus.graph_from_dot_data(dot_data) graph.write_png('C:/Users/透心涼i/Desktop/data/example.png') #儲存影象 Image(graph.create_png())
結果如下: