【學習排序】 Learning to Rank 中Listwise關於ListNet演算法講解及實現
阿新 • • 發佈:2019-02-05
程式碼如下:
package listNet_xiuzhang; import java.io.BufferedReader; import java.io.File; import java.io.FileInputStream; import java.io.FileWriter; import java.io.InputStreamReader; public class listNet { //檔案總行數(標記數) private static int sumLabel; //特徵值 46個 (標號1-46) private static double feature[][] = new double[100000][48]; //特徵值權重 46個 (標號1-46) private static double weight [] = new double[48]; //相關度 其值有0-2三個級別 從1開始記錄 private static int label [] = new int[1000000]; //查詢id 從1開始記錄 private static int qid [] = new int[1000000]; //每個Qid的doc數量 private static int doc_ofQid[] = new int[100000]; private static int ITER_NUM=30; //迭代次數 private static int weidu=46; //特徵數 private static int qid_Num=0; //Qid數量 private static int tempQid=-1; //臨時Qid數 private static int tempDoc=0; //臨時doc數 /** * 函式功能 讀取檔案 * 引數 String filePath 檔案路徑 */ public static void ReadTxtFile(String filePath) { try { String encoding="GBK"; File file=new File(filePath); if(file.isFile() && file.exists()) { //判斷檔案是否存在 InputStreamReader read = new InputStreamReader(new FileInputStream(file), encoding); BufferedReader bufferedReader = new BufferedReader(read); String lineTxt = null; sumLabel =1; //初始化從1記錄 //按行讀取資料並分解資料 while((lineTxt = bufferedReader.readLine()) != null) { String str = null; int lengthLine = lineTxt.length(); //獲取資料 字串空格分隔 String arrays[] = lineTxt.split(" "); for(int i=0; i<arrays.length; i++) { //獲取每行樣本的Label值 if(i==0) { label[sumLabel] = Integer.parseInt(arrays[0]); } else if(i>=weidu+2){ //讀取至#跳出 0-label 1-qid 2:47-特徵 continue; } else { String subArrays[] = arrays[i].split(":"); //特徵:特徵值 if(i==1) { //獲取qid //判斷是否是新的Qid if(tempQid != Integer.parseInt(subArrays[1])) { if(tempQid != -1){ //不是第一次出現新Qid //賦值上一個為qid_Num對應的tempDoc個文件 doc_ofQid[qid_Num]=tempDoc; tempDoc=0; } //當tempQid不等於當前qid時下標加1 //相等則直接跳至Doc加1直到不等 qid_Num++; tempQid=Integer.parseInt(subArrays[1]); } tempDoc++; //新的文件 qid[sumLabel] = Integer.parseInt(subArrays[1]); } else { //獲取46維特徵值 int number = Integer.parseInt(subArrays[0]); //判斷特徵 double value = Double.parseDouble(subArrays[1]); feature[sumLabel][number] = value; //number陣列標號:1-46 } } } sumLabel++; } doc_ofQid[qid_Num]=tempDoc; read.close(); } else { System.out.println("找不到指定的檔案\n"); } } catch (Exception e) { System.out.println("讀取檔案內容出錯"); e.printStackTrace(); } } /** * 學習排序 * 訓練模型得到46維權重 */ public static void LearningToRank() { //變數 double index [] = new double[1000000]; double tao [] = new double[1000000]; double yita=0.00003; //初始化 for(int i=0;i<weidu+2;i++) { //從1到136為權重,0和137無用 weight[i] = (double) 1.0; //權重初值 } System.out.println("training..."); //計算權重 學習演算法 for(int iter = 0; iter<ITER_NUM; iter++) //迭代ITER_NUM次 { System.out.println("---迭代次數:"+iter); int now_doc=0; //全域性文件索引 for(int i=1; i<=qid_Num; i++) //總樣qid數 相當於兩層迴圈T和m { double delta_w[] = new double[weidu+2]; //46個梯度組成的向量 int doc_of_i=doc_ofQid[i]; //該Qid的文件數 //得分f(w),一個QID有多個文件,一個文件為一個分,所以一個i對應一個分數陣列 double fw[] = new double[doc_of_i+2]; /* 第一步 算得分陣列fw fin */ for(int k=1;k<=doc_of_i;k++) { //初始化 fw[k]=0.0; } for(int k=1;k<=doc_of_i;k++) { //每個文件的得分 for(int p=1;p<=weidu;p++) { fw[k]=fw[k]+weight[p]*feature[now_doc+k][p]; //算出這個文件的分數 } } /* * 第二步 算梯度delta_w向量 * a=Σp*x,a是向量 * b=Σexpf(x),b是數字 * c=expf(x)*x,c是向量 * 最終結果delta_w是向量 */ double[] a=new double[weidu+2],c=new double[weidu+2]; for(int k=0;k<weidu+2;k++){a[k]=0.0;} //初始化 for(int k=0;k<weidu+2;k++){c[k]=0.0;} //初始化 double b=0.0; //算a:---- for(int k=1; k<=doc_of_i; k++) { double p=1.0; //先不topK double[] temp=new double[48]; for(int q=1;q<=weidu;q++) { //算P: ----第q個向量排XX的概率是多少 //分母: double fenmu=0.0; for(int m=1;m<=doc_of_i;m++) { fenmu=fenmu+Math.exp(fw[m]); //所有文件得分 } //top-1 exp(s1) / exp(s1)+exp(s2)+..+exp(sn) for(int m=1;m<=doc_of_i;m++) { p=p*(Math.exp(fw[m])/fenmu); } //算積 temp[q]=temp[q]+p*feature[now_doc+k][q]; } for(int q=1; q<=weidu; q++){ a[q]=a[q]+temp[q]; } } //End a //算b:---- fin. for(int k=1; k<=doc_of_i; k++){ b=b+Math.exp(fw[k]); } //算c:---- for(int k=1; k<=doc_of_i; k++){ double[] temp=new double[weidu+2]; for(int q=1; q<=weidu; q++){ temp[q]=temp[q]+Math.exp(fw[k])*feature[now_doc+k][q]; } for(int q=1; q<=weidu; q++){ c[q]=c[q]+temp[q]; } } //算梯度:delta_x=-a+1/b*c for(int q=1; q<=weidu; q++){ delta_w[q]= (-1)*a[q] + ((1.0/b)*c[q]); } //********** /* 第三步 更新權重 fin. */ for(int k=1; k<=weidu; k++){ weight[k]=weight[k]-yita*delta_w[k]; } now_doc=now_doc+doc_of_i; //更新當前文件索引 } } //End 迭代次數 //輸出權重 for(int i=1;i<=weidu;i++) //從1到136為權重,0和137無用 { System.out.println(i+"wei:"+weight[i]); } } /** * 輸出權重到檔案fileModel * @param fileModel */ public static void WriteFileModel(String fileModel) { //輸出權重到檔案 try { System.out.println("write start.總行數:"+sumLabel); FileWriter fileWriter = new FileWriter(fileModel); //寫資料 fileWriter.write("## ListNet"); fileWriter.write("\r\n"); fileWriter.write("## Epochs = "+ITER_NUM); fileWriter.write("\r\n"); fileWriter.write("## No. of features = 46"); fileWriter.write("\r\n"); fileWriter.write("1 2 3 4 5 6 7 8 9 10 ... 39 40 41 42 43 44 45 46"); fileWriter.write("\r\n"); fileWriter.write("0"); fileWriter.write("\r\n"); for(int k=0; k<weidu; k++){ fileWriter.write("0 "+k+" "+weight[k+1]); fileWriter.write("\r\n"); } fileWriter.close(); System.out.println("write fin."); } catch(Exception e) { System.out.println("寫檔案內容出錯"); e.printStackTrace(); } } /** * 預測排序 * 正規應對test.txt檔案進行打分排序 * 但我們是在Hadoop實現該打分排序步驟 此函式僅測試train.txt打分 */ public static void PredictRank(String fileScore) { //輸出得分 try { System.out.println("write start.總行數:"+sumLabel); String encoding = "GBK"; FileWriter fileWriter = new FileWriter(fileScore); //寫資料 for(int k=1; k<sumLabel; k++){ double score=0.0; for(int j=1;j<=weidu;j++){ score=score+weight[j]*feature[k][j]; } fileWriter.write("qid:"+qid[k]+" score:"+score+" label:"+label[k]); fileWriter.write("\r\n"); } fileWriter.close(); System.out.println("write fin."); } catch(Exception e) { System.out.println("寫檔案內容出錯"); e.printStackTrace(); } } /** * 主函式 */ public static void main(String args[]) { String fileInput = "Fold1\\train.txt"; //訓練 String fileModel = "model_weight.txt"; //輸出權重模型 String fileScore = "score_listNet.txt"; //輸出得分 //第1步 讀取檔案並解析資料 System.out.println("read..."); ReadTxtFile(fileInput); System.out.println("read and write well."); //第2步 排序計算 LearningToRank(); //第3步 輸出模型 WriteFileModel(fileModel); //第4步 打分預測排序 PredictRank(fileScore); } /* * End */ }