併發程式設計---ThreadLocal原始碼解析
在遇到執行緒安全問題的時候,我們一般都是使用同步來解決,比如內建鎖、顯示鎖等等。執行緒安全的主要起因是因為多個執行緒同時操作一個共享變數,如果我們換種思路,在某些場景下,我們為這些執行緒提供共享變數的副本,讓他們在自己的私有域中去操作這些變數,執行緒之間互不影響,那是不是就不會產生執行緒安全問題了?ThreadLocal提供了這樣的一種實現。
ThreadLocal內部封裝了ThreadLocalMap結構來為執行緒提供儲存資料的私有域空間,而Thread類提供了成員變數threadLocals來ThreadLocalMap,這樣ThreadLocal、TreadLocalMap、Thread就緊密聯絡起來了。ThreadLocal對外提供了get、set、remove等方法來供我們操作Thread的私有域空間ThreadLocalMap。這裡我們先說個大概,後面分析原始碼的時候再來一一解釋。
接下來直接看ThreadLocal的原始碼。
ThreadLocal的類結構
ThreadLocal的類是java.lang包下的一個普通類,沒有任何類的繼承與介面實現。
public class ThreadLocal<T> {
......
}
ThreadLocal的成員變數
private final int threadLocalHashCode = nextHashCode(); private static AtomicInteger nextHashCode = new AtomicInteger(); private static final int HASH_INCREMENT = 0x61c88647;
ThreadLocal的構造方法
public ThreadLocal() {}
ThreadLocal的內部類
ThreadLocalMap:
我們第一眼就看到ThreadLocalMap中又有一個內部類Entry,好,我們一個一個看。
Entry:
Entry就是ThreadLocalMap中實際存放資料的單個節點,為了便於理解,我們可以參照HashMap中的Node節點。Entry組成的陣列就是ThreadLocalMap的底層封裝資料的資料結構。
Entry繼承於WeakReference(弱引用),對於弱引用,我們先做個大概的瞭解。
如果一個物件僅被WeakReference指向,而沒有其他任何強引用指向的話,在下一次GC的時候,弱引用指向的物件就會被回收。
//ThreadLocalMap的map中定義內部類Entry,Entry就是具體儲存資料的結構
//Entry繼承了弱引用
//Entry的key是啥?是ThreadLocal的弱引用
static class Entry extends WeakReference<ThreadLocal<?>> {
//存放的資料
Object value;
//Entry的構造方法
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
Entry中有兩個成員變數,一個是Ojbect型別的value,還有一個是繼承於WeakReference的型別為ThreadLocal的reference。我們可以把reference看做是key。
接著繼續看ThreadLocalMap中的成員變數和構造方法。
static class ThreadLocalMap {
//節點陣列的初始化容量值
private static final int INITIAL_CAPACITY = 16;
//Entry節點陣列,存放資料的陣列
private Entry[] table;
//Entry陣列中實際儲存資料的數目,初始為0
private int size = 0;
//Entry陣列擴容的閾值
private int threshold;
//ThreadLocalMap的構造方法
ThreadLocalMap(ThreadLocal<?> firstKey, Object firstValue) {
//初始化Entry陣列,容量為預設的初始值16
table = new Entry[INITIAL_CAPACITY];
//threadLocalHashCode = nextHashCode(),
//INITIAL_CAPACITY為16,所以(INITIAL_CAPACITY - 1)的二進位制形式為1111,
//與(INITIAL_CAPACITY - 1)進行位與運算就是相當於threadLocalHashCode對16取模
//這是因為Entry陣列是一個長度為16的陣列圓環,而key的落腳點即是在這個HashCode對16取模的值
//i就是當前這個key在Entry環形陣列的索引值
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
//將ThreadLocal和value值構建成一個Entry,放置在ENtry陣列中,
table[i] = new Entry(firstKey, firstValue);
//因為是構造方法,這裡肯定是第一次存入資料,所以size為1
size = 1;
//設定entry陣列的閾值,閾值為當前Entry陣列長度的三分之二
setThreshold(INITIAL_CAPACITY);
}
//這個方法是ThreadLocal的方法
private static int nextHashCode() {
//nextHashCode為AtomicInteger型別
//AtomicInteger的getAndAdd()方法就是以用Unsafe的設定方式去更新這個AtomicInteger
//更新為當前值+HASH_INCREMENT
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
//這個方法是AtomicInteger的方法
public final int getAndAddInt(Object var1, long var2, int var4) {
int var5;
do {
//var5即為當前這個AtomicInteger的值
var5 = this.getIntVolatile(var1, var2);
} while(!this.compareAndSwapInt(var1, var2, var5, var5 + var4));
//將AtomicInteger的當前值var5更新為var5+var4,而war4即為增量
return var5;
}
//這個方法是Entry本身的方法
private void setThreshold(int len) {
//閾值為當前entry陣列長度的三分之二
threshold = len * 2 / 3;
}
//ThreadLocalMap的構造方法,引數為一個ThreadLocalMap
private ThreadLocalMap(ThreadLocalMap parentMap) {
//獲取引數ThreadLocalMap中的Entry陣列
Entry[] parentTable = parentMap.table;
//獲取引數Entry陣列的長度
int len = parentTable.length;
//設定閾值為陣列長度的三分之二
setThreshold(len);
//建立一個新的陣列,將陣列賦值給當前Entry陣列table
table = new Entry[len];
//迴圈遍歷
for (int j = 0; j < len; j++) {
//獲取引數entry陣列的每個entry節點
Entry e = parentTable[j];
if (e != null) {
@SuppressWarnings("unchecked")
//e.get()返回引用referent,這個referent即為ThreadLocal
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
//獲取value
Object value = key.childValue(e.value);
//對key和value做完基本校驗後,組建新的Entry節點
Entry c = new Entry(key, value);
//計算下角標位置
int h = key.threadLocalHashCode & (len - 1);
while (table[h] != null)
//如果該下角標位置已經有元素了,計算下個索引位置
h = nextIndex(h, len);
//直到計算出的索引位置上沒有元素時,將新建的entry放到該索引位置
table[h] = c;
//entry陣列的元素數量加一
size++;
}
}
}
}
//當前下角標i的下一個索引位置,如果達到entry陣列的長度16的話,重新從0開始
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}
}
ThreadLocalMap中維護了一個初始容量為16的entry陣列。這個entry陣列就是儲存資料的底層結構,還有一個閾值,看過HashMap底層原始碼的就不會對這個概念陌生,另外其實還有一個負載因子,不過這個負載因子並沒有宣告成員變數,而是在程式碼中直接使用的,這個負載因子為三分之二,我們可以看下setThreshold()這個方法,threshold = len * 2 / 3。
繼續往下看,有兩個方法比較重要的,是咱們理解ThreadLocalMap資料結構的重要切入點。
//根據當前索引位置和陣列長度獲取下一個索引值
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}
//根據當前索引位置和陣列長度獲取上一個索引值
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}
我們看nextIndex()方法,噹噹前索引值加1,如果小於陣列長度i+1,否則返回0。就是說如果當前索引值加一等於陣列的長度就返回0。我們想到了啥?圓鍾,23點再加一個小時等於24點,24就為一天的中時數,而24點也是零點,起點。我們會想到Entry陣列是一個環形狀。再看nextIndex()方法,當前索引值減1後如果小於0,返回陣列的長度減1,即15,就是i等於0的時候,i減一不是等於負一,而是十五,這個時候我們可以確認entry陣列就是一個環形結構。使用線性探測法來解決雜湊衝突的。
下圖即為Entry陣列的結構圖
圖片來源於:https://www.cnblogs.com/micrari/p/6790229.html
Entry陣列上每個節點為一個Entry,每個Entry由一個指向ThreadLocal的的弱引用為key,value即為我們設定的變數值。
這裡再想下怎麼通過Key(ThreadLocal)來計算索引值?
這個計算索引值不是通過類似key.hashCode()這種方式來計算的,而是根據型別為AtomicInteger的nextHashCode成員變數和增量值HASH_INCREMENT成員變數來計算的,計算方式就是通過nextHashCode加上HASH_INCREMENT值的和與Entry陣列長度的位與運算來計算的。如程式碼所示。
int i = key.threadLocalHashCode & (table.length - 1);
private final int threadLocalHashCode = nextHashCode();
private static int nextHashCode() {
return nextHashCode.getAndAdd(HASH_INCREMENT);
}
理解了Entry陣列的資料結構,我們繼續看ThreadLocalMap提供的主要方法。
獲取:private Entry getEntry(ThreadLocal<?> key)
//根據key值獲取Entry節點
private Entry getEntry(ThreadLocal<?> key) {
//根據key值計算索引位置
int i = key.threadLocalHashCode & (table.length - 1);
//獲取entry陣列中該索引位置的Entry節點
Entry e = table[i];
if (e != null && e.get() == key)
//如果e不為null並且e的Reference(ThreadLocal)與key相同,直接返回e節點
return e;
else
//如果根據計算出的索引值沒有找到Entry節點
return getEntryAfterMiss(key, i, e);
}
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
Entry[] tab = table;
int len = tab.length;
while (e != null) {
//如果e不為null
//獲取entry的key,即ThreadLocal
ThreadLocal<?> k = e.get();
if (k == key)
//如果和key相等直接返回該元素
return e;
if (k == null)
//如果k為null,清理無效的entry,或者說清理ThreadLocal已經被回收的entry
expungeStaleEntry(i);
else
i = nextIndex(i, len);
e = tab[i];
}
//如果e為null,就直接返回null了
return null;
}
//該方法主要做了兩件事
//第一將索引為staleSlot的節點entry的value置為null,並且將entry置為null,有利於垃圾回收
//第二從索引stateSlot的下一個索引處開始遍歷判斷每個entry的ThreadLocal是否為null,如果為null,將
//該entry的value和entry本身置為null,如果不為null,進行rehash重新計算索引值,判斷重新計算出來的
//索引值和當前迴圈的索引值是否相等,如果相等,進入下一個迴圈,如果不等,在環形索引中尋找為節點為空的
//下角標,將e節點放置在這個索引位置
private int expungeStaleEntry(int staleSlot) {
//獲取ThreadLocalMap的entry陣列和陣列的長度
Entry[] tab = table;
int len = tab.length;
//因為在getEntryAfterMiss方法中已經判定k==null了
//既然key為null,所以顯示將key對應的value置為null
tab[staleSlot].value = null;
//顯示將這個節點entry也置為null,置為null有助於垃圾回收
tab[staleSlot] = null;
//entry陣列的元素個數減一
size--;
//執行Rehash直到再次遇到null值
Entry e;
int i;
//迴圈遍歷,i的初始值為當前下角標stateSlot的下一個索引位置
for (i = nextIndex(staleSlot, len);
//將entry陣列中下角標為當前遍歷的角標i的節點賦值給e
(e = tab[i]) != null;
//每迴圈完一次去獲取下一個索引位置賦值給i
i = nextIndex(i, len)) {
//獲取當前遍歷的entry的key值,即ThreadLocal
ThreadLocal<?> k = e.get();
if (k == null) {
//如果key(threadLocal)為null,即把key對應的value和當前這個節點都置為null
//有助於垃圾回收
e.value = null;
tab[i] = null;
size--;
} else {
//如果key不為null
//計算索引值
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
//如果新計算的索引值跟現在遍歷的索引值不相等
//將當前遍歷的索引值對應的節點置為null
tab[i] = null;
// Unlike Knuth 6.4 Algorithm R, we must scan until
// null because multiple entries could have been stale.
//在環形索引中尋找為節點為空的下角標,將e節點放置在這個索引位置
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
return i;
}
ThreadLocalMap通過key(ThreadLocal)來獲取Entry節點,首先通過key來計算索引值,再通過索引值獲取到某個Entry。如果Entry的key與引數key相同,則直接返回這個Entry節點;如果Entry為null,則直接返回null;如果Entry不為null,但是key不相同,就走getEntryAfterMiss()這個方法。這個方法裡面主要是判斷entry的key(ThreadLocal)。如果key既不相等也不為null,迴圈遍歷下個索引值對應的entry。但是如果key為null,這個時候會走expungeStaleEntry()方法了,這個方法比較重要,我們單獨來說說。
首先我們想象key為null代表著什麼?key為threadLocal,即threadLocal為null,而threadLocal為弱引用指向的,其實這裡表示為ThreadLocal被回收了,雖然ThreadLocal被回收了,但是key對應的value是跟Thread掛鉤的,value可能還沒被回收,所以這裡我們需要顯示的將value和entry置為null,以便於垃圾回收這些物件,同時防止記憶體洩露。不僅如此程式碼中還會開始遍歷該entry索引後面的整個Entry陣列,如果那個entry的key為null,都會顯示將object和entry置為null,讓其被回收,防止記憶體洩露。
設定值:private void set(ThreadLocal<?> key , Object value)
private void set(ThreadLocal<?> key, Object value) {
//獲取entry陣列和陣列的長度
Entry[] tab = table;
int len = tab.length;
//計算key值對應的索引位置
int i = key.threadLocalHashCode & (len-1);
//根據計算的索引值獲取對應的Entry,從該索引處開始迴圈向後遍歷
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
//根據Entry獲取ThreadLocal
ThreadLocal<?> k = e.get();
if (k == key) {
//如果key與當前entry的key相同
//直接用引數value覆蓋entry中的原value
e.value = value;
return;
}
if (k == null) {
//如果k為null
//替換無效的entry
replaceStaleEntry(key, value, i);
return;
}
}
//建立一個新的Entry節點
tab[i] = new Entry(key, value);
int sz = ++size;
if (!cleanSomeSlots(i, sz) && sz >= threshold)
//如果元素個數大於或者等於閾值,擴容
rehash();
}
private void replaceStaleEntry(ThreadLocal<?> key, Object value,int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
int slotToExpunge = staleSlot;
//向索引staleSlot的前面開始迴圈遍歷,直到tab[i]不為null
//向前遍歷找到最近的一個ThreadLocal為null的entry
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
//如果entry的key(ThreadLocal)為null
//獲取entry的索引值
slotToExpunge = i;
//向staleSlot的後面遍歷
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
if (k == key) {
//如果entry的key等於引數key
//直接覆蓋entry的value值
e.value = value;
//??
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
if (slotToExpunge == staleSlot)
slotToExpunge = i;
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
// If key not found, put new entry in stale slot
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
// If there are any other stale entries in run, expunge them
if (slotToExpunge != staleSlot)
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
}
private boolean cleanSomeSlots(int i, int n) {
boolean removed = false;
Entry[] tab = table;
int len = tab.length;
do {
i = nextIndex(i, len);
Entry e = tab[i];
if (e != null && e.get() == null) {
n = len;
removed = true;
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return removed;
}
private void rehash() {
expungeStaleEntries();
// Use lower threshold for doubling to avoid hysteresis
if (size >= threshold - threshold / 4)
resize();
}
private void resize() {
Entry[] oldTab = table;
int oldLen = oldTab.length;
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null; // Help the GC
} else {
int h = k.threadLocalHashCode & (newLen - 1);
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}
setThreshold(newLen);
size = count;
table = newTab;
}
終於看完ThreadLocalMap了,我們可以接著看ThreadLocal的程式碼了!。
protected T initialValue() {
return null;
}
設定,void set(T value);
public void set(T value) {
//獲取當前執行緒
Thread t = Thread.currentThread();
//獲取當前執行緒的TreadLocalMap
ThreadLocalMap map = getMap(t);
if (map != null)
//如果ThreadLocalMap不為null,直接呼叫ThreadLocalMap的set方法
map.set(this, value);
else
//如果ThreadLocalMap為,以當前執行緒和value值建立ThreadLocalMap
createMap(t, value);
}
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}
void createMap(Thread t, T firstValue) {
//建立ThreadLocalMap並用當前執行緒指向該map
t.threadLocals = new ThreadLocalMap(this, firstValue);
}
從set()方法可以看出每個執行緒(Thread)有一個threadLocals變數,如程式碼所示:
//Thread類的成員變數
ThreadLocal.ThreadLocalMap threadLocals = null;
ThreadLocal在設定值的時候,會先判斷當前執行緒有沒有初始化ThreadLocalMap,如果沒有,先根據當前thredLocal(key)和value值生成ThreadLocalMap,並用該執行緒的成員變數threadLocals指向這個ThreadLocalMap;如果當前執行緒已經關聯ThreadLocalMap了,則直接通過ThreadLocalMap的set方法設定值。
獲取:T get();
public T get() {
//獲取當前執行緒
Thread t = Thread.currentThread();
//獲取當前執行緒關聯的ThreadLocalMap
ThreadLocalMap map = getMap(t);
if (map != null) {
//如果ThreadLocalMap不為null,根據key(ThreadLocal)值獲取entry
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
//獲取entry的value值返回
return result;
}
}
//否則初始化當前執行緒的ThreadLocalMap,value為null
return setInitialValue();
}
private T setInitialValue() {
//value為空
T value = initialValue();
//獲取當前執行緒
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
createMap(t, value);
return value;
}
protected T initialValue() {
return null;
}
到此,ThreadLocal的主要程式碼就介紹完了。
ThreadLocal是否存在記憶體洩露問題?
會,我們先來看下ThreadLocal的引用和資料結構圖,圖片來源於:http://www.importnew.com/22039.html,map指的ThreadLocalMap,實線代表強引用,虛線代表弱引用。
我們看到ThreadLocal有一個強引用和一個弱引用,強引用來自高層程式碼中的引用,比如ThreadLocal tl = new TheadLocal(),tl這就是一個強引用,而弱應用來自於ThreadLocalMap中的Entry的key的引用。當高層程式碼中把threadlocal例項置為null以後,就沒有任何強引用指向threadlocal例項,而只有一個弱引用去指向ThreadLocal,但是我們知道弱引用指向的物件在GC時是會被回收的,所以threadlocal將會被gc回收。這也是Entry中的key使用弱應用的原因,否則TreadLoca就算在高層程式碼中釋放引用後,因為Entry還存在,key仍然指向ThreadLocal,所以讓不會被回收,容易造成記憶體洩露。
當ThreadLocal被回收後,我們的value還不能回收,因為存在一條從current thread連線過來的強引用.,只要thread存在,這個引用就會一直存在,只有當thread結束以後, current thread才會被銷燬,強引用才會斷開, 此時Current Thread, Map, value才能全部被GC回收。
所以這裡存在一個風險就是,在current Thread到銷燬的這段時間內,存在由於value值過多或者過大導致的記憶體洩露問題,我們在想下,如果我們是使用的執行緒池,出現什麼結果,執行緒用完後,直接放回執行緒池中,不會被銷燬,那麼那些value就會一直存在,這樣產生記憶體洩露的可能性大大增加。
JDK是怎麼解決這個問題的呢?
我們回過頭來在看看ThreadLocalMap的set和get方法,我們發現程式碼裡都會迴圈遍歷Entry陣列,檢查entey中的key(ThreadLocal)是否為null,如果為null,會顯示的將entry的value和entry本身置為null,這樣以便entry和entry的value能被GC回收,防止記憶體洩露。
既然知道了記憶體洩露的前因後果,我們在使用TheadLocal時候就要特別注意這方面的問題,比如我們再用完TheadLocal後記得用remove()方法去清除資料。