TransmittableThreadLocal 相信很多人用过,一个在多线程情况下共享线程上下文的利器
名字好长,以下简称 ttl
本文以之前一个真实项目开发遇到的问题,还原当时从源码的角度分析并解决问题的过程

环境

item version
java 8
springboot 2.2.2.RELEASE
ttl 2.11.4

代码如下,主线程并行启复数任务丢给线程池处理 点击查看代码

List<ProcDef> defs = createsValidate(procCreate);
List<CompletableFuture<CreateResult>> cfs = Lists.newArrayListWithCapacity(defs.size());
defs.forEach(def -> {
    AbstractTransientVariable variable = procCreate.getProcDefKeyVars().get(def.getProcDefKey());
    CompletableFuture<CreateResult> cf = CompletableFuture.supplyAsync(() -> create(def, variable), threadPoolTaskExecutor)
	    .handle((r, e) -> {
		if (e != null) {
		    expHandle(def.getProcDefKey(), procCreate.getUserId(), variable, e);
		}
		return r;
	    });
    cfs.add(cf);
});

List<CreateResult> results = cfs.stream().map(CompletableFuture::join).filter(Objects::nonNull).collect(Collectors.toList());
if (defs.size() != results.size()) {
    log.error("Process create fail exist: [{}]", JacksonUtil.toJsonString(procCreate));
}

第一行 createsValidate 方法在主线程设置了当前用户的上下文
将用户信息放入了 ttl,方便后续子线程使用
测试的时候, 线程池在处理任务的时候,有时会获取不到主线程 ttl 信息
很奇怪,之前也是一直这样使用,为什么没有问题

于是本地main方法模拟

点击查看代码

UserContext.set(new ContextUser().setUserId("mycx26"));

IntStream.range(0, 10).forEach(e -> {
    Supplier<Void> supplier = () -> {
	String userId = UserContext.get() != null ? UserContext.getUserId() : null;
	System.out.println(Thread.currentThread().getName() + " get: " + userId);
	return null;
    };
    CompletableFuture.supplyAsync(supplier);
});

Thread.currentThread().join();

这里主线程将用户信息放入 ttl,依次将异步任务丢给线程池,任务执行获取 ttl 并打印

点击查看代码

ForkJoinPool.commonPool-worker-9 get: mycx26
ForkJoinPool.commonPool-worker-6 get: mycx26
ForkJoinPool.commonPool-worker-13 get: mycx26
ForkJoinPool.commonPool-worker-4 get: mycx26
ForkJoinPool.commonPool-worker-11 get: mycx26
ForkJoinPool.commonPool-worker-2 get: mycx26
ForkJoinPool.commonPool-worker-6 get: mycx26
ForkJoinPool.commonPool-worker-15 get: mycx26
ForkJoinPool.commonPool-worker-8 get: mycx26
ForkJoinPool.commonPool-worker-9 get: mycx26

从输出结果看,各线程都拿到了用户信息,似乎又没有问题

将相同的代码放到工程的单元测试方法里跑

点击查看代码

ForkJoinPool.commonPool-worker-10 get: null
ForkJoinPool.commonPool-worker-15 get: null
ForkJoinPool.commonPool-worker-9 get: null
ForkJoinPool.commonPool-worker-1 get: null
ForkJoinPool.commonPool-worker-13 get: null
ForkJoinPool.commonPool-worker-8 get: null
ForkJoinPool.commonPool-worker-3 get: null
ForkJoinPool.commonPool-worker-2 get: null
ForkJoinPool.commonPool-worker-11 get: null
ForkJoinPool.commonPool-worker-15 get: null

结果却截然相反,到这里我有点怀疑 ttl 对于多线程支持的泛用性了

找到 ttl 的 github readme 阅读
要保证线程池中传递值,一种方式是修饰 Runnable 和 Callable,Supplier 也有类似的包装器
于是修改代码重新测试,测试通过

虽然问题是解决了,但是原因却无从得知,等于还是绕过了问题
下次遇到 ttl 的问题,不知道原理还是无从下手
找到了一个已经 closed 类似的 issue
https://github.com/alibaba/transmittable-thread-local/issues/138
但还是没有解决我的疑问
没有办法,只能看源码了,问题还是要一个一个解决

一. main方法没有修饰的任务为什么能跨越线程池传递 ttl

