- ThreadLocal原理
- ThreadLocal源码研读
- set方法
- getMap
- createMap
- map.set
- get方法
- map.getEntry(this)
- remove
- FastThreadLocal
ThreadLocal原理
我们都知道当使用ThreadLocal维护变量时,ThreadLocal为每个使用该变量的线程提供独立的变量副本,所以每一个线程都可以独立地改变自己的副本,而不会影响其它线程所对应的副本。
详细介绍一下ThreadLocal是如何实现为线程提供变量副本的,方便下面源码的理解:
首先我们要知道每个线程下都有一个私有变量map,当我们使用ThreadLocal进行set(val)变量时,会向当前线程下的map中put一个键为当前ThreadLocal对象(虚引用),值为val的键值对,这样当使用ThreadLocal的get方法时,会直接向当前线程下的map获得键为此ThreadLocal的值。由于此操作只在当前线程下,所以完美的避免了并发
如果没看懂,建议你多读几遍,带着问题往下看。
ThreadLocal源码研读
对于ThreadLocal源码 本身没有什么好研究的,因为它就五个我们可以调用的方法。ThreadLocalMap才是我们要研究的核心。
为了研究ThreadLocalMap源码,我们从ThreadLocal的set方法开始
set方法
就从我们常用的set方法开始。
public void set(T value) {//获取当前执行线程
Thread t = Thread.currentThread();
//获得当前线程的map
ThreadLocalMap map = getMap(t);
//如果map不等于null,key为当前ThreadLocal对象,value为我们的值,如果获得的map为null,则为该线程初始化map
if (map != null)
map.set(this, value);
else
createMap(t, value);
}
在这个方法里面有三个需要解释的地方:
- getMap(t)
- createMap(t, value)
- map.set(this, value)
getMap
此方法实质就是获得当前线程的map对象,而此map是绑定在线程上的
Thread t) {//Thread类中有一个ThreadLocalMap类型的对象threadLocals
return t.threadLocals;
}public class Thread implements Runnable {
...
ThreadLocal.ThreadLocalMap threadLocals = null;
...
通过源码可以发现getMap(t)方法为获得当前线程的map。追踪到Thread类中,可以发现该类中有一个ThreadLocalMap类型的threadLocals变量。
createMap
而对于 set方法中的createMap(t, value); 方法也很简单,就是为该线程初始化一个map,使用的构造函数会为该map插入第一个键值对。
void createMap(Thread t, T firstValue) {t.threadLocals = new ThreadLocalMap(this, firstValue);
}Object firstValue) {
//初始大小:16
table = new Entry[INITIAL_CAPACITY];
//threadLocal的hashCode与Entry大小进行&操作得到该value放的位置i
int i = firstKey.threadLocalHashCode & (INITIAL_CAPACITY - 1);
table[i] = new Entry(firstKey, firstValue);
size = 1;
//设置阈值private void setThreshold(int len) {
threshold = len * 2 / 3;
}
通过ThreadLocalMap构造函数,即发现了ThreadLocalMap也是一个Entry数组实现的map,通过hash得到数组的下标,把值放入即可。firstKey就是当前的ThreadLocal对象。
有个问题为什么初始化Entry数组的大小为24
2
4
呢?其实不只初始化为24
2
4
,在每次扩容是都是2倍的扩容。原因就是2n−1
2
n
−
1
的二进制都是1,在与hash码进行与操作时不会造成浪费。
还有一点需要提出的是:firstKey.threadLocalHashCode
new AtomicInteger();
private static final int HASH_INCREMENT = 0x61c88647;
private static int nextHashCode() {
return
通过源码可以发现nextHashCode每次都是增加HASH_INCREMENT,而HASH_INCREMENT的值为0x61c88647,有什么含义吗?
原来0x61c88647做为hash码与2n−1
2
n
−
1
进行与操作时会减少冲突,这与jdk8中hashMap源码的(h = key.hashCode()) ^ (h >>> 16)含义一样。
至于为什么使用0x61c88647 会减少冲突,这要问那些大数学家了。我把google搜索到的内容贴上:
This number represents the golden ratio (sqrt(5)-1) times two to the power of 31 ((sqrt(5)-1) * (2^31)). The result is then a golden number, either 2654435769 or -1640531527.
好像是黄金分割数。
We established thus that the HASH_INCREMENT has something to do with fibonacci hashing, using the golden ratio. If we look carefully at the way that hashing is done in the ThreadLocalMap, we see why this is necessary. The standard java.util.HashMap uses linked lists to resolve clashes. The ThreadLocalMapsimply looks for the next available space and inserts the element there. It finds the first space by bit masking, thus only the lower few bits are significant. If the first space is full, it simply puts the element in the next available space. The HASH_INCREMENT spaces the keys out in the sparce hash table, so that the possibility of finding a value next to ours is reduced.
在构造函数中还new了一个Entry对象。而Entry对象继承了弱引用
class Entry extends WeakReference<ThreadLocal<?>> {/** The value associated with this ThreadLocal. */
Object value;
Entry(ThreadLocal<?> k, Object v) {
super(k);
value = v;
}
}
通过这里我们发现,在ThreadLocalMap中的key其实是ThreadLocal的虚引用。
我们知道引分为:强引用、软引用、弱引用、虚引用。对于弱引用而言,处于第三档,一般在第二次GC时被回收。
那么ThradLocal为什么要用虚引用呢?
原因就是避免内存泄漏,当jvm进行垃圾回收时,会从root节点开始,把不可达的对象进行清理。而ThreadLocalMap是绑定在线程上的,只要线程不被销毁,那么此对象就一直可达。具体的引用链为:currentThread(当前线程)->threadLocals(ThreadLocalMap对象)->Entry数组->某个entry对象。ThreadLocal有一套自己的清理机制,会在下面详细介绍。
在研究map.set方法时,首先看其它两个方法。
private static int nextIndex(int i, int len) {
return ((i + 1 < len) ? i + 1 : 0);
}
//获得以i为索引位置len为总长的上一个索引位置
private static int prevIndex(int i, int len) {
return ((i - 1 >= 0) ? i - 1 : len - 1);
}
通过这两个方法,我们可以发现其实此map就是一个环形的Entry数组。具体的示意图
map.set
private void set(ThreadLocal<?> key, Object value) {Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len-1);
//循环结束条件为:e=null
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
ThreadLocal<?> k = e.get();
//找到key 重新赋值
if (k == key) {
e.value = value;
return;
}
//发现被gc回收的key 进行value的回收操作 并set新值
if (k == null) {
replaceStaleEntry(key, value, i);
return;
}
}
tab[i] = new Entry(key, value);
int sz = ++size;
//如果未进行段式清理 并且需要扩容 则进行rehash操作
if
解释:获得当前ThreadLocal的位置i,然后以i为起点向后进行遍历,如果找到可key则重新为其赋值,如果找到了被回收的key,进行清理并赋值,否则新建一个Entry对象。如果必要,进行重新hash的操作。
private void replaceStaleEntry(ThreadLocal<?> key, Object value,int staleSlot) {
Entry[] tab = table;
int len = tab.length;
Entry e;
//记录段式清理的回收起点
int slotToExpunge = staleSlot;
//以要回收的位置staleSlot开始,向前找第一个被GC回收的位置,
//如果找到则为slotToExpunge重新赋值
for (int i = prevIndex(staleSlot, len);
(e = tab[i]) != null;
i = prevIndex(i, len))
if (e.get() == null)
slotToExpunge = i;
//以要回收的位置staleSlot开始 向后
for (int i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
//因为staleSlot是需要回收的,为了保证正常的顺序,进行位置的调换
if (k == key) {
e.value = value;
tab[i] = tab[staleSlot];
tab[staleSlot] = e;
//当前位置i需要被回收
if (slotToExpunge == staleSlot)
slotToExpunge = i;
//进行段式清理
cleanSomeSlots(expungeStaleEntry(slotToExpunge), len);
return;
}
//向后便利发现了key被回收 并且向前没找到key需要回收的 重新赋值回收的起点
if (k == null && slotToExpunge == staleSlot)
slotToExpunge = i;
}
// 如果没找到key 则new一个新的对象
tab[staleSlot].value = null;
tab[staleSlot] = new Entry(key, value);
// 如果回收起点改变了 进行回收操作
if
解释:首先以要回收的位置staleSlot开始,默认以staleSlot为第一个要回收的key。然后以staleSlot为起点,向前遍历,找到第一个被回收的key,记录下来。以staleSlot为起点,向后遍历,如果找到了要set的key,赋值后,与要回收的位置staleSlot进行交换,然后以当前位置为第一个要回收的起点,进行回收操作。
如果向后遍历没有找到要set的key,则进行新建操作。如果回收起点变了,进行清理操作。
Entry[] tab = table;
int len = tab.length;
// 回收staleSlot位置
tab[staleSlot].value = null;
tab[staleSlot] = null;
size--;
Entry e;
int i;
//以回收的位置为起点,向后遍历
//如果key被回收,则设置其value为null
//如果未被回收 进行rehash操作
for (i = nextIndex(staleSlot, len);
(e = tab[i]) != null;
i = nextIndex(i, len)) {
ThreadLocal<?> k = e.get();
//key为null 回收其value
if (k == null) {
e.value = null;
tab[i] = null;
size--;
} else {
//获得hash位置
int h = k.threadLocalHashCode & (len - 1);
if (h != i) {
tab[i] = null;
//因为用的是线性探测法 向后遍历
while (tab[h] != null)
h = nextIndex(h, len);
tab[h] = e;
}
}
}
//返回回收后 空的位置i
return
解释:首先把回收的位置staleSlot进行回收,然后以staleSlot为起点,向后遍历,如果发现其它被回收的key,对其value进行回收操作。如果key未被回收,则进行rehash操作。
private boolean cleanSomeSlots(int i, int n) {boolean removed = false;
Entry[] tab = table;
int len = tab.length;
//以i为起点 进行回收操作
do {
i = nextIndex(i, len);
Entry e = tab[i];
//发现被回收的key 进行段式清理
if (e != null && e.get() == null) {
n = len;
removed = true;
i = expungeStaleEntry(i);
}
} while ( (n >>>= 1) != 0);
return
解释:此方法,会以i为起点,进行回收操作,如果发现了被回收的key则进行段式清理,否则进行logn次的清理。
private void rehash() {expungeStaleEntries();
//size是否达到阈值
//实际阈值为:(len * 2 / 3) - (len * 2 / 3) / 4 = len / 2
if (size >= threshold - threshold / 4)
resize();
}
解释:rehash会进行一次全量的清理,然后再进行判断,是否进行resize操作。
private void expungeStaleEntries() {Entry[] tab = table;
int len = tab.length;
//全量的清理
for (int j = 0; j < len; j++) {
Entry e = tab[j];
if (e != null && e.get() == null)
expungeStaleEntry(j);
}
}
解释:从起点开始,进行全量的清理
private void resize() {Entry[] oldTab = table;
int oldLen = oldTab.length;
//扩容操作 为原来的2倍
int newLen = oldLen * 2;
Entry[] newTab = new Entry[newLen];
int count = 0;
for (int j = 0; j < oldLen; ++j) {
Entry e = oldTab[j];
if (e != null) {
ThreadLocal<?> k = e.get();
if (k == null) {
e.value = null;
} else {
//新的slot
int h = k.threadLocalHashCode & (newLen - 1);
//线性探测法 向后探测
while (newTab[h] != null)
h = nextIndex(h, newLen);
newTab[h] = e;
count++;
}
}
}
setThreshold(newLen);
size = count;
//替换旧table
解释:新的table容量为以前table的2倍,然后简历oldTable,对于未被回收的slot进行迁移新table的操作。最后设置新的阈值与size,新建的table替换老table。
get方法
下面开始看get方法
public T get() {Thread t = Thread.currentThread();
//得到当前线程的map集合
ThreadLocalMap map = getMap(t);
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
return
主要看map.getEntry(this)
map.getEntry(this)
private Entry getEntry(ThreadLocal<?> key) {//获得slot位置
int i = key.threadLocalHashCode & (table.length - 1);
Entry e = table[i];
//如果找到 直接返回
if (e != null && e.get() == key)
return e;
else
//可能第一个不是 进行线性探测
return
解释:根据slot直接获得Entry对象,如果此对象的key为要查询的key,直接返回值,否则进行线性探测。
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {Entry[] tab = table;
int len = tab.length;
while (e != null) {
ThreadLocal<?> k = e.get();
//找到 直接返回
if (k == key)
return e;
//key被回收 进行回收操作
if (k == null)
expungeStaleEntry(i);
else
//未找到 继续遍历
i = nextIndex(i, len);
e = tab[i];
}
return null;
}
解释:以位置i开始,向后进行遍历,继续查找需要get的key。
- 如果找到,直接返回
- 如果key被回收,则进行回收操作
- 未找到,继续下一个slot
remove
private void remove(ThreadLocal<?> key) {Entry[] tab = table;
int len = tab.length;
int i = key.threadLocalHashCode & (len - 1);
for (Entry e = tab[i];
e != null;
e = tab[i = nextIndex(i, len)]) {
if (e.get() == key) {
// 断开弱引用
e.clear();
// 进行清理
expungeStaleEntry(i);
return;
}
}
}
解释:remove方法就很简单,查找到key 断开弱引用,然后进行清理即可
另外建议:尽管ThreadLocal自带一套清理机制,在set操作和get操作都有很大的几率进行清理操作。但大家在使用完ThreadLocal中保存的值时,也最好手动操作一次remove方法。避免entry.value较长时间存留。
FastThreadLocal
其实在netty中,也实现了自己的ThreadLocal。
在jdk中,纵然使用了黄金分割数,也是只能尽量减少冲突,而不能完全避免。基于线性探测的开放寻址,势必会消耗一定的时间。如果我们直接使用数组来进行索引,也就完美的避免了冲突,只是会消耗大一点的内存。
而netty就是基于数组实现的FastThreadLocal.具体源码为就不再去介绍,在这里奉上对比的时间。
@throws InterruptedException*/
@Test
public void testThreadLocalOnJDK() throws InterruptedException {
ThreadLocal<Integer>[] threadLocals = new ThreadLocal[THREAD_LOCAL_NUM];
for (int i = 0; i < THREAD_LOCAL_NUM; i++) {
threadLocals[i] = new ThreadLocal<>();
}
Thread thread = new Thread(() -> {
long startTime = System.currentTimeMillis();
for (int i = 0; i < THREAD_LOCAL_NUM; i++) {
threadLocals[i].set(i);
}
for (int i = 0; i < THREAD_LOCAL_NUM; i++) {
for (int j = 0; j < GET_COUNT; j++) {
threadLocals[i].get();
}
}
System.out.println(Thread.currentThread().getName() + "耗时:" + (System.currentTimeMillis() - startTime) + "ms");
});
thread.setName("jdk-thread");
thread.start();
thread.join();
}
/**
* 测试netty的ThreadLocal耗时
* @throws InterruptedException
*/
@Test
public void testThreadLocalOnNetty() throws InterruptedException {
FastThreadLocal<Integer>[] threadLocals = new FastThreadLocal[THREAD_LOCAL_NUM];
for (int i = 0; i < THREAD_LOCAL_NUM; i++) {
threadLocals[i] = new FastThreadLocal<>();
}
Thread thread = new FastThreadLocalThread(() -> {
long startTime = System.currentTimeMillis();
for (int i = 0; i < THREAD_LOCAL_NUM; i++) {
threadLocals[i].set(i);
}
for (int i = 0; i < THREAD_LOCAL_NUM; i++) {
for (int j = 0; j < GET_COUNT; j++) {
threadLocals[i].get();
}
}
System.out.println(Thread.currentThread().getName() + "耗时:" + (System.currentTimeMillis() - startTime) + "ms");
});
thread.setName("netty-thread");
thread.start();
thread.join();
}
结果: