sklearn:繪製由VotingClassifier計算的類概率¶
阿新 • • 發佈:2018-12-07
繪製由三個不同分類器預測並由VotingClassifier平均的玩具資料集中第一個樣本的類概率。
首先,初始化三個示例性分類器(LogisticRegression,GaussianNB和RandomForestClassifier)並用於初始化具有權重[1,1,5]的軟投票VotingClassifier,這意味著RandomForestClassifier的預測概率計數是該值的5倍。 計算平均概率時其他分類器的權重。
為了使概率加權視覺化,我們將每個分類器擬合到訓練集上,並繪製該示例資料集中第一個樣本的預測類概率。
import numpy as np import matplotlib.pyplot as plt from sklearn.linear_model import LogisticRegression from sklearn.naive_bayes import GaussianNB from sklearn.ensemble import RandomForestClassifier from sklearn.ensemble import VotingClassifier clf1 = LogisticRegression(random_state=123) clf2 = RandomForestClassifier(random_state=123) clf3 = GaussianNB() X = np.array([[-1.0, -1.0], [-1.2, -1.4], [-3.4, -2.2], [1.1, 1.2]]) y = np.array([1, 1, 2, 2]) eclf = VotingClassifier(estimators=[('lr', clf1), ('rf', clf2), ('gnb', clf3)], voting='soft', weights=[1, 1, 5]) # predict class probabilities for all classifiers probas = [c.fit(X, y).predict_proba(X) for c in (clf1, clf2, clf3, eclf)] # get class probabilities for the first sample in the dataset class1_1 = [pr[0, 0] for pr in probas] class2_1 = [pr[0, 1] for pr in probas] # plotting N = 4 # number of groups ind = np.arange(N) # group positions width = 0.35 # bar width fig, ax = plt.subplots() # bars for classifier 1-3 p1 = ax.bar(ind, np.hstack(([class1_1[:-1], [0]])), width, color='green') p2 = ax.bar(ind + width, np.hstack(([class2_1[:-1], [0]])), width, color='lightgreen') # bars for VotingClassifier p3 = ax.bar(ind, [0, 0, 0, class1_1[-1]], width, color='blue') p4 = ax.bar(ind + width, [0, 0, 0, class2_1[-1]], width, color='steelblue') # plot annotations plt.axvline(2.8, color='k', linestyle='dashed') ax.set_xticks(ind + width) ax.set_xticklabels(['LogisticRegression\nweight 1', 'GaussianNB\nweight 1', 'RandomForestClassifier\nweight 5', 'VotingClassifier\n(average probabilities)'], rotation=40, ha='right') plt.ylim([0, 1]) plt.title('Class probabilities for sample 1 by different classifiers') plt.legend([p1[0], p2[0]], ['class 1', 'class 2'], loc='upper left') plt.show()