1. 程式人生 > 其它 >011Java併發包010分支合併框架

011Java併發包010分支合併框架

本文主要學習了多執行緒的分支合併框架。

部分內容來自以下部落格:

https://segmentfault.com/a/1190000016781127

https://segmentfault.com/a/1190000016877931

1 簡介

JDK1.7版本引入了一套Fork/Join框架。Fork/Join框架的基本思想就是將一個大任務分解(Fork)成一系列子任務,子任務可以繼續往下分解,當多個不同的子任務都執行完成後,可以將它們各自的結果合併(Join)成一個大結果,最終合併成大任務的結果。

Fork/Join 框架要完成兩件事情:

1)Fork:把一個複雜任務進行分拆

2)Join:把分拆任務的結果進行合併

Fork/Join框架的實現非常複雜,內部大量運用了位操作和無鎖演算法。

Fork/Join框架內部還涉及到三大核心元件:ForkJoinPool(執行緒池)、ForkJoinTask(任務)、ForkJoinWorkerThread(工作執行緒),外加WorkQueue(任務佇列)。

2 類和介面

2.1 ForkJoinPool

ForkJoinPool是分支合併池,類似於執行緒池ThreadPoolExecutor,同樣是ExecutorService介面的一個實現類。

ForkJoinPool類的實現:

1 public class ForkJoinPool extends AbstractExecutorService {

在ForkJoinPool類中提供了三個構造方法:

1 public ForkJoinPool();
2 public ForkJoinPool(int parallelism);
3 public ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, boolean asyncMode);

最終呼叫的是下面這個私有構造器:

1 private ForkJoinPool(int parallelism, ForkJoinWorkerThreadFactory factory, UncaughtExceptionHandler handler, int
mode, String workerNamePrefix);

其引數含義如下:

parallelism:並行級別,預設值為CPU核心數,ForkJoinPool裡工作執行緒數量與該引數有關,但它不表示最大執行緒數。

factory:工作執行緒工廠,預設是DefaultForkJoinWorkerThreadFactory,其實就是用來建立ForkJoinWorkerThread工作執行緒物件。

handler:異常處理器。

mode:排程模式,true表示FIFO_QUEUE,false表示LIFO_QUEUE。

workerNamePrefix:工作執行緒的名稱字首。

2.2 ForkJoinTask

ForkJoinTask是Future介面的抽象實現類,提供了用於分解任務的fork()方法和用於合併任務的join()方法。

在ThreadPoolExecutor類中,使用執行緒池執行任務呼叫的execute()方法中要求傳入Runnable介面的例項。但是在ForkJoinPool類中,除了可以傳入Runnable介面的例項外,還可以傳入ForkJoinTask抽象類的例項,並且傳入Runnable介面的例項也會被適配為ForkJoinTask抽象類的例項。

2.3 RecursiveTask

通常情況下使用ForkJoinTask抽象類的例項,並不需要直接繼承ForkJoinTask類,只需要繼承其子類:

1)RecursiveAction:用於沒有返回結果的任務

2)RecursiveTask:用於有返回結果的任務

其中,最常用的還是RecursiveTask類。

2.4 ForkJoinWorkerThread

ForkJoinWorkerThread類是Thread的子類,作為執行緒池中的工作執行緒執行任務,其內部維護了一個WorkerQueue型別的雙向任務佇列。

工作執行緒在執行任務時,優先處理自身任務佇列中的任務(FIFO或者LIFO),當自身佇列中的任務為空時,會竊取其他任務佇列中的任務(FIFO)。

2.5 WorkerQueue

WorkerQueue類是ForkJoinPool類的一個內部類,代表儲存ForkJoinTask例項的雙端佇列。

在ForkJoinPool類的私有構造方法中,有一個int型別的mode引數,其取值如下:

1 static final int LIFO_QUEUE = 0;
2 static final int FIFO_QUEUE = 1 << 16;

