機器學習特徵工程——給任意屬性增加任意次方的全組合
在機器學習中,我們時常會碰到需要給屬性增加欄位的情況。譬如有x、y兩個屬性,當結果傾向於線性時,我們可以很簡單的通過線性迴歸得到模型。但很多時候,線性(在數學上稱為多元一次方程),線性是擬合不了結果的。
往往,我們就需要在給定的幾個屬性上,通過增加屬性來嘗試能否擬合。那麼原本只有兩列,x、y,我們增加2次方的屬性後,就會變成x、y、x^2、x*y、y^2,變成了5個屬性,根據以往經驗,我們知道通過這5個屬性是能擬合出曲線。
2次方時,我們還能很簡單的寫出來所有的組合形式,但是當5次方時,原本有4列時,我們該增加多少列,增加的列該怎麼計算呢。這就有點麻煩了,譬如(x+y+z)^3展開後就是x^3+y^3+z^3+3xy^2+3xz^2+3x^2y
譬如m為2,n也為2,那麼我們給出結果組合:[{0,2}, {1,1}, {2,0}],代表追加3列,第一列是x^0 * y^2,第二列是x^1 * y^1,第三列是x^2 * y^0.
通過觀察我們發現,我們需要做的是求這樣的方程的所有解:X1+X2+X3+……+Xm = N。其中0<=X<=n。
那麼解法就是,我們可以定義一個int[m],該陣列共有m個元素,每個元素的取值範圍在0到n之間,並且該陣列的所有元素的和等於n即可。
直接看程式:
/** * @author wuweifeng wrote on 2018/6/4. */ public class LineAdder { private static int lines = 3; private static int power = 5; private static int[] resultArray; public static void main(String[] args) { resultArray = new int[lines]; deal(0); } public static void deal(int m) { for (int i = 0; i <= power; i++) { resultArray[m] = i; if (m == lines - 1) { //如果找到一個解 if (check()) { print(); return; } } else { deal(m + 1); } } } /** * 判斷是否符合結果 * * @return 是否符合 */ private static boolean check() { int total = 0; for (int one : resultArray) { total += one; } return power == total; } private static void print() { for (int one : resultArray) { System.out.print(one); } System.out.print("\n"); } }
結果是:
005
014
023
032
041
050
104
113
122
131
140
203
212
221
230
302
311
320
401
410
500
這就是有3列,並且希望求出5次方時的所有組合的答案。下面我們將它優化一下,讓他能處理文字,能處理一行一行的資料,直接把列追加在文字上。
直接上程式碼:
package ploy;
import java.util.ArrayList;
import java.util.List;
/**
* @author wuweifeng wrote on 2018/6/4.
*/
public class LineAdder {
private int lines = 3;
private int power = 5;
private List<int[]> resultList = new ArrayList<>();
private int[] resultArray;
public List<int[]> lineAdd(int lines, int power) {
resultArray = new int[lines];
this.lines = lines;
this.power = power;
deal(0);
return resultList;
}
private void deal(int m) {
for (int i = 0; i <= power; i++) {
resultArray[m] = i;
if (m == lines - 1) {
//如果找到一個解
if (check()) {
print();
return;
}
} else {
deal(m + 1);
}
}
}
/**
* 判斷是否符合結果
*
* @return 是否符合
*/
private boolean check() {
int total = 0;
for (int one : resultArray) {
total += one;
}
return power == total;
}
private void print() {
for (int one : resultArray) {
System.out.print(one);
}
System.out.print("\n");
int[] temp = new int[resultArray.length];
System.arraycopy(resultArray, 0, temp, 0, resultArray.length);
resultList.add(temp);
}
}
package ploy;
import java.io.*;
import java.util.List;
/**
* @author wuweifeng wrote on 2018/6/5.
*/
public class TextDeal {
public static void main(String[] args) throws IOException {
new TextDeal().linePower("/Users/wuwf/Downloads/ml_data/1邏輯迴歸入門/train_test_deal.csv",
"/Users/wuwf/Downloads/ml_data/1邏輯迴歸入門/train_test_deal-3.csv", 3, 1, 2, 3, 6);
}
/**
* @param filePath
* 檔案的路徑
* @param outputPath
* 輸出檔案的路徑
* @param power
* 要做幾次方
* @param lineNums
* 都有哪幾列,需要power,不填預設所有列。從第0列開始
*/
public void linePower(String filePath, String outputPath, Integer power, Integer... lineNums) throws IOException {
BufferedReader reader = buildReader(filePath);
BufferedWriter writer = buildWriter(outputPath);
addCSVHeader(reader, writer, power, lineNums);
}
private Integer[] getLineNums(String[] lines, Integer... lineNums) {
//為null,則是所有列
if (lineNums == null) {
lineNums = new Integer[lines.length];
for (int i = 0; i < lines.length; i++) {
lineNums[i] = i;
}
}
return lineNums;
}
private List<int[]> getAddList(int power, Integer... lineNums) {
LineAdder lineAdder = new LineAdder();
//計算共需增加多少列
return lineAdder.lineAdd(lineNums.length, power);
}
/**
* 給header裡增加相應的列名,都在第一行
*/
private void addCSVHeader(BufferedReader reader, BufferedWriter writer, Integer power, Integer... lineNums)
throws IOException {
//讀取第一行
String header = reader.readLine();
//所有的列名
String[] lines = header.split(",");
lineNums = getLineNums(lines, lineNums);
//計算共需增加多少列
List<int[]> list = getAddList(power, lineNums);
String[] addLines = new String[list.size()];
String[] needLines = new String[lineNums.length];
for (int i = 0; i < lineNums.length; i++) {
needLines[i] = lines[lineNums[i]];
}
//設定每一列的名字
for (int i = 0; i < list.size(); i++) {
int[] array = list.get(i);
String s = "";
for (int j = 0; j < array.length; j++) {
s += needLines[j] + array[j];
}
addLines[i] = s;
}
for (String addLine : addLines) {
header += "," + addLine;
}
//將新增的列,寫入header檔案
writer.write(header);
writer.newLine();
writer.flush();
String oneLine;
while ((oneLine = reader.readLine()) != null) {
addLines = new String[list.size()];
lines = oneLine.split(",");
needLines = new String[lineNums.length];
for (int i = 0; i < lineNums.length; i++) {
needLines[i] = lines[lineNums[i]];
}
//設定每一列的值
for (int i = 0; i < list.size(); i++) {
int[] array = list.get(i);
double s = 1;
try {
for (int j = 0; j < array.length; j++) {
//譬如a,b,對應02時,該列就是a的0次方乘以b的2次方
s *= Math.pow(Double.valueOf(needLines[j]), array[j]);
}
addLines[i] = s + "";
} catch (Exception e) {
addLines[i] = "?";
}
}
for (String addLine : addLines) {
oneLine += "," + addLine;
}
writer.write(oneLine);
//寫入相關檔案
writer.newLine();
}
//將新增的列,寫入header檔案
writer.flush();
//關閉流
reader.close();
writer.close();
}
private BufferedReader buildReader(String filePath) {
try {
return new BufferedReader(new FileReader(new File(filePath)));
} catch (FileNotFoundException e) {
e.printStackTrace();
return null;
}
}
private BufferedWriter buildWriter(String outputPath) {
//寫入相應的檔案
try {
return new BufferedWriter(new OutputStreamWriter(new FileOutputStream(outputPath), "utf-8"));
} catch (UnsupportedEncodingException | FileNotFoundException e) {
e.printStackTrace();
return null;
}
}
}
假如csv檔案是這樣的
a,b
1,2
2,3
4,5
執行後,結果是
a,b,a0b2,a1b1,a2b0
1,2,4.0,2.0,1.0
2,3,9.0,6.0,4.0
4,5,25.0,20.0,16.0
可以看到已經完成了做2次方的展開。
這個類,可以完成任意次方的模擬及計算。