通过源码学习ThreadLocal实现原理

ThreadLocal是什么?

ThreadLocal是一个与Thread线程绑定的变量,其填充的数据只属于当前线程,变量的数据对于其他线程而言是隔离的。

简单来说就是,ThreadLocal是一个变量,在A线程中设值a,B线程中设值b,其数据与设值时的线程相绑定,不同线程之间是隔离的。也就是说这个值隶属于设值时的线程对象,相当于在Thread对象中额外定义了一个属性(实际上是额外定义了一个map对象,通过key-value的形式来存储数据,后面会解释)。

ThreadLocal的使用场景?

在web开发中,用户信息一般存储于session中,而session一般可以在controller层中进行获取,如果在service层中需要使用到用户信息的话,就需要一路传递用户信息,显得有点恶心且繁琐,而用户信息又是比较常用到的一个数据对象,且可以与线程进行绑定(这个线程干的活一般都只会跟这个用户有关)。所以可以考虑使用ThreadLocal进行用户信息的存储,且搭配拦截器或者AOP的形式从session中获取用户信息并存储到UserHolder中。

public class UserHolder {

    /**
     * 存储用户信息
     */
    private static ThreadLocal<User> USER_THREAD_LOCAL = new ThreadLocal<>();

    /**
     * 获取用户信息
     *
     * @return
     */
    public User getUser() {
        return USER_THREAD_LOCAL.get();
    }

    /**
     * 设置用户信息
     *
     * @param user
     */
    public void setUser(User user) {
        USER_THREAD_LOCAL.set(user);
    }

    /**
     * 移除用户信息
     */
    public void removeUser() {
        USER_THREAD_LOCAL.remove();
    }

}

上面是一个比较常见且比较好理解的业务上的使用场景,实际上在框架中可以看到对ThreadLocal的大量使用,比如:链路日志的链路id传递MDC对象,Spring的事务管理TransactionSynchronizationManager等。

ThreadLocal的实现原理?

1、看源码前先大概看下ThreadLocal的基本使用以及提供的几个方法

public class Test {

    private static ThreadLocal<Integer> tl1 = new ThreadLocal<>();

    private static ThreadLocal<Integer> tl2 = new ThreadLocal<>();

    public static void main(String[] args) {
        tl1.set(1);
        tl1.get();
        tl1.remove();
    }

}

可以看到,ThreadLocal就是一个变量,只不过这个变量的设值是通过set,get来设置的;同时ThreadLocal也是一个对象,而且一般作为一个类的静态属性,即使用static进行修饰。

这里有个问题,为什么推荐使用static进行修饰呢?
因为如果我们不使用static进行修饰的话,不同的Test对象中就会实例化不同的ThreadLocal对象tl1,虽然他们都叫tl1,但是因为他们不是同一个对象,所以获取到的数据也是不一样的。而我们的数据是跟着ThreadLocal对象走的,只是这个对象与线程Thread有关,换个说法,唯一能确定我们的数据的条件是(ThreadLocal对象+Thread对象)。所以我们需要使用static进行修饰来保证tl1的唯一性。(当然使用其他办法也可以,比如Spring容器的单例等,只是static是最佳的一种模式)

2、从get、set方法入手

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null)
        map.set(this, value);
    else
        createMap(t, value);
}

可以看到,ThreadLocal的get和set方法,其基本操作都是从当前的线程对象中获取到一个map对象,然后通过map对象来获取对应的值。所以不难猜测,整个ThreadLocal的核心就是ThreadLocalMap。不废话,我们直接看整个数据存储的一个基本模型。

3、基本模型

从上图可以看出,在Thread类中有一个ThreadLocalMap属性用于存储数据,而ThreadLocalMap与我们常见的HashMap很类似,也是通过将数据包装成一个Entry对象存储于一个数组中,但是与HashMap还是有些许的不同,在这里姑且先认为是一个常见的map,暂不讨论具体细节。

4、如何从线程中获取ThreadLocalMap
无论是set还是get方法,第一步都是先获取ThreadLocalMap,而ThreadLocalMap是存储在Thread中的,所以需要先获取到当前线程对象。通过Thread.currentThread();方法进行获取,这个方法是一个native方法,其实现我们就不关心了。

接着调用getMap方法,入参为线程对象,方法内部实现也极其简单,就是直接读取

ThreadLocalMap getMap(Thread t) {
    return t.threadLocals;
}

获取到ThreadLocalMap后,后续的逻辑就很简单了,基本都能看懂。而到这里,ThreadLocal基本的存取值,与线程对象绑定、线程间数据的隔离等秘密也已经解开。接下来我们深入探究下整个的核心ThreadLocalMap。

ThreadLocalMap

ThreadLocalMap与HashMap其实长得很像,基本的思想逻辑也很像,这里只探讨一些不同的点。

1、为什么使用ThreadLocal做为Map的key?
因为在一个线程中,我们可以同时使用多个ThreadLocal对象来存储获取数据,也就是说数据与ThreadLocal关联,故使用ThreadLocal作为key非常合适。

2、为什么Entry是一个WeakReference弱引用类型?
插播学习一下java中四种引用类型之间的区别:

  • 强引用:最常使用到的=赋值的情况就是强引用,对象只有在与GC Root之间没有强引用链的情况下才会被回收。
  • 软引用:内存不够就回收,内存充足不回收。
  • 弱引用:只要发生GC,一定被回收。
  • 虚引用:是最弱的引用关系,无法通过虚引用来取得一个对象实例。为一个对象设置虚引用关联的唯一目的就是能在这个对象被收集器回收时收到一个系统通知。
