資料探勘Apriori演算法JAVA實現
阿新 • • 發佈:2021-01-18
實驗結果
最小支援度為0.005時只需10s即可得到結果並且輸出所有頻繁項集
程式碼
package com.company;
import java.io.BufferedReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.nio.charset.StandardCharsets;
import java.util.*;
public class Main {
private static final ArrayList<int[]> data = new ArrayList();
private static final HashMap<String, Integer> fir = new HashMap<>();
private static double SUPPORT_PERCENT;
private static int[] a;
private static int[] b;
private static int Firlen;
private static int Datalen;
private static int sum = 0;
private static List<List<Integer>> ck = new ArrayList<>();
public static void main(String[] args) throws IOException {
// write your code here
Scanner s = new Scanner(System.in);
SUPPORT_PERCENT = s.nextDouble ();
long startTime = System.currentTimeMillis();
loadData();
SUPPORT_PERCENT = SUPPORT_PERCENT * data.size();
a = new int[Firlen];
b = new int[Firlen * (Firlen - 1) / 2];
Countandhash();
com();
showfir();
while (ck.size() > 1) {
link();
if (ck.size() == 0)
break;
tradata();
showitem();
delete();
System.out.println("總數:" + sum);
}
long endTime = System.currentTimeMillis();
System.out.println("程式執行時間:" + (endTime - startTime) + "ms");
}
private static void loadData() throws IOException {
try (InputStreamReader reader = new InputStreamReader(new FileInputStream("retail.dat"),
StandardCharsets.UTF_8)
) {
BufferedReader bufferedReader = new BufferedReader(reader);
String line;
while ((line = bufferedReader.readLine()) != null) {
String[] temp = line.split(" ");
int[] arrayList = new int[temp.length];
Datalen += temp.length;
for (int i = 0; i < temp.length; i++) {
arrayList[i] = Integer.parseInt(temp[i]);
if (fir.get(temp[i]) == null)
fir.put(temp[i], 1);
else
fir.put(temp[i], fir.get(temp[i]) + 1);
}
data.add(arrayList);
}
Firlen = fir.size();
}
}
public static void Countandhash() {
for (int[] array :
data) {
if (array.length > 1) {
for (int i = 0; i < array.length - 1; i++) {
for (int j = i + 1; j < array.length; j++) {
int temp = (array[i] * (Firlen - 1) + array[j]) %
(Datalen + 1);
b[temp]+=1;
}
}
}
for (int j : array) {
a[j] += 1;
}
}
}
public static void com() {
for (int i = 0; i < a.length; i++) {
List<Integer> list = new ArrayList<>();
if (a[i] >= SUPPORT_PERCENT)
list.add(i);
else
continue;
ck.add(list);
}
}
public static void showfir() {
int count = 0;
System.out.println("1項集如下:");
for (int i = 0; i < a.length; i++) {
if (a[i] >= SUPPORT_PERCENT) {
count++;
System.out.println("{" + i + "}:" + (double) a[i] / Datalen);
}
}
System.out.println("1項集個數:" + count);
sum += count;
}
public static void link() {
List<Integer> t1 = new ArrayList<>();
List<List<Integer>> t2 = new ArrayList<>();
if (ck.get(0).size() == 1) {
for (int i = 0; i < ck.size() - 1; i++) {
for (int j = i + 1; j < ck.size(); j++) {
int temp = (ck.get(i).get(0) * (Firlen - 1) + ck.get(j).get(0)) %
(Datalen + 1);
if (b[temp] >= SUPPORT_PERCENT) {
t1.add(ck.get(i).get(0));
t1.add(ck.get(j).get(0));
t1.add(0);
t2.add(t1);
//t1.clear();
t1 = new ArrayList<>();
}
}
}
ck = t2;
} else {
for (int i = 0; i < ck.size() - 1; i++) {
for (int j = i + 1; j < ck.size(); j++) {
int t = 0;
while (t < ck.get(0).size() - 1) {
if (Objects.equals(ck.get(i).get(t), ck.get(j).get(t)))
t1.add(ck.get(i).get(t));
else
break;
t++;
}
if (t == ck.get(0).size() - 1) {
t1.add(ck.get(i).get(ck.get(0).size() - 1));
t1.add(ck.get(j).get(ck.get(0).size() - 1));
t1.add(0);
t2.add(t1);
t1 = new ArrayList<>();
} else
t1 = new ArrayList<>();
}
}
}
if (t2.size() != 0 && t2.get(0).size() >= 4) {
List<List<Integer>> t3 = new ArrayList<List<Integer>>();
for (List<Integer> integers : t2) {
List<Integer> temp = new ArrayList<>();
for (int j = 1; j < t2.get(0).size() - 1; j++) {
temp.add(t2.get(0).get(j));
}
for (List<Integer> integerList : ck) {
int n = 0;
while (Objects.equals(integerList.get(n), temp.get(n))) {
n++;
if (n == temp.size()) {
for (int k = 0; k < t2.get(0).size(); k++) {
t1.add(integers.get(k));
}
t3.add(t1);
t1 = new ArrayList<>();
break;
}
}
if (n == temp.size()) {
break;
}
}
}
ck = t3;
} else
ck = t2;
}
public static void tradata() {
for (int[] datum : data) {
int m = 0;
while (m < ck.size()) {
if (ck.get(m).get(1) <= datum[datum.length - 1]) {
int n = 0;
for (int i : datum) {
if (ck.get(m).get(n) < i)
break;
else {
if (i == ck.get(m).get(n)) {
n++;
if (n == ck.get(m).size() - 1) {
int temp = ck.get(m).get(n) + 1;
ck.get(m).set(n, temp);
break;
}
}
}
}
}
m++;
}
}
}
public static void showitem() {
int co = 0;
System.out.println(ck.get(0).size() - 1 + "項集如下:");
for (List<Integer> integers : ck) {
if (integers.get(integers.size() - 1) >= SUPPORT_PERCENT) {
co++;
System.out.print("{");
for (int j = 0; j < integers.size() - 1; j++) {
System.out.print(integers.get(j) + " ");
}
System.out.println("}:" + (double) integers.get(integers.size() - 1) / Datalen);
}
}
System.out.println(ck.get(0).size() - 1 + "項集個數:" + co);
sum += co;
}
public static void delete() {
List<List<Integer>> t2 = new ArrayList<>();
for (List<Integer> integers : ck) {
List<Integer> t1 = new ArrayList<>();
if (integers.get(integers.size() - 1) >= SUPPORT_PERCENT) {
for (int j = 0; j < integers.size() - 1; j++) {
t1.add(integers.get(j));
}
} else
continue;
t2.add(t1);
}
ck = t2;
}
}