1. 程式人生 > >推薦演算法slope one之java實現

推薦演算法slope one之java實現

import java.util.*;

/**
 * Daniel Lemire A simple implementation of the weighted slope one algorithm in
 * Java for item-based collaborative filtering. Assumes Java 1.5.
 * 
 * See main function for example.
 * 
 * June 1st 2006. Revised by Marco Ponzi on March 29th 2007
 */

public class SlopeOne {

	public static void main(String args[]) {
		// this is my data base
		Map<UserId, Map<ItemId, Float>> data = new HashMap<UserId, Map<ItemId, Float>>();
		// items
		ItemId item1 = new ItemId("       candy");
		ItemId item2 = new ItemId("         dog");
		ItemId item3 = new ItemId("         cat");
		ItemId item4 = new ItemId("         war");
		ItemId item5 = new ItemId("strange food");

		mAllItems = new ItemId[] { item1, item2, item3, item4, item5 };

		// I'm going to fill it in
		HashMap<ItemId, Float> user1 = new HashMap<ItemId, Float>();
		HashMap<ItemId, Float> user2 = new HashMap<ItemId, Float>();
		HashMap<ItemId, Float> user3 = new HashMap<ItemId, Float>();
		HashMap<ItemId, Float> user4 = new HashMap<ItemId, Float>();
		user1.put(item1, 1.0f);
		user1.put(item2, 0.5f);
		user1.put(item4, 0.1f);
		data.put(new UserId("Bob"), user1);
		user2.put(item1, 1.0f);
		user2.put(item3, 0.5f);
		user2.put(item4, 0.2f);
		data.put(new UserId("Jane"), user2);
		user3.put(item1, 0.9f);
		user3.put(item2, 0.4f);
		user3.put(item3, 0.5f);
		user3.put(item4, 0.1f);
		data.put(new UserId("Jo"), user3);
		user4.put(item1, 0.1f);
		// user4.put(item2,0.4f);
		// user4.put(item3,0.5f);
		user4.put(item4, 1.0f);
		user4.put(item5, 0.4f);
		data.put(new UserId("StrangeJo"), user4);
		// next, I create my predictor engine
		SlopeOne so = new SlopeOne(data);
		System.out.println("Here's the data I have accumulated...");
		so.printData();
		// then, I'm going to test it out...
		HashMap<ItemId, Float> user = new HashMap<ItemId, Float>();
		System.out.println("Ok, now we predict...");
		user.put(item5, 0.4f);
		System.out.println("Inputting...");
		SlopeOne.print(user);
		System.out.println("Getting...");
		SlopeOne.print(so.predict(user));
		//
		user.put(item4, 0.2f);
		System.out.println("Inputting...");
		SlopeOne.print(user);
		System.out.println("Getting...");
		SlopeOne.print(so.predict(user));
	}

	Map<UserId, Map<ItemId, Float>> mData;
	Map<ItemId, Map<ItemId, Float>> mDiffMatrix;
	Map<ItemId, Map<ItemId, Integer>> mFreqMatrix;

	static ItemId[] mAllItems;

	public SlopeOne(Map<UserId, Map<ItemId, Float>> data) {
		mData = data;
		buildDiffMatrix();
	}

	/**
	 * Based on existing data, and using weights, try to predict all missing
	 * ratings. The trick to make this more scalable is to consider only
	 * mDiffMatrix entries having a large (>1) mFreqMatrix entry.
	 * 
	 * It will output the prediction 0 when no prediction is possible.
	 */
	public Map<ItemId, Float> predict(Map<ItemId, Float> user) {
		HashMap<ItemId, Float> predictions = new HashMap<ItemId, Float>();
		HashMap<ItemId, Integer> frequencies = new HashMap<ItemId, Integer>();
		for (ItemId j : mDiffMatrix.keySet()) {
			frequencies.put(j, 0);
			predictions.put(j, 0.0f);
		}
		for (ItemId j : user.keySet()) {
			for (ItemId k : mDiffMatrix.keySet()) {
				try {
					float newval = (mDiffMatrix.get(k).get(j).floatValue() + user
							.get(j).floatValue())
							* mFreqMatrix.get(k).get(j).intValue();
					predictions.put(k, predictions.get(k) + newval);
					frequencies.put(k, frequencies.get(k)
							+ mFreqMatrix.get(k).get(j).intValue());
				} catch (NullPointerException e) {
				}
			}
		}
		HashMap<ItemId, Float> cleanpredictions = new HashMap<ItemId, Float>();
		for (ItemId j : predictions.keySet()) {
			if (frequencies.get(j) > 0) {
				cleanpredictions.put(j, predictions.get(j).floatValue()
						/ frequencies.get(j).intValue());
			}
		}
		for (ItemId j : user.keySet()) {
			cleanpredictions.put(j, user.get(j));
		}
		return cleanpredictions;
	}

