K-NN演算法的C語言實現
阿新 • • 發佈:2019-01-01
k-NN(k-Nearest Neighbors) k-近鄰演算法
概述
- k-近鄰演算法採用測量不同的特徵值之間的距離方法進行分類
k-近鄰演算法的一般流程
- 收集資料:可以使用任何方法
- 準備資料:距離計算所需要的數值,最好是結構化的資料格式
- 分析資料:可以使用任何方法
- 訓練演算法:此步驟不適用於k-近鄰演算法
- 測試演算法:計算錯誤率
- 使用演算法:首先需要輸入樣本資料和結構化的輸出結果,然後使用k-近鄰演算法判定輸入資料分別屬於哪個分類,最後應用對計算出的分類執行後續的處理
對未知類別屬性的資料集中的每個點依次執行以下操作
- 計算已知類別資料集中的點與當前點的距離
- 按照距離遞增次序排序
- 選取與當前點距離最小的k個點
- 確定前k個點所在類別的出現頻率
- 返回前k個點出現頻率最高的類別作為當前點的預測分類
按照上述步驟,可以實現k-近鄰演算法
k-近鄰演算法的C語言實現
#include <stdlib.h>
#include <stdio.h>
#include <math.h>
#define SIZE_ATTR 3 /* 屬性維度 */
#define SIZE_TRAIN 500 /* 訓練集大小 */
#define SIZE_TEST 500 /* 測試集大小 */
#define K 7 /* 所選k值 */
#define FILE_TRAIN "train.txt"
/* 記錄所構成的結構體變數 */
typedef struct _DataVector {
int id; /* 標號 */
float attr[SIZE_ATTR]; /* 屬性 */
int label; /* 類別 */
} DataVector;
/* 把記錄中的屬性換成距離後的結構體變數 */
typedef struct _DistanceVector {
int id; /* 標號 */
int label; /* 類別 */
float distance; /* 距離 */
} DistanceVector;
/* 屬性的結構體變數
可以先對屬性值做一個分析,再做下一步針對性處理(如歸一化特徵值處理) */
typedef struct _AttrValue {
float max; /* 屬性的最大值 */
float min; /* 屬性的最小值 */
float length; /* 屬性的長度 */
} AttrValue;
/* 定義全域性變數 */
DataVector trainSet[SIZE_TRAIN]; /* 訓練集 */
DataVector testSet[SIZE_TEST]; /* 測試集 */
DistanceVector knn[SIZE_TRAIN]; /* 距離儲存 */
AttrValue av[SIZE_ATTR]; /* 屬性的屬性 */
/* 從檔案中載入資料到記憶體 */
void loadDataFromFile(FILE *fp, char *fileName, DataVector *dv, int length) {
int i, j;
if ((fp = fopen(fileName, "r")) == NULL) {
printf("open \"%s\" failured!/n", fileName);
exit(1);
}
for (i = 0; i < length; ++i) {
for (j = 0; j < SIZE_ATTR; ++j) {
fscanf(fp, "%f ", &dv[i].attr[j]);
}
fscanf(fp, "%d\n", &dv[i].label);
}
fclose(fp);
}
/* 準備資料 */
void loadData() {
FILE *fp = NULL;
loadDataFromFile(fp, FILE_TRAIN, trainSet, SIZE_TRAIN);
loadDataFromFile(fp, FILE_TRAIN, testSet, SIZE_TEST);
printf("loading data success!\n");
}
/* 資料分析(預處理)
計算每個屬性長度,為歸一化特徵值準備 */
void preProcess() {
int i, j;
/* 初始化 */
for (i = 0; i < SIZE_ATTR; ++i) {
av[i].max = trainSet[0].attr[i];
av[i].min = trainSet[0].attr[i];
}
/* 計算屬性最大最小值 */
for (i = 0; i < SIZE_TRAIN; ++i) {
for (j = 0; j < SIZE_ATTR; ++j) {
if (trainSet[i].attr[j] > av[j].max) {
av[j].max = trainSet[i].attr[j];
} else if (trainSet[i].attr[j] < av[j].min) {
av[j].min = trainSet[i].attr[j];
}
}
}
/* 計算屬性長度 */
for (i = 0; i < SIZE_ATTR; ++i) {
av[i].length = av[i].max - av[i].min;
}
}
/* 歸一化特徵值
公式:newValue = (oldValue - min) / (max - min) */
float autoNorm(float oldValue, AttrValue *av) {
return (oldValue - (av->min)) / (av->length);
}
/* 距離計算
這裡計算的是歐式距離 */
float calcDistance(DataVector d1, DataVector d2) {
float sum = 0.0;
float newValue;
int i;
for (i = 0; i < SIZE_ATTR; ++i) {
newValue = autoNorm((d1.attr[i] - d2.attr[i]), av+i);
sum += newValue * newValue;
}
return (float) sqrt(sum);
}
/* 把每個資料的屬性向量轉化為距離 */
void transDistance(DataVector dv) {
int i;
for (i = 0; i < SIZE_TRAIN; ++i) {
/* 對距離進行賦值 */
knn[i].id = i;
knn[i].label = trainSet[i].label;
knn[i].distance = calcDistance(trainSet[i], dv);
}
}
/* 對所有距離進行排序,選取距離最小的k個數據向量(此處使用直接選擇排序) */
void knnSort() {
int i, j, k;
DistanceVector temp;
for (i = 0; i < K; ++i) {
k = i;
/* 從無序序列中挑出一個最小的元素 */
for (j = i + 1; j <= SIZE_TRAIN; ++j) {
if (knn[k].distance > knn[j].distance) {
k = j;
}
}
temp = knn[i];
knn[i] = knn[k];
knn[k] = temp;
}
}
/* 預測分類 */
int forecastClassification() {
int freq[K] = {0};
int maxFreq = 0;
int i, j, k = 0;
/* 確定前k個點所在類別出現的概率
這裡有點欠妥,因為分類最多能出現k個,出現了重複類別重複計算*/
for (i = 0; i < K; ++i) {
for (j = 0; j < K; ++j) {
if (knn[j].label == knn[i].label) {
freq[i]++;
}
}
}
/* 找到最大頻率 */
for (i = 0; i < K; ++i) {
if (freq[i] > maxFreq) {
maxFreq = freq[i];
k = i;
}
}
/* 得到最大頻率的類別 */
return knn[k].label;
}
/* 對測試資料進行測試 */
void test() {
int i;
int k = 0;
loadData();
preProcess();
/* 對每一條測試資料進行計算 */
for (i = 0; i < SIZE_TEST; ++i) {
transDistance(testSet[i]);
knnSort();
if (testSet[i].label == forecastClassification()) {
printf("1");
} else {
printf("0");
++k;
}
}
printf("\nTest end, wrong time is %d, the correct rate is %.2f%%\n", k, (float) (SIZE_TEST - k)/SIZE_TEST*100);
}
void main() {
test();
system("pause");
}
參考資料
- 機器學習實戰. Peter Harrington
測試材料
- 機器學習實戰原始碼/Ch02/datingTestSet2.txt
- 下載連結: