如何用多執行緒實現歸併排序
等我有時間了,一定要把《演算法導論》啃完,這本書的印刷質量實在太好了,滑稽。
之前聽吳恩達老大說過Python裡面的Numpy包的矩陣運算就是多執行緒的,所以能做到的情況下儘量用矩陣運算代替迴圈,這樣能大大加快運算的速度。
為了提高速度,如果不涉及外部資源讀取的話,要提高執行速度就要做到平行計算,依賴於處理器的數量;如果需要等待耗時的外部資源讀取,就可以通過併發邊讀邊運算。
演算法導論有一章節提到了並行迴圈,多執行緒矩陣乘法和多執行緒歸併排序,方法都是講一個大的計算過程分成幾個獨立的小部分,各個部分讓單獨的執行緒去計算。
排序裡面講問題分解的典型的就有快排和歸併,接下來看一下怎麼寫多執行緒的。
多執行緒歸併排序
直接點的思考方式,歸併排序先要把一個數據分成兩個,然後這兩個分別歸併排序,拍完了把兩個歸併到一起,典型的遞迴。
那麼我們直接點,先把陣列分割好,然後開兩個執行緒,一個執行緒給一個,等著兩個執行緒都搞定了,在把兩個結果合併起來。或者你覺得兩個執行緒每個要處理的還是太長了,那就在這兩個執行緒裡面再把拿到的陣列分割了,各自再開兩個。嘗試一下
先看下單執行緒的版本,做下測試
import java.util.Random; public class Main { public static void main(String[] args) { int length = 1000; int[] data = (new Data(length)).getData(); printArr(data); System.out.println(); mergeSort(data); printArr(data); } //遞迴 private static void mergeSort(int[] nums,int[] tmp,int left,int right){ if(left<right){ int center = (left+right)/2; mergeSort(nums,tmp,left,center); mergeSort(nums,tmp,center+1,right); merge(nums,tmp,left,center+1,right); } } //合併 private static void merge(int[] nums,int[] tmp,int leftPos, int rightPos, int rightEnd){ int leftEnd = rightPos-1; int tmpPos = leftPos; int numElements = rightEnd - leftPos + 1; while(leftPos<=leftEnd&&rightPos<=rightEnd){ if(nums[leftPos]<nums[rightPos]) tmp[tmpPos++]=nums[leftPos++]; else tmp[tmpPos++]=nums[rightPos++]; } while(leftPos<=leftEnd) tmp[tmpPos++]=nums[leftPos++]; while(rightPos<=rightEnd) tmp[tmpPos++]=nums[rightPos++]; for(int i = 0;i<numElements;i++,rightEnd--) nums[rightEnd]=tmp[rightEnd]; } public static void mergeSort(int[] nums){ int[] tmp = new int[nums.length]; mergeSort(nums,tmp,0,nums.length-1); } //列印 public static void printArr(int[] arr) { for(int i : arr){ System.out.print(i+" "); } } } /** * 產生隨機資料 */ class Data{ int length; int[] data; public Data(int length){ this.length = length; data = new int[length]; } public int[] getData(){ Random random = new Random(System.currentTimeMillis()); for(int i=0;i<length;i++){ data[i]=random.nextInt(2*length); } return data; } }
可以看到演算法是能正常執行的
按上面思路的多執行緒版本呢?用 兩個執行緒試驗了下
只修改了main函式,加入了一個verify用作驗證排序是不是OK的,不能人眼看吧
import java.util.Random; import java.util.concurrent.CountDownLatch; public class Main { public static void main (String[] args) throws InterruptedException { int length = 1000; int[] data = (new Data(length)).getData(); printArr(data); System.out.println(); // mergeSort(data); //在這裡修改 int center = data.length/2; int[] tmp = new int[data.length]; CountDownLatch latch = new CountDownLatch(2);//CountDownLatch能夠使一個執行緒在等待另 //外一些執行緒完成各自工作之後,再繼續執行 new Thread(new Runnable(){ @Override public void run() { mergeSort(data,tmp,0,center); latch.countDown(); } }).start(); new Thread(new Runnable(){ @Override public void run() { mergeSort(data,tmp,center+1,data.length-1); latch.countDown(); } }).start(); latch.await(); merge(data, tmp, 0, center+1, data.length-1); printArr(data); System.out.println(); verify(data); } //遞迴 private static void mergeSort(int[] nums,int[] tmp,int left,int right){ if(left<right){ int center = (left+right)/2; mergeSort(nums,tmp,left,center); mergeSort(nums,tmp,center+1,right); merge(nums,tmp,left,center+1,right); } } //合併 private static void merge(int[] nums,int[] tmp,int leftPos, int rightPos, int rightEnd){ int leftEnd = rightPos-1; int tmpPos = leftPos; int numElements = rightEnd - leftPos + 1; while(leftPos<=leftEnd&&rightPos<=rightEnd){ if(nums[leftPos]<nums[rightPos]) tmp[tmpPos++]=nums[leftPos++]; else tmp[tmpPos++]=nums[rightPos++]; } while(leftPos<=leftEnd) tmp[tmpPos++]=nums[leftPos++]; while(rightPos<=rightEnd) tmp[tmpPos++]=nums[rightPos++]; for(int i = 0;i<numElements;i++,rightEnd--) nums[rightEnd]=tmp[rightEnd]; } public static void mergeSort(int[] nums){ int[] tmp = new int[nums.length]; mergeSort(nums,tmp,0,nums.length-1); } //列印 public static void printArr(int[] arr) { for(int i : arr){ System.out.print(i+" "); } } public static void verify(int[] nums) { for(int i=0;i<nums.length-1;i++){ if(nums[i]>nums[i+1]){ System.out.println("排序失敗"); return; } } System.out.println("排序成功"); } } /** * 產生隨機資料 */ class Data{ int length; int[] data; public Data(int length){ this.length = length; data = new int[length]; } public int[] getData(){ Random random = new Random(System.currentTimeMillis()); for(int i=0;i<length;i++){ data[i]=random.nextInt(2*length); } return data; } }
結果是OK的
上面是按自己的構思開啟的執行緒。
其實Java本身提供了更好的解決方案,就是Fork/Join
框架, 貼一下這個框架的介紹
使用Fork/Join 我們需要知道兩個類:
- ForkJoinTask:我們要使用ForkJoin框架,必須首先建立一個ForkJoin任務。它提供在任務中執行fork()和join()操作的機制,通常情況下我們不需要直接繼承ForkJoinTask類,而只需要繼承它的子類,Fork/Join框架提供了以下兩個子類:
- RecursiveAction:用於沒有返回結果的任務。
- RecursiveTask :用於有返回結果的任務。
- ForkJoinPool :ForkJoinTask需要通過ForkJoinPool來執行,任務分割出的子任務會新增到當前工作執行緒所維護的雙端佇列中,進入佇列的頭部。當一個工作執行緒的佇列裡暫時沒有任務時,它會隨機從其他工作執行緒的佇列的尾部獲取一個任務。
下面看下如何用這個框架實現多執行緒歸併排序
import java.util.Random;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;
import java.util.concurrent.RecursiveTask;
import java.util.concurrent.TimeUnit;
public class Main {
public static void main (String[] args) throws InterruptedException {
int length = 1000;
int[] data = (new Data(length)).getData();
printArr(data);
System.out.println();
// mergeSort(data);
//在這裡修改
// int center = data.length/2;
int[] tmp = new int[data.length];
// CountDownLatch latch = new CountDownLatch(2);//CountDownLatch能夠使一個執行緒在等待另外一些執行緒完成各自工作之後,再繼續執行
// new Thread(new Runnable(){
// @Override
// public void run() {
// mergeSort(data,tmp,0,center);
// latch.countDown();
// }
// }).start();
// new Thread(new Runnable(){
// @Override
// public void run() {
// mergeSort(data,tmp,center+1,data.length-1);
// latch.countDown();
// }
// }).start();
// latch.await();
// merge(data, tmp, 0, center+1, data.length-1);
//Fork/Join 從這裡開始
ForkJoinPool forkJoinPool = new ForkJoinPool();
Main.mergeTask task = new Main.mergeTask(data, tmp, 0, data.length-1);//建立任務
forkJoinPool.execute(task);//執行任務
forkJoinPool.awaitTermination(2, TimeUnit.SECONDS);//阻塞當前執行緒直到pool中的任務都完成了
printArr(data);
System.out.println();
verify(data);
}
//遞迴
private static void mergeSort(int[] nums,int[] tmp,int left,int right){
if(left<right){
int center = (left+right)/2;
mergeSort(nums,tmp,left,center);
mergeSort(nums,tmp,center+1,right);
merge(nums,tmp,left,center+1,right);
}
}
//合併
private static void merge(int[] nums,int[] tmp,int leftPos, int rightPos, int rightEnd){
int leftEnd = rightPos-1;
int tmpPos = leftPos;
int numElements = rightEnd - leftPos + 1;
while(leftPos<=leftEnd&&rightPos<=rightEnd){
if(nums[leftPos]<nums[rightPos])
tmp[tmpPos++]=nums[leftPos++];
else
tmp[tmpPos++]=nums[rightPos++];
}
while(leftPos<=leftEnd)
tmp[tmpPos++]=nums[leftPos++];
while(rightPos<=rightEnd)
tmp[tmpPos++]=nums[rightPos++];
for(int i = 0;i<numElements;i++,rightEnd--)
nums[rightEnd]=tmp[rightEnd];
}
public static void mergeSort(int[] nums){
int[] tmp = new int[nums.length];
mergeSort(nums,tmp,0,nums.length-1);
}
//列印
public static void printArr(int[] arr) {
for(int i : arr){
System.out.print(i+" ");
}
}
public static void verify(int[] nums) {
for(int i=0;i<nums.length-1;i++){
if(nums[i]>nums[i+1]){
System.out.println("排序失敗");
return;
}
}
System.out.println("排序成功");
}
static class mergeTask extends RecursiveAction {
private static final int THRESHOLD = 2;//設定任務大小閾值
private int start;
private int end;
private int[] data;
private int[] tmp;
public mergeTask(int[] data, int[] tmp, int start, int end){
this.data = data;
this.tmp = tmp;
this.start = start;
this.end = end;
}
@Override
protected void compute(){
if((end - start)<=THRESHOLD){
mergeSort(data,tmp,start,end);
}else{
int center = (start + end)/2;
Main.mergeTask leftTask = new Main.mergeTask(data, tmp, start, center);
Main.mergeTask rightTask = new Main.mergeTask(data, tmp, center+1, end);
leftTask.fork();
rightTask.fork();
leftTask.join();
rightTask.join();
merge(data, tmp, start, center+1, end);
}
}
}
}
/**
* 產生隨機資料
*/
class Data{
int length;
int[] data;
public Data(int length){
this.length = length;
data = new int[length];
}
public int[] getData(){
Random random = new Random(System.currentTimeMillis());
for(int i=0;i<length;i++){
data[i]=random.nextInt(2*length);
}
return data;
}
}
結果也是OK的
以上都沒有涉及到鎖,雖然操作的是共享的陣列,但是被讀寫的區域是被隔離開的。
也是在演算法導論上瞟到多執行緒演算法這麼一章,順藤摸瓜才知道有Fork/Join 這個東西,要學的東西真的多。
搞完這個我又聯想到之前看過的一道演算法題:
在大量的資料中,尋找最大的k個數,或者是出現次數最多的k個數據,比如說這個資料有10個G,放在一個大檔案中,電腦記憶體4G。
解題思路就是先把這個檔案分塊,為了確保相同的資料在一個塊中,通過計算Hash值來分塊,相同Hash 放到一個塊中。比如每分100個塊,這樣平均一個塊就在100M左右,對每個塊分別載入記憶體找最大的前K個數或者出現最多的前K個數據,最後比較這100*K個數據來得到結果。
怎麼用多執行緒求解?