用tensorflow實現svm對鳶尾花資料分類
阿新 • • 發佈:2019-01-10
from sklearn import datasets import matplotlib.pyplot as plt import numpy as np import tensorflow as tf sess=tf.Session() iris=datasets.load_iris() x_vals=np.array([[x[0],x[3]] for x in iris.data]) y_vals=np.array([1 if y==0 else -1 for y in iris.target]) train_indices=np.random.choice(len(x_vals),round(len(x_vals)*0.8),replace=False) test_indices=np.array(list(set(range(len(x_vals)))-set(train_indices))) x_vals_train=x_vals[train_indices] x_vals_test=x_vals[test_indices] y_vals_train=y_vals[train_indices] y_vals_test=y_vals[test_indices] batch_size=100 x_data=tf.placeholder(shape=[None,2],dtype=tf.float32) y_target=tf.placeholder(shape=[None,1],dtype=tf.float32) A=tf.Variable(tf.random_normal(shape=[2,1])) b=tf.Variable(tf.random_normal(shape=[1,1])) model_output=tf.add(tf.matmul(x_data,A),b) l2_norm=tf.reduce_sum(tf.square(A)) alpha=tf.constant([0.1]) classification_term=tf.reduce_mean(tf.maximum(0.0,tf.subtract(1.0,tf.multiply(model_output,y_target)))) loss=tf.add(classification_term,tf.multiply(alpha,l2_norm)) prediction=tf.sign(model_output) accuracy=tf.reduce_mean(tf.cast(tf.equal(prediction,y_target),tf.float32)) my_opt=tf.train.GradientDescentOptimizer(0.01) train_step=my_opt.minimize(loss) init=tf.initialize_all_variables() sess.run(init) loss_vec=[] train_accuracy=[] test_accuracy=[] for i in range(500): rand_index=np.random.choice(len(x_vals_train),size=batch_size) rand_x=x_vals_train[rand_index] rand_y=np.transpose([y_vals_train[rand_index]]) sess.run(train_step,feed_dict={x_data:rand_x,y_target:rand_y}) temp_loss=sess.run(loss,feed_dict={x_data:rand_x,y_target:rand_y}) loss_vec.append(temp_loss) train_accuracy_temp=sess.run(accuracy,feed_dict={x_data:x_vals_test,y_target:np.transpose([y_vals_test])}) train_accuracy.append(train_accuracy_temp) test_accuracy_temp=sess.run(accuracy,feed_dict={x_data:x_vals_test,y_target:np.transpose([y_vals_test])}) test_accuracy.append(test_accuracy_temp) if(i+1)%100==0: print('step #'+str(i+1)+' A = '+str(sess.run(A))+' b= '+str(sess.run(b))) print('loss is '+str(temp_loss)) [[a1],[a2]]=sess.run(A) [[b1]]=sess.run(b) slope=-a2/a1 y_intercept=b1/a1 x1_vals=[d[1] for d in x_vals] best_fit=[i*slope+y_intercept for i in x1_vals] setosa_x=[d[1] for i,d in enumerate(x_vals) if y_vals[i]==1] setosa_y=[d[0] for i,d in enumerate(x_vals) if y_vals[i]==1] not_setosa_x=[d[1] for i,d in enumerate(x_vals) if y_vals[i]==-1] not_setosa_y=[d[0] for i,d in enumerate(x_vals) if y_vals[i]==-1] plt.plot(setosa_x,setosa_y,'o',label='setosa') plt.plot(not_setosa_x,not_setosa_y,'x',label='Non-setosa') plt.plot(x1_vals,best_fit,'r-',label='Linear Separator',linewidth=3) plt.ylim([0,10]) plt.legend(loc='lower right') plt.title('sepal length vs pedal width') plt.xlabel('pedal width') plt.ylabel('sepal width') plt.show() plt.plot(train_accuracy,'k-',label='Training Accuracy') plt.plot(test_accuracy,'r--',label='Test Accuracy') plt.title('Train and Test Set Accuracy') plt.show() plt.plot(loss_vec,'k-',) plt.title('loss') ply.show()