1. 程式人生 > 其它 >用隨機森林做分類

用隨機森林做分類

import matplotlib.pyplot as plt import numpy as np import pandas as pd import torch import torch.fft as fft df = pd.read_csv('train.csv') df=df.drop(['ID'],axis=1) nmp=df.to_numpy() feature=nmp[:-20,:-1] label=nmp[:-20,-1]#(210,240) feature=torch.fft.fft(torch.Tensor(feature)) feature=torch.abs(feature)/240*2 feature=feature.detach().numpy() sum=1 li=[] for i in range(feature.shape[0]):     index=feature[i,:]>=0.2     index=index.astype(np.int)     index=np.nonzero(index)
    for j in index:         for j1 in j:             if j1 not in li:                 li.append(j1) print(li) print(len(li))
df = pd.read_csv('train.csv') df=df.drop(['ID'],axis=1) nmp=df.to_numpy() feature=nmp[:-20,:-1] label=nmp[:-20,-1]#(210,240) feature=torch.fft.fft(torch.Tensor(feature)) feature=torch.abs(feature)/240*2 feature=feature[:,li] feature=feature.detach().numpy() test_feature=nmp[-20:,:-1] test_label=nmp[-20:,-1]#(210,240)
test_feature=torch.fft.fft(torch.Tensor(test_feature)) test_feature=torch.abs(test_feature)/240*2 test_feature=test_feature[:,li] from torch import nn import torch label=label.reshape(-1,1) test_label=test_label.reshape(-1,1)
from sklearn import svm import matplotlib.pyplot as plt from sklearn import tree
from sklearn.ensemble import RandomForestClassifier clf=RandomForestClassifier(n_estimators=2000,max_depth=8) # .SVC()就是 SVM 的方程,引數 kernel 為線性核函式 # 訓練分類器 準確率0.83效果不太好。 import sklearn from sklearn.metrics import accuracy_score clf.fit(feature, label) w=clf.predict(feature) pr=accuracy_score(label, w) print(pr)
w=clf.predict(test_feature) pr=accuracy_score(test_label, w) print(pr) df = pd.read_csv('test.csv') df=df.drop(['ID'],axis=1) nmp=df.to_numpy() feature=nmp[:,:] feature=torch.fft.fft(torch.Tensor(feature)) feature=torch.abs(feature)/240*2 feature=torch.Tensor(feature[:,li]) feature=feature.detach().numpy() out=clf.predict(feature) out=pd.DataFrame(out) out.columns = ['CLASS'] w=[] for k in range(out.shape[0]):     w.append(k+210) out['ID']=np.reshape(w,(-1,1)) out[['ID','CLASS']].to_csv('out.csv',index=False)