From a10d35e6cdab391e8a637b860f4c1ba5e57e96c2 Mon Sep 17 00:00:00 2001 From: dmitrii Date: Wed, 3 Jun 2026 01:23:59 +0200 Subject: [PATCH] Fix request context propagation for async executor tasks --- .../main/java/dev/aikido/agent/Wrappers.java | 14 ++ .../AbstractExecutorServiceWrapper.java | 52 ++++ .../DelegatedExecutorServiceWrapper.java | 72 ++++++ .../executor/ExecutorContextPropagation.java | 67 ++++++ .../executor/ForkJoinPoolWrapper.java | 52 ++++ .../ScheduledThreadPoolExecutorWrapper.java | 73 ++++++ .../executor/ThreadPoolExecutorWrapper.java | 42 ++++ .../context/ContextPropagatingCallable.java | 29 +++ .../context/ContextPropagatingRunnable.java | 27 +++ .../agent_api/context/ContextPropagation.java | 33 +++ .../java/context/ContextPropagationTest.java | 224 ++++++++++++++++++ end2end/spring_boot_postgres.py | 21 ++ .../demo/AsyncContextPropagationConfig.java | 23 ++ .../AsyncContextPropagationController.java | 138 +++++++++++ .../demo/AsyncContextPropagationService.java | 15 ++ 15 files changed, 882 insertions(+) create mode 100644 agent/src/main/java/dev/aikido/agent/wrappers/executor/AbstractExecutorServiceWrapper.java create mode 100644 agent/src/main/java/dev/aikido/agent/wrappers/executor/DelegatedExecutorServiceWrapper.java create mode 100644 agent/src/main/java/dev/aikido/agent/wrappers/executor/ExecutorContextPropagation.java create mode 100644 agent/src/main/java/dev/aikido/agent/wrappers/executor/ForkJoinPoolWrapper.java create mode 100644 agent/src/main/java/dev/aikido/agent/wrappers/executor/ScheduledThreadPoolExecutorWrapper.java create mode 100644 agent/src/main/java/dev/aikido/agent/wrappers/executor/ThreadPoolExecutorWrapper.java create mode 100644 agent_api/src/main/java/dev/aikido/agent_api/context/ContextPropagatingCallable.java create mode 100644 agent_api/src/main/java/dev/aikido/agent_api/context/ContextPropagatingRunnable.java create mode 100644 agent_api/src/main/java/dev/aikido/agent_api/context/ContextPropagation.java create mode 100644 agent_api/src/test/java/context/ContextPropagationTest.java create mode 100644 sample-apps/SpringBootPostgres/src/main/java/com/example/demo/AsyncContextPropagationConfig.java create mode 100644 sample-apps/SpringBootPostgres/src/main/java/com/example/demo/AsyncContextPropagationController.java create mode 100644 sample-apps/SpringBootPostgres/src/main/java/com/example/demo/AsyncContextPropagationService.java diff --git a/agent/src/main/java/dev/aikido/agent/Wrappers.java b/agent/src/main/java/dev/aikido/agent/Wrappers.java index a71c13b26..3c628a340 100644 --- a/agent/src/main/java/dev/aikido/agent/Wrappers.java +++ b/agent/src/main/java/dev/aikido/agent/Wrappers.java @@ -1,6 +1,11 @@ package dev.aikido.agent; import dev.aikido.agent.wrappers.*; +import dev.aikido.agent.wrappers.executor.AbstractExecutorServiceWrapper; +import dev.aikido.agent.wrappers.executor.DelegatedExecutorServiceWrapper; +import dev.aikido.agent.wrappers.executor.ForkJoinPoolWrapper; +import dev.aikido.agent.wrappers.executor.ScheduledThreadPoolExecutorWrapper; +import dev.aikido.agent.wrappers.executor.ThreadPoolExecutorWrapper; import dev.aikido.agent.wrappers.file.FileConstructorMultiArgumentWrapper; import dev.aikido.agent.wrappers.file.FileConstructorSingleArgumentWrapper; import dev.aikido.agent.wrappers.javalin.*; @@ -9,6 +14,8 @@ import dev.aikido.agent.wrappers.spring.SpringWebfluxWrapper; import dev.aikido.agent.wrappers.spring.SpringControllerWrapper; import dev.aikido.agent.wrappers.spring.SpringMVCJakartaWrapper; +import dev.aikido.agent.wrappers.spring.SpringMVCJavaxWrapper; +import dev.aikido.agent.wrappers.spring.SpringWebfluxWrapper; import java.util.Arrays; import java.util.List; @@ -17,6 +24,13 @@ public final class Wrappers { private Wrappers() {} public static final List WRAPPERS = Arrays.asList( new PostgresWrapper(), + + new DelegatedExecutorServiceWrapper(), + new ThreadPoolExecutorWrapper(), + new AbstractExecutorServiceWrapper(), + new ForkJoinPoolWrapper(), + new ScheduledThreadPoolExecutorWrapper(), + new SpringMVCJakartaWrapper(), new SpringMVCJavaxWrapper(), new SpringWebfluxWrapper(), diff --git a/agent/src/main/java/dev/aikido/agent/wrappers/executor/AbstractExecutorServiceWrapper.java b/agent/src/main/java/dev/aikido/agent/wrappers/executor/AbstractExecutorServiceWrapper.java new file mode 100644 index 000000000..1f152c3f1 --- /dev/null +++ b/agent/src/main/java/dev/aikido/agent/wrappers/executor/AbstractExecutorServiceWrapper.java @@ -0,0 +1,52 @@ +package dev.aikido.agent.wrappers.executor; + +import dev.aikido.agent.wrappers.Wrapper; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.method.MethodDescription; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; + +import java.util.concurrent.AbstractExecutorService; +import java.util.concurrent.Callable; + +import static net.bytebuddy.implementation.bytecode.assign.Assigner.Typing.DYNAMIC; +import static net.bytebuddy.matcher.ElementMatchers.isMethod; +import static net.bytebuddy.matcher.ElementMatchers.isSubTypeOf; +import static net.bytebuddy.matcher.ElementMatchers.named; +import static net.bytebuddy.matcher.ElementMatchers.takesArguments; + +public class AbstractExecutorServiceWrapper implements Wrapper { + @Override + public String getName() { + return SubmitAdvice.class.getName(); + } + + @Override + public ElementMatcher getMatcher() { + return isMethod() + .and(named("submit")) + .and( + takesArguments(Runnable.class) + .or(takesArguments(Callable.class)) + .or(takesArguments(Runnable.class, Object.class)) + ); + } + + @Override + public ElementMatcher getTypeMatcher() { + return isSubTypeOf(AbstractExecutorService.class); + } + + public static class SubmitAdvice { + @Advice.OnMethodEnter(suppress = Throwable.class) + public static void before( + @Advice.Argument(value = 0, readOnly = false, typing = DYNAMIC) Object task + ) { + if (task instanceof Runnable) { + task = ExecutorContextPropagation.wrap((Runnable) task); + } else if (task instanceof Callable) { + task = ExecutorContextPropagation.wrap((Callable) task); + } + } + } +} diff --git a/agent/src/main/java/dev/aikido/agent/wrappers/executor/DelegatedExecutorServiceWrapper.java b/agent/src/main/java/dev/aikido/agent/wrappers/executor/DelegatedExecutorServiceWrapper.java new file mode 100644 index 000000000..697615aa1 --- /dev/null +++ b/agent/src/main/java/dev/aikido/agent/wrappers/executor/DelegatedExecutorServiceWrapper.java @@ -0,0 +1,72 @@ +package dev.aikido.agent.wrappers.executor; + +import dev.aikido.agent.wrappers.Wrapper; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.method.MethodDescription; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; + +import java.lang.reflect.Method; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.concurrent.Callable; + +import static net.bytebuddy.implementation.bytecode.assign.Assigner.Typing.DYNAMIC; +import static net.bytebuddy.matcher.ElementMatchers.isMethod; +import static net.bytebuddy.matcher.ElementMatchers.nameStartsWith; +import static net.bytebuddy.matcher.ElementMatchers.named; +import static net.bytebuddy.matcher.ElementMatchers.takesArguments; + +public class DelegatedExecutorServiceWrapper implements Wrapper { + @Override + public String getName() { + return DelegatedExecutorAdvice.class.getName(); + } + + @Override + public ElementMatcher getMatcher() { + return isMethod() + .and(named("execute").or(named("submit"))) + .and( + takesArguments(Runnable.class) + .or(takesArguments(Callable.class)) + .or(takesArguments(Runnable.class, Object.class)) + ); + } + + @Override + public ElementMatcher getTypeMatcher() { + return nameStartsWith("java.util.concurrent.Executors$"); + } + + public static class DelegatedExecutorAdvice { + @Advice.OnMethodEnter(suppress = Throwable.class) + public static void before( + @Advice.Argument(value = 0, readOnly = false, typing = DYNAMIC) Object task + ) throws Exception { + if (task == null) { + return; + } + + // This advice is applied to JDK classes loaded by the bootstrap classloader. + // Load agent_api reflectively because bootstrap classes cannot directly reference agent classes. + String jarFilePath = System.getProperty("AIK_agent_api_jar"); + if (jarFilePath == null || jarFilePath.isBlank()) { + return; + } + + URLClassLoader classLoader = new URLClassLoader(new URL[] { new URL(jarFilePath) }); + Class contextPropagationClass = classLoader.loadClass( + "dev.aikido.agent_api.context.ContextPropagation" + ); + + if (task instanceof Runnable) { + Method wrapRunnable = contextPropagationClass.getMethod("wrap", Runnable.class); + task = wrapRunnable.invoke(null, task); + } else if (task instanceof Callable) { + Method wrapCallable = contextPropagationClass.getMethod("wrap", Callable.class); + task = wrapCallable.invoke(null, task); + } + } + } +} diff --git a/agent/src/main/java/dev/aikido/agent/wrappers/executor/ExecutorContextPropagation.java b/agent/src/main/java/dev/aikido/agent/wrappers/executor/ExecutorContextPropagation.java new file mode 100644 index 000000000..8466083f0 --- /dev/null +++ b/agent/src/main/java/dev/aikido/agent/wrappers/executor/ExecutorContextPropagation.java @@ -0,0 +1,67 @@ +package dev.aikido.agent.wrappers.executor; + +import java.lang.reflect.Method; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.concurrent.Callable; + +public final class ExecutorContextPropagation { + private static Method wrapRunnableMethod; + private static Method wrapCallableMethod; + private static boolean disabled; + + private ExecutorContextPropagation() {} + + public static Runnable wrap(Runnable task) { + if (task == null || disabled) { + return task; + } + + try { + init(); + if (wrapRunnableMethod == null) { + return task; + } + return (Runnable) wrapRunnableMethod.invoke(null, task); + } catch (Throwable ignored) { + disabled = true; + return task; + } + } + + @SuppressWarnings("unchecked") + public static Callable wrap(Callable task) { + if (task == null || disabled) { + return task; + } + + try { + init(); + if (wrapCallableMethod == null) { + return task; + } + return (Callable) wrapCallableMethod.invoke(null, task); + } catch (Throwable ignored) { + disabled = true; + return task; + } + } + + private static synchronized void init() throws Exception { + if (disabled || wrapRunnableMethod != null) { + return; + } + + String jarFilePath = System.getProperty("AIK_agent_api_jar"); + if (jarFilePath == null || jarFilePath.isBlank()) { + disabled = true; + return; + } + + URLClassLoader classLoader = new URLClassLoader(new URL[] { new URL(jarFilePath) }); + Class clazz = classLoader.loadClass("dev.aikido.agent_api.context.ContextPropagation"); + + wrapRunnableMethod = clazz.getMethod("wrap", Runnable.class); + wrapCallableMethod = clazz.getMethod("wrap", Callable.class); + } +} diff --git a/agent/src/main/java/dev/aikido/agent/wrappers/executor/ForkJoinPoolWrapper.java b/agent/src/main/java/dev/aikido/agent/wrappers/executor/ForkJoinPoolWrapper.java new file mode 100644 index 000000000..b78f0a95d --- /dev/null +++ b/agent/src/main/java/dev/aikido/agent/wrappers/executor/ForkJoinPoolWrapper.java @@ -0,0 +1,52 @@ +package dev.aikido.agent.wrappers.executor; + +import dev.aikido.agent.wrappers.Wrapper; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.method.MethodDescription; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; + +import java.util.concurrent.Callable; +import java.util.concurrent.ForkJoinPool; + +import static net.bytebuddy.implementation.bytecode.assign.Assigner.Typing.DYNAMIC; +import static net.bytebuddy.matcher.ElementMatchers.isMethod; +import static net.bytebuddy.matcher.ElementMatchers.isSubTypeOf; +import static net.bytebuddy.matcher.ElementMatchers.named; +import static net.bytebuddy.matcher.ElementMatchers.takesArguments; + +public class ForkJoinPoolWrapper implements Wrapper { + @Override + public String getName() { + return ForkJoinAdvice.class.getName(); + } + + @Override + public ElementMatcher getMatcher() { + return isMethod() + .and(named("execute").or(named("submit"))) + .and( + takesArguments(Runnable.class) + .or(takesArguments(Callable.class)) + .or(takesArguments(Runnable.class, Object.class)) + ); + } + + @Override + public ElementMatcher getTypeMatcher() { + return isSubTypeOf(ForkJoinPool.class); + } + + public static class ForkJoinAdvice { + @Advice.OnMethodEnter(suppress = Throwable.class) + public static void before( + @Advice.Argument(value = 0, readOnly = false, typing = DYNAMIC) Object task + ) { + if (task instanceof Runnable) { + task = ExecutorContextPropagation.wrap((Runnable) task); + } else if (task instanceof Callable) { + task = ExecutorContextPropagation.wrap((Callable) task); + } + } + } +} diff --git a/agent/src/main/java/dev/aikido/agent/wrappers/executor/ScheduledThreadPoolExecutorWrapper.java b/agent/src/main/java/dev/aikido/agent/wrappers/executor/ScheduledThreadPoolExecutorWrapper.java new file mode 100644 index 000000000..b2c436e6e --- /dev/null +++ b/agent/src/main/java/dev/aikido/agent/wrappers/executor/ScheduledThreadPoolExecutorWrapper.java @@ -0,0 +1,73 @@ +package dev.aikido.agent.wrappers.executor; + +import dev.aikido.agent.wrappers.Wrapper; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.method.MethodDescription; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; + +import java.lang.reflect.Method; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.concurrent.Callable; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.TimeUnit; + +import static net.bytebuddy.implementation.bytecode.assign.Assigner.Typing.DYNAMIC; +import static net.bytebuddy.matcher.ElementMatchers.isMethod; +import static net.bytebuddy.matcher.ElementMatchers.isSubTypeOf; +import static net.bytebuddy.matcher.ElementMatchers.named; +import static net.bytebuddy.matcher.ElementMatchers.takesArguments; + +public class ScheduledThreadPoolExecutorWrapper implements Wrapper { + @Override + public String getName() { + return ScheduleAdvice.class.getName(); + } + + @Override + public ElementMatcher getMatcher() { + return isMethod() + .and(named("schedule")) + .and( + takesArguments(Runnable.class, long.class, TimeUnit.class) + .or(takesArguments(Callable.class, long.class, TimeUnit.class)) + ); + } + + @Override + public ElementMatcher getTypeMatcher() { + return isSubTypeOf(ScheduledThreadPoolExecutor.class); + } + + public static class ScheduleAdvice { + @Advice.OnMethodEnter(suppress = Throwable.class) + public static void before( + @Advice.Argument(value = 0, readOnly = false, typing = DYNAMIC) Object task + ) throws Exception { + if (task == null) { + return; + } + + // This advice is applied to JDK classes loaded by the bootstrap classloader. + // Load agent_api reflectively because bootstrap classes cannot directly reference agent classes. + String jarFilePath = System.getProperty("AIK_agent_api_jar"); + if (jarFilePath == null || jarFilePath.isBlank()) { + return; + } + + URLClassLoader classLoader = new URLClassLoader(new URL[] { new URL(jarFilePath) }); + Class contextPropagationClass = classLoader.loadClass( + "dev.aikido.agent_api.context.ContextPropagation" + ); + + if (task instanceof Runnable) { + Method wrapRunnable = contextPropagationClass.getMethod("wrap", Runnable.class); + task = wrapRunnable.invoke(null, task); + } else if (task instanceof Callable) { + Method wrapCallable = contextPropagationClass.getMethod("wrap", Callable.class); + task = wrapCallable.invoke(null, task); + } + } + } +} diff --git a/agent/src/main/java/dev/aikido/agent/wrappers/executor/ThreadPoolExecutorWrapper.java b/agent/src/main/java/dev/aikido/agent/wrappers/executor/ThreadPoolExecutorWrapper.java new file mode 100644 index 000000000..ece585980 --- /dev/null +++ b/agent/src/main/java/dev/aikido/agent/wrappers/executor/ThreadPoolExecutorWrapper.java @@ -0,0 +1,42 @@ +package dev.aikido.agent.wrappers.executor; + +import dev.aikido.agent.wrappers.Wrapper; +import net.bytebuddy.asm.Advice; +import net.bytebuddy.description.method.MethodDescription; +import net.bytebuddy.description.type.TypeDescription; +import net.bytebuddy.matcher.ElementMatcher; + +import java.util.concurrent.ThreadPoolExecutor; + +import static net.bytebuddy.matcher.ElementMatchers.isMethod; +import static net.bytebuddy.matcher.ElementMatchers.isSubTypeOf; +import static net.bytebuddy.matcher.ElementMatchers.named; +import static net.bytebuddy.matcher.ElementMatchers.takesArguments; + +public class ThreadPoolExecutorWrapper implements Wrapper { + @Override + public String getName() { + return ExecuteAdvice.class.getName(); + } + + @Override + public ElementMatcher getMatcher() { + return isMethod() + .and(named("execute")) + .and(takesArguments(Runnable.class)); + } + + @Override + public ElementMatcher getTypeMatcher() { + return isSubTypeOf(ThreadPoolExecutor.class); + } + + public static class ExecuteAdvice { + @Advice.OnMethodEnter(suppress = Throwable.class) + public static void before( + @Advice.Argument(value = 0, readOnly = false) Runnable task + ) { + task = ExecutorContextPropagation.wrap(task); + } + } +} diff --git a/agent_api/src/main/java/dev/aikido/agent_api/context/ContextPropagatingCallable.java b/agent_api/src/main/java/dev/aikido/agent_api/context/ContextPropagatingCallable.java new file mode 100644 index 000000000..ec9dd26a0 --- /dev/null +++ b/agent_api/src/main/java/dev/aikido/agent_api/context/ContextPropagatingCallable.java @@ -0,0 +1,29 @@ +package dev.aikido.agent_api.context; + +import java.util.concurrent.Callable; + +public final class ContextPropagatingCallable implements Callable { + private final Callable delegate; + private final ContextObject context; + + public ContextPropagatingCallable(Callable delegate, ContextObject context) { + this.delegate = delegate; + this.context = context; + } + + @Override + public T call() throws Exception { + ContextObject previousContext = Context.get(); + + try { + Context.set(context); + return delegate.call(); + } finally { + if (previousContext != null) { + Context.set(previousContext); + } else { + Context.reset(); + } + } + } +} diff --git a/agent_api/src/main/java/dev/aikido/agent_api/context/ContextPropagatingRunnable.java b/agent_api/src/main/java/dev/aikido/agent_api/context/ContextPropagatingRunnable.java new file mode 100644 index 000000000..20edb2a80 --- /dev/null +++ b/agent_api/src/main/java/dev/aikido/agent_api/context/ContextPropagatingRunnable.java @@ -0,0 +1,27 @@ +package dev.aikido.agent_api.context; + +public final class ContextPropagatingRunnable implements Runnable { + private final Runnable delegate; + private final ContextObject context; + + public ContextPropagatingRunnable(Runnable delegate, ContextObject context) { + this.delegate = delegate; + this.context = context; + } + + @Override + public void run() { + ContextObject previousContext = Context.get(); + + try { + Context.set(context); + delegate.run(); + } finally { + if (previousContext != null) { + Context.set(previousContext); + } else { + Context.reset(); + } + } + } +} diff --git a/agent_api/src/main/java/dev/aikido/agent_api/context/ContextPropagation.java b/agent_api/src/main/java/dev/aikido/agent_api/context/ContextPropagation.java new file mode 100644 index 000000000..a3223ce85 --- /dev/null +++ b/agent_api/src/main/java/dev/aikido/agent_api/context/ContextPropagation.java @@ -0,0 +1,33 @@ +package dev.aikido.agent_api.context; + +import java.util.concurrent.Callable; + +public final class ContextPropagation { + private ContextPropagation() {} + + public static Runnable wrap(Runnable task) { + if (task == null || task instanceof ContextPropagatingRunnable) { + return task; + } + + ContextObject context = Context.get(); + if (context == null) { + return task; + } + + return new ContextPropagatingRunnable(task, context); + } + + public static Callable wrap(Callable task) { + if (task == null || task instanceof ContextPropagatingCallable) { + return task; + } + + ContextObject context = Context.get(); + if (context == null) { + return task; + } + + return new ContextPropagatingCallable<>(task, context); + } +} diff --git a/agent_api/src/test/java/context/ContextPropagationTest.java b/agent_api/src/test/java/context/ContextPropagationTest.java new file mode 100644 index 000000000..1a1fe3306 --- /dev/null +++ b/agent_api/src/test/java/context/ContextPropagationTest.java @@ -0,0 +1,224 @@ +package context; + +import dev.aikido.agent_api.context.Context; +import dev.aikido.agent_api.context.ContextObject; +import dev.aikido.agent_api.context.ContextPropagatingCallable; +import dev.aikido.agent_api.context.ContextPropagatingRunnable; +import dev.aikido.agent_api.context.ContextPropagation; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.Callable; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +class ContextPropagationTest { + @AfterEach + void tearDown() { + Context.reset(); + } + + @Test + void wrapRunnableReturnsNullForNullTask() { + Assertions.assertNull(ContextPropagation.wrap((Runnable) null)); + } + + @Test + void wrapCallableReturnsNullForNullTask() { + Assertions.assertNull(ContextPropagation.wrap((Callable) null)); + } + + @Test + void wrapRunnableReturnsOriginalTaskWhenNoContextIsSet() { + Runnable task = () -> {}; + + Runnable wrapped = ContextPropagation.wrap(task); + + Assertions.assertSame(task, wrapped); + } + + @Test + void wrapCallableReturnsOriginalTaskWhenNoContextIsSet() { + Callable task = () -> "ok"; + + Callable wrapped = ContextPropagation.wrap(task); + + Assertions.assertSame(task, wrapped); + } + + @Test + void wrapRunnableReturnsSameTaskWhenAlreadyWrapped() { + ContextObject contextObject = new ContextObject(); + Runnable task = new ContextPropagatingRunnable(() -> {}, contextObject); + + Runnable wrapped = ContextPropagation.wrap(task); + + Assertions.assertSame(task, wrapped); + } + + @Test + void wrapCallableReturnsSameTaskWhenAlreadyWrapped() { + ContextObject contextObject = new ContextObject(); + Callable task = new ContextPropagatingCallable<>(() -> "ok", contextObject); + + Callable wrapped = ContextPropagation.wrap(task); + + Assertions.assertSame(task, wrapped); + } + + @Test + void wrapRunnableCapturesCurrentContext() { + ContextObject requestContext = new ContextObject(); + Context.set(requestContext); + + AtomicReference contextDuringRun = new AtomicReference<>(); + Runnable wrapped = ContextPropagation.wrap(() -> contextDuringRun.set(Context.get())); + + Context.reset(); + wrapped.run(); + + Assertions.assertSame(requestContext, contextDuringRun.get()); + Assertions.assertNull(Context.get(), "Expected worker context to be cleared after task execution"); + } + + @Test + void wrapCallableCapturesCurrentContext() throws Exception { + ContextObject requestContext = new ContextObject(); + Context.set(requestContext); + + Callable wrapped = ContextPropagation.wrap(Context::get); + + Context.reset(); + ContextObject contextDuringCall = wrapped.call(); + + Assertions.assertSame(requestContext, contextDuringCall); + Assertions.assertNull(Context.get(), "Expected worker context to be cleared after task execution"); + } + + @Test + void contextPropagatingRunnableRestoresPreviousWorkerContext() { + ContextObject capturedContext = new ContextObject(); + ContextObject previousWorkerContext = new ContextObject(); + + ContextPropagatingRunnable task = new ContextPropagatingRunnable( + () -> Assertions.assertSame(capturedContext, Context.get()), + capturedContext + ); + + Context.set(previousWorkerContext); + task.run(); + + Assertions.assertSame(previousWorkerContext, Context.get()); + } + + @Test + void contextPropagatingCallableRestoresPreviousWorkerContext() throws Exception { + ContextObject capturedContext = new ContextObject(); + ContextObject previousWorkerContext = new ContextObject(); + + ContextPropagatingCallable task = new ContextPropagatingCallable<>( + Context::get, + capturedContext + ); + + Context.set(previousWorkerContext); + ContextObject contextDuringCall = task.call(); + + Assertions.assertSame(capturedContext, contextDuringCall); + Assertions.assertSame(previousWorkerContext, Context.get()); + } + + @Test + void contextPropagatingRunnableClearsContextWhenWorkerHadNoPreviousContext() { + ContextObject capturedContext = new ContextObject(); + + ContextPropagatingRunnable task = new ContextPropagatingRunnable( + () -> Assertions.assertSame(capturedContext, Context.get()), + capturedContext + ); + + Context.reset(); + task.run(); + + Assertions.assertNull(Context.get()); + } + + @Test + void contextPropagatingCallableClearsContextWhenWorkerHadNoPreviousContext() throws Exception { + ContextObject capturedContext = new ContextObject(); + + ContextPropagatingCallable task = new ContextPropagatingCallable<>( + Context::get, + capturedContext + ); + + Context.reset(); + ContextObject contextDuringCall = task.call(); + + Assertions.assertSame(capturedContext, contextDuringCall); + Assertions.assertNull(Context.get()); + } + + @Test + void contextPropagatingRunnableRestoresPreviousWorkerContextAfterException() { + ContextObject capturedContext = new ContextObject(); + ContextObject previousWorkerContext = new ContextObject(); + + ContextPropagatingRunnable task = new ContextPropagatingRunnable( + () -> { + throw new IllegalStateException("boom"); + }, + capturedContext + ); + + Context.set(previousWorkerContext); + + Assertions.assertThrows(IllegalStateException.class, task::run); + Assertions.assertSame(previousWorkerContext, Context.get()); + } + + @Test + void contextPropagatingCallableRestoresPreviousWorkerContextAfterException() { + ContextObject capturedContext = new ContextObject(); + ContextObject previousWorkerContext = new ContextObject(); + + ContextPropagatingCallable task = new ContextPropagatingCallable<>( + () -> { + throw new IllegalStateException("boom"); + }, + capturedContext + ); + + Context.set(previousWorkerContext); + + Assertions.assertThrows(IllegalStateException.class, task::call); + Assertions.assertSame(previousWorkerContext, Context.get()); + } + + @Test + void wrappedRunnableRunsDelegate() { + ContextObject requestContext = new ContextObject(); + Context.set(requestContext); + + AtomicBoolean delegateCalled = new AtomicBoolean(false); + Runnable wrapped = ContextPropagation.wrap(() -> delegateCalled.set(true)); + + Context.reset(); + wrapped.run(); + + Assertions.assertTrue(delegateCalled.get()); + } + + @Test + void wrappedCallableReturnsDelegateResult() throws Exception { + ContextObject requestContext = new ContextObject(); + Context.set(requestContext); + + Callable wrapped = ContextPropagation.wrap(() -> "ok"); + + Context.reset(); + + Assertions.assertEquals("ok", wrapped.call()); + } +} diff --git a/end2end/spring_boot_postgres.py b/end2end/spring_boot_postgres.py index 8421b6c1b..01278db2a 100644 --- a/end2end/spring_boot_postgres.py +++ b/end2end/spring_boot_postgres.py @@ -6,6 +6,27 @@ safe_request=Request("/api/pets/create", body={"name": "Bobby"}), unsafe_request=Request("/api/pets/create", body={"name": "Malicious Pet', 'Gru from the Minions') -- "}) ) + +for endpoint in [ + "completable-future-single", + "submit-callable", + "thread-pool-execute", + "fork-join-submit", + "scheduled-callable", + "spring-task-executor", + "spring-async-annotation", +]: + spring_boot_postgres_app.add_payload( + f"sql async context propagation {endpoint}", + safe_request=Request( + f"/api/pets/create/async/{endpoint}", + body={"name": "Bobby"} + ), + unsafe_request=Request( + f"/api/pets/create/async/{endpoint}", + body={"name": "Malicious Pet', 'Gru from the Minions') -- "} + ) + ) spring_boot_postgres_app.add_payload("command injection", safe_request=Request("/api/commands/execute/Johnny", method='GET'), unsafe_request=Request("/api/commands/execute/%27%3B%20sleep%202%3B%20%23%20", method='GET'), diff --git a/sample-apps/SpringBootPostgres/src/main/java/com/example/demo/AsyncContextPropagationConfig.java b/sample-apps/SpringBootPostgres/src/main/java/com/example/demo/AsyncContextPropagationConfig.java new file mode 100644 index 000000000..695e905f8 --- /dev/null +++ b/sample-apps/SpringBootPostgres/src/main/java/com/example/demo/AsyncContextPropagationConfig.java @@ -0,0 +1,23 @@ +package com.example.demo; + +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.scheduling.annotation.EnableAsync; +import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor; + +import java.util.concurrent.Executor; + +@Configuration +@EnableAsync +public class AsyncContextPropagationConfig { + @Bean(name = "asyncContextPropagationExecutor") + public Executor asyncContextPropagationExecutor() { + ThreadPoolTaskExecutor executor = new ThreadPoolTaskExecutor(); + executor.setThreadNamePrefix("async-context-"); + executor.setCorePoolSize(2); + executor.setMaxPoolSize(2); + executor.setQueueCapacity(10); + executor.initialize(); + return executor; + } +} diff --git a/sample-apps/SpringBootPostgres/src/main/java/com/example/demo/AsyncContextPropagationController.java b/sample-apps/SpringBootPostgres/src/main/java/com/example/demo/AsyncContextPropagationController.java new file mode 100644 index 000000000..636dab94f --- /dev/null +++ b/sample-apps/SpringBootPostgres/src/main/java/com/example/demo/AsyncContextPropagationController.java @@ -0,0 +1,138 @@ +package com.example.demo; + +import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.http.MediaType; +import org.springframework.web.bind.annotation.*; + +import java.util.concurrent.*; + +@RestController +@RequestMapping("/api/pets/create/async") +public class AsyncContextPropagationController { + private final Executor springExecutor; + private final AsyncContextPropagationService asyncContextPropagationService; + + public AsyncContextPropagationController( + @Qualifier("asyncContextPropagationExecutor") Executor springExecutor, + AsyncContextPropagationService asyncContextPropagationService + ) { + this.springExecutor = springExecutor; + this.asyncContextPropagationService = asyncContextPropagationService; + } + + private record PetCreate(String name) {} + + @PostMapping( + path = "/completable-future-single", + consumes = MediaType.APPLICATION_JSON_VALUE, + produces = MediaType.APPLICATION_JSON_VALUE + ) + public PetsController.Rows completableFutureSingle(@RequestBody PetCreate pet) throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + try { + return CompletableFuture + .supplyAsync(() -> createPet(pet.name()), executor) + .get(); + } finally { + executor.shutdown(); + } + } + + @PostMapping( + path = "/submit-callable", + consumes = MediaType.APPLICATION_JSON_VALUE, + produces = MediaType.APPLICATION_JSON_VALUE + ) + public PetsController.Rows submitCallable(@RequestBody PetCreate pet) throws Exception { + ExecutorService executor = Executors.newSingleThreadExecutor(); + try { + return executor.submit(() -> createPet(pet.name())).get(); + } finally { + executor.shutdown(); + } + } + + @PostMapping( + path = "/thread-pool-execute", + consumes = MediaType.APPLICATION_JSON_VALUE, + produces = MediaType.APPLICATION_JSON_VALUE + ) + public PetsController.Rows threadPoolExecute(@RequestBody PetCreate pet) throws Exception { + ThreadPoolExecutor executor = (ThreadPoolExecutor) Executors.newFixedThreadPool(1); + try { + CompletableFuture future = new CompletableFuture<>(); + executor.execute(() -> { + try { + future.complete(createPet(pet.name())); + } catch (Throwable throwable) { + future.completeExceptionally(throwable); + } + }); + return future.get(); + } finally { + executor.shutdown(); + } + } + + @PostMapping( + path = "/fork-join-submit", + consumes = MediaType.APPLICATION_JSON_VALUE, + produces = MediaType.APPLICATION_JSON_VALUE + ) + public PetsController.Rows forkJoinSubmit(@RequestBody PetCreate pet) throws Exception { + return ForkJoinPool.commonPool() + .submit(() -> createPet(pet.name())) + .get(); + } + + @PostMapping( + path = "/scheduled-callable", + consumes = MediaType.APPLICATION_JSON_VALUE, + produces = MediaType.APPLICATION_JSON_VALUE + ) + public PetsController.Rows scheduledCallable(@RequestBody PetCreate pet) throws Exception { + ScheduledThreadPoolExecutor executor = new ScheduledThreadPoolExecutor(1); + try { + return executor.schedule( + () -> createPet(pet.name()), + 1, + TimeUnit.MILLISECONDS + ).get(); + } finally { + executor.shutdown(); + } + } + + @PostMapping( + path = "/spring-task-executor", + consumes = MediaType.APPLICATION_JSON_VALUE, + produces = MediaType.APPLICATION_JSON_VALUE + ) + public PetsController.Rows springTaskExecutor(@RequestBody PetCreate pet) throws Exception { + CompletableFuture future = new CompletableFuture<>(); + springExecutor.execute(() -> { + try { + future.complete(createPet(pet.name())); + } catch (Throwable throwable) { + future.completeExceptionally(throwable); + } + }); + return future.get(); + } + + @PostMapping( + path = "/spring-async-annotation", + consumes = MediaType.APPLICATION_JSON_VALUE, + produces = MediaType.APPLICATION_JSON_VALUE + ) + public PetsController.Rows springAsyncAnnotation(@RequestBody PetCreate pet) throws Exception { + return asyncContextPropagationService + .createPetWithAsyncAnnotation(pet.name()) + .get(); + } + + private PetsController.Rows createPet(String name) { + Integer rowsCreated = DatabaseHelper.createPetByName(name); + return new PetsController.Rows(rowsCreated); + } +} diff --git a/sample-apps/SpringBootPostgres/src/main/java/com/example/demo/AsyncContextPropagationService.java b/sample-apps/SpringBootPostgres/src/main/java/com/example/demo/AsyncContextPropagationService.java new file mode 100644 index 000000000..bb8573952 --- /dev/null +++ b/sample-apps/SpringBootPostgres/src/main/java/com/example/demo/AsyncContextPropagationService.java @@ -0,0 +1,15 @@ +package com.example.demo; + +import org.springframework.scheduling.annotation.Async; +import org.springframework.stereotype.Service; + +import java.util.concurrent.CompletableFuture; + +@Service +public class AsyncContextPropagationService { + @Async("asyncContextPropagationExecutor") + public CompletableFuture createPetWithAsyncAnnotation(String name) { + Integer rowsCreated = DatabaseHelper.createPetByName(name); + return CompletableFuture.completedFuture(new PetsController.Rows(rowsCreated)); + } +}