diff --git a/.github/workflows/e2e-java-tracer.yaml b/.github/workflows/e2e-java-tracer.yaml index 7e92e9eee..6ed17ce90 100644 --- a/.github/workflows/e2e-java-tracer.yaml +++ b/.github/workflows/e2e-java-tracer.yaml @@ -3,17 +3,9 @@ name: E2E - Java Tracer on: pull_request: paths: - - 'codeflash/languages/java/**' - - 'codeflash/languages/base.py' - - 'codeflash/languages/registry.py' - - 'codeflash/tracer.py' - - 'codeflash/benchmarking/function_ranker.py' - - 'codeflash/discovery/functions_to_optimize.py' - - 'codeflash/optimization/**' - - 'codeflash/verification/**' + - 'codeflash/**' - 'codeflash-java-runtime/**' - - 'tests/test_languages/fixtures/java_tracer_e2e/**' - - 'tests/scripts/end_to_end_test_java_tracer.py' + - 'tests/**' - '.github/workflows/e2e-java-tracer.yaml' workflow_dispatch: diff --git a/code_to_optimize/java-gradle/codeflash.toml b/code_to_optimize/java-gradle/codeflash.toml deleted file mode 100644 index bf6e45279..000000000 --- a/code_to_optimize/java-gradle/codeflash.toml +++ /dev/null @@ -1,4 +0,0 @@ -[tool.codeflash] -module-root = "src/main/java" -tests-root = "src/test/java" -formatter-cmds = [] diff --git a/code_to_optimize/java/codeflash.toml b/code_to_optimize/java/codeflash.toml deleted file mode 100644 index 4016df28a..000000000 --- a/code_to_optimize/java/codeflash.toml +++ /dev/null @@ -1,6 +0,0 @@ -# Codeflash configuration for Java project - -[tool.codeflash] -module-root = "src/main/java" -tests-root = "src/test/java" -formatter-cmds = [] diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java b/codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java index f4b9ec453..3a73038c1 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/ReplayHelper.java @@ -12,20 +12,179 @@ public class ReplayHelper { - private final Connection db; + private final Connection traceDb; + + // Codeflash instrumentation state — read from environment variables once + private final String mode; // "behavior", "performance", or null + private final int loopIndex; + private final String testIteration; + private final String outputFile; // SQLite path for behavior capture + private final int innerIterations; // for performance looping + + // Behavior mode: lazily opened SQLite connection for writing results + private Connection behaviorDb; + private boolean behaviorDbInitialized; public ReplayHelper(String traceDbPath) { try { - this.db = DriverManager.getConnection("jdbc:sqlite:" + traceDbPath); + this.traceDb = DriverManager.getConnection("jdbc:sqlite:" + traceDbPath); } catch (SQLException e) { throw new RuntimeException("Failed to open trace database: " + traceDbPath, e); } + + // Read codeflash instrumentation env vars (set by the test runner) + this.mode = System.getenv("CODEFLASH_MODE"); + this.loopIndex = parseIntEnv("CODEFLASH_LOOP_INDEX", 1); + this.testIteration = getEnvOrDefault("CODEFLASH_TEST_ITERATION", "0"); + this.outputFile = System.getenv("CODEFLASH_OUTPUT_FILE"); + this.innerIterations = parseIntEnv("CODEFLASH_INNER_ITERATIONS", 10); } public void replay(String className, String methodName, String descriptor, int invocationIndex) throws Exception { - // Query the function_calls table for this method at the given index + // Deserialize args and resolve method (done once, outside timing) + Object[] allArgs = loadArgs(className, methodName, descriptor, invocationIndex); + Class targetClass = Class.forName(className); + + Type[] paramTypes = Type.getArgumentTypes(descriptor); + Class[] paramClasses = new Class[paramTypes.length]; + for (int i = 0; i < paramTypes.length; i++) { + paramClasses[i] = typeToClass(paramTypes[i]); + } + + Method method = targetClass.getDeclaredMethod(methodName, paramClasses); + method.setAccessible(true); + boolean isStatic = Modifier.isStatic(method.getModifiers()); + + Object instance = null; + if (!isStatic) { + try { + java.lang.reflect.Constructor ctor = targetClass.getDeclaredConstructor(); + ctor.setAccessible(true); + instance = ctor.newInstance(); + } catch (NoSuchMethodException e) { + instance = new org.objenesis.ObjenesisStd().newInstance(targetClass); + } + } + + // Get the calling test method name from the stack trace + String testMethodName = getCallingTestMethodName(); + // Module name = the test class that called us + String testClassName = getCallingTestClassName(); + + if ("behavior".equals(mode)) { + replayBehavior(method, instance, allArgs, className, methodName, testClassName, testMethodName); + } else if ("performance".equals(mode)) { + replayPerformance(method, instance, allArgs, className, methodName, testClassName, testMethodName); + } else { + // No codeflash mode — just invoke (trace-only or manual testing) + method.invoke(instance, allArgs); + } + } + + private void replayBehavior(Method method, Object instance, Object[] args, + String className, String methodName, + String testClassName, String testMethodName) throws Exception { + String invId = testIteration + "_" + testMethodName; + + // Print start marker (same format as behavior instrumentation) + System.out.println("!$######" + testClassName + ":" + testClassName + "." + testMethodName + + ":" + methodName + ":" + loopIndex + ":" + invId + "######$!"); + + long startNs = System.nanoTime(); + Object result; + try { + result = method.invoke(instance, args); + } catch (java.lang.reflect.InvocationTargetException e) { + throw (Exception) e.getCause(); + } + long durationNs = System.nanoTime() - startNs; + + // Print end marker + System.out.println("!######" + testClassName + ":" + testClassName + "." + testMethodName + + ":" + methodName + ":" + loopIndex + ":" + invId + ":" + durationNs + "######!"); + + // Write return value to SQLite for correctness comparison + if (outputFile != null && !outputFile.isEmpty()) { + writeBehaviorResult(testClassName, testMethodName, methodName, invId, durationNs, result); + } + } + + private void replayPerformance(Method method, Object instance, Object[] args, + String className, String methodName, + String testClassName, String testMethodName) throws Exception { + // Performance mode: run inner loop for JIT warmup, print timing for each iteration + int maxInner = innerIterations; + for (int inner = 0; inner < maxInner; inner++) { + int loopId = (loopIndex - 1) * maxInner + inner; + String invId = testMethodName; + + // Print start marker + System.out.println("!$######" + testClassName + ":" + testClassName + "." + testMethodName + + ":" + methodName + ":" + loopId + ":" + invId + "######$!"); + + long startNs = System.nanoTime(); + try { + method.invoke(instance, args); + } catch (java.lang.reflect.InvocationTargetException e) { + // Swallow — performance mode doesn't check correctness + } + long durationNs = System.nanoTime() - startNs; + + // Print end marker + System.out.println("!######" + testClassName + ":" + testClassName + "." + testMethodName + + ":" + methodName + ":" + loopId + ":" + invId + ":" + durationNs + "######!"); + } + } + + private void writeBehaviorResult(String testClassName, String testMethodName, + String functionName, String invId, + long durationNs, Object result) { + try { + ensureBehaviorDb(); + String sql = "INSERT INTO test_results VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?)"; + try (PreparedStatement ps = behaviorDb.prepareStatement(sql)) { + ps.setString(1, testClassName); // test_module_path + ps.setString(2, testClassName); // test_class_name + ps.setString(3, testMethodName); // test_function_name + ps.setString(4, functionName); // function_getting_tested + ps.setInt(5, loopIndex); // loop_index + ps.setString(6, invId); // iteration_id + ps.setLong(7, durationNs); // runtime + ps.setBytes(8, serializeResult(result)); // return_value + ps.setString(9, "function_call"); // verification_type + ps.executeUpdate(); + } + } catch (Exception e) { + System.err.println("ReplayHelper: SQLite behavior write error: " + e.getMessage()); + } + } + + private void ensureBehaviorDb() throws SQLException { + if (behaviorDbInitialized) return; + behaviorDbInitialized = true; + behaviorDb = DriverManager.getConnection("jdbc:sqlite:" + outputFile); + try (java.sql.Statement stmt = behaviorDb.createStatement()) { + stmt.execute("CREATE TABLE IF NOT EXISTS test_results (" + + "test_module_path TEXT, test_class_name TEXT, test_function_name TEXT, " + + "function_getting_tested TEXT, loop_index INTEGER, iteration_id TEXT, " + + "runtime INTEGER, return_value BLOB, verification_type TEXT)"); + } + } + + private byte[] serializeResult(Object result) { + if (result == null) return null; + try { + return Serializer.serialize(result); + } catch (Exception e) { + // Fall back to String.valueOf if Kryo fails + return String.valueOf(result).getBytes(java.nio.charset.StandardCharsets.UTF_8); + } + } + + private Object[] loadArgs(String className, String methodName, String descriptor, int invocationIndex) + throws SQLException { byte[] argsBlob; - try (PreparedStatement stmt = db.prepareStatement( + try (PreparedStatement stmt = traceDb.prepareStatement( "SELECT args FROM function_calls " + "WHERE classname = ? AND function = ? AND descriptor = ? " + "ORDER BY time_ns LIMIT 1 OFFSET ?")) { @@ -43,46 +202,35 @@ public void replay(String className, String methodName, String descriptor, int i } } - // Deserialize args Object deserialized = Serializer.deserialize(argsBlob); if (!(deserialized instanceof Object[])) { throw new RuntimeException("Deserialized args is not Object[], got: " + (deserialized == null ? "null" : deserialized.getClass().getName())); } - Object[] allArgs = (Object[]) deserialized; - - // Load the target class - Class targetClass = Class.forName(className); + return (Object[]) deserialized; + } - // Parse descriptor to find parameter types - Type[] paramTypes = Type.getArgumentTypes(descriptor); - Class[] paramClasses = new Class[paramTypes.length]; - for (int i = 0; i < paramTypes.length; i++) { - paramClasses[i] = typeToClass(paramTypes[i]); + private static String getCallingTestMethodName() { + StackTraceElement[] stack = Thread.currentThread().getStackTrace(); + // Walk up: [0]=getStackTrace, [1]=this method, [2]=replay(), [3]=calling test method + for (int i = 3; i < stack.length; i++) { + String method = stack[i].getMethodName(); + if (method.startsWith("replay_")) { + return method; + } } + return stack.length > 3 ? stack[3].getMethodName() : "unknown"; + } - // Find the method - Method method = targetClass.getDeclaredMethod(methodName, paramClasses); - method.setAccessible(true); - - boolean isStatic = Modifier.isStatic(method.getModifiers()); - - if (isStatic) { - method.invoke(null, allArgs); - } else { - // Args contain only explicit parameters (no 'this'). - // Create a default instance via no-arg constructor or Kryo. - Object instance; - try { - java.lang.reflect.Constructor ctor = targetClass.getDeclaredConstructor(); - ctor.setAccessible(true); - instance = ctor.newInstance(); - } catch (NoSuchMethodException e) { - // Fall back to Objenesis instantiation (no constructor needed) - instance = new org.objenesis.ObjenesisStd().newInstance(targetClass); + private static String getCallingTestClassName() { + StackTraceElement[] stack = Thread.currentThread().getStackTrace(); + for (int i = 3; i < stack.length; i++) { + String cls = stack[i].getClassName(); + if (cls.contains("ReplayTest") || cls.contains("replay")) { + return cls; } - method.invoke(instance, allArgs); } + return stack.length > 3 ? stack[3].getClassName() : "unknown"; } private static Class typeToClass(Type type) throws ClassNotFoundException { @@ -106,11 +254,23 @@ private static Class typeToClass(Type type) throws ClassNotFoundException { } } + private static int parseIntEnv(String name, int defaultValue) { + String val = System.getenv(name); + if (val == null || val.isEmpty()) return defaultValue; + try { return Integer.parseInt(val); } catch (NumberFormatException e) { return defaultValue; } + } + + private static String getEnvOrDefault(String name, String defaultValue) { + String val = System.getenv(name); + return (val != null && !val.isEmpty()) ? val : defaultValue; + } + public void close() { - try { - if (db != null) db.close(); - } catch (SQLException e) { - System.err.println("Error closing ReplayHelper: " + e.getMessage()); + try { if (traceDb != null) traceDb.close(); } catch (SQLException e) { + System.err.println("Error closing ReplayHelper trace db: " + e.getMessage()); + } + try { if (behaviorDb != null) behaviorDb.close(); } catch (SQLException e) { + System.err.println("Error closing ReplayHelper behavior db: " + e.getMessage()); } } } diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java index 2a22b74f4..28c2d2998 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TraceRecorder.java @@ -22,6 +22,7 @@ public final class TraceRecorder { private final TracerConfig config; private final TraceWriter writer; private final ConcurrentHashMap functionCounts = new ConcurrentHashMap<>(); + private final AtomicInteger droppedCaptures = new AtomicInteger(0); private final int maxFunctionCount; private final ExecutorService serializerExecutor; @@ -82,11 +83,13 @@ private void onEntryImpl(String className, String methodName, String descriptor, argsBlob = future.get(SERIALIZATION_TIMEOUT_MS, TimeUnit.MILLISECONDS); } catch (TimeoutException e) { future.cancel(true); + droppedCaptures.incrementAndGet(); System.err.println("[codeflash-tracer] Serialization timed out for " + className + "." + methodName); return; } catch (Exception e) { Throwable cause = e.getCause() != null ? e.getCause() : e; + droppedCaptures.incrementAndGet(); System.err.println("[codeflash-tracer] Serialization failed for " + className + "." + methodName + ": " + cause.getClass().getSimpleName() + ": " + cause.getMessage()); return; @@ -113,11 +116,15 @@ public void flush() { } metadata.put("totalCaptures", String.valueOf(totalCaptures)); + int dropped = droppedCaptures.get(); + metadata.put("droppedCaptures", String.valueOf(dropped)); + writer.writeMetadata(metadata); writer.flush(); writer.close(); System.err.println("[codeflash-tracer] Captured " + totalCaptures - + " invocations across " + functionCounts.size() + " methods"); + + " invocations across " + functionCounts.size() + " methods" + + (dropped > 0 ? " (" + dropped + " dropped due to serialization timeout/failure)" : "")); } } diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingClassVisitor.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingClassVisitor.java index c760ea636..90d4cd7a0 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingClassVisitor.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingClassVisitor.java @@ -4,14 +4,20 @@ import org.objectweb.asm.MethodVisitor; import org.objectweb.asm.Opcodes; +import java.util.Collections; +import java.util.Map; + public class TracingClassVisitor extends ClassVisitor { private final String internalClassName; + private final Map methodLineNumbers; private String sourceFile; - public TracingClassVisitor(ClassVisitor classVisitor, String internalClassName) { + public TracingClassVisitor(ClassVisitor classVisitor, String internalClassName, + Map methodLineNumbers) { super(Opcodes.ASM9, classVisitor); this.internalClassName = internalClassName; + this.methodLineNumbers = methodLineNumbers != null ? methodLineNumbers : Collections.emptyMap(); } @Override @@ -37,7 +43,8 @@ public MethodVisitor visitMethod(int access, String name, String descriptor, return mv; } + int lineNumber = methodLineNumbers.getOrDefault(name + descriptor, 0); return new TracingMethodAdapter(mv, access, name, descriptor, - internalClassName, 0, sourceFile != null ? sourceFile : ""); + internalClassName, lineNumber, sourceFile != null ? sourceFile : ""); } } diff --git a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingTransformer.java b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingTransformer.java index 974c767a9..53ac775af 100644 --- a/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingTransformer.java +++ b/codeflash-java-runtime/src/main/java/com/codeflash/tracer/TracingTransformer.java @@ -1,10 +1,16 @@ package com.codeflash.tracer; import org.objectweb.asm.ClassReader; +import org.objectweb.asm.ClassVisitor; import org.objectweb.asm.ClassWriter; +import org.objectweb.asm.Label; +import org.objectweb.asm.MethodVisitor; +import org.objectweb.asm.Opcodes; import java.lang.instrument.ClassFileTransformer; import java.security.ProtectionDomain; +import java.util.HashMap; +import java.util.Map; public class TracingTransformer implements ClassFileTransformer { @@ -22,11 +28,6 @@ public byte[] transform(ClassLoader loader, String className, return null; } - // Skip instrumentation if we're inside a recording call (e.g., during Kryo serialization) - if (TraceRecorder.isRecording()) { - return null; - } - // Skip internal JDK, framework, and synthetic classes if (className.startsWith("java/") || className.startsWith("javax/") @@ -51,6 +52,30 @@ public byte[] transform(ClassLoader loader, String className, private byte[] instrumentClass(String internalClassName, byte[] bytecode) { ClassReader cr = new ClassReader(bytecode); + + // Pre-scan: collect the first source line number for each method. + // ASM's visitMethod() doesn't provide line info — it arrives later via visitLineNumber(). + // We do a lightweight read pass first so the instrumentation pass has accurate line numbers. + Map methodLineNumbers = new HashMap<>(); + cr.accept(new ClassVisitor(Opcodes.ASM9) { + @Override + public MethodVisitor visitMethod(int access, String name, String descriptor, + String signature, String[] exceptions) { + String key = name + descriptor; + return new MethodVisitor(Opcodes.ASM9) { + private boolean captured = false; + + @Override + public void visitLineNumber(int line, Label start) { + if (!captured) { + methodLineNumbers.put(key, line); + captured = true; + } + } + }; + } + }, ClassReader.SKIP_FRAMES); + // Use COMPUTE_MAXS only (not COMPUTE_FRAMES) to preserve original stack map frames. // COMPUTE_FRAMES recomputes all frames and calls getCommonSuperClass() which either // triggers classloader deadlocks or produces incorrect frames when returning "java/lang/Object". @@ -58,7 +83,7 @@ private byte[] instrumentClass(String internalClassName, byte[] bytecode) { // adjusts offsets for injected code. Our AdviceAdapter only injects at method entry // (before any branch points), so existing frames remain valid. ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS); - TracingClassVisitor cv = new TracingClassVisitor(cw, internalClassName); + TracingClassVisitor cv = new TracingClassVisitor(cw, internalClassName, methodLineNumbers); cr.accept(cv, ClassReader.EXPAND_FRAMES); return cw.toByteArray(); } diff --git a/codeflash/cli_cmds/cli.py b/codeflash/cli_cmds/cli.py index d76e60a11..7230eb3bc 100644 --- a/codeflash/cli_cmds/cli.py +++ b/codeflash/cli_cmds/cli.py @@ -9,7 +9,9 @@ from codeflash.cli_cmds.console import apologize_and_exit, logger from codeflash.code_utils import env_utils from codeflash.code_utils.code_utils import exit_with_message, normalize_ignore_paths -from codeflash.code_utils.config_parser import parse_config_file +from codeflash.code_utils.config_parser import LanguageConfig, parse_config_file +from codeflash.languages import set_current_language +from codeflash.languages.language_enum import Language from codeflash.languages.test_framework import set_current_test_framework from codeflash.lsp.helpers import is_LSP_enabled from codeflash.version import __version__ as version @@ -108,11 +110,14 @@ def process_pyproject_config(args: Namespace) -> Namespace: assert args.module_root is not None, "--module-root must be specified" assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory" - # For JS/TS projects, tests_root is optional (Jest auto-discovers tests) - # Default to module_root if not specified is_js_ts_project = pyproject_config.get("language") in ("javascript", "typescript") is_java_project = pyproject_config.get("language") == "java" + # Set the language singleton early so downstream code (e.g. get_git_diff) + # can use current_language_support() before function discovery. + if pyproject_config.get("language"): + set_current_language(pyproject_config["language"]) + # Set the test framework singleton for JS/TS projects if is_js_ts_project and pyproject_config.get("test_framework"): set_current_test_framework(pyproject_config["test_framework"]) @@ -185,11 +190,17 @@ def process_pyproject_config(args: Namespace) -> Namespace: args.ignore_paths = normalize_ignore_paths(args.ignore_paths, base_path=args.module_root) # If module-root is "." then all imports are relatives to it. # in this case, the ".." becomes outside project scope, causing issues with un-importable paths - args.project_root = project_root_from_module_root(args.module_root, pyproject_file_path) + args.project_root = project_root_from_module_root(Path(args.module_root), pyproject_file_path) args.tests_root = Path(args.tests_root).resolve() if args.benchmarks_root: args.benchmarks_root = Path(args.benchmarks_root).resolve() args.test_project_root = project_root_from_module_root(args.tests_root, pyproject_file_path) + + if is_java_project and pyproject_file_path.is_dir(): + # For Java projects, pyproject_file_path IS the project root directory (not a file). + # Override project_root which may have resolved to a sub-module. + args.project_root = pyproject_file_path.resolve() + args.test_project_root = pyproject_file_path.resolve() if is_LSP_enabled(): args.all = None return args @@ -208,8 +219,6 @@ def project_root_from_module_root(module_root: Path, pyproject_file_path: Path) return current.resolve() if (current / "build.gradle").exists() or (current / "build.gradle.kts").exists(): return current.resolve() - if (current / "codeflash.toml").exists(): - return current.resolve() current = current.parent return module_root.parent.resolve() @@ -250,6 +259,83 @@ def handle_optimize_all_arg_parsing(args: Namespace) -> Namespace: return args +def apply_language_config(args: Namespace, lang_config: LanguageConfig) -> Namespace: + config = lang_config.config + config_path = lang_config.config_path + + supported_keys = [ + "module_root", + "tests_root", + "benchmarks_root", + "ignore_paths", + "pytest_cmd", + "formatter_cmds", + "disable_telemetry", + "disable_imports_sorting", + "git_remote", + "override_fixtures", + ] + for key in supported_keys: + if key in config and ((hasattr(args, key) and getattr(args, key) is None) or not hasattr(args, key)): + setattr(args, key, config[key]) + + assert args.module_root is not None, "--module-root must be specified" + assert Path(args.module_root).is_dir(), f"--module-root {args.module_root} must be a valid directory" + + set_current_language(lang_config.language) + + is_js_ts = lang_config.language in (Language.JAVASCRIPT, Language.TYPESCRIPT) + if is_js_ts and config.get("test_framework"): + set_current_test_framework(config["test_framework"]) + + is_java = lang_config.language == Language.JAVA + if args.tests_root is None: + if is_java: + for test_dir in ["src/test/java", "test", "tests"]: + test_path = Path(args.module_root).parent / test_dir if "/" in test_dir else Path(test_dir) + if not test_path.is_absolute(): + test_path = Path.cwd() / test_path + if test_path.is_dir(): + args.tests_root = str(test_path) + break + if args.tests_root is None: + args.tests_root = str(Path.cwd() / "src" / "test" / "java") + elif is_js_ts: + for test_dir in ["test", "tests", "__tests__"]: + if Path(test_dir).is_dir(): + args.tests_root = test_dir + break + if args.tests_root is None and args.module_root: + module_root_path = Path(args.module_root) + for test_dir in ["test", "tests", "__tests__"]: + test_path = module_root_path / test_dir + if test_path.is_dir(): + args.tests_root = str(test_path) + break + if args.tests_root is None: + args.tests_root = args.module_root + else: + raise AssertionError("--tests-root must be specified") + + assert Path(args.tests_root).is_dir(), f"--tests-root {args.tests_root} must be a valid directory" + + args.module_root = Path(args.module_root).resolve() + if hasattr(args, "ignore_paths") and args.ignore_paths is not None: + args.ignore_paths = normalize_ignore_paths(args.ignore_paths, base_path=args.module_root) + args.project_root = project_root_from_module_root(args.module_root, config_path) + args.tests_root = Path(args.tests_root).resolve() + if args.benchmarks_root: + args.benchmarks_root = Path(args.benchmarks_root).resolve() + args.test_project_root = project_root_from_module_root(args.tests_root, config_path) + + if is_java and config_path.is_dir(): + # For Java projects, config_path IS the project root directory (from build-tool detection). + args.project_root = config_path.resolve() + args.test_project_root = config_path.resolve() + + return args + + def _handle_show_config() -> None: """Show current or auto-detected Codeflash configuration.""" from rich.table import Table @@ -370,7 +456,7 @@ def _build_parser() -> ArgumentParser: subparsers.add_parser("vscode-install", help="Install the Codeflash VSCode extension") subparsers.add_parser("init-actions", help="Initialize GitHub Actions workflow") - trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.") + trace_optimize = subparsers.add_parser("optimize", help="Trace and optimize your project.", add_help=False) trace_optimize.add_argument( "--max-function-count", diff --git a/codeflash/code_utils/config_parser.py b/codeflash/code_utils/config_parser.py index ef21ce051..6d165e9a4 100644 --- a/codeflash/code_utils/config_parser.py +++ b/codeflash/code_utils/config_parser.py @@ -1,19 +1,49 @@ from __future__ import annotations +from dataclasses import dataclass from pathlib import Path from typing import Any import tomlkit from codeflash.code_utils.config_js import find_package_json, parse_package_json_config +from codeflash.languages.language_enum import Language from codeflash.lsp.helpers import is_LSP_enabled PYPROJECT_TOML_CACHE: dict[Path, Path] = {} ALL_CONFIG_FILES: dict[Path, dict[str, Path]] = {} +@dataclass +class LanguageConfig: + config: dict[str, Any] + config_path: Path + language: Language + + +def _try_parse_java_build_config() -> tuple[dict[str, Any], Path] | None: + """Detect Java project from build files and parse config from pom.xml/gradle.properties. + + Returns (config_dict, project_root) if a Java project is found, None otherwise. + """ + dir_path = Path.cwd() + while dir_path != dir_path.parent: + if ( + (dir_path / "pom.xml").exists() + or (dir_path / "build.gradle").exists() + or (dir_path / "build.gradle.kts").exists() + ): + from codeflash.languages.java.build_tools import parse_java_project_config + + config = parse_java_project_config(dir_path) + if config is not None: + return config, dir_path + dir_path = dir_path.parent + return None + + def find_pyproject_toml(config_file: Path | None = None) -> Path: - # Find the pyproject.toml or codeflash.toml file on the root of the project + # Find the pyproject.toml file on the root of the project if config_file is not None: config_file = Path(config_file) @@ -29,21 +59,13 @@ def find_pyproject_toml(config_file: Path | None = None) -> Path: # see if it was encountered before in search if cur_path in PYPROJECT_TOML_CACHE: return PYPROJECT_TOML_CACHE[cur_path] - # map current path to closest file - check both pyproject.toml and codeflash.toml while dir_path != dir_path.parent: - # First check pyproject.toml (Python projects) config_file = dir_path / "pyproject.toml" if config_file.exists(): PYPROJECT_TOML_CACHE[cur_path] = config_file return config_file - # Then check codeflash.toml (Java/other projects) - config_file = dir_path / "codeflash.toml" - if config_file.exists(): - PYPROJECT_TOML_CACHE[cur_path] = config_file - return config_file - # Search in parent directories dir_path = dir_path.parent - msg = f"Could not find pyproject.toml or codeflash.toml in the current directory {Path.cwd()} or any of the parent directories. Please create it by running `codeflash init`, or pass the path to the config file with the --config-file argument." + msg = f"Could not find pyproject.toml in the current directory {Path.cwd()} or any of the parent directories. Please create it by running `codeflash init`, or pass the path to the config file with the --config-file argument." raise ValueError(msg) from None @@ -90,33 +112,169 @@ def find_conftest_files(test_paths: list[Path]) -> list[Path]: return list(list_of_conftest_files) -# TODO for claude: There should be different functions to parse it per language, which should be chosen during runtime +def normalize_toml_config(config: dict[str, Any], config_file_path: Path) -> dict[str, Any]: + path_keys = ["module-root", "tests-root", "benchmarks-root"] + path_list_keys = ["ignore-paths"] + str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"} + bool_keys = { + "override-fixtures": False, + "disable-telemetry": False, + "disable-imports-sorting": False, + "benchmark": False, + } + list_str_keys = {"formatter-cmds": []} + + for key, default_value in str_keys.items(): + if key in config: + config[key] = str(config[key]) + else: + config[key] = default_value + for key, default_value in bool_keys.items(): + if key in config: + config[key] = bool(config[key]) + else: + config[key] = default_value + for key in path_keys: + if key in config: + config[key] = str((config_file_path.parent / Path(config[key])).resolve()) + for key, default_value in list_str_keys.items(): + if key in config: + config[key] = [str(cmd) for cmd in config[key]] + else: + config[key] = default_value + for key in path_list_keys: + if key in config: + config[key] = [str((config_file_path.parent / path).resolve()) for path in config[key]] + else: + config[key] = [] + + # Convert hyphenated keys to underscored keys + for key in list(config.keys()): + if "-" in key: + config[key.replace("-", "_")] = config[key] + del config[key] + + return config + + +def _parse_java_config_for_dir(dir_path: Path) -> dict[str, Any] | None: + from codeflash.languages.java.build_tools import parse_java_project_config + + return parse_java_project_config(dir_path) + + +_SUBDIR_SKIP = frozenset({ + ".git", ".hg", ".svn", "node_modules", ".venv", "venv", "__pycache__", + "target", "build", "dist", ".tox", ".mypy_cache", ".ruff_cache", ".pytest_cache", +}) + + +def _check_dir_for_configs( + dir_path: Path, + configs: list[LanguageConfig], + seen_languages: set[Language], +) -> None: + """Check a single directory for language config files and append any found to *configs*.""" + if Language.PYTHON not in seen_languages: + pyproject = dir_path / "pyproject.toml" + if pyproject.exists(): + try: + with pyproject.open("rb") as f: + data = tomlkit.parse(f.read()) + tool = data.get("tool", {}) + if isinstance(tool, dict) and "codeflash" in tool: + raw_config = dict(tool["codeflash"]) + normalized = normalize_toml_config(raw_config, pyproject) + seen_languages.add(Language.PYTHON) + configs.append(LanguageConfig(config=normalized, config_path=pyproject, language=Language.PYTHON)) + except Exception: + pass + + if Language.JAVASCRIPT not in seen_languages: + package_json = dir_path / "package.json" + if package_json.exists(): + try: + result = parse_package_json_config(package_json) + if result is not None: + config, path = result + seen_languages.add(Language.JAVASCRIPT) + configs.append(LanguageConfig(config=config, config_path=path, language=Language.JAVASCRIPT)) + except Exception: + pass + + if Language.JAVA not in seen_languages: + if ( + (dir_path / "pom.xml").exists() + or (dir_path / "build.gradle").exists() + or (dir_path / "build.gradle.kts").exists() + ): + try: + java_config = _parse_java_config_for_dir(dir_path) + if java_config is not None: + seen_languages.add(Language.JAVA) + configs.append(LanguageConfig(config=java_config, config_path=dir_path, language=Language.JAVA)) + except Exception: + pass + + +def find_all_config_files(start_dir: Path | None = None) -> list[LanguageConfig]: + if start_dir is None: + start_dir = Path.cwd() + + configs: list[LanguageConfig] = [] + seen_languages: set[Language] = set() + + # Walk upward from start_dir to filesystem root (closest config wins per language) + dir_path = start_dir.resolve() + while True: + _check_dir_for_configs(dir_path, configs, seen_languages) + + parent = dir_path.parent + if parent == dir_path: + break + dir_path = parent + + # Scan immediate subdirectories for monorepo language subprojects + resolved_start = start_dir.resolve() + try: + subdirs = sorted(p for p in resolved_start.iterdir() if p.is_dir() and p.name not in _SUBDIR_SKIP) + except OSError: + subdirs = [] + for subdir in subdirs: + _check_dir_for_configs(subdir, configs, seen_languages) + + return configs + + + def parse_config_file( config_file_path: Path | None = None, override_formatter_check: bool = False ) -> tuple[dict[str, Any], Path]: + # Detect all config sources — Java, package.json, pyproject.toml + java_result = _try_parse_java_build_config() if config_file_path is None else None package_json_path = find_package_json(config_file_path) pyproject_toml_path = find_closest_config_file("pyproject.toml") if config_file_path is None else None - codeflash_toml_path = find_closest_config_file("codeflash.toml") if config_file_path is None else None - # Pick the closest toml config (pyproject.toml or codeflash.toml). - # Java projects use codeflash.toml; Python projects use pyproject.toml. - closest_toml_path = None - if pyproject_toml_path and codeflash_toml_path: - closest_toml_path = max(pyproject_toml_path, codeflash_toml_path, key=lambda p: len(p.parent.parts)) - else: - closest_toml_path = pyproject_toml_path or codeflash_toml_path + # Use Java config only if no closer JS/Python config exists (monorepo support). + # In a monorepo with a parent pom.xml and a child package.json, the closer config wins. + if java_result is not None: + java_depth = len(java_result[1].parts) + has_closer = (package_json_path is not None and len(package_json_path.parent.parts) >= java_depth) or ( + pyproject_toml_path is not None and len(pyproject_toml_path.parent.parts) >= java_depth + ) + if not has_closer: + return java_result # When both config files exist, prefer the one closer to CWD. # This prevents a parent-directory package.json (e.g., monorepo root) - # from overriding a closer pyproject.toml or codeflash.toml. + # from overriding a closer pyproject.toml. use_package_json = False if package_json_path: - if closest_toml_path is None: + if pyproject_toml_path is None: use_package_json = True else: - # Compare depth: more path parts = closer to CWD = more specific package_json_depth = len(package_json_path.parent.parts) - toml_depth = len(closest_toml_path.parent.parts) + toml_depth = len(pyproject_toml_path.parent.parts) use_package_json = package_json_depth >= toml_depth if use_package_json: @@ -160,55 +318,13 @@ def parse_config_file( if config == {} and lsp_mode: return {}, config_file_path - # Preserve language field if present (important for Java/JS projects using codeflash.toml) - # default values: - path_keys = ["module-root", "tests-root", "benchmarks-root"] - path_list_keys = ["ignore-paths"] - str_keys = {"pytest-cmd": "pytest", "git-remote": "origin"} - bool_keys = { - "override-fixtures": False, - "disable-telemetry": False, - "disable-imports-sorting": False, - "benchmark": False, - } - # Note: formatter-cmds defaults to empty list. For Python projects, black is typically - # detected by the project detector. For Java projects, no formatter is supported yet. - list_str_keys = {"formatter-cmds": []} - - for key, default_value in str_keys.items(): - if key in config: - config[key] = str(config[key]) - else: - config[key] = default_value - for key, default_value in bool_keys.items(): - if key in config: - config[key] = bool(config[key]) - else: - config[key] = default_value - for key in path_keys: - if key in config: - config[key] = str((Path(config_file_path).parent / Path(config[key])).resolve()) - for key, default_value in list_str_keys.items(): - if key in config: - config[key] = [str(cmd) for cmd in config[key]] - else: - config[key] = default_value - - for key in path_list_keys: - if key in config: - config[key] = [str((Path(config_file_path).parent / path).resolve()) for path in config[key]] - else: - config[key] = [] + config = normalize_toml_config(config, config_file_path) # see if this is happening during GitHub actions setup - if config.get("formatter-cmds") and len(config.get("formatter-cmds")) > 0 and not override_formatter_check: - assert config.get("formatter-cmds")[0] != "your-formatter $file", ( + if config.get("formatter_cmds") and len(config.get("formatter_cmds")) > 0 and not override_formatter_check: + assert config.get("formatter_cmds")[0] != "your-formatter $file", ( "The formatter command is not set correctly in pyproject.toml. Please set the " "formatter command in the 'formatter-cmds' key. More info - https://docs.codeflash.ai/configuration" ) - for key in list(config.keys()): - if "-" in key: - config[key.replace("-", "_")] = config[key] - del config[key] return config, config_file_path diff --git a/codeflash/discovery/functions_to_optimize.py b/codeflash/discovery/functions_to_optimize.py index 5780f4def..ec58a747d 100644 --- a/codeflash/discovery/functions_to_optimize.py +++ b/codeflash/discovery/functions_to_optimize.py @@ -554,11 +554,13 @@ def get_all_replay_test_functions( def _get_java_replay_test_functions( - replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path + replay_test: list[Path], test_cfg: TestConfig, project_root_path: Path | str ) -> tuple[dict[Path, list[FunctionToOptimize]], Path]: """Parse Java replay test files to extract functions and trace file path.""" from codeflash.languages.java.replay_test import parse_replay_test_metadata + project_root_path = Path(project_root_path) + trace_file_path: Path | None = None functions: dict[Path, list[FunctionToOptimize]] = defaultdict(list) diff --git a/codeflash/languages/java/build_tools.py b/codeflash/languages/java/build_tools.py index 28db2c9aa..f8a19c693 100644 --- a/codeflash/languages/java/build_tools.py +++ b/codeflash/languages/java/build_tools.py @@ -10,7 +10,8 @@ import xml.etree.ElementTree as ET from dataclasses import dataclass from enum import Enum -from pathlib import Path # noqa: TC003 — used at runtime +from pathlib import Path +from typing import Any logger = logging.getLogger(__name__) @@ -343,6 +344,218 @@ def _parse_surefire_reports(surefire_dir: Path) -> tuple[int, int, int, int]: return tests_run, failures, errors, skipped +def parse_java_project_config(project_root: Path) -> dict[str, Any] | None: + """Parse codeflash config from Maven/Gradle build files. + + Reads codeflash.* properties from pom.xml or gradle.properties, + then fills in defaults from auto-detected build tool conventions. + + Returns None if no Java build tool is detected. + """ + build_tool = detect_build_tool(project_root) + if build_tool == BuildTool.UNKNOWN: + return None + + # Read explicit codeflash properties from build files + user_config: dict[str, str] = {} + if build_tool == BuildTool.MAVEN: + user_config = _read_maven_codeflash_properties(project_root) + elif build_tool == BuildTool.GRADLE: + user_config = _read_gradle_codeflash_properties(project_root) + + # Auto-detect defaults — for multi-module Maven projects, scan module pom.xml files + source_root = find_source_root(project_root) + test_root = find_test_root(project_root) + + if build_tool == BuildTool.MAVEN: + source_from_modules, test_from_modules = _detect_roots_from_maven_modules(project_root) + # Module-level pom.xml declarations are more precise than directory-name heuristics + if source_from_modules is not None: + source_root = source_from_modules + if test_from_modules is not None: + test_root = test_from_modules + + # Build the config dict matching the format expected by the rest of codeflash + config: dict[str, Any] = { + "language": "java", + "module_root": str( + (project_root / user_config["moduleRoot"]).resolve() + if "moduleRoot" in user_config + else (source_root or project_root / "src" / "main" / "java") + ), + "tests_root": str( + (project_root / user_config["testsRoot"]).resolve() + if "testsRoot" in user_config + else (test_root or project_root / "src" / "test" / "java") + ), + "pytest_cmd": "pytest", + "git_remote": user_config.get("gitRemote", "origin"), + "disable_telemetry": user_config.get("disableTelemetry", "false").lower() == "true", + "disable_imports_sorting": False, + "override_fixtures": False, + "benchmark": False, + "formatter_cmds": [], + "ignore_paths": [], + } + + if "ignorePaths" in user_config: + config["ignore_paths"] = [ + str((project_root / p.strip()).resolve()) for p in user_config["ignorePaths"].split(",") if p.strip() + ] + + if "formatterCmds" in user_config: + config["formatter_cmds"] = [cmd.strip() for cmd in user_config["formatterCmds"].split(",") if cmd.strip()] + + return config + + +def _read_maven_codeflash_properties(project_root: Path) -> dict[str, str]: + """Read codeflash.* properties from pom.xml section.""" + pom_path = project_root / "pom.xml" + if not pom_path.exists(): + return {} + + try: + tree = _safe_parse_xml(pom_path) + root = tree.getroot() + ns = {"m": "http://maven.apache.org/POM/4.0.0"} + + result: dict[str, str] = {} + for props in [root.find("m:properties", ns), root.find("properties")]: + if props is None: + continue + for child in props: + tag = child.tag + # Strip Maven namespace prefix + if "}" in tag: + tag = tag.split("}", 1)[1] + if tag.startswith("codeflash.") and child.text: + key = tag[len("codeflash.") :] + result[key] = child.text.strip() + return result + except Exception: + logger.debug("Failed to read codeflash properties from pom.xml", exc_info=True) + return {} + + +def _read_gradle_codeflash_properties(project_root: Path) -> dict[str, str]: + """Read codeflash.* properties from gradle.properties.""" + props_path = project_root / "gradle.properties" + if not props_path.exists(): + return {} + + result: dict[str, str] = {} + try: + with props_path.open("r", encoding="utf-8") as f: + for line in f: + stripped = line.strip() + if stripped.startswith("#") or "=" not in stripped: + continue + key, value = stripped.split("=", 1) + key = key.strip() + if key.startswith("codeflash."): + result[key[len("codeflash.") :]] = value.strip() + return result + except Exception: + logger.debug("Failed to read codeflash properties from gradle.properties", exc_info=True) + return {} + + +def _detect_roots_from_maven_modules(project_root: Path) -> tuple[Path | None, Path | None]: + """Scan Maven module pom.xml files for custom sourceDirectory/testSourceDirectory. + + For multi-module projects like aerospike (client/, test/, benchmarks/), + finds the main source module and test module by parsing each module's build config. + """ + pom_path = project_root / "pom.xml" + if not pom_path.exists(): + return None, None + + try: + tree = _safe_parse_xml(pom_path) + root = tree.getroot() + ns = {"m": "http://maven.apache.org/POM/4.0.0"} + + # Find to get module names + modules: list[str] = [] + for modules_elem in [root.find("m:modules", ns), root.find("modules")]: + if modules_elem is not None: + for mod in modules_elem: + if mod.text: + modules.append(mod.text.strip()) + + if not modules: + return None, None + + # Collect candidate source and test roots with Java file counts + source_candidates: list[tuple[Path, int]] = [] + test_root: Path | None = None + + skip_modules = {"example", "examples", "benchmark", "benchmarks", "demo", "sample", "samples"} + + for module_name in modules: + module_pom = project_root / module_name / "pom.xml" + if not module_pom.exists(): + continue + + # Modules named "test" are test modules, not source modules + is_test_module = "test" in module_name.lower() + + try: + mod_tree = _safe_parse_xml(module_pom) + mod_root = mod_tree.getroot() + + for build in [mod_root.find("m:build", ns), mod_root.find("build")]: + if build is None: + continue + + for src_elem in [build.find("m:sourceDirectory", ns), build.find("sourceDirectory")]: + if src_elem is not None and src_elem.text: + src_text = src_elem.text.replace("${project.basedir}", str(project_root / module_name)) + src_path = Path(src_text) + if not src_path.is_absolute(): + src_path = project_root / module_name / src_path + if src_path.exists(): + if is_test_module and test_root is None: + test_root = src_path + elif module_name.lower() not in skip_modules: + java_count = sum(1 for _ in src_path.rglob("*.java")) + if java_count > 0: + source_candidates.append((src_path, java_count)) + + for test_elem in [build.find("m:testSourceDirectory", ns), build.find("testSourceDirectory")]: + if test_elem is not None and test_elem.text: + test_text = test_elem.text.replace("${project.basedir}", str(project_root / module_name)) + test_path = Path(test_text) + if not test_path.is_absolute(): + test_path = project_root / module_name / test_path + if test_path.exists() and test_root is None: + test_root = test_path + + # Also check standard module layouts + if module_name.lower() not in skip_modules and not is_test_module: + std_src = project_root / module_name / "src" / "main" / "java" + if std_src.exists(): + java_count = sum(1 for _ in std_src.rglob("*.java")) + if java_count > 0: + source_candidates.append((std_src, java_count)) + + if test_root is None: + std_test = project_root / module_name / "src" / "test" / "java" + if std_test.exists() and any(std_test.rglob("*.java")): + test_root = std_test + + except Exception: + continue + + # Pick the source root with the most Java files (likely the main library) + source_root = max(source_candidates, key=lambda x: x[1])[0] if source_candidates else None + return source_root, test_root + + except Exception: + return None, None + + def find_test_root(project_root: Path) -> Path | None: """Find the test root directory for a Java project. diff --git a/codeflash/languages/java/instrumentation.py b/codeflash/languages/java/instrumentation.py index 9ecbd613e..914fe7a70 100644 --- a/codeflash/languages/java/instrumentation.py +++ b/codeflash/languages/java/instrumentation.py @@ -785,26 +785,35 @@ def _add_behavior_instrumentation(source: str, class_name: str, func_name: str, if _is_test_annotation(stripped): if not helper_added: helper_added = True - result.append(line) - i += 1 - # Collect any additional annotations - while i < len(lines) and lines[i].strip().startswith("@"): - result.append(lines[i]) + # Check if the @Test line already contains the method signature and opening brace + # (common in compact test styles like replay tests: @Test void replay_foo_0() throws Exception {) + if "{" in line: + # The annotation line IS the method signature — don't look for a separate one + result.append(line) i += 1 - - # Now find the method signature and opening brace - method_lines = [] - while i < len(lines): - method_lines.append(lines[i]) - if "{" in lines[i]: - break + method_lines = [line] + else: + result.append(line) i += 1 - # Add the method signature lines - for ml in method_lines: - result.append(ml) - i += 1 + # Collect any additional annotations + while i < len(lines) and lines[i].strip().startswith("@"): + result.append(lines[i]) + i += 1 + + # Now find the method signature and opening brace + method_lines = [] + while i < len(lines): + method_lines.append(lines[i]) + if "{" in lines[i]: + break + i += 1 + + # Add the method signature lines + for ml in method_lines: + result.append(ml) + i += 1 # Extract the test method name from the method signature test_method_name = _extract_test_method_name(method_lines) diff --git a/codeflash/languages/java/jfr_parser.py b/codeflash/languages/java/jfr_parser.py index 7775378e6..7f3816856 100644 --- a/codeflash/languages/java/jfr_parser.py +++ b/codeflash/languages/java/jfr_parser.py @@ -152,6 +152,8 @@ def _frame_to_key(self, frame: dict[str, Any]) -> str | None: method_name = method.get("name", "") if not class_name or not method_name: return None + # JFR uses / separators (JVM internal format), normalize to dots for package matching + class_name = class_name.replace("/", ".") return f"{class_name}.{method_name}" def _store_method_info(self, key: str, frame: dict[str, Any]) -> None: @@ -159,7 +161,7 @@ def _store_method_info(self, key: str, frame: dict[str, Any]) -> None: return method = frame.get("method", {}) self._method_info[key] = { - "class_name": method.get("type", {}).get("name", ""), + "class_name": method.get("type", {}).get("name", "").replace("/", "."), "method_name": method.get("name", ""), "descriptor": method.get("descriptor", ""), "line_number": str(frame.get("lineNumber", 0)), diff --git a/codeflash/languages/java/replay_test.py b/codeflash/languages/java/replay_test.py index c753bf4fa..415b7a34e 100644 --- a/codeflash/languages/java/replay_test.py +++ b/codeflash/languages/java/replay_test.py @@ -12,9 +12,12 @@ logger = logging.getLogger(__name__) -def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: Path, max_run_count: int = 256) -> int: - """Generate JUnit 5 replay test files from a trace SQLite database. +def generate_replay_tests( + trace_db_path: Path, output_dir: Path, project_root: Path, max_run_count: int = 256, test_framework: str = "junit5" +) -> int: + """Generate JUnit replay test files from a trace SQLite database. + Supports both JUnit 5 (default) and JUnit 4. Returns the number of test files generated. """ if not trace_db_path.exists(): @@ -44,9 +47,10 @@ def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: P test_methods_code: list[str] = [] class_function_names: list[str] = [] + # Global test counter to avoid duplicate method names for overloaded Java methods + method_name_counters: dict[str, int] = {} for method_name, descriptor in method_list: - # Count invocations for this method count_result = conn.execute( "SELECT COUNT(*) FROM function_calls WHERE classname = ? AND function = ? AND descriptor = ?", (classname, method_name, descriptor), @@ -57,9 +61,14 @@ def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: P safe_method = _sanitize_identifier(method_name) for i in range(invocation_count): + # Use a global counter per method name to avoid collisions on overloaded methods + test_idx = method_name_counters.get(safe_method, 0) + method_name_counters[safe_method] = test_idx + 1 + escaped_descriptor = descriptor.replace('"', '\\"') + access = "public " if test_framework == "junit4" else "" test_methods_code.append( - f" @Test void replay_{safe_method}_{i}() throws Exception {{\n" + f" @Test {access}void replay_{safe_method}_{test_idx}() throws Exception {{\n" f' helper.replay("{classname}", "{method_name}", ' f'"{escaped_descriptor}", {i});\n' f" }}" @@ -69,18 +78,28 @@ def generate_replay_tests(trace_db_path: Path, output_dir: Path, project_root: P # Generate the test file functions_comment = ",".join(class_function_names) + if test_framework == "junit4": + test_imports = "import org.junit.Test;\nimport org.junit.AfterClass;\n" + cleanup_annotation = "@AfterClass" + class_modifier = "public " + else: + test_imports = "import org.junit.jupiter.api.Test;\nimport org.junit.jupiter.api.AfterAll;\n" + cleanup_annotation = "@AfterAll" + class_modifier = "" + test_content = ( f"// codeflash:functions={functions_comment}\n" f"// codeflash:trace_file={trace_db_path.as_posix()}\n" f"// codeflash:classname={classname}\n" f"package codeflash.replay;\n\n" - f"import org.junit.jupiter.api.Test;\n" - f"import org.junit.jupiter.api.AfterAll;\n" + f"{test_imports}" f"import com.codeflash.ReplayHelper;\n\n" - f"class {test_class_name} {{\n" + f"{class_modifier}class {test_class_name} {{\n" f" private static final ReplayHelper helper =\n" f' new ReplayHelper("{trace_db_path.as_posix()}");\n\n' - f" @AfterAll static void cleanup() {{ helper.close(); }}\n\n" + "\n\n".join(test_methods_code) + "\n" + f" {cleanup_annotation} public static void cleanup() {{ helper.close(); }}\n\n" + + "\n\n".join(test_methods_code) + + "\n" "}\n" ) diff --git a/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar b/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar index cfcee9390..546a8b89d 100644 Binary files a/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar and b/codeflash/languages/java/resources/codeflash-runtime-1.0.0.jar differ diff --git a/codeflash/languages/java/tracer.py b/codeflash/languages/java/tracer.py index 7b5a30421..ab8f19514 100644 --- a/codeflash/languages/java/tracer.py +++ b/codeflash/languages/java/tracer.py @@ -14,6 +14,39 @@ logger = logging.getLogger(__name__) +GRACEFUL_SHUTDOWN_WAIT = 5 # seconds to wait after SIGTERM before SIGKILL + + +def _run_java_with_graceful_timeout( + java_command: list[str], env: dict[str, str], timeout: int, stage_name: str +) -> None: + """Run a Java command with graceful timeout handling. + + Sends SIGTERM first (allowing JFR dump and shutdown hooks to run), + then SIGKILL if the process doesn't exit within GRACEFUL_SHUTDOWN_WAIT seconds. + """ + if not timeout: + subprocess.run(java_command, env=env, check=False) + return + + import signal + + proc = subprocess.Popen(java_command, env=env) + try: + proc.wait(timeout=timeout) + except subprocess.TimeoutExpired: + logger.warning( + "%s stage timed out after %d seconds, sending SIGTERM for graceful shutdown...", stage_name, timeout + ) + proc.send_signal(signal.SIGTERM) + try: + proc.wait(timeout=GRACEFUL_SHUTDOWN_WAIT) + except subprocess.TimeoutExpired: + logger.warning("%s stage did not exit after SIGTERM, sending SIGKILL", stage_name) + proc.kill() + proc.wait() + + # --add-opens flags needed for Kryo serialization on Java 16+ ADD_OPENS_FLAGS = ( "--add-opens=java.base/java.util=ALL-UNNAMED " @@ -48,10 +81,7 @@ def trace( # Stage 1: JFR Profiling logger.info("Stage 1: Running JFR profiling...") jfr_env = self.build_jfr_env(jfr_file) - try: - subprocess.run(java_command, env=jfr_env, check=False, timeout=timeout or None) - except subprocess.TimeoutExpired: - logger.warning("JFR profiling stage timed out after %d seconds", timeout) + _run_java_with_graceful_timeout(java_command, jfr_env, timeout, "JFR profiling") if not jfr_file.exists(): logger.warning("JFR file was not created at %s", jfr_file) @@ -62,10 +92,7 @@ def trace( trace_db_path, packages, project_root=project_root, max_function_count=max_function_count, timeout=timeout ) agent_env = self.build_agent_env(config_path) - try: - subprocess.run(java_command, env=agent_env, check=False, timeout=timeout or None) - except subprocess.TimeoutExpired: - logger.warning("Argument capture stage timed out after %d seconds", timeout) + _run_java_with_graceful_timeout(java_command, agent_env, timeout, "Argument capture") if not trace_db_path.exists(): logger.error("Trace database was not created at %s", trace_db_path) @@ -95,7 +122,12 @@ def create_tracer_config( def build_jfr_env(self, jfr_file: Path) -> dict[str, str]: env = os.environ.copy() - jfr_opts = f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true" + # Use profile settings with increased sampling frequency (1ms instead of default 10ms) + # This captures more samples for short-running programs + jfr_opts = ( + f"-XX:StartFlightRecording=filename={jfr_file.resolve()},settings=profile,dumponexit=true" + ",jdk.ExecutionSample#period=1ms" + ) existing = env.get("JAVA_TOOL_OPTIONS", "") env["JAVA_TOOL_OPTIONS"] = f"{existing} {jfr_opts}".strip() return env @@ -133,7 +165,7 @@ def detect_packages_from_source(module_root: Path) -> list[str]: if stripped.startswith("package "): pkg = stripped[8:].rstrip(";").strip() parts = pkg.split(".") - prefix = ".".join(parts[: min(2, len(parts))]) + prefix = ".".join(parts[: min(3, len(parts))]) packages.add(prefix) break if stripped and not stripped.startswith("//"): @@ -153,6 +185,7 @@ def run_java_tracer( max_function_count: int = 256, timeout: int = 0, max_run_count: int = 256, + test_framework: str = "junit5", ) -> tuple[Path, Path, int]: """High-level entry point: trace a Java command and generate replay tests. @@ -169,7 +202,11 @@ def run_java_tracer( ) test_count = generate_replay_tests( - trace_db_path=trace_db, output_dir=output_dir, project_root=project_root, max_run_count=max_run_count + trace_db_path=trace_db, + output_dir=output_dir, + project_root=project_root, + max_run_count=max_run_count, + test_framework=test_framework, ) return trace_db, jfr_file, test_count diff --git a/codeflash/main.py b/codeflash/main.py index 80d6d156a..0beda6d61 100644 --- a/codeflash/main.py +++ b/codeflash/main.py @@ -6,6 +6,8 @@ from __future__ import annotations +import copy +import logging import os import sys from pathlib import Path @@ -17,16 +19,27 @@ warnings.filterwarnings("ignore") -from codeflash.cli_cmds.cli import parse_args, process_pyproject_config +from codeflash.cli_cmds.cli import ( + apply_language_config, + handle_optimize_all_arg_parsing, + parse_args, + process_pyproject_config, +) from codeflash.cli_cmds.console import paneled_text from codeflash.code_utils import env_utils from codeflash.code_utils.checkpoint import ask_should_use_checkpoint_get_functions -from codeflash.code_utils.config_parser import parse_config_file +from codeflash.code_utils.config_parser import find_all_config_files, parse_config_file from codeflash.code_utils.version_check import check_for_newer_minor_version +from codeflash.languages.registry import UnsupportedLanguageError, get_language_support +from codeflash.setup.config_writer import write_config if TYPE_CHECKING: from argparse import Namespace + from codeflash.code_utils.config_parser import LanguageConfig + from codeflash.languages.language_enum import Language + from codeflash.setup.detector import DetectedProject + def main() -> None: """Entry point for the codeflash command-line interface.""" @@ -72,21 +85,188 @@ def main() -> None: ask_run_end_to_end_test(args) else: - # Check for first-run experience (no config exists) - loaded_args = _handle_config_loading(args) - if loaded_args is None: - sys.exit(0) - args = loaded_args - - if not env_utils.check_formatter_installed(args.formatter_cmds): + language_configs = find_all_config_files() + + # Auto-configure unconfigured languages detected from changed files + # Only for subagent/no-flags path (not --file which targets a specific file) + logger = logging.getLogger("codeflash") + if not (hasattr(args, "file") and args.file): + changed_files = get_changed_file_paths() + if changed_files: + unconfigured = detect_unconfigured_languages(language_configs, changed_files) + if unconfigured: + project_root = Path.cwd() + for lang in unconfigured: + new_config = auto_configure_language(lang, project_root, logger) + if new_config is not None: + language_configs.append(new_config) + + if not language_configs: + # Fallback: no multi-config found, use existing single-config path + loaded_args = _handle_config_loading(args) + if loaded_args is None: + sys.exit(0) + args = loaded_args + + if not env_utils.check_formatter_installed(args.formatter_cmds): + return + args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args) + init_sentry(enabled=not args.disable_telemetry, exclude_errors=True) + posthog_cf.initialize_posthog(enabled=not args.disable_telemetry) + + from codeflash.optimization import optimizer + + optimizer.run_with_args(args) return - args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(args) - init_sentry(enabled=not args.disable_telemetry, exclude_errors=True) - posthog_cf.initialize_posthog(enabled=not args.disable_telemetry) - from codeflash.optimization import optimizer + # Filter to single language when --file is specified + if hasattr(args, "file") and args.file: + try: + file_lang_support = get_language_support(Path(args.file)) + file_language = file_lang_support.language + matching_configs = [lc for lc in language_configs if lc.language == file_language] + if matching_configs: + language_configs = matching_configs + # If no matching config found, let all configs run (existing behavior handles it) + except UnsupportedLanguageError: + pass # Unknown extension, let all configs run + + # Track whether --all was originally requested (before handle_optimize_all_arg_parsing + # resolves it — in multi-language mode, module_root isn't available yet so the resolution + # produces None; we re-resolve per language inside the loop) + optimize_all_requested = hasattr(args, "all") and args.all is not None + + # Multi-language path: run git/GitHub checks ONCE before the loop + args = handle_optimize_all_arg_parsing(args) + + results: dict[str, str] = {} + for lang_config in language_configs: + lang_name = lang_config.language.value + try: + pass_args = copy.deepcopy(args) + pass_args = apply_language_config(pass_args, lang_config) + + if optimize_all_requested: + pass_args.all = pass_args.module_root + + if not env_utils.check_formatter_installed(pass_args.formatter_cmds): + logger.info("Skipping %s: formatter not installed", lang_name) + results[lang_name] = "skipped" + continue + + pass_args.previous_checkpoint_functions = ask_should_use_checkpoint_get_functions(pass_args) + init_sentry(enabled=not pass_args.disable_telemetry, exclude_errors=True) + posthog_cf.initialize_posthog(enabled=not pass_args.disable_telemetry) + + logger.info("Processing %s (config: %s)", lang_name, lang_config.config_path) + + from codeflash.optimization import optimizer + + optimizer.run_with_args(pass_args) + results[lang_name] = "success" + except Exception: + logger.exception("Error processing %s, continuing with remaining languages", lang_name) + results[lang_name] = "failed" + + _log_orchestration_summary(logger, results) + + +def _log_orchestration_summary(logger: logging.Logger, results: dict[str, str]) -> None: + if not results: + return + parts = [f"{lang}: {status}" for lang, status in results.items()] + logger.info("Multi-language orchestration complete: %s", ", ".join(parts)) + + +def detect_unconfigured_languages(language_configs: list[LanguageConfig], changed_files: list[Path]) -> set[Language]: + configured = {lc.language for lc in language_configs} + changed_languages: set[Language] = set() + for f in changed_files: + try: + lang_support = get_language_support(f) + changed_languages.add(lang_support.language) + except UnsupportedLanguageError: + pass + return changed_languages - configured + + +def get_changed_file_paths() -> list[Path]: + import subprocess + + try: + result = subprocess.run( + ["git", "diff", "--name-only", "HEAD~1"], capture_output=True, text=True, timeout=10, check=False + ) + if result.returncode == 0: + return [Path(line) for line in result.stdout.strip().splitlines() if line] + except (subprocess.TimeoutExpired, FileNotFoundError, OSError): + pass + return [] + + +def detect_project_for_language(language: Language, project_root: Path) -> DetectedProject: + from codeflash.setup.detector import ( + DetectedProject, + _detect_formatter, + _detect_ignore_paths, + _detect_java_module_root, + _detect_js_module_root, + _detect_python_module_root, + _detect_test_runner, + _detect_tests_root, + ) - optimizer.run_with_args(args) + lang_str = language.value + + module_root_detectors = { + "python": _detect_python_module_root, + "java": _detect_java_module_root, + "javascript": _detect_js_module_root, + } + + detector = module_root_detectors.get(lang_str) + if detector is None: + msg = f"No auto-detection available for {lang_str}" + raise ValueError(msg) + + module_root, _ = detector(project_root) + tests_root, _ = _detect_tests_root(project_root, lang_str) + test_runner, _ = _detect_test_runner(project_root, lang_str) + formatter_cmds, _ = _detect_formatter(project_root, lang_str) + ignore_paths, _ = _detect_ignore_paths(project_root, lang_str) + + return DetectedProject( + language=lang_str, + project_root=project_root, + module_root=module_root, + tests_root=tests_root, + test_runner=test_runner, + formatter_cmds=formatter_cmds, + ignore_paths=ignore_paths, + ) + + +def auto_configure_language(language: Language, project_root: Path, logger: logging.Logger) -> LanguageConfig | None: + lang_str = language.value + try: + detected = detect_project_for_language(language, project_root) + success, msg = write_config(detected) + if success: + logger.info("Auto-created config for %s: %s", lang_str, msg) + logger.info("Review the generated config file to verify paths are correct.") + new_configs = find_all_config_files() + for nc in new_configs: + if nc.language == language: + return nc + logger.warning("Config was created for %s but could not be re-discovered.", lang_str) + return None + logger.warning("Could not auto-configure %s: %s. Skipping.", lang_str, msg) + logger.info("Run 'codeflash init' to set up %s manually.", lang_str) + return None + except Exception: + logger.exception("Auto-detection failed for %s. Skipping.", lang_str) + logger.info("Run 'codeflash init' to set up %s manually.", lang_str) + return None def _handle_config_loading(args: Namespace) -> Namespace | None: diff --git a/codeflash/setup/config_writer.py b/codeflash/setup/config_writer.py index 0889690d5..43ce03eb3 100644 --- a/codeflash/setup/config_writer.py +++ b/codeflash/setup/config_writer.py @@ -8,7 +8,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import tomlkit @@ -38,7 +38,7 @@ def write_config(detected: DetectedProject, config: CodeflashConfig | None = Non if detected.language == "python": return _write_pyproject_toml(detected.project_root, config) if detected.language == "java": - return _write_codeflash_toml(detected.project_root, config) + return _write_java_build_config(detected.project_root, config) return _write_package_json(detected.project_root, config) @@ -92,10 +92,10 @@ def _write_pyproject_toml(project_root: Path, config: CodeflashConfig) -> tuple[ return False, f"Failed to write pyproject.toml: {e}" -def _write_codeflash_toml(project_root: Path, config: CodeflashConfig) -> tuple[bool, str]: - """Write config to codeflash.toml [tool.codeflash] section for Java projects. +def _write_java_build_config(project_root: Path, config: CodeflashConfig) -> tuple[bool, str]: + """Write codeflash config to pom.xml properties or gradle.properties. - Creates codeflash.toml if it doesn't exist. + Only writes non-default values. Standard Maven/Gradle layouts need no config. Args: project_root: Project root directory. @@ -105,40 +105,141 @@ def _write_codeflash_toml(project_root: Path, config: CodeflashConfig) -> tuple[ Tuple of (success, message). """ - codeflash_toml_path = project_root / "codeflash.toml" + config_dict = config.to_pyproject_dict() - try: - # Load existing or create new - if codeflash_toml_path.exists(): - with codeflash_toml_path.open("rb") as f: - doc = tomlkit.parse(f.read()) - else: - doc = tomlkit.document() + # Filter out default values — only write overrides + defaults = {"module-root": "src/main/java", "tests-root": "src/test/java", "language": "java"} + non_default = {k: v for k, v in config_dict.items() if k not in defaults or str(v) != defaults.get(k)} + # Remove empty lists and False booleans + non_default = {k: v for k, v in non_default.items() if v not in ([], False, "", None)} - # Ensure [tool] section exists - if "tool" not in doc: - doc["tool"] = tomlkit.table() + if not non_default: + return True, "Standard Maven/Gradle layout detected — no config needed" - # Create codeflash section - codeflash_table = tomlkit.table() - codeflash_table.add(tomlkit.comment("Codeflash configuration for Java - https://docs.codeflash.ai")) + pom_path = project_root / "pom.xml" + if pom_path.exists(): + return _write_maven_properties(pom_path, non_default) - # Add config values - config_dict = config.to_pyproject_dict() - for key, value in config_dict.items(): - codeflash_table[key] = value + gradle_props_path = project_root / "gradle.properties" + return _write_gradle_properties(gradle_props_path, non_default) - # Update the document - doc["tool"]["codeflash"] = codeflash_table - # Write back - with codeflash_toml_path.open("w", encoding="utf8") as f: - f.write(tomlkit.dumps(doc)) +_MAVEN_KEY_MAP: dict[str, str] = { + "module-root": "moduleRoot", + "tests-root": "testsRoot", + "git-remote": "gitRemote", + "disable-telemetry": "disableTelemetry", + "ignore-paths": "ignorePaths", + "formatter-cmds": "formatterCmds", +} + + +def _write_maven_properties(pom_path: Path, config: dict[str, Any]) -> tuple[bool, str]: + """Add codeflash.* properties to pom.xml section. + + Uses text-based manipulation to preserve comments, formatting, and namespace declarations. + """ + import re - return True, f"Config saved to {codeflash_toml_path}" + try: + content = pom_path.read_text(encoding="utf-8") + + # Remove existing codeflash.* property lines (with surrounding whitespace) + content = re.sub(r"\n[ \t]*]*>[^<]*]*>", "", content) + + # Detect child indentation from existing properties or fall back to indent + 4 spaces + props_close = re.search(r"([ \t]*)", content) + if props_close: + parent_indent = props_close.group(1) + # Try to detect child indent from an existing property element + child_match = re.search( + r"\n([ \t]+)<[a-zA-Z]", + content[content.find("") : props_close.start()] if "" in content else "", + ) + child_indent = child_match.group(1) if child_match else parent_indent + " " + else: + parent_indent = "" + child_indent = " " + + # Build new property lines with detected indentation + new_lines = [] + for key, value in config.items(): + maven_key = f"codeflash.{_MAVEN_KEY_MAP.get(key, key)}" + if isinstance(value, list): + value = ",".join(str(v) for v in value) + elif isinstance(value, bool): + value = str(value).lower() + else: + value = str(value) + new_lines.append(f"{child_indent}<{maven_key}>{value}") + + properties_block = "\n".join(new_lines) + + # Insert before + if props_close: + content = ( + content[: props_close.start()] + + properties_block + + "\n" + + parent_indent + + "" + + content[props_close.end() :] + ) + else: + # No section — create one before + project_close = re.search(r"([ \t]*)", content) + if project_close: + indent = project_close.group(1) + inner = " " + indent + props_section = ( + f"{inner}\n" + + "\n".join(f" {line}" for line in new_lines) + + f"\n{inner}\n" + ) + content = ( + content[: project_close.start()] + + props_section + + indent + + "" + + content[project_close.end() :] + ) + + pom_path.write_text(content, encoding="utf-8") + return True, f"Config saved to {pom_path} " except Exception as e: - return False, f"Failed to write codeflash.toml: {e}" + return False, f"Failed to write Maven properties: {e}" + + +def _write_gradle_properties(props_path: Path, config: dict[str, Any]) -> tuple[bool, str]: + """Add codeflash.* entries to gradle.properties.""" + try: + lines = [] + if props_path.exists(): + lines = props_path.read_text(encoding="utf-8").splitlines() + + # Remove existing codeflash.* lines + lines = [line for line in lines if not line.strip().startswith("codeflash.")] + + # Add new config + if lines and lines[-1].strip(): + lines.append("") + lines.append("# Codeflash configuration — https://docs.codeflash.ai") + for key, value in config.items(): + gradle_key = f"codeflash.{_MAVEN_KEY_MAP.get(key, key)}" + if isinstance(value, list): + value = ",".join(str(v) for v in value) + elif isinstance(value, bool): + value = str(value).lower() + else: + value = str(value) + lines.append(f"{gradle_key}={value}") + + props_path.write_text("\n".join(lines) + "\n", encoding="utf-8") + return True, f"Config saved to {props_path}" + + except Exception as e: + return False, f"Failed to write gradle.properties: {e}" def _write_package_json(project_root: Path, config: CodeflashConfig) -> tuple[bool, str]: @@ -206,7 +307,7 @@ def remove_config(project_root: Path, language: str) -> tuple[bool, str]: if language == "python": return _remove_from_pyproject(project_root) if language == "java": - return _remove_from_codeflash_toml(project_root) + return _remove_java_build_config(project_root) return _remove_from_package_json(project_root) @@ -235,29 +336,42 @@ def _remove_from_pyproject(project_root: Path) -> tuple[bool, str]: return False, f"Failed to remove config: {e}" -def _remove_from_codeflash_toml(project_root: Path) -> tuple[bool, str]: - """Remove [tool.codeflash] section from codeflash.toml.""" - codeflash_toml_path = project_root / "codeflash.toml" - - if not codeflash_toml_path.exists(): - return True, "No codeflash.toml found" - - try: - with codeflash_toml_path.open("rb") as f: - doc = tomlkit.parse(f.read()) - - if "tool" in doc and "codeflash" in doc["tool"]: - del doc["tool"]["codeflash"] - - with codeflash_toml_path.open("w", encoding="utf8") as f: - f.write(tomlkit.dumps(doc)) +def _remove_java_build_config(project_root: Path) -> tuple[bool, str]: + """Remove codeflash.* properties from pom.xml or gradle.properties. - return True, "Removed [tool.codeflash] section from codeflash.toml" - - return True, "No codeflash config found in codeflash.toml" - - except Exception as e: - return False, f"Failed to remove config: {e}" + Priority matches _write_java_build_config: pom.xml first, then gradle.properties. + """ + # Try pom.xml first (matches write priority) — text-based removal preserves formatting + pom_path = project_root / "pom.xml" + if pom_path.exists(): + try: + import re + + content = pom_path.read_text(encoding="utf-8") + updated = re.sub(r"\n[ \t]*]*>[^<]*]*>", "", content) + if updated != content: + pom_path.write_text(updated, encoding="utf-8") + return True, "Removed codeflash properties from pom.xml" + except Exception as e: + return False, f"Failed to remove config from pom.xml: {e}" + + # Try gradle.properties + gradle_props = project_root / "gradle.properties" + if gradle_props.exists(): + try: + lines = gradle_props.read_text(encoding="utf-8").splitlines() + filtered = [ + line + for line in lines + if not line.strip().startswith("codeflash.") + and line.strip() != "# Codeflash configuration \u2014 https://docs.codeflash.ai" + ] + gradle_props.write_text("\n".join(filtered) + "\n", encoding="utf-8") + return True, "Removed codeflash properties from gradle.properties" + except Exception as e: + return False, f"Failed to remove config from gradle.properties: {e}" + + return True, "No Java build config found" def _remove_from_package_json(project_root: Path) -> tuple[bool, str]: diff --git a/codeflash/setup/detector.py b/codeflash/setup/detector.py index defe1a22d..81e900436 100644 --- a/codeflash/setup/detector.py +++ b/codeflash/setup/detector.py @@ -886,20 +886,25 @@ def has_existing_config(project_root: Path) -> tuple[bool, str | None]: Returns: Tuple of (has_config, config_file_type). - config_file_type is "pyproject.toml", "codeflash.toml", "package.json", or None. + config_file_type is "pyproject.toml", "pom.xml", "build.gradle", "package.json", or None. """ - # Check TOML config files (pyproject.toml, codeflash.toml) - for toml_filename in ("pyproject.toml", "codeflash.toml"): - toml_path = project_root / toml_filename - if toml_path.exists(): - try: - with toml_path.open("rb") as f: - data = tomlkit.parse(f.read()) - if "tool" in data and "codeflash" in data["tool"]: - return True, toml_filename - except Exception: - pass + # Check pyproject.toml (Python projects) + pyproject_path = project_root / "pyproject.toml" + if pyproject_path.exists(): + try: + with pyproject_path.open("rb") as f: + data = tomlkit.parse(f.read()) + if "tool" in data and "codeflash" in data["tool"]: + return True, "pyproject.toml" + except Exception: + pass + + # Check Java build files — for zero-config Java, any build file means "configured" + # because Java config is auto-detected from build files without explicit codeflash.* properties + for build_file in ("pom.xml", "build.gradle", "build.gradle.kts"): + if (project_root / build_file).exists(): + return True, build_file # Check package.json package_json_path = project_root / "package.json" diff --git a/codeflash/tracer.py b/codeflash/tracer.py index 84f58e9da..5f8a1a4ab 100644 --- a/codeflash/tracer.py +++ b/codeflash/tracer.py @@ -38,7 +38,7 @@ def _detect_non_python_language(args: Namespace | None) -> Language | None: - """Detect if the project uses a non-Python language from --file or config. + """Detect if the project uses a non-Python language from --file or build files. Returns a Language enum value if non-Python detected, None otherwise. """ @@ -66,15 +66,23 @@ def _detect_non_python_language(args: Namespace | None) -> Language | None: except Exception: pass - # Method 2: Check project config for language field + # Method 2: Detect Java from build files (pom.xml / build.gradle) + try: + from codeflash.languages.java.build_tools import BuildTool, detect_build_tool + + cwd = Path.cwd() + if detect_build_tool(cwd) != BuildTool.UNKNOWN: + return Language.JAVA + except Exception: + pass + + # Method 3: Check config file for language field (JS/TS via package.json) try: from codeflash.code_utils.config_parser import parse_config_file config_file = getattr(args, "config_file_path", None) if args else None config, _ = parse_config_file(config_file) lang_str = config.get("language", "") - if lang_str == "java": - return Language.JAVA if lang_str in ("javascript", "typescript"): return Language(lang_str) except Exception: @@ -336,8 +344,12 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser: max_function_count = getattr(config, "max_function_count", 256) timeout = int(getattr(config, "timeout", None) or getattr(config, "tracer_timeout", 0) or 0) + console.print("[bold]Java project detected[/]") + console.print(f" Project root: {project_root}") + console.print(f" Module root: {getattr(config, 'module_root', '?')}") + console.print(f" Tests root: {getattr(config, 'tests_root', '?')}") + from codeflash.code_utils.code_utils import get_run_tmp_file - from codeflash.languages.java.build_tools import find_test_root from codeflash.languages.java.tracer import JavaTracer, run_java_tracer tracer = JavaTracer() @@ -347,12 +359,16 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser: trace_db_path = get_run_tmp_file(Path("java_trace.db")) - # Place replay tests in the project's test source tree so Maven/Gradle can compile them - test_root = find_test_root(project_root) - if test_root: - output_dir = test_root / "codeflash" / "replay" + # Place replay tests in the project's test source tree so Maven/Gradle can compile them. + # Use the config's tests_root (correctly resolved for multi-module projects) not find_test_root(). + tests_root = Path(getattr(config, "tests_root", "")) + if tests_root.is_dir(): + output_dir = tests_root / "codeflash" / "replay" else: - output_dir = project_root / "src" / "test" / "java" / "codeflash" / "replay" + from codeflash.languages.java.build_tools import find_test_root + + test_root = find_test_root(project_root) + output_dir = (test_root or project_root / "src" / "test" / "java") / "codeflash" / "replay" output_dir.mkdir(parents=True, exist_ok=True) # Remaining args after our flags are the Java command @@ -364,6 +380,12 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser: sys.exit(1) java_command = remaining + # Detect test framework for replay test generation + from codeflash.languages.java.config import detect_java_project + + java_config = detect_java_project(project_root) + test_framework = java_config.test_framework if java_config else "junit5" + trace_db, jfr_file, test_count = run_java_tracer( java_command=java_command, trace_db_path=trace_db_path, @@ -372,6 +394,7 @@ def _run_java_tracer(existing_args: Namespace | None = None) -> ArgumentParser: output_dir=output_dir, max_function_count=max_function_count, timeout=timeout, + test_framework=test_framework, ) console.print(f"[bold green]Java tracing complete:[/] {test_count} replay test files generated") diff --git a/docs/configuration/java.mdx b/docs/configuration/java.mdx index 9d110fc55..720e5e091 100644 --- a/docs/configuration/java.mdx +++ b/docs/configuration/java.mdx @@ -1,101 +1,112 @@ --- title: "Java Configuration" -description: "Configure Codeflash for Java projects using codeflash.toml" +description: "Configure Codeflash for Java projects — zero config for standard layouts" icon: "java" -sidebarTitle: "Java (codeflash.toml)" +sidebarTitle: "Java (pom.xml / Gradle)" keywords: [ "configuration", - "codeflash.toml", "java", "maven", "gradle", "junit", + "pom.xml", + "gradle.properties", + "zero-config", ] --- # Java Configuration -Codeflash stores its configuration in `codeflash.toml` under the `[tool.codeflash]` section. +**Standard Maven/Gradle projects need zero configuration.** Codeflash auto-detects your project structure from `pom.xml` or `build.gradle` — no config file is required. -## Full Reference - -```toml -[tool.codeflash] -# Required -module-root = "src/main/java" -tests-root = "src/test/java" -language = "java" - -# Optional -test-framework = "junit5" # "junit5", "junit4", or "testng" -disable-telemetry = false -git-remote = "origin" -ignore-paths = ["src/main/java/generated/"] -``` - -All file paths are relative to the directory containing `codeflash.toml`. - - -Codeflash auto-detects most settings from your project structure. Running `codeflash init` will set up the correct config — manual configuration is usually not needed. - +For projects with non-standard layouts, you can add `codeflash.*` properties to your existing `pom.xml` or `gradle.properties`. ## Auto-Detection -When you run `codeflash init`, Codeflash inspects your project and auto-detects: +Codeflash inspects your build files and auto-detects: | Setting | Detection logic | |---------|----------------| -| `module-root` | Looks for `src/main/java` (Maven/Gradle standard layout) | -| `tests-root` | Looks for `src/test/java`, `test/`, `tests/` | -| `language` | Detected from build files (`pom.xml`, `build.gradle`) and `.java` files | -| `test-framework` | Checks build file dependencies for JUnit 5, JUnit 4, or TestNG | - -## Required Options - -- **`module-root`**: The source directory to optimize. Only code under this directory is discovered for optimization. For standard Maven/Gradle projects, this is `src/main/java`. -- **`tests-root`**: The directory where your tests are located. Codeflash discovers existing tests and places generated replay tests here. -- **`language`**: Must be set to `"java"` for Java projects. +| **Language** | Presence of `pom.xml` or `build.gradle` / `build.gradle.kts` | +| **Source root** | `src/main/java` (standard), or `` in `pom.xml`, or Gradle `sourceSets` | +| **Test root** | `src/test/java` (standard), or `` in `pom.xml` | +| **Test framework** | Checks build file dependencies for JUnit 5, JUnit 4, or TestNG | +| **Java version** | ``, `` in `pom.xml` | -## Optional Options +### Multi-module Maven projects -- **`test-framework`**: Test framework. Auto-detected from build dependencies. Supported values: `"junit5"` (default), `"junit4"`, `"testng"`. -- **`disable-telemetry`**: Disable anonymized telemetry. Defaults to `false`. -- **`git-remote`**: Git remote for pull requests. Defaults to `"origin"`. -- **`ignore-paths`**: Paths within `module-root` to skip during optimization. +For multi-module projects, Codeflash scans each module's `pom.xml` for `` and `` declarations. It picks the module with the most Java source files as the main source root, and identifies test modules by name. -## Multi-Module Projects - -For multi-module Maven/Gradle projects, place `codeflash.toml` at the project root and set `module-root` to the module you want to optimize: +For example, with this layout: ```text my-project/ -|- client/ -| |- src/main/java/com/example/client/ -| |- src/test/java/com/example/client/ -|- server/ -| |- src/main/java/com/example/server/ -|- pom.xml -|- codeflash.toml +|- client/ ← main library (most .java files) +| |- src/com/example/ +| |- pom.xml ← ${project.basedir}/src +|- test/ ← test module +| |- src/com/example/ +| |- pom.xml ← ${project.basedir}/src +|- benchmarks/ ← skipped (benchmark module) +|- pom.xml ← client, test, benchmarks ``` -```toml -[tool.codeflash] -module-root = "client/src/main/java" -tests-root = "client/src/test/java" -language = "java" +Codeflash auto-detects `client/src` as the source root and `test/src` as the test root — no manual configuration needed. + +## Custom Configuration + +If auto-detection doesn't match your project layout, add `codeflash.*` properties to your build files. + + + + +Add properties to your `pom.xml` `` section: + +```xml + + + client/src + test/src + true + upstream + src/main/java/generated/,src/main/java/proto/ + ``` -For non-standard layouts (like the Aerospike client where source is under `client/src/`), adjust paths accordingly: +This follows the same pattern as SonarQube (`sonar.sources`), JaCoCo, and other Java tools — config lives in the build file, not a separate tool-specific file. + + + + +Add properties to `gradle.properties`: -```toml -[tool.codeflash] -module-root = "client/src" -tests-root = "test/src" -language = "java" +```properties +# Only set values that differ from auto-detected defaults +codeflash.moduleRoot=lib/src/main/java +codeflash.testsRoot=lib/src/test/java +codeflash.disableTelemetry=true +codeflash.gitRemote=upstream +codeflash.ignorePaths=src/main/java/generated/ ``` -## Tracer Options + + + +## Available Properties + +All properties are optional — only set values that differ from auto-detected defaults. + +| Property | Description | Default | +|----------|------------|---------| +| `codeflash.moduleRoot` | Source directory to optimize | Auto-detected from `` or `src/main/java` | +| `codeflash.testsRoot` | Test directory | Auto-detected from `` or `src/test/java` | +| `codeflash.disableTelemetry` | Disable anonymized telemetry | `false` | +| `codeflash.gitRemote` | Git remote for pull requests | `origin` | +| `codeflash.ignorePaths` | Comma-separated paths to skip during optimization | Empty | +| `codeflash.formatterCmds` | Comma-separated formatter commands (`$file` = file path) | Empty | + +## Tracer CLI Options When using `codeflash optimize` to trace a Java program, these CLI options are available: @@ -111,9 +122,9 @@ Example with timeout: codeflash optimize --timeout 30 java -jar target/my-app.jar --app-args ``` -## Example +## Examples -### Standard Maven project +### Standard Maven project (zero config) ```text my-app/ @@ -124,17 +135,14 @@ my-app/ | |- test/java/com/example/ | |- AppTest.java |- pom.xml -|- codeflash.toml ``` -```toml -[tool.codeflash] -module-root = "src/main/java" -tests-root = "src/test/java" -language = "java" +Just run: +```bash +codeflash optimize java -jar target/my-app.jar ``` -### Gradle project +### Standard Gradle project (zero config) ```text my-lib/ @@ -142,12 +150,55 @@ my-lib/ | |- main/java/com/example/ | |- test/java/com/example/ |- build.gradle -|- codeflash.toml ``` -```toml -[tool.codeflash] -module-root = "src/main/java" -tests-root = "src/test/java" -language = "java" +Just run: +```bash +codeflash optimize java -cp build/classes/java/main com.example.Main ``` + +### Non-standard layout (with config) + +```text +aerospike-client-java/ +|- client/ +| |- src/com/aerospike/client/ ← source here (not src/main/java) +| |- pom.xml +|- test/ +| |- src/com/aerospike/test/ ← tests here +| |- pom.xml +|- pom.xml +``` + +If auto-detection doesn't pick up the right modules, add to the root `pom.xml`: + +```xml + + client/src + test/src + +``` + + +In most cases, even non-standard multi-module layouts are auto-detected correctly from `` and `` in each module's `pom.xml`. Only add manual config if auto-detection gets it wrong. + + +## FAQ + + + + No. Codeflash auto-detects Java projects from `pom.xml` or `build.gradle`. No initialization step or config file is needed for standard layouts. + + + + Codeflash reads config from your existing build files — `pom.xml` `` for Maven, `gradle.properties` for Gradle. No separate config file is created. + + + + Add `` and `` properties to your `pom.xml` or `gradle.properties`. These override auto-detection. + + + + Codeflash scans each module's `pom.xml` for `` and ``. It picks the module with the most Java files as the source root (skipping modules named `examples`, `benchmarks`, etc.) and identifies `test` modules for the test root. + + diff --git a/docs/getting-started/java-installation.mdx b/docs/getting-started/java-installation.mdx index a75e1f0b7..fb2a88ef2 100644 --- a/docs/getting-started/java-installation.mdx +++ b/docs/getting-started/java-installation.mdx @@ -12,10 +12,11 @@ keywords: "junit", "junit5", "tracing", + "zero-config", ] --- -Codeflash supports Java projects using Maven or Gradle build systems. It uses a two-stage tracing approach to capture method arguments and profiling data from running Java programs, then optimizes the hottest functions. +Codeflash supports Java projects using Maven or Gradle build systems. **No configuration file is needed** — Codeflash auto-detects your project structure from `pom.xml` or `build.gradle`. ### Prerequisites @@ -23,7 +24,7 @@ Before installing Codeflash, ensure you have: 1. **Java 11 or above** installed 2. **Maven or Gradle** as your build tool -3. **A Java project** with source code under a standard directory layout +3. **A Java project** with source code Good to have (optional): @@ -45,61 +46,48 @@ uv pip install codeflash ``` - + Navigate to your Java project root (where `pom.xml` or `build.gradle` is) and run: ```bash -codeflash init +codeflash optimize java -jar target/my-app.jar ``` -This will: -- Detect your build tool (Maven/Gradle) -- Find your source and test directories -- Create a `codeflash.toml` configuration file +That's it — no `init` step, no config file. Codeflash detects Maven/Gradle automatically and infers source and test directories from your build files. - - +Codeflash will: +1. Profile your program using JFR (Java Flight Recorder) +2. Capture method arguments using a bytecode instrumentation agent +3. Generate JUnit replay tests from the captured data +4. Rank functions by performance impact +5. Optimize the most impactful functions -Check that the configuration looks correct: + + -```bash -cat codeflash.toml -``` + +**Zero config for standard projects.** If your project uses the standard Maven/Gradle layout (`src/main/java`, `src/test/java`), everything is auto-detected. For non-standard layouts, see the [configuration guide](/configuration/java). + -You should see something like: +## Usage examples -```toml -[tool.codeflash] -module-root = "src/main/java" -tests-root = "src/test/java" -language = "java" +**Trace and optimize a JAR application:** +```bash +codeflash optimize java -jar target/my-app.jar --app-args ``` - - - -Trace and optimize a running Java program: - +**Optimize a specific file and function:** ```bash -codeflash optimize java -jar target/my-app.jar +codeflash --file src/main/java/com/example/Utils.java --function computeHash ``` -Or with Maven: - +**Trace a long-running program with a timeout:** ```bash -codeflash optimize mvn exec:java -Dexec.mainClass="com.example.Main" +codeflash optimize --timeout 30 java -jar target/my-server.jar ``` -Codeflash will: -1. Profile your program using JFR (Java Flight Recorder) -2. Capture method arguments using a bytecode instrumentation agent -3. Generate JUnit replay tests from the captured data -4. Rank functions by performance impact -5. Optimize the most impactful functions - - - +Each tracing stage runs for at most 30 seconds, then the captured data is processed. ## How it works diff --git a/tests/code_utils/test_config_parser.py b/tests/code_utils/test_config_parser.py new file mode 100644 index 000000000..dc47a4f1d --- /dev/null +++ b/tests/code_utils/test_config_parser.py @@ -0,0 +1,87 @@ +"""Tests for config_parser.py — monorepo language detection priority.""" + +from __future__ import annotations + +import json +import os +from pathlib import Path +from unittest.mock import patch + +import pytest + +from codeflash.code_utils.config_parser import parse_config_file + + +class TestMonorepoConfigPriority: + """Verify that closer config files win over parent Java build files in monorepos.""" + + def test_closer_package_json_wins_over_parent_pom_xml(self, tmp_path: Path) -> None: + """In monorepo/frontend/, a local package.json should win over a parent pom.xml.""" + # Parent Java project + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + (tmp_path / "src" / "main" / "java").mkdir(parents=True) + + # Child JS project + frontend = tmp_path / "frontend" + frontend.mkdir() + (frontend / "package.json").write_text( + json.dumps({"name": "frontend", "codeflash": {"moduleRoot": "src"}}), + encoding="utf-8", + ) + (frontend / "src").mkdir() + + with patch("codeflash.code_utils.config_parser.Path") as mock_path_cls: + mock_path_cls.cwd.return_value = frontend + # find_package_json also uses Path.cwd; mock it at the source + with patch("codeflash.code_utils.config_js.Path") as mock_js_path_cls: + mock_js_path_cls.cwd.return_value = frontend + # Also need to let normal Path operations work + mock_path_cls.side_effect = Path + mock_path_cls.cwd.return_value = frontend + mock_js_path_cls.side_effect = Path + mock_js_path_cls.cwd.return_value = frontend + + config, root = parse_config_file() + + # Should detect JS, not Java + assert config.get("language") != "java", ( + "Closer package.json should take priority over parent pom.xml" + ) + + def test_java_wins_when_no_closer_js_config(self, tmp_path: Path) -> None: + """When only a pom.xml exists (no package.json/pyproject.toml closer), Java config wins.""" + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + (tmp_path / "src" / "main" / "java").mkdir(parents=True) + + with patch("codeflash.code_utils.config_parser.Path") as mock_path_cls: + mock_path_cls.side_effect = Path + mock_path_cls.cwd.return_value = tmp_path + with patch("codeflash.code_utils.config_js.Path") as mock_js_path_cls: + mock_js_path_cls.side_effect = Path + mock_js_path_cls.cwd.return_value = tmp_path + + config, root = parse_config_file() + + assert config.get("language") == "java" + + def test_same_level_package_json_wins_over_pom_xml(self, tmp_path: Path) -> None: + """When pom.xml and package.json are at the same level, package.json wins (more specific).""" + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + (tmp_path / "src" / "main" / "java").mkdir(parents=True) + (tmp_path / "package.json").write_text( + json.dumps({"name": "mixed-project", "codeflash": {"moduleRoot": "src"}}), + encoding="utf-8", + ) + + with patch("codeflash.code_utils.config_parser.Path") as mock_path_cls: + mock_path_cls.side_effect = Path + mock_path_cls.cwd.return_value = tmp_path + with patch("codeflash.code_utils.config_js.Path") as mock_js_path_cls: + mock_js_path_cls.side_effect = Path + mock_js_path_cls.cwd.return_value = tmp_path + + config, root = parse_config_file() + + assert config.get("language") != "java", ( + "Same-level package.json should take priority over pom.xml" + ) diff --git a/tests/scripts/end_to_end_test_java_tracer.py b/tests/scripts/end_to_end_test_java_tracer.py index e904a4e98..5555b041c 100644 --- a/tests/scripts/end_to_end_test_java_tracer.py +++ b/tests/scripts/end_to_end_test_java_tracer.py @@ -59,6 +59,7 @@ def run_test(expected_improvement_pct: int) -> bool: env = os.environ.copy() env["PYTHONIOENCODING"] = "utf-8" + env["PYTHONUNBUFFERED"] = "1" logging.info(f"Running command: {' '.join(command)}") logging.info(f"Working directory: {fixture_dir}") process = subprocess.Popen( @@ -73,13 +74,11 @@ def run_test(expected_improvement_pct: int) -> bool: output = [] for line in process.stdout: - logging.info(line.strip()) + print(line, end="", flush=True) output.append(line) return_code = process.wait() stdout = "".join(output) - if return_code != 0: - logging.error(f"Full output:\n{stdout}") if return_code != 0: logging.error(f"Command returned exit code {return_code}") @@ -90,7 +89,7 @@ def run_test(expected_improvement_pct: int) -> bool: logging.error("Failed to find replay test generation message") return False - # Validate: replay tests were discovered + # Validate: replay tests were discovered (global count) replay_match = re.search(r"Discovered \d+ existing unit tests? and (\d+) replay tests?", stdout) if not replay_match: logging.error("Failed to find replay test discovery message") @@ -101,6 +100,17 @@ def run_test(expected_improvement_pct: int) -> bool: return False logging.info(f"Replay tests discovered: {num_replay}") + # Validate: replay test files were used per-function + replay_file_match = re.search(r"Discovered \d+ existing unit test files?, (\d+) replay test files?", stdout) + if not replay_file_match: + logging.error("Failed to find per-function replay test file discovery message") + return False + num_replay_files = int(replay_file_match.group(1)) + if num_replay_files == 0: + logging.error("No replay test files discovered per-function") + return False + logging.info(f"Replay test files per-function: {num_replay_files}") + # Validate: at least one optimization was found if "⚡️ Optimization successful! 📄 " not in stdout: logging.error("Failed to find optimization success message") diff --git a/tests/scripts/end_to_end_test_utilities.py b/tests/scripts/end_to_end_test_utilities.py index 12259b339..33825db4d 100644 --- a/tests/scripts/end_to_end_test_utilities.py +++ b/tests/scripts/end_to_end_test_utilities.py @@ -149,8 +149,8 @@ def build_command( if config.function_name: base_command.extend(["--function", config.function_name]) - # Check if config exists (pyproject.toml or codeflash.toml) - if so, don't override it - has_codeflash_config = (cwd / "codeflash.toml").exists() + # Check if config exists (pyproject.toml, pom.xml, build.gradle) - if so, don't override it + has_codeflash_config = (cwd / "pom.xml").exists() or (cwd / "build.gradle").exists() or (cwd / "build.gradle.kts").exists() if not has_codeflash_config: pyproject_path = cwd / "pyproject.toml" if pyproject_path.exists(): diff --git a/tests/test_git_utils.py b/tests/test_git_utils.py index f3f23c1d9..0666a6136 100644 --- a/tests/test_git_utils.py +++ b/tests/test_git_utils.py @@ -282,64 +282,145 @@ def helper(): """ +UNSUPPORTED_LANG_DIFF = """\ +--- a/src/main.rs ++++ b/src/main.rs +@@ -1,3 +1,4 @@ + fn main() { ++ let x = 1; + println!("Hello"); + +""" + +JS_TS_DIFF = """\ +--- a/src/app.js ++++ b/src/app.js +@@ -1,3 +1,4 @@ + function start() { ++ const x = 1; + return true; + +--- a/src/utils.ts ++++ b/src/utils.ts +@@ -1,3 +1,4 @@ + function helper() { ++ const y = 2; + return false; + +--- a/src/Component.jsx ++++ b/src/Component.jsx +@@ -1,3 +1,4 @@ + function Component() { ++ const a = null; + return null; + +--- a/src/Page.tsx ++++ b/src/Page.tsx +@@ -1,3 +1,4 @@ + function Page() { ++ const b = null; + return null; + +""" + +ALL_THREE_LANGS_DIFF = """\ +--- a/src/main.py ++++ b/src/main.py +@@ -1,3 +1,4 @@ + def main(): ++ x = 1 + return True + +--- a/src/Main.java ++++ b/src/Main.java +@@ -1,3 +1,4 @@ + public class Main { ++ int x = 1; + public static void main(String[] args) {} + +--- a/src/app.js ++++ b/src/app.js +@@ -1,3 +1,4 @@ + function app() { ++ const x = 1; + return true; + +--- a/src/utils.ts ++++ b/src/utils.ts +@@ -1,3 +1,4 @@ + function util() { ++ const y = 2; + return false; + +""" + + class TestGetGitDiffMultiLanguage(unittest.TestCase): @patch("codeflash.code_utils.git_utils.git.Repo") - def test_java_diff_found_when_language_is_java(self, mock_repo_cls): - from codeflash.languages.current import reset_current_language, set_current_language - + def test_java_diff_found_without_singleton(self, mock_repo_cls): repo = mock_repo_cls.return_value repo.head.commit.hexsha = "abc123" repo.working_dir = "/repo" repo.git.diff.return_value = JAVA_ADDITION_DIFF - set_current_language("java") - try: - result = get_git_diff(repo_directory=None, uncommitted_changes=True) - assert len(result) == 1 - key = list(result.keys())[0] - assert str(key).endswith("Fibonacci.java") - assert result[key] == [7, 8] - finally: - reset_current_language() + result = get_git_diff(repo_directory=None, uncommitted_changes=True) + assert len(result) == 1 + key = list(result.keys())[0] + assert str(key).endswith("Fibonacci.java") + assert result[key] == [7, 8] @patch("codeflash.code_utils.git_utils.git.Repo") - def test_java_diff_found_regardless_of_current_language(self, mock_repo_cls): - from codeflash.languages.current import reset_current_language, set_current_language + def test_unsupported_extension_still_filtered(self, mock_repo_cls): + repo = mock_repo_cls.return_value + repo.head.commit.hexsha = "abc123" + repo.working_dir = "/repo" + repo.git.diff.return_value = UNSUPPORTED_LANG_DIFF + + result = get_git_diff(repo_directory=None, uncommitted_changes=True) + assert len(result) == 0 + @patch("codeflash.code_utils.git_utils.git.Repo") + def test_mixed_lang_diff_returns_all_languages(self, mock_repo_cls): repo = mock_repo_cls.return_value repo.head.commit.hexsha = "abc123" repo.working_dir = "/repo" - repo.git.diff.return_value = JAVA_ADDITION_DIFF + repo.git.diff.return_value = MIXED_LANG_DIFF - # get_git_diff uses all registered extensions, not just the current language's - set_current_language("python") - try: - result = get_git_diff(repo_directory=None, uncommitted_changes=True) - assert len(result) == 1 - key = list(result.keys())[0] - assert str(key).endswith("Fibonacci.java") - finally: - reset_current_language() + result = get_git_diff(repo_directory=None, uncommitted_changes=True) + assert len(result) == 2 + keys = [str(k) for k in result.keys()] + assert any(k.endswith("utils.py") for k in keys) + assert any(k.endswith("App.java") for k in keys) @patch("codeflash.code_utils.git_utils.git.Repo") - def test_mixed_lang_diff_returns_all_supported_extensions(self, mock_repo_cls): - from codeflash.languages.current import reset_current_language, set_current_language + def test_js_ts_extensions_found(self, mock_repo_cls): + repo = mock_repo_cls.return_value + repo.head.commit.hexsha = "abc123" + repo.working_dir = "/repo" + repo.git.diff.return_value = JS_TS_DIFF + + result = get_git_diff(repo_directory=None, uncommitted_changes=True) + assert len(result) == 4 + keys = [str(k) for k in result.keys()] + assert any(k.endswith("app.js") for k in keys) + assert any(k.endswith("utils.ts") for k in keys) + assert any(k.endswith("Component.jsx") for k in keys) + assert any(k.endswith("Page.tsx") for k in keys) + @patch("codeflash.code_utils.git_utils.git.Repo") + def test_mixed_all_three_languages(self, mock_repo_cls): repo = mock_repo_cls.return_value repo.head.commit.hexsha = "abc123" repo.working_dir = "/repo" - repo.git.diff.return_value = MIXED_LANG_DIFF + repo.git.diff.return_value = ALL_THREE_LANGS_DIFF - # All supported extensions are returned regardless of current language - set_current_language("python") - try: - result = get_git_diff(repo_directory=None, uncommitted_changes=True) - assert len(result) == 2 - paths = [str(k) for k in result.keys()] - assert any(p.endswith("utils.py") for p in paths) - assert any(p.endswith("App.java") for p in paths) - finally: - reset_current_language() + result = get_git_diff(repo_directory=None, uncommitted_changes=True) + assert len(result) == 4 + keys = [str(k) for k in result.keys()] + assert any(k.endswith("main.py") for k in keys) + assert any(k.endswith("Main.java") for k in keys) + assert any(k.endswith("app.js") for k in keys) + assert any(k.endswith("utils.ts") for k in keys) if __name__ == "__main__": diff --git a/tests/test_languages/fixtures/java_maven/codeflash.toml b/tests/test_languages/fixtures/java_maven/codeflash.toml deleted file mode 100644 index ecd20a562..000000000 --- a/tests/test_languages/fixtures/java_maven/codeflash.toml +++ /dev/null @@ -1,5 +0,0 @@ -# Codeflash configuration for Java project - -[tool.codeflash] -module-root = "src/main/java" -tests-root = "src/test/java" diff --git a/tests/test_languages/fixtures/java_tracer_e2e/codeflash.toml b/tests/test_languages/fixtures/java_tracer_e2e/codeflash.toml deleted file mode 100644 index a501ef8cb..000000000 --- a/tests/test_languages/fixtures/java_tracer_e2e/codeflash.toml +++ /dev/null @@ -1,6 +0,0 @@ -# Codeflash configuration for Java project - -[tool.codeflash] -module-root = "src/main/java" -tests-root = "src/test/java" -language = "java" diff --git a/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java b/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java index 9b6078000..7beb2a4ea 100644 --- a/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java +++ b/tests/test_languages/fixtures/java_tracer_e2e/src/main/java/com/example/Workload.java @@ -36,20 +36,30 @@ public int instanceMethod(int x, int y) { } public static void main(String[] args) { - // Exercise the methods so the tracer can capture invocations - System.out.println("computeSum(100) = " + computeSum(100)); - System.out.println("computeSum(50) = " + computeSum(50)); + // Run methods with large inputs so JFR can capture CPU samples. + // Small inputs finish too fast (<1ms) for JFR's 10ms sampling interval. + for (int round = 0; round < 1000; round++) { + computeSum(100_000); + repeatString("hello world ", 1000); + + List nums = new ArrayList<>(); + for (int i = 1; i <= 10_000; i++) nums.add(i); + filterEvens(nums); + Workload w = new Workload(); + w.instanceMethod(100_000, 42); + } + + // Also call with small inputs for variety in traced args + System.out.println("computeSum(100) = " + computeSum(100)); System.out.println("repeatString(\"ab\", 3) = " + repeatString("ab", 3)); - System.out.println("repeatString(\"x\", 5) = " + repeatString("x", 5)); - List nums = new ArrayList<>(); - for (int i = 1; i <= 10; i++) nums.add(i); - System.out.println("filterEvens(1..10) = " + filterEvens(nums)); + List small = new ArrayList<>(); + for (int i = 1; i <= 10; i++) small.add(i); + System.out.println("filterEvens(1..10) = " + filterEvens(small)); Workload w = new Workload(); System.out.println("instanceMethod(5, 3) = " + w.instanceMethod(5, 3)); - System.out.println("instanceMethod(10, 2) = " + w.instanceMethod(10, 2)); System.out.println("Workload complete."); } diff --git a/tests/test_languages/test_java/test_java_config_detection.py b/tests/test_languages/test_java/test_java_config_detection.py new file mode 100644 index 000000000..ebb8653af --- /dev/null +++ b/tests/test_languages/test_java/test_java_config_detection.py @@ -0,0 +1,444 @@ +"""Tests for Java project auto-detection from Maven/Gradle build files. + +Tests that codeflash can detect Java projects and infer module-root, +tests-root, and other config from pom.xml / build.gradle / gradle.properties +without requiring a standalone codeflash.toml config file. +""" + +from __future__ import annotations + +from pathlib import Path + +import pytest + +from codeflash.languages.java.build_tools import ( + BuildTool, + detect_build_tool, + find_source_root, + find_test_root, + parse_java_project_config, +) + + +# --------------------------------------------------------------------------- +# Build tool detection +# --------------------------------------------------------------------------- + + +class TestDetectBuildTool: + def test_detect_maven(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + assert detect_build_tool(tmp_path) == BuildTool.MAVEN + + def test_detect_gradle(self, tmp_path: Path) -> None: + (tmp_path / "build.gradle").write_text("", encoding="utf-8") + assert detect_build_tool(tmp_path) == BuildTool.GRADLE + + def test_detect_gradle_kts(self, tmp_path: Path) -> None: + (tmp_path / "build.gradle.kts").write_text("", encoding="utf-8") + assert detect_build_tool(tmp_path) == BuildTool.GRADLE + + def test_maven_takes_priority_over_gradle(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + (tmp_path / "build.gradle").write_text("", encoding="utf-8") + assert detect_build_tool(tmp_path) == BuildTool.MAVEN + + def test_unknown_when_no_build_file(self, tmp_path: Path) -> None: + assert detect_build_tool(tmp_path) == BuildTool.UNKNOWN + + def test_detect_maven_in_parent(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + child = tmp_path / "module" + child.mkdir() + assert detect_build_tool(child) == BuildTool.MAVEN + + +# --------------------------------------------------------------------------- +# Source / test root detection (standard layouts) +# --------------------------------------------------------------------------- + + +class TestFindSourceRoot: + def test_standard_maven_layout(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + assert find_source_root(tmp_path) == src + + def test_fallback_to_src_with_java_files(self, tmp_path: Path) -> None: + src = tmp_path / "src" + src.mkdir() + (src / "App.java").write_text("class App {}", encoding="utf-8") + assert find_source_root(tmp_path) == src + + def test_returns_none_when_no_source(self, tmp_path: Path) -> None: + assert find_source_root(tmp_path) is None + + +class TestFindTestRoot: + def test_standard_maven_layout(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + test = tmp_path / "src" / "test" / "java" + test.mkdir(parents=True) + assert find_test_root(tmp_path) == test + + def test_fallback_to_test_dir(self, tmp_path: Path) -> None: + test = tmp_path / "test" + test.mkdir() + assert find_test_root(tmp_path) == test + + def test_fallback_to_tests_dir(self, tmp_path: Path) -> None: + tests = tmp_path / "tests" + tests.mkdir() + assert find_test_root(tmp_path) == tests + + def test_returns_none_when_no_test_dir(self, tmp_path: Path) -> None: + assert find_test_root(tmp_path) is None + + +# --------------------------------------------------------------------------- +# parse_java_project_config — standard layouts +# --------------------------------------------------------------------------- + + +class TestParseJavaProjectConfigStandard: + def test_standard_maven_project(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + test = tmp_path / "src" / "test" / "java" + test.mkdir(parents=True) + + config = parse_java_project_config(tmp_path) + assert config is not None + assert config["language"] == "java" + assert config["module_root"] == str(src) + assert config["tests_root"] == str(test) + + def test_standard_gradle_project(self, tmp_path: Path) -> None: + (tmp_path / "build.gradle").write_text("", encoding="utf-8") + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + test = tmp_path / "src" / "test" / "java" + test.mkdir(parents=True) + + config = parse_java_project_config(tmp_path) + assert config is not None + assert config["language"] == "java" + assert config["module_root"] == str(src) + assert config["tests_root"] == str(test) + + def test_returns_none_for_non_java_project(self, tmp_path: Path) -> None: + assert parse_java_project_config(tmp_path) is None + + def test_defaults_when_dirs_missing(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + config = parse_java_project_config(tmp_path) + assert config is not None + # Falls back to default paths even if they don't exist + assert str(tmp_path / "src" / "main" / "java") == config["module_root"] + assert config["language"] == "java" + + +# --------------------------------------------------------------------------- +# parse_java_project_config — Maven properties (codeflash.*) +# --------------------------------------------------------------------------- + +MAVEN_POM_WITH_PROPERTIES = """\ + + 4.0.0 + com.example + test + 1.0 + + custom/src + custom/test + true + upstream + gen/,build/ + + +""" + + +class TestMavenCodeflashProperties: + def test_reads_custom_properties(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text(MAVEN_POM_WITH_PROPERTIES, encoding="utf-8") + (tmp_path / "custom" / "src").mkdir(parents=True) + (tmp_path / "custom" / "test").mkdir(parents=True) + + config = parse_java_project_config(tmp_path) + assert config is not None + assert config["module_root"] == str((tmp_path / "custom" / "src").resolve()) + assert config["tests_root"] == str((tmp_path / "custom" / "test").resolve()) + assert config["disable_telemetry"] is True + assert config["git_remote"] == "upstream" + assert len(config["ignore_paths"]) == 2 + + def test_properties_override_auto_detection(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text(MAVEN_POM_WITH_PROPERTIES, encoding="utf-8") + # Create standard dirs AND custom dirs + (tmp_path / "src" / "main" / "java").mkdir(parents=True) + (tmp_path / "custom" / "src").mkdir(parents=True) + (tmp_path / "custom" / "test").mkdir(parents=True) + + config = parse_java_project_config(tmp_path) + assert config is not None + # Should use custom paths from properties, not auto-detected standard paths + assert config["module_root"] == str((tmp_path / "custom" / "src").resolve()) + + def test_no_properties_uses_defaults(self, tmp_path: Path) -> None: + (tmp_path / "pom.xml").write_text( + '4.0.0', + encoding="utf-8", + ) + (tmp_path / "src" / "main" / "java").mkdir(parents=True) + + config = parse_java_project_config(tmp_path) + assert config is not None + assert config["disable_telemetry"] is False + assert config["git_remote"] == "origin" + + +# --------------------------------------------------------------------------- +# parse_java_project_config — Gradle properties +# --------------------------------------------------------------------------- + + +class TestGradleCodeflashProperties: + def test_reads_gradle_properties(self, tmp_path: Path) -> None: + (tmp_path / "build.gradle").write_text("", encoding="utf-8") + (tmp_path / "gradle.properties").write_text( + "codeflash.moduleRoot=lib/src\ncodeflash.testsRoot=lib/test\ncodeflash.disableTelemetry=true\n", + encoding="utf-8", + ) + (tmp_path / "lib" / "src").mkdir(parents=True) + (tmp_path / "lib" / "test").mkdir(parents=True) + + config = parse_java_project_config(tmp_path) + assert config is not None + assert config["module_root"] == str((tmp_path / "lib" / "src").resolve()) + assert config["tests_root"] == str((tmp_path / "lib" / "test").resolve()) + assert config["disable_telemetry"] is True + + def test_ignores_non_codeflash_properties(self, tmp_path: Path) -> None: + (tmp_path / "build.gradle").write_text("", encoding="utf-8") + (tmp_path / "gradle.properties").write_text( + "org.gradle.jvmargs=-Xmx2g\ncodeflash.gitRemote=upstream\n", + encoding="utf-8", + ) + + config = parse_java_project_config(tmp_path) + assert config is not None + assert config["git_remote"] == "upstream" + + def test_no_gradle_properties_uses_defaults(self, tmp_path: Path) -> None: + (tmp_path / "build.gradle").write_text("", encoding="utf-8") + (tmp_path / "src" / "main" / "java").mkdir(parents=True) + (tmp_path / "src" / "test" / "java").mkdir(parents=True) + + config = parse_java_project_config(tmp_path) + assert config is not None + assert config["git_remote"] == "origin" + assert config["disable_telemetry"] is False + + +# --------------------------------------------------------------------------- +# Multi-module Maven projects +# --------------------------------------------------------------------------- + +PARENT_POM = """\ + + 4.0.0 + com.example + parent + 1.0 + pom + + client + test + examples + + +""" + +CLIENT_POM = """\ + + 4.0.0 + + com.example + parent + 1.0 + + client + + ${project.basedir}/src + + +""" + +TEST_POM = """\ + + 4.0.0 + + com.example + parent + 1.0 + + test + + ${project.basedir}/src + + +""" + +EXAMPLES_POM = """\ + + 4.0.0 + + com.example + parent + 1.0 + + examples + + ${project.basedir}/src + + +""" + + +class TestMultiModuleMaven: + @pytest.fixture + def multi_module_project(self, tmp_path: Path) -> Path: + """Create a multi-module Maven project mimicking aerospike's layout.""" + (tmp_path / "pom.xml").write_text(PARENT_POM, encoding="utf-8") + + # Client module — main library with the most Java files + client = tmp_path / "client" + client.mkdir() + (client / "pom.xml").write_text(CLIENT_POM, encoding="utf-8") + client_src = client / "src" / "com" / "example" / "client" + client_src.mkdir(parents=True) + for i in range(10): + (client_src / f"Class{i}.java").write_text(f"class Class{i} {{}}", encoding="utf-8") + + # Test module — test code + test = tmp_path / "test" + test.mkdir() + (test / "pom.xml").write_text(TEST_POM, encoding="utf-8") + test_src = test / "src" / "com" / "example" / "test" + test_src.mkdir(parents=True) + (test_src / "ClientTest.java").write_text("class ClientTest {}", encoding="utf-8") + + # Examples module — should be skipped + examples = tmp_path / "examples" + examples.mkdir() + (examples / "pom.xml").write_text(EXAMPLES_POM, encoding="utf-8") + examples_src = examples / "src" / "com" / "example" + examples_src.mkdir(parents=True) + (examples_src / "Example.java").write_text("class Example {}", encoding="utf-8") + + return tmp_path + + def test_detects_client_as_source_root(self, multi_module_project: Path) -> None: + config = parse_java_project_config(multi_module_project) + assert config is not None + assert config["module_root"] == str(multi_module_project / "client" / "src") + + def test_detects_test_module_as_test_root(self, multi_module_project: Path) -> None: + config = parse_java_project_config(multi_module_project) + assert config is not None + assert config["tests_root"] == str(multi_module_project / "test" / "src") + + def test_skips_examples_module(self, multi_module_project: Path) -> None: + config = parse_java_project_config(multi_module_project) + assert config is not None + # The module_root should be client/src, not examples/src + assert config["module_root"] == str(multi_module_project / "client" / "src") + + def test_picks_module_with_most_java_files(self, multi_module_project: Path) -> None: + """Client has 10 .java files, examples has 1 — client should win.""" + config = parse_java_project_config(multi_module_project) + assert config is not None + assert "client" in config["module_root"] + + +# --------------------------------------------------------------------------- +# Language detection from config_parser +# --------------------------------------------------------------------------- + + +class TestLanguageDetectionViaConfigParser: + def test_java_detected_from_pom_xml(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + (tmp_path / "src" / "main" / "java").mkdir(parents=True) + (tmp_path / "src" / "test" / "java").mkdir(parents=True) + monkeypatch.chdir(tmp_path) + + from codeflash.code_utils.config_parser import _try_parse_java_build_config + + result = _try_parse_java_build_config() + assert result is not None + config, project_root = result + assert config["language"] == "java" + assert project_root == tmp_path + + def test_java_detected_from_build_gradle(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + (tmp_path / "build.gradle").write_text("", encoding="utf-8") + (tmp_path / "src" / "main" / "java").mkdir(parents=True) + monkeypatch.chdir(tmp_path) + + from codeflash.code_utils.config_parser import _try_parse_java_build_config + + result = _try_parse_java_build_config() + assert result is not None + config, _ = result + assert config["language"] == "java" + + def test_no_java_detected_for_python_project(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + (tmp_path / "pyproject.toml").write_text("[tool.codeflash]\nmodule-root='src'\ntests-root='tests'\n", encoding="utf-8") + monkeypatch.chdir(tmp_path) + + from codeflash.code_utils.config_parser import _try_parse_java_build_config + + result = _try_parse_java_build_config() + assert result is None + + +# --------------------------------------------------------------------------- +# Language detection from tracer +# --------------------------------------------------------------------------- + + +class TestTracerLanguageDetection: + def test_detects_java_from_build_files(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + monkeypatch.chdir(tmp_path) + + from codeflash.languages.base import Language + from codeflash.tracer import _detect_non_python_language + + result = _detect_non_python_language(None) + assert result == Language.JAVA + + def test_no_detection_without_build_files(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + monkeypatch.chdir(tmp_path) + + from codeflash.tracer import _detect_non_python_language + + result = _detect_non_python_language(None) + assert result is None + + def test_detects_java_from_file_extension(self, tmp_path: Path) -> None: + java_file = tmp_path / "App.java" + java_file.write_text("class App {}", encoding="utf-8") + + from argparse import Namespace + + from codeflash.languages.base import Language + from codeflash.tracer import _detect_non_python_language + + args = Namespace(file=str(java_file)) + result = _detect_non_python_language(args) + assert result == Language.JAVA diff --git a/tests/test_languages/test_java/test_jfr_parser.py b/tests/test_languages/test_java/test_jfr_parser.py new file mode 100644 index 000000000..8b5cf8a6e --- /dev/null +++ b/tests/test_languages/test_java/test_jfr_parser.py @@ -0,0 +1,302 @@ +"""Tests for JFR parser — class name normalization, package filtering, addressable time.""" + +from __future__ import annotations + +import json +import subprocess +from pathlib import Path +from unittest.mock import patch + +import pytest + +from codeflash.languages.java.jfr_parser import JfrProfile + + +def _make_jfr_json(events: list[dict]) -> str: + """Create fake JFR JSON output matching the jfr print format.""" + return json.dumps({"recording": {"events": events}}) + + +def _make_execution_sample(class_name: str, method_name: str, start_time: str = "2026-01-01T00:00:00Z") -> dict: + return { + "type": "jdk.ExecutionSample", + "values": { + "startTime": start_time, + "stackTrace": { + "frames": [ + { + "method": { + "type": {"name": class_name}, + "name": method_name, + "descriptor": "()V", + }, + "lineNumber": 42, + } + ], + }, + }, + } + + +class TestClassNameNormalization: + """Test that JVM internal class names (com/example/Foo) are normalized to dots (com.example.Foo).""" + + def test_slash_separators_normalized_to_dots(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [ + _make_execution_sample("com/aerospike/client/command/Buffer", "bytesToInt"), + _make_execution_sample("com/aerospike/client/command/Buffer", "bytesToInt"), + _make_execution_sample("com/aerospike/client/util/Utf8", "encodedLength"), + ] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.aerospike"]) + + assert profile._total_samples == 3 + assert len(profile._method_samples) == 2 + + # Keys should use dots, not slashes + assert "com.aerospike.client.command.Buffer.bytesToInt" in profile._method_samples + assert "com.aerospike.client.util.Utf8.encodedLength" in profile._method_samples + + def test_method_info_uses_dot_class_names(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [_make_execution_sample("com/example/MyClass", "myMethod")] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + info = profile._method_info.get("com.example.MyClass.myMethod") + assert info is not None + assert info["class_name"] == "com.example.MyClass" + assert info["method_name"] == "myMethod" + + +class TestPackageFiltering: + def test_filters_by_package_prefix(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [ + _make_execution_sample("com/aerospike/client/Value", "get"), + _make_execution_sample("java/util/HashMap", "put"), + _make_execution_sample("com/aerospike/benchmarks/Main", "main"), + ] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.aerospike"]) + + # Only com.aerospike classes should be in samples + assert len(profile._method_samples) == 2 + assert "com.aerospike.client.Value.get" in profile._method_samples + assert "com.aerospike.benchmarks.Main.main" in profile._method_samples + assert "java.util.HashMap.put" not in profile._method_samples + + def test_empty_packages_includes_all(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [ + _make_execution_sample("com/example/Foo", "bar"), + _make_execution_sample("java/lang/String", "length"), + ] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, []) + + assert len(profile._method_samples) == 2 + + +class TestAddressableTime: + def test_addressable_time_proportional_to_samples(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + # 3 samples for methodA, 1 for methodB, spanning 10 seconds + jfr_json = _make_jfr_json( + [ + _make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:00Z"), + _make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:03Z"), + _make_execution_sample("com/example/Foo", "methodA", "2026-01-01T00:00:06Z"), + _make_execution_sample("com/example/Foo", "methodB", "2026-01-01T00:00:10Z"), + ] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + time_a = profile.get_addressable_time_ns("com.example.Foo", "methodA") + time_b = profile.get_addressable_time_ns("com.example.Foo", "methodB") + + # methodA has 3x the samples of methodB, so 3x the addressable time + assert time_a > 0 + assert time_b > 0 + assert time_a == pytest.approx(time_b * 3, rel=0.01) + + def test_addressable_time_zero_for_unknown_method(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [_make_execution_sample("com/example/Foo", "bar")] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + assert profile.get_addressable_time_ns("com.example.Foo", "nonExistent") == 0.0 + + +class TestMethodRanking: + def test_ranking_ordered_by_sample_count(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [ + _make_execution_sample("com/example/A", "hot"), + _make_execution_sample("com/example/A", "hot"), + _make_execution_sample("com/example/A", "hot"), + _make_execution_sample("com/example/B", "warm"), + _make_execution_sample("com/example/B", "warm"), + _make_execution_sample("com/example/C", "cold"), + ] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + ranking = profile.get_method_ranking() + assert len(ranking) == 3 + assert ranking[0]["method_name"] == "hot" + assert ranking[0]["sample_count"] == 3 + assert ranking[1]["method_name"] == "warm" + assert ranking[1]["sample_count"] == 2 + assert ranking[2]["method_name"] == "cold" + assert ranking[2]["sample_count"] == 1 + + def test_empty_ranking_when_no_samples(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json([]) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + assert profile.get_method_ranking() == [] + + def test_ranking_uses_dot_class_names(self, tmp_path: Path) -> None: + jfr_file = tmp_path / "test.jfr" + jfr_file.write_text("dummy", encoding="utf-8") + + jfr_json = _make_jfr_json( + [_make_execution_sample("com/example/nested/Deep", "method")] + ) + + with patch("subprocess.run") as mock_run: + mock_run.return_value = subprocess.CompletedProcess(args=[], returncode=0, stdout=jfr_json, stderr="") + profile = JfrProfile(jfr_file, ["com.example"]) + + ranking = profile.get_method_ranking() + assert len(ranking) == 1 + assert ranking[0]["class_name"] == "com.example.nested.Deep" + + +class TestGracefulTimeout: + """Test that _run_java_with_graceful_timeout sends SIGTERM before SIGKILL.""" + + def test_sends_sigterm_on_timeout(self) -> None: + import signal + + from codeflash.languages.java.tracer import _run_java_with_graceful_timeout + + # Run a sleep command with a 1s timeout — should get SIGTERM'd + import os + + env = os.environ.copy() + _run_java_with_graceful_timeout(["sleep", "60"], env, timeout=1, stage_name="test") + # If we get here, the process was killed (didn't hang for 60s) + + def test_no_timeout_runs_normally(self) -> None: + import os + + from codeflash.languages.java.tracer import _run_java_with_graceful_timeout + + env = os.environ.copy() + _run_java_with_graceful_timeout(["echo", "hello"], env, timeout=0, stage_name="test") + # Should complete without error + + +class TestProjectRootResolution: + """Test that project_root is correctly set for Java multi-module projects.""" + + def test_java_project_root_is_build_root_not_module(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """For multi-module Maven, project_root should be the root with , not a sub-module.""" + # Create a multi-module project + (tmp_path / "pom.xml").write_text( + 'client', + encoding="utf-8", + ) + client = tmp_path / "client" + client.mkdir() + (client / "pom.xml").write_text("", encoding="utf-8") + src = client / "src" / "main" / "java" + src.mkdir(parents=True) + test = tmp_path / "src" / "test" / "java" + test.mkdir(parents=True) + monkeypatch.chdir(tmp_path) + + from codeflash.code_utils.config_parser import parse_config_file + + config, config_path = parse_config_file() + assert config["language"] == "java" + + # config_path should be the project root directory + assert config_path == tmp_path + + def test_project_root_is_path_not_string(self, tmp_path: Path, monkeypatch: pytest.MonkeyPatch) -> None: + """project_root from process_pyproject_config should be a Path for Java projects.""" + from argparse import Namespace + + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + test = tmp_path / "src" / "test" / "java" + test.mkdir(parents=True) + monkeypatch.chdir(tmp_path) + + from codeflash.cli_cmds.cli import process_pyproject_config + + # Create a minimal args namespace matching what parse_args produces + args = Namespace( + config_file=None, module_root=None, tests_root=None, benchmarks_root=None, + ignore_paths=None, pytest_cmd=None, formatter_cmds=None, disable_telemetry=None, + disable_imports_sorting=None, git_remote=None, override_fixtures=None, + benchmark=False, verbose=False, version=False, show_config=False, reset_config=False, + ) + args = process_pyproject_config(args) + + assert hasattr(args, "project_root") + assert isinstance(args.project_root, Path) + assert args.project_root == tmp_path diff --git a/tests/test_languages/test_java/test_replay_test_generation.py b/tests/test_languages/test_java/test_replay_test_generation.py new file mode 100644 index 000000000..da7138114 --- /dev/null +++ b/tests/test_languages/test_java/test_replay_test_generation.py @@ -0,0 +1,255 @@ +"""Tests for Java replay test generation — JUnit 4/5 support, overload handling, instrumentation skip.""" + +from __future__ import annotations + +import sqlite3 +from pathlib import Path + +import pytest + +from codeflash.languages.java.replay_test import generate_replay_tests, parse_replay_test_metadata + + +@pytest.fixture +def trace_db(tmp_path: Path) -> Path: + """Create a trace database with sample function calls.""" + db_path = tmp_path / "trace.db" + conn = sqlite3.connect(str(db_path)) + conn.execute( + "CREATE TABLE function_calls(" + "type TEXT, function TEXT, classname TEXT, filename TEXT, " + "line_number INTEGER, descriptor TEXT, time_ns INTEGER, args BLOB)" + ) + conn.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)") + conn.execute( + "INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ("call", "add", "com.example.Calculator", "Calculator.java", 10, "(II)I", 1000, b"\x00"), + ) + conn.execute( + "INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ("call", "add", "com.example.Calculator", "Calculator.java", 10, "(II)I", 2000, b"\x00"), + ) + conn.execute( + "INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ("call", "multiply", "com.example.Calculator", "Calculator.java", 20, "(II)I", 3000, b"\x00"), + ) + conn.commit() + conn.close() + return db_path + + +@pytest.fixture +def trace_db_overloaded(tmp_path: Path) -> Path: + """Create a trace database with overloaded methods (same name, different descriptors).""" + db_path = tmp_path / "trace_overloaded.db" + conn = sqlite3.connect(str(db_path)) + conn.execute( + "CREATE TABLE function_calls(" + "type TEXT, function TEXT, classname TEXT, filename TEXT, " + "line_number INTEGER, descriptor TEXT, time_ns INTEGER, args BLOB)" + ) + conn.execute("CREATE TABLE metadata(key TEXT PRIMARY KEY, value TEXT)") + # Two overloads of estimateKeySize with different descriptors + for i in range(3): + conn.execute( + "INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ("call", "estimateKeySize", "com.example.Command", "Command.java", 10, "(I)I", i * 1000, b"\x00"), + ) + for i in range(2): + conn.execute( + "INSERT INTO function_calls VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ( + "call", + "estimateKeySize", + "com.example.Command", + "Command.java", + 15, + "(Ljava/lang/String;)I", + (i + 10) * 1000, + b"\x00", + ), + ) + conn.commit() + conn.close() + return db_path + + +class TestGenerateReplayTestsJunit5: + def test_generates_junit5_by_default(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + count = generate_replay_tests(trace_db, output_dir, tmp_path) + assert count == 1 + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "import org.junit.jupiter.api.Test;" in content + assert "import org.junit.jupiter.api.AfterAll;" in content + assert "@Test void replay_add_0()" in content + + def test_junit5_class_is_package_private(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path) + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "class ReplayTest_" in content + assert "public class ReplayTest_" not in content + + +class TestGenerateReplayTestsJunit4: + def test_generates_junit4_imports(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + count = generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4") + assert count == 1 + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "import org.junit.Test;" in content + assert "import org.junit.AfterClass;" in content + assert "org.junit.jupiter" not in content + + def test_junit4_methods_are_public(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4") + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "@Test public void replay_add_0()" in content + + def test_junit4_class_is_public(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4") + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "public class ReplayTest_" in content + + def test_junit4_cleanup_uses_afterclass(self, trace_db: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path, test_framework="junit4") + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + assert "@AfterClass" in content + assert "@AfterAll" not in content + + +class TestOverloadedMethods: + def test_no_duplicate_method_names(self, trace_db_overloaded: Path, tmp_path: Path) -> None: + output_dir = tmp_path / "output" + count = generate_replay_tests(trace_db_overloaded, output_dir, tmp_path) + assert count == 1 + + test_file = list(output_dir.glob("*.java"))[0] + content = test_file.read_text(encoding="utf-8") + + # Should have 5 unique methods (3 from first overload + 2 from second) + assert "replay_estimateKeySize_0" in content + assert "replay_estimateKeySize_1" in content + assert "replay_estimateKeySize_2" in content + assert "replay_estimateKeySize_3" in content + assert "replay_estimateKeySize_4" in content + + # Verify no duplicates by counting occurrences + lines = content.splitlines() + method_lines = [l for l in lines if "void replay_estimateKeySize_" in l] + method_names = [l.split("void ")[1].split("(")[0] for l in method_lines] + assert len(method_names) == len(set(method_names)), f"Duplicate methods: {method_names}" + + +class TestReplayTestInstrumentation: + def test_replay_tests_instrumented_correctly(self, trace_db: Path, tmp_path: Path) -> None: + """Replay tests with compact @Test lines should be instrumented without orphaned code.""" + from codeflash.languages.java.discovery import discover_functions_from_source + + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path) + + test_file = list(output_dir.glob("*.java"))[0] + + src = "public class Calculator { public int add(int a, int b) { return a + b; } }" + funcs = discover_functions_from_source(src, tmp_path / "Calculator.java") + target = funcs[0] + + from codeflash.languages.java.support import JavaSupport + + support = JavaSupport() + success, instrumented = support.instrument_existing_test( + test_path=test_file, + call_positions=[], + function_to_optimize=target, + tests_project_root=tmp_path, + mode="behavior", + ) + assert success + assert instrumented is not None + assert "__perfinstrumented" in instrumented + + # Verify no code outside class body + lines = instrumented.splitlines() + class_closed = False + for line in lines: + if line.strip() == "}" and not line.startswith(" "): + class_closed = True + elif class_closed and line.strip() and not line.strip().startswith("//"): + pytest.fail(f"Orphaned code outside class: {line}") + + def test_replay_tests_perf_instrumented(self, trace_db: Path, tmp_path: Path) -> None: + from codeflash.languages.java.discovery import discover_functions_from_source + + output_dir = tmp_path / "output" + generate_replay_tests(trace_db, output_dir, tmp_path) + + test_file = list(output_dir.glob("*.java"))[0] + + src = "public class Calculator { public int add(int a, int b) { return a + b; } }" + funcs = discover_functions_from_source(src, tmp_path / "Calculator.java") + target = funcs[0] + + from codeflash.languages.java.support import JavaSupport + + support = JavaSupport() + success, instrumented = support.instrument_existing_test( + test_path=test_file, + call_positions=[], + function_to_optimize=target, + tests_project_root=tmp_path, + mode="performance", + ) + assert success + assert "__perfonlyinstrumented" in instrumented + + def test_regular_tests_still_instrumented(self, tmp_path: Path) -> None: + from codeflash.languages.java.discovery import discover_functions_from_source + + src = "public class Calculator { public int add(int a, int b) { return a + b; } }" + funcs = discover_functions_from_source(src, tmp_path / "Calculator.java") + target = funcs[0] + + test_file = tmp_path / "CalculatorTest.java" + test_file.write_text( + """ +import org.junit.jupiter.api.Test; +public class CalculatorTest { + @Test + public void testAdd() { + Calculator calc = new Calculator(); + calc.add(1, 2); + } +} +""", + encoding="utf-8", + ) + + from codeflash.languages.java.support import JavaSupport + + support = JavaSupport() + success, instrumented = support.instrument_existing_test( + test_path=test_file, + call_positions=[], + function_to_optimize=target, + tests_project_root=tmp_path, + mode="behavior", + ) + assert success + assert "CODEFLASH_LOOP_INDEX" in instrumented diff --git a/tests/test_languages/test_registry.py b/tests/test_languages/test_registry.py index cdb44e1af..417a4a62e 100644 --- a/tests/test_languages/test_registry.py +++ b/tests/test_languages/test_registry.py @@ -272,6 +272,7 @@ def test_clear_registry_removes_everything(self): assert not is_language_supported(Language.PYTHON) # Re-register all languages by importing + from codeflash.languages.java.support import JavaSupport from codeflash.languages.javascript.support import JavaScriptSupport, TypeScriptSupport from codeflash.languages.python.support import PythonSupport @@ -279,6 +280,7 @@ def test_clear_registry_removes_everything(self): register_language(PythonSupport) register_language(JavaScriptSupport) register_language(TypeScriptSupport) + register_language(JavaSupport) # Should be supported again assert is_language_supported(Language.PYTHON) diff --git a/tests/test_multi_config_discovery.py b/tests/test_multi_config_discovery.py new file mode 100644 index 000000000..90cc7eca3 --- /dev/null +++ b/tests/test_multi_config_discovery.py @@ -0,0 +1,211 @@ +from __future__ import annotations + +import json +from pathlib import Path +from unittest.mock import patch + +import tomlkit + +from codeflash.code_utils.config_parser import find_all_config_files +from codeflash.languages.language_enum import Language + + +def write_toml(path: Path, data: dict) -> None: + path.write_text(tomlkit.dumps(data), encoding="utf-8") + + +class TestFindAllConfigFiles: + def test_finds_pyproject_toml_with_codeflash_section(self, tmp_path: Path, monkeypatch) -> None: + write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}}) + monkeypatch.chdir(tmp_path) + result = find_all_config_files() + assert len(result) == 1 + assert result[0].language == Language.PYTHON + assert result[0].config_path == tmp_path / "pyproject.toml" + + def test_finds_java_via_build_tool_detection(self, tmp_path: Path, monkeypatch) -> None: + java_config = {"language": "java", "module_root": str(tmp_path / "src/main/java")} + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + monkeypatch.chdir(tmp_path) + with patch( + "codeflash.code_utils.config_parser._parse_java_config_for_dir", + return_value=java_config, + ): + result = find_all_config_files() + assert len(result) == 1 + assert result[0].language == Language.JAVA + assert result[0].config_path == tmp_path + + def test_finds_multiple_configs_python_and_java(self, tmp_path: Path, monkeypatch) -> None: + write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}}) + java_config = {"language": "java", "module_root": str(tmp_path / "src/main/java")} + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + monkeypatch.chdir(tmp_path) + with patch( + "codeflash.code_utils.config_parser._parse_java_config_for_dir", + return_value=java_config, + ): + result = find_all_config_files() + assert len(result) == 2 + languages = {r.language for r in result} + assert languages == {Language.PYTHON, Language.JAVA} + + def test_skips_pyproject_without_codeflash_section(self, tmp_path: Path, monkeypatch) -> None: + write_toml(tmp_path / "pyproject.toml", {"tool": {"black": {"line-length": 120}}}) + monkeypatch.chdir(tmp_path) + result = find_all_config_files() + assert len(result) == 0 + + def test_finds_config_in_parent_directory(self, tmp_path: Path, monkeypatch) -> None: + write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}}) + subdir = tmp_path / "subproject" + subdir.mkdir() + java_config = {"language": "java", "module_root": str(subdir / "src/main/java")} + (subdir / "pom.xml").write_text("", encoding="utf-8") + monkeypatch.chdir(subdir) + with patch( + "codeflash.code_utils.config_parser._parse_java_config_for_dir", + return_value=java_config, + ): + result = find_all_config_files() + assert len(result) == 2 + languages = {r.language for r in result} + assert languages == {Language.PYTHON, Language.JAVA} + + def test_closest_config_wins_per_language(self, tmp_path: Path, monkeypatch) -> None: + write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "."}}}) + subdir = tmp_path / "sub" + subdir.mkdir() + write_toml(subdir / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}}) + monkeypatch.chdir(subdir) + result = find_all_config_files() + assert len(result) == 1 + assert result[0].language == Language.PYTHON + assert result[0].config_path == subdir / "pyproject.toml" + + def test_finds_package_json_with_codeflash_section(self, tmp_path: Path, monkeypatch) -> None: + pkg = {"codeflash": {"moduleRoot": "src"}} + (tmp_path / "package.json").write_text(json.dumps(pkg), encoding="utf-8") + monkeypatch.chdir(tmp_path) + result = find_all_config_files() + assert len(result) == 1 + assert result[0].language == Language.JAVASCRIPT + assert result[0].config_path == tmp_path / "package.json" + + def test_finds_all_three_config_types(self, tmp_path: Path, monkeypatch) -> None: + write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}}) + pkg = {"codeflash": {"moduleRoot": "src"}} + (tmp_path / "package.json").write_text(json.dumps(pkg), encoding="utf-8") + java_config = {"language": "java", "module_root": str(tmp_path / "src/main/java")} + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + monkeypatch.chdir(tmp_path) + with patch( + "codeflash.code_utils.config_parser._parse_java_config_for_dir", + return_value=java_config, + ): + result = find_all_config_files() + assert len(result) == 3 + languages = {r.language for r in result} + assert languages == {Language.PYTHON, Language.JAVA, Language.JAVASCRIPT} + + def test_no_java_when_no_build_file_exists(self, tmp_path: Path, monkeypatch) -> None: + monkeypatch.chdir(tmp_path) + result = find_all_config_files() + assert len(result) == 0 + + def test_missing_codeflash_section_skipped(self, tmp_path: Path, monkeypatch) -> None: + write_toml(tmp_path / "pyproject.toml", {"tool": {"other": {"key": "value"}}}) + monkeypatch.chdir(tmp_path) + result = find_all_config_files() + assert len(result) == 0 + + def test_finds_java_in_subdirectory(self, tmp_path: Path, monkeypatch) -> None: + """Monorepo: Java project in a subdirectory is discovered from the repo root.""" + write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}}) + java_dir = tmp_path / "java" + java_dir.mkdir() + (java_dir / "pom.xml").write_text("", encoding="utf-8") + java_config = {"language": "java", "module_root": str(java_dir / "src/main/java")} + monkeypatch.chdir(tmp_path) + with patch( + "codeflash.code_utils.config_parser._parse_java_config_for_dir", + return_value=java_config, + ): + result = find_all_config_files() + assert len(result) == 2 + languages = {r.language for r in result} + assert languages == {Language.PYTHON, Language.JAVA} + java_result = next(r for r in result if r.language == Language.JAVA) + assert java_result.config_path == java_dir + + def test_finds_js_in_subdirectory(self, tmp_path: Path, monkeypatch) -> None: + """Monorepo: JS project in a subdirectory is discovered from the repo root.""" + write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}}) + js_dir = tmp_path / "js" + js_dir.mkdir() + pkg = {"codeflash": {"moduleRoot": "src"}} + (js_dir / "package.json").write_text(json.dumps(pkg), encoding="utf-8") + monkeypatch.chdir(tmp_path) + result = find_all_config_files() + assert len(result) == 2 + languages = {r.language for r in result} + assert languages == {Language.PYTHON, Language.JAVASCRIPT} + + def test_finds_all_three_in_monorepo_subdirs(self, tmp_path: Path, monkeypatch) -> None: + """Monorepo: Python at root, Java and JS in subdirectories.""" + write_toml(tmp_path / "pyproject.toml", {"tool": {"codeflash": {"module-root": "src"}}}) + java_dir = tmp_path / "java" + java_dir.mkdir() + (java_dir / "pom.xml").write_text("", encoding="utf-8") + java_config = {"language": "java", "module_root": str(java_dir / "src/main/java")} + js_dir = tmp_path / "js" + js_dir.mkdir() + pkg = {"codeflash": {"moduleRoot": "src"}} + (js_dir / "package.json").write_text(json.dumps(pkg), encoding="utf-8") + monkeypatch.chdir(tmp_path) + with patch( + "codeflash.code_utils.config_parser._parse_java_config_for_dir", + return_value=java_config, + ): + result = find_all_config_files() + assert len(result) == 3 + languages = {r.language for r in result} + assert languages == {Language.PYTHON, Language.JAVA, Language.JAVASCRIPT} + + def test_skips_hidden_and_build_subdirs(self, tmp_path: Path, monkeypatch) -> None: + """Subdirectory scan skips .git, node_modules, target, etc.""" + for name in [".git", "node_modules", "target", "build", "__pycache__"]: + d = tmp_path / name + d.mkdir() + write_toml(d / "pyproject.toml", {"tool": {"codeflash": {"module-root": "."}}}) + monkeypatch.chdir(tmp_path) + result = find_all_config_files() + assert len(result) == 0 + + def test_root_config_wins_over_subdir(self, tmp_path: Path, monkeypatch) -> None: + """Config at CWD (found during upward walk) takes precedence over subdirectory.""" + (tmp_path / "pom.xml").write_text("", encoding="utf-8") + java_dir = tmp_path / "java" + java_dir.mkdir() + (java_dir / "pom.xml").write_text("", encoding="utf-8") + java_config = {"language": "java", "module_root": str(tmp_path / "src/main/java")} + monkeypatch.chdir(tmp_path) + with patch( + "codeflash.code_utils.config_parser._parse_java_config_for_dir", + return_value=java_config, + ): + result = find_all_config_files() + java_results = [r for r in result if r.language == Language.JAVA] + assert len(java_results) == 1 + assert java_results[0].config_path == tmp_path + + +def test_find_all_functions_uses_registry_not_singleton() -> None: + """DISC-04: Verify find_all_functions_in_file uses per-file registry lookup.""" + import inspect + + from codeflash.discovery.functions_to_optimize import find_all_functions_in_file + + source = inspect.getsource(find_all_functions_in_file) + assert "get_language_support" in source + assert "current_language_support" not in source diff --git a/tests/test_multi_language_orchestration.py b/tests/test_multi_language_orchestration.py new file mode 100644 index 000000000..41e4ed9d7 --- /dev/null +++ b/tests/test_multi_language_orchestration.py @@ -0,0 +1,879 @@ +from __future__ import annotations + +import logging +from argparse import Namespace +from pathlib import Path +from unittest.mock import MagicMock, patch + +import tomlkit + +from codeflash.code_utils.config_parser import LanguageConfig, normalize_toml_config +from codeflash.languages.language_enum import Language + + +def write_toml(path: Path, data: dict) -> None: + path.write_text(tomlkit.dumps(data), encoding="utf-8") + + +def make_base_args(**overrides) -> Namespace: + defaults = { + "module_root": None, + "tests_root": None, + "benchmarks_root": None, + "ignore_paths": None, + "pytest_cmd": None, + "formatter_cmds": None, + "disable_telemetry": None, + "disable_imports_sorting": None, + "git_remote": None, + "override_fixtures": None, + "config_file": None, + "file": None, + "function": None, + "no_pr": False, + "verbose": False, + "command": None, + "verify_setup": False, + "version": False, + "show_config": False, + "reset_config": False, + "previous_checkpoint_functions": [], + } + defaults.update(overrides) + return Namespace(**defaults) + + +class TestApplyLanguageConfig: + def test_sets_module_root(self, tmp_path: Path) -> None: + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + config = {"module_root": str(src)} + lang_config = LanguageConfig(config=config, config_path=tmp_path, language=Language.JAVA) + args = make_base_args() + + from codeflash.cli_cmds.cli import apply_language_config + + result = apply_language_config(args, lang_config) + assert result.module_root == src.resolve() + + def test_sets_tests_root(self, tmp_path: Path) -> None: + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + tests = tmp_path / "src" / "test" / "java" + tests.mkdir(parents=True) + config = {"module_root": str(src), "tests_root": str(tests)} + lang_config = LanguageConfig(config=config, config_path=tmp_path, language=Language.JAVA) + args = make_base_args() + + from codeflash.cli_cmds.cli import apply_language_config + + result = apply_language_config(args, lang_config) + assert result.tests_root == tests.resolve() + + def test_resolves_paths_relative_to_config_parent(self, tmp_path: Path) -> None: + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + tests = tmp_path / "src" / "test" / "java" + tests.mkdir(parents=True) + config = {"module_root": str(src), "tests_root": str(tests)} + lang_config = LanguageConfig(config=config, config_path=tmp_path, language=Language.JAVA) + args = make_base_args() + + from codeflash.cli_cmds.cli import apply_language_config + + result = apply_language_config(args, lang_config) + assert result.module_root.is_absolute() + assert result.tests_root.is_absolute() + + def test_sets_project_root(self, tmp_path: Path) -> None: + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + tests = tmp_path / "src" / "test" / "java" + tests.mkdir(parents=True) + (tmp_path / "pom.xml").touch() + config = {"module_root": str(src), "tests_root": str(tests)} + lang_config = LanguageConfig(config=config, config_path=tmp_path, language=Language.JAVA) + args = make_base_args() + + from codeflash.cli_cmds.cli import apply_language_config + + result = apply_language_config(args, lang_config) + assert result.project_root == tmp_path.resolve() + + def test_preserves_cli_overrides(self, tmp_path: Path) -> None: + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + override_module = tmp_path / "custom" + override_module.mkdir() + tests = tmp_path / "src" / "test" / "java" + tests.mkdir(parents=True) + config = {"module_root": str(src), "tests_root": str(tests)} + lang_config = LanguageConfig(config=config, config_path=tmp_path, language=Language.JAVA) + args = make_base_args(module_root=str(override_module)) + + from codeflash.cli_cmds.cli import apply_language_config + + result = apply_language_config(args, lang_config) + assert result.module_root == override_module.resolve() + + def test_copies_formatter_cmds(self, tmp_path: Path) -> None: + src = tmp_path / "src" + src.mkdir() + tests = tmp_path / "tests" + tests.mkdir() + config = {"module_root": str(src), "tests_root": str(tests), "formatter_cmds": ["black $file"]} + lang_config = LanguageConfig(config=config, config_path=tmp_path / "pyproject.toml", language=Language.PYTHON) + args = make_base_args() + + from codeflash.cli_cmds.cli import apply_language_config + + result = apply_language_config(args, lang_config) + assert result.formatter_cmds == ["black $file"] + + def test_sets_language_singleton(self, tmp_path: Path) -> None: + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + tests = tmp_path / "src" / "test" / "java" + tests.mkdir(parents=True) + config = {"module_root": str(src), "tests_root": str(tests)} + lang_config = LanguageConfig(config=config, config_path=tmp_path, language=Language.JAVA) + args = make_base_args() + + with patch("codeflash.cli_cmds.cli.set_current_language") as mock_set: + from codeflash.cli_cmds.cli import apply_language_config + + apply_language_config(args, lang_config) + mock_set.assert_called_once_with(Language.JAVA) + + def test_handles_python_config(self, tmp_path: Path) -> None: + src = tmp_path / "src" + src.mkdir() + tests = tmp_path / "tests" + tests.mkdir() + config = {"module_root": str(src), "tests_root": str(tests)} + lang_config = LanguageConfig(config=config, config_path=tmp_path / "pyproject.toml", language=Language.PYTHON) + args = make_base_args() + + from codeflash.cli_cmds.cli import apply_language_config + + result = apply_language_config(args, lang_config) + assert result.module_root == src.resolve() + assert result.tests_root == tests.resolve() + + def test_java_default_tests_root(self, tmp_path: Path, monkeypatch) -> None: + src = tmp_path / "src" / "main" / "java" + src.mkdir(parents=True) + default_tests = tmp_path / "src" / "test" / "java" + default_tests.mkdir(parents=True) + monkeypatch.chdir(tmp_path) + config = {"module_root": str(src)} + lang_config = LanguageConfig(config=config, config_path=tmp_path, language=Language.JAVA) + args = make_base_args() + + from codeflash.cli_cmds.cli import apply_language_config + + result = apply_language_config(args, lang_config) + assert result.tests_root == default_tests.resolve() + + +def make_lang_config(tmp_path: Path, language: Language, subdir: str = "") -> LanguageConfig: + if language == Language.PYTHON: + src = tmp_path / subdir / "src" if subdir else tmp_path / "src" + tests = tmp_path / subdir / "tests" if subdir else tmp_path / "tests" + src.mkdir(parents=True, exist_ok=True) + tests.mkdir(parents=True, exist_ok=True) + config_path = tmp_path / subdir / "pyproject.toml" if subdir else tmp_path / "pyproject.toml" + return LanguageConfig( + config={"module_root": str(src), "tests_root": str(tests)}, + config_path=config_path, + language=Language.PYTHON, + ) + if language == Language.JAVASCRIPT: + src = tmp_path / subdir / "src" if subdir else tmp_path / "src" + tests = tmp_path / subdir / "tests" if subdir else tmp_path / "tests" + src.mkdir(parents=True, exist_ok=True) + tests.mkdir(parents=True, exist_ok=True) + config_path = tmp_path / subdir / "package.json" if subdir else tmp_path / "package.json" + return LanguageConfig( + config={"module_root": str(src), "tests_root": str(tests)}, + config_path=config_path, + language=Language.JAVASCRIPT, + ) + src = tmp_path / subdir / "src" / "main" / "java" if subdir else tmp_path / "src" / "main" / "java" + tests = tmp_path / subdir / "src" / "test" / "java" if subdir else tmp_path / "src" / "test" / "java" + src.mkdir(parents=True, exist_ok=True) + tests.mkdir(parents=True, exist_ok=True) + config_path = tmp_path / subdir if subdir else tmp_path + return LanguageConfig( + config={"module_root": str(src), "tests_root": str(tests)}, + config_path=config_path, + language=Language.JAVA, + ) + + +class TestMultiLanguageOrchestration: + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + def test_sequential_passes_calls_optimizer_per_language( + self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path + ) -> None: + py_config = make_lang_config(tmp_path, Language.PYTHON) + java_config = make_lang_config(tmp_path, Language.JAVA) + mock_find_configs.return_value = [py_config, java_config] + mock_parse_args.return_value = make_base_args(disable_telemetry=False) + + from codeflash.main import main + + main() + + assert mock_run.call_count == 2 + + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + @patch("codeflash.cli_cmds.cli.set_current_language") + def test_singleton_set_per_pass( + self, + mock_set_lang, + _sentry, + _posthog, + _ver, + _banner, + mock_parse_args, + mock_find_configs, + mock_run, + _handle_all, + _fmt, + _ckpt, + tmp_path: Path, + ) -> None: + py_config = make_lang_config(tmp_path, Language.PYTHON) + java_config = make_lang_config(tmp_path, Language.JAVA) + mock_find_configs.return_value = [py_config, java_config] + mock_parse_args.return_value = make_base_args(disable_telemetry=False) + + from codeflash.main import main + + main() + + # set_current_language is called once per language pass via apply_language_config + lang_calls = [c for c in mock_set_lang.call_args_list if c[0][0] in (Language.PYTHON, Language.JAVA)] + assert len(lang_calls) >= 2 + called_langs = {c[0][0] for c in lang_calls} + assert Language.PYTHON in called_langs + assert Language.JAVA in called_langs + + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files", return_value=[]) + @patch("codeflash.main._handle_config_loading") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + @patch("codeflash.main.get_changed_file_paths", return_value=[]) + def test_fallback_to_single_config_when_no_multi_configs( + self, _changed, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_handle_config, mock_run, _fmt, _ckpt, tmp_path: Path + ) -> None: + base = make_base_args( + disable_telemetry=False, formatter_cmds=[], module_root=str(tmp_path), tests_root=str(tmp_path) + ) + mock_parse_args.return_value = base + mock_handle_config.return_value = base + + from codeflash.main import main + + main() + + mock_handle_config.assert_called_once() + mock_run.assert_called_once() + + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + def test_args_deep_copied_between_passes( + self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path + ) -> None: + py_config = make_lang_config(tmp_path, Language.PYTHON) + java_config = make_lang_config(tmp_path, Language.JAVA) + mock_find_configs.return_value = [py_config, java_config] + mock_parse_args.return_value = make_base_args(disable_telemetry=False) + + from codeflash.main import main + + main() + + assert mock_run.call_count == 2 + call1_args = mock_run.call_args_list[0][0][0] + call2_args = mock_run.call_args_list[1][0][0] + # Args should be different objects (deep copied) + assert call1_args is not call2_args + # Module roots should differ between Python and Java configs + assert call1_args.module_root != call2_args.module_root + + + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + def test_error_in_one_language_does_not_block_others( + self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path + ) -> None: + py_config = make_lang_config(tmp_path, Language.PYTHON) + java_config = make_lang_config(tmp_path, Language.JAVA) + mock_find_configs.return_value = [py_config, java_config] + mock_parse_args.return_value = make_base_args(disable_telemetry=False) + # First call (Python) raises, second call (Java) succeeds + mock_run.side_effect = [RuntimeError("Python optimizer crashed"), None] + + from codeflash.main import main + + main() + + assert mock_run.call_count == 2 + + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + def test_orchestration_summary_logged( + self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path + ) -> None: + py_config = make_lang_config(tmp_path, Language.PYTHON) + java_config = make_lang_config(tmp_path, Language.JAVA) + mock_find_configs.return_value = [py_config, java_config] + mock_parse_args.return_value = make_base_args(disable_telemetry=False) + + with patch("codeflash.main._log_orchestration_summary") as mock_summary: + from codeflash.main import main + + main() + + mock_summary.assert_called_once() + results = mock_summary.call_args[0][1] + assert results["python"] == "success" + assert results["java"] == "success" + + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + def test_summary_reports_failure_status( + self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path + ) -> None: + py_config = make_lang_config(tmp_path, Language.PYTHON) + java_config = make_lang_config(tmp_path, Language.JAVA) + mock_find_configs.return_value = [py_config, java_config] + mock_parse_args.return_value = make_base_args(disable_telemetry=False) + mock_run.side_effect = [RuntimeError("boom"), None] + + with patch("codeflash.main._log_orchestration_summary") as mock_summary: + from codeflash.main import main + + main() + + results = mock_summary.call_args[0][1] + assert results["python"] == "failed" + assert results["java"] == "success" + + +class TestOrchestrationSummaryLogging: + def test_summary_format_all_success(self) -> None: + import logging + + from codeflash.main import _log_orchestration_summary + + with patch.object(logging.Logger, "info") as mock_info: + logger = logging.getLogger("codeflash.test") + _log_orchestration_summary(logger, {"python": "success", "java": "success"}) + mock_info.assert_called_once() + msg = mock_info.call_args[0][0] % mock_info.call_args[0][1:] + assert "python: success" in msg + assert "java: success" in msg + + def test_summary_format_mixed_statuses(self) -> None: + import logging + + from codeflash.main import _log_orchestration_summary + + with patch.object(logging.Logger, "info") as mock_info: + logger = logging.getLogger("codeflash.test") + _log_orchestration_summary(logger, {"python": "failed", "java": "success", "javascript": "skipped"}) + mock_info.assert_called_once() + msg = mock_info.call_args[0][0] % mock_info.call_args[0][1:] + assert "python: failed" in msg + assert "java: success" in msg + assert "javascript: skipped" in msg + + def test_summary_no_results_no_log(self) -> None: + import logging + + from codeflash.main import _log_orchestration_summary + + with patch.object(logging.Logger, "info") as mock_info: + logger = logging.getLogger("codeflash.test") + _log_orchestration_summary(logger, {}) + mock_info.assert_not_called() + + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed") + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + def test_summary_reports_skipped_status( + self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, mock_fmt, _ckpt, tmp_path: Path + ) -> None: + py_config = make_lang_config(tmp_path, Language.PYTHON) + java_config = make_lang_config(tmp_path, Language.JAVA) + mock_find_configs.return_value = [py_config, java_config] + mock_parse_args.return_value = make_base_args(disable_telemetry=False) + # Python formatter check fails (skipped), Java succeeds + mock_fmt.side_effect = [False, True] + + with patch("codeflash.main._log_orchestration_summary") as mock_summary: + from codeflash.main import main + + main() + + results = mock_summary.call_args[0][1] + assert results["python"] == "skipped" + assert results["java"] == "success" + assert mock_run.call_count == 1 + + +class TestCLIPathRouting: + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + def test_file_flag_filters_to_matching_language( + self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path + ) -> None: + py_config = make_lang_config(tmp_path, Language.PYTHON) + java_config = make_lang_config(tmp_path, Language.JAVA) + mock_find_configs.return_value = [py_config, java_config] + mock_parse_args.return_value = make_base_args(file="path/to/Foo.java", disable_telemetry=False) + + from codeflash.main import main + + main() + + assert mock_run.call_count == 1 + + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + def test_file_flag_python_file_filters_to_python( + self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path + ) -> None: + py_config = make_lang_config(tmp_path, Language.PYTHON) + java_config = make_lang_config(tmp_path, Language.JAVA) + mock_find_configs.return_value = [py_config, java_config] + mock_parse_args.return_value = make_base_args(file="module.py", disable_telemetry=False) + + from codeflash.main import main + + main() + + assert mock_run.call_count == 1 + + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + def test_file_flag_unknown_extension_runs_all( + self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path + ) -> None: + py_config = make_lang_config(tmp_path, Language.PYTHON) + java_config = make_lang_config(tmp_path, Language.JAVA) + mock_find_configs.return_value = [py_config, java_config] + mock_parse_args.return_value = make_base_args(file="Foo.rs", disable_telemetry=False) + + from codeflash.main import main + + main() + + assert mock_run.call_count == 2 + + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + def test_file_flag_no_matching_config_runs_all( + self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path + ) -> None: + py_config = make_lang_config(tmp_path, Language.PYTHON) + mock_find_configs.return_value = [py_config] + mock_parse_args.return_value = make_base_args(file="Foo.java", disable_telemetry=False) + + from codeflash.main import main + + main() + + assert mock_run.call_count == 1 + + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + def test_all_flag_sets_module_root_per_language( + self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path + ) -> None: + py_config = make_lang_config(tmp_path, Language.PYTHON) + java_config = make_lang_config(tmp_path, Language.JAVA) + mock_find_configs.return_value = [py_config, java_config] + mock_parse_args.return_value = make_base_args(all="", disable_telemetry=False) + + from codeflash.main import main + + main() + + assert mock_run.call_count == 2 + for call in mock_run.call_args_list: + passed_args = call[0][0] + assert passed_args.all == passed_args.module_root + + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + def test_no_flags_runs_all_language_passes( + self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path + ) -> None: + py_config = make_lang_config(tmp_path, Language.PYTHON) + java_config = make_lang_config(tmp_path, Language.JAVA) + mock_find_configs.return_value = [py_config, java_config] + mock_parse_args.return_value = make_base_args(disable_telemetry=False) + + from codeflash.main import main + + main() + + assert mock_run.call_count == 2 + + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + def test_file_flag_typescript_extension( + self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path + ) -> None: + # .tsx maps to Language.TYPESCRIPT, which is distinct from Language.JAVASCRIPT. + # When no TYPESCRIPT config exists, all configs run (fallback behavior). + py_config = make_lang_config(tmp_path, Language.PYTHON) + js_config = make_lang_config(tmp_path, Language.JAVASCRIPT, subdir="js-proj") + mock_find_configs.return_value = [py_config, js_config] + mock_parse_args.return_value = make_base_args(file="path/to/Component.tsx", disable_telemetry=False) + + from codeflash.main import main + + main() + + # No TYPESCRIPT config exists, so all configs run (same as unknown extension) + assert mock_run.call_count == 2 + + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + def test_file_flag_jsx_extension( + self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path + ) -> None: + # .jsx maps to Language.JAVASCRIPT, so it correctly filters to the JS config. + py_config = make_lang_config(tmp_path, Language.PYTHON) + js_config = make_lang_config(tmp_path, Language.JAVASCRIPT, subdir="js-proj") + mock_find_configs.return_value = [py_config, js_config] + mock_parse_args.return_value = make_base_args(file="path/to/Widget.jsx", disable_telemetry=False) + + from codeflash.main import main + + main() + + assert mock_run.call_count == 1 + + +class TestDirectFunctionCoverage: + @patch("subprocess.run") + def test_get_changed_file_paths_returns_diff_files(self, mock_subprocess) -> None: + from codeflash.main import get_changed_file_paths + + mock_subprocess.return_value = MagicMock(returncode=0, stdout="src/main.py\nsrc/App.java\n") + result = get_changed_file_paths() + assert len(result) == 2 + assert Path("src/main.py") in result + assert Path("src/App.java") in result + + @patch("subprocess.run") + def test_get_changed_file_paths_returns_empty_on_failure(self, mock_subprocess) -> None: + from codeflash.main import get_changed_file_paths + + mock_subprocess.return_value = MagicMock(returncode=1, stdout="") + result = get_changed_file_paths() + assert result == [] + + def test_detect_project_for_language_java(self, tmp_path: Path) -> None: + from codeflash.main import detect_project_for_language + + with ( + patch( + "codeflash.setup.detector._detect_java_module_root", + return_value=(tmp_path / "src/main/java", "pom.xml"), + ), + patch( + "codeflash.setup.detector._detect_tests_root", + return_value=(tmp_path / "src/test/java", "maven"), + ), + patch("codeflash.setup.detector._detect_test_runner", return_value=("maven", "pom.xml")), + patch("codeflash.setup.detector._detect_formatter", return_value=([], None)), + patch("codeflash.setup.detector._detect_ignore_paths", return_value=([], None)), + ): + result = detect_project_for_language(Language.JAVA, tmp_path) + assert result is not None + assert result.language == "java" + + def test_detect_project_for_language_unsupported(self) -> None: + from codeflash.main import detect_project_for_language + + mock_lang = MagicMock() + mock_lang.value = "rust" + try: + detect_project_for_language(mock_lang, Path("/tmp")) + assert False, "Should have raised ValueError" + except ValueError as e: + assert "No auto-detection available" in str(e) + + def test_empty_config_no_module_root(self, tmp_path: Path) -> None: + config: dict = {} + result = normalize_toml_config(config, tmp_path / "pyproject.toml") + assert result["formatter_cmds"] == [] + assert result["disable_telemetry"] is False + assert "module_root" not in result + + +class TestNormalizeTomlConfig: + def test_converts_hyphenated_keys_to_underscored(self, tmp_path: Path) -> None: + config = {"module-root": "src", "tests-root": "tests"} + (tmp_path / "src").mkdir() + (tmp_path / "tests").mkdir() + result = normalize_toml_config(config, tmp_path / "pyproject.toml") + assert "module_root" in result + assert "tests_root" in result + assert "module-root" not in result + assert "tests-root" not in result + + def test_resolves_paths_relative_to_config_parent(self, tmp_path: Path) -> None: + src = tmp_path / "src" + src.mkdir() + config = {"module-root": "src"} + result = normalize_toml_config(config, tmp_path / "pyproject.toml") + assert result["module_root"] == str(src.resolve()) + + def test_applies_default_values(self, tmp_path: Path) -> None: + config: dict = {} + result = normalize_toml_config(config, tmp_path / "pyproject.toml") + assert result["formatter_cmds"] == [] + assert result["disable_telemetry"] is False + assert result["override_fixtures"] is False + assert result["git_remote"] == "origin" + assert result["pytest_cmd"] == "pytest" + + def test_preserves_explicit_values(self, tmp_path: Path) -> None: + config = {"disable-telemetry": True, "formatter-cmds": ["prettier $file"]} + result = normalize_toml_config(config, tmp_path / "pyproject.toml") + assert result["disable_telemetry"] is True + assert result["formatter_cmds"] == ["prettier $file"] + + def test_resolves_ignore_paths(self, tmp_path: Path) -> None: + config = {"ignore-paths": ["build", "dist"]} + result = normalize_toml_config(config, tmp_path / "pyproject.toml") + assert result["ignore_paths"] == [ + str((tmp_path / "build").resolve()), + str((tmp_path / "dist").resolve()), + ] + + def test_empty_ignore_paths_default(self, tmp_path: Path) -> None: + config: dict = {} + result = normalize_toml_config(config, tmp_path / "pyproject.toml") + assert result["ignore_paths"] == [] + + +class TestUnconfiguredLanguageDetection: + def test_detects_unconfigured_java_from_changed_files(self) -> None: + from codeflash.main import detect_unconfigured_languages + + configs = [LanguageConfig(config={}, config_path=Path("pyproject.toml"), language=Language.PYTHON)] + changed = [Path("src/main/java/Foo.java"), Path("src/Bar.py")] + result = detect_unconfigured_languages(configs, changed) + assert Language.JAVA in result + assert Language.PYTHON not in result + + def test_no_unconfigured_when_all_configured(self) -> None: + from codeflash.main import detect_unconfigured_languages + + configs = [ + LanguageConfig(config={}, config_path=Path("pyproject.toml"), language=Language.PYTHON), + LanguageConfig(config={}, config_path=Path(), language=Language.JAVA), + ] + changed = [Path("Foo.java"), Path("bar.py")] + result = detect_unconfigured_languages(configs, changed) + assert result == set() + + def test_ignores_unsupported_extensions(self) -> None: + from codeflash.main import detect_unconfigured_languages + + changed = [Path("main.rs"), Path("lib.go")] + result = detect_unconfigured_languages([], changed) + assert result == set() + + @patch("codeflash.main.find_all_config_files") + def test_auto_config_adds_language_config_on_success(self, mock_find_configs, tmp_path: Path) -> None: + from codeflash.main import auto_configure_language + + new_lc = LanguageConfig(config={}, config_path=tmp_path, language=Language.JAVA) + mock_find_configs.return_value = [new_lc] + + logger = logging.getLogger("codeflash.test") + with ( + patch("codeflash.main.write_config", return_value=(True, "Created config")) as mock_write, + patch("codeflash.main.detect_project_for_language") as mock_detect, + ): + mock_detect.return_value = MagicMock() + result = auto_configure_language(Language.JAVA, tmp_path, logger) + + assert result is not None + assert result.language == Language.JAVA + mock_write.assert_called_once() + + def test_auto_config_failure_logs_warning(self, tmp_path: Path, caplog: object) -> None: + from codeflash.main import auto_configure_language + + logger = logging.getLogger("codeflash.test") + with ( + patch("codeflash.main.detect_project_for_language", side_effect=RuntimeError("detection failed")), + caplog.at_level(logging.WARNING), # type: ignore[union-attr] + ): + result = auto_configure_language(Language.JAVA, tmp_path, logger) + + assert result is None + + @patch("codeflash.main.ask_should_use_checkpoint_get_functions", return_value=[]) + @patch("codeflash.main.env_utils.check_formatter_installed", return_value=True) + @patch("codeflash.main.handle_optimize_all_arg_parsing", side_effect=lambda args: args) + @patch("codeflash.optimization.optimizer.run_with_args") + @patch("codeflash.main.find_all_config_files") + @patch("codeflash.main.parse_args") + @patch("codeflash.main.print_codeflash_banner") + @patch("codeflash.main.check_for_newer_minor_version") + @patch("codeflash.telemetry.posthog_cf.initialize_posthog") + @patch("codeflash.telemetry.sentry.init_sentry") + def test_per_language_logging_shows_config_path( + self, _sentry, _posthog, _ver, _banner, mock_parse_args, mock_find_configs, mock_run, _handle_all, _fmt, _ckpt, tmp_path: Path + ) -> None: + py_config = make_lang_config(tmp_path, Language.PYTHON) + mock_find_configs.return_value = [py_config] + mock_parse_args.return_value = make_base_args(disable_telemetry=False) + + with patch("codeflash.main._log_orchestration_summary"): + from codeflash.main import main + + with patch("logging.Logger.info") as mock_log_info: + main() + logged_messages = [str(call) for call in mock_log_info.call_args_list] + processing_logs = [m for m in logged_messages if "Processing" in m and "config:" in m] + assert len(processing_logs) >= 1 diff --git a/tests/test_setup/test_config_writer.py b/tests/test_setup/test_config_writer.py new file mode 100644 index 000000000..89426bdfd --- /dev/null +++ b/tests/test_setup/test_config_writer.py @@ -0,0 +1,148 @@ +"""Tests for config_writer module — Java pom.xml formatting preservation.""" + +from pathlib import Path + + +class TestWriteMavenProperties: + """Tests for _write_maven_properties — text-based pom.xml editing.""" + + def test_preserves_comments(self, tmp_path: Path) -> None: + pom = tmp_path / "pom.xml" + pom.write_text( + '\n' + "\n" + " \n" + " \n" + " 17\n" + " \n" + "\n", + encoding="utf-8", + ) + + from codeflash.setup.config_writer import _write_maven_properties + + ok, _ = _write_maven_properties(pom, {"module-root": "src/main/java"}) + result = pom.read_text(encoding="utf-8") + + assert ok + assert "" in result + assert "src/main/java" in result + + def test_preserves_namespace(self, tmp_path: Path) -> None: + pom = tmp_path / "pom.xml" + pom.write_text( + '\n' + '\n' + " \n" + " 17\n" + " \n" + "\n", + encoding="utf-8", + ) + + from codeflash.setup.config_writer import _write_maven_properties + + ok, _ = _write_maven_properties(pom, {"module-root": "src/main/java"}) + result = pom.read_text(encoding="utf-8") + + assert ok + assert 'xmlns="http://maven.apache.org/POM/4.0.0"' in result + # Must NOT have ns0: prefix (ElementTree bug) + assert "ns0:" not in result + + def test_updates_existing_codeflash_properties(self, tmp_path: Path) -> None: + pom = tmp_path / "pom.xml" + pom.write_text( + "\n" + " \n" + " old/path\n" + " \n" + "\n", + encoding="utf-8", + ) + + from codeflash.setup.config_writer import _write_maven_properties + + ok, _ = _write_maven_properties(pom, {"module-root": "new/path"}) + result = pom.read_text(encoding="utf-8") + + assert ok + assert "old/path" not in result + assert "new/path" in result + + def test_creates_properties_section(self, tmp_path: Path) -> None: + pom = tmp_path / "pom.xml" + pom.write_text( + "\n" " 4.0.0\n" "\n", + encoding="utf-8", + ) + + from codeflash.setup.config_writer import _write_maven_properties + + ok, _ = _write_maven_properties(pom, {"module-root": "src/main/java"}) + result = pom.read_text(encoding="utf-8") + + assert ok + assert "" in result + assert "src/main/java" in result + + def test_converts_kebab_to_camelcase(self, tmp_path: Path) -> None: + pom = tmp_path / "pom.xml" + pom.write_text( + "\n \n \n\n", + encoding="utf-8", + ) + + from codeflash.setup.config_writer import _write_maven_properties + + ok, _ = _write_maven_properties(pom, {"ignore-paths": ["target", "build"]}) + result = pom.read_text(encoding="utf-8") + + assert ok + assert "target,build" in result + + +class TestRemoveJavaBuildConfig: + """Tests for _remove_java_build_config — preserves formatting during removal.""" + + def test_removes_codeflash_from_pom_preserving_others(self, tmp_path: Path) -> None: + pom = tmp_path / "pom.xml" + pom.write_text( + "\n" + " \n" + " \n" + " 17\n" + " src/main/java\n" + " \n" + "\n", + encoding="utf-8", + ) + + from codeflash.setup.config_writer import _remove_java_build_config + + ok, _ = _remove_java_build_config(tmp_path) + result = pom.read_text(encoding="utf-8") + + assert ok + assert "" in result + assert "17" in result + assert "codeflash.moduleRoot" not in result + + def test_removes_codeflash_from_gradle_properties(self, tmp_path: Path) -> None: + gradle = tmp_path / "gradle.properties" + gradle.write_text( + "org.gradle.jvmargs=-Xmx2g\n" + "# Codeflash configuration \u2014 https://docs.codeflash.ai\n" + "codeflash.moduleRoot=src/main/java\n" + "codeflash.testsRoot=src/test/java\n", + encoding="utf-8", + ) + + from codeflash.setup.config_writer import _remove_java_build_config + + ok, _ = _remove_java_build_config(tmp_path) + result = gradle.read_text(encoding="utf-8") + + assert ok + assert "org.gradle.jvmargs=-Xmx2g" in result + assert "codeflash." not in result diff --git a/tests/test_setup/test_detector.py b/tests/test_setup/test_detector.py index 781d393e6..3b0e165c8 100644 --- a/tests/test_setup/test_detector.py +++ b/tests/test_setup/test_detector.py @@ -558,6 +558,22 @@ def test_returns_false_when_no_config(self, tmp_path): assert has_config is False assert config_type is None + def test_java_pom_xml_is_zero_config(self, tmp_path): + """Java projects with pom.xml are zero-config — build file presence means configured.""" + (tmp_path / "pom.xml").write_text("4.0.0") + + has_config, config_type = has_existing_config(tmp_path) + assert has_config is True + assert config_type == "pom.xml" + + def test_java_build_gradle_is_zero_config(self, tmp_path): + """Java projects with build.gradle are zero-config — build file presence means configured.""" + (tmp_path / "build.gradle").write_text('plugins { id "java" }') + + has_config, config_type = has_existing_config(tmp_path) + assert has_config is True + assert config_type == "build.gradle" + def test_returns_false_for_empty_directory(self, tmp_path): """Should return False for empty directory.""" has_config, config_type = has_existing_config(tmp_path)