使用模擬退火演算法優化 Hash 函式
阿新 • • 發佈:2020-10-04
# 背景
現有個處理股票行情訊息的系統,其架構如下:
將 5000 個非負整數分配至 15 個桶`(bucket)`中,並儘可能保證每個桶中的元素之和接近(每個桶中的元素個數無限制)。
每個整數元素可能的放置方法有 15 種,這個問題總共可能的解有 155000種,暴力求解的可能性微乎其微。作為工程問題,最優解不是必要的,可以退而求其次尋找一個可接受的次優解:
symbolAndCount = new TreeMap<>();
for (int i=0; i distribution = findBestDistribution(symbolAndCount);
// 測試效果
int[] buckets = new int[NUM_OF_BUCKETS];
for (Map.Entry entry : symbolAndCount.entrySet()) {
Map.Entry floor = distribution.floorEntry(entry.getKey());
int bucketIndex = floor == null ? 0 : floor.getValue();
buckets[bucketIndex] += entry.getValue();
}
System.out.printf("buckets: %s\n", Arrays.toString(buckets));
}
public static TreeMap findBestDistribution(Map symbolAndCount) {
// 每個桶均勻分佈的情況(最優情況)
int avg = symbolAndCount.values().stream().mapToInt(Integer::intValue).sum() / NUM_OF_BUCKETS;
// 嘗試將 symbol 放入不同的桶
int bucketIdx = 0;
int[] buckets = new int[NUM_OF_BUCKETS];
String[] bulkheads = new String[NUM_OF_BUCKETS-1];
for (Map.Entry entry : symbolAndCount.entrySet()) {
// 如果首個 symbol 資料量過大,則分配給其一個獨立的桶
int count = entry.getValue();
if (count / 2 > avg && bucketIdx == 0 && buckets[0] == 0) {
buckets[bucketIdx] += count;
continue;
}
// 評估將 symbol 放入桶後的效果
// 1. 如果桶中的數量更接近期望,則將其放入當前桶中
// 2. 如果桶中的數量更遠離期望,則將其放入下個桶中
double before = Math.abs(buckets[bucketIdx] - avg);
double after = Math.abs(buckets[bucketIdx] + count - avg);
if (after > before && bucketIdx < buckets.length - 1) {
bulkheads[bucketIdx++] = entry.getKey();
}
buckets[bucketIdx] += count;
}
System.out.printf("expectation: %d\n", avg);
System.out.printf("bulkheads: %s\n", Arrays.toString(bulkheads));
TreeMap distribution = new TreeMap<>();
for (int i=0; i32 -1。在 5000 個 symbol 的情況下,單執行緒嘗試遍歷所有 seed 的時間約為 25 小時。
通常情況下 symbol 的數量會超過 5000,因此實際的搜尋時間會大於這個值。此外,受限於計算資源限制,無法進行大規模的並行搜尋,因此窮舉法的耗時是不可接受的。
幸好本例並不要求最優解,可以引入啟發式搜尋演算法,加快訓練速度。由於本人在這方面並不熟悉,為了降低程式設計難度,最終選擇了模擬退火`(simulated annealing)`演算法。它模擬固體退火過程的熱平衡問題與隨機搜尋尋優問題的相似性來達到尋找全域性最優或近似全域性最優的目的。
相較於最簡單的爬山法,模擬退火演算法通以一定的概率接受較差的解,從而擴大搜索範圍,保證解近似最優。
```
/**
* Basic framework of simulated annealing algorithm
* @param the solution of given problem
*/
public abstract class SimulatedAnnealing {
protected final int numberOfIterations; // stopping condition for simulations
protected final double coolingRate; // the percentage by which we reduce the temperature of the system
protected final double initialTemperature; // the starting energy of the system
protected final double minimumTemperature; // optional stopping condition
protected final long simulationTime; // optional stopping condition
protected final int detectionInterval; // optional stopping condition
protected SimulatedAnnealing(int numberOfIterations, double coolingRate) {
this(numberOfIterations, coolingRate, 10000000, 1, 0, 0);
}
protected SimulatedAnnealing(int numberOfIterations, double coolingRate, double initialTemperature, double minimumTemperature, long simulationTime, int detectionInterval) {
this.numberOfIterations = numberOfIterations;
this.coolingRate = coolingRate;
this.initialTemperature = initialTemperature;
this.minimumTemperature = minimumTemperature;
this.simulationTime = simulationTime;
this.detectionInterval = detectionInterval;
}
protected abstract double score(X currentSolution);
protected abstract X neighbourSolution(X currentSolution);
public X simulateAnnealing(X currentSolution) {
final long startTime = System.currentTimeMillis();
// Initialize searching
X bestSolution = currentSolution;
double bestScore = score(bestSolution);
double currentScore = bestScore;
double t = initialTemperature;
for (int i = 0; i < numberOfIterations; i++) {
if (currentScore < bestScore) {
// If the new solution is better, accept it unconditionally
bestScore = currentScore;
bestSolution = currentSolution;
} else {
// If the new solution is worse, calculate an acceptance probability for the worse solution
// At high temperatures, the system is more likely to accept the solutions that are worse
boolean rejectWorse = Math.exp((bestScore - currentScore) / t) < Math.random();
if (rejectWorse || currentScore == bestScore) {
currentSolution = neighbourSolution(currentSolution);
currentScore = score(currentSolution);
}
}
// Stop searching when the temperature is too low
if ((t *= coolingRate) < minimumTemperature) {
break;
}
// Stop searching when simulation time runs out
if (simulationTime > 0 && (i+1) % detectionInterval == 0) {
if (System.currentTimeMillis() - startTime > simulationTime)
break;
}
}
return bestSolution;
}
}
```
```
/**
* Search best hash seed for given key distribution and number of buckets with simulated annealing algorithm
*/
@Data
public class SimulatedAnnealingHashing extends SimulatedAnnealing {
private static final int DISTRIBUTION_BATCH = 100;
static final int SEARCH_BATCH = 200;
private final int[] hashCodes = new int[SEARCH_BATCH];
private final long[][] buckets = new long[SEARCH_BATCH][];
@Data
public class HashingSolution {
private final int begin, range; // the begin and range for searching
private int bestSeed; // the best seed found in this search
private long bestScore; // the score corresponding to bestSeed
private long calculateDivergence(long[] bucket) {
long divergence = 0;
for (int i=0; i keyAndCounts, int numOfBuckets) {
super(100000000, .9999);
distributions = buildDistribution(keyAndCounts);
long sum = 0;
for (KeyDistribution[] batch : distributions) {
for (KeyDistribution distribution : batch) {
sum += distribution.getCount();
}
}
this.expectation = sum / numOfBuckets;
this.searchOutset = 0;
for (int i = 0; i< buckets.length; i++) {
buckets[i] = new long[numOfBuckets];
}
}
/**
* SimulatedAnnealingHashing Derivative
* @param prototype prototype simulation
* @param searchOutset the outset for searching
* @param simulationTime the expect time consuming for simulation
*/
private SimulatedAnnealingHashing(SimulatedAnnealingHashing prototype, int searchOutset, long simulationTime) {
super(prototype.numberOfIterations, prototype.coolingRate, prototype.initialTemperature, prototype.minimumTemperature,
simulationTime, 10000);
distributions = prototype.distributions;
expectation = prototype.expectation;
for (int i = 0; i< buckets.length; i++) {
buckets[i] = new long[prototype.buckets[i].length];
}
this.searchOutset = searchOutset;
this.searchMax = searchMin = searchOutset;
}
@Override
public String toString() {
return String.format("expectation: %d, outset:%d, search(min:%d, max:%d)", expectation, searchOutset, searchMin, searchMax);
}
private KeyDistribution[][] buildDistribution(Map symbolCounts) {
int bucketNum = symbolCounts.size() / DISTRIBUTION_BATCH + Integer.signum(symbolCounts.size() % DISTRIBUTION_BATCH);
KeyDistribution[][] distributions = new KeyDistribution[bucketNum][];
int bucketIndex = 0;
List batch = new ArrayList<>(DISTRIBUTION_BATCH);
for (Map.Entry entry : symbolCounts.entrySet()) {
batch.add(new KeyDistribution(entry.getKey().toCharArray(), entry.getValue()));
if (batch.size() == DISTRIBUTION_BATCH) {
distributions[bucketIndex++] = batch.toArray(new KeyDistribution[0]);
batch.clear();
}
}
if (batch.size() > 0) {
distributions[bucketIndex] = batch.toArray(new KeyDistribution[0]);
batch.clear();
}
return distributions;
}
@Override
protected double score(HashingSolution currentSolution) {
return currentSolution.solve().bestScore;
}
@Override
protected HashingSolution neighbourSolution(HashingSolution currentSolution) {
// The default range of neighbourhood is [-100, 100]
int rand = ThreadLocalRandom.current().nextInt(-100, 101);
int next = currentSolution.begin + rand;
searchMin = Math.min(next, searchMin);
searchMax = Math.max(next, searchMax);
return new HashingSolution(next, currentSolution.range);
}
public HashingSolution solve() {
searchMin = searchMax = searchOutset;
HashingSolution initialSolution = new HashingSolution(searchOutset, SEARCH_BATCH);
return simulateAnnealing(initialSolution);
}
public SimulatedAnnealingHashing derive(int searchOutset, long simulationTime) {
return new SimulatedAnnealingHashing(this, searchOutset, simulationTime);
}
}
```
## ForkJoin 框架
為了達到更好的搜尋效果,可以將整個搜尋區域遞迴地劃分為兩兩相鄰的區域,然後在這些區域上執行併發的搜尋,並遞迴地合併相鄰區域的搜尋結果。
使用 JDK 提供的 ForkJoinPool 與 RecursiveTask 能很好地完成以上任務。
```
@Data
@Slf4j
public class HashingSeedCalculator {
/**
* Recursive search task
*/
private class HashingSeedCalculatorSearchTask extends RecursiveTask {
private SimulatedAnnealingHashing simulation;
private final int level;
private final int center, range;
private HashingSeedCalculatorSearchTask() {
this.center = 0;
this.range = Integer.MAX_VALUE / SimulatedAnnealingHashing.SEARCH_BATCH;
this.level = traversalDepth;
this.simulation = hashingSimulation;
}
private HashingSeedCalculatorSearchTask(HashingSeedCalculatorSearchTask parent, int center, int range) {
this.center = center;
this.range = range;
this.level = parent.level - 1;
this.simulation = parent.simulation;
}
@Override
protected HashingSolution compute() {
if (level == 0) {
long actualCenter = center * SimulatedAnnealingHashing.SEARCH_BATCH;
log.info("Searching around center {}", actualCenter);
HashingSolution solution = simulation.derive(center, perShardRunningMills).solve();
log.info("Searching around center {} found {}", actualCenter, solution);
return solution;
} else {
int halfRange = range / 2;
int leftCenter = center - halfRange, rightCenter = center + halfRange;
ForkJoinTask leftTask = new HashingSeedCalculatorSearchTask(this, leftCenter, halfRange).fork();
ForkJoinTask rightTask = new HashingSeedCalculatorSearchTask(this, rightCenter, halfRange).fork();
HashingSolution left = leftTask.join();
HashingSolution right = rightTask.join();
return left.getBestScore() < right.getBestScore() ? left : right;
}
}
}
private final int poolParallelism;
private final int traversalDepth;
private final long perShardRunningMills;
private final SimulatedAnnealingHashing hashingSimulation;
/**
* HashingSeedCalculator
* @param numberOfShards the shard of the whole search range [Integer.MIN_VALUE, Integer.MAX_VALUE]
* @param totalRunningHours the expect total time consuming for searching
* @param symbolCounts the key and it`s distribution
* @param numOfBuckets the number of buckets
*/
public HashingSeedCalculator(int numberOfShards, int totalRunningHours, Map symbolCounts, int numOfBuckets) {
int n = (int) (Math.log(numberOfShards) / Math.log(2));
if (Math.pow(2, n) != numberOfShards) {
throw new IllegalArgumentException();
}
this.traversalDepth = n;
this.poolParallelism = Math.max(ForkJoinPool.getCommonPoolParallelism() / 3 * 2, 1); // conservative estimation for parallelism
this.perShardRunningMills = TimeUnit.HOURS.toMillis(totalRunningHours * poolParallelism) / numberOfShards;
this.hashingSimulation = new SimulatedAnnealingHashing(symbolCounts, numOfBuckets);
}
@Override
public String toString() {
int numberOfShards = (int) Math.pow(2, traversalDepth);
int totalRunningHours = (int) TimeUnit.MILLISECONDS.toHours(perShardRunningMills * numberOfShards) / poolParallelism;
return "HashingSeedCalculator(" +
"numberOfShards: " + numberOfShards +
", perShardRunningMinutes: " + TimeUnit.MILLISECONDS.toMinutes(perShardRunningMills) +
", totalRunningHours: " + totalRunningHours +
", poolParallelism: " + poolParallelism +
", traversalDepth: " + traversalDepth + ")";
}
public synchronized HashingSolution searchBestSeed() {
long now = System.currentTimeMillis();
log.info("SearchBestSeed start");
ForkJoinTask root = new HashingSeedCalculatorSearchTask().fork();
HashingSolution initSolution = hashingSimulation.derive(0, perShardRunningMills).solve();
HashingSolution bestSolution = root.join();
log.info("Found init solution {}", initSolution);
log.info("Found best solution {}", bestSolution);
if (initSolution.getBestScore() < bestSolution.getBestScore()) {
bestSolution = initSolution;
}
long cost = System.currentTimeMillis() - now;
log.info("SearchBestSeed finish (cost:{}ms)", cost);
return bestSolution;
}
}
```
# 效果
將改造後的程式碼部署到測試環境後,某日訓練日誌:
> 12:49:15.227 85172866 INFO hash.HashingSeedCalculator - Found init solution (seed:15231, score:930685828341164)
12:49:15.227 85172866 INFO hash.HashingSeedCalculator - Found best solution (seed:362333, score:793386389726926)
12:49:15.227 85172866 INFO hash.HashingSeedCalculator - SearchBestSeed finish (cost:10154898ms)
12:49:15.227 85172866 INFO hash.TrainingService -
Training result: (seed:362333, score:793386389726926)
Buckets: 15
Expectation: 44045697
Result of Hashing.HashCode(seed=362333): 21327108 [42512742, 40479608, 43915771, 47211553, 45354264, 43209190, 43196570, 44725786, 41999747, 46450288, 46079231, 45116615, 44004021, 43896194, 42533877]
Result of Hashing.HashCode(seed=31): 66929172 [39723630, 48721463, 43365391, 46301448, 43931616, 44678194, 39064877, 45922454, 43171141, 40715060, 33964547, 49709090, 58869949, 34964729, 47581868]
當晚使用 `BKDRHash(seed=31)` 對新的交易日資料的進行分片:
> 04:00:59.001 partition messages per minute [45171, 68641, 62001, 80016, 55977, 61916, 55102, 49322, 55982, 57081, 51100, 70437, 135992, 37823, 58552] , messages total [39654953, 48666261, 43310578, 46146841, 43834832, 44577454, 38990331, 45871075, 43106710, 40600708, 33781629, 49752592, 58584246, 34928991, 47545369]
當晚使用 `BKDRHash(seed=362333)` 對新的交易日資料的進行分片:
> 04:00:59.001 partition messages per minute [62424, 82048, 64184, 47000, 57206, 69439, 64430, 60096, 46986, 58182, 54557, 41523, 64310, 72402, 100326] , messages total [44985772, 48329212, 39995385, 43675702, 45216341, 45524616, 41335804, 44917938, 44605376, 44054821, 43371892, 42068637, 44000817, 42617562, 44652695]
對比日誌發現 hash 經過優化後,分割槽的均勻程度有了顯著的上升,並且熱點分片也被消除了,基本達到當初設想的優化
由於資料量巨大,系統中啟動了 15 個執行緒來消費行情訊息。訊息分配的策略較為簡單:對 symbol 的 hashCode 取模,將訊息分配給其中一個執行緒進行處理。 經過驗證,每個執行緒分配到的 symbol 數量較為均勻,於是系統愉快地上線了。 執行一段時間後,突然收到了系統的告警,但此時並非訊息峰值時間段。經過排查後,發現問題出現在 hash 函式上:
雖然每個執行緒被分配到的 symbol 數量較為均衡,但是部分熱門 symbol 的報價訊息量會更多,如果熱門 symbol 集中到特定執行緒上,就會造成執行緒負載不均衡,使得系統整體的吞吐量大打折扣。 為提高系統的吞吐量,有必要訊息分發邏輯進行一些改造,避免出現熱點執行緒。為此,系統需要記錄下某天內每個 symbol 的訊息量,然後在第二天使用這些資料,對分發邏輯進行調整。具體的改造的方案可以分為兩種: - 放棄使用 hash 函式 - 對 hash 函式進行優化 # 放棄 hash 函式 問題可以抽象為: >
- 根據所有 symbol 的訊息總數計算一個期望的分佈均值
(expectation)
。 - 將每個 symbol 的訊息數按照 symbol 的順序進行排列,最後將這組陣列劃分為 15 個區間,並且儘可能使得每個區間元素之和與 expection 接近。
- 使用一個有序查詢表記錄每個區間的首個 symbol,後續就可以按照這個表對資料進行劃分。