static class Entry extends WeakReference<ThreadLocal<?>> {
    /** The value associated with this ThreadLocal. */
    Object value;

    Entry(ThreadLocal<?> k, Object v) {
        super(k);
        value = v;
    }
}

Entry继承了WeakReference类,通过构造函数我们可以看出,Entry只是对key也就是ThreadLocal是一个弱引用,而具体的value值还是一个强引用,为什么呢?

首先,我们访问ThreadLocalMap中的值都是通过ThreadLocal对象来访问的,如果说ThreadLocal对象已经不存在强引用了,那么这个ThreadLocal对象就应该被回收掉。
比如:我们在上述例子中的Test对象中,执行tl1 = null时。原先的tl1对象就不存在引用了,这个时候就应该回收该对象。而如果不使用弱引用的话,就会导致ThreadLocalMap中对tl1还存在一个强引用,这个时候tl1对象就永远不会被回收(除非Thread对象被回收),从而导致内存泄露。

这个时候你可能又要问了,那为什么value不是一个弱引用?
因为value是作为一个值存储在map中,方便后续我们使用的时候获取这个值的,所以在整个堆中,他的引用可能有且只有一个,也就是Entry的一个引用,设想一下,如果他是一个弱引用,只要发生GC,那么他就没了,那么请问我们存储他的意义是什么?所以value不能是一个弱引用。也正因为如此,我们在使用完ThreadLocal之后,一定要记得remove,否则还是有可能发生内存泄露的问题。(这里还有一点是因为大部分情况下会使用线程池对线程进行复用,所以使用完没有remove的话,可能会导致数据污染)

3、发生hash冲突时,也是使用链表或者红黑树吗?

private void set(ThreadLocal<?> key, Object value) {

    // We don't use a fast path as with get() because it is at
    // least as common to use set() to create new entries as
    // it is to replace existing ones, in which case, a fast
    // path would fail more often than not.

    Entry[] tab = table;
    int len = tab.length;
    int i = key.threadLocalHashCode & (len-1);

    for (Entry e = tab[i];
         e != null;
         e = tab[i = nextIndex(i, len)]) {
        ThreadLocal<?> k = e.get();

        if (k == key) {
            e.value = value;
            return;
        }

        if (k == null) {
            replaceStaleEntry(key, value, i);
            return;
        }
    }

    tab[i] = new Entry(key, value);
    int sz = ++size;
    if (!cleanSomeSlots(i, sz) && sz >= threshold)
        rehash();
}

通过set方法可以看出,ThreadLocalMap的hash值取的是key(也就是ThreadLocal的threadLocalHashCode,这个值在初始化ThreadLocal对象时自动生成)与当前容量的一个与运算得出一个i值,这个i值就是Entry数组中的下标。i当然是存在冲突的可能的。在发生冲突时,会自动的调用nextIndex方法,在这个方法内部会直接进行i+1,如果超过len的值时,会从0开始。也就是把Entry数组当成了一个圆环,发生冲突之后就往后面找,直到找到一个空的位置。

private Entry getEntry(ThreadLocal<?> key) {
    int i = key.threadLocalHashCode & (table.length - 1);
    Entry e = table[i];
    if (e != null && e.get() == key)
        return e;
    else
        return getEntryAfterMiss(key, i, e);
}

/**
 * Version of getEntry method for use when key is not found in
 * its direct hash slot.
 *
 * @param  key the thread local object
 * @param  i the table index for key's hash code
 * @param  e the entry at table[i]
 * @return the entry associated with key, or null if no such
 */
private Entry getEntryAfterMiss(ThreadLocal<?> key, int i, Entry e) {
    Entry[] tab = table;
    int len = tab.length;

    while (e != null) {
        ThreadLocal<?> k = e.get();
        if (k == key)
            return e;
        if (k == null)
            expungeStaleEntry(i);
        else
            i = nextIndex(i, len);
        e = tab[i];
    }
    return null;
}

而在get方法中,计算出i的值并获取到Entry对象后会额外做一个key的判断,如果key的值不对,则会调用getEntryAfterMiss方法,这个方法会从该值开始往后寻找,直到查找到key相等的Entry对象进行返回。

所以在发生hash冲突时,并没有使用链表或者红黑树,而是直接往后寻找的形式来解决hash冲突。可能会问,万一后面的值都被占满了找不到空位怎么办?答案是ThreadLocalMap会进行扩容,所以不存在找不到空位置的情况。

4、map是如何进行扩容的?

private void resize() {
    Entry[] oldTab = table;
    int oldLen = oldTab.length;
    int newLen = oldLen * 2;
    Entry[] newTab = new Entry[newLen];
    int count = 0;

    for (int j = 0; j < oldLen; ++j) {
        Entry e = oldTab[j];
        if (e != null) {
            ThreadLocal<?> k = e.get();
            if (k == null) {
                e.value = null; // Help the GC
            } else {
                int h = k.threadLocalHashCode & (newLen - 1);
                while (newTab[h] != null)
                    h = nextIndex(h, newLen);
                newTab[h] = e;
                count++;
            }
        }
    }

    setThreshold(newLen);
    size = count;
    table = newTab;
}

扩容主要关注resize方法,其核心是直接进行翻倍扩容,初始化一个newTab数组,并重新计算hash值并且赋值。由于ThreadLocalMap是跟线程进行绑定的,所以也不需要考虑多线程的问题。