2.3測試演算法:使用k-近鄰演算法識別手寫數字
阿新 • • 發佈:2021-07-21
1 #將每個32*32影象陣列轉為1*1024特徵值陣列 2 def img2vector(filename): 3 returnVect = zeros((1,1024)) #初始化returnVect為1行1024列的全零陣列 4 fr = open(filename) 5 for i in range(32): 6 lineStr = fr.readline() 7 for j in range(32): 8 returnVect[0,32*i+j] = int(lineStr[j]) 9#將32*32的影象矩陣轉換為1*1024的陣列 10 return returnVect
1 #手寫數字識別系統的測試程式碼 2 def handwritingClassTest(): 3 hwLabels = [] 4 trainingFileList = os.listdir('trainingDigits') #listdir返回指定資料夾下檔案的列表 5 m = len(trainingFileList) #m值為列表長度,即檔案數目 6 trainingMat = zeros((m,1024)) #初始化全零陣列(m,1024) 7 for i in range(m): 8 fileNameStr = trainingFileList[i] #取出某個一個檔案 9 fileStr = fileNameStr.split('.')[0] #得到去除檔案格式後的檔名 10 classNumStr = int(fileStr.split('_')[0]) #將該元素轉化為int型別 11 hwLabels.append(classNumStr) #將該值加到標籤列表中。 12 trainingMat[i,:] = img2vector('trainingDigits/%s' % fileNameStr) 13 #根據路徑將相應的檔案內容轉換為1行1024列的資料 14 testFileList = os.listdir('testDigits') #返回指定資料夾下檔案的列表 15 errorCount = 0.0 16 mTest = len(testFileList)#獲得檔案個數 17 for i in range(mTest): 18 fileNameStr = testFileList[i] #取出某個檔案 19 fileStr = fileNameStr.split('.')[0] #得到去除檔案格式後的檔名 20 classNumStr = int(fileStr.split('_')[0]) #得到實際標籤值 21 vectorUnderTest = img2vector('testDigits/%s' % fileNameStr) 22 #將檔案內容轉化為1*1024形式的資料 23 classifierResult = classify0(vectorUnderTest, \ 24 trainingMat,hwLabels,3) 25 #對輸入向量進行k近鄰演算法預測,預測標籤值為classfierResult 26 print('the classifier came back with: %d, the real answer is :%d'\ 27 % (classifierResult,classNumStr)) 28 if (classifierResult != classNumStr): 29 errorCount += 1.0 30 print('\nthe total number of errors is: %f' % errorCount) 31 32 print('the total error rate is: %f' % (errorCount/float(mTest)))
錄入程式碼時要注意縮排!!!(剛開始時,由於img2vector函式return語句縮排有誤,執行出來識別的錯誤率奇高無比,經仔細的畢對檢查,發現是該語句縮排打錯了。)
對於語法錯誤,有時在spider提示的錯誤行中檢查半天查不出來問題,但是卻報”invalid syntax“, 此時也需要檢查上一句的括號有沒有缺失!!