diff --git a/dd-java-agent/instrumentation/spring/spring-messaging-4.0/src/main/java/datadog/trace/instrumentation/springmessaging/SpringMessageHandlerInstrumentation.java b/dd-java-agent/instrumentation/spring/spring-messaging-4.0/src/main/java/datadog/trace/instrumentation/springmessaging/SpringMessageHandlerInstrumentation.java index edd55b2ba98..d56f18b279e 100644 --- a/dd-java-agent/instrumentation/spring/spring-messaging-4.0/src/main/java/datadog/trace/instrumentation/springmessaging/SpringMessageHandlerInstrumentation.java +++ b/dd-java-agent/instrumentation/spring/spring-messaging-4.0/src/main/java/datadog/trace/instrumentation/springmessaging/SpringMessageHandlerInstrumentation.java @@ -20,6 +20,8 @@ import datadog.trace.agent.tooling.annotation.AppliesOn; import datadog.trace.bootstrap.instrumentation.api.AgentScope; import datadog.trace.bootstrap.instrumentation.api.AgentSpan; +import datadog.trace.bootstrap.instrumentation.java.concurrent.AsyncResultExtensions; +import java.util.concurrent.CompletionStage; import net.bytebuddy.asm.Advice; import org.springframework.messaging.Message; import org.springframework.messaging.handler.invocation.InvocableHandlerMethod; @@ -84,17 +86,24 @@ public static AgentScope onEnter(@Advice.This InvocableHandlerMethod thiz) { } @Advice.OnMethodExit(onThrowable = Throwable.class, suppress = Throwable.class) - public static void onExit(@Advice.Enter AgentScope scope, @Advice.Thrown Throwable error) { + public static void onExit( + @Advice.Enter AgentScope scope, + @Advice.Return(readOnly = false) Object result, + @Advice.Thrown Throwable error) { if (null == scope) { return; } AgentSpan span = scope.span(); - if (null != error) { - DECORATE.onError(span, error); - } scope.close(); - DECORATE.beforeFinish(span); - span.finish(); + if (result instanceof CompletionStage) { + result = ((CompletionStage) result).whenComplete(AsyncResultExtensions.finishSpan(span)); + } else { + if (null != error) { + DECORATE.onError(span, error); + } + DECORATE.beforeFinish(span); + span.finish(); + } } } } diff --git a/dd-java-agent/instrumentation/spring/spring-messaging-4.0/src/test/groovy/SpringListenerSQSTest.groovy b/dd-java-agent/instrumentation/spring/spring-messaging-4.0/src/test/groovy/SpringListenerSQSTest.groovy index ec15fe7d808..99691073866 100644 --- a/dd-java-agent/instrumentation/spring/spring-messaging-4.0/src/test/groovy/SpringListenerSQSTest.groovy +++ b/dd-java-agent/instrumentation/spring/spring-messaging-4.0/src/test/groovy/SpringListenerSQSTest.groovy @@ -15,6 +15,9 @@ import org.springframework.context.annotation.AnnotationConfigApplicationContext import org.springframework.messaging.support.GenericMessage import software.amazon.awssdk.services.sqs.SqsAsyncClient +import java.util.concurrent.TimeUnit + + class SpringListenerSQSTest extends InstrumentationSpecification { @Override @@ -128,7 +131,57 @@ class SpringListenerSQSTest extends InstrumentationSpecification { } } - static sendMessage(TraceAssert traceAssert, InetSocketAddress address, DDSpan parentSpan) { + def "async handler span duration reflects CompletableFuture completion"() { + setup: + def context = new AnnotationConfigApplicationContext(Config) + def address = context.getBean(SQSRestServer).waitUntilStarted().localAddress() + def template = SqsTemplate.newTemplate(context.getBean(SqsAsyncClient)) + TEST_WRITER.waitForTraces(2) + TEST_WRITER.clear() + + when: + TraceUtils.runUnderTrace("parent") { + template.sendAsync("SpringListenerSQSAsync", "an async message").get() + } + + then: + def sendingSpan + assertTraces(4, SORT_TRACES_BY_START) { + sortSpansByStart() + trace(3) { + basicSpan(it, "parent") + getQueueUrl(it, address, span(0), "SpringListenerSQSAsync") + sendMessage(it, address, span(0), "SpringListenerSQSAsync") + sendingSpan = span(2) + } + trace(1) { + receiveMessage(it, address, sendingSpan, "SpringListenerSQSAsync") + } + trace(1) { + span { + serviceName "my-service" + operationName "spring.consume" + resourceName "TestListener.observeAsync" + spanType DDSpanTypes.MESSAGE_CONSUMER + errored false + measured true + childOf(sendingSpan) + // The span duration should be at least 500ms since the async handler sleeps 500ms + assert span(0).durationNano > TimeUnit.MILLISECONDS.toNanos(500) + tags { + "$Tags.COMPONENT" "spring-messaging" + "$Tags.SPAN_KIND" Tags.SPAN_KIND_CONSUMER + defaultTags(true) + } + } + } + trace(1) { + deleteMessageBatch(it, address, "SpringListenerSQSAsync") + } + } + } + + static sendMessage(TraceAssert traceAssert, InetSocketAddress address, DDSpan parentSpan, String queueName = "SpringListenerSQS") { traceAssert.span { serviceName "sqs" operationName "aws.http" @@ -148,7 +201,7 @@ class SpringListenerSQSTest extends InstrumentationSpecification { "aws_service" "Sqs" "aws.operation" "SendMessage" "aws.agent" "java-aws-sdk" - "aws.queue.url" "http://localhost:${address.port}/000000000000/SpringListenerSQS" + "aws.queue.url" "http://localhost:${address.port}/000000000000/${queueName}" "aws.requestId" "00000000-0000-0000-0000-000000000000" urlTags("http://localhost:${address.port}/", ExpectedQueryParams.getExpectedQueryParams("SendMessage")) defaultTags() @@ -156,7 +209,7 @@ class SpringListenerSQSTest extends InstrumentationSpecification { } } - static getQueueUrl(TraceAssert traceAssert, InetSocketAddress address, DDSpan parentSpan) { + static getQueueUrl(TraceAssert traceAssert, InetSocketAddress address, DDSpan parentSpan, String queueName = "SpringListenerSQS") { traceAssert.span { serviceName "java-aws-sdk" operationName "aws.http" @@ -176,16 +229,16 @@ class SpringListenerSQSTest extends InstrumentationSpecification { "aws_service" "Sqs" "aws.operation" "GetQueueUrl" "aws.agent" "java-aws-sdk" - "aws.queue.name" "SpringListenerSQS" + "aws.queue.name" queueName "aws.requestId" "00000000-0000-0000-0000-000000000000" - "queuename" "SpringListenerSQS" + "queuename" queueName urlTags("http://localhost:${address.port}/", ExpectedQueryParams.getExpectedQueryParams("GetQueueUrl")) defaultTags() } } } - static receiveMessage(TraceAssert traceAssert, InetSocketAddress address, DDSpan parentSpan) { + static receiveMessage(TraceAssert traceAssert, InetSocketAddress address, DDSpan parentSpan, String queueName = "SpringListenerSQS") { traceAssert.span { serviceName "sqs" operationName "aws.http" @@ -201,7 +254,7 @@ class SpringListenerSQSTest extends InstrumentationSpecification { "aws_service" "Sqs" "aws.operation" "ReceiveMessage" "aws.agent" "java-aws-sdk" - "aws.queue.url" "http://localhost:${address.port}/000000000000/SpringListenerSQS" + "aws.queue.url" "http://localhost:${address.port}/000000000000/${queueName}" "aws.requestId" "00000000-0000-0000-0000-000000000000" defaultTags(true) } @@ -225,7 +278,7 @@ class SpringListenerSQSTest extends InstrumentationSpecification { } } - static deleteMessageBatch(TraceAssert traceAssert, InetSocketAddress address) { + static deleteMessageBatch(TraceAssert traceAssert, InetSocketAddress address, String queueName = "SpringListenerSQS") { traceAssert.span { serviceName "sqs" operationName "aws.http" @@ -245,7 +298,7 @@ class SpringListenerSQSTest extends InstrumentationSpecification { "aws_service" "Sqs" "aws.operation" "DeleteMessageBatch" "aws.agent" "java-aws-sdk" - "aws.queue.url" "http://localhost:${address.port}/000000000000/SpringListenerSQS" + "aws.queue.url" "http://localhost:${address.port}/000000000000/${queueName}" "aws.requestId" "00000000-0000-0000-0000-000000000000" urlTags("http://localhost:${address.port}/", ExpectedQueryParams.getExpectedQueryParams("DeleteMessageBatch")) defaultTags() diff --git a/dd-java-agent/instrumentation/spring/spring-messaging-4.0/src/test/groovy/listener/TestListener.groovy b/dd-java-agent/instrumentation/spring/spring-messaging-4.0/src/test/groovy/listener/TestListener.groovy index 7d0dd2e612d..b058bfc7453 100644 --- a/dd-java-agent/instrumentation/spring/spring-messaging-4.0/src/test/groovy/listener/TestListener.groovy +++ b/dd-java-agent/instrumentation/spring/spring-messaging-4.0/src/test/groovy/listener/TestListener.groovy @@ -3,10 +3,20 @@ package listener import io.awspring.cloud.sqs.annotation.SqsListener import org.springframework.stereotype.Component +import java.util.concurrent.CompletableFuture + @Component class TestListener { @SqsListener(queueNames = "SpringListenerSQS") void observe(String message) { println "Received $message" } + + @SqsListener(queueNames = "SpringListenerSQSAsync") + CompletableFuture observeAsync(String message) { + return CompletableFuture.runAsync { + Thread.sleep(500) + println "Async received $message" + } + } }