當入參為LIFO_QUEUE時,表示同步,對於工作執行緒(Worker)自身佇列中的任務,採用後進先出(LIFO)的方式執行。

當入參為FIFO_QUEUE時,表示非同步,對於工作執行緒(Worker)自身佇列中的任務,採用先進先出(FIFO)的方式執行。

3 實現原理

3.1 提交任務

使用ForkJoinPool的submit方法提交任務得到ForkJoinTask物件:

1 public <T> ForkJoinTask<T> submit(ForkJoinTask<T> task) {
2   if (task == null)
3     throw new NullPointerException();
4   externalPush(task);
5   return task;
6 }

繼續檢視externalPush方法:

 1 final void externalPush(ForkJoinTask<?> task) {
 2   WorkQueue[] ws; WorkQueue q; int m;
 3   int r = ThreadLocalRandom.getProbe();
 4   int rs = runState;
 5   if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
 6     (q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
 7     U.compareAndSwapInt(q, QLOCK, 0, 1)) {
 8     ForkJoinTask<?>[] a; int am, n, s;
 9     if ((a = q.array) != null &&
10       (am = a.length - 1) > (n = (s = q.top) - q.base)) {
11       int j = ((am & s) << ASHIFT) + ABASE;
12       U.putOrderedObject(a, j, task);
13       U.putOrderedInt(q, QTOP, s + 1);
14       U.putIntVolatile(q, QLOCK, 0);
15       if (n <= 1)
16         signalWork(ws, q);
17       return;
18     }
19     U.compareAndSwapInt(q, QLOCK, 1, 0);
20   }
21   externalSubmit(task);
22 }

該方法包含兩個部分:

1)嘗試將任務新增到任務佇列,新增後則建立或啟用一個工作執行緒,在此過程中使用了CAS保證執行緒安全。

2)新增佇列失敗,則呼叫externalSubmit方法初始化佇列,並將任務加入到佇列。

3.2 分解任務

3.2.1 建立或喚醒工作執行緒

呼叫ForkJoinTask的fork方法完成任務分解:

1 public final ForkJoinTask<V> fork() {
2   Thread t;
3   if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)// 呼叫執行緒為工作執行緒
4     ((ForkJoinWorkerThread)t).workQueue.push(this);// 將任務新增到自身佇列
5   else
6     ForkJoinPool.common.externalPush(this);// 呼叫ForkJoinPool的externalPush方法
7   return this;
8 }

該方法包含兩個部分:

1)呼叫執行緒為工作執行緒,將任務新增到自身佇列。

2)呼叫執行緒為其他外部執行緒,繼續呼叫ForkJoinPool的externalPush方法,嘗試將任務新增到任務佇列並激活工作執行緒。

繼續檢視push方法,新增任務到自身佇列:

 1 final void push(ForkJoinTask<?> task) {
 2   ForkJoinTask<?>[] a; ForkJoinPool p;
 3   int b = base, s = top, n;
 4   if ((a = array) != null) {  // ignore if queue removed
 5     int m = a.length - 1;   // fenced write for task visibility
 6     U.putOrderedObject(a, ((m & s) << ASHIFT) + ABASE, task);
 7     U.putOrderedInt(this, QTOP, s + 1);
 8     if ((n = s - b) <= 1) {
 9       if ((p = pool) != null)
10         p.signalWork(p.workQueues, this);// 喚醒或建立工作執行緒
11     }
12     else if (n >= m)
13       growArray();// 擴容
14   }
15 }

該方法包含兩個部分:

1)判斷是否需要擴容,不需要擴容則喚醒或建立工作執行緒。

2)需要擴容,則進行擴容操作。

