1. 程式人生 > >基於KNN演算法實現的單個圖片數字識別

基於KNN演算法實現的單個圖片數字識別

Test.csv中第1434行,圖片數字值為”0,最終歸類為0,正確。

Test.csv中第14686行,圖片數字值為”8,最終歸類為8,正確。

4原始碼

最後附上本次基於KNN思想實現單個數字圖片識別的全部原始碼。

/** 

* @Title: DigitClassification.java

* @Package com.org.meify

* @Description: 單個數字圖片識別

* @authormeify

* @date 2016年1月15日下午4:19:04

* @version V1.0 

*/

publicclass DigitClassification {

         // 訓練集圖片向量map

         privatestatic HashMap<String, ArrayList<Integer>> trainingMap = new HashMap<String, ArrayList<Integer>>();

         // 測試集圖片向量map

         privatestatic HashMap<String, ArrayList<Integer>> testMap = new HashMap<String, ArrayList<Integer>>();

        

         /**

         * @Title: getSubArray

         * @Description: 從數字下標1開始擷取子集合

         * @param  arr

         * @return ArrayList<Integer>  

         * @throws

         */

         public ArrayList<Integer> getSubArray(String[] arr) {

                   ArrayList<Integer> list = new ArrayList<Integer>();

                   for (inti = 1; i < arr.length; i++) {

                            list.add(Integer.valueOf(arr[i]));

                   }

                   returnlist;

         }

        

         /**

         * @Title: toList

         * @Description: 將陣列轉為集合

         * @param  arr 

         * @return ArrayList<Integer>  

         * @throws

         */

         public ArrayList<Integer> toList(String[] arr) {

                   ArrayList<Integer> list = new ArrayList<Integer>();

                   for (inti = 0; i < arr.length; i++) {

                            list.add(Integer.valueOf(arr[i]));

                   }

                   returnlist;

         }

        

         /**

         * @Title: getDistance

         * @Description: 計算兩個向量之間的距離(向量歐式距離)

         * @
[email protected]
list1 * @[email protected] list2 * @return double * @throws */ publicdouble getDistance(ArrayList<Integer> list1, ArrayList<Integer> list2) { if(list1.size() != list2.size()) { System.out.println("警告:兩個向量大小不等"); return 0.0d; } intsum = 0; for (inti = 0; i < list1.size(); i++) { inta = list1.get(i); intb = list2.get(i); sum += (a - b) * (a - b); } return Math.sqrt((double) sum); } /** * @Title: display * @Description: 以矩陣的形式展示圖片點陣圖片大小為28*28 * @paramlist * @return void * @throws */ void display(ArrayList<Integer> list) { if(list.size() == 784) { for(inti=0;i<784;i++) { System.out.print(list.get(i) + " "); if((i+1)%28 == 0) { System.out.println(); } } } } /** * @Title: loadTrainData * @Description: 載入訓練集資料 * @param path * @return void * @throws */ void loadTrainData(String path) { try { ArrayList<String[]> csvList = new ArrayList<String[]>(); CsvReader reader = new CsvReader(path, ',', Charset.forName("SJIS")); reader.readHeaders(); // 跳過CSV header欄 intindex = 1; while (reader.readRecord()) { csvList.add(reader.getValues()); String[] arr = reader.getValues(); String key = "line" + index + "/" + arr[0]; // key:行號 + 圖片的真實數字值ֵ ArrayList<Integer> values = getSubArray(arr); trainingMap.put(key, values); index ++; } reader.close(); } catch (Exception ex) { System.out.println(ex); } } /** * @Title: loadTestData * @Description: 載入測試集資料 * @
[email protected]
path * @return void * @throws */ void loadTestData(String path) { try { ArrayList<String[]> csvList = new ArrayList<String[]>(); CsvReader reader = new CsvReader(path, ',', Charset.forName("SJIS")); reader.readHeaders(); // 跳過CSV header欄 intindex = 1; while (reader.readRecord()) { csvList.add(reader.getValues()); String[] arr = reader.getValues(); String key = "line" + index; // key:行號 ArrayList<Integer> values = toList(arr); testMap.put(key, values); index ++; } reader.close(); } catch (Exception ex) { System.out.println(ex); } } /** * @Title: classify * @Description: * @param testList 測試圖片向量 * @param k K值決定臨近K個圖片 * @returnint * @throws */ int classify(ArrayList<Integer> testList, intk) { HashMap<Double, String> distanceMap = new HashMap<Double, String>(); // 計算當前圖片向量和每個訓練集向量的距離 Iterator<Entry<String, ArrayList<Integer>>> train_iter = trainingMap.entrySet().iterator(); while (train_iter.hasNext()) { Map.Entry<String, ArrayList<Integer>> train_entry = (Entry<String, ArrayList<Integer>>) train_iter.next(); String train_key = (String) train_entry.getKey(); // ArrayList<Integer> trainList = train_entry.getValue(); // doubledistance = getDistance(trainList, testList); distanceMap.put(distance, train_key); } // 初始化每個數字出現的頻次map HashMap<String, Integer> countMap = new HashMap<String, Integer>(); for(inti=0;i<10; i++) { countMap.put(String.valueOf(i), 0); } // 將距離map排序,並選取前K小的圖片 SortedMap<Double, String> sortMap = new TreeMap<Double, String>(distanceMap); Set<Entry<Double, String>> sort_entry = sortMap.entrySet(); Iterator<Entry<Double, String>> sort_it = sort_entry.iterator(); inti = 0; while (sort_it.hasNext() && i < k) { Entry<Double, String> entry = sort_it.next(); String str = entry.getValue(); String digit = str.split("/")[1]; intnum = countMap.get(digit); num ++; i++; countMap.put(digit, num); } // 刪選出現頻次最大的數字作為當前圖片所屬數字 intmax = 0; String targetValue = "unknown"; Iterator<Entry<String, Integer>> iter2 = countMap.entrySet().iterator(); while(iter2.hasNext()) { Map.Entry<String, Integer> entry = (Entry<String, Integer>) iter2.next(); String digit = entry.getKey(); intnum = entry.getValue(); if(num > max) { max = num; targetValue = digit; } } return Integer.valueOf(targetValue); } /** * @Title: getRealDigits * @Description: 將所有的測試圖片向量進行歸類,使用KNN演算法思想,得出每個測試圖片上的數字 * @param k * @return void * @throws */ void getRealDigits(intk) { Iterator<Entry<String, ArrayList<Integer>>> test_iter = testMap.entrySet().iterator(); intindex = 1; while (test_iter.hasNext()) { Map.Entry<String, ArrayList<Integer>> test_entry = (Entry<String, ArrayList<Integer>>) test_iter.next(); String test_key = (String) test_entry.getKey(); // 當前行號 ArrayList<Integer> testList = test_entry.getValue(); // 當前測試圖片的向量 // 展示當前圖片向量 display(testList); intfinalDigit = classify(testList, k); System.out.println("line:" + index + ",圖片上的數字為:" + finalDigit); index ++; } } publicstaticvoid main(String[] args) { DigitClassification classification = new DigitClassification(); // 1.載入所有訓練圖片向量 classification.loadTrainData("D://train.csv"); // 2.載入所有測試圖片向量 classification.loadTestData("D://test.csv"); // 3.利用KNN演算法思想,對測試圖片進行歸類其中K值取20 classification.getRealDigits(20); } }


相關推薦

KNN演算法——實現手寫數字識別(Sklearn實現

KNN專案實戰——手寫數字識別 1、資料集介紹 需要識別的數字已經使用圖形處理軟體,處理成具有相同的色彩和大小:寬高是32畫素x32畫素的黑白影象。儘管採用本文格式儲存影象不能有效地利用記憶體空間,但是為了方便理解,我們將圖片轉換為文字格式。 數字的文字格式如下:

基於KNN演算法實現單個圖片數字識別

Test.csv中第1434行,圖片數字值為”0“,最終歸類為0,正確。 Test.csv中第14686行,圖片數字值為”8“,最終歸類為8,正確。 4原始碼 最後附上本次基於KNN思想實現單個數字圖片識別的全部原始碼。 /** * @Title: DigitClassification.java

機器學習--k-近鄰演算法kNN實現手寫數字識別

這裡的手寫數字以0,1的形式儲存在文字檔案中,大小是32x32.目錄trainingDigits有1934個樣本。0-9每個數字大約有200個樣本,命名規則如下: 下劃線前的數字代表是樣本0-9的

【人工智慧】利用C語言實現KNN演算法進行手寫數字識別

KNN演算法稱為鄰近演算法,或者說K最近鄰(kNN,k-NearestNeighbor)分類演算法。所謂K最近鄰,就是k個最近的鄰居的意思,說的是每個樣本都可以用它最接近的k個鄰居來代表。kNN演算法的核心思想是如果一個樣本在特徵空間中的k個最相鄰的樣本中的大多數屬於某一個類

R語言基於KNN演算法實現蘑菇毒性識別

R語言:基於KNN演算法實現蘑菇毒性識別 平臺:Ubuntu16.04LTS   RStudio 資料集介紹: trainData.txt  訓練資料集。包含4339個樣本(行),每個樣本共6個特徵(列),其中前5列為蘑菇樣本的特徵值,第6列為蘑菇的毒性屬性,0表示無毒,1

學習KNN(三)KNN+HOG實現手寫數字識別

在學習KNN(二)KNN演算法手寫數字識別的OpenCV實現我們直接將畫素值作為特徵,實現了KNN演算法的手寫數字識別問題,並得到了較好的準確率,但是就像其他機器學習演算法一樣,KNN的物件同樣是特徵,所以我們可以用一種特徵提取演算法配合KNN實現手寫數字識

[分享] Python實現基於深度學習的手寫數字識別演算法

本文將採用深度學習中的卷積神經網路來訓練手寫數字識別模型。使用卷積神經網路建立合理的模型結構,利用卷積層中設定一定數目的卷積核(即濾波器),通過訓練資料使模型學習到能夠反映出十個不同手寫提數字特徵的卷積核權值,最後通過全連線層使用softmax函式給出預測數字圖對應每種數字可能性的概率多少。 本文以學習基於

python資料建模與KNN演算法實現手寫體數字識別

      資料建模指的是對現實世界各類資料的抽象組織,建立一一個適合的模型對資料進行處理。在資料分析與挖掘中,我們通常需要根據一-些資料建 立起特定的模型,然後處理。模型的建立需要依賴於演算法, - -般,常見的演算法有分類、聚類、關聯、

各種機器學習方法(線性迴歸、支援向量機、決策樹、樸素貝葉斯、KNN演算法、邏輯迴歸)實現手寫數字識別並用準確率、召回率、F1進行評估

本文轉自:http://blog.csdn.net/net_wolf_007/article/details/51794254 前面兩章對資料進行了簡單的特徵提取及線性迴歸分析。識別率已經達到了85%, 完成了數字識別的第一步:資料探測。 這一章要做的就各

一看就懂的K近鄰演算法(KNN),K-D樹,並實現手寫數字識別

1. 什麼是KNN 1.1 KNN的通俗解釋 何謂K近鄰演算法,即K-Nearest Neighbor algorithm,簡稱KNN演算法,單從名字來猜想,可以簡單粗暴的認為是:K個最近的鄰居,當K=1時,演算法便成了最近鄰演算法,即尋找最近的那個鄰居。 用官方的話來說,所謂K近鄰演算法,即是給定一個訓練資

Java基於opencv實現圖像數字識別(一)

binary oid ring 是把 sca pre 內存 還需要 自己 Java基於opencv實現圖像數字識別(一) 最近分到了一個任務,要做數字識別,我分配到的任務是把數字一個個的分開;當時一臉懵逼,直接百度java如何分割圖片中的數字,然後就百度到了用Buffere

Java基於opencv實現圖像數字識別(二)—基本流程

數字 都是 模型 PE 設計 category 理解 兩種 ace Java基於opencv實現圖像數字識別(二)—基本流程 做一個項目之前呢,我們應該有一個總體把握,或者是進度條;來一步步的督促著我們來完成這個項目,在我們正式開始前呢,我們先討論下流程。 我做的主要是表格

Hadoop偽分佈安裝詳解+MapReduce執行原理+基於MapReduce的KNN演算法實現

本篇部落格將圍繞Hadoop偽分佈安裝+MapReduce執行原理+基於MapReduce的KNN演算法實現這三個方面進行敘述。 (一)Hadoop偽分佈安裝 1、簡述Hadoop的安裝模式中–偽分佈模式與叢集模式的區別與聯絡. Hadoop的安裝方式有三種:本地模式,偽分佈模式

加權歐氏距離KNN演算法實現人臉識別(Python實現)

前沿: 本實踐是純屬小白練手入門小專案,希望未來可以手動自己用神經網路來識別人臉。共勉,加油! 題目內容: 針對標準人臉樣本庫,選擇訓練和測試樣本,對基本的knn分類演算法設計智慧演算法進行改進,能夠對測試樣本識別出身份。 題目要求: 1) 選擇合適的編碼

機器學習實戰k近鄰演算法(kNN)應用之手寫數字識別程式碼解讀

from numpy import * from os import listdir import operator import time #k-NN簡單實現函式 def classify0(inX,dataSet,labels,k): #求出樣本集的行數,也就是labels標籤的數目

python tensorflow 基於cnn實現手寫數字識別

感覺剛才的程式碼不夠給力,所以再儲存一份基於cnn的手寫數字自識別的程式碼 # -*- coding: utf-8 -*- import tensorflow as tf from tensorflow.examples.tutorials.mnist

MachineLearning— (KNN)k Nearest Neighbor實現手寫數字識別(三)

    本篇博文主要結合前兩篇的knn演算法理論部分knn理論理解(一)和knn理論理解(二),做一個KNN的實現,主要是根據《機器學習實戰》這本書的內容,一個非常經典有趣的例子就是使用knn最近鄰演算法來實現對手寫數字的識別,下面將給出Python程式碼,儘量使用詳盡的解

基於k近鄰(KNN)的手寫數字識別

作者:faaronzheng 轉載請註明出處! 最近再看Machine Learning in Action. k近鄰演算法這一章節提供了不少例子,本著Talk is cheap的原則,我們用手寫數字識別來實際測試一下。 簡單的介紹一下k近鄰演算法(KNN):給定測試樣本

基於感知機的手寫數字識別java實現

多層感知機的手寫數字識別,迭代10次對訓練集的正確率97 Main函式,在繪製完數字後,要點下確定按鈕再去識別,重繪按鈕自然是再次繪圖    訓練自己的網路結構會替換之前訓練的網路結構,沒有寫儲存或者另存新網路模型。結果對訓練集變現很好,對繪圖的識別結果仍不

MNIST資料集實現手寫數字識別基於tensorflow)

主要應用了下面幾個方法來提高準確率; 使用隨機梯度下降(batch) 使用Relu啟用函式去線性化 使用正則化避免過擬合 使用帶指數衰減的學習率 使用滑動平均模型 使用交叉熵損失函式來刻畫預測值和真實值之間的差距的損失函式 第一步,匯入MNIST資料集 from