Skip to content

Commit 6d7a93e

Browse files
add thread local safe
1 parent 182b8de commit 6d7a93e

File tree

4 files changed

+917
-0
lines changed

4 files changed

+917
-0
lines changed
Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
package com.lambdatest.selenium.agent;
2+
3+
import java.util.Map;
4+
import java.util.concurrent.ConcurrentHashMap;
5+
import java.util.logging.Logger;
6+
7+
import org.openqa.selenium.WebDriver;
8+
import org.openqa.selenium.remote.RemoteWebDriver;
9+
10+
/**
11+
* ThreadLocal storage for WebDriver instances.
12+
*
13+
* This class provides thread-safe storage for WebDriver fields, enabling
14+
* parallel="methods" execution without code changes.
15+
*
16+
* Each thread maintains its own map of field -> driver mappings.
17+
*/
18+
public class ThreadLocalDriverStorage {
19+
20+
private static final Logger LOGGER = Logger.getLogger(ThreadLocalDriverStorage.class.getName());
21+
22+
// ThreadLocal storage: each thread has its own map of (fieldKey -> driver)
23+
private static final ThreadLocal<Map<String, WebDriver>> THREAD_STORAGE =
24+
ThreadLocal.withInitial(ConcurrentHashMap::new);
25+
26+
// Track which fields have been logged (to avoid spam)
27+
private static final Map<String, Boolean> LOGGED_FIELDS = new ConcurrentHashMap<>();
28+
29+
/**
30+
* Store a WebDriver for the current thread.
31+
*
32+
* @param fieldKey Unique identifier for the field (className.fieldName)
33+
* @param driver The WebDriver instance to store
34+
*/
35+
public static void setDriver(String fieldKey, WebDriver driver) {
36+
if (driver == null) {
37+
THREAD_STORAGE.get().remove(fieldKey);
38+
return;
39+
}
40+
41+
THREAD_STORAGE.get().put(fieldKey, driver);
42+
43+
// Log once per field per thread (first write only)
44+
String logKey = Thread.currentThread().getId() + ":" + fieldKey;
45+
if (!LOGGED_FIELDS.containsKey(logKey)) {
46+
LOGGED_FIELDS.put(logKey, true);
47+
48+
String sessionId = "unknown";
49+
if (driver instanceof RemoteWebDriver) {
50+
RemoteWebDriver rwd = (RemoteWebDriver) driver;
51+
try {
52+
if (rwd.getSessionId() != null) {
53+
sessionId = rwd.getSessionId().toString().substring(0, 8) + "...";
54+
}
55+
} catch (Exception e) {
56+
// Ignore - session might not be available yet
57+
}
58+
}
59+
60+
LOGGER.info(String.format(
61+
"[Thread-%d/%s] ThreadLocal field write: %s (key=%s) -> Session %s",
62+
Thread.currentThread().getId(),
63+
Thread.currentThread().getName(),
64+
fieldKey.substring(fieldKey.lastIndexOf('.') + 1),
65+
fieldKey,
66+
sessionId));
67+
}
68+
}
69+
70+
/**
71+
* Retrieve a WebDriver for the current thread.
72+
*
73+
* @param fieldKey Unique identifier for the field (className.fieldName)
74+
* @return The WebDriver instance, or null if not found
75+
*/
76+
public static WebDriver getDriver(String fieldKey) {
77+
WebDriver driver = THREAD_STORAGE.get().get(fieldKey);
78+
if (driver == null) {
79+
LOGGER.warning(String.format("[Thread-%d/%s] ThreadLocal driver lookup failed for key: %s (mapSize=%d)",
80+
Thread.currentThread().getId(),
81+
Thread.currentThread().getName(),
82+
fieldKey,
83+
THREAD_STORAGE.get().size()));
84+
} else {
85+
LOGGER.fine(String.format("[Thread-%d/%s] ThreadLocal driver lookup success for key: %s",
86+
Thread.currentThread().getId(),
87+
Thread.currentThread().getName(),
88+
fieldKey));
89+
}
90+
return driver;
91+
}
92+
93+
/**
94+
* Clean up ThreadLocal storage for the current thread.
95+
* Should be called after test completes to prevent memory leaks.
96+
*/
97+
public static void cleanupThread() {
98+
Map<String, WebDriver> storage = THREAD_STORAGE.get();
99+
if (storage != null) {
100+
storage.clear();
101+
}
102+
THREAD_STORAGE.remove();
103+
104+
LOGGER.fine(String.format(
105+
"[Thread-%d/%s] Cleaned up ThreadLocal driver storage",
106+
Thread.currentThread().getId(),
107+
Thread.currentThread().getName()));
108+
}
109+
110+
/**
111+
* Get all drivers for the current thread (for debugging).
112+
*/
113+
public static Map<String, WebDriver> getAllDrivers() {
114+
return new ConcurrentHashMap<>(THREAD_STORAGE.get());
115+
}
116+
117+
/**
118+
* Get the number of drivers stored for the current thread.
119+
*/
120+
public static int getDriverCount() {
121+
return THREAD_STORAGE.get().size();
122+
}
123+
}
124+
Lines changed: 262 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,262 @@
1+
package com.lambdatest.selenium.agent;
2+
3+
import java.lang.instrument.ClassFileTransformer;
4+
import java.lang.instrument.IllegalClassFormatException;
5+
import java.security.ProtectionDomain;
6+
import java.util.Map;
7+
import java.util.concurrent.ConcurrentHashMap;
8+
import java.util.logging.Logger;
9+
10+
import org.objectweb.asm.ClassReader;
11+
import org.objectweb.asm.ClassVisitor;
12+
import org.objectweb.asm.ClassWriter;
13+
import org.objectweb.asm.FieldVisitor;
14+
import org.objectweb.asm.MethodVisitor;
15+
import org.objectweb.asm.Opcodes;
16+
17+
/**
18+
* ASM-based transformer that intercepts WebDriver field access and redirects to ThreadLocal.
19+
*
20+
* This makes parallel="methods" execution thread-safe WITHOUT any code changes.
21+
*
22+
* Transformation:
23+
* - All GETFIELD on WebDriver fields → call to ThreadLocalDriverStorage.get()
24+
* - All PUTFIELD on WebDriver fields → call to ThreadLocalDriverStorage.set()
25+
*/
26+
public class WebDriverFieldTransformer implements ClassFileTransformer {
27+
28+
private static final Logger LOGGER = Logger.getLogger(WebDriverFieldTransformer.class.getName());
29+
30+
@Override
31+
public byte[] transform(ClassLoader loader, String className, Class<?> classBeingRedefined,
32+
ProtectionDomain protectionDomain, byte[] classfileBuffer)
33+
throws IllegalClassFormatException {
34+
35+
// Only transform test classes (heuristic: classes with test annotations or in test packages)
36+
// Exclude JDK, frameworks, and library classes
37+
if (className == null ||
38+
className.startsWith("java/") ||
39+
className.startsWith("javax/") ||
40+
className.startsWith("sun/") ||
41+
className.startsWith("com/sun/") ||
42+
className.startsWith("jdk/") || // JDK internal classes
43+
className.startsWith("org/xml/") || // XML parsers
44+
className.startsWith("org/jcp/") || // JCP (Java Community Process) internal classes
45+
className.startsWith("apple/") || // Apple-specific classes (macOS JDK)
46+
className.startsWith("org/testng/") ||
47+
className.startsWith("org/junit/") ||
48+
className.startsWith("net/bytebuddy/") ||
49+
className.startsWith("org/openqa/selenium/") ||
50+
className.startsWith("com/lambdatest/selenium/")) { // Don't transform SDK itself
51+
return null; // Don't transform JDK, TestNG, JUnit, ByteBuddy, or Selenium classes
52+
}
53+
54+
try {
55+
ClassReader cr = new ClassReader(classfileBuffer);
56+
ClassWriter cw = new ClassWriter(cr, ClassWriter.COMPUTE_MAXS);
57+
58+
WebDriverFieldVisitor visitor = new WebDriverFieldVisitor(Opcodes.ASM9, cw, className, loader);
59+
cr.accept(visitor, ClassReader.EXPAND_FRAMES);
60+
61+
if (visitor.wasTransformed()) {
62+
LOGGER.info("Transformed WebDriver fields in class: " + className.replace('/', '.'));
63+
return cw.toByteArray();
64+
}
65+
} catch (Exception e) {
66+
LOGGER.warning("Failed to transform class " + className + ": " + e.getMessage());
67+
}
68+
69+
return null;
70+
}
71+
72+
/**
73+
* ASM ClassVisitor that finds and transforms WebDriver field access.
74+
*/
75+
private static class WebDriverFieldVisitor extends ClassVisitor {
76+
77+
private final String className;
78+
private final ClassLoader loader;
79+
private boolean transformed = false;
80+
private final Map<String, String> webDriverFields = new ConcurrentHashMap<>();
81+
private String superClassName;
82+
83+
public WebDriverFieldVisitor(int api, ClassVisitor cv, String className, ClassLoader loader) {
84+
super(api, cv);
85+
this.className = className;
86+
this.loader = loader;
87+
}
88+
89+
@Override
90+
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
91+
this.superClassName = superName;
92+
super.visit(version, access, name, signature, superName, interfaces);
93+
}
94+
95+
@Override
96+
public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) {
97+
// Check if field is WebDriver type
98+
if (descriptor.equals("Lorg/openqa/selenium/WebDriver;") ||
99+
descriptor.equals("Lorg/openqa/selenium/remote/RemoteWebDriver;")) {
100+
webDriverFields.put(name, descriptor);
101+
LOGGER.fine("Found WebDriver field: " + className + "." + name);
102+
}
103+
return super.visitField(access, name, descriptor, signature, value);
104+
}
105+
106+
@Override
107+
public MethodVisitor visitMethod(int access, String name, String descriptor, String signature, String[] exceptions) {
108+
MethodVisitor mv = super.visitMethod(access, name, descriptor, signature, exceptions);
109+
return new WebDriverFieldMethodVisitor(api, mv, className, this);
110+
}
111+
112+
public boolean wasTransformed() {
113+
return transformed || !webDriverFields.isEmpty();
114+
}
115+
116+
public void markTransformed() {
117+
transformed = true;
118+
}
119+
120+
public Map<String, String> getWebDriverFields() {
121+
return webDriverFields;
122+
}
123+
124+
public String getSuperClassName() {
125+
return superClassName;
126+
}
127+
}
128+
129+
/**
130+
* ASM MethodVisitor that intercepts GETFIELD and PUTFIELD instructions on WebDriver fields.
131+
*/
132+
private static class WebDriverFieldMethodVisitor extends MethodVisitor {
133+
134+
private final String className;
135+
private final WebDriverFieldVisitor parentVisitor;
136+
137+
public WebDriverFieldMethodVisitor(int api, MethodVisitor mv, String className,
138+
WebDriverFieldVisitor parentVisitor) {
139+
super(api, mv);
140+
this.className = className;
141+
this.parentVisitor = parentVisitor;
142+
}
143+
144+
@Override
145+
public void visitFieldInsn(int opcode, String owner, String name, String descriptor) {
146+
// Check if this is a WebDriver field access
147+
// Note: owner might be a parent class (inheritance), so we check descriptor only
148+
boolean isWebDriverField = descriptor.equals("Lorg/openqa/selenium/WebDriver;") ||
149+
descriptor.equals("Lorg/openqa/selenium/remote/RemoteWebDriver;");
150+
151+
if (isWebDriverField) {
152+
String key = resolveCanonicalFieldKey(owner, name);
153+
if (opcode == Opcodes.PUTFIELD) {
154+
// Intercept field write: driver = value
155+
// Stack before: [obj, value]
156+
// We need to call: ThreadLocalDriverStorage.setDriver(className, fieldName, value)
157+
parentVisitor.markTransformed();
158+
159+
// Swap value and object so we can remove object while keeping value
160+
mv.visitInsn(Opcodes.SWAP); // Stack: [value, obj]
161+
mv.visitInsn(Opcodes.POP); // Remove obj -> [value]
162+
163+
// Prepare arguments for setDriver(String, WebDriver)
164+
mv.visitLdcInsn(key); // Push field key -> [value, key]
165+
mv.visitInsn(Opcodes.SWAP); // -> [key, value]
166+
mv.visitMethodInsn(Opcodes.INVOKESTATIC,
167+
"com/lambdatest/selenium/agent/ThreadLocalDriverStorage",
168+
"setDriver",
169+
"(Ljava/lang/String;Lorg/openqa/selenium/WebDriver;)V",
170+
false);
171+
172+
return; // Don't call original PUTFIELD
173+
174+
} else if (opcode == Opcodes.GETFIELD) {
175+
// Intercept field read: value = driver
176+
// Stack before: [obj]
177+
// We need to call: ThreadLocalDriverStorage.getDriver(className, fieldName)
178+
parentVisitor.markTransformed();
179+
180+
mv.visitInsn(Opcodes.POP); // Remove obj from stack
181+
182+
// Call ThreadLocalDriverStorage.getDriver
183+
mv.visitLdcInsn(key); // Push field key
184+
mv.visitMethodInsn(Opcodes.INVOKESTATIC,
185+
"com/lambdatest/selenium/agent/ThreadLocalDriverStorage",
186+
"getDriver",
187+
"(Ljava/lang/String;)Lorg/openqa/selenium/WebDriver;",
188+
false);
189+
190+
return; // Don't call original GETFIELD
191+
}
192+
}
193+
194+
// Not a WebDriver field, use original instruction
195+
super.visitFieldInsn(opcode, owner, name, descriptor);
196+
}
197+
198+
private String resolveCanonicalFieldKey(String ownerInternalName, String fieldName) {
199+
// Use ASM to walk the class hierarchy without loading classes (avoids LinkageError)
200+
String declaringClass = findDeclaringClassViaASM(ownerInternalName, fieldName);
201+
if (declaringClass != null) {
202+
return declaringClass.replace('/', '.') + "." + fieldName;
203+
}
204+
return ownerInternalName.replace('/', '.') + "." + fieldName;
205+
}
206+
207+
private String findDeclaringClassViaASM(String startClass, String fieldName) {
208+
String current = startClass;
209+
ClassLoader cl = parentVisitor.loader;
210+
211+
// Walk up the hierarchy using ASM to avoid class loading
212+
while (current != null && !current.equals("java/lang/Object")) {
213+
try {
214+
// Read the class bytecode
215+
String resourceName = current + ".class";
216+
java.io.InputStream is = cl.getResourceAsStream(resourceName);
217+
if (is == null) {
218+
break;
219+
}
220+
221+
ClassReader cr = new ClassReader(is);
222+
FieldFinder finder = new FieldFinder(fieldName);
223+
cr.accept(finder, ClassReader.SKIP_CODE | ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES);
224+
225+
if (finder.fieldFound) {
226+
return current;
227+
}
228+
229+
current = finder.superName;
230+
} catch (Exception e) {
231+
break;
232+
}
233+
}
234+
return null;
235+
}
236+
237+
private static class FieldFinder extends ClassVisitor {
238+
private final String targetField;
239+
boolean fieldFound = false;
240+
String superName;
241+
242+
FieldFinder(String targetField) {
243+
super(Opcodes.ASM9);
244+
this.targetField = targetField;
245+
}
246+
247+
@Override
248+
public void visit(int version, int access, String name, String signature, String superName, String[] interfaces) {
249+
this.superName = superName;
250+
}
251+
252+
@Override
253+
public FieldVisitor visitField(int access, String name, String descriptor, String signature, Object value) {
254+
if (name.equals(targetField)) {
255+
fieldFound = true;
256+
}
257+
return null;
258+
}
259+
}
260+
}
261+
}
262+

0 commit comments

Comments
 (0)