繼續檢視signalWork方法,建立或喚醒工作執行緒:

 1 final void signalWork(WorkQueue[] ws, WorkQueue q) {
 2   long c; int sp, i; WorkQueue v; Thread p;
 3   while ((c = ctl) < 0L) {            // too few active
 4     if ((sp = (int)c) == 0) {         // 沒有空閒工作程序
 5       if ((c & ADD_WORKER) != 0L)      // 工作程序太少
 6         tryAddWorker(c);// 增加工作程序
 7       break;
 8     }
 9     // 有工作程序,喚醒
10     if (ws == null)              // unstarted/terminated
11       break;
12     if (ws.length <= (i = sp & SMASK))     // terminated
13       break;
14     if ((v = ws[i]) == null)          // terminating
15       break;
16     int vs = (sp + SS_SEQ) & ~INACTIVE;    // next scanState
17     int d = sp - v.scanState;         // screen CAS
18     long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
19     if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
20       v.scanState = vs;           // activate v
21       if ((p = v.parker) != null)
22         U.unpark(p);
23       break;
24     }
25     if (q != null && q.base == q.top)     // no more work
26       break;
27   }
28 }

繼續檢視tryAddWorker方法:

 1 private void tryAddWorker(long c) {
 2   boolean add = false;
 3   do {
 4     // 設定活躍工作執行緒數和總工作執行緒數
 5     long nc = ((AC_MASK & (c + AC_UNIT)) |
 6          (TC_MASK & (c + TC_UNIT)));
 7     if (ctl == c) {
 8       int rs, stop;         // check if terminating
 9       if ((stop = (rs = lockRunState()) & STOP) == 0)
10         add = U.compareAndSwapLong(this, CTL, c, nc);
11       unlockRunState(rs, rs & ~RSLOCK);
12       if (stop != 0)
13         break;
14       if (add) {
15         // 建立工作執行緒
16         createWorker();
17         break;
18       }
19     }
20   } while (((c = ctl) & ADD_WORKER) != 0L && (int)c == 0);
21 }

繼續檢視createWorker方法:

 1 private boolean createWorker() {
 2   ForkJoinWorkerThreadFactory fac = factory;
 3   Throwable ex = null;
 4   ForkJoinWorkerThread wt = null;
 5   try {
 6     // 使用執行緒池工廠建立執行緒
 7     if (fac != null && (wt = fac.newThread(this)) != null) {
 8       // 啟動執行緒
 9       wt.start();
10       return true;
11     }
12   } catch (Throwable rex) {
13     ex = rex;
14   }
15   // 出現異常,登出該工作執行緒
16   deregisterWorker(wt, ex);
17   return false;
18 }

3.2.2 啟動任務

ForkJoinWorkerThread在執行start方法後,會執行run方法:

 1 public void run() {
 2   if (workQueue.array == null) { // only run once
 3     Throwable exception = null;
 4     try {
 5       onStart();
 6       pool.runWorker(workQueue);
 7     } catch (Throwable ex) {
 8       exception = ex;
 9     } finally {
10       try {
11         onTermination(exception);
12       } catch (Throwable ex) {
13         if (exception == null)
14           exception = ex;
15       } finally {
16         pool.deregisterWorker(this, exception);
17       }
18     }
19   }
20 }

在run方法內部呼叫了ForkJoinPool物件的runWorker方法:

 1 final void runWorker(WorkQueue w) {
 2   w.growArray();          // 初始化任務佇列
 3   int seed = w.hint;        // initially holds randomization hint
 4   int r = (seed == 0) ? 1 : seed; // avoid 0 for xorShift
 5   for (ForkJoinTask<?> t;;) {
 6     if ((t = scan(w, r)) != null)// 嘗試獲取任務
 7       w.runTask(t);// 執行任務
 8     else if (!awaitWork(w, r))// 獲取失敗,加入等待任務佇列
 9       break;// 等待失敗,跳出方法並登出工作執行緒
10     r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
11   }
12 }

3.2.3 竊取任務