1.1 首先看看 ttl 的 set 方法做了什么

点击查看代码

public final void set(T value) {
    if (!disableIgnoreNullValueSemantics && null == value) {
        // may set null to remove value
        remove();
    } else {
        super.set(value);
        addThisToHolder();
    }
}

else 走了父类 ThreadLocal 的 set 方法 点击查看代码

public void set(T value) {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        map.set(this, value);
    } else {
        createMap(t, value);
    }
}

先看一下 ttl 的继承体系

ttl 继承 InheritableThreadLocal,InheritableThreadLocal 继承 ThreadLocal

InheritableThreadLocal 可以让子线程访问父线程设置的本地变量

点击查看代码

ThreadLocalMap getMap(Thread t) {
   return t.inheritableThreadLocals;
}

void createMap(Thread t, T firstValue) {
    t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
}

通过重写 getMap 和 createMap 方法将 ThreadLocal 的维护职责
由 threadLocals 转移给了 inheritableThreadLocals
threadLocals 和 inheritableThreadLocals 类型一样
是 ThreadLocal 中的静态内部类 ThreadLocalMap
为了维护线程本地变量定制化的哈希map, 两者由 Thread 持有

回到上文 TheadLocal set方法
首先获取当前线程,入参调用 getMap 方法获取当前线程的 inheritableThreadLocals

  • map不为null
    将 ttl 做为 key,value 作为值,放入当前线程的 inheritableThreadLocals
  • map为null
    将 ttl 和 value 构造一个新的 ThreadLocalMap,初始化当前线程的 inheritableThreadLocals

1.2 接下来看 CompletableFuture 的 supplyAsync 方法
这个方法调用栈很深,如果多线程功力不深,基本看不懂
但这不妨碍排查这个问题
supplyAsync 默认用的 ForkJoinPool 跑任务
那么必然会启一个线程
即必然会调用 Thread 的 init 方法初始化线程

首先将断点加到 CompletableFuture.supplyAsync(supplier); 这行
debug跑起来
然后将断点加到 Thread init 方法的第一行
(防止jvm启动初始化的线程产生干扰,比如 c2 complier thread)

点击查看代码

Thread parent = currentThread();

if (inheritThreadLocals && parent.inheritableThreadLocals != null)
        this.inheritableThreadLocals =
            ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);

重点是这里,判断父线程的 inheritableThreadLocals 如果不为 null
就把父线程的 inheritableThreadLocals 复制到子线程

1.3 接着看 ttl 的 get 方法

点击查看代码

public final T get() {
    T value = super.get();
    if (disableIgnoreNullValueSemantics || null != value) addThisToHolder();
    return value;
}

public T get() {
    Thread t = Thread.currentThread();
    ThreadLocalMap map = getMap(t);
    if (map != null) {
        ThreadLocalMap.Entry e = map.getEntry(this);
        if (e != null) {
            @SuppressWarnings("unchecked")
            T result = (T)e.value;
            return result;
        }
    }
    return setInitialValue();
}

同样走的 ThreadLocal 的 get 方法
首先获取当前线程,getMap 获取其对应的 inheritableThreadLocals
顺利拿到之前父线程设置的变量
到这里,第一个问题算是解了

二. 同样的代码跑在单元测试,没有修饰的任务为什么不能跨越线程池传递 ttl

这里我有理由怀疑是 spring 容器在拉起的时候,提前用到了 ForkJoinPool 的 commonPool
但是项目依赖众多,如何定位
既然用到了,那么将断点加在 ForkJoinPool 启动线程
然后沿着调用栈帧一直向上找不就行了
将断点加在 ForkJoinPool 的 createWorker方法的第一行
开始找

点击查看代码

/* 检查逻辑删除字段只能有最多一个 */
    Assert.isTrue(fieldList.parallelStream().filter(TableFieldInfo::isLogicDelete).count() < 2L,
        String.format("annotation of @TableLogic can't more than one in class : %s.", clazz.getName()));

果然,熟悉的身影,mybatis plus
spring容器拉起时在初始化 SqlSessionFactory 时
会调用 TableInfoHelper 的 initTableFields 方法初始化表主键和字段
注意这里用的 stream 的并行流 parallel stream,很熟悉了
底层默认用的 ForkJoinPool 的 commonPool
那么在主线程设置的 TTL,线程池中的线程之前已经初始化,当然就拿不到了
好,这是第二个问题

