1. 程式人生 > >Java7 Fork-Join 框架:任務切分,並行處理

Java7 Fork-Join 框架:任務切分,並行處理

package forkJoin;

import java.util.concurrent.RecursiveTask;

public class SumTask extends RecursiveTask<Integer> {
    private static final int THRESHOLD = 20;

    private int[] array;
    private int low;
    private int high;

    public SumTask(int[] array, int low, int high) {
        this.array = array;
        this.low = low;
        this.high = high;
    }

    @Override
    protected Integer compute() {
        int sum = 0;
        if (high - low + 1 <= THRESHOLD) {
            System.out.println(low + " - " + high + "  計算");
//            測試並行的個數,統計輸出過程中的文字,看看有多少執行緒停止在這裡就知道有多少平行計算
//            參考 ForkJoinPool 初始化設定的並行數
//            try {
//                Thread.sleep(11111111);
//            } catch (InterruptedException e) {
//                e.printStackTrace();
//            }
            // 小於閾值則直接計算
            for (int i = low; i <= high; i++) {
                sum += array[i];
            }
        } else {
            System.out.println(low + " - " + high + "  切分");
            // 1. 一個大任務分割成兩個子任務
            int mid = (low + high) / 2;
            SumTask left = new SumTask(array, low, mid);
            SumTask right = new SumTask(array, mid + 1, high);

            // 2. 分別平行計算
            invokeAll(left, right);

            // 3. 合併結果
            sum = left.join() + right.join();

            // 另一種方式
            try {
                sum = left.get() + right.get();
            } catch (Throwable e) {
                System.out.println(e.getMessage());
            }
        }
        return sum;
    }
}
package forkJoin;

import java.util.Random;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.RecursiveAction;

public class Main {

    /*static class MyTaskTest extends RecursiveTask<Integer> {
        final int n;

        MyTaskTest(int n) {
            this.n = n;
        }

        @Override
        protected Integer compute() {
            if (n <= 1) return n;
            MyTaskTest f1 = new MyTaskTest(n - 1);
            f1.fork();
            MyTaskTest f2 = new MyTaskTest(n - 2);
            return f2.compute() + f1.join();
        }
    }*/

    /*class SortTask extends RecursiveAction {
        static final int THRESHOLD = 2;
        final long[] array;
        final int lo;
        final int hi;

        SortTask(long[] array, int lo, int hi) {
            this.array = array;
            this.lo = lo;
            this.hi = hi;
        }

        protected void compute() {
            if (hi - lo < THRESHOLD)
                sequentiallySort(array, lo, hi);
            else {
                int mid = (lo + hi) >>> 1;
                invokeAll(new SortTask(array, lo, mid),
                        new SortTask(array, mid, hi));
                merge(array, lo, hi);
            }
        }
    }*/

    private static int[] genArray() {
        int[] array = new int[100];
        for (int i = 0; i < array.length; i++) {
            array[i] = new Random().nextInt(500);
        }
        return array;
    }

    public static void main(String[] args) throws ExecutionException, InterruptedException {
        /**
         * 下面以一個有返回值的大任務為例,介紹一下RecursiveTask的用法。
         大任務是:計算隨機的100個數字的和。
         小任務是:每次只能20個數值的和。
         */
        int[] array = genArray();

//        System.out.println(Arrays.toString(array));
        int total = 0;
        for (int i = 0; i < array.length; i++) {
            total += array[i];
        }
        System.out.println("目標和:" + total);

        // 1. 建立任務
        SumTask sumTask = new SumTask(array, 0, array.length - 1);

        // 2. 建立執行緒池
        // 設定平行計算的個數
        int processors = Runtime.getRuntime().availableProcessors();
        ForkJoinPool forkJoinPool = new ForkJoinPool(processors * 2);

        // 3. 提交任務到執行緒池
        forkJoinPool.submit(sumTask);
//        forkJoinPool.shutdown();

        long begin = System.currentTimeMillis();
        // 4. 獲取結果
        Integer result = sumTask.get();// wait for
        long end = System.currentTimeMillis();
        System.out.println(String.format("結果 %s ,耗時 %sms", result, end - begin));

        if (result == total) {
            System.out.println("測試成功");
        } else {
            System.out.println("fork join 使用失敗!!!!");
        }
    }
}

上面的程式碼是一個100個整數累加的任務,切分到小於20個數的時候直接進行累加,不再切分。