diff --git a/btrace-client/src/main/java/org/openjdk/btrace/client/Client.java b/btrace-client/src/main/java/org/openjdk/btrace/client/Client.java
index b80b9410..ca6f2b2e 100644
--- a/btrace-client/src/main/java/org/openjdk/btrace/client/Client.java
+++ b/btrace-client/src/main/java/org/openjdk/btrace/client/Client.java
@@ -758,7 +758,7 @@ public void attach(String pid, String agentPath, String sysCp, String bootCp) th
}
}
- void connectAndListProbes(String host, CommandListener listener) throws IOException {
+ public void connectAndListProbes(String host, CommandListener listener) throws IOException {
if (sock != null) {
throw new IllegalStateException();
}
@@ -1071,11 +1071,11 @@ public synchronized void close() throws IOException {
reset();
}
- boolean isDisconnected() {
+ public boolean isDisconnected() {
return disconnected;
}
- void disconnect() throws IOException {
+ public void disconnect() throws IOException {
disconnected = true;
if (log.isDebugEnabled()) {
log.debug("sending DISCONNECT request to agent");
diff --git a/btrace-dist/src/main/resources/samples/GpuBridge.java b/btrace-dist/src/main/resources/samples/GpuBridge.java
new file mode 100644
index 00000000..f8e451a4
--- /dev/null
+++ b/btrace-dist/src/main/resources/samples/GpuBridge.java
@@ -0,0 +1,79 @@
+import org.openjdk.btrace.core.annotations.*;
+import org.openjdk.btrace.core.extensions.Injected;
+import org.openjdk.btrace.gpu.GpuBridgeService;
+
+import static org.openjdk.btrace.core.BTraceUtils.*;
+
+/**
+ * Traces GPU model inference via ONNX Runtime and DJL (Deep Java Library).
+ * Tracks inference latency, batch sizes, and model load times.
+ *
+ *
",
+ location = @Location(Kind.RETURN))
+ public static void onOnnxModelLoad(@Duration long dur) {
+ gpu.recordModelLoad("onnx", "session", dur);
+ }
+
+ // ==================== DJL (Deep Java Library) ====================
+
+ @OnMethod(
+ clazz = "/ai\\.djl\\.inference\\.Predictor/",
+ method = "predict",
+ location = @Location(Kind.RETURN))
+ public static void onDjlPredict(@Duration long dur) {
+ gpu.recordInference("djl", "predictor", dur);
+ }
+
+ @OnMethod(
+ clazz = "/ai\\.djl\\.repository\\.zoo\\.ModelZoo/",
+ method = "loadModel",
+ location = @Location(Kind.RETURN))
+ public static void onDjlModelLoad(@Duration long dur) {
+ gpu.recordModelLoad("djl", "model-zoo", dur);
+ }
+
+ // ==================== TensorFlow Java ====================
+
+ @OnMethod(
+ clazz = "/org\\.tensorflow\\.Session/",
+ method = "run",
+ location = @Location(Kind.RETURN))
+ public static void onTensorFlowRun(@Duration long dur) {
+ gpu.recordInference("tensorflow", "session", dur);
+ }
+
+ // ==================== Periodic summary ====================
+
+ @OnTimer(30000)
+ public static void periodicSummary() {
+ println(gpu.getSummary());
+ }
+
+ @OnEvent("summary")
+ public static void onDemandSummary() {
+ println(gpu.getSummary());
+ }
+}
diff --git a/btrace-dist/src/main/resources/samples/LlmTrace.java b/btrace-dist/src/main/resources/samples/LlmTrace.java
new file mode 100644
index 00000000..98e2cefc
--- /dev/null
+++ b/btrace-dist/src/main/resources/samples/LlmTrace.java
@@ -0,0 +1,114 @@
+/*
+ * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation. Oracle designates this
+ * particular file as subject to the Classpath exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+import static org.openjdk.btrace.core.BTraceUtils.*;
+
+import org.openjdk.btrace.core.annotations.*;
+import org.openjdk.btrace.llm.LlmTraceService;
+
+/**
+ * Sample BTrace script that traces LLM API calls using the btrace-llm-trace extension.
+ *
+ * This is a generic template. It intercepts Langchain4j's ChatLanguageModel.generate()
+ * method and records call metrics. Adapt the @OnMethod annotations to target your specific
+ * LLM SDK (OpenAI Java SDK, Anthropic Java SDK, Spring AI, etc.).
+ *
+ *
Usage:
+ * btrace <PID> LlmTrace.java
+ *
+ *
Send a named event "summary" to print the current stats:
+ * (Ctrl-C, option 3, enter "summary")
+ */
+@BTrace
+public class LlmTrace {
+
+ @Injected
+ private static LlmTraceService llm;
+
+ /**
+ * Trace Langchain4j ChatLanguageModel.generate() calls.
+ * Captures latency on every call completion.
+ */
+ @OnMethod(
+ clazz = "+dev.langchain4j.model.chat.ChatLanguageModel",
+ method = "generate",
+ location = @Location(Kind.RETURN))
+ public static void onLangchain4jGenerate(
+ @ProbeClassName String className,
+ @ProbeMethodName String methodName,
+ @Duration long duration) {
+ // Model name extracted from the class; token counts need return value parsing
+ // For a production script, parse the Response return value
+ llm.recordCall("langchain4j", className, 0, 0, duration);
+ println(strcat(strcat(strcat("LLM call: ", className), " "),
+ strcat(str(duration / 1000000L), "ms")));
+ }
+
+ /**
+ * Trace Langchain4j StreamingChatLanguageModel calls.
+ */
+ @OnMethod(
+ clazz = "+dev.langchain4j.model.chat.StreamingChatLanguageModel",
+ method = "generate",
+ location = @Location(Kind.RETURN))
+ public static void onLangchain4jStreaming(
+ @ProbeClassName String className,
+ @Duration long duration) {
+ llm.recordStreamingCall(className, 0, 0, duration, 0);
+ }
+
+ /**
+ * Trace errors from any ChatLanguageModel implementation.
+ */
+ @OnMethod(
+ clazz = "+dev.langchain4j.model.chat.ChatLanguageModel",
+ method = "generate",
+ location = @Location(Kind.ERROR))
+ public static void onLangchain4jError(
+ @ProbeClassName String className,
+ @Duration long duration,
+ Throwable error) {
+ llm.recordError(className, Strings.str(error.getClass()), duration);
+ println(strcat("LLM ERROR: ", Strings.str(error)));
+ }
+
+ /**
+ * Print summary on named event "summary".
+ */
+ @OnEvent("summary")
+ public static void onSummary() {
+ println(llm.getSummary());
+ }
+
+ /**
+ * Print summary periodically (every 30 seconds).
+ */
+ @OnTimer(30000)
+ public static void onTimer() {
+ if (llm.getTotalCalls() > 0) {
+ println(llm.getSummary());
+ }
+ }
+}
diff --git a/btrace-dist/src/main/resources/samples/RagQuality.java b/btrace-dist/src/main/resources/samples/RagQuality.java
new file mode 100644
index 00000000..f75497ac
--- /dev/null
+++ b/btrace-dist/src/main/resources/samples/RagQuality.java
@@ -0,0 +1,73 @@
+import org.openjdk.btrace.core.annotations.*;
+import org.openjdk.btrace.core.extensions.Injected;
+import org.openjdk.btrace.rag.RagQualityService;
+
+import static org.openjdk.btrace.core.BTraceUtils.*;
+
+/**
+ * Traces RAG pipeline performance: vector DB query latency, similarity
+ * scores, and empty retrieval rates. Targets Pinecone Java client.
+ *
+ * Attach to a JVM running a RAG pipeline:
+ *
+ * btrace <pid> RagQuality.java
+ *
+ */
+@BTrace
+public class RagQuality {
+
+ @Injected
+ private static RagQualityService rag;
+
+ // ==================== Pinecone ====================
+
+ @OnMethod(
+ clazz = "/io\\.pinecone\\..*/",
+ method = "query",
+ location = @Location(Kind.RETURN))
+ public static void onPineconeQuery(@Duration long dur) {
+ rag.recordQuery("pinecone", dur);
+ }
+
+ // ==================== Milvus ====================
+
+ @OnMethod(
+ clazz = "/io\\.milvus\\.client\\..*/",
+ method = "search",
+ location = @Location(Kind.RETURN))
+ public static void onMilvusSearch(@Duration long dur) {
+ rag.recordQuery("milvus", dur);
+ }
+
+ // ==================== Weaviate ====================
+
+ @OnMethod(
+ clazz = "/io\\.weaviate\\.client\\..*/",
+ method = "/get|search/",
+ location = @Location(Kind.RETURN))
+ public static void onWeaviateQuery(@Duration long dur) {
+ rag.recordQuery("weaviate", dur);
+ }
+
+ // ==================== Chroma ====================
+
+ @OnMethod(
+ clazz = "/tech\\.amikos\\.chromadb\\..*/",
+ method = "query",
+ location = @Location(Kind.RETURN))
+ public static void onChromaQuery(@Duration long dur) {
+ rag.recordQuery("chroma", dur);
+ }
+
+ // ==================== Periodic summary ====================
+
+ @OnTimer(30000)
+ public static void periodicSummary() {
+ println(rag.getSummary());
+ }
+
+ @OnEvent("summary")
+ public static void onDemandSummary() {
+ println(rag.getSummary());
+ }
+}
diff --git a/btrace-dist/src/main/resources/samples/VibeGuard.java b/btrace-dist/src/main/resources/samples/VibeGuard.java
new file mode 100644
index 00000000..40afbaa3
--- /dev/null
+++ b/btrace-dist/src/main/resources/samples/VibeGuard.java
@@ -0,0 +1,69 @@
+import org.openjdk.btrace.core.annotations.*;
+import org.openjdk.btrace.core.extensions.Injected;
+import org.openjdk.btrace.vibeguard.VibeGuardService;
+
+import static org.openjdk.btrace.core.BTraceUtils.*;
+
+/**
+ * Runtime behavioral contracts for AI-generated code. Enforces latency
+ * budgets, call rate limits, and null-safety on methods that may have
+ * been generated by an LLM.
+ *
+ * Attach to a JVM running AI-generated service code:
+ *
+ * btrace <pid> VibeGuard.java
+ *
+ */
+@BTrace
+public class VibeGuard {
+
+ @Injected
+ private static VibeGuardService guard;
+
+ // ==================== Latency budgets ====================
+
+ /**
+ * Enforce 500ms latency budget on all methods in the AI-generated package.
+ * Adjust the clazz pattern to match your project structure.
+ */
+ @OnMethod(
+ clazz = "/com\\.myapp\\.ai\\..*/",
+ method = "/.*/",
+ location = @Location(Kind.RETURN))
+ public static void checkAiLatency(
+ @ProbeClassName String cls,
+ @ProbeMethodName String method,
+ @Duration long dur) {
+ guard.checkLatency(Strings.strcat(cls, Strings.strcat(".", method)),
+ dur, 500_000_000L);
+ }
+
+ // ==================== Null safety on return values ====================
+
+ @OnMethod(
+ clazz = "/com\\.myapp\\.ai\\..*/",
+ method = "/.*/",
+ location = @Location(Kind.RETURN))
+ public static void checkNullReturn(
+ @ProbeClassName String cls,
+ @ProbeMethodName String method,
+ @Return Object ret) {
+ guard.checkNotNull(
+ Strings.strcat(cls, Strings.strcat(".", method)), ret);
+ }
+
+ // ==================== Alert on violations ====================
+
+ @OnTimer(10000)
+ public static void checkViolations() {
+ if (guard.hasViolations()) {
+ println("=== VIBE GUARD ALERT ===");
+ println(guard.getSummary());
+ }
+ }
+
+ @OnEvent("summary")
+ public static void onDemandSummary() {
+ println(guard.getSummary());
+ }
+}
diff --git a/btrace-extensions/btrace-gpu-bridge/build.gradle b/btrace-extensions/btrace-gpu-bridge/build.gradle
new file mode 100644
index 00000000..23a56aec
--- /dev/null
+++ b/btrace-extensions/btrace-gpu-bridge/build.gradle
@@ -0,0 +1,33 @@
+plugins {
+ id 'org.openjdk.btrace.extension'
+ alias(libs.plugins.shadow)
+}
+
+java {
+ sourceCompatibility = 8
+ targetCompatibility = 8
+}
+
+compileJava {
+ javaCompiler = javaToolchains.compilerFor {
+ languageVersion.set(JavaLanguageVersion.of(11))
+ }
+}
+
+btraceExtension {
+ id = 'btrace-gpu-bridge'
+ name = 'BTrace GPU Bridge'
+ description = 'GPU and model inference observability: ONNX Runtime, DJL, TensorFlow Java, and Panama FFM compute tracing'
+ services = ['org.openjdk.btrace.gpu.GpuBridgeService']
+}
+
+dependencies {
+ apiCompileOnly project(':btrace-core')
+ implImplementation project(':btrace-core')
+
+ testImplementation libs.junit.jupiter
+}
+
+test {
+ useJUnitPlatform()
+}
diff --git a/btrace-extensions/btrace-gpu-bridge/src/api/java/org/openjdk/btrace/gpu/GpuBridgeService.java b/btrace-extensions/btrace-gpu-bridge/src/api/java/org/openjdk/btrace/gpu/GpuBridgeService.java
new file mode 100644
index 00000000..ea9cd57a
--- /dev/null
+++ b/btrace-extensions/btrace-gpu-bridge/src/api/java/org/openjdk/btrace/gpu/GpuBridgeService.java
@@ -0,0 +1,131 @@
+package org.openjdk.btrace.gpu;
+
+import org.openjdk.btrace.core.extensions.ServiceDescriptor;
+
+/**
+ * BTrace extension for GPU compute and model inference observability.
+ *
+ * Traces Java-to-GPU boundaries: ONNX Runtime inference sessions,
+ * DJL (Deep Java Library) predictions, TensorFlow Java operations,
+ * and Panama FFM native calls to CUDA/ROCm libraries. Zero external
+ * dependencies — instruments existing client classes.
+ *
+ *
Usage in a BTrace script:
+ *
+ * @Injected GpuBridgeService gpu;
+ *
+ * @OnMethod(clazz = "ai.onnxruntime.OrtSession", method = "run")
+ * void onInference(@Duration long dur) {
+ * gpu.recordInference("onnx", "resnet50", dur);
+ * }
+ *
+ */
+@ServiceDescriptor
+public interface GpuBridgeService {
+
+ // ==================== Inference recording ====================
+
+ /**
+ * Records a model inference call with duration only.
+ *
+ * @param runtime runtime name (e.g. "onnx", "djl", "tensorflow")
+ * @param modelName model identifier
+ * @param durationNanos inference duration in nanoseconds
+ */
+ void recordInference(String runtime, String modelName, long durationNanos);
+
+ /**
+ * Records a model inference with batch size and tensor dimensions.
+ *
+ * @param runtime runtime name
+ * @param modelName model identifier
+ * @param batchSize batch size of the inference request
+ * @param durationNanos inference duration
+ */
+ void recordInference(String runtime, String modelName, int batchSize, long durationNanos);
+
+ /**
+ * Starts a detailed inference record builder. Allocation-free (ThreadLocal-pooled).
+ *
+ *
+ * gpu.inference("onnx", "bert-base")
+ * .batchSize(32)
+ * .inputElements(512 * 768)
+ * .outputElements(512 * 2)
+ * .deviceType("cuda")
+ * .deviceId(0)
+ * .duration(durationNanos)
+ * .record();
+ *
+ *
+ * @param runtime runtime name
+ * @param modelName model identifier
+ * @return an inference record builder (thread-local, do not store)
+ */
+ InferenceRecord inference(String runtime, String modelName);
+
+ // ==================== Memory tracking ====================
+
+ /**
+ * Records GPU memory allocation.
+ *
+ * @param deviceType device type ("cuda", "rocm", "cpu")
+ * @param deviceId device index
+ * @param bytes allocated bytes
+ */
+ void recordMemoryAlloc(String deviceType, int deviceId, long bytes);
+
+ /**
+ * Records GPU memory deallocation.
+ *
+ * @param deviceType device type
+ * @param deviceId device index
+ * @param bytes freed bytes
+ */
+ void recordMemoryFree(String deviceType, int deviceId, long bytes);
+
+ // ==================== Native call tracking ====================
+
+ /**
+ * Records a native/FFM call to a GPU library (e.g. cuBLAS, cuDNN).
+ *
+ * @param library library name
+ * @param function function name
+ * @param durationNanos call duration
+ */
+ void recordNativeCall(String library, String function, long durationNanos);
+
+ // ==================== Model lifecycle ====================
+
+ /**
+ * Records model load/initialization time.
+ *
+ * @param runtime runtime name
+ * @param modelName model identifier
+ * @param durationNanos load duration
+ */
+ void recordModelLoad(String runtime, String modelName, long durationNanos);
+
+ // ==================== Reporting ====================
+
+ /** Returns a formatted summary of all GPU/inference metrics. */
+ String getSummary();
+
+ /** Returns summary for a specific model. */
+ String getModelSummary(String modelName);
+
+ /** Total number of inference calls across all models. */
+ long getTotalInferences();
+
+ /** Returns estimated GPU memory currently allocated (bytes), or -1 if not tracked. */
+ long getCurrentGpuMemoryBytes();
+
+ /** Returns peak GPU memory seen (bytes), or -1 if not tracked. */
+ long getPeakGpuMemoryBytes();
+
+ /** Returns total number of native/FFM calls recorded. */
+ long getTotalNativeCalls();
+
+ /** Resets all collected metrics. */
+ void reset();
+}
diff --git a/btrace-extensions/btrace-gpu-bridge/src/api/java/org/openjdk/btrace/gpu/InferenceRecord.java b/btrace-extensions/btrace-gpu-bridge/src/api/java/org/openjdk/btrace/gpu/InferenceRecord.java
new file mode 100644
index 00000000..58d3a416
--- /dev/null
+++ b/btrace-extensions/btrace-gpu-bridge/src/api/java/org/openjdk/btrace/gpu/InferenceRecord.java
@@ -0,0 +1,42 @@
+package org.openjdk.btrace.gpu;
+
+/**
+ * Fluent builder for recording a model inference with detailed metrics.
+ *
+ * Obtain via {@link GpuBridgeService#inference(String, String)}.
+ * Allocation-free (ThreadLocal-pooled). Do not store the returned reference.
+ *
+ *
+ * gpu.inference("onnx", "bert-base")
+ * .batchSize(32)
+ * .inputElements(512 * 768)
+ * .outputElements(512 * 2)
+ * .deviceType("cuda")
+ * .deviceId(0)
+ * .duration(durationNanos)
+ * .record();
+ *
+ */
+public interface InferenceRecord {
+
+ /** Batch size for this inference. */
+ InferenceRecord batchSize(int size);
+
+ /** Total number of input tensor elements. */
+ InferenceRecord inputElements(long elements);
+
+ /** Total number of output tensor elements. */
+ InferenceRecord outputElements(long elements);
+
+ /** Device type: "cuda", "rocm", "cpu", "mps", etc. */
+ InferenceRecord deviceType(String type);
+
+ /** Device index (for multi-GPU). */
+ InferenceRecord deviceId(int id);
+
+ /** Inference duration in nanoseconds. */
+ InferenceRecord duration(long nanos);
+
+ /** Commits this inference record. */
+ void record();
+}
diff --git a/btrace-extensions/btrace-gpu-bridge/src/impl/java/org/openjdk/btrace/gpu/GpuBridgeServiceImpl.java b/btrace-extensions/btrace-gpu-bridge/src/impl/java/org/openjdk/btrace/gpu/GpuBridgeServiceImpl.java
new file mode 100644
index 00000000..6cceacb9
--- /dev/null
+++ b/btrace-extensions/btrace-gpu-bridge/src/impl/java/org/openjdk/btrace/gpu/GpuBridgeServiceImpl.java
@@ -0,0 +1,335 @@
+package org.openjdk.btrace.gpu;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
+import org.openjdk.btrace.core.extensions.Extension;
+
+/**
+ * Thread-safe GPU compute and inference tracking with lock-free statistics.
+ */
+public final class GpuBridgeServiceImpl extends Extension implements GpuBridgeService {
+
+ private final Map modelStats = new ConcurrentHashMap<>();
+ private final Map deviceMemory = new ConcurrentHashMap<>();
+ private final Map nativeStats = new ConcurrentHashMap<>();
+
+ private final ThreadLocal inferenceRecordPool =
+ ThreadLocal.withInitial(InferenceRecordImpl::new);
+
+ // ==================== Inference recording ====================
+
+ @Override
+ public void recordInference(String runtime, String modelName, long durationNanos) {
+ ModelStats stats = getOrCreateModel(runtime, modelName);
+ stats.inferences.incrementAndGet();
+ stats.totalDurationNanos.addAndGet(durationNanos);
+ updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos);
+ }
+
+ @Override
+ public void recordInference(String runtime, String modelName, int batchSize, long durationNanos) {
+ ModelStats stats = getOrCreateModel(runtime, modelName);
+ stats.inferences.incrementAndGet();
+ stats.totalBatchSize.addAndGet(batchSize);
+ stats.totalDurationNanos.addAndGet(durationNanos);
+ updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos);
+ }
+
+ @Override
+ public InferenceRecord inference(String runtime, String modelName) {
+ return inferenceRecordPool.get().reset(this, runtime, modelName);
+ }
+
+ void commitInferenceRecord(InferenceRecordImpl rec) {
+ ModelStats stats = getOrCreateModel(rec.runtime, rec.modelName);
+ stats.inferences.incrementAndGet();
+ stats.totalBatchSize.addAndGet(rec.batchSizeVal);
+ stats.totalInputElements.addAndGet(rec.inputElem);
+ stats.totalOutputElements.addAndGet(rec.outputElem);
+ stats.totalDurationNanos.addAndGet(rec.durationVal);
+ updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, rec.durationVal);
+ if (rec.deviceTypeVal != null) {
+ stats.lastDeviceType = rec.deviceTypeVal;
+ stats.lastDeviceId = rec.deviceIdVal;
+ }
+ }
+
+ // ==================== Memory tracking ====================
+
+ @Override
+ public void recordMemoryAlloc(String deviceType, int deviceId, long bytes) {
+ DeviceMemory dm = getOrCreateDevice(deviceType, deviceId);
+ dm.currentBytes.addAndGet(bytes);
+ dm.totalAllocated.addAndGet(bytes);
+ dm.allocCount.incrementAndGet();
+ // Update peak
+ long cur;
+ long newVal = dm.currentBytes.get();
+ do {
+ cur = dm.peakBytes.get();
+ if (newVal <= cur) break;
+ } while (!dm.peakBytes.compareAndSet(cur, newVal));
+ }
+
+ @Override
+ public void recordMemoryFree(String deviceType, int deviceId, long bytes) {
+ DeviceMemory dm = getOrCreateDevice(deviceType, deviceId);
+ dm.currentBytes.addAndGet(-bytes);
+ dm.freeCount.incrementAndGet();
+ }
+
+ // ==================== Native call tracking ====================
+
+ @Override
+ public void recordNativeCall(String library, String function, long durationNanos) {
+ String key = library + "::" + function;
+ NativeCallStats stats = nativeStats.computeIfAbsent(key, k -> new NativeCallStats());
+ stats.calls.incrementAndGet();
+ stats.totalDurationNanos.addAndGet(durationNanos);
+ }
+
+ // ==================== Model lifecycle ====================
+
+ @Override
+ public void recordModelLoad(String runtime, String modelName, long durationNanos) {
+ ModelStats stats = getOrCreateModel(runtime, modelName);
+ stats.loadCount.incrementAndGet();
+ stats.totalLoadDurationNanos.addAndGet(durationNanos);
+ }
+
+ // ==================== Reporting ====================
+
+ @Override
+ public String getSummary() {
+ if (modelStats.isEmpty() && deviceMemory.isEmpty() && nativeStats.isEmpty()) {
+ return "No GPU/inference activity recorded.";
+ }
+
+ StringBuilder sb = new StringBuilder();
+ sb.append("=== GPU Bridge Summary ===\n\n");
+
+ long totalInf = 0;
+
+ for (Map.Entry entry : modelStats.entrySet()) {
+ ModelStats s = entry.getValue();
+ long inf = s.inferences.get();
+ totalInf += inf;
+
+ sb.append("Model: ").append(entry.getKey());
+ if (s.lastDeviceType != null) {
+ sb.append(" [").append(s.lastDeviceType).append(":").append(s.lastDeviceId).append("]");
+ }
+ sb.append("\n");
+
+ // Inference stats
+ sb.append(" Inferences: ").append(inf);
+ long loads = s.loadCount.get();
+ if (loads > 0) {
+ long avgLoadMs = (s.totalLoadDurationNanos.get() / loads) / 1_000_000;
+ sb.append(" (loaded ").append(loads).append("x, avg ").append(avgLoadMs).append("ms)");
+ }
+ sb.append("\n");
+
+ // Latency
+ if (inf > 0) {
+ long avgMs = (s.totalDurationNanos.get() / inf) / 1_000_000;
+ long minMs = s.minDurationNanos.get() / 1_000_000;
+ long maxMs = s.maxDurationNanos.get() / 1_000_000;
+ sb.append(" Latency: avg ").append(avgMs).append("ms");
+ sb.append(", min ").append(minMs).append("ms");
+ sb.append(", max ").append(maxMs).append("ms\n");
+ }
+
+ // Batch size
+ long totalBatch = s.totalBatchSize.get();
+ if (totalBatch > 0 && inf > 0) {
+ sb.append(" Batch size: avg ").append(totalBatch / inf).append("\n");
+ }
+
+ // Throughput (items/sec)
+ long totalDur = s.totalDurationNanos.get();
+ if (totalBatch > 0 && totalDur > 0) {
+ double itemsPerSec = (totalBatch * 1_000_000_000.0) / totalDur;
+ sb.append(" Throughput: ").append(String.format("%.1f", itemsPerSec)).append(" items/sec\n");
+ }
+
+ // Tensor elements
+ long inEl = s.totalInputElements.get();
+ long outEl = s.totalOutputElements.get();
+ if (inEl > 0 || outEl > 0) {
+ sb.append(" Tensors: ").append(formatElements(inEl)).append(" in / ")
+ .append(formatElements(outEl)).append(" out\n");
+ }
+
+ sb.append("\n");
+ }
+
+ // Device memory
+ if (!deviceMemory.isEmpty()) {
+ sb.append("--- Device Memory ---\n");
+ for (Map.Entry entry : deviceMemory.entrySet()) {
+ DeviceMemory dm = entry.getValue();
+ sb.append(" ").append(entry.getKey()).append(": ");
+ sb.append("current ").append(formatBytes(dm.currentBytes.get()));
+ sb.append(", peak ").append(formatBytes(dm.peakBytes.get()));
+ sb.append(" (").append(dm.allocCount.get()).append(" allocs, ")
+ .append(dm.freeCount.get()).append(" frees)\n");
+ }
+ sb.append("\n");
+ }
+
+ // Native calls
+ if (!nativeStats.isEmpty()) {
+ sb.append("--- Native Calls ---\n");
+ long totalNative = 0;
+ for (Map.Entry entry : nativeStats.entrySet()) {
+ NativeCallStats ns = entry.getValue();
+ long calls = ns.calls.get();
+ totalNative += calls;
+ long avgUs = calls > 0 ? (ns.totalDurationNanos.get() / calls) / 1000 : 0;
+ sb.append(" ").append(entry.getKey()).append(": ")
+ .append(calls).append(" calls, avg ").append(avgUs).append("us\n");
+ }
+ sb.append("\n");
+ }
+
+ sb.append("--- Totals ---\n");
+ sb.append(" Inferences: ").append(totalInf).append("\n");
+
+ return sb.toString();
+ }
+
+ @Override
+ public String getModelSummary(String modelName) {
+ // Search by model name suffix
+ for (Map.Entry entry : modelStats.entrySet()) {
+ if (entry.getKey().endsWith(modelName) || entry.getKey().equals(modelName)) {
+ ModelStats s = entry.getValue();
+ long inf = s.inferences.get();
+ long avgMs = inf > 0 ? (s.totalDurationNanos.get() / inf) / 1_000_000 : 0;
+ return entry.getKey() + ": " + inf + " inferences, avg " + avgMs + "ms";
+ }
+ }
+ return "No data for model: " + modelName;
+ }
+
+ @Override
+ public long getTotalInferences() {
+ long total = 0;
+ for (ModelStats s : modelStats.values()) {
+ total += s.inferences.get();
+ }
+ return total;
+ }
+
+ @Override
+ public long getCurrentGpuMemoryBytes() {
+ if (deviceMemory.isEmpty()) return -1;
+ long total = 0;
+ for (DeviceMemory dm : deviceMemory.values()) {
+ total += dm.currentBytes.get();
+ }
+ return total;
+ }
+
+ @Override
+ public long getPeakGpuMemoryBytes() {
+ if (deviceMemory.isEmpty()) return -1;
+ long peak = 0;
+ for (DeviceMemory dm : deviceMemory.values()) {
+ peak = Math.max(peak, dm.peakBytes.get());
+ }
+ return peak;
+ }
+
+ @Override
+ public long getTotalNativeCalls() {
+ long total = 0;
+ for (NativeCallStats ns : nativeStats.values()) {
+ total += ns.calls.get();
+ }
+ return total;
+ }
+
+ @Override
+ public void reset() {
+ modelStats.clear();
+ deviceMemory.clear();
+ nativeStats.clear();
+ }
+
+ @Override
+ public void close() {
+ String summary = getSummary();
+ if (!"No GPU/inference activity recorded.".equals(summary)) {
+ getContext().send(summary);
+ }
+ }
+
+ // ==================== Internals ====================
+
+ private ModelStats getOrCreateModel(String runtime, String modelName) {
+ String key = runtime + "/" + modelName;
+ return modelStats.computeIfAbsent(key, k -> new ModelStats());
+ }
+
+ private DeviceMemory getOrCreateDevice(String deviceType, int deviceId) {
+ String key = deviceType + ":" + deviceId;
+ return deviceMemory.computeIfAbsent(key, k -> new DeviceMemory());
+ }
+
+ private static void updateMinMax(AtomicLong min, AtomicLong max, long value) {
+ long cur;
+ do {
+ cur = min.get();
+ if (value >= cur) break;
+ } while (!min.compareAndSet(cur, value));
+ do {
+ cur = max.get();
+ if (value <= cur) break;
+ } while (!max.compareAndSet(cur, value));
+ }
+
+ static String formatBytes(long bytes) {
+ if (bytes < 0) return "-" + formatBytes(-bytes);
+ if (bytes < 1024) return bytes + "B";
+ if (bytes < 1024 * 1024) return (bytes / 1024) + "KB";
+ if (bytes < 1024L * 1024 * 1024) return (bytes / (1024 * 1024)) + "MB";
+ return String.format("%.1fGB", bytes / (1024.0 * 1024 * 1024));
+ }
+
+ static String formatElements(long elements) {
+ if (elements < 1000) return elements + "";
+ if (elements < 1_000_000) return String.format("%.1fK", elements / 1000.0);
+ if (elements < 1_000_000_000) return String.format("%.1fM", elements / 1_000_000.0);
+ return String.format("%.1fB", elements / 1_000_000_000.0);
+ }
+
+ static final class ModelStats {
+ final AtomicLong inferences = new AtomicLong();
+ final AtomicLong totalBatchSize = new AtomicLong();
+ final AtomicLong totalInputElements = new AtomicLong();
+ final AtomicLong totalOutputElements = new AtomicLong();
+ final AtomicLong totalDurationNanos = new AtomicLong();
+ final AtomicLong minDurationNanos = new AtomicLong(Long.MAX_VALUE);
+ final AtomicLong maxDurationNanos = new AtomicLong(0);
+ final AtomicLong loadCount = new AtomicLong();
+ final AtomicLong totalLoadDurationNanos = new AtomicLong();
+ volatile String lastDeviceType;
+ volatile int lastDeviceId;
+ }
+
+ static final class DeviceMemory {
+ final AtomicLong currentBytes = new AtomicLong();
+ final AtomicLong peakBytes = new AtomicLong();
+ final AtomicLong totalAllocated = new AtomicLong();
+ final AtomicLong allocCount = new AtomicLong();
+ final AtomicLong freeCount = new AtomicLong();
+ }
+
+ static final class NativeCallStats {
+ final AtomicLong calls = new AtomicLong();
+ final AtomicLong totalDurationNanos = new AtomicLong();
+ }
+}
diff --git a/btrace-extensions/btrace-gpu-bridge/src/impl/java/org/openjdk/btrace/gpu/InferenceRecordImpl.java b/btrace-extensions/btrace-gpu-bridge/src/impl/java/org/openjdk/btrace/gpu/InferenceRecordImpl.java
new file mode 100644
index 00000000..c4cf48f9
--- /dev/null
+++ b/btrace-extensions/btrace-gpu-bridge/src/impl/java/org/openjdk/btrace/gpu/InferenceRecordImpl.java
@@ -0,0 +1,74 @@
+package org.openjdk.btrace.gpu;
+
+/**
+ * ThreadLocal-pooled builder for {@link InferenceRecord}. Zero allocation.
+ */
+final class InferenceRecordImpl implements InferenceRecord {
+
+ String runtime;
+ String modelName;
+ int batchSizeVal;
+ long inputElem;
+ long outputElem;
+ String deviceTypeVal;
+ int deviceIdVal;
+ long durationVal;
+
+ private GpuBridgeServiceImpl service;
+
+ InferenceRecordImpl() {}
+
+ InferenceRecordImpl reset(GpuBridgeServiceImpl service, String runtime, String modelName) {
+ this.service = service;
+ this.runtime = runtime;
+ this.modelName = modelName;
+ this.batchSizeVal = 1;
+ this.inputElem = 0;
+ this.outputElem = 0;
+ this.deviceTypeVal = null;
+ this.deviceIdVal = 0;
+ this.durationVal = 0;
+ return this;
+ }
+
+ @Override
+ public InferenceRecord batchSize(int size) {
+ this.batchSizeVal = size;
+ return this;
+ }
+
+ @Override
+ public InferenceRecord inputElements(long elements) {
+ this.inputElem = elements;
+ return this;
+ }
+
+ @Override
+ public InferenceRecord outputElements(long elements) {
+ this.outputElem = elements;
+ return this;
+ }
+
+ @Override
+ public InferenceRecord deviceType(String type) {
+ this.deviceTypeVal = type;
+ return this;
+ }
+
+ @Override
+ public InferenceRecord deviceId(int id) {
+ this.deviceIdVal = id;
+ return this;
+ }
+
+ @Override
+ public InferenceRecord duration(long nanos) {
+ this.durationVal = nanos;
+ return this;
+ }
+
+ @Override
+ public void record() {
+ service.commitInferenceRecord(this);
+ }
+}
diff --git a/btrace-extensions/btrace-gpu-bridge/src/test/java/org/openjdk/btrace/gpu/GpuBridgeServiceTest.java b/btrace-extensions/btrace-gpu-bridge/src/test/java/org/openjdk/btrace/gpu/GpuBridgeServiceTest.java
new file mode 100644
index 00000000..521561ef
--- /dev/null
+++ b/btrace-extensions/btrace-gpu-bridge/src/test/java/org/openjdk/btrace/gpu/GpuBridgeServiceTest.java
@@ -0,0 +1,222 @@
+package org.openjdk.btrace.gpu;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+import java.util.concurrent.CountDownLatch;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+class GpuBridgeServiceTest {
+
+ private GpuBridgeServiceImpl service;
+
+ @BeforeEach
+ void setUp() {
+ service = new GpuBridgeServiceImpl();
+ }
+
+ // ==================== Inference recording ====================
+
+ @Test
+ void durationOnlyInference() {
+ service.recordInference("onnx", "resnet50", 10_000_000L);
+ assertEquals(1, service.getTotalInferences());
+ assertTrue(service.getSummary().contains("onnx/resnet50"));
+ }
+
+ @Test
+ void inferenceWithBatchSize() {
+ service.recordInference("djl", "bert", 32, 50_000_000L);
+ assertEquals(1, service.getTotalInferences());
+ assertTrue(service.getSummary().contains("avg 32"));
+ }
+
+ @Test
+ void fluentBuilder() {
+ service.inference("onnx", "yolo-v8")
+ .batchSize(16)
+ .inputElements(640L * 640 * 3)
+ .outputElements(8400L * 84)
+ .deviceType("cuda")
+ .deviceId(0)
+ .duration(25_000_000L)
+ .record();
+
+ assertEquals(1, service.getTotalInferences());
+ String summary = service.getSummary();
+ assertTrue(summary.contains("cuda:0"));
+ assertTrue(summary.contains("items/sec"));
+ }
+
+ @Test
+ void fluentBuilderMinimal() {
+ service.inference("tensorflow", "mobilenet")
+ .duration(5_000_000L)
+ .record();
+
+ assertEquals(1, service.getTotalInferences());
+ }
+
+ @Test
+ void multipleModelsTrackedSeparately() {
+ service.recordInference("onnx", "resnet50", 10_000_000L);
+ service.recordInference("onnx", "bert", 20_000_000L);
+ service.recordInference("onnx", "resnet50", 12_000_000L);
+
+ assertEquals(3, service.getTotalInferences());
+ assertTrue(service.getModelSummary("resnet50").contains("2 inferences"));
+ assertTrue(service.getModelSummary("bert").contains("1 inferences"));
+ }
+
+ @Test
+ void latencyMinMax() {
+ service.recordInference("onnx", "model", 5_000_000L);
+ service.recordInference("onnx", "model", 50_000_000L);
+ service.recordInference("onnx", "model", 10_000_000L);
+
+ String summary = service.getSummary();
+ assertTrue(summary.contains("min 5ms"));
+ assertTrue(summary.contains("max 50ms"));
+ }
+
+ // ==================== Memory tracking ====================
+
+ @Test
+ void memoryAllocAndFree() {
+ service.recordMemoryAlloc("cuda", 0, 1024 * 1024 * 100); // 100MB
+ assertEquals(100 * 1024 * 1024L, service.getCurrentGpuMemoryBytes());
+ assertEquals(100 * 1024 * 1024L, service.getPeakGpuMemoryBytes());
+
+ service.recordMemoryFree("cuda", 0, 1024 * 1024 * 50); // free 50MB
+ assertEquals(50 * 1024 * 1024L, service.getCurrentGpuMemoryBytes());
+ assertEquals(100 * 1024 * 1024L, service.getPeakGpuMemoryBytes()); // peak unchanged
+ }
+
+ @Test
+ void noMemoryTracked() {
+ assertEquals(-1, service.getCurrentGpuMemoryBytes());
+ assertEquals(-1, service.getPeakGpuMemoryBytes());
+ }
+
+ @Test
+ void memorySummary() {
+ service.recordMemoryAlloc("cuda", 0, 500 * 1024 * 1024L);
+ String summary = service.getSummary();
+ assertTrue(summary.contains("Device Memory"));
+ assertTrue(summary.contains("500MB"));
+ }
+
+ // ==================== Native call tracking ====================
+
+ @Test
+ void nativeCalls() {
+ service.recordNativeCall("cublas", "sgemm", 500_000L);
+ service.recordNativeCall("cublas", "sgemm", 600_000L);
+ service.recordNativeCall("cudnn", "conv_forward", 1_000_000L);
+
+ assertEquals(3, service.getTotalNativeCalls());
+ String summary = service.getSummary();
+ assertTrue(summary.contains("cublas::sgemm"));
+ assertTrue(summary.contains("2 calls"));
+ }
+
+ // ==================== Model lifecycle ====================
+
+ @Test
+ void modelLoad() {
+ service.recordModelLoad("onnx", "bert", 2_000_000_000L); // 2 seconds
+ service.recordInference("onnx", "bert", 10_000_000L);
+
+ String summary = service.getSummary();
+ assertTrue(summary.contains("loaded 1x"));
+ assertTrue(summary.contains("2000ms"));
+ }
+
+ // ==================== Reporting ====================
+
+ @Test
+ void noDataSummary() {
+ assertEquals("No GPU/inference activity recorded.", service.getSummary());
+ }
+
+ @Test
+ void unknownModelSummary() {
+ assertEquals("No data for model: unknown", service.getModelSummary("unknown"));
+ }
+
+ @Test
+ void formatBytes() {
+ assertEquals("100B", GpuBridgeServiceImpl.formatBytes(100));
+ assertEquals("10KB", GpuBridgeServiceImpl.formatBytes(10 * 1024));
+ assertEquals("256MB", GpuBridgeServiceImpl.formatBytes(256 * 1024 * 1024));
+ assertEquals("1.5GB", GpuBridgeServiceImpl.formatBytes((long) (1.5 * 1024 * 1024 * 1024)));
+ }
+
+ @Test
+ void formatElements() {
+ assertEquals("500", GpuBridgeServiceImpl.formatElements(500));
+ assertEquals("1.5K", GpuBridgeServiceImpl.formatElements(1500));
+ assertEquals("2.0M", GpuBridgeServiceImpl.formatElements(2_000_000));
+ assertEquals("1.0B", GpuBridgeServiceImpl.formatElements(1_000_000_000));
+ }
+
+ @Test
+ void reset() {
+ service.recordInference("onnx", "model", 1L);
+ service.recordMemoryAlloc("cuda", 0, 1024);
+ service.recordNativeCall("cublas", "fn", 1L);
+ service.reset();
+
+ assertEquals(0, service.getTotalInferences());
+ assertEquals(-1, service.getCurrentGpuMemoryBytes());
+ assertEquals(0, service.getTotalNativeCalls());
+ assertEquals("No GPU/inference activity recorded.", service.getSummary());
+ }
+
+ @Test
+ void concurrentInference() throws Exception {
+ int threads = 8;
+ int infsPerThread = 1000;
+ CountDownLatch latch = new CountDownLatch(threads);
+
+ for (int t = 0; t < threads; t++) {
+ new Thread(() -> {
+ try {
+ for (int i = 0; i < infsPerThread; i++) {
+ service.recordInference("onnx", "model", 32, 1_000_000L);
+ }
+ } finally {
+ latch.countDown();
+ }
+ }).start();
+ }
+ latch.await();
+
+ assertEquals(threads * infsPerThread, service.getTotalInferences());
+ }
+
+ @Test
+ void concurrentBuilderInference() throws Exception {
+ int threads = 8;
+ int infsPerThread = 500;
+ CountDownLatch latch = new CountDownLatch(threads);
+
+ for (int t = 0; t < threads; t++) {
+ new Thread(() -> {
+ try {
+ for (int i = 0; i < infsPerThread; i++) {
+ service.inference("djl", "bert")
+ .batchSize(16)
+ .duration(2_000_000L)
+ .record();
+ }
+ } finally {
+ latch.countDown();
+ }
+ }).start();
+ }
+ latch.await();
+
+ assertEquals(threads * infsPerThread, service.getTotalInferences());
+ }
+}
diff --git a/btrace-extensions/btrace-llm-trace/build.gradle b/btrace-extensions/btrace-llm-trace/build.gradle
new file mode 100644
index 00000000..75f7d67d
--- /dev/null
+++ b/btrace-extensions/btrace-llm-trace/build.gradle
@@ -0,0 +1,33 @@
+plugins {
+ id 'org.openjdk.btrace.extension'
+ alias(libs.plugins.shadow)
+}
+
+java {
+ sourceCompatibility = 8
+ targetCompatibility = 8
+}
+
+compileJava {
+ javaCompiler = javaToolchains.compilerFor {
+ languageVersion.set(JavaLanguageVersion.of(11))
+ }
+}
+
+btraceExtension {
+ id = 'btrace-llm-trace'
+ name = 'BTrace LLM Tracing'
+ description = 'LLM inference observability: token counts, latency, cost tracking for Java LLM SDKs'
+ services = ['org.openjdk.btrace.llm.LlmTraceService']
+}
+
+dependencies {
+ apiCompileOnly project(':btrace-core')
+ implImplementation project(':btrace-core')
+
+ testImplementation libs.junit.jupiter
+}
+
+test {
+ useJUnitPlatform()
+}
diff --git a/btrace-extensions/btrace-llm-trace/src/api/java/org/openjdk/btrace/llm/CallRecord.java b/btrace-extensions/btrace-llm-trace/src/api/java/org/openjdk/btrace/llm/CallRecord.java
new file mode 100644
index 00000000..9dda73e3
--- /dev/null
+++ b/btrace-extensions/btrace-llm-trace/src/api/java/org/openjdk/btrace/llm/CallRecord.java
@@ -0,0 +1,73 @@
+package org.openjdk.btrace.llm;
+
+/**
+ * Fluent builder for recording an LLM API call.
+ *
+ * Obtain via {@link LlmTraceService#call(String)}. All setters are optional except
+ * {@link #duration(long)} — if omitted, duration defaults to 0.
+ *
+ *
Call {@link #record()} to commit the metrics.
+ *
+ *
Allocation-free: Instances are pooled per-thread internally,
+ * so calling {@code call()} does not allocate on the heap. Safe for hot-path
+ * instrumentation. However, the returned reference must not be stored or shared
+ * across threads — use it inline and call {@link #record()} immediately.
+ *
+ *
+ * llm.call("claude-sonnet-4-20250514")
+ * .provider("anthropic")
+ * .inputTokens(1500)
+ * .outputTokens(300)
+ * .cacheReadTokens(800)
+ * .duration(durationNanos)
+ * .record();
+ *
+ */
+public interface CallRecord {
+
+ /** Sets the provider name (e.g. "openai", "anthropic", "google"). */
+ CallRecord provider(String provider);
+
+ /** Sets the number of input/prompt tokens. */
+ CallRecord inputTokens(int tokens);
+
+ /** Sets the number of output/completion tokens. */
+ CallRecord outputTokens(int tokens);
+
+ /**
+ * Sets the number of cache-read input tokens (prompt caching).
+ * Anthropic: {@code usage.cache_read_input_tokens}.
+ * OpenAI: {@code usage.prompt_tokens_details.cached_tokens}.
+ */
+ CallRecord cacheReadTokens(int tokens);
+
+ /**
+ * Sets the number of cache-creation input tokens.
+ * Anthropic: {@code usage.cache_creation_input_tokens}.
+ */
+ CallRecord cacheCreationTokens(int tokens);
+
+ /** Marks this call as a streaming response. */
+ CallRecord streaming();
+
+ /**
+ * Sets the time-to-first-token for streaming calls.
+ * Only meaningful when {@link #streaming()} is set.
+ *
+ * @param nanos time from request start to first token, in nanoseconds
+ */
+ CallRecord timeToFirstToken(long nanos);
+
+ /**
+ * Sets the total call duration.
+ *
+ * @param nanos duration in nanoseconds (typically from {@code @Duration})
+ */
+ CallRecord duration(long nanos);
+
+ /**
+ * Commits this call record to the trace service.
+ * Must be called exactly once to record the metrics.
+ */
+ void record();
+}
diff --git a/btrace-extensions/btrace-llm-trace/src/api/java/org/openjdk/btrace/llm/LlmTraceService.java b/btrace-extensions/btrace-llm-trace/src/api/java/org/openjdk/btrace/llm/LlmTraceService.java
new file mode 100644
index 00000000..3a3519e3
--- /dev/null
+++ b/btrace-extensions/btrace-llm-trace/src/api/java/org/openjdk/btrace/llm/LlmTraceService.java
@@ -0,0 +1,155 @@
+package org.openjdk.btrace.llm;
+
+import org.openjdk.btrace.core.extensions.ServiceDescriptor;
+
+/**
+ * LLM inference observability service for BTrace scripts.
+ *
+ * Records LLM API call metrics — token counts, latencies, costs — and provides
+ * aggregated statistics. Scripts use {@code @OnMethod} to intercept LLM SDK calls
+ * and feed data into this service.
+ *
+ *
Thread-safe. All methods can be called concurrently from instrumented threads.
+ *
+ *
Quick start — latency only (works with any SDK)
+ *
+ * @Injected private static LlmTraceService llm;
+ *
+ * @OnMethod(clazz = "+dev.langchain4j.model.chat.ChatLanguageModel",
+ * method = "generate", location = @Location(Kind.RETURN))
+ * public static void onChat(@ProbeClassName String cls, @Duration long dur) {
+ * llm.recordCall(cls, dur);
+ * }
+ *
+ *
+ * Full metrics with fluent builder
+ *
+ * llm.call("claude-sonnet-4-20250514")
+ * .provider("anthropic")
+ * .inputTokens(1500)
+ * .outputTokens(300)
+ * .cacheReadTokens(800)
+ * .duration(durationNanos)
+ * .record();
+ *
+ */
+@ServiceDescriptor
+public interface LlmTraceService {
+
+ // ==================== Simple recording ====================
+
+ /**
+ * Records an LLM call with only latency (no token info).
+ * Use this when token counts aren't easily extractable.
+ *
+ * @param model model identifier or class name
+ * @param durationNanos call duration in nanoseconds
+ */
+ void recordCall(String model, long durationNanos);
+
+ /**
+ * Records an LLM call with token counts.
+ *
+ * @param model model identifier (e.g. "gpt-4o", "claude-sonnet-4-20250514")
+ * @param inputTokens number of input/prompt tokens
+ * @param outputTokens number of output/completion tokens
+ * @param durationNanos call duration in nanoseconds
+ */
+ void recordCall(String model, int inputTokens, int outputTokens, long durationNanos);
+
+ // ==================== Fluent builder ====================
+
+ /**
+ * Starts a fluent call record for the given model.
+ *
+ * The returned builder is allocation-free (ThreadLocal-pooled). It is safe
+ * to use on hot paths. The builder must be used inline on the calling thread
+ * and {@link CallRecord#record()} must be called before the next {@code call()}.
+ *
+ *
+ * llm.call("claude-sonnet-4-20250514")
+ * .provider("anthropic")
+ * .inputTokens(1500)
+ * .outputTokens(300)
+ * .cacheReadTokens(800)
+ * .streaming()
+ * .timeToFirstToken(200_000_000L)
+ * .duration(durationNanos)
+ * .record();
+ *
+ *
+ * @param model model identifier
+ * @return a call record builder (thread-local, do not store)
+ */
+ CallRecord call(String model);
+
+ // ==================== Specialized recording ====================
+
+ /**
+ * Records an embedding API call.
+ *
+ * @param model embedding model identifier (e.g. "text-embedding-3-small")
+ * @param tokenCount number of tokens embedded
+ * @param durationNanos call duration in nanoseconds
+ */
+ void recordEmbedding(String model, int tokenCount, long durationNanos);
+
+ /**
+ * Records a tool/function call invocation by the LLM.
+ *
+ * @param model model that made the tool call
+ * @param toolName name of the tool/function called
+ */
+ void recordToolUse(String model, String toolName);
+
+ /**
+ * Records a failed LLM API call.
+ *
+ * @param model model identifier
+ * @param errorType error class name or HTTP status code
+ * @param durationNanos call duration before failure
+ */
+ void recordError(String model, String errorType, long durationNanos);
+
+ // ==================== Reporting ====================
+
+ /**
+ * Returns a formatted summary of all recorded metrics.
+ * Includes per-model token counts, latency stats, cost estimates,
+ * cache hit rates, tool use, and error breakdown.
+ *
+ * @return multi-line summary string
+ */
+ String getSummary();
+
+ /**
+ * Returns a one-line summary for a specific model.
+ *
+ * @param model model identifier
+ * @return summary string, or "No data" if none recorded
+ */
+ String getModelSummary(String model);
+
+ /**
+ * Returns total estimated cost across all models in USD.
+ * Uses built-in pricing table. Returns -1 if all models are unknown.
+ *
+ * @return estimated cost in USD, or -1 if pricing unknown
+ */
+ double getEstimatedCostUsd();
+
+ /** Returns total number of chat completion calls recorded. */
+ long getTotalCalls();
+
+ /** Returns total input tokens across all calls. */
+ long getTotalInputTokens();
+
+ /** Returns total output tokens across all calls. */
+ long getTotalOutputTokens();
+
+ /** Returns total number of embedding calls recorded. */
+ long getTotalEmbeddingCalls();
+
+ /** Resets all collected metrics. */
+ void reset();
+}
diff --git a/btrace-extensions/btrace-llm-trace/src/impl/java/org/openjdk/btrace/llm/CallRecordImpl.java b/btrace-extensions/btrace-llm-trace/src/impl/java/org/openjdk/btrace/llm/CallRecordImpl.java
new file mode 100644
index 00000000..591be372
--- /dev/null
+++ b/btrace-extensions/btrace-llm-trace/src/impl/java/org/openjdk/btrace/llm/CallRecordImpl.java
@@ -0,0 +1,98 @@
+package org.openjdk.btrace.llm;
+
+/**
+ * Mutable builder implementing the {@link CallRecord} fluent API.
+ *
+ * Instances are pooled per-thread via {@link ThreadLocal} to avoid
+ * heap allocation on every {@code call()} invocation. This makes the
+ * builder safe for hot-path instrumentation — no garbage is produced.
+ *
+ *
Not thread-safe — intended to be created, configured, and recorded
+ * within a single BTrace handler method invocation on the same thread.
+ */
+final class CallRecordImpl implements CallRecord {
+
+ String model;
+ String providerVal;
+ int inputTok;
+ int outputTok;
+ int cacheReadTok;
+ int cacheCreateTok;
+ boolean isStreaming;
+ long ttftVal;
+ long durationVal;
+
+ private LlmTraceServiceImpl service;
+
+ CallRecordImpl() {
+ // Created once per thread via ThreadLocal
+ }
+
+ /** Resets all fields and binds this record to a new call. */
+ CallRecordImpl reset(LlmTraceServiceImpl service, String model) {
+ this.service = service;
+ this.model = model;
+ this.providerVal = null;
+ this.inputTok = 0;
+ this.outputTok = 0;
+ this.cacheReadTok = 0;
+ this.cacheCreateTok = 0;
+ this.isStreaming = false;
+ this.ttftVal = 0;
+ this.durationVal = 0;
+ return this;
+ }
+
+ @Override
+ public CallRecord provider(String provider) {
+ this.providerVal = provider;
+ return this;
+ }
+
+ @Override
+ public CallRecord inputTokens(int tokens) {
+ this.inputTok = tokens;
+ return this;
+ }
+
+ @Override
+ public CallRecord outputTokens(int tokens) {
+ this.outputTok = tokens;
+ return this;
+ }
+
+ @Override
+ public CallRecord cacheReadTokens(int tokens) {
+ this.cacheReadTok = tokens;
+ return this;
+ }
+
+ @Override
+ public CallRecord cacheCreationTokens(int tokens) {
+ this.cacheCreateTok = tokens;
+ return this;
+ }
+
+ @Override
+ public CallRecord streaming() {
+ this.isStreaming = true;
+ return this;
+ }
+
+ @Override
+ public CallRecord timeToFirstToken(long nanos) {
+ this.ttftVal = nanos;
+ return this;
+ }
+
+ @Override
+ public CallRecord duration(long nanos) {
+ this.durationVal = nanos;
+ return this;
+ }
+
+ @Override
+ public void record() {
+ service.commitCallRecord(this);
+ }
+}
diff --git a/btrace-extensions/btrace-llm-trace/src/impl/java/org/openjdk/btrace/llm/LlmTraceServiceImpl.java b/btrace-extensions/btrace-llm-trace/src/impl/java/org/openjdk/btrace/llm/LlmTraceServiceImpl.java
new file mode 100644
index 00000000..6bff6af8
--- /dev/null
+++ b/btrace-extensions/btrace-llm-trace/src/impl/java/org/openjdk/btrace/llm/LlmTraceServiceImpl.java
@@ -0,0 +1,443 @@
+package org.openjdk.btrace.llm;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
+import org.openjdk.btrace.core.extensions.Extension;
+
+/**
+ * Thread-safe LLM call tracing with lock-free per-model statistics.
+ * Zero external dependencies.
+ */
+public final class LlmTraceServiceImpl extends Extension implements LlmTraceService {
+
+ private final Map modelStats = new ConcurrentHashMap<>();
+ private final Map embeddingStats = new ConcurrentHashMap<>();
+ private final Map toolUseCounts = new ConcurrentHashMap<>();
+ private final Map errorCounts = new ConcurrentHashMap<>();
+
+ // ==================== Simple recording ====================
+
+ @Override
+ public void recordCall(String model, long durationNanos) {
+ ModelStats stats = getOrCreate(modelStats, model);
+ stats.calls.incrementAndGet();
+ stats.totalDurationNanos.addAndGet(durationNanos);
+ updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos);
+ }
+
+ @Override
+ public void recordCall(String model, int inputTokens, int outputTokens, long durationNanos) {
+ ModelStats stats = getOrCreate(modelStats, model);
+ stats.calls.incrementAndGet();
+ stats.inputTokens.addAndGet(inputTokens);
+ stats.outputTokens.addAndGet(outputTokens);
+ stats.totalDurationNanos.addAndGet(durationNanos);
+ updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos);
+ }
+
+ // ==================== Fluent builder ====================
+
+ /**
+ * ThreadLocal-pooled builder — one CallRecordImpl per thread, reused across calls.
+ * Eliminates per-call heap allocation, making the builder safe for hot paths.
+ */
+ private final ThreadLocal callRecordPool =
+ ThreadLocal.withInitial(CallRecordImpl::new);
+
+ @Override
+ public CallRecord call(String model) {
+ return callRecordPool.get().reset(this, model);
+ }
+
+ void commitCallRecord(CallRecordImpl rec) {
+ ModelStats stats = getOrCreate(modelStats, rec.model);
+ if (rec.providerVal != null) {
+ stats.provider = rec.providerVal;
+ }
+ stats.calls.incrementAndGet();
+ stats.inputTokens.addAndGet(rec.inputTok);
+ stats.outputTokens.addAndGet(rec.outputTok);
+ stats.cacheReadTokens.addAndGet(rec.cacheReadTok);
+ stats.cacheCreationTokens.addAndGet(rec.cacheCreateTok);
+ stats.totalDurationNanos.addAndGet(rec.durationVal);
+ updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, rec.durationVal);
+ if (rec.isStreaming) {
+ stats.streamingCalls.incrementAndGet();
+ stats.totalTimeToFirstToken.addAndGet(rec.ttftVal);
+ }
+ }
+
+ // ==================== Specialized recording ====================
+
+ @Override
+ public void recordEmbedding(String model, int tokenCount, long durationNanos) {
+ ModelStats stats = getOrCreate(embeddingStats, model);
+ stats.calls.incrementAndGet();
+ stats.inputTokens.addAndGet(tokenCount);
+ stats.totalDurationNanos.addAndGet(durationNanos);
+ updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos);
+ }
+
+ @Override
+ public void recordToolUse(String model, String toolName) {
+ String key = model + "::" + toolName;
+ toolUseCounts.computeIfAbsent(key, k -> new AtomicLong()).incrementAndGet();
+ getOrCreate(modelStats, model).toolCalls.incrementAndGet();
+ }
+
+ @Override
+ public void recordError(String model, String errorType, long durationNanos) {
+ String key = model + "::" + errorType;
+ errorCounts.computeIfAbsent(key, k -> new AtomicLong()).incrementAndGet();
+ getOrCreate(modelStats, model).errors.incrementAndGet();
+ }
+
+ // ==================== Reporting ====================
+
+ @Override
+ public String getSummary() {
+ if (modelStats.isEmpty() && embeddingStats.isEmpty()) {
+ return "No LLM calls recorded.";
+ }
+
+ StringBuilder sb = new StringBuilder();
+ sb.append("=== LLM Trace Summary ===\n\n");
+
+ long totalCalls = 0;
+ long totalIn = 0;
+ long totalOut = 0;
+ double totalCost = 0;
+
+ // Chat completions
+ for (Map.Entry entry : modelStats.entrySet()) {
+ String model = entry.getKey();
+ ModelStats s = entry.getValue();
+ long calls = s.calls.get();
+ long inTok = s.inputTokens.get();
+ long outTok = s.outputTokens.get();
+ long cacheRead = s.cacheReadTokens.get();
+ long cacheCreate = s.cacheCreationTokens.get();
+
+ totalCalls += calls;
+ totalIn += inTok;
+ totalOut += outTok;
+
+ sb.append("Model: ").append(model);
+ if (!"unknown".equals(s.provider)) {
+ sb.append(" (").append(s.provider).append(")");
+ }
+ sb.append("\n");
+
+ // Calls
+ sb.append(" Calls: ").append(calls);
+ long streaming = s.streamingCalls.get();
+ if (streaming > 0) {
+ sb.append(" (").append(streaming).append(" streaming)");
+ }
+ sb.append("\n");
+
+ // Tokens
+ if (inTok > 0 || outTok > 0) {
+ sb.append(" Tokens: ").append(inTok).append(" in / ").append(outTok).append(" out");
+ if (calls > 0) {
+ sb.append(" (avg ").append(inTok / calls).append("/").append(outTok / calls).append(")");
+ }
+ sb.append("\n");
+ }
+
+ // Cache
+ if (cacheRead > 0 || cacheCreate > 0) {
+ sb.append(" Cache: ");
+ if (cacheRead > 0) {
+ sb.append(cacheRead).append(" read");
+ // Show cache hit rate relative to total input
+ if (inTok > 0) {
+ long hitPct = (cacheRead * 100) / inTok;
+ sb.append(" (").append(hitPct).append("% hit)");
+ }
+ }
+ if (cacheCreate > 0) {
+ if (cacheRead > 0) sb.append(", ");
+ sb.append(cacheCreate).append(" created");
+ }
+ sb.append("\n");
+ }
+
+ // Latency
+ if (calls > 0) {
+ long avgMs = (s.totalDurationNanos.get() / calls) / 1_000_000;
+ long minMs = s.minDurationNanos.get() / 1_000_000;
+ long maxMs = s.maxDurationNanos.get() / 1_000_000;
+ sb.append(" Latency: avg ").append(avgMs).append("ms");
+ sb.append(", min ").append(minMs).append("ms");
+ sb.append(", max ").append(maxMs).append("ms\n");
+ }
+
+ // TTFT
+ if (streaming > 0) {
+ long avgTtft = (s.totalTimeToFirstToken.get() / streaming) / 1_000_000;
+ sb.append(" TTFT (avg): ").append(avgTtft).append("ms\n");
+ }
+
+ // Tool calls
+ long tc = s.toolCalls.get();
+ if (tc > 0) {
+ sb.append(" Tool calls: ").append(tc).append("\n");
+ }
+
+ // Errors
+ long errs = s.errors.get();
+ if (errs > 0) {
+ sb.append(" Errors: ").append(errs).append("\n");
+ }
+
+ // Cost
+ double cost = estimateCost(model, inTok, outTok, cacheRead);
+ if (cost >= 0) {
+ totalCost += cost;
+ sb.append(" Est. cost: $").append(formatCost(cost));
+ if (cacheRead > 0) {
+ double uncachedCost = estimateCost(model, inTok + cacheRead, outTok, 0);
+ if (uncachedCost > cost) {
+ sb.append(" (saved $").append(formatCost(uncachedCost - cost)).append(" via cache)");
+ }
+ }
+ sb.append("\n");
+ }
+ sb.append("\n");
+ }
+
+ // Embeddings
+ if (!embeddingStats.isEmpty()) {
+ sb.append("--- Embeddings ---\n");
+ for (Map.Entry entry : embeddingStats.entrySet()) {
+ ModelStats s = entry.getValue();
+ long calls = s.calls.get();
+ long tokens = s.inputTokens.get();
+ long avgMs = calls > 0 ? (s.totalDurationNanos.get() / calls) / 1_000_000 : 0;
+ sb.append(" ").append(entry.getKey()).append(": ")
+ .append(calls).append(" calls, ")
+ .append(tokens).append(" tokens, avg ")
+ .append(avgMs).append("ms\n");
+ }
+ sb.append("\n");
+ }
+
+ // Tool use breakdown
+ if (!toolUseCounts.isEmpty()) {
+ sb.append("--- Tool Use ---\n");
+ for (Map.Entry entry : toolUseCounts.entrySet()) {
+ sb.append(" ").append(entry.getKey()).append(": ")
+ .append(entry.getValue().get()).append("\n");
+ }
+ sb.append("\n");
+ }
+
+ // Error breakdown
+ if (!errorCounts.isEmpty()) {
+ sb.append("--- Errors ---\n");
+ for (Map.Entry entry : errorCounts.entrySet()) {
+ sb.append(" ").append(entry.getKey()).append(": ")
+ .append(entry.getValue().get()).append("\n");
+ }
+ sb.append("\n");
+ }
+
+ sb.append("--- Totals ---\n");
+ sb.append(" Calls: ").append(totalCalls).append("\n");
+ if (totalIn > 0 || totalOut > 0) {
+ sb.append(" Tokens: ").append(totalIn).append(" in / ").append(totalOut).append(" out\n");
+ }
+ if (totalCost > 0) {
+ sb.append(" Est. total cost: $").append(formatCost(totalCost)).append("\n");
+ }
+
+ return sb.toString();
+ }
+
+ @Override
+ public String getModelSummary(String model) {
+ ModelStats s = modelStats.get(model);
+ if (s == null) {
+ return "No data for model: " + model;
+ }
+ long calls = s.calls.get();
+ long inTok = s.inputTokens.get();
+ long outTok = s.outputTokens.get();
+ long avgMs = calls > 0 ? (s.totalDurationNanos.get() / calls) / 1_000_000 : 0;
+
+ StringBuilder sb = new StringBuilder();
+ sb.append(model).append(": ").append(calls).append(" calls");
+ if (inTok > 0 || outTok > 0) {
+ sb.append(", ").append(inTok).append("/").append(outTok).append(" tokens (in/out)");
+ }
+ sb.append(", avg ").append(avgMs).append("ms");
+ return sb.toString();
+ }
+
+ @Override
+ public double getEstimatedCostUsd() {
+ double total = 0;
+ boolean anyKnown = false;
+ for (Map.Entry entry : modelStats.entrySet()) {
+ ModelStats s = entry.getValue();
+ double cost = estimateCost(entry.getKey(),
+ s.inputTokens.get(), s.outputTokens.get(), s.cacheReadTokens.get());
+ if (cost >= 0) {
+ total += cost;
+ anyKnown = true;
+ }
+ }
+ return anyKnown ? total : -1;
+ }
+
+ @Override
+ public long getTotalCalls() {
+ long total = 0;
+ for (ModelStats s : modelStats.values()) {
+ total += s.calls.get();
+ }
+ return total;
+ }
+
+ @Override
+ public long getTotalInputTokens() {
+ long total = 0;
+ for (ModelStats s : modelStats.values()) {
+ total += s.inputTokens.get();
+ }
+ return total;
+ }
+
+ @Override
+ public long getTotalOutputTokens() {
+ long total = 0;
+ for (ModelStats s : modelStats.values()) {
+ total += s.outputTokens.get();
+ }
+ return total;
+ }
+
+ @Override
+ public long getTotalEmbeddingCalls() {
+ long total = 0;
+ for (ModelStats s : embeddingStats.values()) {
+ total += s.calls.get();
+ }
+ return total;
+ }
+
+ @Override
+ public void reset() {
+ modelStats.clear();
+ embeddingStats.clear();
+ toolUseCounts.clear();
+ errorCounts.clear();
+ }
+
+ @Override
+ public void close() {
+ String summary = getSummary();
+ if (!"No LLM calls recorded.".equals(summary)) {
+ getContext().send(summary);
+ }
+ }
+
+ // ==================== Internals ====================
+
+ private static ModelStats getOrCreate(Map map, String key) {
+ return map.computeIfAbsent(key, k -> new ModelStats());
+ }
+
+ private static void updateMinMax(AtomicLong min, AtomicLong max, long value) {
+ long cur;
+ do {
+ cur = min.get();
+ if (value >= cur) break;
+ } while (!min.compareAndSet(cur, value));
+ do {
+ cur = max.get();
+ if (value <= cur) break;
+ } while (!max.compareAndSet(cur, value));
+ }
+
+ /**
+ * Estimates cost in USD. Cache-read tokens are priced at ~10% of input rate
+ * for models that support caching.
+ */
+ static double estimateCost(String model, long inputTokens, long outputTokens,
+ long cacheReadTokens) {
+ double inputPer1M = -1;
+ double outputPer1M = -1;
+ double cacheReadPer1M = -1;
+
+ String m = model.toLowerCase();
+
+ // Anthropic Claude
+ if (m.contains("claude") && m.contains("opus")) {
+ inputPer1M = 15.0; outputPer1M = 75.0; cacheReadPer1M = 1.50;
+ } else if (m.contains("claude") && m.contains("sonnet")) {
+ inputPer1M = 3.0; outputPer1M = 15.0; cacheReadPer1M = 0.30;
+ } else if (m.contains("claude") && m.contains("haiku")) {
+ inputPer1M = 0.80; outputPer1M = 4.0; cacheReadPer1M = 0.08;
+ }
+ // OpenAI GPT
+ else if (m.contains("gpt-4o-mini")) {
+ inputPer1M = 0.15; outputPer1M = 0.60; cacheReadPer1M = 0.075;
+ } else if (m.contains("gpt-4o")) {
+ inputPer1M = 2.50; outputPer1M = 10.0; cacheReadPer1M = 1.25;
+ } else if (m.contains("gpt-4") && m.contains("turbo")) {
+ inputPer1M = 10.0; outputPer1M = 30.0;
+ } else if (m.contains("gpt-4")) {
+ inputPer1M = 30.0; outputPer1M = 60.0;
+ } else if (m.contains("gpt-3.5")) {
+ inputPer1M = 0.50; outputPer1M = 1.50;
+ } else if (m.contains("o1-mini")) {
+ inputPer1M = 3.0; outputPer1M = 12.0; cacheReadPer1M = 1.50;
+ } else if (m.contains("o1")) {
+ inputPer1M = 15.0; outputPer1M = 60.0; cacheReadPer1M = 7.50;
+ }
+ // Google Gemini
+ else if (m.contains("gemini") && m.contains("pro")) {
+ inputPer1M = 1.25; outputPer1M = 5.0;
+ } else if (m.contains("gemini") && m.contains("flash")) {
+ inputPer1M = 0.075; outputPer1M = 0.30;
+ }
+
+ if (inputPer1M < 0) {
+ return -1;
+ }
+
+ double cost = (inputTokens * inputPer1M / 1_000_000.0)
+ + (outputTokens * outputPer1M / 1_000_000.0);
+ if (cacheReadTokens > 0 && cacheReadPer1M > 0) {
+ cost += (cacheReadTokens * cacheReadPer1M / 1_000_000.0);
+ }
+ return cost;
+ }
+
+ static String formatCost(double cost) {
+ if (cost < 0.01) {
+ return String.format("%.4f", cost);
+ }
+ return String.format("%.2f", cost);
+ }
+
+ /** Lock-free per-model statistics. */
+ static final class ModelStats {
+ volatile String provider = "unknown";
+ final AtomicLong calls = new AtomicLong();
+ final AtomicLong streamingCalls = new AtomicLong();
+ final AtomicLong inputTokens = new AtomicLong();
+ final AtomicLong outputTokens = new AtomicLong();
+ final AtomicLong cacheReadTokens = new AtomicLong();
+ final AtomicLong cacheCreationTokens = new AtomicLong();
+ final AtomicLong totalDurationNanos = new AtomicLong();
+ final AtomicLong minDurationNanos = new AtomicLong(Long.MAX_VALUE);
+ final AtomicLong maxDurationNanos = new AtomicLong(0);
+ final AtomicLong totalTimeToFirstToken = new AtomicLong();
+ final AtomicLong toolCalls = new AtomicLong();
+ final AtomicLong errors = new AtomicLong();
+ }
+}
diff --git a/btrace-extensions/btrace-llm-trace/src/test/java/org/openjdk/btrace/llm/LlmTraceServiceTest.java b/btrace-extensions/btrace-llm-trace/src/test/java/org/openjdk/btrace/llm/LlmTraceServiceTest.java
new file mode 100644
index 00000000..8951ca9f
--- /dev/null
+++ b/btrace-extensions/btrace-llm-trace/src/test/java/org/openjdk/btrace/llm/LlmTraceServiceTest.java
@@ -0,0 +1,335 @@
+package org.openjdk.btrace.llm;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+class LlmTraceServiceTest {
+
+ private LlmTraceServiceImpl service;
+
+ @BeforeEach
+ void setUp() {
+ service = new LlmTraceServiceImpl();
+ }
+
+ // ==================== Simple recording ====================
+
+ @Test
+ void recordCallDurationOnly() {
+ service.recordCall("gpt-4o", 500_000_000L);
+ service.recordCall("gpt-4o", 300_000_000L);
+
+ assertEquals(2, service.getTotalCalls());
+ assertEquals(0, service.getTotalInputTokens(), "No tokens recorded");
+ String summary = service.getModelSummary("gpt-4o");
+ assertTrue(summary.contains("2 calls"));
+ assertTrue(summary.contains("400ms"), "avg of 500ms and 300ms");
+ }
+
+ @Test
+ void recordCallWithTokens() {
+ service.recordCall("claude-sonnet-4-20250514", 1000, 200, 500_000_000L);
+ service.recordCall("claude-sonnet-4-20250514", 800, 300, 400_000_000L);
+
+ assertEquals(2, service.getTotalCalls());
+ assertEquals(1800, service.getTotalInputTokens());
+ assertEquals(500, service.getTotalOutputTokens());
+ }
+
+ // ==================== Fluent builder ====================
+
+ @Test
+ void fluentBuilderBasic() {
+ service.call("gpt-4o")
+ .inputTokens(1000)
+ .outputTokens(200)
+ .duration(500_000_000L)
+ .record();
+
+ assertEquals(1, service.getTotalCalls());
+ assertEquals(1000, service.getTotalInputTokens());
+ assertEquals(200, service.getTotalOutputTokens());
+ }
+
+ @Test
+ void fluentBuilderWithProvider() {
+ service.call("claude-sonnet-4-20250514")
+ .provider("anthropic")
+ .inputTokens(1500)
+ .outputTokens(300)
+ .duration(800_000_000L)
+ .record();
+
+ String summary = service.getSummary();
+ assertTrue(summary.contains("anthropic"));
+ assertTrue(summary.contains("claude-sonnet"));
+ }
+
+ @Test
+ void fluentBuilderWithCache() {
+ service.call("claude-sonnet-4-20250514")
+ .inputTokens(500)
+ .outputTokens(200)
+ .cacheReadTokens(1000)
+ .cacheCreationTokens(200)
+ .duration(300_000_000L)
+ .record();
+
+ String summary = service.getSummary();
+ assertTrue(summary.contains("Cache:"), "Should show cache section");
+ assertTrue(summary.contains("1000 read"), "Should show cache read tokens");
+ assertTrue(summary.contains("200 created"), "Should show cache creation tokens");
+ assertTrue(summary.contains("saved $"), "Should show cache savings");
+ }
+
+ @Test
+ void fluentBuilderStreaming() {
+ service.call("gpt-4o")
+ .inputTokens(2000)
+ .outputTokens(500)
+ .streaming()
+ .timeToFirstToken(200_000_000L)
+ .duration(3_000_000_000L)
+ .record();
+
+ String summary = service.getSummary();
+ assertTrue(summary.contains("streaming"));
+ assertTrue(summary.contains("TTFT"));
+ }
+
+ @Test
+ void fluentBuilderDurationOnly() {
+ // Minimal usage — just model + duration
+ service.call("some-model").duration(100_000_000L).record();
+
+ assertEquals(1, service.getTotalCalls());
+ }
+
+ // ==================== Embeddings ====================
+
+ @Test
+ void recordEmbedding() {
+ service.recordEmbedding("text-embedding-3-small", 500, 50_000_000L);
+ service.recordEmbedding("text-embedding-3-small", 300, 30_000_000L);
+
+ assertEquals(2, service.getTotalEmbeddingCalls());
+ assertEquals(0, service.getTotalCalls(), "Embeddings don't count as chat calls");
+
+ String summary = service.getSummary();
+ assertTrue(summary.contains("Embeddings"));
+ assertTrue(summary.contains("text-embedding-3-small"));
+ assertTrue(summary.contains("2 calls"));
+ assertTrue(summary.contains("800 tokens"));
+ }
+
+ // ==================== Tool use ====================
+
+ @Test
+ void recordToolUse() {
+ service.recordCall("claude-sonnet-4-20250514", 1000, 200, 500_000_000L);
+ service.recordToolUse("claude-sonnet-4-20250514", "search_web");
+ service.recordToolUse("claude-sonnet-4-20250514", "search_web");
+ service.recordToolUse("claude-sonnet-4-20250514", "run_code");
+
+ String summary = service.getSummary();
+ assertTrue(summary.contains("Tool calls: 3"));
+ assertTrue(summary.contains("Tool Use"));
+ assertTrue(summary.contains("search_web"));
+ }
+
+ // ==================== Errors ====================
+
+ @Test
+ void recordError() {
+ service.recordCall("gpt-4o", 1000, 0, 100_000_000L);
+ service.recordError("gpt-4o", "RateLimitException", 50_000_000L);
+
+ String summary = service.getSummary();
+ assertTrue(summary.contains("Errors: 1"));
+ assertTrue(summary.contains("RateLimitException"));
+ }
+
+ // ==================== Model summary ====================
+
+ @Test
+ void getModelSummaryUnknown() {
+ assertTrue(service.getModelSummary("nonexistent").contains("No data"));
+ }
+
+ @Test
+ void getModelSummaryDurationOnly() {
+ service.recordCall("my-model", 500_000_000L);
+ String result = service.getModelSummary("my-model");
+ assertTrue(result.contains("1 calls"));
+ assertTrue(result.contains("500ms"));
+ // Should NOT show "0/0 tokens"
+ assertFalse(result.contains("0/0"));
+ }
+
+ @Test
+ void getModelSummaryWithTokens() {
+ service.recordCall("gpt-4o", 1000, 200, 500_000_000L);
+ String result = service.getModelSummary("gpt-4o");
+ assertTrue(result.contains("1000/200 tokens"));
+ }
+
+ // ==================== Cost estimation ====================
+
+ @Test
+ void estimateCostClaudeSonnet() {
+ // Sonnet: $3/1M input, $15/1M output
+ double cost = LlmTraceServiceImpl.estimateCost("claude-sonnet-4-20250514",
+ 1_000_000, 100_000, 0);
+ assertEquals(3.0 + 1.5, cost, 0.01);
+ }
+
+ @Test
+ void estimateCostWithCacheRead() {
+ // Sonnet: $3/1M input, $15/1M output, $0.30/1M cache-read
+ double cost = LlmTraceServiceImpl.estimateCost("claude-sonnet-4-20250514",
+ 500_000, 100_000, 500_000);
+ double expected = (500_000 * 3.0 / 1e6) + (100_000 * 15.0 / 1e6) + (500_000 * 0.30 / 1e6);
+ assertEquals(expected, cost, 0.001);
+ }
+
+ @Test
+ void estimateCostGpt4o() {
+ double cost = LlmTraceServiceImpl.estimateCost("gpt-4o", 1_000_000, 1_000_000, 0);
+ assertEquals(2.50 + 10.0, cost, 0.01);
+ }
+
+ @Test
+ void estimateCostUnknownModel() {
+ assertEquals(-1, LlmTraceServiceImpl.estimateCost("my-custom-model", 1000, 1000, 0));
+ }
+
+ @Test
+ void estimatedCostUsdAcrossModels() {
+ service.recordCall("claude-sonnet-4-20250514", 1_000_000, 0, 1_000_000_000L);
+ service.recordCall("gpt-4o", 1_000_000, 0, 1_000_000_000L);
+
+ double cost = service.getEstimatedCostUsd();
+ assertEquals(5.50, cost, 0.01);
+ }
+
+ @Test
+ void estimatedCostUsdUnknownModels() {
+ service.recordCall("my-custom-model", 1000, 500, 100_000_000L);
+ assertEquals(-1, service.getEstimatedCostUsd());
+ }
+
+ // ==================== Latency tracking ====================
+
+ @Test
+ void latencyMinMax() {
+ service.recordCall("gpt-4o", 100, 50, 100_000_000L);
+ service.recordCall("gpt-4o", 100, 50, 500_000_000L);
+ service.recordCall("gpt-4o", 100, 50, 200_000_000L);
+
+ String summary = service.getSummary();
+ assertTrue(summary.contains("min 100ms"));
+ assertTrue(summary.contains("max 500ms"));
+ }
+
+ // ==================== Reset ====================
+
+ @Test
+ void resetClearsEverything() {
+ service.recordCall("gpt-4o", 1000, 200, 500_000_000L);
+ service.recordEmbedding("text-embedding-3-small", 100, 10_000_000L);
+ service.recordToolUse("gpt-4o", "search");
+ service.recordError("gpt-4o", "Timeout", 100_000_000L);
+
+ service.reset();
+
+ assertEquals(0, service.getTotalCalls());
+ assertEquals(0, service.getTotalInputTokens());
+ assertEquals(0, service.getTotalOutputTokens());
+ assertEquals(0, service.getTotalEmbeddingCalls());
+ assertEquals("No LLM calls recorded.", service.getSummary());
+ }
+
+ // ==================== Edge cases ====================
+
+ @Test
+ void summaryEmptyWhenNoCalls() {
+ assertEquals("No LLM calls recorded.", service.getSummary());
+ }
+
+ @Test
+ void summaryOmitsTokenLineWhenNoTokens() {
+ service.recordCall("my-model", 500_000_000L);
+ String summary = service.getSummary();
+ // Should show the model and latency but not "0 in / 0 out"
+ assertTrue(summary.contains("my-model"));
+ assertFalse(summary.contains("0 in / 0 out"));
+ }
+
+ @Test
+ void cacheHitRateCalculation() {
+ service.call("claude-sonnet-4-20250514")
+ .inputTokens(200)
+ .outputTokens(100)
+ .cacheReadTokens(800)
+ .duration(100_000_000L)
+ .record();
+
+ String summary = service.getSummary();
+ // 800 cache reads out of 200 input = 400% is wrong semantically,
+ // so check that the cache section exists
+ assertTrue(summary.contains("Cache:"));
+ assertTrue(summary.contains("800 read"));
+ }
+
+ // ==================== Concurrency ====================
+
+ @Test
+ void concurrentRecording() throws InterruptedException {
+ int threads = 8;
+ int callsPerThread = 1000;
+ Thread[] workers = new Thread[threads];
+
+ for (int t = 0; t < threads; t++) {
+ workers[t] = new Thread(() -> {
+ for (int i = 0; i < callsPerThread; i++) {
+ service.recordCall("gpt-4o", 100, 50, 10_000_000L);
+ }
+ });
+ workers[t].start();
+ }
+ for (Thread w : workers) {
+ w.join();
+ }
+
+ assertEquals(threads * callsPerThread, service.getTotalCalls());
+ assertEquals(threads * callsPerThread * 100L, service.getTotalInputTokens());
+ }
+
+ @Test
+ void concurrentBuilderRecording() throws InterruptedException {
+ int threads = 4;
+ int callsPerThread = 500;
+ Thread[] workers = new Thread[threads];
+
+ for (int t = 0; t < threads; t++) {
+ workers[t] = new Thread(() -> {
+ for (int i = 0; i < callsPerThread; i++) {
+ service.call("claude-sonnet-4-20250514")
+ .provider("anthropic")
+ .inputTokens(100)
+ .outputTokens(50)
+ .duration(10_000_000L)
+ .record();
+ }
+ });
+ workers[t].start();
+ }
+ for (Thread w : workers) {
+ w.join();
+ }
+
+ assertEquals(threads * callsPerThread, service.getTotalCalls());
+ }
+}
diff --git a/btrace-extensions/btrace-rag-quality/build.gradle b/btrace-extensions/btrace-rag-quality/build.gradle
new file mode 100644
index 00000000..f22b500a
--- /dev/null
+++ b/btrace-extensions/btrace-rag-quality/build.gradle
@@ -0,0 +1,33 @@
+plugins {
+ id 'org.openjdk.btrace.extension'
+ alias(libs.plugins.shadow)
+}
+
+java {
+ sourceCompatibility = 8
+ targetCompatibility = 8
+}
+
+compileJava {
+ javaCompiler = javaToolchains.compilerFor {
+ languageVersion.set(JavaLanguageVersion.of(11))
+ }
+}
+
+btraceExtension {
+ id = 'btrace-rag-quality'
+ name = 'BTrace RAG Quality'
+ description = 'RAG pipeline observability: retrieval latency, similarity scores, chunk quality for vector DB clients'
+ services = ['org.openjdk.btrace.rag.RagQualityService']
+}
+
+dependencies {
+ apiCompileOnly project(':btrace-core')
+ implImplementation project(':btrace-core')
+
+ testImplementation libs.junit.jupiter
+}
+
+test {
+ useJUnitPlatform()
+}
diff --git a/btrace-extensions/btrace-rag-quality/src/api/java/org/openjdk/btrace/rag/QueryRecord.java b/btrace-extensions/btrace-rag-quality/src/api/java/org/openjdk/btrace/rag/QueryRecord.java
new file mode 100644
index 00000000..08de459a
--- /dev/null
+++ b/btrace-extensions/btrace-rag-quality/src/api/java/org/openjdk/btrace/rag/QueryRecord.java
@@ -0,0 +1,44 @@
+package org.openjdk.btrace.rag;
+
+/**
+ * Fluent builder for recording a vector DB query with detailed metrics.
+ *
+ * Obtain via {@link RagQualityService#query(String)}. All setters are optional.
+ * Call {@link #record()} to commit the metrics.
+ *
+ *
Allocation-free: Instances are pooled per-thread internally.
+ * The returned reference must not be stored or shared across threads.
+ *
+ *
+ * rag.query("pinecone")
+ * .resultCount(5)
+ * .topScore(0.92f)
+ * .lowScore(0.71f)
+ * .embeddingDimension(1536)
+ * .duration(durationNanos)
+ * .record();
+ *
+ */
+public interface QueryRecord {
+
+ /** Number of results/chunks returned by the query. */
+ QueryRecord resultCount(int count);
+
+ /** Highest similarity score in the result set (0.0 to 1.0). */
+ QueryRecord topScore(float score);
+
+ /** Lowest similarity score in the result set (0.0 to 1.0). */
+ QueryRecord lowScore(float score);
+
+ /** Embedding dimension used for the query vector. */
+ QueryRecord embeddingDimension(int dimension);
+
+ /** Total tokens across all returned chunks. */
+ QueryRecord totalChunkTokens(int tokens);
+
+ /** Query duration in nanoseconds. */
+ QueryRecord duration(long nanos);
+
+ /** Commits this query record to the service. */
+ void record();
+}
diff --git a/btrace-extensions/btrace-rag-quality/src/api/java/org/openjdk/btrace/rag/RagQualityService.java b/btrace-extensions/btrace-rag-quality/src/api/java/org/openjdk/btrace/rag/RagQualityService.java
new file mode 100644
index 00000000..1eaccb98
--- /dev/null
+++ b/btrace-extensions/btrace-rag-quality/src/api/java/org/openjdk/btrace/rag/RagQualityService.java
@@ -0,0 +1,111 @@
+package org.openjdk.btrace.rag;
+
+import org.openjdk.btrace.core.extensions.ServiceDescriptor;
+
+/**
+ * BTrace extension service for RAG (Retrieval-Augmented Generation) pipeline observability.
+ *
+ * Tracks vector database query performance, retrieval quality metrics,
+ * and end-to-end RAG pipeline latency. Works with any vector DB client
+ * (Pinecone, Milvus, Weaviate, Chroma, pgvector, Qdrant).
+ *
+ *
Usage in a BTrace script:
+ *
+ * @Injected RagQualityService rag;
+ *
+ * @OnMethod(clazz = "io.pinecone.PineconeClient", method = "query")
+ * void onQuery(@Duration long dur) {
+ * rag.recordQuery("pinecone", 10, dur);
+ * }
+ *
+ */
+@ServiceDescriptor
+public interface RagQualityService {
+
+ // ==================== Query recording ====================
+
+ /**
+ * Records a vector DB query with duration only.
+ *
+ * @param source vector DB or index name (e.g. "pinecone", "milvus-products")
+ * @param durationNanos query duration in nanoseconds
+ */
+ void recordQuery(String source, long durationNanos);
+
+ /**
+ * Records a vector DB query with result count and duration.
+ *
+ * @param source vector DB or index name
+ * @param resultCount number of results/chunks returned (top-K)
+ * @param durationNanos query duration in nanoseconds
+ */
+ void recordQuery(String source, int resultCount, long durationNanos);
+
+ /**
+ * Starts a detailed query record builder. Allocation-free (ThreadLocal-pooled).
+ *
+ * Use inline on the calling thread and call {@link QueryRecord#record()}
+ * before the next {@code query()} call. Do not store the returned reference.
+ *
+ *
+ * rag.query("pinecone")
+ * .resultCount(5)
+ * .topScore(0.92f)
+ * .lowScore(0.71f)
+ * .embeddingDimension(1536)
+ * .duration(durationNanos)
+ * .record();
+ *
+ *
+ * @param source vector DB or index name
+ * @return a query record builder (thread-local, do not store)
+ */
+ QueryRecord query(String source);
+
+ // ==================== Pipeline recording ====================
+
+ /**
+ * Records an end-to-end RAG pipeline invocation (retrieve + generate).
+ *
+ * @param pipelineName pipeline identifier
+ * @param retrievalNanos time spent in retrieval phase
+ * @param generationNanos time spent in generation phase
+ */
+ void recordPipeline(String pipelineName, long retrievalNanos, long generationNanos);
+
+ /**
+ * Records a chunk that was retrieved and used in context.
+ * Useful for tracking context window utilization.
+ *
+ * @param source vector DB or index name
+ * @param chunkTokens approximate token count of the chunk
+ */
+ void recordChunk(String source, int chunkTokens);
+
+ /**
+ * Records a retrieval that returned no results (empty context).
+ *
+ * @param source vector DB or index name
+ */
+ void recordEmptyRetrieval(String source);
+
+ // ==================== Reporting ====================
+
+ /** Returns a formatted summary of all RAG quality metrics. */
+ String getSummary();
+
+ /** Returns summary for a specific source/index. */
+ String getSourceSummary(String source);
+
+ /** Returns total number of queries recorded across all sources. */
+ long getTotalQueries();
+
+ /** Returns total number of empty retrievals. */
+ long getTotalEmptyRetrievals();
+
+ /** Returns average similarity score across all queries that reported scores. */
+ float getAverageTopScore();
+
+ /** Resets all collected metrics. */
+ void reset();
+}
diff --git a/btrace-extensions/btrace-rag-quality/src/impl/java/org/openjdk/btrace/rag/QueryRecordImpl.java b/btrace-extensions/btrace-rag-quality/src/impl/java/org/openjdk/btrace/rag/QueryRecordImpl.java
new file mode 100644
index 00000000..b46a2909
--- /dev/null
+++ b/btrace-extensions/btrace-rag-quality/src/impl/java/org/openjdk/btrace/rag/QueryRecordImpl.java
@@ -0,0 +1,73 @@
+package org.openjdk.btrace.rag;
+
+/**
+ * ThreadLocal-pooled builder implementing the {@link QueryRecord} fluent API.
+ * One instance per thread, reused across calls — zero heap allocation.
+ */
+final class QueryRecordImpl implements QueryRecord {
+
+ String source;
+ int resultCountVal;
+ float topScoreVal;
+ float lowScoreVal;
+ int embDimension;
+ int chunkTokensVal;
+ long durationVal;
+
+ private RagQualityServiceImpl service;
+
+ QueryRecordImpl() {}
+
+ QueryRecordImpl reset(RagQualityServiceImpl service, String source) {
+ this.service = service;
+ this.source = source;
+ this.resultCountVal = 0;
+ this.topScoreVal = -1f;
+ this.lowScoreVal = -1f;
+ this.embDimension = 0;
+ this.chunkTokensVal = 0;
+ this.durationVal = 0;
+ return this;
+ }
+
+ @Override
+ public QueryRecord resultCount(int count) {
+ this.resultCountVal = count;
+ return this;
+ }
+
+ @Override
+ public QueryRecord topScore(float score) {
+ this.topScoreVal = score;
+ return this;
+ }
+
+ @Override
+ public QueryRecord lowScore(float score) {
+ this.lowScoreVal = score;
+ return this;
+ }
+
+ @Override
+ public QueryRecord embeddingDimension(int dimension) {
+ this.embDimension = dimension;
+ return this;
+ }
+
+ @Override
+ public QueryRecord totalChunkTokens(int tokens) {
+ this.chunkTokensVal = tokens;
+ return this;
+ }
+
+ @Override
+ public QueryRecord duration(long nanos) {
+ this.durationVal = nanos;
+ return this;
+ }
+
+ @Override
+ public void record() {
+ service.commitQueryRecord(this);
+ }
+}
diff --git a/btrace-extensions/btrace-rag-quality/src/impl/java/org/openjdk/btrace/rag/RagQualityServiceImpl.java b/btrace-extensions/btrace-rag-quality/src/impl/java/org/openjdk/btrace/rag/RagQualityServiceImpl.java
new file mode 100644
index 00000000..3acfd7d4
--- /dev/null
+++ b/btrace-extensions/btrace-rag-quality/src/impl/java/org/openjdk/btrace/rag/RagQualityServiceImpl.java
@@ -0,0 +1,320 @@
+package org.openjdk.btrace.rag;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
+import org.openjdk.btrace.core.extensions.Extension;
+
+/**
+ * Thread-safe RAG quality tracking with lock-free per-source statistics.
+ */
+public final class RagQualityServiceImpl extends Extension implements RagQualityService {
+
+ private final Map sourceStats = new ConcurrentHashMap<>();
+ private final Map pipelineStats = new ConcurrentHashMap<>();
+
+ private final ThreadLocal queryRecordPool =
+ ThreadLocal.withInitial(QueryRecordImpl::new);
+
+ // ==================== Simple recording ====================
+
+ @Override
+ public void recordQuery(String source, long durationNanos) {
+ SourceStats stats = getOrCreate(source);
+ stats.queries.incrementAndGet();
+ stats.totalDurationNanos.addAndGet(durationNanos);
+ updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos);
+ }
+
+ @Override
+ public void recordQuery(String source, int resultCount, long durationNanos) {
+ SourceStats stats = getOrCreate(source);
+ stats.queries.incrementAndGet();
+ stats.totalResults.addAndGet(resultCount);
+ stats.totalDurationNanos.addAndGet(durationNanos);
+ updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos);
+ if (resultCount == 0) {
+ stats.emptyRetrievals.incrementAndGet();
+ }
+ }
+
+ // ==================== Fluent builder ====================
+
+ @Override
+ public QueryRecord query(String source) {
+ return queryRecordPool.get().reset(this, source);
+ }
+
+ void commitQueryRecord(QueryRecordImpl rec) {
+ SourceStats stats = getOrCreate(rec.source);
+ stats.queries.incrementAndGet();
+ stats.totalResults.addAndGet(rec.resultCountVal);
+ stats.totalDurationNanos.addAndGet(rec.durationVal);
+ updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, rec.durationVal);
+
+ if (rec.resultCountVal == 0) {
+ stats.emptyRetrievals.incrementAndGet();
+ }
+
+ if (rec.topScoreVal >= 0) {
+ stats.scoredQueries.incrementAndGet();
+ // Accumulate scores as fixed-point (multiply by 10000) to avoid floating point atomics
+ stats.totalTopScore.addAndGet((long) (rec.topScoreVal * 10000));
+ updateMin(stats.minTopScore, (long) (rec.topScoreVal * 10000));
+ updateMax(stats.maxTopScore, (long) (rec.topScoreVal * 10000));
+ }
+ if (rec.lowScoreVal >= 0) {
+ stats.totalLowScore.addAndGet((long) (rec.lowScoreVal * 10000));
+ }
+ if (rec.chunkTokensVal > 0) {
+ stats.totalChunkTokens.addAndGet(rec.chunkTokensVal);
+ }
+ if (rec.embDimension > 0) {
+ stats.lastEmbeddingDimension = rec.embDimension;
+ }
+ }
+
+ // ==================== Pipeline recording ====================
+
+ @Override
+ public void recordPipeline(String pipelineName, long retrievalNanos, long generationNanos) {
+ PipelineStats ps = pipelineStats.computeIfAbsent(pipelineName, k -> new PipelineStats());
+ ps.invocations.incrementAndGet();
+ ps.totalRetrievalNanos.addAndGet(retrievalNanos);
+ ps.totalGenerationNanos.addAndGet(generationNanos);
+ }
+
+ @Override
+ public void recordChunk(String source, int chunkTokens) {
+ SourceStats stats = getOrCreate(source);
+ stats.totalChunkTokens.addAndGet(chunkTokens);
+ stats.chunkCount.incrementAndGet();
+ }
+
+ @Override
+ public void recordEmptyRetrieval(String source) {
+ getOrCreate(source).emptyRetrievals.incrementAndGet();
+ }
+
+ // ==================== Reporting ====================
+
+ @Override
+ public String getSummary() {
+ if (sourceStats.isEmpty() && pipelineStats.isEmpty()) {
+ return "No RAG queries recorded.";
+ }
+
+ StringBuilder sb = new StringBuilder();
+ sb.append("=== RAG Quality Summary ===\n\n");
+
+ long totalQueries = 0;
+ long totalEmpty = 0;
+
+ for (Map.Entry entry : sourceStats.entrySet()) {
+ String source = entry.getKey();
+ SourceStats s = entry.getValue();
+ long queries = s.queries.get();
+ totalQueries += queries;
+ long empty = s.emptyRetrievals.get();
+ totalEmpty += empty;
+
+ sb.append("Source: ").append(source).append("\n");
+ sb.append(" Queries: ").append(queries);
+ if (empty > 0) {
+ sb.append(" (").append(empty).append(" empty, ")
+ .append(empty * 100 / queries).append("%)");
+ }
+ sb.append("\n");
+
+ // Results
+ long totalRes = s.totalResults.get();
+ if (totalRes > 0 && queries > 0) {
+ sb.append(" Results: ").append(totalRes)
+ .append(" total (avg ").append(totalRes / queries).append("/query)\n");
+ }
+
+ // Similarity scores
+ long scored = s.scoredQueries.get();
+ if (scored > 0) {
+ float avgTop = (s.totalTopScore.get() / (float) scored) / 10000f;
+ float minTop = s.minTopScore.get() / 10000f;
+ float maxTop = s.maxTopScore.get() / 10000f;
+ sb.append(" Similarity (top): avg ").append(String.format("%.3f", avgTop));
+ sb.append(", min ").append(String.format("%.3f", minTop));
+ sb.append(", max ").append(String.format("%.3f", maxTop));
+ sb.append("\n");
+
+ long totalLow = s.totalLowScore.get();
+ if (totalLow > 0) {
+ float avgLow = (totalLow / (float) scored) / 10000f;
+ float spread = avgTop - avgLow;
+ sb.append(" Similarity (low): avg ").append(String.format("%.3f", avgLow));
+ sb.append(" (spread ").append(String.format("%.3f", spread)).append(")\n");
+ }
+ }
+
+ // Latency
+ if (queries > 0) {
+ long avgMs = (s.totalDurationNanos.get() / queries) / 1_000_000;
+ long minMs = s.minDurationNanos.get() / 1_000_000;
+ long maxMs = s.maxDurationNanos.get() / 1_000_000;
+ sb.append(" Latency: avg ").append(avgMs).append("ms");
+ sb.append(", min ").append(minMs).append("ms");
+ sb.append(", max ").append(maxMs).append("ms\n");
+ }
+
+ // Chunks
+ long chunks = s.chunkCount.get();
+ long chunkTokens = s.totalChunkTokens.get();
+ if (chunkTokens > 0) {
+ sb.append(" Context: ").append(chunkTokens).append(" tokens");
+ if (chunks > 0) {
+ sb.append(" (").append(chunks).append(" chunks, avg ")
+ .append(chunkTokens / chunks).append(" tok/chunk)");
+ }
+ sb.append("\n");
+ }
+
+ sb.append("\n");
+ }
+
+ // Pipelines
+ if (!pipelineStats.isEmpty()) {
+ sb.append("--- Pipelines ---\n");
+ for (Map.Entry entry : pipelineStats.entrySet()) {
+ PipelineStats ps = entry.getValue();
+ long inv = ps.invocations.get();
+ long avgRetMs = inv > 0 ? (ps.totalRetrievalNanos.get() / inv) / 1_000_000 : 0;
+ long avgGenMs = inv > 0 ? (ps.totalGenerationNanos.get() / inv) / 1_000_000 : 0;
+ sb.append(" ").append(entry.getKey()).append(": ")
+ .append(inv).append(" invocations, avg retrieval ")
+ .append(avgRetMs).append("ms, avg generation ")
+ .append(avgGenMs).append("ms\n");
+ }
+ sb.append("\n");
+ }
+
+ sb.append("--- Totals ---\n");
+ sb.append(" Queries: ").append(totalQueries).append("\n");
+ if (totalEmpty > 0) {
+ sb.append(" Empty retrievals: ").append(totalEmpty);
+ if (totalQueries > 0) {
+ sb.append(" (").append(totalEmpty * 100 / totalQueries).append("%)");
+ }
+ sb.append("\n");
+ }
+
+ return sb.toString();
+ }
+
+ @Override
+ public String getSourceSummary(String source) {
+ SourceStats s = sourceStats.get(source);
+ if (s == null) {
+ return "No data for source: " + source;
+ }
+ long queries = s.queries.get();
+ long avgMs = queries > 0 ? (s.totalDurationNanos.get() / queries) / 1_000_000 : 0;
+ long empty = s.emptyRetrievals.get();
+ StringBuilder sb = new StringBuilder();
+ sb.append(source).append(": ").append(queries).append(" queries, avg ").append(avgMs).append("ms");
+ if (empty > 0) {
+ sb.append(", ").append(empty).append(" empty");
+ }
+ return sb.toString();
+ }
+
+ @Override
+ public long getTotalQueries() {
+ long total = 0;
+ for (SourceStats s : sourceStats.values()) {
+ total += s.queries.get();
+ }
+ return total;
+ }
+
+ @Override
+ public long getTotalEmptyRetrievals() {
+ long total = 0;
+ for (SourceStats s : sourceStats.values()) {
+ total += s.emptyRetrievals.get();
+ }
+ return total;
+ }
+
+ @Override
+ public float getAverageTopScore() {
+ long totalScored = 0;
+ long totalScore = 0;
+ for (SourceStats s : sourceStats.values()) {
+ totalScored += s.scoredQueries.get();
+ totalScore += s.totalTopScore.get();
+ }
+ if (totalScored == 0) return -1f;
+ return (totalScore / (float) totalScored) / 10000f;
+ }
+
+ @Override
+ public void reset() {
+ sourceStats.clear();
+ pipelineStats.clear();
+ }
+
+ @Override
+ public void close() {
+ String summary = getSummary();
+ if (!"No RAG queries recorded.".equals(summary)) {
+ getContext().send(summary);
+ }
+ }
+
+ // ==================== Internals ====================
+
+ private SourceStats getOrCreate(String source) {
+ return sourceStats.computeIfAbsent(source, k -> new SourceStats());
+ }
+
+ private static void updateMinMax(AtomicLong min, AtomicLong max, long value) {
+ updateMin(min, value);
+ updateMax(max, value);
+ }
+
+ private static void updateMin(AtomicLong min, long value) {
+ long cur;
+ do {
+ cur = min.get();
+ if (value >= cur) break;
+ } while (!min.compareAndSet(cur, value));
+ }
+
+ private static void updateMax(AtomicLong max, long value) {
+ long cur;
+ do {
+ cur = max.get();
+ if (value <= cur) break;
+ } while (!max.compareAndSet(cur, value));
+ }
+
+ static final class SourceStats {
+ final AtomicLong queries = new AtomicLong();
+ final AtomicLong totalResults = new AtomicLong();
+ final AtomicLong emptyRetrievals = new AtomicLong();
+ final AtomicLong scoredQueries = new AtomicLong();
+ final AtomicLong totalTopScore = new AtomicLong(); // fixed-point * 10000
+ final AtomicLong minTopScore = new AtomicLong(Long.MAX_VALUE);
+ final AtomicLong maxTopScore = new AtomicLong(0);
+ final AtomicLong totalLowScore = new AtomicLong(); // fixed-point * 10000
+ final AtomicLong totalDurationNanos = new AtomicLong();
+ final AtomicLong minDurationNanos = new AtomicLong(Long.MAX_VALUE);
+ final AtomicLong maxDurationNanos = new AtomicLong(0);
+ final AtomicLong totalChunkTokens = new AtomicLong();
+ final AtomicLong chunkCount = new AtomicLong();
+ volatile int lastEmbeddingDimension;
+ }
+
+ static final class PipelineStats {
+ final AtomicLong invocations = new AtomicLong();
+ final AtomicLong totalRetrievalNanos = new AtomicLong();
+ final AtomicLong totalGenerationNanos = new AtomicLong();
+ }
+}
diff --git a/btrace-extensions/btrace-rag-quality/src/test/java/org/openjdk/btrace/rag/RagQualityServiceTest.java b/btrace-extensions/btrace-rag-quality/src/test/java/org/openjdk/btrace/rag/RagQualityServiceTest.java
new file mode 100644
index 00000000..51ae7243
--- /dev/null
+++ b/btrace-extensions/btrace-rag-quality/src/test/java/org/openjdk/btrace/rag/RagQualityServiceTest.java
@@ -0,0 +1,210 @@
+package org.openjdk.btrace.rag;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.atomic.AtomicLong;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+class RagQualityServiceTest {
+
+ private RagQualityServiceImpl service;
+
+ @BeforeEach
+ void setUp() {
+ service = new RagQualityServiceImpl();
+ }
+
+ @Test
+ void durationOnlyQuery() {
+ service.recordQuery("pinecone", 5_000_000L);
+ assertEquals(1, service.getTotalQueries());
+ String summary = service.getSummary();
+ assertTrue(summary.contains("pinecone"));
+ assertTrue(summary.contains("1"));
+ }
+
+ @Test
+ void queryWithResultCount() {
+ service.recordQuery("milvus", 10, 8_000_000L);
+ assertEquals(1, service.getTotalQueries());
+ assertEquals(0, service.getTotalEmptyRetrievals());
+ assertTrue(service.getSummary().contains("10 total"));
+ }
+
+ @Test
+ void emptyRetrieval() {
+ service.recordQuery("weaviate", 0, 2_000_000L);
+ assertEquals(1, service.getTotalEmptyRetrievals());
+ assertTrue(service.getSummary().contains("1 empty"));
+ }
+
+ @Test
+ void explicitEmptyRetrieval() {
+ service.recordEmptyRetrieval("chroma");
+ assertEquals(1, service.getTotalEmptyRetrievals());
+ }
+
+ @Test
+ void fluentBuilder() {
+ service.query("pinecone")
+ .resultCount(5)
+ .topScore(0.92f)
+ .lowScore(0.71f)
+ .embeddingDimension(1536)
+ .totalChunkTokens(3000)
+ .duration(10_000_000L)
+ .record();
+
+ assertEquals(1, service.getTotalQueries());
+ float avgScore = service.getAverageTopScore();
+ assertTrue(avgScore > 0.91f && avgScore < 0.93f, "avgScore=" + avgScore);
+ String summary = service.getSummary();
+ assertTrue(summary.contains("0.920"));
+ assertTrue(summary.contains("3000 tokens"));
+ }
+
+ @Test
+ void fluentBuilderMinimal() {
+ service.query("qdrant")
+ .duration(1_000_000L)
+ .record();
+
+ assertEquals(1, service.getTotalQueries());
+ assertEquals(-1f, service.getAverageTopScore());
+ }
+
+ @Test
+ void fluentBuilderEmptyResult() {
+ service.query("pgvector")
+ .resultCount(0)
+ .duration(500_000L)
+ .record();
+
+ assertEquals(1, service.getTotalEmptyRetrievals());
+ }
+
+ @Test
+ void multipleSourcesTrackedSeparately() {
+ service.recordQuery("pinecone", 5, 10_000_000L);
+ service.recordQuery("milvus", 3, 8_000_000L);
+ service.recordQuery("pinecone", 7, 12_000_000L);
+
+ assertEquals(3, service.getTotalQueries());
+ assertTrue(service.getSourceSummary("pinecone").contains("2 queries"));
+ assertTrue(service.getSourceSummary("milvus").contains("1 queries"));
+ }
+
+ @Test
+ void pipelineRecording() {
+ service.recordPipeline("qa-bot", 5_000_000L, 50_000_000L);
+ service.recordPipeline("qa-bot", 3_000_000L, 45_000_000L);
+
+ String summary = service.getSummary();
+ assertTrue(summary.contains("qa-bot"));
+ assertTrue(summary.contains("2 invocations"));
+ }
+
+ @Test
+ void chunkRecording() {
+ service.recordChunk("pinecone", 500);
+ service.recordChunk("pinecone", 750);
+
+ String summary = service.getSourceSummary("pinecone");
+ // Chunks don't count as queries
+ assertTrue(summary.contains("0 queries"));
+ }
+
+ @Test
+ void similarityScoreAggregation() {
+ service.query("pinecone").topScore(0.90f).duration(1_000_000L).record();
+ service.query("pinecone").topScore(0.80f).duration(1_000_000L).record();
+ service.query("milvus").topScore(0.70f).duration(1_000_000L).record();
+
+ float avg = service.getAverageTopScore();
+ // (0.90 + 0.80 + 0.70) / 3 = 0.8
+ assertTrue(avg > 0.79f && avg < 0.81f, "avg=" + avg);
+ }
+
+ @Test
+ void latencyMinMax() {
+ service.recordQuery("pinecone", 10_000_000L);
+ service.recordQuery("pinecone", 2_000_000L);
+ service.recordQuery("pinecone", 50_000_000L);
+
+ String summary = service.getSummary();
+ assertTrue(summary.contains("min 2ms"));
+ assertTrue(summary.contains("max 50ms"));
+ }
+
+ @Test
+ void noDataSummary() {
+ assertEquals("No RAG queries recorded.", service.getSummary());
+ }
+
+ @Test
+ void unknownSourceSummary() {
+ assertEquals("No data for source: unknown", service.getSourceSummary("unknown"));
+ }
+
+ @Test
+ void reset() {
+ service.recordQuery("pinecone", 5_000_000L);
+ service.recordEmptyRetrieval("milvus");
+ service.recordPipeline("qa", 1L, 1L);
+ service.reset();
+
+ assertEquals(0, service.getTotalQueries());
+ assertEquals(0, service.getTotalEmptyRetrievals());
+ assertEquals("No RAG queries recorded.", service.getSummary());
+ }
+
+ @Test
+ void concurrentRecording() throws Exception {
+ int threads = 8;
+ int queriesPerThread = 1000;
+ CountDownLatch latch = new CountDownLatch(threads);
+
+ for (int t = 0; t < threads; t++) {
+ new Thread(() -> {
+ try {
+ for (int i = 0; i < queriesPerThread; i++) {
+ service.recordQuery("pinecone", 5, 1_000_000L);
+ }
+ } finally {
+ latch.countDown();
+ }
+ }).start();
+ }
+ latch.await();
+
+ assertEquals(threads * queriesPerThread, service.getTotalQueries());
+ }
+
+ @Test
+ void concurrentBuilderRecording() throws Exception {
+ int threads = 8;
+ int queriesPerThread = 500;
+ CountDownLatch latch = new CountDownLatch(threads);
+
+ for (int t = 0; t < threads; t++) {
+ new Thread(() -> {
+ try {
+ for (int i = 0; i < queriesPerThread; i++) {
+ service.query("milvus")
+ .resultCount(3)
+ .topScore(0.85f)
+ .duration(2_000_000L)
+ .record();
+ }
+ } finally {
+ latch.countDown();
+ }
+ }).start();
+ }
+ latch.await();
+
+ assertEquals(threads * queriesPerThread, service.getTotalQueries());
+ }
+}
diff --git a/btrace-extensions/btrace-vibe-guard/build.gradle b/btrace-extensions/btrace-vibe-guard/build.gradle
new file mode 100644
index 00000000..58a3ef5f
--- /dev/null
+++ b/btrace-extensions/btrace-vibe-guard/build.gradle
@@ -0,0 +1,33 @@
+plugins {
+ id 'org.openjdk.btrace.extension'
+ alias(libs.plugins.shadow)
+}
+
+java {
+ sourceCompatibility = 8
+ targetCompatibility = 8
+}
+
+compileJava {
+ javaCompiler = javaToolchains.compilerFor {
+ languageVersion.set(JavaLanguageVersion.of(11))
+ }
+}
+
+btraceExtension {
+ id = 'btrace-vibe-guard'
+ name = 'BTrace Vibe Guard'
+ description = 'Behavioral contracts and runtime validation for AI-generated code paths'
+ services = ['org.openjdk.btrace.vibeguard.VibeGuardService']
+}
+
+dependencies {
+ apiCompileOnly project(':btrace-core')
+ implImplementation project(':btrace-core')
+
+ testImplementation libs.junit.jupiter
+}
+
+test {
+ useJUnitPlatform()
+}
diff --git a/btrace-extensions/btrace-vibe-guard/src/api/java/org/openjdk/btrace/vibeguard/VibeGuardService.java b/btrace-extensions/btrace-vibe-guard/src/api/java/org/openjdk/btrace/vibeguard/VibeGuardService.java
new file mode 100644
index 00000000..cbdc713f
--- /dev/null
+++ b/btrace-extensions/btrace-vibe-guard/src/api/java/org/openjdk/btrace/vibeguard/VibeGuardService.java
@@ -0,0 +1,125 @@
+package org.openjdk.btrace.vibeguard;
+
+import org.openjdk.btrace.core.extensions.ServiceDescriptor;
+
+/**
+ * BTrace extension for runtime behavioral contracts on AI-generated code.
+ *
+ * Validates that methods respect declared invariants at runtime — useful for
+ * "vibe coding" workflows where LLMs generate code that needs guardrails.
+ * Tracks contract violations, latency budgets, call frequency limits, and
+ * return value constraints without modifying the target code.
+ *
+ *
Contracts are defined in BTrace scripts and enforced at instrumentation
+ * points. When a contract is violated, the service records it and optionally
+ * triggers an alert via the configured handler.
+ *
+ *
Usage in a BTrace script:
+ *
+ * @Injected VibeGuardService guard;
+ *
+ * @OnMethod(clazz = "com.app.AiService", method = "generate")
+ * void onEntry() {
+ * guard.checkCallRate("AiService.generate", 100); // max 100 calls/sec
+ * }
+ *
+ * @OnMethod(clazz = "com.app.AiService", method = "generate",
+ * location = @Location(Kind.RETURN))
+ * void onReturn(@Duration long dur) {
+ * guard.checkLatency("AiService.generate", dur, 500_000_000L); // 500ms budget
+ * }
+ *
+ */
+@ServiceDescriptor
+public interface VibeGuardService {
+
+ // ==================== Contract checks ====================
+
+ /**
+ * Checks that a method's latency does not exceed the budget.
+ * Records a violation if {@code durationNanos > budgetNanos}.
+ *
+ * @param contract contract/method name
+ * @param durationNanos actual duration
+ * @param budgetNanos maximum allowed duration
+ */
+ void checkLatency(String contract, long durationNanos, long budgetNanos);
+
+ /**
+ * Checks that call rate does not exceed the limit per second.
+ * Uses a sliding window to detect bursts.
+ *
+ * @param contract contract/method name
+ * @param maxPerSecond maximum allowed calls per second
+ */
+ void checkCallRate(String contract, int maxPerSecond);
+
+ /**
+ * Asserts a boolean condition. Records a violation if false.
+ *
+ * @param contract contract name
+ * @param condition the condition to check
+ * @param message violation message if condition is false
+ */
+ void assertCondition(String contract, boolean condition, String message);
+
+ /**
+ * Checks that a numeric return value is within bounds.
+ *
+ * @param contract contract name
+ * @param value actual value
+ * @param min minimum allowed (inclusive)
+ * @param max maximum allowed (inclusive)
+ */
+ void checkRange(String contract, long value, long min, long max);
+
+ /**
+ * Checks that a return value is not null. Records a violation if null.
+ *
+ * @param contract contract name
+ * @param value the value to check
+ */
+ void checkNotNull(String contract, Object value);
+
+ // ==================== Tracking ====================
+
+ /**
+ * Tags a method invocation as AI-generated. Used to compare behavior
+ * between AI and human code paths.
+ *
+ * @param contract contract/method name
+ * @param durationNanos execution duration
+ */
+ void trackAiCodePath(String contract, long durationNanos);
+
+ /**
+ * Tags a method invocation as human-written (baseline).
+ *
+ * @param contract contract/method name
+ * @param durationNanos execution duration
+ */
+ void trackHumanCodePath(String contract, long durationNanos);
+
+ // ==================== Reporting ====================
+
+ /** Returns a formatted summary of all contract violations and stats. */
+ String getSummary();
+
+ /** Returns total number of contract violations across all contracts. */
+ long getTotalViolations();
+
+ /** Returns number of violations for a specific contract. */
+ long getViolations(String contract);
+
+ /** Returns total number of contract checks performed. */
+ long getTotalChecks();
+
+ /**
+ * Returns true if any contract has been violated.
+ * Useful in {@code @OnEvent} handlers for alerting.
+ */
+ boolean hasViolations();
+
+ /** Resets all metrics and violation history. */
+ void reset();
+}
diff --git a/btrace-extensions/btrace-vibe-guard/src/impl/java/org/openjdk/btrace/vibeguard/VibeGuardServiceImpl.java b/btrace-extensions/btrace-vibe-guard/src/impl/java/org/openjdk/btrace/vibeguard/VibeGuardServiceImpl.java
new file mode 100644
index 00000000..152730a1
--- /dev/null
+++ b/btrace-extensions/btrace-vibe-guard/src/impl/java/org/openjdk/btrace/vibeguard/VibeGuardServiceImpl.java
@@ -0,0 +1,320 @@
+package org.openjdk.btrace.vibeguard;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import java.util.concurrent.atomic.AtomicLong;
+import org.openjdk.btrace.core.extensions.Extension;
+
+/**
+ * Thread-safe behavioral contract enforcement with lock-free statistics.
+ */
+public final class VibeGuardServiceImpl extends Extension implements VibeGuardService {
+
+ private final Map contracts = new ConcurrentHashMap<>();
+ private final Map rateWindows = new ConcurrentHashMap<>();
+ private final Map aiPaths = new ConcurrentHashMap<>();
+ private final Map humanPaths = new ConcurrentHashMap<>();
+
+ // ==================== Contract checks ====================
+
+ @Override
+ public void checkLatency(String contract, long durationNanos, long budgetNanos) {
+ ContractStats stats = getOrCreate(contract);
+ stats.checks.incrementAndGet();
+ if (durationNanos > budgetNanos) {
+ stats.violations.incrementAndGet();
+ stats.lastViolationMessage = "Latency " + (durationNanos / 1_000_000)
+ + "ms exceeded budget " + (budgetNanos / 1_000_000) + "ms";
+ }
+ stats.totalDurationNanos.addAndGet(durationNanos);
+ updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos);
+ }
+
+ @Override
+ public void checkCallRate(String contract, int maxPerSecond) {
+ ContractStats stats = getOrCreate(contract);
+ stats.checks.incrementAndGet();
+
+ RateWindow window = rateWindows.computeIfAbsent(contract, k -> new RateWindow());
+ long now = System.nanoTime();
+ long count = window.recordAndCount(now);
+ if (count > maxPerSecond) {
+ stats.violations.incrementAndGet();
+ stats.lastViolationMessage = "Rate " + count + "/sec exceeded limit " + maxPerSecond + "/sec";
+ }
+ }
+
+ @Override
+ public void assertCondition(String contract, boolean condition, String message) {
+ ContractStats stats = getOrCreate(contract);
+ stats.checks.incrementAndGet();
+ if (!condition) {
+ stats.violations.incrementAndGet();
+ stats.lastViolationMessage = message;
+ }
+ }
+
+ @Override
+ public void checkRange(String contract, long value, long min, long max) {
+ ContractStats stats = getOrCreate(contract);
+ stats.checks.incrementAndGet();
+ if (value < min || value > max) {
+ stats.violations.incrementAndGet();
+ stats.lastViolationMessage = "Value " + value + " outside range [" + min + ", " + max + "]";
+ }
+ }
+
+ @Override
+ public void checkNotNull(String contract, Object value) {
+ ContractStats stats = getOrCreate(contract);
+ stats.checks.incrementAndGet();
+ if (value == null) {
+ stats.violations.incrementAndGet();
+ stats.lastViolationMessage = "Unexpected null return";
+ }
+ }
+
+ // ==================== Tracking ====================
+
+ @Override
+ public void trackAiCodePath(String contract, long durationNanos) {
+ CodePathStats s = aiPaths.computeIfAbsent(contract, k -> new CodePathStats());
+ s.calls.incrementAndGet();
+ s.totalDurationNanos.addAndGet(durationNanos);
+ }
+
+ @Override
+ public void trackHumanCodePath(String contract, long durationNanos) {
+ CodePathStats s = humanPaths.computeIfAbsent(contract, k -> new CodePathStats());
+ s.calls.incrementAndGet();
+ s.totalDurationNanos.addAndGet(durationNanos);
+ }
+
+ // ==================== Reporting ====================
+
+ @Override
+ public String getSummary() {
+ if (contracts.isEmpty() && aiPaths.isEmpty()) {
+ return "No contracts checked.";
+ }
+
+ StringBuilder sb = new StringBuilder();
+ sb.append("=== Vibe Guard Summary ===\n\n");
+
+ long totalChecks = 0;
+ long totalViolations = 0;
+
+ for (Map.Entry entry : contracts.entrySet()) {
+ String name = entry.getKey();
+ ContractStats s = entry.getValue();
+ long checks = s.checks.get();
+ long violations = s.violations.get();
+ totalChecks += checks;
+ totalViolations += violations;
+
+ sb.append("Contract: ").append(name).append("\n");
+ sb.append(" Checks: ").append(checks);
+ if (violations > 0) {
+ sb.append(" | VIOLATIONS: ").append(violations);
+ sb.append(" (").append(violations * 100 / checks).append("%)");
+ } else {
+ sb.append(" | OK");
+ }
+ sb.append("\n");
+
+ // Latency stats if tracked
+ long dur = s.totalDurationNanos.get();
+ if (dur > 0 && checks > 0) {
+ long avgMs = (dur / checks) / 1_000_000;
+ long minMs = s.minDurationNanos.get() == Long.MAX_VALUE ? 0 : s.minDurationNanos.get() / 1_000_000;
+ long maxMs = s.maxDurationNanos.get() / 1_000_000;
+ sb.append(" Latency: avg ").append(avgMs).append("ms");
+ sb.append(", min ").append(minMs).append("ms");
+ sb.append(", max ").append(maxMs).append("ms\n");
+ }
+
+ // Last violation
+ if (violations > 0 && s.lastViolationMessage != null) {
+ sb.append(" Last: ").append(s.lastViolationMessage).append("\n");
+ }
+ sb.append("\n");
+ }
+
+ // AI vs Human comparison
+ if (!aiPaths.isEmpty() || !humanPaths.isEmpty()) {
+ sb.append("--- AI vs Human Code Paths ---\n");
+ // Collect all contract names from both
+ ConcurrentHashMap allNames = new ConcurrentHashMap<>();
+ for (String k : aiPaths.keySet()) allNames.put(k, Boolean.TRUE);
+ for (String k : humanPaths.keySet()) allNames.put(k, Boolean.TRUE);
+
+ for (String name : allNames.keySet()) {
+ CodePathStats ai = aiPaths.get(name);
+ CodePathStats human = humanPaths.get(name);
+
+ sb.append(" ").append(name).append(": ");
+ if (ai != null) {
+ long aiCalls = ai.calls.get();
+ long aiAvgMs = aiCalls > 0 ? (ai.totalDurationNanos.get() / aiCalls) / 1_000_000 : 0;
+ sb.append("AI ").append(aiCalls).append(" calls avg ").append(aiAvgMs).append("ms");
+ }
+ if (ai != null && human != null) sb.append(" | ");
+ if (human != null) {
+ long hCalls = human.calls.get();
+ long hAvgMs = hCalls > 0 ? (human.totalDurationNanos.get() / hCalls) / 1_000_000 : 0;
+ sb.append("Human ").append(hCalls).append(" calls avg ").append(hAvgMs).append("ms");
+ }
+
+ // Performance comparison
+ if (ai != null && human != null) {
+ long aiCalls = ai.calls.get();
+ long hCalls = human.calls.get();
+ if (aiCalls > 0 && hCalls > 0) {
+ long aiAvg = ai.totalDurationNanos.get() / aiCalls;
+ long hAvg = human.totalDurationNanos.get() / hCalls;
+ if (hAvg > 0) {
+ long pctDiff = ((aiAvg - hAvg) * 100) / hAvg;
+ if (pctDiff > 0) {
+ sb.append(" [AI ").append(pctDiff).append("% slower]");
+ } else if (pctDiff < 0) {
+ sb.append(" [AI ").append(-pctDiff).append("% faster]");
+ }
+ }
+ }
+ }
+ sb.append("\n");
+ }
+ sb.append("\n");
+ }
+
+ sb.append("--- Totals ---\n");
+ sb.append(" Checks: ").append(totalChecks).append("\n");
+ sb.append(" Violations: ").append(totalViolations);
+ if (totalViolations == 0) {
+ sb.append(" (all contracts satisfied)");
+ }
+ sb.append("\n");
+
+ return sb.toString();
+ }
+
+ @Override
+ public long getTotalViolations() {
+ long total = 0;
+ for (ContractStats s : contracts.values()) {
+ total += s.violations.get();
+ }
+ return total;
+ }
+
+ @Override
+ public long getViolations(String contract) {
+ ContractStats s = contracts.get(contract);
+ return s != null ? s.violations.get() : 0;
+ }
+
+ @Override
+ public long getTotalChecks() {
+ long total = 0;
+ for (ContractStats s : contracts.values()) {
+ total += s.checks.get();
+ }
+ return total;
+ }
+
+ @Override
+ public boolean hasViolations() {
+ for (ContractStats s : contracts.values()) {
+ if (s.violations.get() > 0) return true;
+ }
+ return false;
+ }
+
+ @Override
+ public void reset() {
+ contracts.clear();
+ rateWindows.clear();
+ aiPaths.clear();
+ humanPaths.clear();
+ }
+
+ @Override
+ public void close() {
+ String summary = getSummary();
+ if (!"No contracts checked.".equals(summary)) {
+ getContext().send(summary);
+ }
+ }
+
+ // ==================== Internals ====================
+
+ private ContractStats getOrCreate(String contract) {
+ return contracts.computeIfAbsent(contract, k -> new ContractStats());
+ }
+
+ private static void updateMinMax(AtomicLong min, AtomicLong max, long value) {
+ long cur;
+ do {
+ cur = min.get();
+ if (value >= cur) break;
+ } while (!min.compareAndSet(cur, value));
+ do {
+ cur = max.get();
+ if (value <= cur) break;
+ } while (!max.compareAndSet(cur, value));
+ }
+
+ static final class ContractStats {
+ final AtomicLong checks = new AtomicLong();
+ final AtomicLong violations = new AtomicLong();
+ final AtomicLong totalDurationNanos = new AtomicLong();
+ final AtomicLong minDurationNanos = new AtomicLong(Long.MAX_VALUE);
+ final AtomicLong maxDurationNanos = new AtomicLong(0);
+ volatile String lastViolationMessage;
+ }
+
+ static final class CodePathStats {
+ final AtomicLong calls = new AtomicLong();
+ final AtomicLong totalDurationNanos = new AtomicLong();
+ }
+
+ /**
+ * Simple sliding-window rate counter. Tracks calls in the last second
+ * using a circular buffer of 10 x 100ms buckets. Lock-free via CAS.
+ */
+ static final class RateWindow {
+ private static final int BUCKETS = 10;
+ private static final long BUCKET_NS = 100_000_000L; // 100ms
+ private final AtomicLong[] counts = new AtomicLong[BUCKETS];
+ private final AtomicLong[] timestamps = new AtomicLong[BUCKETS];
+
+ RateWindow() {
+ for (int i = 0; i < BUCKETS; i++) {
+ counts[i] = new AtomicLong();
+ timestamps[i] = new AtomicLong();
+ }
+ }
+
+ long recordAndCount(long nowNanos) {
+ int bucket = (int) ((nowNanos / BUCKET_NS) % BUCKETS);
+ long bucketTime = (nowNanos / BUCKET_NS) * BUCKET_NS;
+
+ // Reset bucket if stale
+ if (timestamps[bucket].get() != bucketTime) {
+ timestamps[bucket].set(bucketTime);
+ counts[bucket].set(0);
+ }
+ counts[bucket].incrementAndGet();
+
+ // Sum all non-stale buckets
+ long total = 0;
+ long windowStart = nowNanos - (BUCKETS * BUCKET_NS);
+ for (int i = 0; i < BUCKETS; i++) {
+ if (timestamps[i].get() > windowStart) {
+ total += counts[i].get();
+ }
+ }
+ return total;
+ }
+ }
+}
diff --git a/btrace-extensions/btrace-vibe-guard/src/test/java/org/openjdk/btrace/vibeguard/VibeGuardServiceTest.java b/btrace-extensions/btrace-vibe-guard/src/test/java/org/openjdk/btrace/vibeguard/VibeGuardServiceTest.java
new file mode 100644
index 00000000..747c591f
--- /dev/null
+++ b/btrace-extensions/btrace-vibe-guard/src/test/java/org/openjdk/btrace/vibeguard/VibeGuardServiceTest.java
@@ -0,0 +1,227 @@
+package org.openjdk.btrace.vibeguard;
+
+import static org.junit.jupiter.api.Assertions.*;
+
+import java.util.concurrent.CountDownLatch;
+import org.junit.jupiter.api.BeforeEach;
+import org.junit.jupiter.api.Test;
+
+class VibeGuardServiceTest {
+
+ private VibeGuardServiceImpl service;
+
+ @BeforeEach
+ void setUp() {
+ service = new VibeGuardServiceImpl();
+ }
+
+ // ==================== Latency checks ====================
+
+ @Test
+ void latencyWithinBudget() {
+ service.checkLatency("api.generate", 100_000_000L, 500_000_000L);
+ assertEquals(1, service.getTotalChecks());
+ assertEquals(0, service.getTotalViolations());
+ assertFalse(service.hasViolations());
+ }
+
+ @Test
+ void latencyExceedsBudget() {
+ service.checkLatency("api.generate", 600_000_000L, 500_000_000L);
+ assertEquals(1, service.getTotalViolations());
+ assertTrue(service.hasViolations());
+ String summary = service.getSummary();
+ assertTrue(summary.contains("VIOLATIONS: 1"));
+ assertTrue(summary.contains("exceeded budget"));
+ }
+
+ @Test
+ void latencyStatsTracked() {
+ service.checkLatency("api.call", 100_000_000L, 1_000_000_000L);
+ service.checkLatency("api.call", 200_000_000L, 1_000_000_000L);
+ service.checkLatency("api.call", 300_000_000L, 1_000_000_000L);
+
+ assertEquals(3, service.getTotalChecks());
+ String summary = service.getSummary();
+ assertTrue(summary.contains("avg 200ms"));
+ assertTrue(summary.contains("min 100ms"));
+ assertTrue(summary.contains("max 300ms"));
+ }
+
+ // ==================== Call rate checks ====================
+
+ @Test
+ void callRateWithinLimit() {
+ // Single call should never exceed any reasonable limit
+ service.checkCallRate("api.query", 1000);
+ assertEquals(0, service.getTotalViolations());
+ }
+
+ @Test
+ void callRateExceedsLimit() {
+ // Flood with calls — rate window should detect burst
+ for (int i = 0; i < 200; i++) {
+ service.checkCallRate("api.query", 10);
+ }
+ // After 200 calls in rapid succession, should have violations
+ assertTrue(service.getTotalViolations() > 0);
+ }
+
+ // ==================== Condition assertions ====================
+
+ @Test
+ void assertConditionTrue() {
+ service.assertCondition("positive-balance", true, "Balance must be positive");
+ assertEquals(0, service.getTotalViolations());
+ }
+
+ @Test
+ void assertConditionFalse() {
+ service.assertCondition("positive-balance", false, "Balance must be positive");
+ assertEquals(1, service.getTotalViolations());
+ assertTrue(service.getSummary().contains("Balance must be positive"));
+ }
+
+ // ==================== Range checks ====================
+
+ @Test
+ void rangeWithinBounds() {
+ service.checkRange("response-code", 200, 100, 599);
+ assertEquals(0, service.getTotalViolations());
+ }
+
+ @Test
+ void rangeBelowMin() {
+ service.checkRange("response-code", 50, 100, 599);
+ assertEquals(1, service.getTotalViolations());
+ assertTrue(service.getSummary().contains("outside range"));
+ }
+
+ @Test
+ void rangeAboveMax() {
+ service.checkRange("response-code", 700, 100, 599);
+ assertEquals(1, service.getTotalViolations());
+ }
+
+ @Test
+ void rangeAtBoundaries() {
+ service.checkRange("val", 100, 100, 200);
+ service.checkRange("val", 200, 100, 200);
+ assertEquals(0, service.getTotalViolations());
+ }
+
+ // ==================== Null checks ====================
+
+ @Test
+ void checkNotNullWithValue() {
+ service.checkNotNull("api.result", "hello");
+ assertEquals(0, service.getTotalViolations());
+ }
+
+ @Test
+ void checkNotNullWithNull() {
+ service.checkNotNull("api.result", null);
+ assertEquals(1, service.getTotalViolations());
+ assertTrue(service.getSummary().contains("Unexpected null"));
+ }
+
+ // ==================== AI vs Human tracking ====================
+
+ @Test
+ void aiVsHumanComparison() {
+ service.trackAiCodePath("Parser.parse", 50_000_000L);
+ service.trackAiCodePath("Parser.parse", 60_000_000L);
+ service.trackHumanCodePath("Parser.parse", 30_000_000L);
+ service.trackHumanCodePath("Parser.parse", 40_000_000L);
+
+ String summary = service.getSummary();
+ assertTrue(summary.contains("AI vs Human"));
+ assertTrue(summary.contains("AI 2 calls"));
+ assertTrue(summary.contains("Human 2 calls"));
+ // AI avg 55ms, Human avg 35ms -> AI ~57% slower
+ assertTrue(summary.contains("slower"));
+ }
+
+ @Test
+ void aiOnlyTracking() {
+ service.trackAiCodePath("Renderer.render", 100_000_000L);
+ String summary = service.getSummary();
+ assertTrue(summary.contains("AI 1 calls"));
+ assertFalse(summary.contains("Human"));
+ }
+
+ // ==================== Reporting ====================
+
+ @Test
+ void noDataSummary() {
+ assertEquals("No contracts checked.", service.getSummary());
+ }
+
+ @Test
+ void getViolationsPerContract() {
+ service.assertCondition("a", false, "fail");
+ service.assertCondition("a", false, "fail");
+ service.assertCondition("b", false, "fail");
+
+ assertEquals(2, service.getViolations("a"));
+ assertEquals(1, service.getViolations("b"));
+ assertEquals(0, service.getViolations("c"));
+ }
+
+ @Test
+ void multipleContractsInSummary() {
+ service.checkLatency("fast-api", 10_000_000L, 100_000_000L);
+ service.checkLatency("slow-api", 500_000_000L, 100_000_000L);
+
+ String summary = service.getSummary();
+ assertTrue(summary.contains("fast-api"));
+ assertTrue(summary.contains("slow-api"));
+ assertTrue(summary.contains("Checks: 2"));
+ assertTrue(summary.contains("Violations: 1"));
+ }
+
+ @Test
+ void allContractsSatisfied() {
+ service.checkLatency("api", 10_000_000L, 100_000_000L);
+ service.assertCondition("invariant", true, "ok");
+ service.checkRange("val", 50, 0, 100);
+
+ String summary = service.getSummary();
+ assertTrue(summary.contains("all contracts satisfied"));
+ }
+
+ @Test
+ void reset() {
+ service.assertCondition("a", false, "fail");
+ service.trackAiCodePath("b", 100L);
+ service.reset();
+
+ assertEquals(0, service.getTotalChecks());
+ assertEquals(0, service.getTotalViolations());
+ assertFalse(service.hasViolations());
+ assertEquals("No contracts checked.", service.getSummary());
+ }
+
+ @Test
+ void concurrentChecks() throws Exception {
+ int threads = 8;
+ int checksPerThread = 1000;
+ CountDownLatch latch = new CountDownLatch(threads);
+
+ for (int t = 0; t < threads; t++) {
+ new Thread(() -> {
+ try {
+ for (int i = 0; i < checksPerThread; i++) {
+ service.checkLatency("concurrent-api", 50_000_000L, 100_000_000L);
+ }
+ } finally {
+ latch.countDown();
+ }
+ }).start();
+ }
+ latch.await();
+
+ assertEquals(threads * checksPerThread, service.getTotalChecks());
+ assertEquals(0, service.getTotalViolations());
+ }
+}
diff --git a/btrace-mcp-server/README.md b/btrace-mcp-server/README.md
new file mode 100644
index 00000000..a23d5111
--- /dev/null
+++ b/btrace-mcp-server/README.md
@@ -0,0 +1,114 @@
+# BTrace MCP Server
+
+An MCP (Model Context Protocol) server that exposes BTrace operations as tools, allowing LLM clients (Claude Desktop, Claude Code, Cursor, etc.) to diagnose running JVMs.
+
+## Tools
+
+| Tool | Description |
+|------|-------------|
+| `list_jvms` | List all attachable Java VMs on this host |
+| `deploy_oneliner` | Deploy a BTrace oneliner probe (e.g. `com.example.Service::method @return { print method, duration }`) |
+| `deploy_script` | Deploy a full BTrace Java script |
+| `list_probes` | List active probes on a JVM |
+| `send_event` | Send an event to a running probe |
+| `detach_probe` | Detach from a running probe (probe continues) |
+| `exit_probe` | Stop and remove a probe |
+
+## Prompts
+
+| Prompt | Description |
+|--------|-------------|
+| `diagnose_slow_endpoint` | Step-by-step guide to diagnose a slow HTTP endpoint |
+| `find_exception_source` | Guide to find where exceptions originate |
+| `profile_method` | Guide to profile a method's latency |
+
+## Building
+
+```bash
+./gradlew :btrace-mcp-server:build
+```
+
+## Configuration
+
+### Claude Desktop
+
+Add to `~/Library/Application Support/Claude/claude_desktop_config.json` (macOS) or `%APPDATA%\Claude\claude_desktop_config.json` (Windows):
+
+```json
+{
+ "mcpServers": {
+ "btrace": {
+ "command": "java",
+ "args": [
+ "-cp", "/path/to/btrace/lib/*",
+ "org.openjdk.btrace.mcp.BTraceMcpServer"
+ ]
+ }
+ }
+}
+```
+
+### Claude Code
+
+Add to your project's `.mcp.json`:
+
+```json
+{
+ "mcpServers": {
+ "btrace": {
+ "command": "java",
+ "args": [
+ "-cp", "/path/to/btrace/lib/*",
+ "org.openjdk.btrace.mcp.BTraceMcpServer"
+ ]
+ }
+ }
+}
+```
+
+Or configure globally in `~/.claude/settings.json`:
+
+```json
+{
+ "mcpServers": {
+ "btrace": {
+ "command": "java",
+ "args": [
+ "-cp", "/path/to/btrace/lib/*",
+ "org.openjdk.btrace.mcp.BTraceMcpServer"
+ ]
+ }
+ }
+}
+```
+
+### Using the BTrace distribution
+
+If you have BTrace installed (e.g. via SDKMAN), you can reference the distribution directly:
+
+```json
+{
+ "mcpServers": {
+ "btrace": {
+ "command": "java",
+ "args": [
+ "-cp", "$BTRACE_HOME/lib/*:$BTRACE_HOME/build/btrace-mcp-server.jar",
+ "org.openjdk.btrace.mcp.BTraceMcpServer"
+ ]
+ }
+ }
+}
+```
+
+## Usage Examples
+
+Once configured, you can ask your LLM client:
+
+- "List all running Java processes"
+- "Attach to PID 12345 and trace all methods in com.example.UserService that take longer than 100ms"
+- "Find where NullPointerExceptions are being thrown in PID 12345"
+- "Profile the latency of com.example.OrderService::processOrder"
+
+## Protocol
+
+The server uses MCP over stdio (stdin/stdout JSON-RPC). All logging goes to stderr.
diff --git a/btrace-mcp-server/build.gradle b/btrace-mcp-server/build.gradle
new file mode 100644
index 00000000..adb4a4f9
--- /dev/null
+++ b/btrace-mcp-server/build.gradle
@@ -0,0 +1,28 @@
+// MCP server requires Java 11+ (uses ConcurrentHashMap.forEach, etc.)
+compileJava {
+ sourceCompatibility = 11
+ targetCompatibility = 11
+}
+
+dependencies {
+ implementation libs.slf4j
+ implementation libs.slf4j.simple
+
+ def toolsJar = getToolsJar()
+ if (toolsJar.getAsFile().exists()) {
+ compileOnly files("${toolsJar}")
+ }
+
+ implementation project(':btrace-core')
+ implementation project(':btrace-client')
+ implementation project(':btrace-compiler')
+ implementation project(':btrace-boot')
+}
+
+jar {
+ manifest {
+ attributes(
+ 'Main-Class': 'org.openjdk.btrace.mcp.BTraceMcpServer'
+ )
+ }
+}
diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/BTraceMcpServer.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/BTraceMcpServer.java
new file mode 100644
index 00000000..7c06eb05
--- /dev/null
+++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/BTraceMcpServer.java
@@ -0,0 +1,226 @@
+/*
+ * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation. Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package org.openjdk.btrace.mcp;
+
+import java.io.IOException;
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import org.openjdk.btrace.mcp.prompts.DiagnosticPrompts;
+import org.openjdk.btrace.mcp.tools.DeployOnelinerHandler;
+import org.openjdk.btrace.mcp.tools.DeployScriptHandler;
+import org.openjdk.btrace.mcp.tools.DetachProbeHandler;
+import org.openjdk.btrace.mcp.tools.ExitProbeHandler;
+import org.openjdk.btrace.mcp.tools.ListJvmsHandler;
+import org.openjdk.btrace.mcp.tools.ListProbesHandler;
+import org.openjdk.btrace.mcp.tools.SendEventHandler;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * BTrace MCP (Model Context Protocol) server. Exposes BTrace operations as MCP tools over stdio
+ * JSON-RPC transport, allowing LLM clients to instrument and diagnose running JVMs.
+ */
+public final class BTraceMcpServer {
+ private static final Logger log = LoggerFactory.getLogger(BTraceMcpServer.class);
+ private static final String SERVER_NAME = "btrace-mcp-server";
+ private static final String SERVER_VERSION = "0.1.0";
+ private static final String PROTOCOL_VERSION = "2024-11-05";
+
+ private final McpProtocol protocol;
+
+ BTraceMcpServer(McpProtocol protocol) {
+ this.protocol = protocol;
+ }
+
+ public static void main(String[] args) {
+ // Redirect System.out logging to stderr so stdout stays clean for MCP JSON-RPC
+ System.setProperty("org.slf4j.simpleLogger.logFile", "System.err");
+ System.setProperty("org.slf4j.simpleLogger.defaultLogLevel", "warn");
+
+ log.info("Starting {}", SERVER_NAME);
+ McpProtocol proto = new McpProtocol(System.in, System.out);
+ BTraceMcpServer server = new BTraceMcpServer(proto);
+ try {
+ server.run();
+ } catch (IOException e) {
+ log.error("Server error", e);
+ System.exit(1);
+ }
+ }
+
+ @SuppressWarnings("unchecked")
+ void run() throws IOException {
+ while (true) {
+ Map message = protocol.readMessage();
+ if (message == null) {
+ log.info("EOF on stdin, shutting down");
+ break;
+ }
+
+ String method = (String) message.get("method");
+ Object id = message.get("id");
+ Map params =
+ message.containsKey("params") ? (Map) message.get("params") : null;
+
+ if (method == null) {
+ // Response or notification without method — ignore
+ continue;
+ }
+
+ try {
+ switch (method) {
+ case "initialize":
+ handleInitialize(id, params);
+ break;
+ case "notifications/initialized":
+ // Client acknowledgement — no response needed
+ break;
+ case "tools/list":
+ handleToolsList(id);
+ break;
+ case "tools/call":
+ handleToolsCall(id, params);
+ break;
+ case "prompts/list":
+ handlePromptsList(id);
+ break;
+ case "prompts/get":
+ handlePromptsGet(id, params);
+ break;
+ default:
+ protocol.sendError(id, -32601, "Method not found: " + method);
+ }
+ } catch (Exception e) {
+ log.error("Error handling method: {}", method, e);
+ protocol.sendError(id, -32603, "Internal error: " + e.getMessage());
+ }
+ }
+ }
+
+ private void handleInitialize(Object id, Map params) throws IOException {
+ Map serverInfo = new LinkedHashMap<>();
+ serverInfo.put("name", SERVER_NAME);
+ serverInfo.put("version", SERVER_VERSION);
+
+ Map toolsCap = new LinkedHashMap<>();
+ Map promptsCap = new LinkedHashMap<>();
+ Map capabilities = new LinkedHashMap<>();
+ capabilities.put("tools", toolsCap);
+ capabilities.put("prompts", promptsCap);
+
+ Map result = new LinkedHashMap<>();
+ result.put("protocolVersion", PROTOCOL_VERSION);
+ result.put("capabilities", capabilities);
+ result.put("serverInfo", serverInfo);
+
+ protocol.sendResult(id, result);
+ }
+
+ private void handleToolsList(Object id) throws IOException {
+ List tools = new ArrayList<>();
+ tools.add(ListJvmsHandler.schema());
+ tools.add(DeployOnelinerHandler.schema());
+ tools.add(DeployScriptHandler.schema());
+ tools.add(ListProbesHandler.schema());
+ tools.add(SendEventHandler.schema());
+ tools.add(DetachProbeHandler.schema());
+ tools.add(ExitProbeHandler.schema());
+
+ Map result = new LinkedHashMap<>();
+ result.put("tools", tools);
+ protocol.sendResult(id, result);
+ }
+
+ @SuppressWarnings("unchecked")
+ private void handleToolsCall(Object id, Map params) throws IOException {
+ if (params == null) {
+ protocol.sendError(id, -32602, "Missing params");
+ return;
+ }
+ String toolName = (String) params.get("name");
+ Map arguments =
+ params.containsKey("arguments")
+ ? (Map) params.get("arguments")
+ : new LinkedHashMap<>();
+
+ Map result;
+ switch (toolName) {
+ case "list_jvms":
+ result = ListJvmsHandler.execute(arguments);
+ break;
+ case "deploy_oneliner":
+ result = DeployOnelinerHandler.execute(arguments);
+ break;
+ case "deploy_script":
+ result = DeployScriptHandler.execute(arguments);
+ break;
+ case "list_probes":
+ result = ListProbesHandler.execute(arguments);
+ break;
+ case "send_event":
+ result = SendEventHandler.execute(arguments);
+ break;
+ case "detach_probe":
+ result = DetachProbeHandler.execute(arguments);
+ break;
+ case "exit_probe":
+ result = ExitProbeHandler.execute(arguments);
+ break;
+ default:
+ protocol.sendError(id, -32602, "Unknown tool: " + toolName);
+ return;
+ }
+ protocol.sendResult(id, result);
+ }
+
+ private void handlePromptsList(Object id) throws IOException {
+ Map result = new LinkedHashMap<>();
+ result.put("prompts", DiagnosticPrompts.listPrompts());
+ protocol.sendResult(id, result);
+ }
+
+ @SuppressWarnings("unchecked")
+ private void handlePromptsGet(Object id, Map params) throws IOException {
+ if (params == null) {
+ protocol.sendError(id, -32602, "Missing params");
+ return;
+ }
+ String name = (String) params.get("name");
+ Map promptArgs =
+ params.containsKey("arguments")
+ ? (Map) params.get("arguments")
+ : new LinkedHashMap<>();
+
+ Map result = DiagnosticPrompts.getPrompt(name, promptArgs);
+ if (result == null) {
+ protocol.sendError(id, -32602, "Unknown prompt: " + name);
+ return;
+ }
+ protocol.sendResult(id, result);
+ }
+}
diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/ClientManager.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/ClientManager.java
new file mode 100644
index 00000000..a85c431d
--- /dev/null
+++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/ClientManager.java
@@ -0,0 +1,86 @@
+/*
+ * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation. Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package org.openjdk.btrace.mcp;
+
+import java.util.Map;
+import java.util.concurrent.ConcurrentHashMap;
+import org.openjdk.btrace.client.Client;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Manages BTrace client instances per JVM. Tracks active sessions so that tools like send_event,
+ * detach_probe, and exit_probe can interact with already-deployed probes.
+ */
+public final class ClientManager {
+ private static final Logger log = LoggerFactory.getLogger(ClientManager.class);
+
+ /** Key is "pid:port", value is the active Client instance. */
+ private static final Map activeClients = new ConcurrentHashMap<>();
+
+ private ClientManager() {}
+
+ /** Creates a new BTrace Client for the given port. Does not reuse existing sessions. */
+ public static Client getClient(int port) {
+ return new Client(port);
+ }
+
+ /** Registers a client as active for a given PID and port. */
+ public static void registerClient(String pid, int port, Client client) {
+ String key = pid + ":" + port;
+ activeClients.put(key, client);
+ log.debug("Registered client for {}", key);
+ }
+
+ /** Returns an existing active client for the given PID and port, or null if none. */
+ public static Client getExistingClient(String pid, int port) {
+ String key = pid + ":" + port;
+ return activeClients.get(key);
+ }
+
+ /** Removes and returns an active client for the given PID and port. */
+ public static Client removeClient(String pid, int port) {
+ String key = pid + ":" + port;
+ Client removed = activeClients.remove(key);
+ if (removed != null) {
+ log.debug("Removed client for {}", key);
+ }
+ return removed;
+ }
+
+ /** Closes all active client sessions. Called on shutdown. */
+ public static void closeAll() {
+ for (Map.Entry entry : activeClients.entrySet()) {
+ try {
+ entry.getValue().close();
+ log.debug("Closed client for {}", entry.getKey());
+ } catch (Exception e) {
+ log.warn("Error closing client for {}", entry.getKey(), e);
+ }
+ }
+ activeClients.clear();
+ }
+}
diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/McpProtocol.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/McpProtocol.java
new file mode 100644
index 00000000..7b222b05
--- /dev/null
+++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/McpProtocol.java
@@ -0,0 +1,430 @@
+/*
+ * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation. Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package org.openjdk.btrace.mcp;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.io.InputStream;
+import java.io.InputStreamReader;
+import java.io.OutputStream;
+import java.nio.charset.StandardCharsets;
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Handles MCP JSON-RPC protocol over stdio. Reads JSON-RPC messages from stdin and writes responses
+ * to stdout. All logging goes to stderr.
+ */
+final class McpProtocol {
+ private static final Logger log = LoggerFactory.getLogger(McpProtocol.class);
+ private final BufferedReader reader;
+ private final OutputStream out;
+
+ McpProtocol(InputStream in, OutputStream out) {
+ this.reader = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8));
+ this.out = out;
+ }
+
+ /** Reads a single JSON-RPC message from stdin. Returns null on EOF. */
+ Map readMessage() throws IOException {
+ String line = reader.readLine();
+ if (line == null) {
+ return null;
+ }
+ line = line.trim();
+ if (line.isEmpty()) {
+ return null;
+ }
+ log.debug("Received: {}", line);
+ return parseJson(line);
+ }
+
+ /** Writes a JSON-RPC response to stdout (one line, newline terminated). */
+ synchronized void writeMessage(Map message) throws IOException {
+ String json = toJson(message);
+ log.debug("Sending: {}", json);
+ out.write(json.getBytes(StandardCharsets.UTF_8));
+ out.write('\n');
+ out.flush();
+ }
+
+ /** Sends a JSON-RPC success response. */
+ void sendResult(Object id, Object result) throws IOException {
+ Map response = new LinkedHashMap<>();
+ response.put("jsonrpc", "2.0");
+ response.put("id", id);
+ response.put("result", result);
+ writeMessage(response);
+ }
+
+ /** Sends a JSON-RPC error response. */
+ void sendError(Object id, int code, String message) throws IOException {
+ sendError(id, code, message, null);
+ }
+
+ /** Sends a JSON-RPC error response with optional data. */
+ void sendError(Object id, int code, String message, Object data) throws IOException {
+ Map error = new LinkedHashMap<>();
+ error.put("code", code);
+ error.put("message", message);
+ if (data != null) {
+ error.put("data", data);
+ }
+ Map response = new LinkedHashMap<>();
+ response.put("jsonrpc", "2.0");
+ response.put("id", id);
+ response.put("error", error);
+ writeMessage(response);
+ }
+
+ // --- Minimal JSON parser (handles objects, arrays, strings, numbers, booleans, null) ---
+
+ @SuppressWarnings("unchecked")
+ static Map parseJson(String json) {
+ Object result = new JsonParser(json.trim()).parseValue();
+ if (result instanceof Map) {
+ return (Map) result;
+ }
+ throw new IllegalArgumentException("Expected JSON object, got: " + json);
+ }
+
+ private static final class JsonParser {
+ private final String src;
+ private int pos;
+
+ JsonParser(String src) {
+ this.src = src;
+ this.pos = 0;
+ }
+
+ Object parseValue() {
+ skipWhitespace();
+ if (pos >= src.length()) {
+ throw new IllegalArgumentException("Unexpected end of JSON");
+ }
+ char c = src.charAt(pos);
+ if (c == '{') {
+ return parseObject();
+ }
+ if (c == '[') {
+ return parseArray();
+ }
+ if (c == '"') {
+ return parseString();
+ }
+ if (c == 't' || c == 'f') {
+ return parseBoolean();
+ }
+ if (c == 'n') {
+ return parseNull();
+ }
+ return parseNumber();
+ }
+
+ Map parseObject() {
+ expect('{');
+ Map map = new LinkedHashMap<>();
+ skipWhitespace();
+ if (pos < src.length() && src.charAt(pos) == '}') {
+ pos++;
+ return map;
+ }
+ while (true) {
+ skipWhitespace();
+ String key = parseString();
+ skipWhitespace();
+ expect(':');
+ Object value = parseValue();
+ map.put(key, value);
+ skipWhitespace();
+ if (pos < src.length() && src.charAt(pos) == ',') {
+ pos++;
+ } else {
+ break;
+ }
+ }
+ expect('}');
+ return map;
+ }
+
+ List parseArray() {
+ expect('[');
+ List list = new ArrayList<>();
+ skipWhitespace();
+ if (pos < src.length() && src.charAt(pos) == ']') {
+ pos++;
+ return list;
+ }
+ while (true) {
+ list.add(parseValue());
+ skipWhitespace();
+ if (pos < src.length() && src.charAt(pos) == ',') {
+ pos++;
+ } else {
+ break;
+ }
+ }
+ expect(']');
+ return list;
+ }
+
+ String parseString() {
+ expect('"');
+ StringBuilder sb = new StringBuilder();
+ while (pos < src.length()) {
+ char c = src.charAt(pos++);
+ if (c == '"') {
+ return sb.toString();
+ }
+ if (c == '\\') {
+ if (pos >= src.length()) {
+ break;
+ }
+ char esc = src.charAt(pos++);
+ switch (esc) {
+ case '"':
+ case '\\':
+ case '/':
+ sb.append(esc);
+ break;
+ case 'b':
+ sb.append('\b');
+ break;
+ case 'f':
+ sb.append('\f');
+ break;
+ case 'n':
+ sb.append('\n');
+ break;
+ case 'r':
+ sb.append('\r');
+ break;
+ case 't':
+ sb.append('\t');
+ break;
+ case 'u':
+ if (pos + 4 <= src.length()) {
+ String hex = src.substring(pos, pos + 4);
+ sb.append((char) Integer.parseInt(hex, 16));
+ pos += 4;
+ }
+ break;
+ default:
+ sb.append(esc);
+ }
+ } else {
+ sb.append(c);
+ }
+ }
+ throw new IllegalArgumentException("Unterminated string");
+ }
+
+ Object parseNumber() {
+ int start = pos;
+ if (pos < src.length() && src.charAt(pos) == '-') {
+ pos++;
+ }
+ while (pos < src.length() && Character.isDigit(src.charAt(pos))) {
+ pos++;
+ }
+ boolean isFloat = false;
+ if (pos < src.length() && src.charAt(pos) == '.') {
+ isFloat = true;
+ pos++;
+ while (pos < src.length() && Character.isDigit(src.charAt(pos))) {
+ pos++;
+ }
+ }
+ if (pos < src.length() && (src.charAt(pos) == 'e' || src.charAt(pos) == 'E')) {
+ isFloat = true;
+ pos++;
+ if (pos < src.length() && (src.charAt(pos) == '+' || src.charAt(pos) == '-')) {
+ pos++;
+ }
+ while (pos < src.length() && Character.isDigit(src.charAt(pos))) {
+ pos++;
+ }
+ }
+ String numStr = src.substring(start, pos);
+ if (isFloat) {
+ return Double.parseDouble(numStr);
+ }
+ long val = Long.parseLong(numStr);
+ if (val >= Integer.MIN_VALUE && val <= Integer.MAX_VALUE) {
+ return (int) val;
+ }
+ return val;
+ }
+
+ Object parseBoolean() {
+ if (src.startsWith("true", pos)) {
+ pos += 4;
+ return Boolean.TRUE;
+ }
+ if (src.startsWith("false", pos)) {
+ pos += 5;
+ return Boolean.FALSE;
+ }
+ throw new IllegalArgumentException("Invalid boolean at " + pos);
+ }
+
+ Object parseNull() {
+ if (src.startsWith("null", pos)) {
+ pos += 4;
+ return null;
+ }
+ throw new IllegalArgumentException("Invalid null at " + pos);
+ }
+
+ void skipWhitespace() {
+ while (pos < src.length() && Character.isWhitespace(src.charAt(pos))) {
+ pos++;
+ }
+ }
+
+ void expect(char c) {
+ skipWhitespace();
+ if (pos >= src.length() || src.charAt(pos) != c) {
+ throw new IllegalArgumentException(
+ "Expected '" + c + "' at " + pos + " but got: " + (pos < src.length() ? src.charAt(pos) : "EOF"));
+ }
+ pos++;
+ }
+ }
+
+ // --- Minimal JSON serializer ---
+
+ @SuppressWarnings("unchecked")
+ static String toJson(Object obj) {
+ if (obj == null) {
+ return "null";
+ }
+ if (obj instanceof String) {
+ return escapeJsonString((String) obj);
+ }
+ if (obj instanceof Number || obj instanceof Boolean) {
+ return obj.toString();
+ }
+ if (obj instanceof Map) {
+ Map map = (Map) obj;
+ StringBuilder sb = new StringBuilder("{");
+ boolean first = true;
+ for (Map.Entry entry : map.entrySet()) {
+ if (!first) {
+ sb.append(",");
+ }
+ sb.append(escapeJsonString(entry.getKey()));
+ sb.append(":");
+ sb.append(toJson(entry.getValue()));
+ first = false;
+ }
+ sb.append("}");
+ return sb.toString();
+ }
+ if (obj instanceof List) {
+ List list = (List) obj;
+ StringBuilder sb = new StringBuilder("[");
+ for (int i = 0; i < list.size(); i++) {
+ if (i > 0) {
+ sb.append(",");
+ }
+ sb.append(toJson(list.get(i)));
+ }
+ sb.append("]");
+ return sb.toString();
+ }
+ if (obj instanceof Object[]) {
+ Object[] arr = (Object[]) obj;
+ StringBuilder sb = new StringBuilder("[");
+ for (int i = 0; i < arr.length; i++) {
+ if (i > 0) {
+ sb.append(",");
+ }
+ sb.append(toJson(arr[i]));
+ }
+ sb.append("]");
+ return sb.toString();
+ }
+ return escapeJsonString(obj.toString());
+ }
+
+ private static String escapeJsonString(String s) {
+ StringBuilder sb = new StringBuilder("\"");
+ for (int i = 0; i < s.length(); i++) {
+ char c = s.charAt(i);
+ switch (c) {
+ case '"':
+ sb.append("\\\"");
+ break;
+ case '\\':
+ sb.append("\\\\");
+ break;
+ case '\b':
+ sb.append("\\b");
+ break;
+ case '\f':
+ sb.append("\\f");
+ break;
+ case '\n':
+ sb.append("\\n");
+ break;
+ case '\r':
+ sb.append("\\r");
+ break;
+ case '\t':
+ sb.append("\\t");
+ break;
+ default:
+ if (c < 0x20) {
+ sb.append(String.format("\\u%04x", (int) c));
+ } else {
+ sb.append(c);
+ }
+ }
+ }
+ sb.append("\"");
+ return sb.toString();
+ }
+
+ /** Helper to build a tool content result (text). */
+ static Map toolResult(String text, boolean isError) {
+ Map content = new LinkedHashMap<>();
+ content.put("type", "text");
+ content.put("text", text);
+
+ List contentList = new ArrayList<>();
+ contentList.add(content);
+
+ Map result = new LinkedHashMap<>();
+ result.put("content", contentList);
+ result.put("isError", isError);
+ return result;
+ }
+}
diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/prompts/DiagnosticPrompts.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/prompts/DiagnosticPrompts.java
new file mode 100644
index 00000000..84c0c7fe
--- /dev/null
+++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/prompts/DiagnosticPrompts.java
@@ -0,0 +1,320 @@
+/*
+ * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation. Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package org.openjdk.btrace.mcp.prompts;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+
+/** Provides MCP prompt templates for common BTrace diagnostic scenarios. */
+public final class DiagnosticPrompts {
+
+ private DiagnosticPrompts() {}
+
+ /** Returns all available prompt schemas for prompts/list. */
+ public static List> listPrompts() {
+ List> prompts = new ArrayList<>();
+ prompts.add(diagnoseSlowEndpointSchema());
+ prompts.add(findExceptionSourceSchema());
+ prompts.add(profileMethodSchema());
+ return prompts;
+ }
+
+ /** Returns a prompt by name, or null if not found. */
+ public static Map getPrompt(String name, Map arguments) {
+ switch (name) {
+ case "diagnose_slow_endpoint":
+ return diagnoseSlowEndpoint(arguments);
+ case "find_exception_source":
+ return findExceptionSource(arguments);
+ case "profile_method":
+ return profileMethod(arguments);
+ default:
+ return null;
+ }
+ }
+
+ // --- diagnose_slow_endpoint ---
+
+ private static Map diagnoseSlowEndpointSchema() {
+ Map prompt = new LinkedHashMap<>();
+ prompt.put("name", "diagnose_slow_endpoint");
+ prompt.put(
+ "description",
+ "Step-by-step guide to diagnose a slow HTTP endpoint using BTrace. "
+ + "Helps identify where time is spent in request processing.");
+
+ List> args = new ArrayList<>();
+ args.add(promptArg("endpoint_class", "Fully qualified class name of the endpoint/controller", true));
+ args.add(promptArg("endpoint_method", "Method name to diagnose", true));
+ args.add(promptArg("pid", "PID of the target JVM (optional, will use list_jvms if not provided)", false));
+ prompt.put("arguments", args);
+ return prompt;
+ }
+
+ private static Map diagnoseSlowEndpoint(Map arguments) {
+ String endpointClass = getArg(arguments, "endpoint_class", "com.example.Controller");
+ String endpointMethod = getArg(arguments, "endpoint_method", "handleRequest");
+ String pid = getArg(arguments, "pid", null);
+
+ StringBuilder text = new StringBuilder();
+ text.append("# Diagnosing Slow Endpoint: ")
+ .append(endpointClass)
+ .append("::")
+ .append(endpointMethod)
+ .append("\n\n");
+
+ text.append("## Step 1: Find the target JVM\n");
+ if (pid != null) {
+ text.append("Target PID: ").append(pid).append("\n");
+ } else {
+ text.append("Use the `list_jvms` tool to find the target JVM PID.\n");
+ }
+ text.append("\n");
+
+ text.append("## Step 2: Measure endpoint latency\n");
+ text.append("Deploy a oneliner to measure the method's execution time:\n");
+ text.append("```\n");
+ text.append(endpointClass)
+ .append("::")
+ .append(endpointMethod)
+ .append(" @return { print method, duration }\n");
+ text.append("```\n\n");
+
+ text.append("## Step 3: Trace internal method calls\n");
+ text.append(
+ "If the endpoint is slow, trace the internal methods it calls to find the bottleneck:\n");
+ text.append("```\n");
+ text.append(endpointClass)
+ .append("::/.*/")
+ .append(" @return if duration>10ms { print method, duration }\n");
+ text.append("```\n\n");
+
+ text.append("## Step 4: Analyze results\n");
+ text.append(
+ "Look for methods with unexpectedly high durations. Common causes include:\n");
+ text.append("- Database queries taking too long\n");
+ text.append("- External service calls with high latency\n");
+ text.append("- Lock contention\n");
+ text.append("- Excessive object allocation\n\n");
+
+ text.append("## Step 5: Clean up\n");
+ text.append("Use `exit_probe` to remove the instrumentation when done.\n");
+
+ return promptResult(text.toString());
+ }
+
+ // --- find_exception_source ---
+
+ private static Map findExceptionSourceSchema() {
+ Map prompt = new LinkedHashMap<>();
+ prompt.put("name", "find_exception_source");
+ prompt.put(
+ "description",
+ "Guide to find where specific exceptions originate in a running JVM. "
+ + "Uses BTrace to intercept exception constructors and capture stack traces.");
+
+ List> args = new ArrayList<>();
+ args.add(promptArg("exception_class", "Exception class name (e.g. NullPointerException)", true));
+ args.add(promptArg("pid", "PID of the target JVM (optional)", false));
+ prompt.put("arguments", args);
+ return prompt;
+ }
+
+ private static Map findExceptionSource(Map arguments) {
+ String exceptionClass = getArg(arguments, "exception_class", "java.lang.NullPointerException");
+ String pid = getArg(arguments, "pid", null);
+
+ StringBuilder text = new StringBuilder();
+ text.append("# Finding Exception Source: ").append(exceptionClass).append("\n\n");
+
+ text.append("## Step 1: Find the target JVM\n");
+ if (pid != null) {
+ text.append("Target PID: ").append(pid).append("\n");
+ } else {
+ text.append("Use the `list_jvms` tool to find the target JVM PID.\n");
+ }
+ text.append("\n");
+
+ text.append("## Step 2: Deploy an exception tracing script\n");
+ text.append("Deploy this BTrace script to capture exception creation with stack traces:\n\n");
+ text.append("```java\n");
+ text.append("import org.openjdk.btrace.core.annotations.*;\n");
+ text.append("import static org.openjdk.btrace.core.BTraceUtils.*;\n\n");
+ text.append("@BTrace\n");
+ text.append("public class ExceptionTracer {\n");
+ text.append(" @OnMethod(\n");
+ text.append(" clazz = \"").append(exceptionClass).append("\",\n");
+ text.append(" method = \"\"\n");
+ text.append(" )\n");
+ text.append(" public static void onException(@Self Throwable self) {\n");
+ text.append(" println(\"--- Exception created: \" + Strings.str(self) + \" ---\");\n");
+ text.append(" Threads.jstack();\n");
+ text.append(" println(\"\");\n");
+ text.append(" }\n");
+ text.append("}\n");
+ text.append("```\n\n");
+
+ text.append("## Step 3: Analyze the stack traces\n");
+ text.append(
+ "Each time the exception is created, you will see the full stack trace "
+ + "showing exactly which code path creates it.\n\n");
+
+ text.append("## Step 4: Clean up\n");
+ text.append("Use `exit_probe` to remove the instrumentation when done.\n");
+
+ return promptResult(text.toString());
+ }
+
+ // --- profile_method ---
+
+ private static Map profileMethodSchema() {
+ Map prompt = new LinkedHashMap<>();
+ prompt.put("name", "profile_method");
+ prompt.put(
+ "description",
+ "Guide to profile a specific method's latency distribution using BTrace. "
+ + "Captures timing data to understand performance characteristics.");
+
+ List> args = new ArrayList<>();
+ args.add(promptArg("class_name", "Fully qualified class name", true));
+ args.add(promptArg("method_name", "Method name to profile", true));
+ args.add(promptArg("pid", "PID of the target JVM (optional)", false));
+ prompt.put("arguments", args);
+ return prompt;
+ }
+
+ private static Map profileMethod(Map arguments) {
+ String className = getArg(arguments, "class_name", "com.example.Service");
+ String methodName = getArg(arguments, "method_name", "process");
+ String pid = getArg(arguments, "pid", null);
+
+ StringBuilder text = new StringBuilder();
+ text.append("# Profiling Method: ")
+ .append(className)
+ .append("::")
+ .append(methodName)
+ .append("\n\n");
+
+ text.append("## Step 1: Find the target JVM\n");
+ if (pid != null) {
+ text.append("Target PID: ").append(pid).append("\n");
+ } else {
+ text.append("Use the `list_jvms` tool to find the target JVM PID.\n");
+ }
+ text.append("\n");
+
+ text.append("## Step 2: Quick latency check with oneliner\n");
+ text.append("Start with a simple oneliner to see individual call durations:\n");
+ text.append("```\n");
+ text.append(className)
+ .append("::")
+ .append(methodName)
+ .append(" @return { print method, duration }\n");
+ text.append("```\n\n");
+
+ text.append("## Step 3: Detailed profiling with histogram\n");
+ text.append("For a latency distribution, deploy this BTrace script:\n\n");
+ text.append("```java\n");
+ text.append("import org.openjdk.btrace.core.annotations.*;\n");
+ text.append("import org.openjdk.btrace.core.BTraceUtils;\n");
+ text.append("import static org.openjdk.btrace.core.BTraceUtils.*;\n\n");
+ text.append("@BTrace\n");
+ text.append("public class MethodProfiler {\n");
+ text.append(" private static long count;\n");
+ text.append(" private static long totalTime;\n\n");
+ text.append(" @OnMethod(\n");
+ text.append(" clazz = \"").append(className).append("\",\n");
+ text.append(" method = \"").append(methodName).append("\",\n");
+ text.append(" location = @Location(Kind.RETURN)\n");
+ text.append(" )\n");
+ text.append(" public static void onReturn(@Duration long duration) {\n");
+ text.append(" count++;\n");
+ text.append(" totalTime += duration;\n");
+ text.append(
+ " println(\"Call #\" + count + \": \" + (duration / 1000000) + \"ms\");\n");
+ text.append(" }\n\n");
+ text.append(" @OnEvent\n");
+ text.append(" public static void onEvent() {\n");
+ text.append(" println(\"=== Summary ===\");\n");
+ text.append(" println(\"Total calls: \" + count);\n");
+ text.append(" if (count > 0) {\n");
+ text.append(
+ " println(\"Avg duration: \" + ((totalTime / count) / 1000000) + \"ms\");\n");
+ text.append(" }\n");
+ text.append(" }\n");
+ text.append("}\n");
+ text.append("```\n\n");
+
+ text.append("## Step 4: Get summary\n");
+ text.append("Use `send_event` to trigger the @OnEvent handler and get a summary.\n\n");
+
+ text.append("## Step 5: Clean up\n");
+ text.append("Use `exit_probe` to remove the instrumentation when done.\n");
+
+ return promptResult(text.toString());
+ }
+
+ // --- Helpers ---
+
+ private static Map promptArg(
+ String name, String description, boolean required) {
+ Map arg = new LinkedHashMap<>();
+ arg.put("name", name);
+ arg.put("description", description);
+ arg.put("required", required);
+ return arg;
+ }
+
+ private static String getArg(Map args, String key, String defaultVal) {
+ if (args == null) {
+ return defaultVal;
+ }
+ Object val = args.get(key);
+ if (val == null) {
+ return defaultVal;
+ }
+ return val.toString();
+ }
+
+ private static Map promptResult(String text) {
+ Map message = new LinkedHashMap<>();
+ message.put("role", "user");
+
+ Map content = new LinkedHashMap<>();
+ content.put("type", "text");
+ content.put("text", text);
+ message.put("content", content);
+
+ List messages = new ArrayList<>();
+ messages.add(message);
+
+ Map result = new LinkedHashMap<>();
+ result.put("messages", messages);
+ return result;
+ }
+}
diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DeployOnelinerHandler.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DeployOnelinerHandler.java
new file mode 100644
index 00000000..2556554a
--- /dev/null
+++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DeployOnelinerHandler.java
@@ -0,0 +1,230 @@
+/*
+ * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation. Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package org.openjdk.btrace.mcp.tools;
+
+import java.io.PrintWriter;
+import java.io.StringWriter;
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import org.openjdk.btrace.client.Client;
+import org.openjdk.btrace.compiler.oneliner.OnelinerAST.OnelinerNode;
+import org.openjdk.btrace.compiler.oneliner.OnelinerCodeGenerator;
+import org.openjdk.btrace.compiler.oneliner.OnelinerParser;
+import org.openjdk.btrace.compiler.oneliner.OnelinerValidator;
+import org.openjdk.btrace.core.comm.Command;
+import org.openjdk.btrace.core.comm.PrintableCommand;
+import org.openjdk.btrace.mcp.ClientManager;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Handles the deploy_oneliner MCP tool - deploys a BTrace oneliner probe. */
+public final class DeployOnelinerHandler {
+ private static final Logger log = LoggerFactory.getLogger(DeployOnelinerHandler.class);
+ private static final int DEFAULT_PORT = 2020;
+ private static final int PROBE_TIMEOUT_SECONDS = 30;
+
+ private DeployOnelinerHandler() {}
+
+ /** Returns tool schema for MCP tools/list. */
+ public static Map schema() {
+ Map tool = new LinkedHashMap<>();
+ tool.put("name", "deploy_oneliner");
+ tool.put(
+ "description",
+ "Deploy a BTrace oneliner probe to a running JVM. "
+ + "Oneliners are concise probe expressions like: "
+ + "\"com.example.Service::method @return { print duration }\" "
+ + "or \"com.example.Dao::query @return if duration>100ms { print method, duration }\". "
+ + "The probe attaches to the target JVM and captures output for the specified duration.");
+
+ Map properties = new LinkedHashMap<>();
+
+ Map pidProp = new LinkedHashMap<>();
+ pidProp.put("type", "string");
+ pidProp.put("description", "PID of the target JVM (use list_jvms to find it)");
+ properties.put("pid", pidProp);
+
+ Map onelinerProp = new LinkedHashMap<>();
+ onelinerProp.put("type", "string");
+ onelinerProp.put(
+ "description",
+ "BTrace oneliner expression, e.g. "
+ + "\"com.example.Service::method @return { print method, duration }\"");
+ properties.put("oneliner", onelinerProp);
+
+ Map portProp = new LinkedHashMap<>();
+ portProp.put("type", "integer");
+ portProp.put("description", "BTrace agent port (default: 2020)");
+ properties.put("port", portProp);
+
+ List required = new ArrayList<>();
+ required.add("pid");
+ required.add("oneliner");
+
+ Map inputSchema = new LinkedHashMap<>();
+ inputSchema.put("type", "object");
+ inputSchema.put("properties", properties);
+ inputSchema.put("required", required);
+ tool.put("inputSchema", inputSchema);
+ return tool;
+ }
+
+ /** Executes the deploy_oneliner tool. */
+ @SuppressWarnings("unchecked")
+ public static Map execute(Map arguments) {
+ String pid = getStringArg(arguments, "pid");
+ String oneliner = getStringArg(arguments, "oneliner");
+ int port = getIntArg(arguments, "port", DEFAULT_PORT);
+
+ if (pid == null || pid.isEmpty()) {
+ return toolResult("Error: 'pid' parameter is required", true);
+ }
+ if (oneliner == null || oneliner.isEmpty()) {
+ return toolResult("Error: 'oneliner' parameter is required", true);
+ }
+
+ try {
+ // Parse and validate the oneliner
+ OnelinerNode ast = OnelinerParser.parse(oneliner);
+ OnelinerValidator.validate(ast, oneliner);
+ String className = "BTraceOneliner_" + System.currentTimeMillis();
+ String javaSource = OnelinerCodeGenerator.generate(ast, className);
+ String fileName = className + ".java";
+
+ log.info("Generated oneliner source for {}: {}", oneliner, javaSource);
+
+ // Compile the oneliner
+ Client client = ClientManager.getClient(port);
+ StringWriter errorWriter = new StringWriter();
+ PrintWriter errPw = new PrintWriter(errorWriter);
+ byte[] code = client.compileSource(fileName, javaSource, ".", errPw, null);
+
+ if (code == null) {
+ String errors = errorWriter.toString();
+ return toolResult(
+ "Oneliner compilation failed:\n" + errors + "\nGenerated source:\n" + javaSource, true);
+ }
+
+ // Attach and submit
+ client.attach(pid, null, ".");
+
+ StringBuilder output = new StringBuilder();
+ CountDownLatch statusLatch = new CountDownLatch(1);
+ AtomicBoolean success = new AtomicBoolean(false);
+ AtomicBoolean exited = new AtomicBoolean(false);
+
+ client.submit(
+ "localhost",
+ fileName,
+ code,
+ new String[0],
+ cmd -> {
+ int type = cmd.getType();
+ if (cmd instanceof PrintableCommand) {
+ StringWriter sw = new StringWriter();
+ ((PrintableCommand) cmd).print(new java.io.PrintWriter(sw));
+ output.append(sw.toString());
+ }
+ if (type == Command.STATUS) {
+ success.set(true);
+ statusLatch.countDown();
+ }
+ if (type == Command.EXIT) {
+ exited.set(true);
+ statusLatch.countDown();
+ }
+ });
+
+ // Wait for probe to start (or fail)
+ boolean started = statusLatch.await(PROBE_TIMEOUT_SECONDS, TimeUnit.SECONDS);
+
+ if (!started) {
+ return toolResult("Probe deployment timed out after " + PROBE_TIMEOUT_SECONDS + "s", true);
+ }
+
+ if (exited.get() && !success.get()) {
+ return toolResult("Probe exited with error:\n" + output.toString(), true);
+ }
+
+ String resultText =
+ "Probe deployed successfully to PID "
+ + pid
+ + ".\n"
+ + "Oneliner: "
+ + oneliner
+ + "\n"
+ + "Port: "
+ + port
+ + "\n";
+ if (output.length() > 0) {
+ resultText += "\nInitial output:\n" + output.toString();
+ }
+ resultText +=
+ "\nUse send_event, detach_probe, or exit_probe to interact with the running probe.";
+ return toolResult(resultText, false);
+ } catch (Exception e) {
+ log.error("Failed to deploy oneliner", e);
+ return toolResult("Error deploying oneliner: " + e.getMessage(), true);
+ }
+ }
+
+ private static String getStringArg(Map args, String key) {
+ Object val = args == null ? null : args.get(key);
+ return val == null ? null : val.toString();
+ }
+
+ private static int getIntArg(Map args, String key, int defaultVal) {
+ Object val = args == null ? null : args.get(key);
+ if (val == null) {
+ return defaultVal;
+ }
+ if (val instanceof Number) {
+ return ((Number) val).intValue();
+ }
+ try {
+ return Integer.parseInt(val.toString());
+ } catch (NumberFormatException e) {
+ return defaultVal;
+ }
+ }
+
+ private static Map toolResult(String text, boolean isError) {
+ Map content = new LinkedHashMap<>();
+ content.put("type", "text");
+ content.put("text", text);
+ List contentList = new ArrayList<>();
+ contentList.add(content);
+ Map result = new LinkedHashMap<>();
+ result.put("content", contentList);
+ result.put("isError", isError);
+ return result;
+ }
+}
diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DeployScriptHandler.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DeployScriptHandler.java
new file mode 100644
index 00000000..125048ca
--- /dev/null
+++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DeployScriptHandler.java
@@ -0,0 +1,226 @@
+/*
+ * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation. Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package org.openjdk.btrace.mcp.tools;
+
+import java.io.PrintWriter;
+import java.io.StringWriter;
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import java.util.concurrent.atomic.AtomicBoolean;
+import org.openjdk.btrace.client.Client;
+import org.openjdk.btrace.core.comm.Command;
+import org.openjdk.btrace.core.comm.PrintableCommand;
+import org.openjdk.btrace.mcp.ClientManager;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Handles the deploy_script MCP tool - deploys a full BTrace Java script. */
+public final class DeployScriptHandler {
+ private static final Logger log = LoggerFactory.getLogger(DeployScriptHandler.class);
+ private static final int DEFAULT_PORT = 2020;
+ private static final int PROBE_TIMEOUT_SECONDS = 30;
+
+ private DeployScriptHandler() {}
+
+ /** Returns tool schema for MCP tools/list. */
+ public static Map schema() {
+ Map tool = new LinkedHashMap<>();
+ tool.put("name", "deploy_script");
+ tool.put(
+ "description",
+ "Deploy a full BTrace Java script to a running JVM. "
+ + "The script must be a valid BTrace program with @BTrace annotation. "
+ + "Use this for complex instrumentation that cannot be expressed as a oneliner.");
+
+ Map properties = new LinkedHashMap<>();
+
+ Map pidProp = new LinkedHashMap<>();
+ pidProp.put("type", "string");
+ pidProp.put("description", "PID of the target JVM (use list_jvms to find it)");
+ properties.put("pid", pidProp);
+
+ Map scriptProp = new LinkedHashMap<>();
+ scriptProp.put("type", "string");
+ scriptProp.put(
+ "description",
+ "Full BTrace Java source code. Must include @BTrace annotation and proper imports.");
+ properties.put("script", scriptProp);
+
+ Map argsProp = new LinkedHashMap<>();
+ argsProp.put("type", "array");
+ Map argsItems = new LinkedHashMap<>();
+ argsItems.put("type", "string");
+ argsProp.put("items", argsItems);
+ argsProp.put("description", "Optional arguments to pass to the BTrace script");
+ properties.put("args", argsProp);
+
+ Map portProp = new LinkedHashMap<>();
+ portProp.put("type", "integer");
+ portProp.put("description", "BTrace agent port (default: 2020)");
+ properties.put("port", portProp);
+
+ List required = new ArrayList<>();
+ required.add("pid");
+ required.add("script");
+
+ Map inputSchema = new LinkedHashMap<>();
+ inputSchema.put("type", "object");
+ inputSchema.put("properties", properties);
+ inputSchema.put("required", required);
+ tool.put("inputSchema", inputSchema);
+ return tool;
+ }
+
+ /** Executes the deploy_script tool. */
+ @SuppressWarnings("unchecked")
+ public static Map execute(Map arguments) {
+ String pid = getStringArg(arguments, "pid");
+ String script = getStringArg(arguments, "script");
+ int port = getIntArg(arguments, "port", DEFAULT_PORT);
+
+ if (pid == null || pid.isEmpty()) {
+ return toolResult("Error: 'pid' parameter is required", true);
+ }
+ if (script == null || script.isEmpty()) {
+ return toolResult("Error: 'script' parameter is required", true);
+ }
+
+ // Extract args
+ String[] btraceArgs = new String[0];
+ Object argsObj = arguments == null ? null : arguments.get("args");
+ if (argsObj instanceof List) {
+ List argsList = (List) argsObj;
+ btraceArgs = new String[argsList.size()];
+ for (int i = 0; i < argsList.size(); i++) {
+ btraceArgs[i] = argsList.get(i).toString();
+ }
+ }
+
+ try {
+ String fileName = "BTraceScript_" + System.currentTimeMillis() + ".java";
+
+ // Compile the script
+ Client client = ClientManager.getClient(port);
+ StringWriter errorWriter = new StringWriter();
+ PrintWriter errPw = new PrintWriter(errorWriter);
+ byte[] code = client.compileSource(fileName, script, ".", errPw, null);
+
+ if (code == null) {
+ String errors = errorWriter.toString();
+ return toolResult("Script compilation failed:\n" + errors, true);
+ }
+
+ // Attach and submit
+ client.attach(pid, null, ".");
+
+ StringBuilder output = new StringBuilder();
+ CountDownLatch statusLatch = new CountDownLatch(1);
+ AtomicBoolean success = new AtomicBoolean(false);
+ AtomicBoolean exited = new AtomicBoolean(false);
+ final String[] finalArgs = btraceArgs;
+
+ client.submit(
+ "localhost",
+ fileName,
+ code,
+ finalArgs,
+ cmd -> {
+ int type = cmd.getType();
+ if (cmd instanceof PrintableCommand) {
+ StringWriter sw = new StringWriter();
+ ((PrintableCommand) cmd).print(new java.io.PrintWriter(sw));
+ output.append(sw.toString());
+ }
+ if (type == Command.STATUS) {
+ success.set(true);
+ statusLatch.countDown();
+ }
+ if (type == Command.EXIT) {
+ exited.set(true);
+ statusLatch.countDown();
+ }
+ });
+
+ boolean started = statusLatch.await(PROBE_TIMEOUT_SECONDS, TimeUnit.SECONDS);
+
+ if (!started) {
+ return toolResult("Probe deployment timed out after " + PROBE_TIMEOUT_SECONDS + "s", true);
+ }
+
+ if (exited.get() && !success.get()) {
+ return toolResult("Probe exited with error:\n" + output.toString(), true);
+ }
+
+ String resultText =
+ "Script deployed successfully to PID " + pid + ".\nPort: " + port + "\n";
+ if (output.length() > 0) {
+ resultText += "\nInitial output:\n" + output.toString();
+ }
+ resultText +=
+ "\nUse send_event, detach_probe, or exit_probe to interact with the running probe.";
+ return toolResult(resultText, false);
+ } catch (Exception e) {
+ log.error("Failed to deploy script", e);
+ return toolResult("Error deploying script: " + e.getMessage(), true);
+ }
+ }
+
+ private static String getStringArg(Map args, String key) {
+ Object val = args == null ? null : args.get(key);
+ return val == null ? null : val.toString();
+ }
+
+ private static int getIntArg(Map args, String key, int defaultVal) {
+ Object val = args == null ? null : args.get(key);
+ if (val == null) {
+ return defaultVal;
+ }
+ if (val instanceof Number) {
+ return ((Number) val).intValue();
+ }
+ try {
+ return Integer.parseInt(val.toString());
+ } catch (NumberFormatException e) {
+ return defaultVal;
+ }
+ }
+
+ private static Map toolResult(String text, boolean isError) {
+ Map content = new LinkedHashMap<>();
+ content.put("type", "text");
+ content.put("text", text);
+ List contentList = new ArrayList<>();
+ contentList.add(content);
+ Map result = new LinkedHashMap<>();
+ result.put("content", contentList);
+ result.put("isError", isError);
+ return result;
+ }
+}
diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DetachProbeHandler.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DetachProbeHandler.java
new file mode 100644
index 00000000..e277963b
--- /dev/null
+++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DetachProbeHandler.java
@@ -0,0 +1,134 @@
+/*
+ * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation. Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package org.openjdk.btrace.mcp.tools;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import org.openjdk.btrace.client.Client;
+import org.openjdk.btrace.mcp.ClientManager;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/**
+ * Handles the detach_probe MCP tool - detaches the client from a running probe (probe continues
+ * running).
+ */
+public final class DetachProbeHandler {
+ private static final Logger log = LoggerFactory.getLogger(DetachProbeHandler.class);
+ private static final int DEFAULT_PORT = 2020;
+
+ private DetachProbeHandler() {}
+
+ /** Returns tool schema for MCP tools/list. */
+ public static Map schema() {
+ Map tool = new LinkedHashMap<>();
+ tool.put("name", "detach_probe");
+ tool.put(
+ "description",
+ "Detach from a running BTrace probe. The probe continues running on the target JVM "
+ + "but the client disconnects. You can later reconnect to it.");
+
+ Map properties = new LinkedHashMap<>();
+
+ Map pidProp = new LinkedHashMap<>();
+ pidProp.put("type", "string");
+ pidProp.put("description", "PID of the target JVM");
+ properties.put("pid", pidProp);
+
+ Map portProp = new LinkedHashMap<>();
+ portProp.put("type", "integer");
+ portProp.put("description", "BTrace agent port (default: 2020)");
+ properties.put("port", portProp);
+
+ List required = new ArrayList<>();
+ required.add("pid");
+
+ Map inputSchema = new LinkedHashMap<>();
+ inputSchema.put("type", "object");
+ inputSchema.put("properties", properties);
+ inputSchema.put("required", required);
+ tool.put("inputSchema", inputSchema);
+ return tool;
+ }
+
+ /** Executes the detach_probe tool. */
+ public static Map execute(Map arguments) {
+ String pid = getStringArg(arguments, "pid");
+ int port = getIntArg(arguments, "port", DEFAULT_PORT);
+
+ if (pid == null || pid.isEmpty()) {
+ return toolResult("Error: 'pid' parameter is required", true);
+ }
+
+ try {
+ Client client = ClientManager.getExistingClient(pid, port);
+ if (client == null) {
+ return toolResult("No active BTrace session for PID " + pid, true);
+ }
+
+ client.sendDisconnect();
+ ClientManager.removeClient(pid, port);
+ return toolResult("Detached from probe on PID " + pid + ". Probe continues running.", false);
+ } catch (Exception e) {
+ log.error("Failed to detach probe", e);
+ return toolResult("Error detaching from PID " + pid + ": " + e.getMessage(), true);
+ }
+ }
+
+ private static String getStringArg(Map args, String key) {
+ Object val = args == null ? null : args.get(key);
+ return val == null ? null : val.toString();
+ }
+
+ private static int getIntArg(Map args, String key, int defaultVal) {
+ Object val = args == null ? null : args.get(key);
+ if (val == null) {
+ return defaultVal;
+ }
+ if (val instanceof Number) {
+ return ((Number) val).intValue();
+ }
+ try {
+ return Integer.parseInt(val.toString());
+ } catch (NumberFormatException e) {
+ return defaultVal;
+ }
+ }
+
+ private static Map toolResult(String text, boolean isError) {
+ Map content = new LinkedHashMap<>();
+ content.put("type", "text");
+ content.put("text", text);
+ List contentList = new ArrayList<>();
+ contentList.add(content);
+ Map result = new LinkedHashMap<>();
+ result.put("content", contentList);
+ result.put("isError", isError);
+ return result;
+ }
+}
diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ExitProbeHandler.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ExitProbeHandler.java
new file mode 100644
index 00000000..feea97c3
--- /dev/null
+++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ExitProbeHandler.java
@@ -0,0 +1,131 @@
+/*
+ * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation. Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package org.openjdk.btrace.mcp.tools;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import org.openjdk.btrace.client.Client;
+import org.openjdk.btrace.mcp.ClientManager;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Handles the exit_probe MCP tool - stops and removes a probe from the target JVM. */
+public final class ExitProbeHandler {
+ private static final Logger log = LoggerFactory.getLogger(ExitProbeHandler.class);
+ private static final int DEFAULT_PORT = 2020;
+
+ private ExitProbeHandler() {}
+
+ /** Returns tool schema for MCP tools/list. */
+ public static Map schema() {
+ Map tool = new LinkedHashMap<>();
+ tool.put("name", "exit_probe");
+ tool.put(
+ "description",
+ "Stop and remove a BTrace probe from the target JVM. "
+ + "This terminates the probe and removes all instrumentation.");
+
+ Map properties = new LinkedHashMap<>();
+
+ Map pidProp = new LinkedHashMap<>();
+ pidProp.put("type", "string");
+ pidProp.put("description", "PID of the target JVM");
+ properties.put("pid", pidProp);
+
+ Map portProp = new LinkedHashMap<>();
+ portProp.put("type", "integer");
+ portProp.put("description", "BTrace agent port (default: 2020)");
+ properties.put("port", portProp);
+
+ List required = new ArrayList<>();
+ required.add("pid");
+
+ Map inputSchema = new LinkedHashMap<>();
+ inputSchema.put("type", "object");
+ inputSchema.put("properties", properties);
+ inputSchema.put("required", required);
+ tool.put("inputSchema", inputSchema);
+ return tool;
+ }
+
+ /** Executes the exit_probe tool. */
+ public static Map execute(Map arguments) {
+ String pid = getStringArg(arguments, "pid");
+ int port = getIntArg(arguments, "port", DEFAULT_PORT);
+
+ if (pid == null || pid.isEmpty()) {
+ return toolResult("Error: 'pid' parameter is required", true);
+ }
+
+ try {
+ Client client = ClientManager.getExistingClient(pid, port);
+ if (client == null) {
+ return toolResult("No active BTrace session for PID " + pid, true);
+ }
+
+ client.sendExit(0);
+ ClientManager.removeClient(pid, port);
+ return toolResult("Probe stopped and removed from PID " + pid, false);
+ } catch (Exception e) {
+ log.error("Failed to exit probe", e);
+ return toolResult("Error stopping probe on PID " + pid + ": " + e.getMessage(), true);
+ }
+ }
+
+ private static String getStringArg(Map args, String key) {
+ Object val = args == null ? null : args.get(key);
+ return val == null ? null : val.toString();
+ }
+
+ private static int getIntArg(Map args, String key, int defaultVal) {
+ Object val = args == null ? null : args.get(key);
+ if (val == null) {
+ return defaultVal;
+ }
+ if (val instanceof Number) {
+ return ((Number) val).intValue();
+ }
+ try {
+ return Integer.parseInt(val.toString());
+ } catch (NumberFormatException e) {
+ return defaultVal;
+ }
+ }
+
+ private static Map toolResult(String text, boolean isError) {
+ Map content = new LinkedHashMap<>();
+ content.put("type", "text");
+ content.put("text", text);
+ List contentList = new ArrayList<>();
+ contentList.add(content);
+ Map result = new LinkedHashMap<>();
+ result.put("content", contentList);
+ result.put("isError", isError);
+ return result;
+ }
+}
diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ListJvmsHandler.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ListJvmsHandler.java
new file mode 100644
index 00000000..83639a5e
--- /dev/null
+++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ListJvmsHandler.java
@@ -0,0 +1,135 @@
+/*
+ * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation. Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package org.openjdk.btrace.mcp.tools;
+
+import java.util.ArrayList;
+import java.util.Collection;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+import sun.jvmstat.monitor.MonitoredHost;
+import sun.jvmstat.monitor.MonitoredVm;
+import sun.jvmstat.monitor.MonitoredVmUtil;
+import sun.jvmstat.monitor.VmIdentifier;
+
+/** Handles the list_jvms MCP tool - lists all attachable Java VMs. */
+public final class ListJvmsHandler {
+ private static final Logger log = LoggerFactory.getLogger(ListJvmsHandler.class);
+
+ private ListJvmsHandler() {}
+
+ /** Returns tool schema for MCP tools/list. */
+ public static Map schema() {
+ Map tool = new LinkedHashMap<>();
+ tool.put("name", "list_jvms");
+ tool.put(
+ "description",
+ "List all attachable Java Virtual Machines on this host. "
+ + "Returns PID, main class, and whether BTrace is already attached (+/-). "
+ + "Use this to find the PID of the JVM you want to instrument.");
+ // No input parameters needed
+ Map inputSchema = new LinkedHashMap<>();
+ inputSchema.put("type", "object");
+ inputSchema.put("properties", new LinkedHashMap<>());
+ tool.put("inputSchema", inputSchema);
+ return tool;
+ }
+
+ /** Executes the list_jvms tool. */
+ public static Map execute(Map arguments) {
+ try {
+ Collection> vms = listVms();
+ StringBuilder sb = new StringBuilder();
+ if (vms.isEmpty()) {
+ sb.append("No attachable Java VMs found.");
+ } else {
+ sb.append("Attachable Java VMs:\n\n");
+ for (Map vm : vms) {
+ sb.append(
+ String.format(
+ " PID: %s | Main Class: %s | BTrace: %s\n",
+ vm.get("pid"), vm.get("mainClass"), vm.get("btraceAttached")));
+ }
+ }
+ return toolResult(sb.toString(), false);
+ } catch (Exception e) {
+ log.error("Failed to list JVMs", e);
+ return toolResult("Error listing JVMs: " + e.getMessage(), true);
+ }
+ }
+
+ private static Collection> listVms() {
+ List> result = new ArrayList<>();
+ try {
+ MonitoredHost vmHost = MonitoredHost.getMonitoredHost((String) null);
+ for (Integer vmPid : MonitoredHost.getMonitoredHost("localhost").activeVms()) {
+ VmIdentifier id = new VmIdentifier(vmPid.toString());
+ MonitoredVm mvm = vmHost.getMonitoredVm(id);
+ if (MonitoredVmUtil.isAttachable(mvm)) {
+ Map vmInfo = new LinkedHashMap<>();
+ vmInfo.put("pid", vmPid);
+ vmInfo.put("mainClass", MonitoredVmUtil.mainClass(mvm, false));
+ vmInfo.put("btraceAttached", hasBTraceServer(vmPid) ? "attached" : "not attached");
+ result.add(vmInfo);
+ }
+ }
+ } catch (Exception e) {
+ log.warn("Error listing VMs", e);
+ }
+ return result;
+ }
+
+ private static boolean hasBTraceServer(int pid) {
+ com.sun.tools.attach.VirtualMachine vm = null;
+ try {
+ vm = com.sun.tools.attach.VirtualMachine.attach(String.valueOf(pid));
+ return vm.getSystemProperties().containsKey("btrace.port");
+ } catch (Throwable ignored) {
+ return false;
+ } finally {
+ if (vm != null) {
+ try {
+ vm.detach();
+ } catch (Exception ignored) {
+ }
+ }
+ }
+ }
+
+ private static Map toolResult(String text, boolean isError) {
+ Map content = new LinkedHashMap<>();
+ content.put("type", "text");
+ content.put("text", text);
+ List contentList = new ArrayList<>();
+ contentList.add(content);
+ Map result = new LinkedHashMap<>();
+ result.put("content", contentList);
+ result.put("isError", isError);
+ return result;
+ }
+}
diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ListProbesHandler.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ListProbesHandler.java
new file mode 100644
index 00000000..2fe8ebb2
--- /dev/null
+++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ListProbesHandler.java
@@ -0,0 +1,155 @@
+/*
+ * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation. Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package org.openjdk.btrace.mcp.tools;
+
+import java.io.StringWriter;
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import java.util.concurrent.CountDownLatch;
+import java.util.concurrent.TimeUnit;
+import org.openjdk.btrace.client.Client;
+import org.openjdk.btrace.core.comm.Command;
+import org.openjdk.btrace.core.comm.PrintableCommand;
+import org.openjdk.btrace.mcp.ClientManager;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Handles the list_probes MCP tool - lists active probes on a JVM. */
+public final class ListProbesHandler {
+ private static final Logger log = LoggerFactory.getLogger(ListProbesHandler.class);
+ private static final int DEFAULT_PORT = 2020;
+ private static final int TIMEOUT_SECONDS = 10;
+
+ private ListProbesHandler() {}
+
+ /** Returns tool schema for MCP tools/list. */
+ public static Map schema() {
+ Map tool = new LinkedHashMap<>();
+ tool.put("name", "list_probes");
+ tool.put(
+ "description",
+ "List active BTrace probes on a running JVM. "
+ + "Shows which probes are currently deployed and their IDs.");
+
+ Map properties = new LinkedHashMap<>();
+
+ Map pidProp = new LinkedHashMap<>();
+ pidProp.put("type", "string");
+ pidProp.put("description", "PID of the target JVM");
+ properties.put("pid", pidProp);
+
+ Map portProp = new LinkedHashMap<>();
+ portProp.put("type", "integer");
+ portProp.put("description", "BTrace agent port (default: 2020)");
+ properties.put("port", portProp);
+
+ List required = new ArrayList<>();
+ required.add("pid");
+
+ Map inputSchema = new LinkedHashMap<>();
+ inputSchema.put("type", "object");
+ inputSchema.put("properties", properties);
+ inputSchema.put("required", required);
+ tool.put("inputSchema", inputSchema);
+ return tool;
+ }
+
+ /** Executes the list_probes tool. */
+ public static Map execute(Map arguments) {
+ String pid = getStringArg(arguments, "pid");
+ int port = getIntArg(arguments, "port", DEFAULT_PORT);
+
+ if (pid == null || pid.isEmpty()) {
+ return toolResult("Error: 'pid' parameter is required", true);
+ }
+
+ try {
+ Client client = ClientManager.getClient(port);
+ client.attach(pid, null, ".");
+
+ StringBuilder output = new StringBuilder();
+ CountDownLatch latch = new CountDownLatch(1);
+
+ client.connectAndListProbes(
+ "localhost",
+ cmd -> {
+ if (cmd instanceof PrintableCommand) {
+ StringWriter sw = new StringWriter();
+ ((PrintableCommand) cmd).print(new java.io.PrintWriter(sw));
+ output.append(sw.toString());
+ }
+ if (cmd.getType() == Command.LIST_PROBES) {
+ latch.countDown();
+ }
+ });
+
+ boolean done = latch.await(TIMEOUT_SECONDS, TimeUnit.SECONDS);
+ if (!done) {
+ return toolResult("Timed out waiting for probe list from PID " + pid, true);
+ }
+
+ String resultText = "Active probes on PID " + pid + ":\n" + output.toString();
+ return toolResult(resultText, false);
+ } catch (Exception e) {
+ log.error("Failed to list probes", e);
+ return toolResult("Error listing probes on PID " + pid + ": " + e.getMessage(), true);
+ }
+ }
+
+ private static String getStringArg(Map args, String key) {
+ Object val = args == null ? null : args.get(key);
+ return val == null ? null : val.toString();
+ }
+
+ private static int getIntArg(Map args, String key, int defaultVal) {
+ Object val = args == null ? null : args.get(key);
+ if (val == null) {
+ return defaultVal;
+ }
+ if (val instanceof Number) {
+ return ((Number) val).intValue();
+ }
+ try {
+ return Integer.parseInt(val.toString());
+ } catch (NumberFormatException e) {
+ return defaultVal;
+ }
+ }
+
+ private static Map toolResult(String text, boolean isError) {
+ Map content = new LinkedHashMap<>();
+ content.put("type", "text");
+ content.put("text", text);
+ List contentList = new ArrayList<>();
+ contentList.add(content);
+ Map result = new LinkedHashMap<>();
+ result.put("content", contentList);
+ result.put("isError", isError);
+ return result;
+ }
+}
diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/SendEventHandler.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/SendEventHandler.java
new file mode 100644
index 00000000..095936b0
--- /dev/null
+++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/SendEventHandler.java
@@ -0,0 +1,146 @@
+/*
+ * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved.
+ * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
+ *
+ * This code is free software; you can redistribute it and/or modify it
+ * under the terms of the GNU General Public License version 2 only, as
+ * published by the Free Software Foundation. Oracle designates this
+ * particular file as subject to the "Classpath" exception as provided
+ * by Oracle in the LICENSE file that accompanied this code.
+ *
+ * This code is distributed in the hope that it will be useful, but WITHOUT
+ * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
+ * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License
+ * version 2 for more details (a copy is included in the LICENSE file that
+ * accompanied this code).
+ *
+ * You should have received a copy of the GNU General Public License version
+ * 2 along with this work; if not, write to the Free Software Foundation,
+ * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA.
+ *
+ * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA
+ * or visit www.oracle.com if you need additional information or have any
+ * questions.
+ */
+
+package org.openjdk.btrace.mcp.tools;
+
+import java.util.ArrayList;
+import java.util.LinkedHashMap;
+import java.util.List;
+import java.util.Map;
+import org.openjdk.btrace.client.Client;
+import org.openjdk.btrace.mcp.ClientManager;
+import org.slf4j.Logger;
+import org.slf4j.LoggerFactory;
+
+/** Handles the send_event MCP tool - sends an event to a running probe. */
+public final class SendEventHandler {
+ private static final Logger log = LoggerFactory.getLogger(SendEventHandler.class);
+ private static final int DEFAULT_PORT = 2020;
+
+ private SendEventHandler() {}
+
+ /** Returns tool schema for MCP tools/list. */
+ public static Map schema() {
+ Map tool = new LinkedHashMap<>();
+ tool.put("name", "send_event");
+ tool.put(
+ "description",
+ "Send an event to a running BTrace probe. "
+ + "Events can trigger @OnEvent handlers in the probe script. "
+ + "If no event_name is specified, an unnamed event is sent.");
+
+ Map properties = new LinkedHashMap<>();
+
+ Map pidProp = new LinkedHashMap<>();
+ pidProp.put("type", "string");
+ pidProp.put("description", "PID of the target JVM");
+ properties.put("pid", pidProp);
+
+ Map eventNameProp = new LinkedHashMap<>();
+ eventNameProp.put("type", "string");
+ eventNameProp.put("description", "Name of the event to send (optional)");
+ properties.put("event_name", eventNameProp);
+
+ Map portProp = new LinkedHashMap<>();
+ portProp.put("type", "integer");
+ portProp.put("description", "BTrace agent port (default: 2020)");
+ properties.put("port", portProp);
+
+ List required = new ArrayList<>();
+ required.add("pid");
+
+ Map inputSchema = new LinkedHashMap<>();
+ inputSchema.put("type", "object");
+ inputSchema.put("properties", properties);
+ inputSchema.put("required", required);
+ tool.put("inputSchema", inputSchema);
+ return tool;
+ }
+
+ /** Executes the send_event tool. */
+ public static Map execute(Map arguments) {
+ String pid = getStringArg(arguments, "pid");
+ String eventName = getStringArg(arguments, "event_name");
+ int port = getIntArg(arguments, "port", DEFAULT_PORT);
+
+ if (pid == null || pid.isEmpty()) {
+ return toolResult("Error: 'pid' parameter is required", true);
+ }
+
+ try {
+ Client client = ClientManager.getExistingClient(pid, port);
+ if (client == null) {
+ return toolResult(
+ "No active BTrace session for PID "
+ + pid
+ + ". Deploy a probe first using deploy_oneliner or deploy_script.",
+ true);
+ }
+
+ if (eventName != null && !eventName.isEmpty()) {
+ client.sendEvent(eventName);
+ return toolResult("Event '" + eventName + "' sent to PID " + pid, false);
+ } else {
+ client.sendEvent();
+ return toolResult("Unnamed event sent to PID " + pid, false);
+ }
+ } catch (Exception e) {
+ log.error("Failed to send event", e);
+ return toolResult("Error sending event to PID " + pid + ": " + e.getMessage(), true);
+ }
+ }
+
+ private static String getStringArg(Map args, String key) {
+ Object val = args == null ? null : args.get(key);
+ return val == null ? null : val.toString();
+ }
+
+ private static int getIntArg(Map args, String key, int defaultVal) {
+ Object val = args == null ? null : args.get(key);
+ if (val == null) {
+ return defaultVal;
+ }
+ if (val instanceof Number) {
+ return ((Number) val).intValue();
+ }
+ try {
+ return Integer.parseInt(val.toString());
+ } catch (NumberFormatException e) {
+ return defaultVal;
+ }
+ }
+
+ private static Map toolResult(String text, boolean isError) {
+ Map content = new LinkedHashMap<>();
+ content.put("type", "text");
+ content.put("text", text);
+ List contentList = new ArrayList<>();
+ contentList.add(content);
+ Map result = new LinkedHashMap<>();
+ result.put("content", contentList);
+ result.put("isError", isError);
+ return result;
+ }
+}
diff --git a/settings.gradle b/settings.gradle
index bbbbc640..6225e86e 100644
--- a/settings.gradle
+++ b/settings.gradle
@@ -45,6 +45,10 @@ include 'btrace-extensions:btrace-metrics'
include 'btrace-extensions:btrace-utils'
include 'btrace-extensions:btrace-statsd'
include 'btrace-extensions:btrace-streams'
+include 'btrace-extensions:btrace-llm-trace'
+include 'btrace-extensions:btrace-rag-quality'
+include 'btrace-extensions:btrace-vibe-guard'
+include 'btrace-extensions:btrace-gpu-bridge'
dependencyResolutionManagement {
versionCatalogs {