感知器實現鳶尾花的分類
阿新 • • 發佈:2019-04-19
scalar param where self. spa min break pyplot con
import numpy as np
from sklearn.datasets import load_iris
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib as mpl
class perceptron:
‘‘‘
感知器類實現
‘‘‘
def __init__(self,eta,epoch):
‘‘‘
init perceptron
parameters
--------
eta: float
learning rate
epoch: int
learning count
‘‘‘
self.eta=eta
self.epoch=epoch
def step(self,z):
‘‘‘
come true step function
Parameters
---------
z: array_like or scalar
net input
Returns
-----
step:array_like or scalar
output classification
‘‘‘
return np.where(z>=0,1,0)
def fit(self,X,y):
‘‘‘
fit weight
Parameters
--------
X: array_like
Input data
y: array_like
Input data‘s label
Returns
------
fit:array_like
perceptron‘s weight
‘‘‘
X,y=np.asarray(X),np.asarray(y)
self.w_=np.zeros(X.shape[1]+1)
self.errors_=[]
for i in range(self.epoch):
errors=0
for xi,target in zip(X,y):
update=self.eta*(target-self.predict(xi))
self.w_[1:]+=update*xi
self.w_[0]+=update
errors+=int(update!=0.0)
# if errors==0:
# break
self.errors_.append(errors)
# print(self.errors_)
return self.w_
def predict(self,X):
‘‘‘
return y_hat
Parameters
--------
X:array_like
Input data
y:array_like or scalar
Input data‘s label
Returns
-------
predict:array_like or scalar
output y_hat
‘‘‘
z = np.dot(X, self.w_[1:])+self.w_[0]
y_hat = self.step(z)
return y_hat
if __name__==‘__main__‘:
‘‘‘
p=perceptron(0.1,10)
x=((0,0),(0,1),(1,0),(1,1))
y=(0,0,0,1)
x,y=np.asarray(x),np.asarray(y)
# p.predict(x)
k=p.fit(x,y)
print(k)
‘‘‘
X,y=load_iris(True)
x=np.concatenate((X,y.reshape(-1,1)),1)
data=pd.DataFrame(x)
# print(data)
data=data.drop_duplicates()
data=data[data[4]!=2]
x1,y1=data[[1,2,3]],data[4]
x1,y1=x1.values,y1.values
lenght=int(len(data)*0.8)
train_x,train_label=x1[:lenght,:],y1[:lenght]
test_x,test_label=x1[lenght:,:],y1[lenght:]
p = perceptron(0.1, 10)
w=p.fit(train_x,train_label)
# out=p.predict(train_x)
print(w)
print()
fi=np.dot(test_x,w[1:])+w[0]
out=p.step(fi)
right=np.sum(out==test_label)
# print(out)
# print(right)
print(str(right*100/len(test_label))+‘%‘)
mpl.rcParams[‘font.family‘]=‘SimHei‘
mpl.rcParams[‘axes.unicode_minus‘]=False
# plt.plot(out,‘go‘,ms=15,label=‘真實值‘)
# plt.plot(test_label,‘rx‘,ms=15,label=‘預測值‘)
plt.plot(p.errors_,‘o-‘)
plt.title(‘感知器實現鳶尾花的二分類‘)
# plt.legend()
plt.show()
感知器實現鳶尾花的分類