diff --git a/btrace-client/src/main/java/org/openjdk/btrace/client/Client.java b/btrace-client/src/main/java/org/openjdk/btrace/client/Client.java index b80b9410..ca6f2b2e 100644 --- a/btrace-client/src/main/java/org/openjdk/btrace/client/Client.java +++ b/btrace-client/src/main/java/org/openjdk/btrace/client/Client.java @@ -758,7 +758,7 @@ public void attach(String pid, String agentPath, String sysCp, String bootCp) th } } - void connectAndListProbes(String host, CommandListener listener) throws IOException { + public void connectAndListProbes(String host, CommandListener listener) throws IOException { if (sock != null) { throw new IllegalStateException(); } @@ -1071,11 +1071,11 @@ public synchronized void close() throws IOException { reset(); } - boolean isDisconnected() { + public boolean isDisconnected() { return disconnected; } - void disconnect() throws IOException { + public void disconnect() throws IOException { disconnected = true; if (log.isDebugEnabled()) { log.debug("sending DISCONNECT request to agent"); diff --git a/btrace-dist/src/main/resources/samples/GpuBridge.java b/btrace-dist/src/main/resources/samples/GpuBridge.java new file mode 100644 index 00000000..f8e451a4 --- /dev/null +++ b/btrace-dist/src/main/resources/samples/GpuBridge.java @@ -0,0 +1,79 @@ +import org.openjdk.btrace.core.annotations.*; +import org.openjdk.btrace.core.extensions.Injected; +import org.openjdk.btrace.gpu.GpuBridgeService; + +import static org.openjdk.btrace.core.BTraceUtils.*; + +/** + * Traces GPU model inference via ONNX Runtime and DJL (Deep Java Library). + * Tracks inference latency, batch sizes, and model load times. + * + *

Attach to a JVM running ONNX or DJL inference: + *

+ * btrace <pid> GpuBridge.java
+ * 
+ */ +@BTrace +public class GpuBridge { + + @Injected + private static GpuBridgeService gpu; + + // ==================== ONNX Runtime ==================== + + @OnMethod( + clazz = "ai.onnxruntime.OrtSession", + method = "run", + location = @Location(Kind.RETURN)) + public static void onOnnxInference(@Duration long dur) { + gpu.recordInference("onnx", "session", dur); + } + + @OnMethod( + clazz = "ai.onnxruntime.OrtSession", + method = "", + location = @Location(Kind.RETURN)) + public static void onOnnxModelLoad(@Duration long dur) { + gpu.recordModelLoad("onnx", "session", dur); + } + + // ==================== DJL (Deep Java Library) ==================== + + @OnMethod( + clazz = "/ai\\.djl\\.inference\\.Predictor/", + method = "predict", + location = @Location(Kind.RETURN)) + public static void onDjlPredict(@Duration long dur) { + gpu.recordInference("djl", "predictor", dur); + } + + @OnMethod( + clazz = "/ai\\.djl\\.repository\\.zoo\\.ModelZoo/", + method = "loadModel", + location = @Location(Kind.RETURN)) + public static void onDjlModelLoad(@Duration long dur) { + gpu.recordModelLoad("djl", "model-zoo", dur); + } + + // ==================== TensorFlow Java ==================== + + @OnMethod( + clazz = "/org\\.tensorflow\\.Session/", + method = "run", + location = @Location(Kind.RETURN)) + public static void onTensorFlowRun(@Duration long dur) { + gpu.recordInference("tensorflow", "session", dur); + } + + // ==================== Periodic summary ==================== + + @OnTimer(30000) + public static void periodicSummary() { + println(gpu.getSummary()); + } + + @OnEvent("summary") + public static void onDemandSummary() { + println(gpu.getSummary()); + } +} diff --git a/btrace-dist/src/main/resources/samples/LlmTrace.java b/btrace-dist/src/main/resources/samples/LlmTrace.java new file mode 100644 index 00000000..98e2cefc --- /dev/null +++ b/btrace-dist/src/main/resources/samples/LlmTrace.java @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the Classpath exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +import static org.openjdk.btrace.core.BTraceUtils.*; + +import org.openjdk.btrace.core.annotations.*; +import org.openjdk.btrace.llm.LlmTraceService; + +/** + * Sample BTrace script that traces LLM API calls using the btrace-llm-trace extension. + * + *

This is a generic template. It intercepts Langchain4j's ChatLanguageModel.generate() + * method and records call metrics. Adapt the @OnMethod annotations to target your specific + * LLM SDK (OpenAI Java SDK, Anthropic Java SDK, Spring AI, etc.). + * + *

Usage: + * btrace <PID> LlmTrace.java + * + *

Send a named event "summary" to print the current stats: + * (Ctrl-C, option 3, enter "summary") + */ +@BTrace +public class LlmTrace { + + @Injected + private static LlmTraceService llm; + + /** + * Trace Langchain4j ChatLanguageModel.generate() calls. + * Captures latency on every call completion. + */ + @OnMethod( + clazz = "+dev.langchain4j.model.chat.ChatLanguageModel", + method = "generate", + location = @Location(Kind.RETURN)) + public static void onLangchain4jGenerate( + @ProbeClassName String className, + @ProbeMethodName String methodName, + @Duration long duration) { + // Model name extracted from the class; token counts need return value parsing + // For a production script, parse the Response return value + llm.recordCall("langchain4j", className, 0, 0, duration); + println(strcat(strcat(strcat("LLM call: ", className), " "), + strcat(str(duration / 1000000L), "ms"))); + } + + /** + * Trace Langchain4j StreamingChatLanguageModel calls. + */ + @OnMethod( + clazz = "+dev.langchain4j.model.chat.StreamingChatLanguageModel", + method = "generate", + location = @Location(Kind.RETURN)) + public static void onLangchain4jStreaming( + @ProbeClassName String className, + @Duration long duration) { + llm.recordStreamingCall(className, 0, 0, duration, 0); + } + + /** + * Trace errors from any ChatLanguageModel implementation. + */ + @OnMethod( + clazz = "+dev.langchain4j.model.chat.ChatLanguageModel", + method = "generate", + location = @Location(Kind.ERROR)) + public static void onLangchain4jError( + @ProbeClassName String className, + @Duration long duration, + Throwable error) { + llm.recordError(className, Strings.str(error.getClass()), duration); + println(strcat("LLM ERROR: ", Strings.str(error))); + } + + /** + * Print summary on named event "summary". + */ + @OnEvent("summary") + public static void onSummary() { + println(llm.getSummary()); + } + + /** + * Print summary periodically (every 30 seconds). + */ + @OnTimer(30000) + public static void onTimer() { + if (llm.getTotalCalls() > 0) { + println(llm.getSummary()); + } + } +} diff --git a/btrace-dist/src/main/resources/samples/RagQuality.java b/btrace-dist/src/main/resources/samples/RagQuality.java new file mode 100644 index 00000000..f75497ac --- /dev/null +++ b/btrace-dist/src/main/resources/samples/RagQuality.java @@ -0,0 +1,73 @@ +import org.openjdk.btrace.core.annotations.*; +import org.openjdk.btrace.core.extensions.Injected; +import org.openjdk.btrace.rag.RagQualityService; + +import static org.openjdk.btrace.core.BTraceUtils.*; + +/** + * Traces RAG pipeline performance: vector DB query latency, similarity + * scores, and empty retrieval rates. Targets Pinecone Java client. + * + *

Attach to a JVM running a RAG pipeline: + *

+ * btrace <pid> RagQuality.java
+ * 
+ */ +@BTrace +public class RagQuality { + + @Injected + private static RagQualityService rag; + + // ==================== Pinecone ==================== + + @OnMethod( + clazz = "/io\\.pinecone\\..*/", + method = "query", + location = @Location(Kind.RETURN)) + public static void onPineconeQuery(@Duration long dur) { + rag.recordQuery("pinecone", dur); + } + + // ==================== Milvus ==================== + + @OnMethod( + clazz = "/io\\.milvus\\.client\\..*/", + method = "search", + location = @Location(Kind.RETURN)) + public static void onMilvusSearch(@Duration long dur) { + rag.recordQuery("milvus", dur); + } + + // ==================== Weaviate ==================== + + @OnMethod( + clazz = "/io\\.weaviate\\.client\\..*/", + method = "/get|search/", + location = @Location(Kind.RETURN)) + public static void onWeaviateQuery(@Duration long dur) { + rag.recordQuery("weaviate", dur); + } + + // ==================== Chroma ==================== + + @OnMethod( + clazz = "/tech\\.amikos\\.chromadb\\..*/", + method = "query", + location = @Location(Kind.RETURN)) + public static void onChromaQuery(@Duration long dur) { + rag.recordQuery("chroma", dur); + } + + // ==================== Periodic summary ==================== + + @OnTimer(30000) + public static void periodicSummary() { + println(rag.getSummary()); + } + + @OnEvent("summary") + public static void onDemandSummary() { + println(rag.getSummary()); + } +} diff --git a/btrace-dist/src/main/resources/samples/VibeGuard.java b/btrace-dist/src/main/resources/samples/VibeGuard.java new file mode 100644 index 00000000..40afbaa3 --- /dev/null +++ b/btrace-dist/src/main/resources/samples/VibeGuard.java @@ -0,0 +1,69 @@ +import org.openjdk.btrace.core.annotations.*; +import org.openjdk.btrace.core.extensions.Injected; +import org.openjdk.btrace.vibeguard.VibeGuardService; + +import static org.openjdk.btrace.core.BTraceUtils.*; + +/** + * Runtime behavioral contracts for AI-generated code. Enforces latency + * budgets, call rate limits, and null-safety on methods that may have + * been generated by an LLM. + * + *

Attach to a JVM running AI-generated service code: + *

+ * btrace <pid> VibeGuard.java
+ * 
+ */ +@BTrace +public class VibeGuard { + + @Injected + private static VibeGuardService guard; + + // ==================== Latency budgets ==================== + + /** + * Enforce 500ms latency budget on all methods in the AI-generated package. + * Adjust the clazz pattern to match your project structure. + */ + @OnMethod( + clazz = "/com\\.myapp\\.ai\\..*/", + method = "/.*/", + location = @Location(Kind.RETURN)) + public static void checkAiLatency( + @ProbeClassName String cls, + @ProbeMethodName String method, + @Duration long dur) { + guard.checkLatency(Strings.strcat(cls, Strings.strcat(".", method)), + dur, 500_000_000L); + } + + // ==================== Null safety on return values ==================== + + @OnMethod( + clazz = "/com\\.myapp\\.ai\\..*/", + method = "/.*/", + location = @Location(Kind.RETURN)) + public static void checkNullReturn( + @ProbeClassName String cls, + @ProbeMethodName String method, + @Return Object ret) { + guard.checkNotNull( + Strings.strcat(cls, Strings.strcat(".", method)), ret); + } + + // ==================== Alert on violations ==================== + + @OnTimer(10000) + public static void checkViolations() { + if (guard.hasViolations()) { + println("=== VIBE GUARD ALERT ==="); + println(guard.getSummary()); + } + } + + @OnEvent("summary") + public static void onDemandSummary() { + println(guard.getSummary()); + } +} diff --git a/btrace-extensions/btrace-gpu-bridge/build.gradle b/btrace-extensions/btrace-gpu-bridge/build.gradle new file mode 100644 index 00000000..23a56aec --- /dev/null +++ b/btrace-extensions/btrace-gpu-bridge/build.gradle @@ -0,0 +1,33 @@ +plugins { + id 'org.openjdk.btrace.extension' + alias(libs.plugins.shadow) +} + +java { + sourceCompatibility = 8 + targetCompatibility = 8 +} + +compileJava { + javaCompiler = javaToolchains.compilerFor { + languageVersion.set(JavaLanguageVersion.of(11)) + } +} + +btraceExtension { + id = 'btrace-gpu-bridge' + name = 'BTrace GPU Bridge' + description = 'GPU and model inference observability: ONNX Runtime, DJL, TensorFlow Java, and Panama FFM compute tracing' + services = ['org.openjdk.btrace.gpu.GpuBridgeService'] +} + +dependencies { + apiCompileOnly project(':btrace-core') + implImplementation project(':btrace-core') + + testImplementation libs.junit.jupiter +} + +test { + useJUnitPlatform() +} diff --git a/btrace-extensions/btrace-gpu-bridge/src/api/java/org/openjdk/btrace/gpu/GpuBridgeService.java b/btrace-extensions/btrace-gpu-bridge/src/api/java/org/openjdk/btrace/gpu/GpuBridgeService.java new file mode 100644 index 00000000..ea9cd57a --- /dev/null +++ b/btrace-extensions/btrace-gpu-bridge/src/api/java/org/openjdk/btrace/gpu/GpuBridgeService.java @@ -0,0 +1,131 @@ +package org.openjdk.btrace.gpu; + +import org.openjdk.btrace.core.extensions.ServiceDescriptor; + +/** + * BTrace extension for GPU compute and model inference observability. + * + *

Traces Java-to-GPU boundaries: ONNX Runtime inference sessions, + * DJL (Deep Java Library) predictions, TensorFlow Java operations, + * and Panama FFM native calls to CUDA/ROCm libraries. Zero external + * dependencies — instruments existing client classes. + * + *

Usage in a BTrace script: + *

+ * @Injected GpuBridgeService gpu;
+ *
+ * @OnMethod(clazz = "ai.onnxruntime.OrtSession", method = "run")
+ * void onInference(@Duration long dur) {
+ *     gpu.recordInference("onnx", "resnet50", dur);
+ * }
+ * 
+ */ +@ServiceDescriptor +public interface GpuBridgeService { + + // ==================== Inference recording ==================== + + /** + * Records a model inference call with duration only. + * + * @param runtime runtime name (e.g. "onnx", "djl", "tensorflow") + * @param modelName model identifier + * @param durationNanos inference duration in nanoseconds + */ + void recordInference(String runtime, String modelName, long durationNanos); + + /** + * Records a model inference with batch size and tensor dimensions. + * + * @param runtime runtime name + * @param modelName model identifier + * @param batchSize batch size of the inference request + * @param durationNanos inference duration + */ + void recordInference(String runtime, String modelName, int batchSize, long durationNanos); + + /** + * Starts a detailed inference record builder. Allocation-free (ThreadLocal-pooled). + * + *
+   * gpu.inference("onnx", "bert-base")
+   *     .batchSize(32)
+   *     .inputElements(512 * 768)
+   *     .outputElements(512 * 2)
+   *     .deviceType("cuda")
+   *     .deviceId(0)
+   *     .duration(durationNanos)
+   *     .record();
+   * 
+ * + * @param runtime runtime name + * @param modelName model identifier + * @return an inference record builder (thread-local, do not store) + */ + InferenceRecord inference(String runtime, String modelName); + + // ==================== Memory tracking ==================== + + /** + * Records GPU memory allocation. + * + * @param deviceType device type ("cuda", "rocm", "cpu") + * @param deviceId device index + * @param bytes allocated bytes + */ + void recordMemoryAlloc(String deviceType, int deviceId, long bytes); + + /** + * Records GPU memory deallocation. + * + * @param deviceType device type + * @param deviceId device index + * @param bytes freed bytes + */ + void recordMemoryFree(String deviceType, int deviceId, long bytes); + + // ==================== Native call tracking ==================== + + /** + * Records a native/FFM call to a GPU library (e.g. cuBLAS, cuDNN). + * + * @param library library name + * @param function function name + * @param durationNanos call duration + */ + void recordNativeCall(String library, String function, long durationNanos); + + // ==================== Model lifecycle ==================== + + /** + * Records model load/initialization time. + * + * @param runtime runtime name + * @param modelName model identifier + * @param durationNanos load duration + */ + void recordModelLoad(String runtime, String modelName, long durationNanos); + + // ==================== Reporting ==================== + + /** Returns a formatted summary of all GPU/inference metrics. */ + String getSummary(); + + /** Returns summary for a specific model. */ + String getModelSummary(String modelName); + + /** Total number of inference calls across all models. */ + long getTotalInferences(); + + /** Returns estimated GPU memory currently allocated (bytes), or -1 if not tracked. */ + long getCurrentGpuMemoryBytes(); + + /** Returns peak GPU memory seen (bytes), or -1 if not tracked. */ + long getPeakGpuMemoryBytes(); + + /** Returns total number of native/FFM calls recorded. */ + long getTotalNativeCalls(); + + /** Resets all collected metrics. */ + void reset(); +} diff --git a/btrace-extensions/btrace-gpu-bridge/src/api/java/org/openjdk/btrace/gpu/InferenceRecord.java b/btrace-extensions/btrace-gpu-bridge/src/api/java/org/openjdk/btrace/gpu/InferenceRecord.java new file mode 100644 index 00000000..58d3a416 --- /dev/null +++ b/btrace-extensions/btrace-gpu-bridge/src/api/java/org/openjdk/btrace/gpu/InferenceRecord.java @@ -0,0 +1,42 @@ +package org.openjdk.btrace.gpu; + +/** + * Fluent builder for recording a model inference with detailed metrics. + * + *

Obtain via {@link GpuBridgeService#inference(String, String)}. + * Allocation-free (ThreadLocal-pooled). Do not store the returned reference. + * + *

+ * gpu.inference("onnx", "bert-base")
+ *     .batchSize(32)
+ *     .inputElements(512 * 768)
+ *     .outputElements(512 * 2)
+ *     .deviceType("cuda")
+ *     .deviceId(0)
+ *     .duration(durationNanos)
+ *     .record();
+ * 
+ */ +public interface InferenceRecord { + + /** Batch size for this inference. */ + InferenceRecord batchSize(int size); + + /** Total number of input tensor elements. */ + InferenceRecord inputElements(long elements); + + /** Total number of output tensor elements. */ + InferenceRecord outputElements(long elements); + + /** Device type: "cuda", "rocm", "cpu", "mps", etc. */ + InferenceRecord deviceType(String type); + + /** Device index (for multi-GPU). */ + InferenceRecord deviceId(int id); + + /** Inference duration in nanoseconds. */ + InferenceRecord duration(long nanos); + + /** Commits this inference record. */ + void record(); +} diff --git a/btrace-extensions/btrace-gpu-bridge/src/impl/java/org/openjdk/btrace/gpu/GpuBridgeServiceImpl.java b/btrace-extensions/btrace-gpu-bridge/src/impl/java/org/openjdk/btrace/gpu/GpuBridgeServiceImpl.java new file mode 100644 index 00000000..6cceacb9 --- /dev/null +++ b/btrace-extensions/btrace-gpu-bridge/src/impl/java/org/openjdk/btrace/gpu/GpuBridgeServiceImpl.java @@ -0,0 +1,335 @@ +package org.openjdk.btrace.gpu; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import org.openjdk.btrace.core.extensions.Extension; + +/** + * Thread-safe GPU compute and inference tracking with lock-free statistics. + */ +public final class GpuBridgeServiceImpl extends Extension implements GpuBridgeService { + + private final Map modelStats = new ConcurrentHashMap<>(); + private final Map deviceMemory = new ConcurrentHashMap<>(); + private final Map nativeStats = new ConcurrentHashMap<>(); + + private final ThreadLocal inferenceRecordPool = + ThreadLocal.withInitial(InferenceRecordImpl::new); + + // ==================== Inference recording ==================== + + @Override + public void recordInference(String runtime, String modelName, long durationNanos) { + ModelStats stats = getOrCreateModel(runtime, modelName); + stats.inferences.incrementAndGet(); + stats.totalDurationNanos.addAndGet(durationNanos); + updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos); + } + + @Override + public void recordInference(String runtime, String modelName, int batchSize, long durationNanos) { + ModelStats stats = getOrCreateModel(runtime, modelName); + stats.inferences.incrementAndGet(); + stats.totalBatchSize.addAndGet(batchSize); + stats.totalDurationNanos.addAndGet(durationNanos); + updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos); + } + + @Override + public InferenceRecord inference(String runtime, String modelName) { + return inferenceRecordPool.get().reset(this, runtime, modelName); + } + + void commitInferenceRecord(InferenceRecordImpl rec) { + ModelStats stats = getOrCreateModel(rec.runtime, rec.modelName); + stats.inferences.incrementAndGet(); + stats.totalBatchSize.addAndGet(rec.batchSizeVal); + stats.totalInputElements.addAndGet(rec.inputElem); + stats.totalOutputElements.addAndGet(rec.outputElem); + stats.totalDurationNanos.addAndGet(rec.durationVal); + updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, rec.durationVal); + if (rec.deviceTypeVal != null) { + stats.lastDeviceType = rec.deviceTypeVal; + stats.lastDeviceId = rec.deviceIdVal; + } + } + + // ==================== Memory tracking ==================== + + @Override + public void recordMemoryAlloc(String deviceType, int deviceId, long bytes) { + DeviceMemory dm = getOrCreateDevice(deviceType, deviceId); + dm.currentBytes.addAndGet(bytes); + dm.totalAllocated.addAndGet(bytes); + dm.allocCount.incrementAndGet(); + // Update peak + long cur; + long newVal = dm.currentBytes.get(); + do { + cur = dm.peakBytes.get(); + if (newVal <= cur) break; + } while (!dm.peakBytes.compareAndSet(cur, newVal)); + } + + @Override + public void recordMemoryFree(String deviceType, int deviceId, long bytes) { + DeviceMemory dm = getOrCreateDevice(deviceType, deviceId); + dm.currentBytes.addAndGet(-bytes); + dm.freeCount.incrementAndGet(); + } + + // ==================== Native call tracking ==================== + + @Override + public void recordNativeCall(String library, String function, long durationNanos) { + String key = library + "::" + function; + NativeCallStats stats = nativeStats.computeIfAbsent(key, k -> new NativeCallStats()); + stats.calls.incrementAndGet(); + stats.totalDurationNanos.addAndGet(durationNanos); + } + + // ==================== Model lifecycle ==================== + + @Override + public void recordModelLoad(String runtime, String modelName, long durationNanos) { + ModelStats stats = getOrCreateModel(runtime, modelName); + stats.loadCount.incrementAndGet(); + stats.totalLoadDurationNanos.addAndGet(durationNanos); + } + + // ==================== Reporting ==================== + + @Override + public String getSummary() { + if (modelStats.isEmpty() && deviceMemory.isEmpty() && nativeStats.isEmpty()) { + return "No GPU/inference activity recorded."; + } + + StringBuilder sb = new StringBuilder(); + sb.append("=== GPU Bridge Summary ===\n\n"); + + long totalInf = 0; + + for (Map.Entry entry : modelStats.entrySet()) { + ModelStats s = entry.getValue(); + long inf = s.inferences.get(); + totalInf += inf; + + sb.append("Model: ").append(entry.getKey()); + if (s.lastDeviceType != null) { + sb.append(" [").append(s.lastDeviceType).append(":").append(s.lastDeviceId).append("]"); + } + sb.append("\n"); + + // Inference stats + sb.append(" Inferences: ").append(inf); + long loads = s.loadCount.get(); + if (loads > 0) { + long avgLoadMs = (s.totalLoadDurationNanos.get() / loads) / 1_000_000; + sb.append(" (loaded ").append(loads).append("x, avg ").append(avgLoadMs).append("ms)"); + } + sb.append("\n"); + + // Latency + if (inf > 0) { + long avgMs = (s.totalDurationNanos.get() / inf) / 1_000_000; + long minMs = s.minDurationNanos.get() / 1_000_000; + long maxMs = s.maxDurationNanos.get() / 1_000_000; + sb.append(" Latency: avg ").append(avgMs).append("ms"); + sb.append(", min ").append(minMs).append("ms"); + sb.append(", max ").append(maxMs).append("ms\n"); + } + + // Batch size + long totalBatch = s.totalBatchSize.get(); + if (totalBatch > 0 && inf > 0) { + sb.append(" Batch size: avg ").append(totalBatch / inf).append("\n"); + } + + // Throughput (items/sec) + long totalDur = s.totalDurationNanos.get(); + if (totalBatch > 0 && totalDur > 0) { + double itemsPerSec = (totalBatch * 1_000_000_000.0) / totalDur; + sb.append(" Throughput: ").append(String.format("%.1f", itemsPerSec)).append(" items/sec\n"); + } + + // Tensor elements + long inEl = s.totalInputElements.get(); + long outEl = s.totalOutputElements.get(); + if (inEl > 0 || outEl > 0) { + sb.append(" Tensors: ").append(formatElements(inEl)).append(" in / ") + .append(formatElements(outEl)).append(" out\n"); + } + + sb.append("\n"); + } + + // Device memory + if (!deviceMemory.isEmpty()) { + sb.append("--- Device Memory ---\n"); + for (Map.Entry entry : deviceMemory.entrySet()) { + DeviceMemory dm = entry.getValue(); + sb.append(" ").append(entry.getKey()).append(": "); + sb.append("current ").append(formatBytes(dm.currentBytes.get())); + sb.append(", peak ").append(formatBytes(dm.peakBytes.get())); + sb.append(" (").append(dm.allocCount.get()).append(" allocs, ") + .append(dm.freeCount.get()).append(" frees)\n"); + } + sb.append("\n"); + } + + // Native calls + if (!nativeStats.isEmpty()) { + sb.append("--- Native Calls ---\n"); + long totalNative = 0; + for (Map.Entry entry : nativeStats.entrySet()) { + NativeCallStats ns = entry.getValue(); + long calls = ns.calls.get(); + totalNative += calls; + long avgUs = calls > 0 ? (ns.totalDurationNanos.get() / calls) / 1000 : 0; + sb.append(" ").append(entry.getKey()).append(": ") + .append(calls).append(" calls, avg ").append(avgUs).append("us\n"); + } + sb.append("\n"); + } + + sb.append("--- Totals ---\n"); + sb.append(" Inferences: ").append(totalInf).append("\n"); + + return sb.toString(); + } + + @Override + public String getModelSummary(String modelName) { + // Search by model name suffix + for (Map.Entry entry : modelStats.entrySet()) { + if (entry.getKey().endsWith(modelName) || entry.getKey().equals(modelName)) { + ModelStats s = entry.getValue(); + long inf = s.inferences.get(); + long avgMs = inf > 0 ? (s.totalDurationNanos.get() / inf) / 1_000_000 : 0; + return entry.getKey() + ": " + inf + " inferences, avg " + avgMs + "ms"; + } + } + return "No data for model: " + modelName; + } + + @Override + public long getTotalInferences() { + long total = 0; + for (ModelStats s : modelStats.values()) { + total += s.inferences.get(); + } + return total; + } + + @Override + public long getCurrentGpuMemoryBytes() { + if (deviceMemory.isEmpty()) return -1; + long total = 0; + for (DeviceMemory dm : deviceMemory.values()) { + total += dm.currentBytes.get(); + } + return total; + } + + @Override + public long getPeakGpuMemoryBytes() { + if (deviceMemory.isEmpty()) return -1; + long peak = 0; + for (DeviceMemory dm : deviceMemory.values()) { + peak = Math.max(peak, dm.peakBytes.get()); + } + return peak; + } + + @Override + public long getTotalNativeCalls() { + long total = 0; + for (NativeCallStats ns : nativeStats.values()) { + total += ns.calls.get(); + } + return total; + } + + @Override + public void reset() { + modelStats.clear(); + deviceMemory.clear(); + nativeStats.clear(); + } + + @Override + public void close() { + String summary = getSummary(); + if (!"No GPU/inference activity recorded.".equals(summary)) { + getContext().send(summary); + } + } + + // ==================== Internals ==================== + + private ModelStats getOrCreateModel(String runtime, String modelName) { + String key = runtime + "/" + modelName; + return modelStats.computeIfAbsent(key, k -> new ModelStats()); + } + + private DeviceMemory getOrCreateDevice(String deviceType, int deviceId) { + String key = deviceType + ":" + deviceId; + return deviceMemory.computeIfAbsent(key, k -> new DeviceMemory()); + } + + private static void updateMinMax(AtomicLong min, AtomicLong max, long value) { + long cur; + do { + cur = min.get(); + if (value >= cur) break; + } while (!min.compareAndSet(cur, value)); + do { + cur = max.get(); + if (value <= cur) break; + } while (!max.compareAndSet(cur, value)); + } + + static String formatBytes(long bytes) { + if (bytes < 0) return "-" + formatBytes(-bytes); + if (bytes < 1024) return bytes + "B"; + if (bytes < 1024 * 1024) return (bytes / 1024) + "KB"; + if (bytes < 1024L * 1024 * 1024) return (bytes / (1024 * 1024)) + "MB"; + return String.format("%.1fGB", bytes / (1024.0 * 1024 * 1024)); + } + + static String formatElements(long elements) { + if (elements < 1000) return elements + ""; + if (elements < 1_000_000) return String.format("%.1fK", elements / 1000.0); + if (elements < 1_000_000_000) return String.format("%.1fM", elements / 1_000_000.0); + return String.format("%.1fB", elements / 1_000_000_000.0); + } + + static final class ModelStats { + final AtomicLong inferences = new AtomicLong(); + final AtomicLong totalBatchSize = new AtomicLong(); + final AtomicLong totalInputElements = new AtomicLong(); + final AtomicLong totalOutputElements = new AtomicLong(); + final AtomicLong totalDurationNanos = new AtomicLong(); + final AtomicLong minDurationNanos = new AtomicLong(Long.MAX_VALUE); + final AtomicLong maxDurationNanos = new AtomicLong(0); + final AtomicLong loadCount = new AtomicLong(); + final AtomicLong totalLoadDurationNanos = new AtomicLong(); + volatile String lastDeviceType; + volatile int lastDeviceId; + } + + static final class DeviceMemory { + final AtomicLong currentBytes = new AtomicLong(); + final AtomicLong peakBytes = new AtomicLong(); + final AtomicLong totalAllocated = new AtomicLong(); + final AtomicLong allocCount = new AtomicLong(); + final AtomicLong freeCount = new AtomicLong(); + } + + static final class NativeCallStats { + final AtomicLong calls = new AtomicLong(); + final AtomicLong totalDurationNanos = new AtomicLong(); + } +} diff --git a/btrace-extensions/btrace-gpu-bridge/src/impl/java/org/openjdk/btrace/gpu/InferenceRecordImpl.java b/btrace-extensions/btrace-gpu-bridge/src/impl/java/org/openjdk/btrace/gpu/InferenceRecordImpl.java new file mode 100644 index 00000000..c4cf48f9 --- /dev/null +++ b/btrace-extensions/btrace-gpu-bridge/src/impl/java/org/openjdk/btrace/gpu/InferenceRecordImpl.java @@ -0,0 +1,74 @@ +package org.openjdk.btrace.gpu; + +/** + * ThreadLocal-pooled builder for {@link InferenceRecord}. Zero allocation. + */ +final class InferenceRecordImpl implements InferenceRecord { + + String runtime; + String modelName; + int batchSizeVal; + long inputElem; + long outputElem; + String deviceTypeVal; + int deviceIdVal; + long durationVal; + + private GpuBridgeServiceImpl service; + + InferenceRecordImpl() {} + + InferenceRecordImpl reset(GpuBridgeServiceImpl service, String runtime, String modelName) { + this.service = service; + this.runtime = runtime; + this.modelName = modelName; + this.batchSizeVal = 1; + this.inputElem = 0; + this.outputElem = 0; + this.deviceTypeVal = null; + this.deviceIdVal = 0; + this.durationVal = 0; + return this; + } + + @Override + public InferenceRecord batchSize(int size) { + this.batchSizeVal = size; + return this; + } + + @Override + public InferenceRecord inputElements(long elements) { + this.inputElem = elements; + return this; + } + + @Override + public InferenceRecord outputElements(long elements) { + this.outputElem = elements; + return this; + } + + @Override + public InferenceRecord deviceType(String type) { + this.deviceTypeVal = type; + return this; + } + + @Override + public InferenceRecord deviceId(int id) { + this.deviceIdVal = id; + return this; + } + + @Override + public InferenceRecord duration(long nanos) { + this.durationVal = nanos; + return this; + } + + @Override + public void record() { + service.commitInferenceRecord(this); + } +} diff --git a/btrace-extensions/btrace-gpu-bridge/src/test/java/org/openjdk/btrace/gpu/GpuBridgeServiceTest.java b/btrace-extensions/btrace-gpu-bridge/src/test/java/org/openjdk/btrace/gpu/GpuBridgeServiceTest.java new file mode 100644 index 00000000..521561ef --- /dev/null +++ b/btrace-extensions/btrace-gpu-bridge/src/test/java/org/openjdk/btrace/gpu/GpuBridgeServiceTest.java @@ -0,0 +1,222 @@ +package org.openjdk.btrace.gpu; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.concurrent.CountDownLatch; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class GpuBridgeServiceTest { + + private GpuBridgeServiceImpl service; + + @BeforeEach + void setUp() { + service = new GpuBridgeServiceImpl(); + } + + // ==================== Inference recording ==================== + + @Test + void durationOnlyInference() { + service.recordInference("onnx", "resnet50", 10_000_000L); + assertEquals(1, service.getTotalInferences()); + assertTrue(service.getSummary().contains("onnx/resnet50")); + } + + @Test + void inferenceWithBatchSize() { + service.recordInference("djl", "bert", 32, 50_000_000L); + assertEquals(1, service.getTotalInferences()); + assertTrue(service.getSummary().contains("avg 32")); + } + + @Test + void fluentBuilder() { + service.inference("onnx", "yolo-v8") + .batchSize(16) + .inputElements(640L * 640 * 3) + .outputElements(8400L * 84) + .deviceType("cuda") + .deviceId(0) + .duration(25_000_000L) + .record(); + + assertEquals(1, service.getTotalInferences()); + String summary = service.getSummary(); + assertTrue(summary.contains("cuda:0")); + assertTrue(summary.contains("items/sec")); + } + + @Test + void fluentBuilderMinimal() { + service.inference("tensorflow", "mobilenet") + .duration(5_000_000L) + .record(); + + assertEquals(1, service.getTotalInferences()); + } + + @Test + void multipleModelsTrackedSeparately() { + service.recordInference("onnx", "resnet50", 10_000_000L); + service.recordInference("onnx", "bert", 20_000_000L); + service.recordInference("onnx", "resnet50", 12_000_000L); + + assertEquals(3, service.getTotalInferences()); + assertTrue(service.getModelSummary("resnet50").contains("2 inferences")); + assertTrue(service.getModelSummary("bert").contains("1 inferences")); + } + + @Test + void latencyMinMax() { + service.recordInference("onnx", "model", 5_000_000L); + service.recordInference("onnx", "model", 50_000_000L); + service.recordInference("onnx", "model", 10_000_000L); + + String summary = service.getSummary(); + assertTrue(summary.contains("min 5ms")); + assertTrue(summary.contains("max 50ms")); + } + + // ==================== Memory tracking ==================== + + @Test + void memoryAllocAndFree() { + service.recordMemoryAlloc("cuda", 0, 1024 * 1024 * 100); // 100MB + assertEquals(100 * 1024 * 1024L, service.getCurrentGpuMemoryBytes()); + assertEquals(100 * 1024 * 1024L, service.getPeakGpuMemoryBytes()); + + service.recordMemoryFree("cuda", 0, 1024 * 1024 * 50); // free 50MB + assertEquals(50 * 1024 * 1024L, service.getCurrentGpuMemoryBytes()); + assertEquals(100 * 1024 * 1024L, service.getPeakGpuMemoryBytes()); // peak unchanged + } + + @Test + void noMemoryTracked() { + assertEquals(-1, service.getCurrentGpuMemoryBytes()); + assertEquals(-1, service.getPeakGpuMemoryBytes()); + } + + @Test + void memorySummary() { + service.recordMemoryAlloc("cuda", 0, 500 * 1024 * 1024L); + String summary = service.getSummary(); + assertTrue(summary.contains("Device Memory")); + assertTrue(summary.contains("500MB")); + } + + // ==================== Native call tracking ==================== + + @Test + void nativeCalls() { + service.recordNativeCall("cublas", "sgemm", 500_000L); + service.recordNativeCall("cublas", "sgemm", 600_000L); + service.recordNativeCall("cudnn", "conv_forward", 1_000_000L); + + assertEquals(3, service.getTotalNativeCalls()); + String summary = service.getSummary(); + assertTrue(summary.contains("cublas::sgemm")); + assertTrue(summary.contains("2 calls")); + } + + // ==================== Model lifecycle ==================== + + @Test + void modelLoad() { + service.recordModelLoad("onnx", "bert", 2_000_000_000L); // 2 seconds + service.recordInference("onnx", "bert", 10_000_000L); + + String summary = service.getSummary(); + assertTrue(summary.contains("loaded 1x")); + assertTrue(summary.contains("2000ms")); + } + + // ==================== Reporting ==================== + + @Test + void noDataSummary() { + assertEquals("No GPU/inference activity recorded.", service.getSummary()); + } + + @Test + void unknownModelSummary() { + assertEquals("No data for model: unknown", service.getModelSummary("unknown")); + } + + @Test + void formatBytes() { + assertEquals("100B", GpuBridgeServiceImpl.formatBytes(100)); + assertEquals("10KB", GpuBridgeServiceImpl.formatBytes(10 * 1024)); + assertEquals("256MB", GpuBridgeServiceImpl.formatBytes(256 * 1024 * 1024)); + assertEquals("1.5GB", GpuBridgeServiceImpl.formatBytes((long) (1.5 * 1024 * 1024 * 1024))); + } + + @Test + void formatElements() { + assertEquals("500", GpuBridgeServiceImpl.formatElements(500)); + assertEquals("1.5K", GpuBridgeServiceImpl.formatElements(1500)); + assertEquals("2.0M", GpuBridgeServiceImpl.formatElements(2_000_000)); + assertEquals("1.0B", GpuBridgeServiceImpl.formatElements(1_000_000_000)); + } + + @Test + void reset() { + service.recordInference("onnx", "model", 1L); + service.recordMemoryAlloc("cuda", 0, 1024); + service.recordNativeCall("cublas", "fn", 1L); + service.reset(); + + assertEquals(0, service.getTotalInferences()); + assertEquals(-1, service.getCurrentGpuMemoryBytes()); + assertEquals(0, service.getTotalNativeCalls()); + assertEquals("No GPU/inference activity recorded.", service.getSummary()); + } + + @Test + void concurrentInference() throws Exception { + int threads = 8; + int infsPerThread = 1000; + CountDownLatch latch = new CountDownLatch(threads); + + for (int t = 0; t < threads; t++) { + new Thread(() -> { + try { + for (int i = 0; i < infsPerThread; i++) { + service.recordInference("onnx", "model", 32, 1_000_000L); + } + } finally { + latch.countDown(); + } + }).start(); + } + latch.await(); + + assertEquals(threads * infsPerThread, service.getTotalInferences()); + } + + @Test + void concurrentBuilderInference() throws Exception { + int threads = 8; + int infsPerThread = 500; + CountDownLatch latch = new CountDownLatch(threads); + + for (int t = 0; t < threads; t++) { + new Thread(() -> { + try { + for (int i = 0; i < infsPerThread; i++) { + service.inference("djl", "bert") + .batchSize(16) + .duration(2_000_000L) + .record(); + } + } finally { + latch.countDown(); + } + }).start(); + } + latch.await(); + + assertEquals(threads * infsPerThread, service.getTotalInferences()); + } +} diff --git a/btrace-extensions/btrace-llm-trace/build.gradle b/btrace-extensions/btrace-llm-trace/build.gradle new file mode 100644 index 00000000..75f7d67d --- /dev/null +++ b/btrace-extensions/btrace-llm-trace/build.gradle @@ -0,0 +1,33 @@ +plugins { + id 'org.openjdk.btrace.extension' + alias(libs.plugins.shadow) +} + +java { + sourceCompatibility = 8 + targetCompatibility = 8 +} + +compileJava { + javaCompiler = javaToolchains.compilerFor { + languageVersion.set(JavaLanguageVersion.of(11)) + } +} + +btraceExtension { + id = 'btrace-llm-trace' + name = 'BTrace LLM Tracing' + description = 'LLM inference observability: token counts, latency, cost tracking for Java LLM SDKs' + services = ['org.openjdk.btrace.llm.LlmTraceService'] +} + +dependencies { + apiCompileOnly project(':btrace-core') + implImplementation project(':btrace-core') + + testImplementation libs.junit.jupiter +} + +test { + useJUnitPlatform() +} diff --git a/btrace-extensions/btrace-llm-trace/src/api/java/org/openjdk/btrace/llm/CallRecord.java b/btrace-extensions/btrace-llm-trace/src/api/java/org/openjdk/btrace/llm/CallRecord.java new file mode 100644 index 00000000..9dda73e3 --- /dev/null +++ b/btrace-extensions/btrace-llm-trace/src/api/java/org/openjdk/btrace/llm/CallRecord.java @@ -0,0 +1,73 @@ +package org.openjdk.btrace.llm; + +/** + * Fluent builder for recording an LLM API call. + * + *

Obtain via {@link LlmTraceService#call(String)}. All setters are optional except + * {@link #duration(long)} — if omitted, duration defaults to 0. + * + *

Call {@link #record()} to commit the metrics. + * + *

Allocation-free: Instances are pooled per-thread internally, + * so calling {@code call()} does not allocate on the heap. Safe for hot-path + * instrumentation. However, the returned reference must not be stored or shared + * across threads — use it inline and call {@link #record()} immediately. + * + *

+ * llm.call("claude-sonnet-4-20250514")
+ *     .provider("anthropic")
+ *     .inputTokens(1500)
+ *     .outputTokens(300)
+ *     .cacheReadTokens(800)
+ *     .duration(durationNanos)
+ *     .record();
+ * 
+ */ +public interface CallRecord { + + /** Sets the provider name (e.g. "openai", "anthropic", "google"). */ + CallRecord provider(String provider); + + /** Sets the number of input/prompt tokens. */ + CallRecord inputTokens(int tokens); + + /** Sets the number of output/completion tokens. */ + CallRecord outputTokens(int tokens); + + /** + * Sets the number of cache-read input tokens (prompt caching). + * Anthropic: {@code usage.cache_read_input_tokens}. + * OpenAI: {@code usage.prompt_tokens_details.cached_tokens}. + */ + CallRecord cacheReadTokens(int tokens); + + /** + * Sets the number of cache-creation input tokens. + * Anthropic: {@code usage.cache_creation_input_tokens}. + */ + CallRecord cacheCreationTokens(int tokens); + + /** Marks this call as a streaming response. */ + CallRecord streaming(); + + /** + * Sets the time-to-first-token for streaming calls. + * Only meaningful when {@link #streaming()} is set. + * + * @param nanos time from request start to first token, in nanoseconds + */ + CallRecord timeToFirstToken(long nanos); + + /** + * Sets the total call duration. + * + * @param nanos duration in nanoseconds (typically from {@code @Duration}) + */ + CallRecord duration(long nanos); + + /** + * Commits this call record to the trace service. + * Must be called exactly once to record the metrics. + */ + void record(); +} diff --git a/btrace-extensions/btrace-llm-trace/src/api/java/org/openjdk/btrace/llm/LlmTraceService.java b/btrace-extensions/btrace-llm-trace/src/api/java/org/openjdk/btrace/llm/LlmTraceService.java new file mode 100644 index 00000000..3a3519e3 --- /dev/null +++ b/btrace-extensions/btrace-llm-trace/src/api/java/org/openjdk/btrace/llm/LlmTraceService.java @@ -0,0 +1,155 @@ +package org.openjdk.btrace.llm; + +import org.openjdk.btrace.core.extensions.ServiceDescriptor; + +/** + * LLM inference observability service for BTrace scripts. + * + *

Records LLM API call metrics — token counts, latencies, costs — and provides + * aggregated statistics. Scripts use {@code @OnMethod} to intercept LLM SDK calls + * and feed data into this service. + * + *

Thread-safe. All methods can be called concurrently from instrumented threads. + * + *

Quick start — latency only (works with any SDK)

+ *
+ * @Injected private static LlmTraceService llm;
+ *
+ * @OnMethod(clazz = "+dev.langchain4j.model.chat.ChatLanguageModel",
+ *           method = "generate", location = @Location(Kind.RETURN))
+ * public static void onChat(@ProbeClassName String cls, @Duration long dur) {
+ *     llm.recordCall(cls, dur);
+ * }
+ * 
+ * + *

Full metrics with fluent builder

+ *
+ * llm.call("claude-sonnet-4-20250514")
+ *     .provider("anthropic")
+ *     .inputTokens(1500)
+ *     .outputTokens(300)
+ *     .cacheReadTokens(800)
+ *     .duration(durationNanos)
+ *     .record();
+ * 
+ */ +@ServiceDescriptor +public interface LlmTraceService { + + // ==================== Simple recording ==================== + + /** + * Records an LLM call with only latency (no token info). + * Use this when token counts aren't easily extractable. + * + * @param model model identifier or class name + * @param durationNanos call duration in nanoseconds + */ + void recordCall(String model, long durationNanos); + + /** + * Records an LLM call with token counts. + * + * @param model model identifier (e.g. "gpt-4o", "claude-sonnet-4-20250514") + * @param inputTokens number of input/prompt tokens + * @param outputTokens number of output/completion tokens + * @param durationNanos call duration in nanoseconds + */ + void recordCall(String model, int inputTokens, int outputTokens, long durationNanos); + + // ==================== Fluent builder ==================== + + /** + * Starts a fluent call record for the given model. + * + *

The returned builder is allocation-free (ThreadLocal-pooled). It is safe + * to use on hot paths. The builder must be used inline on the calling thread + * and {@link CallRecord#record()} must be called before the next {@code call()}. + * + *

+   * llm.call("claude-sonnet-4-20250514")
+   *     .provider("anthropic")
+   *     .inputTokens(1500)
+   *     .outputTokens(300)
+   *     .cacheReadTokens(800)
+   *     .streaming()
+   *     .timeToFirstToken(200_000_000L)
+   *     .duration(durationNanos)
+   *     .record();
+   * 
+ * + * @param model model identifier + * @return a call record builder (thread-local, do not store) + */ + CallRecord call(String model); + + // ==================== Specialized recording ==================== + + /** + * Records an embedding API call. + * + * @param model embedding model identifier (e.g. "text-embedding-3-small") + * @param tokenCount number of tokens embedded + * @param durationNanos call duration in nanoseconds + */ + void recordEmbedding(String model, int tokenCount, long durationNanos); + + /** + * Records a tool/function call invocation by the LLM. + * + * @param model model that made the tool call + * @param toolName name of the tool/function called + */ + void recordToolUse(String model, String toolName); + + /** + * Records a failed LLM API call. + * + * @param model model identifier + * @param errorType error class name or HTTP status code + * @param durationNanos call duration before failure + */ + void recordError(String model, String errorType, long durationNanos); + + // ==================== Reporting ==================== + + /** + * Returns a formatted summary of all recorded metrics. + * Includes per-model token counts, latency stats, cost estimates, + * cache hit rates, tool use, and error breakdown. + * + * @return multi-line summary string + */ + String getSummary(); + + /** + * Returns a one-line summary for a specific model. + * + * @param model model identifier + * @return summary string, or "No data" if none recorded + */ + String getModelSummary(String model); + + /** + * Returns total estimated cost across all models in USD. + * Uses built-in pricing table. Returns -1 if all models are unknown. + * + * @return estimated cost in USD, or -1 if pricing unknown + */ + double getEstimatedCostUsd(); + + /** Returns total number of chat completion calls recorded. */ + long getTotalCalls(); + + /** Returns total input tokens across all calls. */ + long getTotalInputTokens(); + + /** Returns total output tokens across all calls. */ + long getTotalOutputTokens(); + + /** Returns total number of embedding calls recorded. */ + long getTotalEmbeddingCalls(); + + /** Resets all collected metrics. */ + void reset(); +} diff --git a/btrace-extensions/btrace-llm-trace/src/impl/java/org/openjdk/btrace/llm/CallRecordImpl.java b/btrace-extensions/btrace-llm-trace/src/impl/java/org/openjdk/btrace/llm/CallRecordImpl.java new file mode 100644 index 00000000..591be372 --- /dev/null +++ b/btrace-extensions/btrace-llm-trace/src/impl/java/org/openjdk/btrace/llm/CallRecordImpl.java @@ -0,0 +1,98 @@ +package org.openjdk.btrace.llm; + +/** + * Mutable builder implementing the {@link CallRecord} fluent API. + * + *

Instances are pooled per-thread via {@link ThreadLocal} to avoid + * heap allocation on every {@code call()} invocation. This makes the + * builder safe for hot-path instrumentation — no garbage is produced. + * + *

Not thread-safe — intended to be created, configured, and recorded + * within a single BTrace handler method invocation on the same thread. + */ +final class CallRecordImpl implements CallRecord { + + String model; + String providerVal; + int inputTok; + int outputTok; + int cacheReadTok; + int cacheCreateTok; + boolean isStreaming; + long ttftVal; + long durationVal; + + private LlmTraceServiceImpl service; + + CallRecordImpl() { + // Created once per thread via ThreadLocal + } + + /** Resets all fields and binds this record to a new call. */ + CallRecordImpl reset(LlmTraceServiceImpl service, String model) { + this.service = service; + this.model = model; + this.providerVal = null; + this.inputTok = 0; + this.outputTok = 0; + this.cacheReadTok = 0; + this.cacheCreateTok = 0; + this.isStreaming = false; + this.ttftVal = 0; + this.durationVal = 0; + return this; + } + + @Override + public CallRecord provider(String provider) { + this.providerVal = provider; + return this; + } + + @Override + public CallRecord inputTokens(int tokens) { + this.inputTok = tokens; + return this; + } + + @Override + public CallRecord outputTokens(int tokens) { + this.outputTok = tokens; + return this; + } + + @Override + public CallRecord cacheReadTokens(int tokens) { + this.cacheReadTok = tokens; + return this; + } + + @Override + public CallRecord cacheCreationTokens(int tokens) { + this.cacheCreateTok = tokens; + return this; + } + + @Override + public CallRecord streaming() { + this.isStreaming = true; + return this; + } + + @Override + public CallRecord timeToFirstToken(long nanos) { + this.ttftVal = nanos; + return this; + } + + @Override + public CallRecord duration(long nanos) { + this.durationVal = nanos; + return this; + } + + @Override + public void record() { + service.commitCallRecord(this); + } +} diff --git a/btrace-extensions/btrace-llm-trace/src/impl/java/org/openjdk/btrace/llm/LlmTraceServiceImpl.java b/btrace-extensions/btrace-llm-trace/src/impl/java/org/openjdk/btrace/llm/LlmTraceServiceImpl.java new file mode 100644 index 00000000..6bff6af8 --- /dev/null +++ b/btrace-extensions/btrace-llm-trace/src/impl/java/org/openjdk/btrace/llm/LlmTraceServiceImpl.java @@ -0,0 +1,443 @@ +package org.openjdk.btrace.llm; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import org.openjdk.btrace.core.extensions.Extension; + +/** + * Thread-safe LLM call tracing with lock-free per-model statistics. + * Zero external dependencies. + */ +public final class LlmTraceServiceImpl extends Extension implements LlmTraceService { + + private final Map modelStats = new ConcurrentHashMap<>(); + private final Map embeddingStats = new ConcurrentHashMap<>(); + private final Map toolUseCounts = new ConcurrentHashMap<>(); + private final Map errorCounts = new ConcurrentHashMap<>(); + + // ==================== Simple recording ==================== + + @Override + public void recordCall(String model, long durationNanos) { + ModelStats stats = getOrCreate(modelStats, model); + stats.calls.incrementAndGet(); + stats.totalDurationNanos.addAndGet(durationNanos); + updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos); + } + + @Override + public void recordCall(String model, int inputTokens, int outputTokens, long durationNanos) { + ModelStats stats = getOrCreate(modelStats, model); + stats.calls.incrementAndGet(); + stats.inputTokens.addAndGet(inputTokens); + stats.outputTokens.addAndGet(outputTokens); + stats.totalDurationNanos.addAndGet(durationNanos); + updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos); + } + + // ==================== Fluent builder ==================== + + /** + * ThreadLocal-pooled builder — one CallRecordImpl per thread, reused across calls. + * Eliminates per-call heap allocation, making the builder safe for hot paths. + */ + private final ThreadLocal callRecordPool = + ThreadLocal.withInitial(CallRecordImpl::new); + + @Override + public CallRecord call(String model) { + return callRecordPool.get().reset(this, model); + } + + void commitCallRecord(CallRecordImpl rec) { + ModelStats stats = getOrCreate(modelStats, rec.model); + if (rec.providerVal != null) { + stats.provider = rec.providerVal; + } + stats.calls.incrementAndGet(); + stats.inputTokens.addAndGet(rec.inputTok); + stats.outputTokens.addAndGet(rec.outputTok); + stats.cacheReadTokens.addAndGet(rec.cacheReadTok); + stats.cacheCreationTokens.addAndGet(rec.cacheCreateTok); + stats.totalDurationNanos.addAndGet(rec.durationVal); + updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, rec.durationVal); + if (rec.isStreaming) { + stats.streamingCalls.incrementAndGet(); + stats.totalTimeToFirstToken.addAndGet(rec.ttftVal); + } + } + + // ==================== Specialized recording ==================== + + @Override + public void recordEmbedding(String model, int tokenCount, long durationNanos) { + ModelStats stats = getOrCreate(embeddingStats, model); + stats.calls.incrementAndGet(); + stats.inputTokens.addAndGet(tokenCount); + stats.totalDurationNanos.addAndGet(durationNanos); + updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos); + } + + @Override + public void recordToolUse(String model, String toolName) { + String key = model + "::" + toolName; + toolUseCounts.computeIfAbsent(key, k -> new AtomicLong()).incrementAndGet(); + getOrCreate(modelStats, model).toolCalls.incrementAndGet(); + } + + @Override + public void recordError(String model, String errorType, long durationNanos) { + String key = model + "::" + errorType; + errorCounts.computeIfAbsent(key, k -> new AtomicLong()).incrementAndGet(); + getOrCreate(modelStats, model).errors.incrementAndGet(); + } + + // ==================== Reporting ==================== + + @Override + public String getSummary() { + if (modelStats.isEmpty() && embeddingStats.isEmpty()) { + return "No LLM calls recorded."; + } + + StringBuilder sb = new StringBuilder(); + sb.append("=== LLM Trace Summary ===\n\n"); + + long totalCalls = 0; + long totalIn = 0; + long totalOut = 0; + double totalCost = 0; + + // Chat completions + for (Map.Entry entry : modelStats.entrySet()) { + String model = entry.getKey(); + ModelStats s = entry.getValue(); + long calls = s.calls.get(); + long inTok = s.inputTokens.get(); + long outTok = s.outputTokens.get(); + long cacheRead = s.cacheReadTokens.get(); + long cacheCreate = s.cacheCreationTokens.get(); + + totalCalls += calls; + totalIn += inTok; + totalOut += outTok; + + sb.append("Model: ").append(model); + if (!"unknown".equals(s.provider)) { + sb.append(" (").append(s.provider).append(")"); + } + sb.append("\n"); + + // Calls + sb.append(" Calls: ").append(calls); + long streaming = s.streamingCalls.get(); + if (streaming > 0) { + sb.append(" (").append(streaming).append(" streaming)"); + } + sb.append("\n"); + + // Tokens + if (inTok > 0 || outTok > 0) { + sb.append(" Tokens: ").append(inTok).append(" in / ").append(outTok).append(" out"); + if (calls > 0) { + sb.append(" (avg ").append(inTok / calls).append("/").append(outTok / calls).append(")"); + } + sb.append("\n"); + } + + // Cache + if (cacheRead > 0 || cacheCreate > 0) { + sb.append(" Cache: "); + if (cacheRead > 0) { + sb.append(cacheRead).append(" read"); + // Show cache hit rate relative to total input + if (inTok > 0) { + long hitPct = (cacheRead * 100) / inTok; + sb.append(" (").append(hitPct).append("% hit)"); + } + } + if (cacheCreate > 0) { + if (cacheRead > 0) sb.append(", "); + sb.append(cacheCreate).append(" created"); + } + sb.append("\n"); + } + + // Latency + if (calls > 0) { + long avgMs = (s.totalDurationNanos.get() / calls) / 1_000_000; + long minMs = s.minDurationNanos.get() / 1_000_000; + long maxMs = s.maxDurationNanos.get() / 1_000_000; + sb.append(" Latency: avg ").append(avgMs).append("ms"); + sb.append(", min ").append(minMs).append("ms"); + sb.append(", max ").append(maxMs).append("ms\n"); + } + + // TTFT + if (streaming > 0) { + long avgTtft = (s.totalTimeToFirstToken.get() / streaming) / 1_000_000; + sb.append(" TTFT (avg): ").append(avgTtft).append("ms\n"); + } + + // Tool calls + long tc = s.toolCalls.get(); + if (tc > 0) { + sb.append(" Tool calls: ").append(tc).append("\n"); + } + + // Errors + long errs = s.errors.get(); + if (errs > 0) { + sb.append(" Errors: ").append(errs).append("\n"); + } + + // Cost + double cost = estimateCost(model, inTok, outTok, cacheRead); + if (cost >= 0) { + totalCost += cost; + sb.append(" Est. cost: $").append(formatCost(cost)); + if (cacheRead > 0) { + double uncachedCost = estimateCost(model, inTok + cacheRead, outTok, 0); + if (uncachedCost > cost) { + sb.append(" (saved $").append(formatCost(uncachedCost - cost)).append(" via cache)"); + } + } + sb.append("\n"); + } + sb.append("\n"); + } + + // Embeddings + if (!embeddingStats.isEmpty()) { + sb.append("--- Embeddings ---\n"); + for (Map.Entry entry : embeddingStats.entrySet()) { + ModelStats s = entry.getValue(); + long calls = s.calls.get(); + long tokens = s.inputTokens.get(); + long avgMs = calls > 0 ? (s.totalDurationNanos.get() / calls) / 1_000_000 : 0; + sb.append(" ").append(entry.getKey()).append(": ") + .append(calls).append(" calls, ") + .append(tokens).append(" tokens, avg ") + .append(avgMs).append("ms\n"); + } + sb.append("\n"); + } + + // Tool use breakdown + if (!toolUseCounts.isEmpty()) { + sb.append("--- Tool Use ---\n"); + for (Map.Entry entry : toolUseCounts.entrySet()) { + sb.append(" ").append(entry.getKey()).append(": ") + .append(entry.getValue().get()).append("\n"); + } + sb.append("\n"); + } + + // Error breakdown + if (!errorCounts.isEmpty()) { + sb.append("--- Errors ---\n"); + for (Map.Entry entry : errorCounts.entrySet()) { + sb.append(" ").append(entry.getKey()).append(": ") + .append(entry.getValue().get()).append("\n"); + } + sb.append("\n"); + } + + sb.append("--- Totals ---\n"); + sb.append(" Calls: ").append(totalCalls).append("\n"); + if (totalIn > 0 || totalOut > 0) { + sb.append(" Tokens: ").append(totalIn).append(" in / ").append(totalOut).append(" out\n"); + } + if (totalCost > 0) { + sb.append(" Est. total cost: $").append(formatCost(totalCost)).append("\n"); + } + + return sb.toString(); + } + + @Override + public String getModelSummary(String model) { + ModelStats s = modelStats.get(model); + if (s == null) { + return "No data for model: " + model; + } + long calls = s.calls.get(); + long inTok = s.inputTokens.get(); + long outTok = s.outputTokens.get(); + long avgMs = calls > 0 ? (s.totalDurationNanos.get() / calls) / 1_000_000 : 0; + + StringBuilder sb = new StringBuilder(); + sb.append(model).append(": ").append(calls).append(" calls"); + if (inTok > 0 || outTok > 0) { + sb.append(", ").append(inTok).append("/").append(outTok).append(" tokens (in/out)"); + } + sb.append(", avg ").append(avgMs).append("ms"); + return sb.toString(); + } + + @Override + public double getEstimatedCostUsd() { + double total = 0; + boolean anyKnown = false; + for (Map.Entry entry : modelStats.entrySet()) { + ModelStats s = entry.getValue(); + double cost = estimateCost(entry.getKey(), + s.inputTokens.get(), s.outputTokens.get(), s.cacheReadTokens.get()); + if (cost >= 0) { + total += cost; + anyKnown = true; + } + } + return anyKnown ? total : -1; + } + + @Override + public long getTotalCalls() { + long total = 0; + for (ModelStats s : modelStats.values()) { + total += s.calls.get(); + } + return total; + } + + @Override + public long getTotalInputTokens() { + long total = 0; + for (ModelStats s : modelStats.values()) { + total += s.inputTokens.get(); + } + return total; + } + + @Override + public long getTotalOutputTokens() { + long total = 0; + for (ModelStats s : modelStats.values()) { + total += s.outputTokens.get(); + } + return total; + } + + @Override + public long getTotalEmbeddingCalls() { + long total = 0; + for (ModelStats s : embeddingStats.values()) { + total += s.calls.get(); + } + return total; + } + + @Override + public void reset() { + modelStats.clear(); + embeddingStats.clear(); + toolUseCounts.clear(); + errorCounts.clear(); + } + + @Override + public void close() { + String summary = getSummary(); + if (!"No LLM calls recorded.".equals(summary)) { + getContext().send(summary); + } + } + + // ==================== Internals ==================== + + private static ModelStats getOrCreate(Map map, String key) { + return map.computeIfAbsent(key, k -> new ModelStats()); + } + + private static void updateMinMax(AtomicLong min, AtomicLong max, long value) { + long cur; + do { + cur = min.get(); + if (value >= cur) break; + } while (!min.compareAndSet(cur, value)); + do { + cur = max.get(); + if (value <= cur) break; + } while (!max.compareAndSet(cur, value)); + } + + /** + * Estimates cost in USD. Cache-read tokens are priced at ~10% of input rate + * for models that support caching. + */ + static double estimateCost(String model, long inputTokens, long outputTokens, + long cacheReadTokens) { + double inputPer1M = -1; + double outputPer1M = -1; + double cacheReadPer1M = -1; + + String m = model.toLowerCase(); + + // Anthropic Claude + if (m.contains("claude") && m.contains("opus")) { + inputPer1M = 15.0; outputPer1M = 75.0; cacheReadPer1M = 1.50; + } else if (m.contains("claude") && m.contains("sonnet")) { + inputPer1M = 3.0; outputPer1M = 15.0; cacheReadPer1M = 0.30; + } else if (m.contains("claude") && m.contains("haiku")) { + inputPer1M = 0.80; outputPer1M = 4.0; cacheReadPer1M = 0.08; + } + // OpenAI GPT + else if (m.contains("gpt-4o-mini")) { + inputPer1M = 0.15; outputPer1M = 0.60; cacheReadPer1M = 0.075; + } else if (m.contains("gpt-4o")) { + inputPer1M = 2.50; outputPer1M = 10.0; cacheReadPer1M = 1.25; + } else if (m.contains("gpt-4") && m.contains("turbo")) { + inputPer1M = 10.0; outputPer1M = 30.0; + } else if (m.contains("gpt-4")) { + inputPer1M = 30.0; outputPer1M = 60.0; + } else if (m.contains("gpt-3.5")) { + inputPer1M = 0.50; outputPer1M = 1.50; + } else if (m.contains("o1-mini")) { + inputPer1M = 3.0; outputPer1M = 12.0; cacheReadPer1M = 1.50; + } else if (m.contains("o1")) { + inputPer1M = 15.0; outputPer1M = 60.0; cacheReadPer1M = 7.50; + } + // Google Gemini + else if (m.contains("gemini") && m.contains("pro")) { + inputPer1M = 1.25; outputPer1M = 5.0; + } else if (m.contains("gemini") && m.contains("flash")) { + inputPer1M = 0.075; outputPer1M = 0.30; + } + + if (inputPer1M < 0) { + return -1; + } + + double cost = (inputTokens * inputPer1M / 1_000_000.0) + + (outputTokens * outputPer1M / 1_000_000.0); + if (cacheReadTokens > 0 && cacheReadPer1M > 0) { + cost += (cacheReadTokens * cacheReadPer1M / 1_000_000.0); + } + return cost; + } + + static String formatCost(double cost) { + if (cost < 0.01) { + return String.format("%.4f", cost); + } + return String.format("%.2f", cost); + } + + /** Lock-free per-model statistics. */ + static final class ModelStats { + volatile String provider = "unknown"; + final AtomicLong calls = new AtomicLong(); + final AtomicLong streamingCalls = new AtomicLong(); + final AtomicLong inputTokens = new AtomicLong(); + final AtomicLong outputTokens = new AtomicLong(); + final AtomicLong cacheReadTokens = new AtomicLong(); + final AtomicLong cacheCreationTokens = new AtomicLong(); + final AtomicLong totalDurationNanos = new AtomicLong(); + final AtomicLong minDurationNanos = new AtomicLong(Long.MAX_VALUE); + final AtomicLong maxDurationNanos = new AtomicLong(0); + final AtomicLong totalTimeToFirstToken = new AtomicLong(); + final AtomicLong toolCalls = new AtomicLong(); + final AtomicLong errors = new AtomicLong(); + } +} diff --git a/btrace-extensions/btrace-llm-trace/src/test/java/org/openjdk/btrace/llm/LlmTraceServiceTest.java b/btrace-extensions/btrace-llm-trace/src/test/java/org/openjdk/btrace/llm/LlmTraceServiceTest.java new file mode 100644 index 00000000..8951ca9f --- /dev/null +++ b/btrace-extensions/btrace-llm-trace/src/test/java/org/openjdk/btrace/llm/LlmTraceServiceTest.java @@ -0,0 +1,335 @@ +package org.openjdk.btrace.llm; + +import static org.junit.jupiter.api.Assertions.*; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class LlmTraceServiceTest { + + private LlmTraceServiceImpl service; + + @BeforeEach + void setUp() { + service = new LlmTraceServiceImpl(); + } + + // ==================== Simple recording ==================== + + @Test + void recordCallDurationOnly() { + service.recordCall("gpt-4o", 500_000_000L); + service.recordCall("gpt-4o", 300_000_000L); + + assertEquals(2, service.getTotalCalls()); + assertEquals(0, service.getTotalInputTokens(), "No tokens recorded"); + String summary = service.getModelSummary("gpt-4o"); + assertTrue(summary.contains("2 calls")); + assertTrue(summary.contains("400ms"), "avg of 500ms and 300ms"); + } + + @Test + void recordCallWithTokens() { + service.recordCall("claude-sonnet-4-20250514", 1000, 200, 500_000_000L); + service.recordCall("claude-sonnet-4-20250514", 800, 300, 400_000_000L); + + assertEquals(2, service.getTotalCalls()); + assertEquals(1800, service.getTotalInputTokens()); + assertEquals(500, service.getTotalOutputTokens()); + } + + // ==================== Fluent builder ==================== + + @Test + void fluentBuilderBasic() { + service.call("gpt-4o") + .inputTokens(1000) + .outputTokens(200) + .duration(500_000_000L) + .record(); + + assertEquals(1, service.getTotalCalls()); + assertEquals(1000, service.getTotalInputTokens()); + assertEquals(200, service.getTotalOutputTokens()); + } + + @Test + void fluentBuilderWithProvider() { + service.call("claude-sonnet-4-20250514") + .provider("anthropic") + .inputTokens(1500) + .outputTokens(300) + .duration(800_000_000L) + .record(); + + String summary = service.getSummary(); + assertTrue(summary.contains("anthropic")); + assertTrue(summary.contains("claude-sonnet")); + } + + @Test + void fluentBuilderWithCache() { + service.call("claude-sonnet-4-20250514") + .inputTokens(500) + .outputTokens(200) + .cacheReadTokens(1000) + .cacheCreationTokens(200) + .duration(300_000_000L) + .record(); + + String summary = service.getSummary(); + assertTrue(summary.contains("Cache:"), "Should show cache section"); + assertTrue(summary.contains("1000 read"), "Should show cache read tokens"); + assertTrue(summary.contains("200 created"), "Should show cache creation tokens"); + assertTrue(summary.contains("saved $"), "Should show cache savings"); + } + + @Test + void fluentBuilderStreaming() { + service.call("gpt-4o") + .inputTokens(2000) + .outputTokens(500) + .streaming() + .timeToFirstToken(200_000_000L) + .duration(3_000_000_000L) + .record(); + + String summary = service.getSummary(); + assertTrue(summary.contains("streaming")); + assertTrue(summary.contains("TTFT")); + } + + @Test + void fluentBuilderDurationOnly() { + // Minimal usage — just model + duration + service.call("some-model").duration(100_000_000L).record(); + + assertEquals(1, service.getTotalCalls()); + } + + // ==================== Embeddings ==================== + + @Test + void recordEmbedding() { + service.recordEmbedding("text-embedding-3-small", 500, 50_000_000L); + service.recordEmbedding("text-embedding-3-small", 300, 30_000_000L); + + assertEquals(2, service.getTotalEmbeddingCalls()); + assertEquals(0, service.getTotalCalls(), "Embeddings don't count as chat calls"); + + String summary = service.getSummary(); + assertTrue(summary.contains("Embeddings")); + assertTrue(summary.contains("text-embedding-3-small")); + assertTrue(summary.contains("2 calls")); + assertTrue(summary.contains("800 tokens")); + } + + // ==================== Tool use ==================== + + @Test + void recordToolUse() { + service.recordCall("claude-sonnet-4-20250514", 1000, 200, 500_000_000L); + service.recordToolUse("claude-sonnet-4-20250514", "search_web"); + service.recordToolUse("claude-sonnet-4-20250514", "search_web"); + service.recordToolUse("claude-sonnet-4-20250514", "run_code"); + + String summary = service.getSummary(); + assertTrue(summary.contains("Tool calls: 3")); + assertTrue(summary.contains("Tool Use")); + assertTrue(summary.contains("search_web")); + } + + // ==================== Errors ==================== + + @Test + void recordError() { + service.recordCall("gpt-4o", 1000, 0, 100_000_000L); + service.recordError("gpt-4o", "RateLimitException", 50_000_000L); + + String summary = service.getSummary(); + assertTrue(summary.contains("Errors: 1")); + assertTrue(summary.contains("RateLimitException")); + } + + // ==================== Model summary ==================== + + @Test + void getModelSummaryUnknown() { + assertTrue(service.getModelSummary("nonexistent").contains("No data")); + } + + @Test + void getModelSummaryDurationOnly() { + service.recordCall("my-model", 500_000_000L); + String result = service.getModelSummary("my-model"); + assertTrue(result.contains("1 calls")); + assertTrue(result.contains("500ms")); + // Should NOT show "0/0 tokens" + assertFalse(result.contains("0/0")); + } + + @Test + void getModelSummaryWithTokens() { + service.recordCall("gpt-4o", 1000, 200, 500_000_000L); + String result = service.getModelSummary("gpt-4o"); + assertTrue(result.contains("1000/200 tokens")); + } + + // ==================== Cost estimation ==================== + + @Test + void estimateCostClaudeSonnet() { + // Sonnet: $3/1M input, $15/1M output + double cost = LlmTraceServiceImpl.estimateCost("claude-sonnet-4-20250514", + 1_000_000, 100_000, 0); + assertEquals(3.0 + 1.5, cost, 0.01); + } + + @Test + void estimateCostWithCacheRead() { + // Sonnet: $3/1M input, $15/1M output, $0.30/1M cache-read + double cost = LlmTraceServiceImpl.estimateCost("claude-sonnet-4-20250514", + 500_000, 100_000, 500_000); + double expected = (500_000 * 3.0 / 1e6) + (100_000 * 15.0 / 1e6) + (500_000 * 0.30 / 1e6); + assertEquals(expected, cost, 0.001); + } + + @Test + void estimateCostGpt4o() { + double cost = LlmTraceServiceImpl.estimateCost("gpt-4o", 1_000_000, 1_000_000, 0); + assertEquals(2.50 + 10.0, cost, 0.01); + } + + @Test + void estimateCostUnknownModel() { + assertEquals(-1, LlmTraceServiceImpl.estimateCost("my-custom-model", 1000, 1000, 0)); + } + + @Test + void estimatedCostUsdAcrossModels() { + service.recordCall("claude-sonnet-4-20250514", 1_000_000, 0, 1_000_000_000L); + service.recordCall("gpt-4o", 1_000_000, 0, 1_000_000_000L); + + double cost = service.getEstimatedCostUsd(); + assertEquals(5.50, cost, 0.01); + } + + @Test + void estimatedCostUsdUnknownModels() { + service.recordCall("my-custom-model", 1000, 500, 100_000_000L); + assertEquals(-1, service.getEstimatedCostUsd()); + } + + // ==================== Latency tracking ==================== + + @Test + void latencyMinMax() { + service.recordCall("gpt-4o", 100, 50, 100_000_000L); + service.recordCall("gpt-4o", 100, 50, 500_000_000L); + service.recordCall("gpt-4o", 100, 50, 200_000_000L); + + String summary = service.getSummary(); + assertTrue(summary.contains("min 100ms")); + assertTrue(summary.contains("max 500ms")); + } + + // ==================== Reset ==================== + + @Test + void resetClearsEverything() { + service.recordCall("gpt-4o", 1000, 200, 500_000_000L); + service.recordEmbedding("text-embedding-3-small", 100, 10_000_000L); + service.recordToolUse("gpt-4o", "search"); + service.recordError("gpt-4o", "Timeout", 100_000_000L); + + service.reset(); + + assertEquals(0, service.getTotalCalls()); + assertEquals(0, service.getTotalInputTokens()); + assertEquals(0, service.getTotalOutputTokens()); + assertEquals(0, service.getTotalEmbeddingCalls()); + assertEquals("No LLM calls recorded.", service.getSummary()); + } + + // ==================== Edge cases ==================== + + @Test + void summaryEmptyWhenNoCalls() { + assertEquals("No LLM calls recorded.", service.getSummary()); + } + + @Test + void summaryOmitsTokenLineWhenNoTokens() { + service.recordCall("my-model", 500_000_000L); + String summary = service.getSummary(); + // Should show the model and latency but not "0 in / 0 out" + assertTrue(summary.contains("my-model")); + assertFalse(summary.contains("0 in / 0 out")); + } + + @Test + void cacheHitRateCalculation() { + service.call("claude-sonnet-4-20250514") + .inputTokens(200) + .outputTokens(100) + .cacheReadTokens(800) + .duration(100_000_000L) + .record(); + + String summary = service.getSummary(); + // 800 cache reads out of 200 input = 400% is wrong semantically, + // so check that the cache section exists + assertTrue(summary.contains("Cache:")); + assertTrue(summary.contains("800 read")); + } + + // ==================== Concurrency ==================== + + @Test + void concurrentRecording() throws InterruptedException { + int threads = 8; + int callsPerThread = 1000; + Thread[] workers = new Thread[threads]; + + for (int t = 0; t < threads; t++) { + workers[t] = new Thread(() -> { + for (int i = 0; i < callsPerThread; i++) { + service.recordCall("gpt-4o", 100, 50, 10_000_000L); + } + }); + workers[t].start(); + } + for (Thread w : workers) { + w.join(); + } + + assertEquals(threads * callsPerThread, service.getTotalCalls()); + assertEquals(threads * callsPerThread * 100L, service.getTotalInputTokens()); + } + + @Test + void concurrentBuilderRecording() throws InterruptedException { + int threads = 4; + int callsPerThread = 500; + Thread[] workers = new Thread[threads]; + + for (int t = 0; t < threads; t++) { + workers[t] = new Thread(() -> { + for (int i = 0; i < callsPerThread; i++) { + service.call("claude-sonnet-4-20250514") + .provider("anthropic") + .inputTokens(100) + .outputTokens(50) + .duration(10_000_000L) + .record(); + } + }); + workers[t].start(); + } + for (Thread w : workers) { + w.join(); + } + + assertEquals(threads * callsPerThread, service.getTotalCalls()); + } +} diff --git a/btrace-extensions/btrace-rag-quality/build.gradle b/btrace-extensions/btrace-rag-quality/build.gradle new file mode 100644 index 00000000..f22b500a --- /dev/null +++ b/btrace-extensions/btrace-rag-quality/build.gradle @@ -0,0 +1,33 @@ +plugins { + id 'org.openjdk.btrace.extension' + alias(libs.plugins.shadow) +} + +java { + sourceCompatibility = 8 + targetCompatibility = 8 +} + +compileJava { + javaCompiler = javaToolchains.compilerFor { + languageVersion.set(JavaLanguageVersion.of(11)) + } +} + +btraceExtension { + id = 'btrace-rag-quality' + name = 'BTrace RAG Quality' + description = 'RAG pipeline observability: retrieval latency, similarity scores, chunk quality for vector DB clients' + services = ['org.openjdk.btrace.rag.RagQualityService'] +} + +dependencies { + apiCompileOnly project(':btrace-core') + implImplementation project(':btrace-core') + + testImplementation libs.junit.jupiter +} + +test { + useJUnitPlatform() +} diff --git a/btrace-extensions/btrace-rag-quality/src/api/java/org/openjdk/btrace/rag/QueryRecord.java b/btrace-extensions/btrace-rag-quality/src/api/java/org/openjdk/btrace/rag/QueryRecord.java new file mode 100644 index 00000000..08de459a --- /dev/null +++ b/btrace-extensions/btrace-rag-quality/src/api/java/org/openjdk/btrace/rag/QueryRecord.java @@ -0,0 +1,44 @@ +package org.openjdk.btrace.rag; + +/** + * Fluent builder for recording a vector DB query with detailed metrics. + * + *

Obtain via {@link RagQualityService#query(String)}. All setters are optional. + * Call {@link #record()} to commit the metrics. + * + *

Allocation-free: Instances are pooled per-thread internally. + * The returned reference must not be stored or shared across threads. + * + *

+ * rag.query("pinecone")
+ *     .resultCount(5)
+ *     .topScore(0.92f)
+ *     .lowScore(0.71f)
+ *     .embeddingDimension(1536)
+ *     .duration(durationNanos)
+ *     .record();
+ * 
+ */ +public interface QueryRecord { + + /** Number of results/chunks returned by the query. */ + QueryRecord resultCount(int count); + + /** Highest similarity score in the result set (0.0 to 1.0). */ + QueryRecord topScore(float score); + + /** Lowest similarity score in the result set (0.0 to 1.0). */ + QueryRecord lowScore(float score); + + /** Embedding dimension used for the query vector. */ + QueryRecord embeddingDimension(int dimension); + + /** Total tokens across all returned chunks. */ + QueryRecord totalChunkTokens(int tokens); + + /** Query duration in nanoseconds. */ + QueryRecord duration(long nanos); + + /** Commits this query record to the service. */ + void record(); +} diff --git a/btrace-extensions/btrace-rag-quality/src/api/java/org/openjdk/btrace/rag/RagQualityService.java b/btrace-extensions/btrace-rag-quality/src/api/java/org/openjdk/btrace/rag/RagQualityService.java new file mode 100644 index 00000000..1eaccb98 --- /dev/null +++ b/btrace-extensions/btrace-rag-quality/src/api/java/org/openjdk/btrace/rag/RagQualityService.java @@ -0,0 +1,111 @@ +package org.openjdk.btrace.rag; + +import org.openjdk.btrace.core.extensions.ServiceDescriptor; + +/** + * BTrace extension service for RAG (Retrieval-Augmented Generation) pipeline observability. + * + *

Tracks vector database query performance, retrieval quality metrics, + * and end-to-end RAG pipeline latency. Works with any vector DB client + * (Pinecone, Milvus, Weaviate, Chroma, pgvector, Qdrant). + * + *

Usage in a BTrace script: + *

+ * @Injected RagQualityService rag;
+ *
+ * @OnMethod(clazz = "io.pinecone.PineconeClient", method = "query")
+ * void onQuery(@Duration long dur) {
+ *     rag.recordQuery("pinecone", 10, dur);
+ * }
+ * 
+ */ +@ServiceDescriptor +public interface RagQualityService { + + // ==================== Query recording ==================== + + /** + * Records a vector DB query with duration only. + * + * @param source vector DB or index name (e.g. "pinecone", "milvus-products") + * @param durationNanos query duration in nanoseconds + */ + void recordQuery(String source, long durationNanos); + + /** + * Records a vector DB query with result count and duration. + * + * @param source vector DB or index name + * @param resultCount number of results/chunks returned (top-K) + * @param durationNanos query duration in nanoseconds + */ + void recordQuery(String source, int resultCount, long durationNanos); + + /** + * Starts a detailed query record builder. Allocation-free (ThreadLocal-pooled). + * + *

Use inline on the calling thread and call {@link QueryRecord#record()} + * before the next {@code query()} call. Do not store the returned reference. + * + *

+   * rag.query("pinecone")
+   *     .resultCount(5)
+   *     .topScore(0.92f)
+   *     .lowScore(0.71f)
+   *     .embeddingDimension(1536)
+   *     .duration(durationNanos)
+   *     .record();
+   * 
+ * + * @param source vector DB or index name + * @return a query record builder (thread-local, do not store) + */ + QueryRecord query(String source); + + // ==================== Pipeline recording ==================== + + /** + * Records an end-to-end RAG pipeline invocation (retrieve + generate). + * + * @param pipelineName pipeline identifier + * @param retrievalNanos time spent in retrieval phase + * @param generationNanos time spent in generation phase + */ + void recordPipeline(String pipelineName, long retrievalNanos, long generationNanos); + + /** + * Records a chunk that was retrieved and used in context. + * Useful for tracking context window utilization. + * + * @param source vector DB or index name + * @param chunkTokens approximate token count of the chunk + */ + void recordChunk(String source, int chunkTokens); + + /** + * Records a retrieval that returned no results (empty context). + * + * @param source vector DB or index name + */ + void recordEmptyRetrieval(String source); + + // ==================== Reporting ==================== + + /** Returns a formatted summary of all RAG quality metrics. */ + String getSummary(); + + /** Returns summary for a specific source/index. */ + String getSourceSummary(String source); + + /** Returns total number of queries recorded across all sources. */ + long getTotalQueries(); + + /** Returns total number of empty retrievals. */ + long getTotalEmptyRetrievals(); + + /** Returns average similarity score across all queries that reported scores. */ + float getAverageTopScore(); + + /** Resets all collected metrics. */ + void reset(); +} diff --git a/btrace-extensions/btrace-rag-quality/src/impl/java/org/openjdk/btrace/rag/QueryRecordImpl.java b/btrace-extensions/btrace-rag-quality/src/impl/java/org/openjdk/btrace/rag/QueryRecordImpl.java new file mode 100644 index 00000000..b46a2909 --- /dev/null +++ b/btrace-extensions/btrace-rag-quality/src/impl/java/org/openjdk/btrace/rag/QueryRecordImpl.java @@ -0,0 +1,73 @@ +package org.openjdk.btrace.rag; + +/** + * ThreadLocal-pooled builder implementing the {@link QueryRecord} fluent API. + * One instance per thread, reused across calls — zero heap allocation. + */ +final class QueryRecordImpl implements QueryRecord { + + String source; + int resultCountVal; + float topScoreVal; + float lowScoreVal; + int embDimension; + int chunkTokensVal; + long durationVal; + + private RagQualityServiceImpl service; + + QueryRecordImpl() {} + + QueryRecordImpl reset(RagQualityServiceImpl service, String source) { + this.service = service; + this.source = source; + this.resultCountVal = 0; + this.topScoreVal = -1f; + this.lowScoreVal = -1f; + this.embDimension = 0; + this.chunkTokensVal = 0; + this.durationVal = 0; + return this; + } + + @Override + public QueryRecord resultCount(int count) { + this.resultCountVal = count; + return this; + } + + @Override + public QueryRecord topScore(float score) { + this.topScoreVal = score; + return this; + } + + @Override + public QueryRecord lowScore(float score) { + this.lowScoreVal = score; + return this; + } + + @Override + public QueryRecord embeddingDimension(int dimension) { + this.embDimension = dimension; + return this; + } + + @Override + public QueryRecord totalChunkTokens(int tokens) { + this.chunkTokensVal = tokens; + return this; + } + + @Override + public QueryRecord duration(long nanos) { + this.durationVal = nanos; + return this; + } + + @Override + public void record() { + service.commitQueryRecord(this); + } +} diff --git a/btrace-extensions/btrace-rag-quality/src/impl/java/org/openjdk/btrace/rag/RagQualityServiceImpl.java b/btrace-extensions/btrace-rag-quality/src/impl/java/org/openjdk/btrace/rag/RagQualityServiceImpl.java new file mode 100644 index 00000000..3acfd7d4 --- /dev/null +++ b/btrace-extensions/btrace-rag-quality/src/impl/java/org/openjdk/btrace/rag/RagQualityServiceImpl.java @@ -0,0 +1,320 @@ +package org.openjdk.btrace.rag; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import org.openjdk.btrace.core.extensions.Extension; + +/** + * Thread-safe RAG quality tracking with lock-free per-source statistics. + */ +public final class RagQualityServiceImpl extends Extension implements RagQualityService { + + private final Map sourceStats = new ConcurrentHashMap<>(); + private final Map pipelineStats = new ConcurrentHashMap<>(); + + private final ThreadLocal queryRecordPool = + ThreadLocal.withInitial(QueryRecordImpl::new); + + // ==================== Simple recording ==================== + + @Override + public void recordQuery(String source, long durationNanos) { + SourceStats stats = getOrCreate(source); + stats.queries.incrementAndGet(); + stats.totalDurationNanos.addAndGet(durationNanos); + updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos); + } + + @Override + public void recordQuery(String source, int resultCount, long durationNanos) { + SourceStats stats = getOrCreate(source); + stats.queries.incrementAndGet(); + stats.totalResults.addAndGet(resultCount); + stats.totalDurationNanos.addAndGet(durationNanos); + updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos); + if (resultCount == 0) { + stats.emptyRetrievals.incrementAndGet(); + } + } + + // ==================== Fluent builder ==================== + + @Override + public QueryRecord query(String source) { + return queryRecordPool.get().reset(this, source); + } + + void commitQueryRecord(QueryRecordImpl rec) { + SourceStats stats = getOrCreate(rec.source); + stats.queries.incrementAndGet(); + stats.totalResults.addAndGet(rec.resultCountVal); + stats.totalDurationNanos.addAndGet(rec.durationVal); + updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, rec.durationVal); + + if (rec.resultCountVal == 0) { + stats.emptyRetrievals.incrementAndGet(); + } + + if (rec.topScoreVal >= 0) { + stats.scoredQueries.incrementAndGet(); + // Accumulate scores as fixed-point (multiply by 10000) to avoid floating point atomics + stats.totalTopScore.addAndGet((long) (rec.topScoreVal * 10000)); + updateMin(stats.minTopScore, (long) (rec.topScoreVal * 10000)); + updateMax(stats.maxTopScore, (long) (rec.topScoreVal * 10000)); + } + if (rec.lowScoreVal >= 0) { + stats.totalLowScore.addAndGet((long) (rec.lowScoreVal * 10000)); + } + if (rec.chunkTokensVal > 0) { + stats.totalChunkTokens.addAndGet(rec.chunkTokensVal); + } + if (rec.embDimension > 0) { + stats.lastEmbeddingDimension = rec.embDimension; + } + } + + // ==================== Pipeline recording ==================== + + @Override + public void recordPipeline(String pipelineName, long retrievalNanos, long generationNanos) { + PipelineStats ps = pipelineStats.computeIfAbsent(pipelineName, k -> new PipelineStats()); + ps.invocations.incrementAndGet(); + ps.totalRetrievalNanos.addAndGet(retrievalNanos); + ps.totalGenerationNanos.addAndGet(generationNanos); + } + + @Override + public void recordChunk(String source, int chunkTokens) { + SourceStats stats = getOrCreate(source); + stats.totalChunkTokens.addAndGet(chunkTokens); + stats.chunkCount.incrementAndGet(); + } + + @Override + public void recordEmptyRetrieval(String source) { + getOrCreate(source).emptyRetrievals.incrementAndGet(); + } + + // ==================== Reporting ==================== + + @Override + public String getSummary() { + if (sourceStats.isEmpty() && pipelineStats.isEmpty()) { + return "No RAG queries recorded."; + } + + StringBuilder sb = new StringBuilder(); + sb.append("=== RAG Quality Summary ===\n\n"); + + long totalQueries = 0; + long totalEmpty = 0; + + for (Map.Entry entry : sourceStats.entrySet()) { + String source = entry.getKey(); + SourceStats s = entry.getValue(); + long queries = s.queries.get(); + totalQueries += queries; + long empty = s.emptyRetrievals.get(); + totalEmpty += empty; + + sb.append("Source: ").append(source).append("\n"); + sb.append(" Queries: ").append(queries); + if (empty > 0) { + sb.append(" (").append(empty).append(" empty, ") + .append(empty * 100 / queries).append("%)"); + } + sb.append("\n"); + + // Results + long totalRes = s.totalResults.get(); + if (totalRes > 0 && queries > 0) { + sb.append(" Results: ").append(totalRes) + .append(" total (avg ").append(totalRes / queries).append("/query)\n"); + } + + // Similarity scores + long scored = s.scoredQueries.get(); + if (scored > 0) { + float avgTop = (s.totalTopScore.get() / (float) scored) / 10000f; + float minTop = s.minTopScore.get() / 10000f; + float maxTop = s.maxTopScore.get() / 10000f; + sb.append(" Similarity (top): avg ").append(String.format("%.3f", avgTop)); + sb.append(", min ").append(String.format("%.3f", minTop)); + sb.append(", max ").append(String.format("%.3f", maxTop)); + sb.append("\n"); + + long totalLow = s.totalLowScore.get(); + if (totalLow > 0) { + float avgLow = (totalLow / (float) scored) / 10000f; + float spread = avgTop - avgLow; + sb.append(" Similarity (low): avg ").append(String.format("%.3f", avgLow)); + sb.append(" (spread ").append(String.format("%.3f", spread)).append(")\n"); + } + } + + // Latency + if (queries > 0) { + long avgMs = (s.totalDurationNanos.get() / queries) / 1_000_000; + long minMs = s.minDurationNanos.get() / 1_000_000; + long maxMs = s.maxDurationNanos.get() / 1_000_000; + sb.append(" Latency: avg ").append(avgMs).append("ms"); + sb.append(", min ").append(minMs).append("ms"); + sb.append(", max ").append(maxMs).append("ms\n"); + } + + // Chunks + long chunks = s.chunkCount.get(); + long chunkTokens = s.totalChunkTokens.get(); + if (chunkTokens > 0) { + sb.append(" Context: ").append(chunkTokens).append(" tokens"); + if (chunks > 0) { + sb.append(" (").append(chunks).append(" chunks, avg ") + .append(chunkTokens / chunks).append(" tok/chunk)"); + } + sb.append("\n"); + } + + sb.append("\n"); + } + + // Pipelines + if (!pipelineStats.isEmpty()) { + sb.append("--- Pipelines ---\n"); + for (Map.Entry entry : pipelineStats.entrySet()) { + PipelineStats ps = entry.getValue(); + long inv = ps.invocations.get(); + long avgRetMs = inv > 0 ? (ps.totalRetrievalNanos.get() / inv) / 1_000_000 : 0; + long avgGenMs = inv > 0 ? (ps.totalGenerationNanos.get() / inv) / 1_000_000 : 0; + sb.append(" ").append(entry.getKey()).append(": ") + .append(inv).append(" invocations, avg retrieval ") + .append(avgRetMs).append("ms, avg generation ") + .append(avgGenMs).append("ms\n"); + } + sb.append("\n"); + } + + sb.append("--- Totals ---\n"); + sb.append(" Queries: ").append(totalQueries).append("\n"); + if (totalEmpty > 0) { + sb.append(" Empty retrievals: ").append(totalEmpty); + if (totalQueries > 0) { + sb.append(" (").append(totalEmpty * 100 / totalQueries).append("%)"); + } + sb.append("\n"); + } + + return sb.toString(); + } + + @Override + public String getSourceSummary(String source) { + SourceStats s = sourceStats.get(source); + if (s == null) { + return "No data for source: " + source; + } + long queries = s.queries.get(); + long avgMs = queries > 0 ? (s.totalDurationNanos.get() / queries) / 1_000_000 : 0; + long empty = s.emptyRetrievals.get(); + StringBuilder sb = new StringBuilder(); + sb.append(source).append(": ").append(queries).append(" queries, avg ").append(avgMs).append("ms"); + if (empty > 0) { + sb.append(", ").append(empty).append(" empty"); + } + return sb.toString(); + } + + @Override + public long getTotalQueries() { + long total = 0; + for (SourceStats s : sourceStats.values()) { + total += s.queries.get(); + } + return total; + } + + @Override + public long getTotalEmptyRetrievals() { + long total = 0; + for (SourceStats s : sourceStats.values()) { + total += s.emptyRetrievals.get(); + } + return total; + } + + @Override + public float getAverageTopScore() { + long totalScored = 0; + long totalScore = 0; + for (SourceStats s : sourceStats.values()) { + totalScored += s.scoredQueries.get(); + totalScore += s.totalTopScore.get(); + } + if (totalScored == 0) return -1f; + return (totalScore / (float) totalScored) / 10000f; + } + + @Override + public void reset() { + sourceStats.clear(); + pipelineStats.clear(); + } + + @Override + public void close() { + String summary = getSummary(); + if (!"No RAG queries recorded.".equals(summary)) { + getContext().send(summary); + } + } + + // ==================== Internals ==================== + + private SourceStats getOrCreate(String source) { + return sourceStats.computeIfAbsent(source, k -> new SourceStats()); + } + + private static void updateMinMax(AtomicLong min, AtomicLong max, long value) { + updateMin(min, value); + updateMax(max, value); + } + + private static void updateMin(AtomicLong min, long value) { + long cur; + do { + cur = min.get(); + if (value >= cur) break; + } while (!min.compareAndSet(cur, value)); + } + + private static void updateMax(AtomicLong max, long value) { + long cur; + do { + cur = max.get(); + if (value <= cur) break; + } while (!max.compareAndSet(cur, value)); + } + + static final class SourceStats { + final AtomicLong queries = new AtomicLong(); + final AtomicLong totalResults = new AtomicLong(); + final AtomicLong emptyRetrievals = new AtomicLong(); + final AtomicLong scoredQueries = new AtomicLong(); + final AtomicLong totalTopScore = new AtomicLong(); // fixed-point * 10000 + final AtomicLong minTopScore = new AtomicLong(Long.MAX_VALUE); + final AtomicLong maxTopScore = new AtomicLong(0); + final AtomicLong totalLowScore = new AtomicLong(); // fixed-point * 10000 + final AtomicLong totalDurationNanos = new AtomicLong(); + final AtomicLong minDurationNanos = new AtomicLong(Long.MAX_VALUE); + final AtomicLong maxDurationNanos = new AtomicLong(0); + final AtomicLong totalChunkTokens = new AtomicLong(); + final AtomicLong chunkCount = new AtomicLong(); + volatile int lastEmbeddingDimension; + } + + static final class PipelineStats { + final AtomicLong invocations = new AtomicLong(); + final AtomicLong totalRetrievalNanos = new AtomicLong(); + final AtomicLong totalGenerationNanos = new AtomicLong(); + } +} diff --git a/btrace-extensions/btrace-rag-quality/src/test/java/org/openjdk/btrace/rag/RagQualityServiceTest.java b/btrace-extensions/btrace-rag-quality/src/test/java/org/openjdk/btrace/rag/RagQualityServiceTest.java new file mode 100644 index 00000000..51ae7243 --- /dev/null +++ b/btrace-extensions/btrace-rag-quality/src/test/java/org/openjdk/btrace/rag/RagQualityServiceTest.java @@ -0,0 +1,210 @@ +package org.openjdk.btrace.rag; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicLong; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class RagQualityServiceTest { + + private RagQualityServiceImpl service; + + @BeforeEach + void setUp() { + service = new RagQualityServiceImpl(); + } + + @Test + void durationOnlyQuery() { + service.recordQuery("pinecone", 5_000_000L); + assertEquals(1, service.getTotalQueries()); + String summary = service.getSummary(); + assertTrue(summary.contains("pinecone")); + assertTrue(summary.contains("1")); + } + + @Test + void queryWithResultCount() { + service.recordQuery("milvus", 10, 8_000_000L); + assertEquals(1, service.getTotalQueries()); + assertEquals(0, service.getTotalEmptyRetrievals()); + assertTrue(service.getSummary().contains("10 total")); + } + + @Test + void emptyRetrieval() { + service.recordQuery("weaviate", 0, 2_000_000L); + assertEquals(1, service.getTotalEmptyRetrievals()); + assertTrue(service.getSummary().contains("1 empty")); + } + + @Test + void explicitEmptyRetrieval() { + service.recordEmptyRetrieval("chroma"); + assertEquals(1, service.getTotalEmptyRetrievals()); + } + + @Test + void fluentBuilder() { + service.query("pinecone") + .resultCount(5) + .topScore(0.92f) + .lowScore(0.71f) + .embeddingDimension(1536) + .totalChunkTokens(3000) + .duration(10_000_000L) + .record(); + + assertEquals(1, service.getTotalQueries()); + float avgScore = service.getAverageTopScore(); + assertTrue(avgScore > 0.91f && avgScore < 0.93f, "avgScore=" + avgScore); + String summary = service.getSummary(); + assertTrue(summary.contains("0.920")); + assertTrue(summary.contains("3000 tokens")); + } + + @Test + void fluentBuilderMinimal() { + service.query("qdrant") + .duration(1_000_000L) + .record(); + + assertEquals(1, service.getTotalQueries()); + assertEquals(-1f, service.getAverageTopScore()); + } + + @Test + void fluentBuilderEmptyResult() { + service.query("pgvector") + .resultCount(0) + .duration(500_000L) + .record(); + + assertEquals(1, service.getTotalEmptyRetrievals()); + } + + @Test + void multipleSourcesTrackedSeparately() { + service.recordQuery("pinecone", 5, 10_000_000L); + service.recordQuery("milvus", 3, 8_000_000L); + service.recordQuery("pinecone", 7, 12_000_000L); + + assertEquals(3, service.getTotalQueries()); + assertTrue(service.getSourceSummary("pinecone").contains("2 queries")); + assertTrue(service.getSourceSummary("milvus").contains("1 queries")); + } + + @Test + void pipelineRecording() { + service.recordPipeline("qa-bot", 5_000_000L, 50_000_000L); + service.recordPipeline("qa-bot", 3_000_000L, 45_000_000L); + + String summary = service.getSummary(); + assertTrue(summary.contains("qa-bot")); + assertTrue(summary.contains("2 invocations")); + } + + @Test + void chunkRecording() { + service.recordChunk("pinecone", 500); + service.recordChunk("pinecone", 750); + + String summary = service.getSourceSummary("pinecone"); + // Chunks don't count as queries + assertTrue(summary.contains("0 queries")); + } + + @Test + void similarityScoreAggregation() { + service.query("pinecone").topScore(0.90f).duration(1_000_000L).record(); + service.query("pinecone").topScore(0.80f).duration(1_000_000L).record(); + service.query("milvus").topScore(0.70f).duration(1_000_000L).record(); + + float avg = service.getAverageTopScore(); + // (0.90 + 0.80 + 0.70) / 3 = 0.8 + assertTrue(avg > 0.79f && avg < 0.81f, "avg=" + avg); + } + + @Test + void latencyMinMax() { + service.recordQuery("pinecone", 10_000_000L); + service.recordQuery("pinecone", 2_000_000L); + service.recordQuery("pinecone", 50_000_000L); + + String summary = service.getSummary(); + assertTrue(summary.contains("min 2ms")); + assertTrue(summary.contains("max 50ms")); + } + + @Test + void noDataSummary() { + assertEquals("No RAG queries recorded.", service.getSummary()); + } + + @Test + void unknownSourceSummary() { + assertEquals("No data for source: unknown", service.getSourceSummary("unknown")); + } + + @Test + void reset() { + service.recordQuery("pinecone", 5_000_000L); + service.recordEmptyRetrieval("milvus"); + service.recordPipeline("qa", 1L, 1L); + service.reset(); + + assertEquals(0, service.getTotalQueries()); + assertEquals(0, service.getTotalEmptyRetrievals()); + assertEquals("No RAG queries recorded.", service.getSummary()); + } + + @Test + void concurrentRecording() throws Exception { + int threads = 8; + int queriesPerThread = 1000; + CountDownLatch latch = new CountDownLatch(threads); + + for (int t = 0; t < threads; t++) { + new Thread(() -> { + try { + for (int i = 0; i < queriesPerThread; i++) { + service.recordQuery("pinecone", 5, 1_000_000L); + } + } finally { + latch.countDown(); + } + }).start(); + } + latch.await(); + + assertEquals(threads * queriesPerThread, service.getTotalQueries()); + } + + @Test + void concurrentBuilderRecording() throws Exception { + int threads = 8; + int queriesPerThread = 500; + CountDownLatch latch = new CountDownLatch(threads); + + for (int t = 0; t < threads; t++) { + new Thread(() -> { + try { + for (int i = 0; i < queriesPerThread; i++) { + service.query("milvus") + .resultCount(3) + .topScore(0.85f) + .duration(2_000_000L) + .record(); + } + } finally { + latch.countDown(); + } + }).start(); + } + latch.await(); + + assertEquals(threads * queriesPerThread, service.getTotalQueries()); + } +} diff --git a/btrace-extensions/btrace-vibe-guard/build.gradle b/btrace-extensions/btrace-vibe-guard/build.gradle new file mode 100644 index 00000000..58a3ef5f --- /dev/null +++ b/btrace-extensions/btrace-vibe-guard/build.gradle @@ -0,0 +1,33 @@ +plugins { + id 'org.openjdk.btrace.extension' + alias(libs.plugins.shadow) +} + +java { + sourceCompatibility = 8 + targetCompatibility = 8 +} + +compileJava { + javaCompiler = javaToolchains.compilerFor { + languageVersion.set(JavaLanguageVersion.of(11)) + } +} + +btraceExtension { + id = 'btrace-vibe-guard' + name = 'BTrace Vibe Guard' + description = 'Behavioral contracts and runtime validation for AI-generated code paths' + services = ['org.openjdk.btrace.vibeguard.VibeGuardService'] +} + +dependencies { + apiCompileOnly project(':btrace-core') + implImplementation project(':btrace-core') + + testImplementation libs.junit.jupiter +} + +test { + useJUnitPlatform() +} diff --git a/btrace-extensions/btrace-vibe-guard/src/api/java/org/openjdk/btrace/vibeguard/VibeGuardService.java b/btrace-extensions/btrace-vibe-guard/src/api/java/org/openjdk/btrace/vibeguard/VibeGuardService.java new file mode 100644 index 00000000..cbdc713f --- /dev/null +++ b/btrace-extensions/btrace-vibe-guard/src/api/java/org/openjdk/btrace/vibeguard/VibeGuardService.java @@ -0,0 +1,125 @@ +package org.openjdk.btrace.vibeguard; + +import org.openjdk.btrace.core.extensions.ServiceDescriptor; + +/** + * BTrace extension for runtime behavioral contracts on AI-generated code. + * + *

Validates that methods respect declared invariants at runtime — useful for + * "vibe coding" workflows where LLMs generate code that needs guardrails. + * Tracks contract violations, latency budgets, call frequency limits, and + * return value constraints without modifying the target code. + * + *

Contracts are defined in BTrace scripts and enforced at instrumentation + * points. When a contract is violated, the service records it and optionally + * triggers an alert via the configured handler. + * + *

Usage in a BTrace script: + *

+ * @Injected VibeGuardService guard;
+ *
+ * @OnMethod(clazz = "com.app.AiService", method = "generate")
+ * void onEntry() {
+ *     guard.checkCallRate("AiService.generate", 100); // max 100 calls/sec
+ * }
+ *
+ * @OnMethod(clazz = "com.app.AiService", method = "generate",
+ *          location = @Location(Kind.RETURN))
+ * void onReturn(@Duration long dur) {
+ *     guard.checkLatency("AiService.generate", dur, 500_000_000L); // 500ms budget
+ * }
+ * 
+ */ +@ServiceDescriptor +public interface VibeGuardService { + + // ==================== Contract checks ==================== + + /** + * Checks that a method's latency does not exceed the budget. + * Records a violation if {@code durationNanos > budgetNanos}. + * + * @param contract contract/method name + * @param durationNanos actual duration + * @param budgetNanos maximum allowed duration + */ + void checkLatency(String contract, long durationNanos, long budgetNanos); + + /** + * Checks that call rate does not exceed the limit per second. + * Uses a sliding window to detect bursts. + * + * @param contract contract/method name + * @param maxPerSecond maximum allowed calls per second + */ + void checkCallRate(String contract, int maxPerSecond); + + /** + * Asserts a boolean condition. Records a violation if false. + * + * @param contract contract name + * @param condition the condition to check + * @param message violation message if condition is false + */ + void assertCondition(String contract, boolean condition, String message); + + /** + * Checks that a numeric return value is within bounds. + * + * @param contract contract name + * @param value actual value + * @param min minimum allowed (inclusive) + * @param max maximum allowed (inclusive) + */ + void checkRange(String contract, long value, long min, long max); + + /** + * Checks that a return value is not null. Records a violation if null. + * + * @param contract contract name + * @param value the value to check + */ + void checkNotNull(String contract, Object value); + + // ==================== Tracking ==================== + + /** + * Tags a method invocation as AI-generated. Used to compare behavior + * between AI and human code paths. + * + * @param contract contract/method name + * @param durationNanos execution duration + */ + void trackAiCodePath(String contract, long durationNanos); + + /** + * Tags a method invocation as human-written (baseline). + * + * @param contract contract/method name + * @param durationNanos execution duration + */ + void trackHumanCodePath(String contract, long durationNanos); + + // ==================== Reporting ==================== + + /** Returns a formatted summary of all contract violations and stats. */ + String getSummary(); + + /** Returns total number of contract violations across all contracts. */ + long getTotalViolations(); + + /** Returns number of violations for a specific contract. */ + long getViolations(String contract); + + /** Returns total number of contract checks performed. */ + long getTotalChecks(); + + /** + * Returns true if any contract has been violated. + * Useful in {@code @OnEvent} handlers for alerting. + */ + boolean hasViolations(); + + /** Resets all metrics and violation history. */ + void reset(); +} diff --git a/btrace-extensions/btrace-vibe-guard/src/impl/java/org/openjdk/btrace/vibeguard/VibeGuardServiceImpl.java b/btrace-extensions/btrace-vibe-guard/src/impl/java/org/openjdk/btrace/vibeguard/VibeGuardServiceImpl.java new file mode 100644 index 00000000..152730a1 --- /dev/null +++ b/btrace-extensions/btrace-vibe-guard/src/impl/java/org/openjdk/btrace/vibeguard/VibeGuardServiceImpl.java @@ -0,0 +1,320 @@ +package org.openjdk.btrace.vibeguard; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicLong; +import org.openjdk.btrace.core.extensions.Extension; + +/** + * Thread-safe behavioral contract enforcement with lock-free statistics. + */ +public final class VibeGuardServiceImpl extends Extension implements VibeGuardService { + + private final Map contracts = new ConcurrentHashMap<>(); + private final Map rateWindows = new ConcurrentHashMap<>(); + private final Map aiPaths = new ConcurrentHashMap<>(); + private final Map humanPaths = new ConcurrentHashMap<>(); + + // ==================== Contract checks ==================== + + @Override + public void checkLatency(String contract, long durationNanos, long budgetNanos) { + ContractStats stats = getOrCreate(contract); + stats.checks.incrementAndGet(); + if (durationNanos > budgetNanos) { + stats.violations.incrementAndGet(); + stats.lastViolationMessage = "Latency " + (durationNanos / 1_000_000) + + "ms exceeded budget " + (budgetNanos / 1_000_000) + "ms"; + } + stats.totalDurationNanos.addAndGet(durationNanos); + updateMinMax(stats.minDurationNanos, stats.maxDurationNanos, durationNanos); + } + + @Override + public void checkCallRate(String contract, int maxPerSecond) { + ContractStats stats = getOrCreate(contract); + stats.checks.incrementAndGet(); + + RateWindow window = rateWindows.computeIfAbsent(contract, k -> new RateWindow()); + long now = System.nanoTime(); + long count = window.recordAndCount(now); + if (count > maxPerSecond) { + stats.violations.incrementAndGet(); + stats.lastViolationMessage = "Rate " + count + "/sec exceeded limit " + maxPerSecond + "/sec"; + } + } + + @Override + public void assertCondition(String contract, boolean condition, String message) { + ContractStats stats = getOrCreate(contract); + stats.checks.incrementAndGet(); + if (!condition) { + stats.violations.incrementAndGet(); + stats.lastViolationMessage = message; + } + } + + @Override + public void checkRange(String contract, long value, long min, long max) { + ContractStats stats = getOrCreate(contract); + stats.checks.incrementAndGet(); + if (value < min || value > max) { + stats.violations.incrementAndGet(); + stats.lastViolationMessage = "Value " + value + " outside range [" + min + ", " + max + "]"; + } + } + + @Override + public void checkNotNull(String contract, Object value) { + ContractStats stats = getOrCreate(contract); + stats.checks.incrementAndGet(); + if (value == null) { + stats.violations.incrementAndGet(); + stats.lastViolationMessage = "Unexpected null return"; + } + } + + // ==================== Tracking ==================== + + @Override + public void trackAiCodePath(String contract, long durationNanos) { + CodePathStats s = aiPaths.computeIfAbsent(contract, k -> new CodePathStats()); + s.calls.incrementAndGet(); + s.totalDurationNanos.addAndGet(durationNanos); + } + + @Override + public void trackHumanCodePath(String contract, long durationNanos) { + CodePathStats s = humanPaths.computeIfAbsent(contract, k -> new CodePathStats()); + s.calls.incrementAndGet(); + s.totalDurationNanos.addAndGet(durationNanos); + } + + // ==================== Reporting ==================== + + @Override + public String getSummary() { + if (contracts.isEmpty() && aiPaths.isEmpty()) { + return "No contracts checked."; + } + + StringBuilder sb = new StringBuilder(); + sb.append("=== Vibe Guard Summary ===\n\n"); + + long totalChecks = 0; + long totalViolations = 0; + + for (Map.Entry entry : contracts.entrySet()) { + String name = entry.getKey(); + ContractStats s = entry.getValue(); + long checks = s.checks.get(); + long violations = s.violations.get(); + totalChecks += checks; + totalViolations += violations; + + sb.append("Contract: ").append(name).append("\n"); + sb.append(" Checks: ").append(checks); + if (violations > 0) { + sb.append(" | VIOLATIONS: ").append(violations); + sb.append(" (").append(violations * 100 / checks).append("%)"); + } else { + sb.append(" | OK"); + } + sb.append("\n"); + + // Latency stats if tracked + long dur = s.totalDurationNanos.get(); + if (dur > 0 && checks > 0) { + long avgMs = (dur / checks) / 1_000_000; + long minMs = s.minDurationNanos.get() == Long.MAX_VALUE ? 0 : s.minDurationNanos.get() / 1_000_000; + long maxMs = s.maxDurationNanos.get() / 1_000_000; + sb.append(" Latency: avg ").append(avgMs).append("ms"); + sb.append(", min ").append(minMs).append("ms"); + sb.append(", max ").append(maxMs).append("ms\n"); + } + + // Last violation + if (violations > 0 && s.lastViolationMessage != null) { + sb.append(" Last: ").append(s.lastViolationMessage).append("\n"); + } + sb.append("\n"); + } + + // AI vs Human comparison + if (!aiPaths.isEmpty() || !humanPaths.isEmpty()) { + sb.append("--- AI vs Human Code Paths ---\n"); + // Collect all contract names from both + ConcurrentHashMap allNames = new ConcurrentHashMap<>(); + for (String k : aiPaths.keySet()) allNames.put(k, Boolean.TRUE); + for (String k : humanPaths.keySet()) allNames.put(k, Boolean.TRUE); + + for (String name : allNames.keySet()) { + CodePathStats ai = aiPaths.get(name); + CodePathStats human = humanPaths.get(name); + + sb.append(" ").append(name).append(": "); + if (ai != null) { + long aiCalls = ai.calls.get(); + long aiAvgMs = aiCalls > 0 ? (ai.totalDurationNanos.get() / aiCalls) / 1_000_000 : 0; + sb.append("AI ").append(aiCalls).append(" calls avg ").append(aiAvgMs).append("ms"); + } + if (ai != null && human != null) sb.append(" | "); + if (human != null) { + long hCalls = human.calls.get(); + long hAvgMs = hCalls > 0 ? (human.totalDurationNanos.get() / hCalls) / 1_000_000 : 0; + sb.append("Human ").append(hCalls).append(" calls avg ").append(hAvgMs).append("ms"); + } + + // Performance comparison + if (ai != null && human != null) { + long aiCalls = ai.calls.get(); + long hCalls = human.calls.get(); + if (aiCalls > 0 && hCalls > 0) { + long aiAvg = ai.totalDurationNanos.get() / aiCalls; + long hAvg = human.totalDurationNanos.get() / hCalls; + if (hAvg > 0) { + long pctDiff = ((aiAvg - hAvg) * 100) / hAvg; + if (pctDiff > 0) { + sb.append(" [AI ").append(pctDiff).append("% slower]"); + } else if (pctDiff < 0) { + sb.append(" [AI ").append(-pctDiff).append("% faster]"); + } + } + } + } + sb.append("\n"); + } + sb.append("\n"); + } + + sb.append("--- Totals ---\n"); + sb.append(" Checks: ").append(totalChecks).append("\n"); + sb.append(" Violations: ").append(totalViolations); + if (totalViolations == 0) { + sb.append(" (all contracts satisfied)"); + } + sb.append("\n"); + + return sb.toString(); + } + + @Override + public long getTotalViolations() { + long total = 0; + for (ContractStats s : contracts.values()) { + total += s.violations.get(); + } + return total; + } + + @Override + public long getViolations(String contract) { + ContractStats s = contracts.get(contract); + return s != null ? s.violations.get() : 0; + } + + @Override + public long getTotalChecks() { + long total = 0; + for (ContractStats s : contracts.values()) { + total += s.checks.get(); + } + return total; + } + + @Override + public boolean hasViolations() { + for (ContractStats s : contracts.values()) { + if (s.violations.get() > 0) return true; + } + return false; + } + + @Override + public void reset() { + contracts.clear(); + rateWindows.clear(); + aiPaths.clear(); + humanPaths.clear(); + } + + @Override + public void close() { + String summary = getSummary(); + if (!"No contracts checked.".equals(summary)) { + getContext().send(summary); + } + } + + // ==================== Internals ==================== + + private ContractStats getOrCreate(String contract) { + return contracts.computeIfAbsent(contract, k -> new ContractStats()); + } + + private static void updateMinMax(AtomicLong min, AtomicLong max, long value) { + long cur; + do { + cur = min.get(); + if (value >= cur) break; + } while (!min.compareAndSet(cur, value)); + do { + cur = max.get(); + if (value <= cur) break; + } while (!max.compareAndSet(cur, value)); + } + + static final class ContractStats { + final AtomicLong checks = new AtomicLong(); + final AtomicLong violations = new AtomicLong(); + final AtomicLong totalDurationNanos = new AtomicLong(); + final AtomicLong minDurationNanos = new AtomicLong(Long.MAX_VALUE); + final AtomicLong maxDurationNanos = new AtomicLong(0); + volatile String lastViolationMessage; + } + + static final class CodePathStats { + final AtomicLong calls = new AtomicLong(); + final AtomicLong totalDurationNanos = new AtomicLong(); + } + + /** + * Simple sliding-window rate counter. Tracks calls in the last second + * using a circular buffer of 10 x 100ms buckets. Lock-free via CAS. + */ + static final class RateWindow { + private static final int BUCKETS = 10; + private static final long BUCKET_NS = 100_000_000L; // 100ms + private final AtomicLong[] counts = new AtomicLong[BUCKETS]; + private final AtomicLong[] timestamps = new AtomicLong[BUCKETS]; + + RateWindow() { + for (int i = 0; i < BUCKETS; i++) { + counts[i] = new AtomicLong(); + timestamps[i] = new AtomicLong(); + } + } + + long recordAndCount(long nowNanos) { + int bucket = (int) ((nowNanos / BUCKET_NS) % BUCKETS); + long bucketTime = (nowNanos / BUCKET_NS) * BUCKET_NS; + + // Reset bucket if stale + if (timestamps[bucket].get() != bucketTime) { + timestamps[bucket].set(bucketTime); + counts[bucket].set(0); + } + counts[bucket].incrementAndGet(); + + // Sum all non-stale buckets + long total = 0; + long windowStart = nowNanos - (BUCKETS * BUCKET_NS); + for (int i = 0; i < BUCKETS; i++) { + if (timestamps[i].get() > windowStart) { + total += counts[i].get(); + } + } + return total; + } + } +} diff --git a/btrace-extensions/btrace-vibe-guard/src/test/java/org/openjdk/btrace/vibeguard/VibeGuardServiceTest.java b/btrace-extensions/btrace-vibe-guard/src/test/java/org/openjdk/btrace/vibeguard/VibeGuardServiceTest.java new file mode 100644 index 00000000..747c591f --- /dev/null +++ b/btrace-extensions/btrace-vibe-guard/src/test/java/org/openjdk/btrace/vibeguard/VibeGuardServiceTest.java @@ -0,0 +1,227 @@ +package org.openjdk.btrace.vibeguard; + +import static org.junit.jupiter.api.Assertions.*; + +import java.util.concurrent.CountDownLatch; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +class VibeGuardServiceTest { + + private VibeGuardServiceImpl service; + + @BeforeEach + void setUp() { + service = new VibeGuardServiceImpl(); + } + + // ==================== Latency checks ==================== + + @Test + void latencyWithinBudget() { + service.checkLatency("api.generate", 100_000_000L, 500_000_000L); + assertEquals(1, service.getTotalChecks()); + assertEquals(0, service.getTotalViolations()); + assertFalse(service.hasViolations()); + } + + @Test + void latencyExceedsBudget() { + service.checkLatency("api.generate", 600_000_000L, 500_000_000L); + assertEquals(1, service.getTotalViolations()); + assertTrue(service.hasViolations()); + String summary = service.getSummary(); + assertTrue(summary.contains("VIOLATIONS: 1")); + assertTrue(summary.contains("exceeded budget")); + } + + @Test + void latencyStatsTracked() { + service.checkLatency("api.call", 100_000_000L, 1_000_000_000L); + service.checkLatency("api.call", 200_000_000L, 1_000_000_000L); + service.checkLatency("api.call", 300_000_000L, 1_000_000_000L); + + assertEquals(3, service.getTotalChecks()); + String summary = service.getSummary(); + assertTrue(summary.contains("avg 200ms")); + assertTrue(summary.contains("min 100ms")); + assertTrue(summary.contains("max 300ms")); + } + + // ==================== Call rate checks ==================== + + @Test + void callRateWithinLimit() { + // Single call should never exceed any reasonable limit + service.checkCallRate("api.query", 1000); + assertEquals(0, service.getTotalViolations()); + } + + @Test + void callRateExceedsLimit() { + // Flood with calls — rate window should detect burst + for (int i = 0; i < 200; i++) { + service.checkCallRate("api.query", 10); + } + // After 200 calls in rapid succession, should have violations + assertTrue(service.getTotalViolations() > 0); + } + + // ==================== Condition assertions ==================== + + @Test + void assertConditionTrue() { + service.assertCondition("positive-balance", true, "Balance must be positive"); + assertEquals(0, service.getTotalViolations()); + } + + @Test + void assertConditionFalse() { + service.assertCondition("positive-balance", false, "Balance must be positive"); + assertEquals(1, service.getTotalViolations()); + assertTrue(service.getSummary().contains("Balance must be positive")); + } + + // ==================== Range checks ==================== + + @Test + void rangeWithinBounds() { + service.checkRange("response-code", 200, 100, 599); + assertEquals(0, service.getTotalViolations()); + } + + @Test + void rangeBelowMin() { + service.checkRange("response-code", 50, 100, 599); + assertEquals(1, service.getTotalViolations()); + assertTrue(service.getSummary().contains("outside range")); + } + + @Test + void rangeAboveMax() { + service.checkRange("response-code", 700, 100, 599); + assertEquals(1, service.getTotalViolations()); + } + + @Test + void rangeAtBoundaries() { + service.checkRange("val", 100, 100, 200); + service.checkRange("val", 200, 100, 200); + assertEquals(0, service.getTotalViolations()); + } + + // ==================== Null checks ==================== + + @Test + void checkNotNullWithValue() { + service.checkNotNull("api.result", "hello"); + assertEquals(0, service.getTotalViolations()); + } + + @Test + void checkNotNullWithNull() { + service.checkNotNull("api.result", null); + assertEquals(1, service.getTotalViolations()); + assertTrue(service.getSummary().contains("Unexpected null")); + } + + // ==================== AI vs Human tracking ==================== + + @Test + void aiVsHumanComparison() { + service.trackAiCodePath("Parser.parse", 50_000_000L); + service.trackAiCodePath("Parser.parse", 60_000_000L); + service.trackHumanCodePath("Parser.parse", 30_000_000L); + service.trackHumanCodePath("Parser.parse", 40_000_000L); + + String summary = service.getSummary(); + assertTrue(summary.contains("AI vs Human")); + assertTrue(summary.contains("AI 2 calls")); + assertTrue(summary.contains("Human 2 calls")); + // AI avg 55ms, Human avg 35ms -> AI ~57% slower + assertTrue(summary.contains("slower")); + } + + @Test + void aiOnlyTracking() { + service.trackAiCodePath("Renderer.render", 100_000_000L); + String summary = service.getSummary(); + assertTrue(summary.contains("AI 1 calls")); + assertFalse(summary.contains("Human")); + } + + // ==================== Reporting ==================== + + @Test + void noDataSummary() { + assertEquals("No contracts checked.", service.getSummary()); + } + + @Test + void getViolationsPerContract() { + service.assertCondition("a", false, "fail"); + service.assertCondition("a", false, "fail"); + service.assertCondition("b", false, "fail"); + + assertEquals(2, service.getViolations("a")); + assertEquals(1, service.getViolations("b")); + assertEquals(0, service.getViolations("c")); + } + + @Test + void multipleContractsInSummary() { + service.checkLatency("fast-api", 10_000_000L, 100_000_000L); + service.checkLatency("slow-api", 500_000_000L, 100_000_000L); + + String summary = service.getSummary(); + assertTrue(summary.contains("fast-api")); + assertTrue(summary.contains("slow-api")); + assertTrue(summary.contains("Checks: 2")); + assertTrue(summary.contains("Violations: 1")); + } + + @Test + void allContractsSatisfied() { + service.checkLatency("api", 10_000_000L, 100_000_000L); + service.assertCondition("invariant", true, "ok"); + service.checkRange("val", 50, 0, 100); + + String summary = service.getSummary(); + assertTrue(summary.contains("all contracts satisfied")); + } + + @Test + void reset() { + service.assertCondition("a", false, "fail"); + service.trackAiCodePath("b", 100L); + service.reset(); + + assertEquals(0, service.getTotalChecks()); + assertEquals(0, service.getTotalViolations()); + assertFalse(service.hasViolations()); + assertEquals("No contracts checked.", service.getSummary()); + } + + @Test + void concurrentChecks() throws Exception { + int threads = 8; + int checksPerThread = 1000; + CountDownLatch latch = new CountDownLatch(threads); + + for (int t = 0; t < threads; t++) { + new Thread(() -> { + try { + for (int i = 0; i < checksPerThread; i++) { + service.checkLatency("concurrent-api", 50_000_000L, 100_000_000L); + } + } finally { + latch.countDown(); + } + }).start(); + } + latch.await(); + + assertEquals(threads * checksPerThread, service.getTotalChecks()); + assertEquals(0, service.getTotalViolations()); + } +} diff --git a/btrace-mcp-server/README.md b/btrace-mcp-server/README.md new file mode 100644 index 00000000..a23d5111 --- /dev/null +++ b/btrace-mcp-server/README.md @@ -0,0 +1,114 @@ +# BTrace MCP Server + +An MCP (Model Context Protocol) server that exposes BTrace operations as tools, allowing LLM clients (Claude Desktop, Claude Code, Cursor, etc.) to diagnose running JVMs. + +## Tools + +| Tool | Description | +|------|-------------| +| `list_jvms` | List all attachable Java VMs on this host | +| `deploy_oneliner` | Deploy a BTrace oneliner probe (e.g. `com.example.Service::method @return { print method, duration }`) | +| `deploy_script` | Deploy a full BTrace Java script | +| `list_probes` | List active probes on a JVM | +| `send_event` | Send an event to a running probe | +| `detach_probe` | Detach from a running probe (probe continues) | +| `exit_probe` | Stop and remove a probe | + +## Prompts + +| Prompt | Description | +|--------|-------------| +| `diagnose_slow_endpoint` | Step-by-step guide to diagnose a slow HTTP endpoint | +| `find_exception_source` | Guide to find where exceptions originate | +| `profile_method` | Guide to profile a method's latency | + +## Building + +```bash +./gradlew :btrace-mcp-server:build +``` + +## Configuration + +### Claude Desktop + +Add to `~/Library/Application Support/Claude/claude_desktop_config.json` (macOS) or `%APPDATA%\Claude\claude_desktop_config.json` (Windows): + +```json +{ + "mcpServers": { + "btrace": { + "command": "java", + "args": [ + "-cp", "/path/to/btrace/lib/*", + "org.openjdk.btrace.mcp.BTraceMcpServer" + ] + } + } +} +``` + +### Claude Code + +Add to your project's `.mcp.json`: + +```json +{ + "mcpServers": { + "btrace": { + "command": "java", + "args": [ + "-cp", "/path/to/btrace/lib/*", + "org.openjdk.btrace.mcp.BTraceMcpServer" + ] + } + } +} +``` + +Or configure globally in `~/.claude/settings.json`: + +```json +{ + "mcpServers": { + "btrace": { + "command": "java", + "args": [ + "-cp", "/path/to/btrace/lib/*", + "org.openjdk.btrace.mcp.BTraceMcpServer" + ] + } + } +} +``` + +### Using the BTrace distribution + +If you have BTrace installed (e.g. via SDKMAN), you can reference the distribution directly: + +```json +{ + "mcpServers": { + "btrace": { + "command": "java", + "args": [ + "-cp", "$BTRACE_HOME/lib/*:$BTRACE_HOME/build/btrace-mcp-server.jar", + "org.openjdk.btrace.mcp.BTraceMcpServer" + ] + } + } +} +``` + +## Usage Examples + +Once configured, you can ask your LLM client: + +- "List all running Java processes" +- "Attach to PID 12345 and trace all methods in com.example.UserService that take longer than 100ms" +- "Find where NullPointerExceptions are being thrown in PID 12345" +- "Profile the latency of com.example.OrderService::processOrder" + +## Protocol + +The server uses MCP over stdio (stdin/stdout JSON-RPC). All logging goes to stderr. diff --git a/btrace-mcp-server/build.gradle b/btrace-mcp-server/build.gradle new file mode 100644 index 00000000..adb4a4f9 --- /dev/null +++ b/btrace-mcp-server/build.gradle @@ -0,0 +1,28 @@ +// MCP server requires Java 11+ (uses ConcurrentHashMap.forEach, etc.) +compileJava { + sourceCompatibility = 11 + targetCompatibility = 11 +} + +dependencies { + implementation libs.slf4j + implementation libs.slf4j.simple + + def toolsJar = getToolsJar() + if (toolsJar.getAsFile().exists()) { + compileOnly files("${toolsJar}") + } + + implementation project(':btrace-core') + implementation project(':btrace-client') + implementation project(':btrace-compiler') + implementation project(':btrace-boot') +} + +jar { + manifest { + attributes( + 'Main-Class': 'org.openjdk.btrace.mcp.BTraceMcpServer' + ) + } +} diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/BTraceMcpServer.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/BTraceMcpServer.java new file mode 100644 index 00000000..7c06eb05 --- /dev/null +++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/BTraceMcpServer.java @@ -0,0 +1,226 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package org.openjdk.btrace.mcp; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import org.openjdk.btrace.mcp.prompts.DiagnosticPrompts; +import org.openjdk.btrace.mcp.tools.DeployOnelinerHandler; +import org.openjdk.btrace.mcp.tools.DeployScriptHandler; +import org.openjdk.btrace.mcp.tools.DetachProbeHandler; +import org.openjdk.btrace.mcp.tools.ExitProbeHandler; +import org.openjdk.btrace.mcp.tools.ListJvmsHandler; +import org.openjdk.btrace.mcp.tools.ListProbesHandler; +import org.openjdk.btrace.mcp.tools.SendEventHandler; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * BTrace MCP (Model Context Protocol) server. Exposes BTrace operations as MCP tools over stdio + * JSON-RPC transport, allowing LLM clients to instrument and diagnose running JVMs. + */ +public final class BTraceMcpServer { + private static final Logger log = LoggerFactory.getLogger(BTraceMcpServer.class); + private static final String SERVER_NAME = "btrace-mcp-server"; + private static final String SERVER_VERSION = "0.1.0"; + private static final String PROTOCOL_VERSION = "2024-11-05"; + + private final McpProtocol protocol; + + BTraceMcpServer(McpProtocol protocol) { + this.protocol = protocol; + } + + public static void main(String[] args) { + // Redirect System.out logging to stderr so stdout stays clean for MCP JSON-RPC + System.setProperty("org.slf4j.simpleLogger.logFile", "System.err"); + System.setProperty("org.slf4j.simpleLogger.defaultLogLevel", "warn"); + + log.info("Starting {}", SERVER_NAME); + McpProtocol proto = new McpProtocol(System.in, System.out); + BTraceMcpServer server = new BTraceMcpServer(proto); + try { + server.run(); + } catch (IOException e) { + log.error("Server error", e); + System.exit(1); + } + } + + @SuppressWarnings("unchecked") + void run() throws IOException { + while (true) { + Map message = protocol.readMessage(); + if (message == null) { + log.info("EOF on stdin, shutting down"); + break; + } + + String method = (String) message.get("method"); + Object id = message.get("id"); + Map params = + message.containsKey("params") ? (Map) message.get("params") : null; + + if (method == null) { + // Response or notification without method — ignore + continue; + } + + try { + switch (method) { + case "initialize": + handleInitialize(id, params); + break; + case "notifications/initialized": + // Client acknowledgement — no response needed + break; + case "tools/list": + handleToolsList(id); + break; + case "tools/call": + handleToolsCall(id, params); + break; + case "prompts/list": + handlePromptsList(id); + break; + case "prompts/get": + handlePromptsGet(id, params); + break; + default: + protocol.sendError(id, -32601, "Method not found: " + method); + } + } catch (Exception e) { + log.error("Error handling method: {}", method, e); + protocol.sendError(id, -32603, "Internal error: " + e.getMessage()); + } + } + } + + private void handleInitialize(Object id, Map params) throws IOException { + Map serverInfo = new LinkedHashMap<>(); + serverInfo.put("name", SERVER_NAME); + serverInfo.put("version", SERVER_VERSION); + + Map toolsCap = new LinkedHashMap<>(); + Map promptsCap = new LinkedHashMap<>(); + Map capabilities = new LinkedHashMap<>(); + capabilities.put("tools", toolsCap); + capabilities.put("prompts", promptsCap); + + Map result = new LinkedHashMap<>(); + result.put("protocolVersion", PROTOCOL_VERSION); + result.put("capabilities", capabilities); + result.put("serverInfo", serverInfo); + + protocol.sendResult(id, result); + } + + private void handleToolsList(Object id) throws IOException { + List tools = new ArrayList<>(); + tools.add(ListJvmsHandler.schema()); + tools.add(DeployOnelinerHandler.schema()); + tools.add(DeployScriptHandler.schema()); + tools.add(ListProbesHandler.schema()); + tools.add(SendEventHandler.schema()); + tools.add(DetachProbeHandler.schema()); + tools.add(ExitProbeHandler.schema()); + + Map result = new LinkedHashMap<>(); + result.put("tools", tools); + protocol.sendResult(id, result); + } + + @SuppressWarnings("unchecked") + private void handleToolsCall(Object id, Map params) throws IOException { + if (params == null) { + protocol.sendError(id, -32602, "Missing params"); + return; + } + String toolName = (String) params.get("name"); + Map arguments = + params.containsKey("arguments") + ? (Map) params.get("arguments") + : new LinkedHashMap<>(); + + Map result; + switch (toolName) { + case "list_jvms": + result = ListJvmsHandler.execute(arguments); + break; + case "deploy_oneliner": + result = DeployOnelinerHandler.execute(arguments); + break; + case "deploy_script": + result = DeployScriptHandler.execute(arguments); + break; + case "list_probes": + result = ListProbesHandler.execute(arguments); + break; + case "send_event": + result = SendEventHandler.execute(arguments); + break; + case "detach_probe": + result = DetachProbeHandler.execute(arguments); + break; + case "exit_probe": + result = ExitProbeHandler.execute(arguments); + break; + default: + protocol.sendError(id, -32602, "Unknown tool: " + toolName); + return; + } + protocol.sendResult(id, result); + } + + private void handlePromptsList(Object id) throws IOException { + Map result = new LinkedHashMap<>(); + result.put("prompts", DiagnosticPrompts.listPrompts()); + protocol.sendResult(id, result); + } + + @SuppressWarnings("unchecked") + private void handlePromptsGet(Object id, Map params) throws IOException { + if (params == null) { + protocol.sendError(id, -32602, "Missing params"); + return; + } + String name = (String) params.get("name"); + Map promptArgs = + params.containsKey("arguments") + ? (Map) params.get("arguments") + : new LinkedHashMap<>(); + + Map result = DiagnosticPrompts.getPrompt(name, promptArgs); + if (result == null) { + protocol.sendError(id, -32602, "Unknown prompt: " + name); + return; + } + protocol.sendResult(id, result); + } +} diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/ClientManager.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/ClientManager.java new file mode 100644 index 00000000..a85c431d --- /dev/null +++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/ClientManager.java @@ -0,0 +1,86 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package org.openjdk.btrace.mcp; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import org.openjdk.btrace.client.Client; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Manages BTrace client instances per JVM. Tracks active sessions so that tools like send_event, + * detach_probe, and exit_probe can interact with already-deployed probes. + */ +public final class ClientManager { + private static final Logger log = LoggerFactory.getLogger(ClientManager.class); + + /** Key is "pid:port", value is the active Client instance. */ + private static final Map activeClients = new ConcurrentHashMap<>(); + + private ClientManager() {} + + /** Creates a new BTrace Client for the given port. Does not reuse existing sessions. */ + public static Client getClient(int port) { + return new Client(port); + } + + /** Registers a client as active for a given PID and port. */ + public static void registerClient(String pid, int port, Client client) { + String key = pid + ":" + port; + activeClients.put(key, client); + log.debug("Registered client for {}", key); + } + + /** Returns an existing active client for the given PID and port, or null if none. */ + public static Client getExistingClient(String pid, int port) { + String key = pid + ":" + port; + return activeClients.get(key); + } + + /** Removes and returns an active client for the given PID and port. */ + public static Client removeClient(String pid, int port) { + String key = pid + ":" + port; + Client removed = activeClients.remove(key); + if (removed != null) { + log.debug("Removed client for {}", key); + } + return removed; + } + + /** Closes all active client sessions. Called on shutdown. */ + public static void closeAll() { + for (Map.Entry entry : activeClients.entrySet()) { + try { + entry.getValue().close(); + log.debug("Closed client for {}", entry.getKey()); + } catch (Exception e) { + log.warn("Error closing client for {}", entry.getKey(), e); + } + } + activeClients.clear(); + } +} diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/McpProtocol.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/McpProtocol.java new file mode 100644 index 00000000..7b222b05 --- /dev/null +++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/McpProtocol.java @@ -0,0 +1,430 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package org.openjdk.btrace.mcp; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.io.OutputStream; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Handles MCP JSON-RPC protocol over stdio. Reads JSON-RPC messages from stdin and writes responses + * to stdout. All logging goes to stderr. + */ +final class McpProtocol { + private static final Logger log = LoggerFactory.getLogger(McpProtocol.class); + private final BufferedReader reader; + private final OutputStream out; + + McpProtocol(InputStream in, OutputStream out) { + this.reader = new BufferedReader(new InputStreamReader(in, StandardCharsets.UTF_8)); + this.out = out; + } + + /** Reads a single JSON-RPC message from stdin. Returns null on EOF. */ + Map readMessage() throws IOException { + String line = reader.readLine(); + if (line == null) { + return null; + } + line = line.trim(); + if (line.isEmpty()) { + return null; + } + log.debug("Received: {}", line); + return parseJson(line); + } + + /** Writes a JSON-RPC response to stdout (one line, newline terminated). */ + synchronized void writeMessage(Map message) throws IOException { + String json = toJson(message); + log.debug("Sending: {}", json); + out.write(json.getBytes(StandardCharsets.UTF_8)); + out.write('\n'); + out.flush(); + } + + /** Sends a JSON-RPC success response. */ + void sendResult(Object id, Object result) throws IOException { + Map response = new LinkedHashMap<>(); + response.put("jsonrpc", "2.0"); + response.put("id", id); + response.put("result", result); + writeMessage(response); + } + + /** Sends a JSON-RPC error response. */ + void sendError(Object id, int code, String message) throws IOException { + sendError(id, code, message, null); + } + + /** Sends a JSON-RPC error response with optional data. */ + void sendError(Object id, int code, String message, Object data) throws IOException { + Map error = new LinkedHashMap<>(); + error.put("code", code); + error.put("message", message); + if (data != null) { + error.put("data", data); + } + Map response = new LinkedHashMap<>(); + response.put("jsonrpc", "2.0"); + response.put("id", id); + response.put("error", error); + writeMessage(response); + } + + // --- Minimal JSON parser (handles objects, arrays, strings, numbers, booleans, null) --- + + @SuppressWarnings("unchecked") + static Map parseJson(String json) { + Object result = new JsonParser(json.trim()).parseValue(); + if (result instanceof Map) { + return (Map) result; + } + throw new IllegalArgumentException("Expected JSON object, got: " + json); + } + + private static final class JsonParser { + private final String src; + private int pos; + + JsonParser(String src) { + this.src = src; + this.pos = 0; + } + + Object parseValue() { + skipWhitespace(); + if (pos >= src.length()) { + throw new IllegalArgumentException("Unexpected end of JSON"); + } + char c = src.charAt(pos); + if (c == '{') { + return parseObject(); + } + if (c == '[') { + return parseArray(); + } + if (c == '"') { + return parseString(); + } + if (c == 't' || c == 'f') { + return parseBoolean(); + } + if (c == 'n') { + return parseNull(); + } + return parseNumber(); + } + + Map parseObject() { + expect('{'); + Map map = new LinkedHashMap<>(); + skipWhitespace(); + if (pos < src.length() && src.charAt(pos) == '}') { + pos++; + return map; + } + while (true) { + skipWhitespace(); + String key = parseString(); + skipWhitespace(); + expect(':'); + Object value = parseValue(); + map.put(key, value); + skipWhitespace(); + if (pos < src.length() && src.charAt(pos) == ',') { + pos++; + } else { + break; + } + } + expect('}'); + return map; + } + + List parseArray() { + expect('['); + List list = new ArrayList<>(); + skipWhitespace(); + if (pos < src.length() && src.charAt(pos) == ']') { + pos++; + return list; + } + while (true) { + list.add(parseValue()); + skipWhitespace(); + if (pos < src.length() && src.charAt(pos) == ',') { + pos++; + } else { + break; + } + } + expect(']'); + return list; + } + + String parseString() { + expect('"'); + StringBuilder sb = new StringBuilder(); + while (pos < src.length()) { + char c = src.charAt(pos++); + if (c == '"') { + return sb.toString(); + } + if (c == '\\') { + if (pos >= src.length()) { + break; + } + char esc = src.charAt(pos++); + switch (esc) { + case '"': + case '\\': + case '/': + sb.append(esc); + break; + case 'b': + sb.append('\b'); + break; + case 'f': + sb.append('\f'); + break; + case 'n': + sb.append('\n'); + break; + case 'r': + sb.append('\r'); + break; + case 't': + sb.append('\t'); + break; + case 'u': + if (pos + 4 <= src.length()) { + String hex = src.substring(pos, pos + 4); + sb.append((char) Integer.parseInt(hex, 16)); + pos += 4; + } + break; + default: + sb.append(esc); + } + } else { + sb.append(c); + } + } + throw new IllegalArgumentException("Unterminated string"); + } + + Object parseNumber() { + int start = pos; + if (pos < src.length() && src.charAt(pos) == '-') { + pos++; + } + while (pos < src.length() && Character.isDigit(src.charAt(pos))) { + pos++; + } + boolean isFloat = false; + if (pos < src.length() && src.charAt(pos) == '.') { + isFloat = true; + pos++; + while (pos < src.length() && Character.isDigit(src.charAt(pos))) { + pos++; + } + } + if (pos < src.length() && (src.charAt(pos) == 'e' || src.charAt(pos) == 'E')) { + isFloat = true; + pos++; + if (pos < src.length() && (src.charAt(pos) == '+' || src.charAt(pos) == '-')) { + pos++; + } + while (pos < src.length() && Character.isDigit(src.charAt(pos))) { + pos++; + } + } + String numStr = src.substring(start, pos); + if (isFloat) { + return Double.parseDouble(numStr); + } + long val = Long.parseLong(numStr); + if (val >= Integer.MIN_VALUE && val <= Integer.MAX_VALUE) { + return (int) val; + } + return val; + } + + Object parseBoolean() { + if (src.startsWith("true", pos)) { + pos += 4; + return Boolean.TRUE; + } + if (src.startsWith("false", pos)) { + pos += 5; + return Boolean.FALSE; + } + throw new IllegalArgumentException("Invalid boolean at " + pos); + } + + Object parseNull() { + if (src.startsWith("null", pos)) { + pos += 4; + return null; + } + throw new IllegalArgumentException("Invalid null at " + pos); + } + + void skipWhitespace() { + while (pos < src.length() && Character.isWhitespace(src.charAt(pos))) { + pos++; + } + } + + void expect(char c) { + skipWhitespace(); + if (pos >= src.length() || src.charAt(pos) != c) { + throw new IllegalArgumentException( + "Expected '" + c + "' at " + pos + " but got: " + (pos < src.length() ? src.charAt(pos) : "EOF")); + } + pos++; + } + } + + // --- Minimal JSON serializer --- + + @SuppressWarnings("unchecked") + static String toJson(Object obj) { + if (obj == null) { + return "null"; + } + if (obj instanceof String) { + return escapeJsonString((String) obj); + } + if (obj instanceof Number || obj instanceof Boolean) { + return obj.toString(); + } + if (obj instanceof Map) { + Map map = (Map) obj; + StringBuilder sb = new StringBuilder("{"); + boolean first = true; + for (Map.Entry entry : map.entrySet()) { + if (!first) { + sb.append(","); + } + sb.append(escapeJsonString(entry.getKey())); + sb.append(":"); + sb.append(toJson(entry.getValue())); + first = false; + } + sb.append("}"); + return sb.toString(); + } + if (obj instanceof List) { + List list = (List) obj; + StringBuilder sb = new StringBuilder("["); + for (int i = 0; i < list.size(); i++) { + if (i > 0) { + sb.append(","); + } + sb.append(toJson(list.get(i))); + } + sb.append("]"); + return sb.toString(); + } + if (obj instanceof Object[]) { + Object[] arr = (Object[]) obj; + StringBuilder sb = new StringBuilder("["); + for (int i = 0; i < arr.length; i++) { + if (i > 0) { + sb.append(","); + } + sb.append(toJson(arr[i])); + } + sb.append("]"); + return sb.toString(); + } + return escapeJsonString(obj.toString()); + } + + private static String escapeJsonString(String s) { + StringBuilder sb = new StringBuilder("\""); + for (int i = 0; i < s.length(); i++) { + char c = s.charAt(i); + switch (c) { + case '"': + sb.append("\\\""); + break; + case '\\': + sb.append("\\\\"); + break; + case '\b': + sb.append("\\b"); + break; + case '\f': + sb.append("\\f"); + break; + case '\n': + sb.append("\\n"); + break; + case '\r': + sb.append("\\r"); + break; + case '\t': + sb.append("\\t"); + break; + default: + if (c < 0x20) { + sb.append(String.format("\\u%04x", (int) c)); + } else { + sb.append(c); + } + } + } + sb.append("\""); + return sb.toString(); + } + + /** Helper to build a tool content result (text). */ + static Map toolResult(String text, boolean isError) { + Map content = new LinkedHashMap<>(); + content.put("type", "text"); + content.put("text", text); + + List contentList = new ArrayList<>(); + contentList.add(content); + + Map result = new LinkedHashMap<>(); + result.put("content", contentList); + result.put("isError", isError); + return result; + } +} diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/prompts/DiagnosticPrompts.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/prompts/DiagnosticPrompts.java new file mode 100644 index 00000000..84c0c7fe --- /dev/null +++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/prompts/DiagnosticPrompts.java @@ -0,0 +1,320 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package org.openjdk.btrace.mcp.prompts; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +/** Provides MCP prompt templates for common BTrace diagnostic scenarios. */ +public final class DiagnosticPrompts { + + private DiagnosticPrompts() {} + + /** Returns all available prompt schemas for prompts/list. */ + public static List> listPrompts() { + List> prompts = new ArrayList<>(); + prompts.add(diagnoseSlowEndpointSchema()); + prompts.add(findExceptionSourceSchema()); + prompts.add(profileMethodSchema()); + return prompts; + } + + /** Returns a prompt by name, or null if not found. */ + public static Map getPrompt(String name, Map arguments) { + switch (name) { + case "diagnose_slow_endpoint": + return diagnoseSlowEndpoint(arguments); + case "find_exception_source": + return findExceptionSource(arguments); + case "profile_method": + return profileMethod(arguments); + default: + return null; + } + } + + // --- diagnose_slow_endpoint --- + + private static Map diagnoseSlowEndpointSchema() { + Map prompt = new LinkedHashMap<>(); + prompt.put("name", "diagnose_slow_endpoint"); + prompt.put( + "description", + "Step-by-step guide to diagnose a slow HTTP endpoint using BTrace. " + + "Helps identify where time is spent in request processing."); + + List> args = new ArrayList<>(); + args.add(promptArg("endpoint_class", "Fully qualified class name of the endpoint/controller", true)); + args.add(promptArg("endpoint_method", "Method name to diagnose", true)); + args.add(promptArg("pid", "PID of the target JVM (optional, will use list_jvms if not provided)", false)); + prompt.put("arguments", args); + return prompt; + } + + private static Map diagnoseSlowEndpoint(Map arguments) { + String endpointClass = getArg(arguments, "endpoint_class", "com.example.Controller"); + String endpointMethod = getArg(arguments, "endpoint_method", "handleRequest"); + String pid = getArg(arguments, "pid", null); + + StringBuilder text = new StringBuilder(); + text.append("# Diagnosing Slow Endpoint: ") + .append(endpointClass) + .append("::") + .append(endpointMethod) + .append("\n\n"); + + text.append("## Step 1: Find the target JVM\n"); + if (pid != null) { + text.append("Target PID: ").append(pid).append("\n"); + } else { + text.append("Use the `list_jvms` tool to find the target JVM PID.\n"); + } + text.append("\n"); + + text.append("## Step 2: Measure endpoint latency\n"); + text.append("Deploy a oneliner to measure the method's execution time:\n"); + text.append("```\n"); + text.append(endpointClass) + .append("::") + .append(endpointMethod) + .append(" @return { print method, duration }\n"); + text.append("```\n\n"); + + text.append("## Step 3: Trace internal method calls\n"); + text.append( + "If the endpoint is slow, trace the internal methods it calls to find the bottleneck:\n"); + text.append("```\n"); + text.append(endpointClass) + .append("::/.*/") + .append(" @return if duration>10ms { print method, duration }\n"); + text.append("```\n\n"); + + text.append("## Step 4: Analyze results\n"); + text.append( + "Look for methods with unexpectedly high durations. Common causes include:\n"); + text.append("- Database queries taking too long\n"); + text.append("- External service calls with high latency\n"); + text.append("- Lock contention\n"); + text.append("- Excessive object allocation\n\n"); + + text.append("## Step 5: Clean up\n"); + text.append("Use `exit_probe` to remove the instrumentation when done.\n"); + + return promptResult(text.toString()); + } + + // --- find_exception_source --- + + private static Map findExceptionSourceSchema() { + Map prompt = new LinkedHashMap<>(); + prompt.put("name", "find_exception_source"); + prompt.put( + "description", + "Guide to find where specific exceptions originate in a running JVM. " + + "Uses BTrace to intercept exception constructors and capture stack traces."); + + List> args = new ArrayList<>(); + args.add(promptArg("exception_class", "Exception class name (e.g. NullPointerException)", true)); + args.add(promptArg("pid", "PID of the target JVM (optional)", false)); + prompt.put("arguments", args); + return prompt; + } + + private static Map findExceptionSource(Map arguments) { + String exceptionClass = getArg(arguments, "exception_class", "java.lang.NullPointerException"); + String pid = getArg(arguments, "pid", null); + + StringBuilder text = new StringBuilder(); + text.append("# Finding Exception Source: ").append(exceptionClass).append("\n\n"); + + text.append("## Step 1: Find the target JVM\n"); + if (pid != null) { + text.append("Target PID: ").append(pid).append("\n"); + } else { + text.append("Use the `list_jvms` tool to find the target JVM PID.\n"); + } + text.append("\n"); + + text.append("## Step 2: Deploy an exception tracing script\n"); + text.append("Deploy this BTrace script to capture exception creation with stack traces:\n\n"); + text.append("```java\n"); + text.append("import org.openjdk.btrace.core.annotations.*;\n"); + text.append("import static org.openjdk.btrace.core.BTraceUtils.*;\n\n"); + text.append("@BTrace\n"); + text.append("public class ExceptionTracer {\n"); + text.append(" @OnMethod(\n"); + text.append(" clazz = \"").append(exceptionClass).append("\",\n"); + text.append(" method = \"\"\n"); + text.append(" )\n"); + text.append(" public static void onException(@Self Throwable self) {\n"); + text.append(" println(\"--- Exception created: \" + Strings.str(self) + \" ---\");\n"); + text.append(" Threads.jstack();\n"); + text.append(" println(\"\");\n"); + text.append(" }\n"); + text.append("}\n"); + text.append("```\n\n"); + + text.append("## Step 3: Analyze the stack traces\n"); + text.append( + "Each time the exception is created, you will see the full stack trace " + + "showing exactly which code path creates it.\n\n"); + + text.append("## Step 4: Clean up\n"); + text.append("Use `exit_probe` to remove the instrumentation when done.\n"); + + return promptResult(text.toString()); + } + + // --- profile_method --- + + private static Map profileMethodSchema() { + Map prompt = new LinkedHashMap<>(); + prompt.put("name", "profile_method"); + prompt.put( + "description", + "Guide to profile a specific method's latency distribution using BTrace. " + + "Captures timing data to understand performance characteristics."); + + List> args = new ArrayList<>(); + args.add(promptArg("class_name", "Fully qualified class name", true)); + args.add(promptArg("method_name", "Method name to profile", true)); + args.add(promptArg("pid", "PID of the target JVM (optional)", false)); + prompt.put("arguments", args); + return prompt; + } + + private static Map profileMethod(Map arguments) { + String className = getArg(arguments, "class_name", "com.example.Service"); + String methodName = getArg(arguments, "method_name", "process"); + String pid = getArg(arguments, "pid", null); + + StringBuilder text = new StringBuilder(); + text.append("# Profiling Method: ") + .append(className) + .append("::") + .append(methodName) + .append("\n\n"); + + text.append("## Step 1: Find the target JVM\n"); + if (pid != null) { + text.append("Target PID: ").append(pid).append("\n"); + } else { + text.append("Use the `list_jvms` tool to find the target JVM PID.\n"); + } + text.append("\n"); + + text.append("## Step 2: Quick latency check with oneliner\n"); + text.append("Start with a simple oneliner to see individual call durations:\n"); + text.append("```\n"); + text.append(className) + .append("::") + .append(methodName) + .append(" @return { print method, duration }\n"); + text.append("```\n\n"); + + text.append("## Step 3: Detailed profiling with histogram\n"); + text.append("For a latency distribution, deploy this BTrace script:\n\n"); + text.append("```java\n"); + text.append("import org.openjdk.btrace.core.annotations.*;\n"); + text.append("import org.openjdk.btrace.core.BTraceUtils;\n"); + text.append("import static org.openjdk.btrace.core.BTraceUtils.*;\n\n"); + text.append("@BTrace\n"); + text.append("public class MethodProfiler {\n"); + text.append(" private static long count;\n"); + text.append(" private static long totalTime;\n\n"); + text.append(" @OnMethod(\n"); + text.append(" clazz = \"").append(className).append("\",\n"); + text.append(" method = \"").append(methodName).append("\",\n"); + text.append(" location = @Location(Kind.RETURN)\n"); + text.append(" )\n"); + text.append(" public static void onReturn(@Duration long duration) {\n"); + text.append(" count++;\n"); + text.append(" totalTime += duration;\n"); + text.append( + " println(\"Call #\" + count + \": \" + (duration / 1000000) + \"ms\");\n"); + text.append(" }\n\n"); + text.append(" @OnEvent\n"); + text.append(" public static void onEvent() {\n"); + text.append(" println(\"=== Summary ===\");\n"); + text.append(" println(\"Total calls: \" + count);\n"); + text.append(" if (count > 0) {\n"); + text.append( + " println(\"Avg duration: \" + ((totalTime / count) / 1000000) + \"ms\");\n"); + text.append(" }\n"); + text.append(" }\n"); + text.append("}\n"); + text.append("```\n\n"); + + text.append("## Step 4: Get summary\n"); + text.append("Use `send_event` to trigger the @OnEvent handler and get a summary.\n\n"); + + text.append("## Step 5: Clean up\n"); + text.append("Use `exit_probe` to remove the instrumentation when done.\n"); + + return promptResult(text.toString()); + } + + // --- Helpers --- + + private static Map promptArg( + String name, String description, boolean required) { + Map arg = new LinkedHashMap<>(); + arg.put("name", name); + arg.put("description", description); + arg.put("required", required); + return arg; + } + + private static String getArg(Map args, String key, String defaultVal) { + if (args == null) { + return defaultVal; + } + Object val = args.get(key); + if (val == null) { + return defaultVal; + } + return val.toString(); + } + + private static Map promptResult(String text) { + Map message = new LinkedHashMap<>(); + message.put("role", "user"); + + Map content = new LinkedHashMap<>(); + content.put("type", "text"); + content.put("text", text); + message.put("content", content); + + List messages = new ArrayList<>(); + messages.add(message); + + Map result = new LinkedHashMap<>(); + result.put("messages", messages); + return result; + } +} diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DeployOnelinerHandler.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DeployOnelinerHandler.java new file mode 100644 index 00000000..2556554a --- /dev/null +++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DeployOnelinerHandler.java @@ -0,0 +1,230 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package org.openjdk.btrace.mcp.tools; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.openjdk.btrace.client.Client; +import org.openjdk.btrace.compiler.oneliner.OnelinerAST.OnelinerNode; +import org.openjdk.btrace.compiler.oneliner.OnelinerCodeGenerator; +import org.openjdk.btrace.compiler.oneliner.OnelinerParser; +import org.openjdk.btrace.compiler.oneliner.OnelinerValidator; +import org.openjdk.btrace.core.comm.Command; +import org.openjdk.btrace.core.comm.PrintableCommand; +import org.openjdk.btrace.mcp.ClientManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Handles the deploy_oneliner MCP tool - deploys a BTrace oneliner probe. */ +public final class DeployOnelinerHandler { + private static final Logger log = LoggerFactory.getLogger(DeployOnelinerHandler.class); + private static final int DEFAULT_PORT = 2020; + private static final int PROBE_TIMEOUT_SECONDS = 30; + + private DeployOnelinerHandler() {} + + /** Returns tool schema for MCP tools/list. */ + public static Map schema() { + Map tool = new LinkedHashMap<>(); + tool.put("name", "deploy_oneliner"); + tool.put( + "description", + "Deploy a BTrace oneliner probe to a running JVM. " + + "Oneliners are concise probe expressions like: " + + "\"com.example.Service::method @return { print duration }\" " + + "or \"com.example.Dao::query @return if duration>100ms { print method, duration }\". " + + "The probe attaches to the target JVM and captures output for the specified duration."); + + Map properties = new LinkedHashMap<>(); + + Map pidProp = new LinkedHashMap<>(); + pidProp.put("type", "string"); + pidProp.put("description", "PID of the target JVM (use list_jvms to find it)"); + properties.put("pid", pidProp); + + Map onelinerProp = new LinkedHashMap<>(); + onelinerProp.put("type", "string"); + onelinerProp.put( + "description", + "BTrace oneliner expression, e.g. " + + "\"com.example.Service::method @return { print method, duration }\""); + properties.put("oneliner", onelinerProp); + + Map portProp = new LinkedHashMap<>(); + portProp.put("type", "integer"); + portProp.put("description", "BTrace agent port (default: 2020)"); + properties.put("port", portProp); + + List required = new ArrayList<>(); + required.add("pid"); + required.add("oneliner"); + + Map inputSchema = new LinkedHashMap<>(); + inputSchema.put("type", "object"); + inputSchema.put("properties", properties); + inputSchema.put("required", required); + tool.put("inputSchema", inputSchema); + return tool; + } + + /** Executes the deploy_oneliner tool. */ + @SuppressWarnings("unchecked") + public static Map execute(Map arguments) { + String pid = getStringArg(arguments, "pid"); + String oneliner = getStringArg(arguments, "oneliner"); + int port = getIntArg(arguments, "port", DEFAULT_PORT); + + if (pid == null || pid.isEmpty()) { + return toolResult("Error: 'pid' parameter is required", true); + } + if (oneliner == null || oneliner.isEmpty()) { + return toolResult("Error: 'oneliner' parameter is required", true); + } + + try { + // Parse and validate the oneliner + OnelinerNode ast = OnelinerParser.parse(oneliner); + OnelinerValidator.validate(ast, oneliner); + String className = "BTraceOneliner_" + System.currentTimeMillis(); + String javaSource = OnelinerCodeGenerator.generate(ast, className); + String fileName = className + ".java"; + + log.info("Generated oneliner source for {}: {}", oneliner, javaSource); + + // Compile the oneliner + Client client = ClientManager.getClient(port); + StringWriter errorWriter = new StringWriter(); + PrintWriter errPw = new PrintWriter(errorWriter); + byte[] code = client.compileSource(fileName, javaSource, ".", errPw, null); + + if (code == null) { + String errors = errorWriter.toString(); + return toolResult( + "Oneliner compilation failed:\n" + errors + "\nGenerated source:\n" + javaSource, true); + } + + // Attach and submit + client.attach(pid, null, "."); + + StringBuilder output = new StringBuilder(); + CountDownLatch statusLatch = new CountDownLatch(1); + AtomicBoolean success = new AtomicBoolean(false); + AtomicBoolean exited = new AtomicBoolean(false); + + client.submit( + "localhost", + fileName, + code, + new String[0], + cmd -> { + int type = cmd.getType(); + if (cmd instanceof PrintableCommand) { + StringWriter sw = new StringWriter(); + ((PrintableCommand) cmd).print(new java.io.PrintWriter(sw)); + output.append(sw.toString()); + } + if (type == Command.STATUS) { + success.set(true); + statusLatch.countDown(); + } + if (type == Command.EXIT) { + exited.set(true); + statusLatch.countDown(); + } + }); + + // Wait for probe to start (or fail) + boolean started = statusLatch.await(PROBE_TIMEOUT_SECONDS, TimeUnit.SECONDS); + + if (!started) { + return toolResult("Probe deployment timed out after " + PROBE_TIMEOUT_SECONDS + "s", true); + } + + if (exited.get() && !success.get()) { + return toolResult("Probe exited with error:\n" + output.toString(), true); + } + + String resultText = + "Probe deployed successfully to PID " + + pid + + ".\n" + + "Oneliner: " + + oneliner + + "\n" + + "Port: " + + port + + "\n"; + if (output.length() > 0) { + resultText += "\nInitial output:\n" + output.toString(); + } + resultText += + "\nUse send_event, detach_probe, or exit_probe to interact with the running probe."; + return toolResult(resultText, false); + } catch (Exception e) { + log.error("Failed to deploy oneliner", e); + return toolResult("Error deploying oneliner: " + e.getMessage(), true); + } + } + + private static String getStringArg(Map args, String key) { + Object val = args == null ? null : args.get(key); + return val == null ? null : val.toString(); + } + + private static int getIntArg(Map args, String key, int defaultVal) { + Object val = args == null ? null : args.get(key); + if (val == null) { + return defaultVal; + } + if (val instanceof Number) { + return ((Number) val).intValue(); + } + try { + return Integer.parseInt(val.toString()); + } catch (NumberFormatException e) { + return defaultVal; + } + } + + private static Map toolResult(String text, boolean isError) { + Map content = new LinkedHashMap<>(); + content.put("type", "text"); + content.put("text", text); + List contentList = new ArrayList<>(); + contentList.add(content); + Map result = new LinkedHashMap<>(); + result.put("content", contentList); + result.put("isError", isError); + return result; + } +} diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DeployScriptHandler.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DeployScriptHandler.java new file mode 100644 index 00000000..125048ca --- /dev/null +++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DeployScriptHandler.java @@ -0,0 +1,226 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package org.openjdk.btrace.mcp.tools; + +import java.io.PrintWriter; +import java.io.StringWriter; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import org.openjdk.btrace.client.Client; +import org.openjdk.btrace.core.comm.Command; +import org.openjdk.btrace.core.comm.PrintableCommand; +import org.openjdk.btrace.mcp.ClientManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Handles the deploy_script MCP tool - deploys a full BTrace Java script. */ +public final class DeployScriptHandler { + private static final Logger log = LoggerFactory.getLogger(DeployScriptHandler.class); + private static final int DEFAULT_PORT = 2020; + private static final int PROBE_TIMEOUT_SECONDS = 30; + + private DeployScriptHandler() {} + + /** Returns tool schema for MCP tools/list. */ + public static Map schema() { + Map tool = new LinkedHashMap<>(); + tool.put("name", "deploy_script"); + tool.put( + "description", + "Deploy a full BTrace Java script to a running JVM. " + + "The script must be a valid BTrace program with @BTrace annotation. " + + "Use this for complex instrumentation that cannot be expressed as a oneliner."); + + Map properties = new LinkedHashMap<>(); + + Map pidProp = new LinkedHashMap<>(); + pidProp.put("type", "string"); + pidProp.put("description", "PID of the target JVM (use list_jvms to find it)"); + properties.put("pid", pidProp); + + Map scriptProp = new LinkedHashMap<>(); + scriptProp.put("type", "string"); + scriptProp.put( + "description", + "Full BTrace Java source code. Must include @BTrace annotation and proper imports."); + properties.put("script", scriptProp); + + Map argsProp = new LinkedHashMap<>(); + argsProp.put("type", "array"); + Map argsItems = new LinkedHashMap<>(); + argsItems.put("type", "string"); + argsProp.put("items", argsItems); + argsProp.put("description", "Optional arguments to pass to the BTrace script"); + properties.put("args", argsProp); + + Map portProp = new LinkedHashMap<>(); + portProp.put("type", "integer"); + portProp.put("description", "BTrace agent port (default: 2020)"); + properties.put("port", portProp); + + List required = new ArrayList<>(); + required.add("pid"); + required.add("script"); + + Map inputSchema = new LinkedHashMap<>(); + inputSchema.put("type", "object"); + inputSchema.put("properties", properties); + inputSchema.put("required", required); + tool.put("inputSchema", inputSchema); + return tool; + } + + /** Executes the deploy_script tool. */ + @SuppressWarnings("unchecked") + public static Map execute(Map arguments) { + String pid = getStringArg(arguments, "pid"); + String script = getStringArg(arguments, "script"); + int port = getIntArg(arguments, "port", DEFAULT_PORT); + + if (pid == null || pid.isEmpty()) { + return toolResult("Error: 'pid' parameter is required", true); + } + if (script == null || script.isEmpty()) { + return toolResult("Error: 'script' parameter is required", true); + } + + // Extract args + String[] btraceArgs = new String[0]; + Object argsObj = arguments == null ? null : arguments.get("args"); + if (argsObj instanceof List) { + List argsList = (List) argsObj; + btraceArgs = new String[argsList.size()]; + for (int i = 0; i < argsList.size(); i++) { + btraceArgs[i] = argsList.get(i).toString(); + } + } + + try { + String fileName = "BTraceScript_" + System.currentTimeMillis() + ".java"; + + // Compile the script + Client client = ClientManager.getClient(port); + StringWriter errorWriter = new StringWriter(); + PrintWriter errPw = new PrintWriter(errorWriter); + byte[] code = client.compileSource(fileName, script, ".", errPw, null); + + if (code == null) { + String errors = errorWriter.toString(); + return toolResult("Script compilation failed:\n" + errors, true); + } + + // Attach and submit + client.attach(pid, null, "."); + + StringBuilder output = new StringBuilder(); + CountDownLatch statusLatch = new CountDownLatch(1); + AtomicBoolean success = new AtomicBoolean(false); + AtomicBoolean exited = new AtomicBoolean(false); + final String[] finalArgs = btraceArgs; + + client.submit( + "localhost", + fileName, + code, + finalArgs, + cmd -> { + int type = cmd.getType(); + if (cmd instanceof PrintableCommand) { + StringWriter sw = new StringWriter(); + ((PrintableCommand) cmd).print(new java.io.PrintWriter(sw)); + output.append(sw.toString()); + } + if (type == Command.STATUS) { + success.set(true); + statusLatch.countDown(); + } + if (type == Command.EXIT) { + exited.set(true); + statusLatch.countDown(); + } + }); + + boolean started = statusLatch.await(PROBE_TIMEOUT_SECONDS, TimeUnit.SECONDS); + + if (!started) { + return toolResult("Probe deployment timed out after " + PROBE_TIMEOUT_SECONDS + "s", true); + } + + if (exited.get() && !success.get()) { + return toolResult("Probe exited with error:\n" + output.toString(), true); + } + + String resultText = + "Script deployed successfully to PID " + pid + ".\nPort: " + port + "\n"; + if (output.length() > 0) { + resultText += "\nInitial output:\n" + output.toString(); + } + resultText += + "\nUse send_event, detach_probe, or exit_probe to interact with the running probe."; + return toolResult(resultText, false); + } catch (Exception e) { + log.error("Failed to deploy script", e); + return toolResult("Error deploying script: " + e.getMessage(), true); + } + } + + private static String getStringArg(Map args, String key) { + Object val = args == null ? null : args.get(key); + return val == null ? null : val.toString(); + } + + private static int getIntArg(Map args, String key, int defaultVal) { + Object val = args == null ? null : args.get(key); + if (val == null) { + return defaultVal; + } + if (val instanceof Number) { + return ((Number) val).intValue(); + } + try { + return Integer.parseInt(val.toString()); + } catch (NumberFormatException e) { + return defaultVal; + } + } + + private static Map toolResult(String text, boolean isError) { + Map content = new LinkedHashMap<>(); + content.put("type", "text"); + content.put("text", text); + List contentList = new ArrayList<>(); + contentList.add(content); + Map result = new LinkedHashMap<>(); + result.put("content", contentList); + result.put("isError", isError); + return result; + } +} diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DetachProbeHandler.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DetachProbeHandler.java new file mode 100644 index 00000000..e277963b --- /dev/null +++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/DetachProbeHandler.java @@ -0,0 +1,134 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package org.openjdk.btrace.mcp.tools; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import org.openjdk.btrace.client.Client; +import org.openjdk.btrace.mcp.ClientManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** + * Handles the detach_probe MCP tool - detaches the client from a running probe (probe continues + * running). + */ +public final class DetachProbeHandler { + private static final Logger log = LoggerFactory.getLogger(DetachProbeHandler.class); + private static final int DEFAULT_PORT = 2020; + + private DetachProbeHandler() {} + + /** Returns tool schema for MCP tools/list. */ + public static Map schema() { + Map tool = new LinkedHashMap<>(); + tool.put("name", "detach_probe"); + tool.put( + "description", + "Detach from a running BTrace probe. The probe continues running on the target JVM " + + "but the client disconnects. You can later reconnect to it."); + + Map properties = new LinkedHashMap<>(); + + Map pidProp = new LinkedHashMap<>(); + pidProp.put("type", "string"); + pidProp.put("description", "PID of the target JVM"); + properties.put("pid", pidProp); + + Map portProp = new LinkedHashMap<>(); + portProp.put("type", "integer"); + portProp.put("description", "BTrace agent port (default: 2020)"); + properties.put("port", portProp); + + List required = new ArrayList<>(); + required.add("pid"); + + Map inputSchema = new LinkedHashMap<>(); + inputSchema.put("type", "object"); + inputSchema.put("properties", properties); + inputSchema.put("required", required); + tool.put("inputSchema", inputSchema); + return tool; + } + + /** Executes the detach_probe tool. */ + public static Map execute(Map arguments) { + String pid = getStringArg(arguments, "pid"); + int port = getIntArg(arguments, "port", DEFAULT_PORT); + + if (pid == null || pid.isEmpty()) { + return toolResult("Error: 'pid' parameter is required", true); + } + + try { + Client client = ClientManager.getExistingClient(pid, port); + if (client == null) { + return toolResult("No active BTrace session for PID " + pid, true); + } + + client.sendDisconnect(); + ClientManager.removeClient(pid, port); + return toolResult("Detached from probe on PID " + pid + ". Probe continues running.", false); + } catch (Exception e) { + log.error("Failed to detach probe", e); + return toolResult("Error detaching from PID " + pid + ": " + e.getMessage(), true); + } + } + + private static String getStringArg(Map args, String key) { + Object val = args == null ? null : args.get(key); + return val == null ? null : val.toString(); + } + + private static int getIntArg(Map args, String key, int defaultVal) { + Object val = args == null ? null : args.get(key); + if (val == null) { + return defaultVal; + } + if (val instanceof Number) { + return ((Number) val).intValue(); + } + try { + return Integer.parseInt(val.toString()); + } catch (NumberFormatException e) { + return defaultVal; + } + } + + private static Map toolResult(String text, boolean isError) { + Map content = new LinkedHashMap<>(); + content.put("type", "text"); + content.put("text", text); + List contentList = new ArrayList<>(); + contentList.add(content); + Map result = new LinkedHashMap<>(); + result.put("content", contentList); + result.put("isError", isError); + return result; + } +} diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ExitProbeHandler.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ExitProbeHandler.java new file mode 100644 index 00000000..feea97c3 --- /dev/null +++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ExitProbeHandler.java @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package org.openjdk.btrace.mcp.tools; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import org.openjdk.btrace.client.Client; +import org.openjdk.btrace.mcp.ClientManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Handles the exit_probe MCP tool - stops and removes a probe from the target JVM. */ +public final class ExitProbeHandler { + private static final Logger log = LoggerFactory.getLogger(ExitProbeHandler.class); + private static final int DEFAULT_PORT = 2020; + + private ExitProbeHandler() {} + + /** Returns tool schema for MCP tools/list. */ + public static Map schema() { + Map tool = new LinkedHashMap<>(); + tool.put("name", "exit_probe"); + tool.put( + "description", + "Stop and remove a BTrace probe from the target JVM. " + + "This terminates the probe and removes all instrumentation."); + + Map properties = new LinkedHashMap<>(); + + Map pidProp = new LinkedHashMap<>(); + pidProp.put("type", "string"); + pidProp.put("description", "PID of the target JVM"); + properties.put("pid", pidProp); + + Map portProp = new LinkedHashMap<>(); + portProp.put("type", "integer"); + portProp.put("description", "BTrace agent port (default: 2020)"); + properties.put("port", portProp); + + List required = new ArrayList<>(); + required.add("pid"); + + Map inputSchema = new LinkedHashMap<>(); + inputSchema.put("type", "object"); + inputSchema.put("properties", properties); + inputSchema.put("required", required); + tool.put("inputSchema", inputSchema); + return tool; + } + + /** Executes the exit_probe tool. */ + public static Map execute(Map arguments) { + String pid = getStringArg(arguments, "pid"); + int port = getIntArg(arguments, "port", DEFAULT_PORT); + + if (pid == null || pid.isEmpty()) { + return toolResult("Error: 'pid' parameter is required", true); + } + + try { + Client client = ClientManager.getExistingClient(pid, port); + if (client == null) { + return toolResult("No active BTrace session for PID " + pid, true); + } + + client.sendExit(0); + ClientManager.removeClient(pid, port); + return toolResult("Probe stopped and removed from PID " + pid, false); + } catch (Exception e) { + log.error("Failed to exit probe", e); + return toolResult("Error stopping probe on PID " + pid + ": " + e.getMessage(), true); + } + } + + private static String getStringArg(Map args, String key) { + Object val = args == null ? null : args.get(key); + return val == null ? null : val.toString(); + } + + private static int getIntArg(Map args, String key, int defaultVal) { + Object val = args == null ? null : args.get(key); + if (val == null) { + return defaultVal; + } + if (val instanceof Number) { + return ((Number) val).intValue(); + } + try { + return Integer.parseInt(val.toString()); + } catch (NumberFormatException e) { + return defaultVal; + } + } + + private static Map toolResult(String text, boolean isError) { + Map content = new LinkedHashMap<>(); + content.put("type", "text"); + content.put("text", text); + List contentList = new ArrayList<>(); + contentList.add(content); + Map result = new LinkedHashMap<>(); + result.put("content", contentList); + result.put("isError", isError); + return result; + } +} diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ListJvmsHandler.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ListJvmsHandler.java new file mode 100644 index 00000000..83639a5e --- /dev/null +++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ListJvmsHandler.java @@ -0,0 +1,135 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package org.openjdk.btrace.mcp.tools; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; +import sun.jvmstat.monitor.MonitoredHost; +import sun.jvmstat.monitor.MonitoredVm; +import sun.jvmstat.monitor.MonitoredVmUtil; +import sun.jvmstat.monitor.VmIdentifier; + +/** Handles the list_jvms MCP tool - lists all attachable Java VMs. */ +public final class ListJvmsHandler { + private static final Logger log = LoggerFactory.getLogger(ListJvmsHandler.class); + + private ListJvmsHandler() {} + + /** Returns tool schema for MCP tools/list. */ + public static Map schema() { + Map tool = new LinkedHashMap<>(); + tool.put("name", "list_jvms"); + tool.put( + "description", + "List all attachable Java Virtual Machines on this host. " + + "Returns PID, main class, and whether BTrace is already attached (+/-). " + + "Use this to find the PID of the JVM you want to instrument."); + // No input parameters needed + Map inputSchema = new LinkedHashMap<>(); + inputSchema.put("type", "object"); + inputSchema.put("properties", new LinkedHashMap<>()); + tool.put("inputSchema", inputSchema); + return tool; + } + + /** Executes the list_jvms tool. */ + public static Map execute(Map arguments) { + try { + Collection> vms = listVms(); + StringBuilder sb = new StringBuilder(); + if (vms.isEmpty()) { + sb.append("No attachable Java VMs found."); + } else { + sb.append("Attachable Java VMs:\n\n"); + for (Map vm : vms) { + sb.append( + String.format( + " PID: %s | Main Class: %s | BTrace: %s\n", + vm.get("pid"), vm.get("mainClass"), vm.get("btraceAttached"))); + } + } + return toolResult(sb.toString(), false); + } catch (Exception e) { + log.error("Failed to list JVMs", e); + return toolResult("Error listing JVMs: " + e.getMessage(), true); + } + } + + private static Collection> listVms() { + List> result = new ArrayList<>(); + try { + MonitoredHost vmHost = MonitoredHost.getMonitoredHost((String) null); + for (Integer vmPid : MonitoredHost.getMonitoredHost("localhost").activeVms()) { + VmIdentifier id = new VmIdentifier(vmPid.toString()); + MonitoredVm mvm = vmHost.getMonitoredVm(id); + if (MonitoredVmUtil.isAttachable(mvm)) { + Map vmInfo = new LinkedHashMap<>(); + vmInfo.put("pid", vmPid); + vmInfo.put("mainClass", MonitoredVmUtil.mainClass(mvm, false)); + vmInfo.put("btraceAttached", hasBTraceServer(vmPid) ? "attached" : "not attached"); + result.add(vmInfo); + } + } + } catch (Exception e) { + log.warn("Error listing VMs", e); + } + return result; + } + + private static boolean hasBTraceServer(int pid) { + com.sun.tools.attach.VirtualMachine vm = null; + try { + vm = com.sun.tools.attach.VirtualMachine.attach(String.valueOf(pid)); + return vm.getSystemProperties().containsKey("btrace.port"); + } catch (Throwable ignored) { + return false; + } finally { + if (vm != null) { + try { + vm.detach(); + } catch (Exception ignored) { + } + } + } + } + + private static Map toolResult(String text, boolean isError) { + Map content = new LinkedHashMap<>(); + content.put("type", "text"); + content.put("text", text); + List contentList = new ArrayList<>(); + contentList.add(content); + Map result = new LinkedHashMap<>(); + result.put("content", contentList); + result.put("isError", isError); + return result; + } +} diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ListProbesHandler.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ListProbesHandler.java new file mode 100644 index 00000000..2fe8ebb2 --- /dev/null +++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/ListProbesHandler.java @@ -0,0 +1,155 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package org.openjdk.btrace.mcp.tools; + +import java.io.StringWriter; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import org.openjdk.btrace.client.Client; +import org.openjdk.btrace.core.comm.Command; +import org.openjdk.btrace.core.comm.PrintableCommand; +import org.openjdk.btrace.mcp.ClientManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Handles the list_probes MCP tool - lists active probes on a JVM. */ +public final class ListProbesHandler { + private static final Logger log = LoggerFactory.getLogger(ListProbesHandler.class); + private static final int DEFAULT_PORT = 2020; + private static final int TIMEOUT_SECONDS = 10; + + private ListProbesHandler() {} + + /** Returns tool schema for MCP tools/list. */ + public static Map schema() { + Map tool = new LinkedHashMap<>(); + tool.put("name", "list_probes"); + tool.put( + "description", + "List active BTrace probes on a running JVM. " + + "Shows which probes are currently deployed and their IDs."); + + Map properties = new LinkedHashMap<>(); + + Map pidProp = new LinkedHashMap<>(); + pidProp.put("type", "string"); + pidProp.put("description", "PID of the target JVM"); + properties.put("pid", pidProp); + + Map portProp = new LinkedHashMap<>(); + portProp.put("type", "integer"); + portProp.put("description", "BTrace agent port (default: 2020)"); + properties.put("port", portProp); + + List required = new ArrayList<>(); + required.add("pid"); + + Map inputSchema = new LinkedHashMap<>(); + inputSchema.put("type", "object"); + inputSchema.put("properties", properties); + inputSchema.put("required", required); + tool.put("inputSchema", inputSchema); + return tool; + } + + /** Executes the list_probes tool. */ + public static Map execute(Map arguments) { + String pid = getStringArg(arguments, "pid"); + int port = getIntArg(arguments, "port", DEFAULT_PORT); + + if (pid == null || pid.isEmpty()) { + return toolResult("Error: 'pid' parameter is required", true); + } + + try { + Client client = ClientManager.getClient(port); + client.attach(pid, null, "."); + + StringBuilder output = new StringBuilder(); + CountDownLatch latch = new CountDownLatch(1); + + client.connectAndListProbes( + "localhost", + cmd -> { + if (cmd instanceof PrintableCommand) { + StringWriter sw = new StringWriter(); + ((PrintableCommand) cmd).print(new java.io.PrintWriter(sw)); + output.append(sw.toString()); + } + if (cmd.getType() == Command.LIST_PROBES) { + latch.countDown(); + } + }); + + boolean done = latch.await(TIMEOUT_SECONDS, TimeUnit.SECONDS); + if (!done) { + return toolResult("Timed out waiting for probe list from PID " + pid, true); + } + + String resultText = "Active probes on PID " + pid + ":\n" + output.toString(); + return toolResult(resultText, false); + } catch (Exception e) { + log.error("Failed to list probes", e); + return toolResult("Error listing probes on PID " + pid + ": " + e.getMessage(), true); + } + } + + private static String getStringArg(Map args, String key) { + Object val = args == null ? null : args.get(key); + return val == null ? null : val.toString(); + } + + private static int getIntArg(Map args, String key, int defaultVal) { + Object val = args == null ? null : args.get(key); + if (val == null) { + return defaultVal; + } + if (val instanceof Number) { + return ((Number) val).intValue(); + } + try { + return Integer.parseInt(val.toString()); + } catch (NumberFormatException e) { + return defaultVal; + } + } + + private static Map toolResult(String text, boolean isError) { + Map content = new LinkedHashMap<>(); + content.put("type", "text"); + content.put("text", text); + List contentList = new ArrayList<>(); + contentList.add(content); + Map result = new LinkedHashMap<>(); + result.put("content", contentList); + result.put("isError", isError); + return result; + } +} diff --git a/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/SendEventHandler.java b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/SendEventHandler.java new file mode 100644 index 00000000..095936b0 --- /dev/null +++ b/btrace-mcp-server/src/main/java/org/openjdk/btrace/mcp/tools/SendEventHandler.java @@ -0,0 +1,146 @@ +/* + * Copyright (c) 2024, Oracle and/or its affiliates. All rights reserved. + * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. + * + * This code is free software; you can redistribute it and/or modify it + * under the terms of the GNU General Public License version 2 only, as + * published by the Free Software Foundation. Oracle designates this + * particular file as subject to the "Classpath" exception as provided + * by Oracle in the LICENSE file that accompanied this code. + * + * This code is distributed in the hope that it will be useful, but WITHOUT + * ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or + * FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License + * version 2 for more details (a copy is included in the LICENSE file that + * accompanied this code). + * + * You should have received a copy of the GNU General Public License version + * 2 along with this work; if not, write to the Free Software Foundation, + * Inc., 51 Franklin St, Fifth Floor, Boston, MA 02110-1301 USA. + * + * Please contact Oracle, 500 Oracle Parkway, Redwood Shores, CA 94065 USA + * or visit www.oracle.com if you need additional information or have any + * questions. + */ + +package org.openjdk.btrace.mcp.tools; + +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import org.openjdk.btrace.client.Client; +import org.openjdk.btrace.mcp.ClientManager; +import org.slf4j.Logger; +import org.slf4j.LoggerFactory; + +/** Handles the send_event MCP tool - sends an event to a running probe. */ +public final class SendEventHandler { + private static final Logger log = LoggerFactory.getLogger(SendEventHandler.class); + private static final int DEFAULT_PORT = 2020; + + private SendEventHandler() {} + + /** Returns tool schema for MCP tools/list. */ + public static Map schema() { + Map tool = new LinkedHashMap<>(); + tool.put("name", "send_event"); + tool.put( + "description", + "Send an event to a running BTrace probe. " + + "Events can trigger @OnEvent handlers in the probe script. " + + "If no event_name is specified, an unnamed event is sent."); + + Map properties = new LinkedHashMap<>(); + + Map pidProp = new LinkedHashMap<>(); + pidProp.put("type", "string"); + pidProp.put("description", "PID of the target JVM"); + properties.put("pid", pidProp); + + Map eventNameProp = new LinkedHashMap<>(); + eventNameProp.put("type", "string"); + eventNameProp.put("description", "Name of the event to send (optional)"); + properties.put("event_name", eventNameProp); + + Map portProp = new LinkedHashMap<>(); + portProp.put("type", "integer"); + portProp.put("description", "BTrace agent port (default: 2020)"); + properties.put("port", portProp); + + List required = new ArrayList<>(); + required.add("pid"); + + Map inputSchema = new LinkedHashMap<>(); + inputSchema.put("type", "object"); + inputSchema.put("properties", properties); + inputSchema.put("required", required); + tool.put("inputSchema", inputSchema); + return tool; + } + + /** Executes the send_event tool. */ + public static Map execute(Map arguments) { + String pid = getStringArg(arguments, "pid"); + String eventName = getStringArg(arguments, "event_name"); + int port = getIntArg(arguments, "port", DEFAULT_PORT); + + if (pid == null || pid.isEmpty()) { + return toolResult("Error: 'pid' parameter is required", true); + } + + try { + Client client = ClientManager.getExistingClient(pid, port); + if (client == null) { + return toolResult( + "No active BTrace session for PID " + + pid + + ". Deploy a probe first using deploy_oneliner or deploy_script.", + true); + } + + if (eventName != null && !eventName.isEmpty()) { + client.sendEvent(eventName); + return toolResult("Event '" + eventName + "' sent to PID " + pid, false); + } else { + client.sendEvent(); + return toolResult("Unnamed event sent to PID " + pid, false); + } + } catch (Exception e) { + log.error("Failed to send event", e); + return toolResult("Error sending event to PID " + pid + ": " + e.getMessage(), true); + } + } + + private static String getStringArg(Map args, String key) { + Object val = args == null ? null : args.get(key); + return val == null ? null : val.toString(); + } + + private static int getIntArg(Map args, String key, int defaultVal) { + Object val = args == null ? null : args.get(key); + if (val == null) { + return defaultVal; + } + if (val instanceof Number) { + return ((Number) val).intValue(); + } + try { + return Integer.parseInt(val.toString()); + } catch (NumberFormatException e) { + return defaultVal; + } + } + + private static Map toolResult(String text, boolean isError) { + Map content = new LinkedHashMap<>(); + content.put("type", "text"); + content.put("text", text); + List contentList = new ArrayList<>(); + contentList.add(content); + Map result = new LinkedHashMap<>(); + result.put("content", contentList); + result.put("isError", isError); + return result; + } +} diff --git a/settings.gradle b/settings.gradle index bbbbc640..6225e86e 100644 --- a/settings.gradle +++ b/settings.gradle @@ -45,6 +45,10 @@ include 'btrace-extensions:btrace-metrics' include 'btrace-extensions:btrace-utils' include 'btrace-extensions:btrace-statsd' include 'btrace-extensions:btrace-streams' +include 'btrace-extensions:btrace-llm-trace' +include 'btrace-extensions:btrace-rag-quality' +include 'btrace-extensions:btrace-vibe-guard' +include 'btrace-extensions:btrace-gpu-bridge' dependencyResolutionManagement { versionCatalogs {