Hadoop學習筆記三 -- 決策樹演算法實現使用者風險等級分類
前言
剛剛過去的2016年被稱為人工智慧的元年,在AlphaGo大戰李世石取得里程碑式的勝利後,神經網路和深度學習的概念瞬間進入了人們的視野,各大商業巨頭也紛紛將自己的目標轉移到這個還沒有任何明確方向但所有人都知道它一旦出手將改變世界的人工智慧方向中。在這個過程中,人們也突然發現在過去幾年大資料儲存技術和硬體處理能力不斷髮展,而產出卻有限,主要是面對如此紛繁複雜的資料,人們卻不知道如何利用。答案就在那裡,卻不知道如何尋找答案。所以資料探勘、機器學習的演算法的學習和研究又成了高度熱門的話題。本文繼上一篇部落格中研究的KNN演算法,對機器學習中另一個比較簡單的演算法 – 決策樹演算法進行學習和研究。KNN演算法是基於節點之間的歐式距離進行分類,演算法簡單易懂,比較大的缺陷是計算量比較大而且無法給出資料的內在含義,而決策樹演算法相對而言在資料內在含義方面有比較大的優勢,得到的結果也容易在業務上被理解。
決策樹演算法
決策樹演算法的規則跟人腦決策非常相似,通過一系列IF-ELSE的問題進行決策實現最終的分類。以下是一個極簡單的決策樹例子。
決策樹演算法執行的過程也是決策樹構造的過程,面對龐雜的資料,在構造決策樹時,需要解決的第一個問題就是當前資料集上哪個特徵在劃分資料分類時起決定性作用。如在上一個部落格中使用者風險等級劃分的案例,使用者有股票、基金及貴金屬投資,理財產品投資,存款機貨幣市場投資三個方面的資料,而實際的商業使用者有更多維度的資料,我們必須找到決定性的特徵,才能劃分最好的結果,所以我們必須評估每個特徵的重要性。在找到第一個決策點後,整個資料集就會被劃分成幾個分支,接下來再檢查這幾個分支下的資料是否屬於同一類,如果是同一類資料,則停止劃分,如果不屬於同一類資料,則需要繼續尋找決策點,建立分支的虛擬碼如下:
檢測資料集中每個子項是否屬於同一分類:
If so return 類標籤;
Else
尋找劃分資料集的最好特徵
劃分資料集
建立分支節點
For 每個劃分的子集
迭代並增加返回結果到分支節點中
return 分支節點
資訊增益
劃分資料集的最大原則是將無序的資料變得更加有序,在劃分資料集之前和之後資訊發生的變化稱為資訊增益,在計算完每個特徵值劃分資料集獲得的資訊增益後,獲得資訊增益最高的特徵就是最好的選擇。而集合資訊的度量方式稱為夏農熵。夏農熵的計算公式為
在MapReduce中實現每個維度資訊增益的計算。
public class CalcShannonEntMapper extends
Mapper<LongWritable, Text, Text, AttributeWritable> {
@Override
protected void setup(Context context) throws IOException,
InterruptedException {
super.setup(context);
}
@Override
protected void map(LongWritable key, Text value, Context context)
throws IOException, InterruptedException {
String line = value.toString();
StringTokenizer tokenizer = new StringTokenizer(line);
Long id = Long.parseLong(tokenizer.nextToken());
String category = tokenizer.nextToken();
boolean isCategory = true;
while (tokenizer.hasMoreTokens()) {
isCategory = false;
String attribute = tokenizer.nextToken();
String[] entry = attribute.split(":");
context.write(new Text(entry[0]), new AttributeWritable(id,
category, entry[1]));
}
if (isCategory) {
context.write(new Text(category), new AttributeWritable(id,
category, category));
}
}
@Override
protected void cleanup(Context context) throws IOException,
InterruptedException {
super.cleanup(context);
}
}
public class CalcShannonEntReducer extends
Reducer<Text, AttributeWritable, Text, AttributeGainWritable> {
@Override
protected void setup(Context context) throws IOException,
InterruptedException {
super.setup(context);
}
@Override
protected void reduce(Text key, Iterable<AttributeWritable> values,
Context context) throws IOException, InterruptedException {
String attributeName = key.toString();
double totalNum = 0.0;
Map<String, Map<String, Integer>> attrValueSplits = new HashMap<String, Map<String, Integer>>();
Iterator<AttributeWritable> iterator = values.iterator();
boolean isCategory = false;
while (iterator.hasNext()) {
AttributeWritable attribute = iterator.next();
String attributeValue = attribute.getAttributeValue();
if (attributeName.equals(attributeValue)) {
isCategory = true;
break;
}
Map<String, Integer> attrValueSplit = attrValueSplits
.get(attributeValue);
if (null == attrValueSplit) {
attrValueSplit = new HashMap<String, Integer>();
attrValueSplits.put(attributeValue, attrValueSplit);
}
String category = attribute.getCategory();
Integer categoryNum = attrValueSplit.get(category);
attrValueSplit.put(category, null == categoryNum ? 1
: categoryNum + 1);
totalNum++;
}
if (isCategory) {
System.out.println("is Category");
int sum = 0;
iterator = values.iterator();
while (iterator.hasNext()) {
iterator.next();
sum += 1;
}
System.out.println("sum: " + sum);
context.write(key, new AttributeGainWritable(attributeName, sum,
true, null));
} else {
double gainInfo = 0.0;
double splitInfo = 0.0;
for (Map<String, Integer> attrValueSplit : attrValueSplits.values()) {
double totalCategoryNum = 0;
for (Integer categoryNum : attrValueSplit.values()) {
totalCategoryNum += categoryNum;
}
double entropy = 0.0;
for (Integer categoryNum : attrValueSplit.values()) {
double p = categoryNum / totalCategoryNum;
entropy -= p * (Math.log(p) / Math.log(2));
}
double dj = totalCategoryNum / totalNum;
gainInfo += dj * entropy;
splitInfo -= dj * (Math.log(dj) / Math.log(2));
}
double gainRatio = splitInfo == 0.0 ? 0.0 : gainInfo / splitInfo;
StringBuilder splitPoints = new StringBuilder();
for (String attrValue : attrValueSplits.keySet()) {
splitPoints.append(attrValue).append(",");
}
splitPoints.deleteCharAt(splitPoints.length() - 1);
context.write(key, new AttributeGainWritable(attributeName,
gainRatio, false, splitPoints.toString()));
}
}
@Override
protected void cleanup(Context context) throws IOException,
InterruptedException {
super.cleanup(context);
}
}
實驗
我們還是用上一篇部落格中使用者風險等級分類的例子中的資料,去測試決策樹演算法的優劣,但由於決策樹演算法只能對是或者否進行判斷,所以,對案例中的資料進行了改造,示例如下:
使用者 | 股票、基金及貴金屬投資 | 理財產品投資 | 存款及貨幣市場投資 | 風險分類 |
---|---|---|---|---|
1 | 1 | 1 | 1 | high |
2 | 0 | 1 | 1 | middle |
3 | 0 | 0 | 1 | low |
把一組已經打好標籤的資料作為訓練資料,另一組沒打標籤的資料作為測試資料,測試的結果如下:
實驗結果非常具有可讀性,也符合業務的常理,但是由於決策樹演算法只能輸入0-1資料,運算結果的錯誤率為13.6%,相對KNN來說是錯誤率是提高了,要進一步降低錯誤率,可以增加判斷的維度,比如對於理財產品來說,有不同型別的理財產品,可以依據理財產品的型別增加幾個維度等。