OneR演算法的Python簡單實現
阿新 • • 發佈:2018-11-13
OneR演算法就是,在已有資料中,根據具有相同特徵值的個體最可能屬於哪個類別進行分類。即取效果最好的那個特徵進行分類。
#-*- coding=utf-8 -*- # import numpy as np from sklearn.datasets import load_iris from collections import defaultdict from operator import itemgetter #該演算法目的是通過這四個特徵中的一個以分辨種類,即,如果某一植物的特徵feature_index 的離散值為valu #那麼該植物最有可能是most_frequent_class,錯誤率為error #X為離散後的資料,y_true為每組資料的植株種類,feature_index為以第幾個特徵為標準,value為特徵值 def train_feature_value(X,y_true,feature_index,value): class_counts = defaultdict(int) for sample,y in zip(X,y_true): if sample[feature_index] == value: class_counts[y]+=1 sorted_class_counts = sorted(class_counts.items(),key=itemgetter(1),reverse=True) print(sorted_class_counts) most_frequent_class = sorted_class_counts[0][0] print(most_frequent_class) incorrect_predictions = [class_count for class_vlue,class_count in class_counts.items() if class_vlue != most_frequent_class] print(incorrect_predictions) error = sum(incorrect_predictions) return most_frequent_class,error if __name__ == '__main__': #從scikit-learn庫中讀取內建的“Iris植物分類資料集” dataset = load_iris() x = dataset.data#每株植物的四個特徵 y = dataset.target#每株植物的種類,有4個種類 #求4個特徵的平均值 attribute_means = x.mean(axis=0) #當該值大於平局值時為1,小於平局值時為0,完成原始資料的離散化 x_d = np.array(x>=attribute_means,dtype='int') train_feature_value(x_d,y,0,1) #TODO predictors = {} errors = []