使用scan方法竊取任務:

 1 private ForkJoinTask<?> scan(WorkQueue w, int r) {
 2   WorkQueue[] ws; int m;
 3   if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
 4     int ss = w.scanState;           // initially non-negative
 5     for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
 6       WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
 7       int b, n; long c;
 8       if ((q = ws[k]) != null) {// 定位任務佇列
 9         if ((n = (b = q.base) - q.top) < 0 &&
10           (a = q.array) != null) {   // non-empty
11           long i = (((a.length - 1) & b) << ASHIFT) + ABASE;
12           if ((t = ((ForkJoinTask<?>)
13                U.getObjectVolatile(a, i))) != null &&
14             q.base == b) {
15             if (ss >= 0) {
16               if (U.compareAndSwapObject(a, i, t, null)) {
17                 q.base = b + 1;
18                 if (n < -1)    // signal others
19                   signalWork(ws, q);// 建立獲喚醒工作執行緒執行任務
20                 return t;
21               }
22             }
23             else if (oldSum == 0 &&  // try to activate
24                 w.scanState < 0)
25               tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);// 喚醒棧頂工作執行緒
26           }
27           if (ss < 0)          // refresh
28             ss = w.scanState;
29           r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
30           origin = k = r & m;      // move and rescan
31           oldSum = checkSum = 0;
32           continue;
33         }
34         checkSum += b;
35       }
36       // 已掃描全部工作執行緒,但並未找到任務
37       if ((k = (k + 1) & m) == origin) {  // continue until stable
38         if ((ss >= 0 || (ss == (ss = w.scanState))) &&
39           oldSum == (oldSum = checkSum)) {
40           if (ss < 0 || w.qlock < 0)  // already inactive
41             break;
42           int ns = ss | INACTIVE;    // 嘗試對當前工作執行緒滅活
43           long nc = ((SP_MASK & ns) |
44                (UC_MASK & ((c = ctl) - AC_UNIT)));
45           w.stackPred = (int)c;     // hold prev stack top
46           U.putInt(w, QSCANSTATE, ns);
47           if (U.compareAndSwapLong(this, CTL, c, nc))
48             ss = ns;
49           else
50             w.scanState = ss;     // back out
51         }
52         checkSum = 0;
53       }
54     }
55   }
56   return null;
57 }

3.2.4 執行任務

竊取到任務後,呼叫runTask方法執行任務:

 1 final void runTask(ForkJoinTask<?> task) {
 2   if (task != null) {
 3     scanState &= ~SCANNING; // mark as busy
 4     (currentSteal = task).doExec();// 執行任務
 5     U.putOrderedObject(this, QCURRENTSTEAL, null); // release for GC
 6     execLocalTasks();// 執行本地任務
 7     ForkJoinWorkerThread thread = owner;
 8     if (++nsteals < 0)   // collect on overflow
 9       transferStealCount(pool);// 增加竊取任務數
10     scanState |= SCANNING;
11     if (thread != null)
12       thread.afterTopLevelExec();// 執行鉤子函式
13   }
14 }

3.2.5 阻塞等待

