Fork/Join框架 demo
阿新 • • 發佈:2020-07-20
demo1:使⽤Fork/Join來求,斐波那契數列第n項
斐波那契數列數列是⼀個線性遞推數列,從第三項開始,每⼀項的值都等於
前兩項之和:
1, 1, 2, 3, 5, 8, 13, 21, 34, 55, 89······
如果設f(n)為該數列的第n項(n∈N*),那麼有:f(n) = f(n-1) + f(n-2)。
import org.junit.Test;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;import java.util.concurrent.Future;
import java.util.concurrent.RecursiveTask;
/**
* Created by DELL on 2020/7/20.
*/
public class FibonacciTest {
class Fibonacci extends RecursiveTask<Integer> {
int n;
public Fibonacci(int n) {
this.n = n;
}
// 主要的實現邏輯都在compute()⾥@Override
protected Integer compute() {
// 這⾥先假設 n >= 0
if (n <= 1) {
return n;
} else {
// f(n-1)
Fibonacci f1 = new Fibonacci(n - 1);
f1.fork();
// f(n-2)
Fibonacci f2 = new Fibonacci(n - 2);f2.fork();
// f(n) = f(n-1) + f(n-2)
return f1.join() + f2.join();
}
}
}
@Test
public void testFib() throws ExecutionException, InterruptedException {
ForkJoinPool forkJoinPool = new ForkJoinPool();
System.out.println("CPU核數:" + Runtime.getRuntime().availableProcessors());
long start = System.currentTimeMillis();
Fibonacci fibonacci = new Fibonacci(40);
Future<Integer> future = forkJoinPool.submit(fibonacci);
System.out.println(future.get());
long end = System.currentTimeMillis();
System.out.println(String.format("耗時:%d millis", end - start));
}
}
輸出:
CPU核數:4 計算結果:102334155 耗時:9490 millis
demo2:求1+2+3+4的結果
Fork/Join框架分割任務,將每個子任務最多執行兩個數的相加,那麼我們設定分割的閾值是2,由於是4個數字相加,所以Fork/Join框架會把這個任務fork成兩個子任務,子任務一負責計算1+2,子任務二負責計算3+4,然後再join兩個子任務的結果。因為是有結果的任務,所以必須繼承RecursiveTask
import java.util.concurrent.ExecutionException; import java.util.concurrent.ForkJoinPool; import java.util.concurrent.Future; import java.util.concurrent.RecursiveTask; /** * * @author aikq * @date 2018年11月21日 20:37 */ public class ForkJoinTaskDemo { public static void main(String[] args) { ForkJoinPool pool = new ForkJoinPool(); CountTask task = new CountTask(1,4); Future<Integer> result = pool.submit(task); try { System.out.println("計算結果=" + result.get()); } catch (InterruptedException e) { e.printStackTrace(); } catch (ExecutionException e) { e.printStackTrace(); } } } class CountTask extends RecursiveTask<Integer>{ private static final long serialVersionUID = -7524245439872879478L; private static final int THREAD_HOLD = 2; private int start; private int end; public CountTask(int start,int end){ this.start = start; this.end = end; } @Override protected Integer compute() { int sum = 0; //如果任務足夠小就計算 boolean canCompute = (end - start) <= THREAD_HOLD; if(canCompute){ for(int i=start;i<=end;i++){ sum += i; } }else{ int middle = (start + end) / 2; CountTask left = new CountTask(start,middle); CountTask right = new CountTask(middle+1,end); //執行子任務 left.fork(); right.fork(); //獲取子任務結果 int lResult = left.join(); int rResult = right.join(); sum = lResult + rResult; } return sum; } }