三. 为什么项目中自定义线程池获取不到前面主线程创建的 ttl
和二是相同的问题,执行操作前,线程池已经被调度执行任务了
线程如果池化,那么后续在跑异步任务时就没有父子线程之说了
那么现在只剩最后一个问题

四. 为什么项目中任务加了包装器后又拿到了

点击查看代码

TtlWrappers.wrap(() -> create(def, variable))

没有什么办法,跟进去吧

4.1 看看 TtlWrappers 的静态方法 wrap 做了什么

点击查看代码

public static <T> Supplier<T> wrap(@Nullable Supplier<T> supplier) {
    if (supplier == null) return null;
    else if (supplier instanceof TtlEnhanced) return supplier;
    else return new TtlSupplier<T>(supplier);
}

看样子,大概是想用装饰模式包装 Supplier 为 TTL wrapper
new 了一个 TtlSupplier,这是 TtlWrappers 的一个静态内部类
继续进去

点击查看代码

TtlSupplier(@NonNull Supplier<T> supplier) {
    this.supplier = supplier;
    this.capture = capture();
}

supplier完成赋值后,重点是后面的 capture

点击查看代码

/**
 * Capture all {@link TransmittableThreadLocal} and registered {@link ThreadLocal} values in the current thread.
 *
 * @return the captured {@link TransmittableThreadLocal} values
 * @since 2.3.0
 */
@NonNull
public static Object capture() {
    return new Snapshot(captureTtlValues(), captureThreadLocalValues());
}

capture 静态方法位于 ttl 的静态内部类 Transmitter 中
注释很清晰,捕获当前线程的所有 ttl 和 ThreadLocal 的值
new Snapshort 继续跟进去

点击查看代码

private static class Snapshot {
    final WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value;
    final WeakHashMap<ThreadLocal<Object>, Object> threadLocal2Value;

    private Snapshot(WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value, WeakHashMap<ThreadLocal<Object>, Object> threadLocal2Value) {
        this.ttl2Value = ttl2Value;
        this.threadLocal2Value = threadLocal2Value;
    }
}

Snaphost 同样是 ttl 的静态内部类
构造方法的第一个参数方法 captureTtlValues 跟进去

点击查看代码

private static WeakHashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
    WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
    for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
        ttl2Value.put(threadLocal, threadLocal.copyValue());
    }
    return ttl2Value;
}

同样来自 Transmitter
好,代码并不复杂,重点是 holder

点击查看代码

// Note about holder:
// 1. holder self is a InheritableThreadLocal(a *ThreadLocal*).
// 2. The type of value in holder is WeakHashMap<TransmittableThreadLocal<Object>, ?>.
//    2.1 but the WeakHashMap is used as a *Set*:
//        - the value of WeakHashMap is *always null,
//        - and never be used.
//    2.2 WeakHashMap support *null* value.
private static InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
        new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
            @Override
            protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
                return new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
            }

            @Override
            protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
                return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
            }
        };

holder 为 ttl 的静态成员变量,类型为 InheritableThreadLocal 的匿名内部类
重写了 initialValue 和 childValue 方法
再看注释
这里 value 的 type 是 WeakHashMap 并且这个 map 被当作 set 用了
还记得上文分析 ttl 的 set 方法吗,有一块没有讲
对,就是 else 的 addThisToHolder

点击查看代码

private void addThisToHolder() {
    if (!holder.get().containsKey(this)) {
        holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
    }
}

set方法
第一步将 ttl 做为 key,value 作为值,放入当前线程的 inheritableThreadLocals
第二步 addThisToHolder
holder.get()方法获取当前线程的 inheritableThreadLocals 变量
key为 holder本身,这里不要乱了
第一次肯定是拿不到的,那么这里为什么没有npe?
上文提到他重写了 initialValue 方法
继续
首先判断当前线程的 inheritableThreadLocals 是否包含 holder
第一次肯定没有
那么把 ttl 本身作为key, 放入当前线程的 inheritableThreadLocals 维持的 map 中

至此 TtlSupplier 的 capture 属性已经持有了主线程的所有 ttl 快照

4.2 接下来看 TtlSupplier 重写的 get 方法

这是核心的行为,可以断定,其必然做了增强

