SVM程式碼實現非線性分類 Matlab版
阿新 • • 發佈:2019-01-03
一、前言
在已經實現線性分類的基礎上,進一步實現非線性分類的情況。其流程與程式碼與上一篇部落格中的線性分類實現很相似。下面先談一下sv與bsv,這些概念在西瓜書上可以找到的。關於這裡出現的公式的推倒,請參考上上一篇部落格:
二、流程及實現
流程圖與線性分類一樣
而關於二次規劃檢視幫助文件即可。
還是直接上程式碼,註釋都比較詳細:
%------------主函式---------------- clear all; close all; C = 10; %成本約束引數 kertype = 'rbf'; %rbf高斯核 %①------資料準備5*10方格,每個方格20個點,共1000個點 x1=[]; x3=[]; for i=0:1:9 for j=0:1:4 b=j+rand(20,1);%隨機生成20個點 c=i+rand(20,1); x0=[x1;b]; x2=[x3;c]; x1=x0; %這裡的x1放所有點的橫座標 x3=x2; %x3放所有點的縱座標,(x1,x3) end end y0=[];%這個矩陣放所有點的標記 y1 = ones(20,1); %20個+1標記 y2 = -ones(20,1);%20個-1標記 for k=0:1:24 %迴圈賦值,使得5*10方格內相鄰的格子標記都不一樣 y0=[y0;y1]; y0=[y0;y2]; end x1=x1.' x3=x3.' %記得轉置一下哦 figure; %建立一個用來顯示圖形輸出的一個視窗物件 for m=1:1:25 plot(x1(1,(1+20*(2*m-2)):(20*(2*m-1))),x3(1,(1+20*(2*m-2)):(20*(2*m-1))),'k.'); %畫圖 hold on; plot(x1(1,(1+20*(2*m-1)):(20*(2*m))),x3(1,(1+20*(2*m-1)):(20*(2*m))),'b+'); %畫圖 hold on; %在同一個figure中畫幾幅圖時,用此句 end %axis([0 5 0 10]); %設定座標軸範圍 %②-------------訓練樣本 X = [x1;x3]; %訓練樣本2*n矩陣,n為樣本個數,d為特徵向量個數 Y = y0.'; %訓練目標1*n矩陣,n為樣本個數,值為+1或-1 svm = svmTrain(X,Y,kertype,C); %訓練樣本 %%%把支援向量標出來,若支援向量畫的不對,此時可通過在kernel函式調參來修改 for i=1:1:svm.svnum if svm.Ysv(1,i)==1 plot(svm.Xsv(1,i),svm.Xsv(2,i),'mo');%一類支援向量用粉色圈住 else plot(svm.Xsv(1,i),svm.Xsv(2,i),'ko');%另一類支援向量黑色圈 end end %plot(svm.Xsv(1,:),svm.Xsv(2,:),'ro'); %③-------------測試 [x1,x2] = meshgrid(0:0.05:5,0:0.05:10); %最大值控制著等高線在幾乘幾範圍畫出來 [rows,cols] = size(x1); nt = rows*cols; Xt = [reshape(x1,1,nt);reshape(x2,1,nt)]; %前半句reshape(x1,1,nt)是將x1轉成1*(rows*cols)的矩陣,所以Xt是2*(rows*cols)的矩陣 %reshape函式重新調整矩陣的行、列、維數 y3 = ones(1,floor(nt/2)); y4 = -ones(1,floor(nt/2)+1); Yt = [y3,y4]; result = svmTest(svm, Xt, Yt, kertype); %④--------------畫曲線的等高線圖 Yd = reshape(result.Y,rows,cols); contour(x1,x2,Yd,3); %產生三個水平的等高線 title('5*10資料分類'); x1=xlabel('X軸'); x2=ylabel('Y軸'); %-----------訓練樣本的函式svmTrain--------- function svm = svmTrain(X,Y,kertype,C) % Options是用來控制演算法的選項引數的向量,optimset無參時,建立一個選項結構所有欄位為預設值的選項 options = optimset; options.LargeScale = 'off';%LargeScale指大規模搜尋,off表示在規模搜尋模式關閉 options.Display = 'off'; %表示無輸出 %二次規劃來求解問題,可輸入命令help quadprog檢視詳情 n = length(Y); %返回Y最長維數 H = (Y'*Y).*kernel(X,X,kertype); f = -ones(n,1); %f為1*n個-1,f相當於Quadprog函式中的c A = []; b = []; Aeq = Y; %相當於Quadprog函式中的A1,b1 beq = 0; lb = zeros(n,1); %相當於Quadprog函式中的LB,UB ub = C*ones(n,1); a0 = zeros(n,1); % a0是解的初始近似值 [a,fval,eXitflag,output,lambda] = quadprog(H,f,A,b,Aeq,beq,lb,ub,a0,options); %a是輸出變數,問題的解 %fval是目標函式在解a處的值 %eXitflag>0,則程式收斂於解x;=0則函式的計算達到了最大次數;<0則問題無可行解,或程式執行失敗 %output輸出程式執行的某些資訊 %lambda為在解a處的值Lagrange乘子 epsilon = 1e-8; %0<a<a(max)則認為x為支援向量,find返回一個包含陣列X中每個非零元素的線性索引的向量。 sv_label = find(abs(a)>epsilon); svm.a = a(sv_label); svm.Xsv = X(:,sv_label); svm.Ysv = Y(sv_label); svm.svnum = length(sv_label); %svm.label = sv_label; end %---------------核函式kernel--------------- function K = kernel(X,Y,type) %X 維數*個數 switch type case 'linear' %此時代表線性核 K = X'*Y; case 'rbf' %此時代表高斯核 delta = 0.5; %改變這個引數圖會變的不一樣唉。。。越大支援向量越多。。。 delta = delta*delta; XX = sum(X'.*X',2); %2表示將矩陣中的按行為單位進行求和 YY = sum(Y'.*Y',2); XY = X'*Y; K = abs(repmat(XX,[1 size(YY,1)]) + repmat(YY',[size(XX,1) 1]) - 2*XY); K = exp(-K./delta); end end %---------------測試的函式svmTest------------- function result = svmTest(svm, Xt, Yt, kertype) temp = (svm.a'.*svm.Ysv)*kernel(svm.Xsv,svm.Xsv,kertype); %total_b = svm.Ysv-temp; b = mean(svm.Ysv-temp); %b取均值 w = (svm.a'.*svm.Ysv)*kernel(svm.Xsv,Xt,kertype); result.score = w + b; Y = sign(w+b); %f(x) result.Y = Y; result.accuracy = size(find(Y==Yt))/size(Yt); end
執行結果: