機器學習入門-載入sklearn中資料並用matplotlib進行視覺化
阿新 • • 發佈:2018-11-05
from sklearn import datasets import matplotlib.pyplot as plt def get_data(): """ 從sklearn中獲取鳶尾花的資料 :return: 鳶尾花資料的字典,字典中包括的key有:【'data', 'target', 'target_names', 'DESCR', 'feature_names'] 簡單介紹一下: data就是(150, 4)的資料集,target表示1位陣列,數字0~2表示分類, target_names表示分類名,DESCR表示對資料的描述 feature_names: 特徵值名稱 """ iris = datasets.load_iris() return iris def draw_graph(iris_data): """ 獲取兩個維度的資料進行資料視覺化,由於鳶尾花共有4個特徵,在平面中只能繪製2個特徵,所以獲取特徵1和特徵2進行繪製 :param: 鳶尾花的資料集 :return: """ X = iris_data.data[:, :2] target_names = iris_data.target_names print(target_names) print(iris_data.feature_names ) y = iris_data.target plt.scatter(X[y == 0, 0], X[y == 0, 1], color='red', marker='o', label=target_names[0]) plt.scatter(X[y == 1, 0], X[y == 1, 1], color='blue', marker='*', label=target_names[1]) plt.scatter(X[y == 2, 0], X[y == 2, 1], color='green', marker='+', label=target_names[2]) plt.legend() plt.title(u'Distribution of 3 different irises in length and width') plt.show() if __name__ == '__main__': iris = get_data() draw_graph(iris)
執行結果:
方法學習:
1、在sklearn中所有的資料集都放在datasets模組裡面,匯入對應的資料直接用loadxx
2、在sklearn中load出來的資料是一個字典,直接可以用原始data.屬性值獲取對應的值,比如 data.feature_names就可以獲取特徵的名字
3、用matplotlib的時候, 多看官方文件 https://matplotlib.org/users/pyplot_tutorial.html
scatter表示隨機的繪製點,marker也有不同的描述方法,可以在文件中看到