Skip to content

Commit e2ffa00

Browse files
committed
Add GPU data dump hook for circuit gates + witness export
1 parent 464f687 commit e2ffa00

File tree

1 file changed

+72
-0
lines changed
  • expander_compiler/src/zkcuda/proving_system/expander_local_deferred

1 file changed

+72
-0
lines changed

expander_compiler/src/zkcuda/proving_system/expander_local_deferred/api.rs

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -181,6 +181,10 @@ fn prove_one<C: GKREngine, ECCConfig: Config<FieldConfig = C::FieldConfig>>(
181181
}).collect();
182182
&mut _sps_vec
183183
};
184+
// Optional: dump circuit data for GPU prover
185+
if std::env::var("DUMP_GPU_DATA").is_ok() {
186+
dump_circuits_for_gpu(ti, pc, &tc, &circuits);
187+
}
184188
let t1 = std::time::Instant::now();
185189
let (cv, ch) = gkr_prove_batch(&circuits, sps, &mut tr);
186190
// Drop batch circuits without freeing shared gates or aliased buffers
@@ -223,3 +227,71 @@ fn prove_one<C: GKREngine, ECCConfig: Config<FieldConfig = C::FieldConfig>>(
223227
ExpanderProof { data: vec![tr.finalize_and_get_proof()] }
224228
}
225229
}
230+
231+
/// Dump circuit data (gates + witness) to binary files for GPU prover.
232+
/// Each template gets its own directory: gpu_data/tmpl_{ti}/
233+
fn dump_circuits_for_gpu<C: GKREngine>(
234+
ti: usize,
235+
pc: usize,
236+
template_circuit: &expander_circuit::Circuit<C>,
237+
circuits: &[expander_circuit::Circuit<C>],
238+
) {
239+
use std::io::Write;
240+
let dir = format!("gpu_data/tmpl_{}", ti);
241+
std::fs::create_dir_all(&dir).ok();
242+
243+
let num_layers = template_circuit.layers.len();
244+
245+
// Write header: N, num_layers, per-layer sizes
246+
let mut hdr = std::fs::File::create(format!("{}/header.bin", dir)).unwrap();
247+
hdr.write_all(&(pc as u32).to_le_bytes()).unwrap();
248+
hdr.write_all(&(num_layers as u32).to_le_bytes()).unwrap();
249+
for layer in &template_circuit.layers {
250+
hdr.write_all(&(layer.input_var_num as u32).to_le_bytes()).unwrap();
251+
hdr.write_all(&(layer.output_var_num as u32).to_le_bytes()).unwrap();
252+
hdr.write_all(&(layer.mul.len() as u32).to_le_bytes()).unwrap();
253+
hdr.write_all(&(layer.add.len() as u32).to_le_bytes()).unwrap();
254+
}
255+
256+
// Write gates per layer (shared across all instances)
257+
for (li, layer) in template_circuit.layers.iter().enumerate() {
258+
// Mul gates: [o_id, x_id, y_id, coef] x n_mul
259+
let mut gf = std::fs::File::create(format!("{}/layer_{}_mul.bin", dir, li)).unwrap();
260+
for gate in &layer.mul {
261+
gf.write_all(&(gate.o_id as u32).to_le_bytes()).unwrap();
262+
gf.write_all(&(gate.i_ids[0] as u32).to_le_bytes()).unwrap();
263+
gf.write_all(&(gate.i_ids[1] as u32).to_le_bytes()).unwrap();
264+
// coef is M31, extract .v field via unsafe transmute
265+
let coef_bytes: [u8; 4] = unsafe { std::mem::transmute(gate.coef) };
266+
gf.write_all(&coef_bytes).unwrap();
267+
}
268+
269+
// Add gates: [o_id, x_id, coef] x n_add
270+
let mut af = std::fs::File::create(format!("{}/layer_{}_add.bin", dir, li)).unwrap();
271+
for gate in &layer.add {
272+
af.write_all(&(gate.o_id as u32).to_le_bytes()).unwrap();
273+
af.write_all(&(gate.i_ids[0] as u32).to_le_bytes()).unwrap();
274+
let coef_bytes: [u8; 4] = unsafe { std::mem::transmute(gate.coef) };
275+
af.write_all(&coef_bytes).unwrap();
276+
}
277+
}
278+
279+
// Write witness (input_vals) per instance per layer
280+
// Each instance's layer 0 input_vals = the actual witness
281+
// Format: raw M31x16 values as [u32; 16] per element
282+
for (pi, circuit) in circuits.iter().enumerate() {
283+
let mut wf = std::fs::File::create(format!("{}/witness_{}.bin", dir, pi)).unwrap();
284+
// Only dump layer 0 input_vals (other layers computed by evaluate())
285+
let vals = &circuit.layers[0].input_vals;
286+
// SimdCircuitField = M31x16, each is [M31; 16] = [u32; 16] = 64 bytes
287+
let bytes: &[u8] = unsafe {
288+
std::slice::from_raw_parts(
289+
vals.as_ptr() as *const u8,
290+
vals.len() * std::mem::size_of_val(&vals[0]),
291+
)
292+
};
293+
wf.write_all(bytes).unwrap();
294+
}
295+
296+
eprintln!(" [dump] tmpl[{}] N={} layers={} → {}/", ti, pc, num_layers, dir);
297+
}

0 commit comments

Comments
 (0)