KNN(最鄰近值演算法) scala實現
阿新 • • 發佈:2018-12-29
最鄰近值演算法實現
工程目錄結構
程式碼
訓練模型
package com.knn.model
/**
* 訓練資料模型
*
* @param aA 資料a
* @param bA 資料b
* @param typeA 型別
*/
class KNNModel(aA:Double,bA:Double,typeA:String) {
var a:Double = aA
var b:Double = bA
var resType: String = typeA
//距離
var distince:Double = 0
}
核心演算法程式碼
package com.knn.core
import com.knn.model.KNNModel
import scala.collection.immutable.ListMap
/**
* 最鄰近演算法核心演算法
*/
class KNN_Core {
// val knnModel = new KNNModel(null,null,null,null,null);
/**
* 對訓練資料進行升序排序(根據距離來進行排序)
* @param knnMOdels
* @return
*/
private def sortByDistince(knnMOdels:List[KNNModel]):List[KNNModel] ={
//進行升序排序
return knnMOdels.sortBy(knn => knn.distince)
}
/**
* 使用歐幾里得度量計算出距離
* @param knnMOdels
* @param k
*/
private def coluaclateDistince(knnMOdels:List[KNNModel],k: KNNModel):Unit =
knnMOdels.foreach(n=>{
n.distince = Math.sqrt((k.a-n.a)*(k.a-n.a)+(k.b-n.b)*(k.b-n.b))
})
/**
* 獲取距離目標點附近(指定集合大小的範圍內存在最多的資料)
* @param ks
* @return
*/
private def findMostValue(ks:List[KNNModel]):String ={
//找出訓練集中在規定數量中存在最多的類
var resType = ""
var typeCountMap:Map[String,Int] = Map()
//進行計數
ks.toStream.foreach(k=>{
if (typeCountMap.contains(k.resType)){
typeCountMap+= (k.resType -> (typeCountMap(k.resType)+1))
}else{
typeCountMap+=(k.resType -> 1)
}
})
//獲取最多數量型別(根據鍵值進行排序)
resType = ListMap(typeCountMap.toSeq.sortWith(_._2 >_._2):_*).take(1).keySet.head
return resType
}
def reckonRelize(kns:List[KNNModel],kn:KNNModel,k: Int):String={
//計算距離
coluaclateDistince(kns,kn)
//根據距離排序
var knsSort = sortByDistince(kns)
//獲取前k個數據
var knss = knsSort.take(k)
//獲取k個數據中數量最多的型別
return findMostValue(knss)
}
}
執行程式碼
package com.knn
import com.knn.core.KNN_Core
import com.knn.model.KNNModel
/**
* 分割類
*/
object app {
def main(args: Array[String]): Unit = {
//資料準備
var knnModels:List[KNNModel] = List()
knnModels = knnModels.::(new KNNModel(1.1, 1.1, "A"))
knnModels = knnModels.::(new KNNModel(1.2, 1.2, "A"))
knnModels = knnModels.::(new KNNModel(1.1, 1.0, "A"))
knnModels = knnModels.::(new KNNModel(3.0, 3.1, "B"))
knnModels = knnModels.::(new KNNModel(3.1, 3.0, "B"))
knnModels = knnModels.::(new KNNModel(5.4, 6.0, "C"))
knnModels = knnModels.::(new KNNModel(5.5, 6.3, "C"))
knnModels = knnModels.::(new KNNModel(6.0, 12.0, "C"))
knnModels = knnModels.::(new KNNModel(10.0, 12.0, "M"))
//待預測資料
var knnModle = new KNNModel(4.0, 3.2, "A")
var kNN_Core = new KNN_Core
//演算法實現
var resType = kNN_Core.reckonRelize(knnModels,knnModle,5)
println("預測結果",resType)
}
}