如何未竊取到任務,會呼叫awaitWork方法等待獲取任務:

 1 private boolean awaitWork(WorkQueue w, int r) {
 2   if (w == null || w.qlock < 0)         // w is terminating
 3     return false;
 4   for (int pred = w.stackPred, spins = SPINS, ss;;) {
 5     if ((ss = w.scanState) >= 0)
 6       break;
 7     else if (spins > 0) {
 8       r ^= r << 6; r ^= r >>> 21; r ^= r << 7;
 9       if (r >= 0 && --spins == 0) {     // randomize spins
10         WorkQueue v; WorkQueue[] ws; int s, j; AtomicLong sc;
11         if (pred != 0 && (ws = workQueues) != null &&
12           (j = pred & SMASK) < ws.length &&
13           (v = ws[j]) != null &&    // see if pred parking
14           (v.parker == null || v.scanState >= 0))
15           spins = SPINS;        // continue spinning
16       }
17     }
18     else if (w.qlock < 0)           // recheck after spins
19       return false;
20     else if (!Thread.interrupted()) {
21       long c, prevctl, parkTime, deadline;
22       int ac = (int)((c = ctl) >> AC_SHIFT) + (config & SMASK);
23       if ((ac <= 0 && tryTerminate(false, false)) ||
24         (runState & STOP) != 0)      // pool terminating
25         return false;
26       if (ac <= 0 && ss == (int)c) {    // is last waiter
27         prevctl = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & pred);
28         int t = (short)(c >>> TC_SHIFT); // shrink excess spares
29         if (t > 2 && U.compareAndSwapLong(this, CTL, c, prevctl))
30           return false;         // else use timed wait
31         parkTime = IDLE_TIMEOUT * ((t >= 0) ? 1 : 1 - t);
32         deadline = System.nanoTime() + parkTime - TIMEOUT_SLOP;
33       }
34       else
35         prevctl = parkTime = deadline = 0L;
36       Thread wt = Thread.currentThread();
37       U.putObject(wt, PARKBLOCKER, this);  // emulate LockSupport
38       w.parker = wt;
39       if (w.scanState < 0 && ctl == c)   // recheck before park
40         U.park(false, parkTime);
41       U.putOrderedObject(w, QPARKER, null);
42       U.putObject(wt, PARKBLOCKER, null);
43       if (w.scanState >= 0)
44         break;
45       if (parkTime != 0L && ctl == c &&
46         deadline - System.nanoTime() <= 0L &&
47         U.compareAndSwapLong(this, CTL, c, prevctl))
48         return false;           // shrink pool
49     }
50   }
51   return true;
52 }

3.3 合併任務

使用ForkJoinTask的join方法可以獲取任務的執行結果:

1 public final V join() {
2   int s;
3   if ((s = doJoin() & DONE_MASK) != NORMAL)
4     reportException(s);
5   return getRawResult();
6 }

檢視doJoin方法:

1 private int doJoin() {
2   int s; Thread t; ForkJoinWorkerThread wt; ForkJoinPool.WorkQueue w;
3   return (s = status) < 0 ? s :
4     ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
5     (w = (wt = (ForkJoinWorkerThread)t).workQueue).
6     tryUnpush(this) && (s = doExec()) < 0 ? s :
7     wt.pool.awaitJoin(w, this, 0L) :
8     externalAwaitDone();
9 }

4 使用

4.1 計算多個整數的和

任務類定義,因為需要返回結果,所以繼承RecursiveTask,並覆寫compute方法。

任務的拆分通過ForkJoinTask的fork方法執行,join方法用於等待任務執行後返回。

 1 class SumTask extends RecursiveTask<Integer> {
 2   private static final int THRESHOLD = 10;// 拆分閾值
 3   private int begin;// 拆分開始值
 4   private int end;// 拆分結束值
 5   public SumTask(int begin, int end) {
 6     this.begin = begin;
 7     this.end = end;
 8   }
 9   @Override
10   protected Integer compute() {
11     Integer value = 0;
12     if (end - begin <= THRESHOLD) {// 小於閾值,直接計算
13       for (int i = begin; i <= end; i++) {
14         value += i;
15       }
16     } else {// 大於閾值,遞迴計算
17       int middle = (begin + end) / 2;
18       SumTask beginTask = new SumTask(begin, middle);
19       SumTask endTask = new SumTask(middle + 1, end);
20       beginTask.fork();
21       endTask.fork();
22       value = beginTask.join() + endTask.join();
23     }
24     return value;
25   }
26 }
27 public class DemoTest {
28   public static void main(String[] args) {
29     SumTask sumTask = new SumTask(1, 100);
30     ForkJoinPool pool = new ForkJoinPool();
31     try {
32       ForkJoinTask<Integer> task = pool.submit(sumTask);
33       System.out.println(task.get());
34     } catch (Exception e) {
35       e.printStackTrace();
36     } finally {
37       pool.shutdown();
38     }
39   }
40 }

最終結果是5050。