1. 程式人生 > >OneR演算法的Python簡單實現

OneR演算法的Python簡單實現

OneR演算法就是,在已有資料中,根據具有相同特徵值的個體最可能屬於哪個類別進行分類。即取效果最好的那個特徵進行分類。

#-*- coding=utf-8 -*-
#
import numpy as np
from sklearn.datasets import load_iris
from collections import defaultdict
from operator import itemgetter


#該演算法目的是通過這四個特徵中的一個以分辨種類,即,如果某一植物的特徵feature_index 的離散值為valu
#那麼該植物最有可能是most_frequent_class,錯誤率為error
#X為離散後的資料,y_true為每組資料的植株種類,feature_index為以第幾個特徵為標準,value為特徵值
def train_feature_value(X,y_true,feature_index,value):
	class_counts = defaultdict(int)
	for sample,y in zip(X,y_true):
		if sample[feature_index] == value:
			class_counts[y]+=1
	sorted_class_counts = sorted(class_counts.items(),key=itemgetter(1),reverse=True)
	print(sorted_class_counts)
	most_frequent_class = sorted_class_counts[0][0]
	print(most_frequent_class)
	incorrect_predictions = [class_count for class_vlue,class_count in class_counts.items() if class_vlue != most_frequent_class]
	print(incorrect_predictions)
	error = sum(incorrect_predictions)
	return most_frequent_class,error


if __name__ == '__main__':
	#從scikit-learn庫中讀取內建的“Iris植物分類資料集”
	dataset = load_iris()
	x = dataset.data#每株植物的四個特徵
	y = dataset.target#每株植物的種類,有4個種類

	#求4個特徵的平均值
	attribute_means = x.mean(axis=0)
	#當該值大於平局值時為1,小於平局值時為0,完成原始資料的離散化
	x_d = np.array(x>=attribute_means,dtype='int')
	train_feature_value(x_d,y,0,1)
	#TODO
	predictors = {}
	errors = []