1. 程式人生 > >sklearn中SVM簡單使用

sklearn中SVM簡單使用

在看周志華老師的機器學習書時,利用sklearn中的SVM解決第六章的一些課後習題。

*************************************************************************************************************

1、在西瓜資料集3.0alpha上分別用線性和高斯核訓練一個SVM,並比較其支援向量的差別。

2、以“密度”為輸入,“含糖量”為輸出,訓練一個SVR。

資料如下:

	編號	密度	含糖率	好瓜
	1	0.697	0.46	是
	2	0.774	0.376	是
	3	0.634	0.264	是
	4	0.608	0.318	是
	5	0.556	0.215	是
	6	0.403	0.237	是
	7	0.481	0.149	是
	8	0.437	0.211	是
	9	0.666	0.091	否
	10	0.243	0.267	否
	11	0.245	0.057	否
	12	0.343	0.099	否
	13	0.639	0.161	否
	14	0.657	0.198	否
	15	0.36	0.37	否
	16	0.593	0.042	否
	17	0.719	0.103	否
程式碼:
# -*- coding: utf-8 -*-
"""
Created on Thu Dec 29 21:14:16 2016

@author: ZQ
"""
import numpy as np
import matplotlib.pyplot as plt

from sklearn.svm import SVC
from sklearn.svm import SVR

def plot_decision_function(X,classifier,sample_weight,axis,title):
    xx,yy = np.meshgrid(np.linspace(-4,5,500),np.linspace(-4,5,500))
    
    Z = classifier.decision_function(np.c_[xx.ravel(),yy.ravel()])
    Z = Z.reshape(xx.shape)
    
    axis.contourf(xx,yy,Z,alpha = 0.75,cmap = plt.cm.bone)
    axis.scatter(X[:,0],X[:,1],c=y,s=100*sample_weight,alpha = 0.9,
                 cmap = plt.cm.bone)
    axis.axis('off')
    axis.set_title(title)

def loadData(filename):   
    data = []
    with open(filename) as f:
        for line in f.readlines():
            data.append(line.strip().split('\t')[1:])
    return np.array(data[1:])
def initData(data):
    m,n = np.shape(data)
    retDat = np.zeros((m,n))
    for i in range(m):
        for j in range(n):
            if data[i][j] == '是':
                retDat[i][j] = 0
            elif data[i][j] == '否':
                retDat[i][j] = 1
            else:
                retDat[i][j] = float(data[i][j])
    return retDat
if __name__ == '__main__':
    data = loadData('watermelon3.0Alpha.txt')
    num_data = initData(data)
    x = num_data[:,:2]
    y = num_data[:,-1]
    
    """
    #線性畫圖
    lclf = SVC(C = 1.0,gamma=0.1,kernel='linear')
    bclf = SVC(C = 1.0,gamma = 0.1)
    
    plt.scatter(x[:,0],x[:,1],c = y)
    xx = np.linspace(-5,5)
    lclf.fit(x,y)
    lw = lclf.coef_[0]
    la = -lw[0]/lw[1]
    ly = la*xx - lclf.intercept_[0]/lw[1]
    h0 = plt.plot(xx,ly,'k-',label = 'linear')
    """
    """
    #高斯畫圖
    bclf.fit(x,y)
    
    weight = np.ones(len(x))
    fig,axis = plt.subplots(1,1)
    plot_decision_function(x,bclf,weight,axis,'test')
    """
    #訓練的SVR
    svr_rbf = SVR(kernel = 'rbf',C = 1e3,gamma = 3)
    X = x[:,0].reshape((17,1))
    y = x[:,1].reshape((17,1))
    y_rbf = svr_rbf.fit(X,y.ravel()).predict(X)
    plt.scatter(x[:,0],x[:,1],c = x[:,0])
    plt.scatter(x[:,0],y_rbf,c = x[:,0],marker = '*')
    plt.show()

結果如下:

對於線性和高斯核的SVC的支援向量都是相同的,如下:


更為直觀的畫圖顯示,第一個為線性核,第二個為高斯核。



不知道是程式碼原因還是資料原因導致結果不是很好。希望大家能多多指正。

SVR的結果如下: