Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
183 changes: 172 additions & 11 deletions src/compile/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -486,26 +486,37 @@ fn list_fold<'brand>(
f_array: &ProgNode<'brand>,
f_fold: &ProgNode<'brand>,
) -> Result<ProgNode<'brand>, simplicity::types::Error> {
/* (fold f)_(n + 1) : E<2^(n + 1) × A → A
/* (fold f)_(n + 1) : E^<2^(n + 1) × A → A
* (fold f)_(n + 1) := OOH ▵ (OIH ▵ IH);
* case (drop (fold f)_n)
* ((IOH ▵ (OH ▵ IIH; f_n)); (fold f)_n)
* (IOH ▵ case IIH (OH ▵ IIH; f_n));
* (fold f)_n
*
* Uses f_fold exactly once (no DAG sharing) to avoid incorrect pruning.
*/
let ctx = f_array.inference_context();
// Rearrange input: ((1+E^n) × E^<n) × A → (1+E^n, (E^<n, A))
let case_input = ProgNode::o()
.o()
.h(ctx)
.pair(ProgNode::o().i().h(ctx).pair(ProgNode::i().h(ctx)));
let case_left = ProgNode::drop_(f_fold);

let f_n_input = ProgNode::o().h(ctx).pair(ProgNode::i().i().h(ctx));
let f_n_output = f_n_input.comp(f_array)?;
let fold_n_input = ProgNode::i().o().h(ctx).pair(f_n_output);
let case_right = fold_n_input.comp(f_fold)?;
// acc' = case IIH (OH ▵ IIH; f_n)
// Left(()): head empty, keep acc unchanged (IIH)
let left_branch = ProgNode::i().i().h(ctx);
// Right(h): apply f_n(h, acc) = (OH ▵ IIH); f_n
let right_branch = ProgNode::o()
.h(ctx)
.pair(ProgNode::i().i().h(ctx))
.comp(f_array)?;
let acc_prime = ProgNode::case(left_branch.as_ref(), right_branch.as_ref())?;

// fold_n_input = IOH ▵ acc': gives (tail, acc')
let ioh = ProgNode::i().o().h(ctx).build();
let fold_n_input = ProgNode::pair(&ioh, &acc_prime)?;

case_input
.comp(&ProgNode::case(&case_left, case_right.as_ref())?)
.map(PairBuilder::build)
// Compose: case_input; (IOH ▵ acc'); f_fold
let middle = ProgNode::comp(&fold_n_input, f_fold)?;
case_input.comp(&middle).map(PairBuilder::build)
}

/* f_0 : E × A → A
Expand Down Expand Up @@ -680,3 +691,153 @@ impl Match {
input.comp(&output).with_span(self)
}
}

#[cfg(test)]
mod tests {
use std::borrow::Cow;

use crate::{tests::TestCase, WitnessValues};

// Helper: fold with addition, f(elt, acc) = acc + elt
const ADD_FN: &str = r#"fn add(elt: u32, acc: u32) -> u32 {
let (_, sum): (bool, u32) = jet::add_32(elt, acc);
sum
}
"#;

// Helper: fold that ignores the accumulator and returns the element.
// When folded left-to-right over a non-empty list, the final result equals the LAST element.
const LAST_FN: &str = r#"fn last(elt: u32, _acc: u32) -> u32 {
elt
}
"#;

fn make_prog(fns: &str, body: &str) -> String {
format!("{fns}\nfn main() {{\n{body}\n}}\n")
}

fn run(prog: &str) {
TestCase::program_text(Cow::Owned(prog.to_owned()))
.with_witness_values(WitnessValues::default())
.assert_run_success();
}

// ── bound = 2 (list holds 0 or 1 elements) ──────────────────────────

#[test]
fn list_fold_empty_bound2() {
// Empty list: fold must return the initial accumulator unchanged.
run(&make_prog(
ADD_FN,
r#" let list: List<u32, 2> = list![];
let result: u32 = fold::<add, 2>(list, 77);
assert!(jet::eq_32(result, 77));"#,
));
}

#[test]
fn list_fold_single_element_bound2() {
// One-element list: fold must apply f exactly once.
run(&make_prog(
ADD_FN,
r#" let list: List<u32, 2> = list![42];
let result: u32 = fold::<add, 2>(list, 0);
assert!(jet::eq_32(result, 42));"#,
));
}

// ── bound = 4 (list holds 0–3 elements) ──────────────────────────────

#[test]
fn list_fold_empty_bound4() {
// Empty list with larger bound: must still return initial accumulator.
run(&make_prog(
ADD_FN,
r#" let list: List<u32, 4> = list![];
let result: u32 = fold::<add, 4>(list, 99);
assert!(jet::eq_32(result, 99));"#,
));
}

#[test]
fn list_fold_one_element_bound4() {
// Partition: head-block empty, tail = [5].
run(&make_prog(
ADD_FN,
r#" let list: List<u32, 4> = list![5];
let result: u32 = fold::<add, 4>(list, 0);
assert!(jet::eq_32(result, 5));"#,
));
}

#[test]
fn list_fold_two_elements_bound4() {
// Partition: head-block = [10, 20], tail empty.
run(&make_prog(
ADD_FN,
r#" let list: List<u32, 4> = list![10, 20];
let result: u32 = fold::<add, 4>(list, 0);
assert!(jet::eq_32(result, 30));"#,
));
}

#[test]
fn list_fold_three_elements_bound4() {
// Full list for bound=4: head=[1,2], tail=[3].
run(&make_prog(
ADD_FN,
r#" let list: List<u32, 4> = list![1, 2, 3];
let result: u32 = fold::<add, 4>(list, 0);
assert!(jet::eq_32(result, 6));"#,
));
}

// ── bound = 8 (list holds 0–7 elements) ──────────────────────────────

#[test]
fn list_fold_five_elements_bound8() {
// Partition: outer-head=[1,2,3,4], outer-tail partition [5].
run(&make_prog(
ADD_FN,
r#" let list: List<u32, 8> = list![1, 2, 3, 4, 5];
let result: u32 = fold::<add, 8>(list, 0);
assert!(jet::eq_32(result, 15));"#,
));
}

#[test]
fn list_fold_seven_elements_bound8() {
// Full list for bound=8.
run(&make_prog(
ADD_FN,
r#" let list: List<u32, 8> = list![1, 2, 3, 4, 5, 6, 7];
let result: u32 = fold::<add, 8>(list, 0);
assert!(jet::eq_32(result, 28));"#,
));
}

// ── Ordering tests ────────────────────────────────────────────────────
// `last` ignores the accumulator and returns the element itself.
// A correct left-to-right fold over [a, b, c] returns the *last* element c.
// A right-to-left fold would return the *first* element a instead.

#[test]
fn list_fold_order_bound4() {
run(&make_prog(
LAST_FN,
r#" let list: List<u32, 4> = list![1, 2, 3];
let result: u32 = fold::<last, 4>(list, 0);
assert!(jet::eq_32(result, 3));"#,
));
}

#[test]
fn list_fold_order_bound8() {
run(&make_prog(
LAST_FN,
r#" let list: List<u32, 8> = list![1, 2, 3, 4, 5, 6, 7];
let result: u32 = fold::<last, 8>(list, 0);
assert!(jet::eq_32(result, 7));"#,
));
}
}
Loading