点击查看代码

public T get() {
    final Object backup = replay(capture);
    try {
        return supplier.get();
    } finally {
        restore(backup);
    }
}

结构很清晰,先replay,再执行核心行为,最后restore
replay 跟进去

点击查看代码

/**
 * Replay the captured {@link TransmittableThreadLocal} and registered {@link ThreadLocal} values from {@link #capture()},
 * and return the backup {@link TransmittableThreadLocal} values in the current thread before replay.
 *
 * @param captured captured {@link TransmittableThreadLocal} values from other thread from {@link #capture()}
 * @return the backup {@link TransmittableThreadLocal} values before replay
 * @see #capture()
 * @since 2.3.0
 */
@NonNull
public static Object replay(@NonNull Object captured) {
    final Snapshot capturedSnapshot = (Snapshot) captured;
    return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
}

同样是位于 ttl 的静态内部类 Transmitter 的静态方法
上文 Tramsmitter capture()方法捕获的主线程快照这里用到了
replayTtlValues 方法跟进去

点击查看代码

private static WeakHashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> captured) {
    WeakHashMap<TransmittableThreadLocal<Object>, Object> backup = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();

    for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        TransmittableThreadLocal<Object> threadLocal = iterator.next();

        // backup
        backup.put(threadLocal, threadLocal.get());

        // clear the TTL values that is not in captured
        // avoid the extra TTL values after replay when run task
        if (!captured.containsKey(threadLocal)) {
            iterator.remove();
            threadLocal.superRemove();
        }
    }

    // set TTL values to captured
    setTtlValuesTo(captured);

    // call beforeExecute callback
    doExecuteCallback(true);

    return backup;
}

private static void setTtlValuesTo(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
    for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
        TransmittableThreadLocal<Object> threadLocal = entry.getKey();
        threadLocal.set(entry.getValue());
    }
}

注意这里已经到子线程了
首先通过 holder 获取当前线程的 inheritableThreadLocals 变量
很可能没有
但是如果线程之前已经池化用完没有remove,这里是有的
遍历 map 的 key ttl
这里为了不污染子线程上下文,先做了备份
对于快照中不包含的 ttl 信息依次 remove
然后遍历快照信息设置到当前线程的 inheritableThreadLocals
doExecuteCallback 方法是 ttl 为开发者留的一个勾子方法
时机在任务执行前
最后返回子线程 ttl 备份

再回到 TtlSupplier 的 get 方法
supplier 的 get 方法执行任务
最后还剩 restore,传入上面子线程的 ttl 备份

点击查看代码

/**
 * Restore the backup {@link TransmittableThreadLocal} and
 * registered {@link ThreadLocal} values from {@link #replay(Object)}/{@link #clear()}.
 *
 * @param backup the backup {@link TransmittableThreadLocal} values from {@link #replay(Object)}/{@link #clear()}
 * @see #replay(Object)
 * @see #clear()
 * @since 2.3.0
 */
public static void restore(@NonNull Object backup) {
    final Snapshot backupSnapshot = (Snapshot) backup;
    restoreTtlValues(backupSnapshot.ttl2Value);
    restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
}

private static void restoreTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> backup) {
    // call afterExecute callback
    doExecuteCallback(false);

    for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        TransmittableThreadLocal<Object> threadLocal = iterator.next();

        // clear the TTL values that is not in backup
        // avoid the extra TTL values after restore
        if (!backup.containsKey(threadLocal)) {
            iterator.remove();
            threadLocal.superRemove();
        }
    }

    // restore TTL values
    setTtlValuesTo(backup);
}

主要看 restoreTtlValues 方法
doExecuteCallback 和上面逻辑类似
区别在时机在任务执行后
holder 获取子线程的 inheritableThreadLocals 变量
遍历 map 的 key ttl
对于不在备份的 ttl 全部删除
最后恢复子线程的 ttl

仿佛一切没有发生过

至此最后一个问题解决

你对的不一定对,你错了一定是错了
源码面前,没有什么秘密可言了

声明:本站所有文章,如无特殊说明或标注,均为本站原创发布。任何个人或组织,在未征得本站同意时,禁止复制、盗用、采集、发布本站内容到任何网站、书籍等各类媒体平台。如若本站内容侵犯了原著者的合法权益,可联系我们进行处理。