	/**
	 * Based on existing data, and not using weights, try to predict all missing
	 * ratings. The trick to make this more scalable is to consider only
	 * mDiffMatrix entries having a large (>1) mFreqMatrix entry.
	 */
	public Map<ItemId, Float> weightlesspredict(Map<ItemId, Float> user) {
		HashMap<ItemId, Float> predictions = new HashMap<ItemId, Float>();
		HashMap<ItemId, Integer> frequencies = new HashMap<ItemId, Integer>();
		for (ItemId j : mDiffMatrix.keySet()) {
			predictions.put(j, 0.0f);
			frequencies.put(j, 0);
		}
		for (ItemId j : user.keySet()) {
			for (ItemId k : mDiffMatrix.keySet()) {
				// System.out.println("Average diff between "+j+" and "+ k +
				// " is "+mDiffMatrix.get(k).get(j).floatValue()+" with n = "+mFreqMatrix.get(k).get(j).floatValue());
				float newval = (mDiffMatrix.get(k).get(j).floatValue() + user
						.get(j).floatValue());
				predictions.put(k, predictions.get(k) + newval);
			}
		}
		for (ItemId j : predictions.keySet()) {
			predictions.put(j, predictions.get(j).floatValue() / user.size());
		}
		for (ItemId j : user.keySet()) {
			predictions.put(j, user.get(j));
		}
		return predictions;
	}

	public void printData() {
		for (UserId user : mData.keySet()) {
			System.out.println(user);
			print(mData.get(user));
		}
		for (int i = 0; i < mAllItems.length; i++) {
			System.out.print("\n" + mAllItems[i] + ":");
			printMatrixes(mDiffMatrix.get(mAllItems[i]),
					mFreqMatrix.get(mAllItems[i]));
		}
	}

	private void printMatrixes(Map<ItemId, Float> ratings,
			Map<ItemId, Integer> frequencies) {
		for (int j = 0; j < mAllItems.length; j++) {
			System.out.format("%10.3f", ratings.get(mAllItems[j]));
			System.out.print(" ");
			System.out.format("%10d", frequencies.get(mAllItems[j]));
		}
		System.out.println();
	}

	public static void print(Map<ItemId, Float> user) {
		for (ItemId j : user.keySet()) {
			System.out.println(" " + j + " --> " + user.get(j).floatValue());
		}
	}

	public void buildDiffMatrix() {
		mDiffMatrix = new HashMap<ItemId, Map<ItemId, Float>>();
		mFreqMatrix = new HashMap<ItemId, Map<ItemId, Integer>>();
		// first iterate through users
		for (Map<ItemId, Float> user : mData.values()) {
			// then iterate through user data
			for (Map.Entry<ItemId, Float> entry : user.entrySet()) {
				if (!mDiffMatrix.containsKey(entry.getKey())) {
					mDiffMatrix.put(entry.getKey(),
							new HashMap<ItemId, Float>());
					mFreqMatrix.put(entry.getKey(),
							new HashMap<ItemId, Integer>());
				}
				for (Map.Entry<ItemId, Float> entry2 : user.entrySet()) {
					int oldcount = 0;
					if (mFreqMatrix.get(entry.getKey()).containsKey(
							entry2.getKey()))
						oldcount = mFreqMatrix.get(entry.getKey())
								.get(entry2.getKey()).intValue();
					float olddiff = 0.0f;
					if (mDiffMatrix.get(entry.getKey()).containsKey(
							entry2.getKey()))
						olddiff = mDiffMatrix.get(entry.getKey())
								.get(entry2.getKey()).floatValue();
					float observeddiff = entry.getValue() - entry2.getValue();
					mFreqMatrix.get(entry.getKey()).put(entry2.getKey(),
							oldcount + 1);
					mDiffMatrix.get(entry.getKey()).put(entry2.getKey(),
							olddiff + observeddiff);
				}
			}
		}
		for (ItemId j : mDiffMatrix.keySet()) {
			for (ItemId i : mDiffMatrix.get(j).keySet()) {
				float oldvalue = mDiffMatrix.get(j).get(i).floatValue();
				int count = mFreqMatrix.get(j).get(i).intValue();
				mDiffMatrix.get(j).put(i, oldvalue / count);
			}
		}
	}
}

class UserId {
	String content;

	public UserId(String s) {
		content = s;
	}

	public int hashCode() {
		return content.hashCode();
	}

	public String toString() {
		return content;
	}
}

class ItemId {
	String content;

	public ItemId(String s) {
		content = s;
	}

	public int hashCode() {
		return content.hashCode();
	}

	public String toString() {
		return content;
	}
}