diff --git a/.github/workflows/release-rust.yml b/.github/workflows/release-rust.yml index 27a77b38..a507d4b0 100644 --- a/.github/workflows/release-rust.yml +++ b/.github/workflows/release-rust.yml @@ -31,10 +31,9 @@ jobs: - os: macos-latest target: x86_64-apple-darwin suffix: darwin-x64 - # TODO: Windows support needs cross-platform IPC (see feat/windows-support branch) - # - os: windows-latest - # target: x86_64-pc-windows-msvc - # suffix: windows-x64 + - os: windows-latest + target: x86_64-pc-windows-msvc + suffix: windows-x64 runs-on: ${{ matrix.os }} @@ -49,17 +48,19 @@ jobs: if: runner.os == 'macOS' run: brew install protobuf - # - name: Install protoc (Windows) - # if: runner.os == 'Windows' - # uses: arduino/setup-protoc@v3 - # with: - # repo-token: ${{ secrets.GITHUB_TOKEN }} + - name: Install protoc (Windows) + if: runner.os == 'Windows' + uses: arduino/setup-protoc@v3 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} - name: Install Rust toolchain uses: dtolnay/rust-toolchain@stable with: targets: ${{ matrix.target }} + - uses: Swatinem/rust-cache@v2 + - name: Install cross-compilation tools if: matrix.cross run: cargo install cross --git https://github.com/cross-rs/cross @@ -73,12 +74,22 @@ jobs: run: cross build --release --target ${{ matrix.target }} -p ahandd -p ahandctl - name: Prepare artifacts + if: runner.os != 'Windows' run: | mkdir -p release cp target/${{ matrix.target }}/release/ahandd release/ahandd-${{ matrix.suffix }} cp target/${{ matrix.target }}/release/ahandctl release/ahandctl-${{ matrix.suffix }} cd release && shasum -a 256 * > checksums-rust-${{ matrix.suffix }}.txt + - name: Prepare artifacts (Windows) + if: runner.os == 'Windows' + shell: bash + run: | + mkdir -p release + cp target/${{ matrix.target }}/release/ahandd.exe release/ahandd-${{ matrix.suffix }}.exe + cp target/${{ matrix.target }}/release/ahandctl.exe release/ahandctl-${{ matrix.suffix }}.exe + cd release && sha256sum * > checksums-rust-${{ matrix.suffix }}.txt + - name: Upload artifacts uses: actions/upload-artifact@v4 with: diff --git a/.github/workflows/test-rust.yml b/.github/workflows/test-rust.yml new file mode 100644 index 00000000..6a2d39ec --- /dev/null +++ b/.github/workflows/test-rust.yml @@ -0,0 +1,45 @@ +name: Test Rust + +on: + push: + branches: [dev, main] + paths: + - "crates/**" + - "proto/**" + - "Cargo.*" + pull_request: + paths: + - "crates/**" + - "proto/**" + - "Cargo.*" + +jobs: + test: + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + runs-on: ${{ matrix.os }} + steps: + - uses: actions/checkout@v4 + + - name: Install system dependencies (Linux) + if: runner.os == 'Linux' + run: sudo apt-get update && sudo apt-get install -y protobuf-compiler libssl-dev pkg-config + + - name: Install protoc (macOS) + if: runner.os == 'macOS' + run: brew install protobuf + + - name: Install protoc (Windows) + if: runner.os == 'Windows' + uses: arduino/setup-protoc@v3 + with: + repo-token: ${{ secrets.GITHUB_TOKEN }} + + - uses: dtolnay/rust-toolchain@stable + + - uses: Swatinem/rust-cache@v2 + + - name: Run tests + run: cargo test --workspace diff --git a/Cargo.lock b/Cargo.lock index f8bce4da..4c43d3a2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,17 @@ version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + [[package]] name = "ahand-hub" version = "0.1.2" @@ -96,8 +107,10 @@ dependencies = [ "open", "prost", "rand 0.8.5", + "reqwest", "serde", "serde_json", + "sha2", "tokio", "tokio-tungstenite 0.24.0", "toml", @@ -119,13 +132,17 @@ dependencies = [ "flate2", "futures-util", "gethostname", + "hex", + "libc", "prost", "rand 0.8.5", "reqwest", + "semver", "serde", "serde_json", "sha2", "tar", + "tempfile", "thiserror 2.0.18", "tokio", "tokio-tungstenite 0.24.0", @@ -135,7 +152,9 @@ dependencies = [ "tracing-subscriber", "url", "uuid", + "windows-sys 0.59.0", "xz2", + "zip", ] [[package]] @@ -218,6 +237,15 @@ version = "1.0.100" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a23eb6b1614318a8071c9b2521f36b424b2c83db5eb3a0fead4a6c0809af6e61" +[[package]] +name = "arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3d036a3c4ab069c7b410a2ce876bd74808d2d0888a82667669f8e783a898bf1" +dependencies = [ + "derive_arbitrary", +] + [[package]] name = "arc-swap" version = "1.9.0" @@ -488,6 +516,25 @@ version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b35204fbdc0b3f4446b89fc1ac2cf84a8a68971995d0bf2e925ec7cd960f9cb3" +[[package]] +name = "bzip2" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49ecfb22d906f800d4fe833b6282cf4dc1c298f5057ca0b5445e5c209735ca47" +dependencies = [ + "bzip2-sys", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.13+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "225bff33b2141874fe80d71e07d6eec4f85c5c216453dd96388240f96e1acc14" +dependencies = [ + "cc", + "pkg-config", +] + [[package]] name = "cc" version = "1.2.55" @@ -495,6 +542,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47b26a0954ae34af09b50f0de26458fa95369a0d478d8236d3f93082b219bd29" dependencies = [ "find-msvc-tools", + "jobserver", + "libc", "shlex", ] @@ -504,6 +553,12 @@ version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" +[[package]] +name = "cfg_aliases" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613afe47fcd5fac7ccf1db93babcb082c5994d996f20b8b159f2ad1658eb5724" + [[package]] name = "chrono" version = "0.4.44" @@ -518,6 +573,16 @@ dependencies = [ "windows-link", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clap" version = "4.5.56" @@ -593,6 +658,12 @@ version = "0.9.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8" +[[package]] +name = "constant_time_eq" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7c74b8349d32d297c9134b8c88677813a227df8f779daa29bfc29c183fe3dca6" + [[package]] name = "core-foundation" version = "0.9.4" @@ -758,6 +829,12 @@ version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d7a1e2f27636f116493b8b860f5546edb47c8d8f8ea73e1d2a20be88e28d1fea" +[[package]] +name = "deflate64" +version = "0.1.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac6b926516df9c60bfa16e107b21086399f8285a44ca9711344b9e553c5146e2" + [[package]] name = "der" version = "0.7.10" @@ -779,6 +856,17 @@ dependencies = [ "serde_core", ] +[[package]] +name = "derive_arbitrary" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e567bd82dcff979e4b03460c307b3cdc9e96fde3d73bed1496d2bc75d9dd62a" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "digest" version = "0.10.7" @@ -1191,9 +1279,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" dependencies = [ "cfg-if", + "js-sys", "libc", "r-efi", "wasip2", + "wasm-bindgen", ] [[package]] @@ -1484,6 +1574,7 @@ dependencies = [ "tokio", "tokio-rustls", "tower-service", + "webpki-roots 1.0.6", ] [[package]] @@ -1697,6 +1788,15 @@ dependencies = [ "serde_core", ] +[[package]] +name = "inout" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +dependencies = [ + "generic-array", +] + [[package]] name = "ipnet" version = "2.12.0" @@ -1753,6 +1853,16 @@ version = "1.0.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "92ecc6618181def0457392ccd0ee51198e065e016d1d527a7ac1b6dc7c1f09d2" +[[package]] +name = "jobserver" +version = "0.1.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9afb3de4395d6b3e67a780b6de64b51c978ecf11cb9a462c66be7d4ca9039d33" +dependencies = [ + "getrandom 0.3.4", + "libc", +] + [[package]] name = "js-sys" version = "0.3.85" @@ -1847,6 +1957,22 @@ version = "0.4.29" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +[[package]] +name = "lru-slab" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" + +[[package]] +name = "lzma-rs" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "297e814c836ae64db86b36cf2a557ba54368d03f6afcd7d947c266692f71115e" +dependencies = [ + "byteorder", + "crc", +] + [[package]] name = "lzma-sys" version = "0.1.20" @@ -2179,6 +2305,16 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "df94ce210e5bc13cb6651479fa48d14f601d9858cfe0467f43ae157023b938d3" +[[package]] +name = "pbkdf2" +version = "0.12.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2" +dependencies = [ + "digest", + "hmac", +] + [[package]] name = "pem" version = "3.0.6" @@ -2432,6 +2568,61 @@ version = "3.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95067976aca6421a523e491fce939a3e65249bac4b977adee0ee9771568e8aa3" +[[package]] +name = "quinn" +version = "0.11.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9e20a958963c291dc322d98411f541009df2ced7b5a4f2bd52337638cfccf20" +dependencies = [ + "bytes", + "cfg_aliases", + "pin-project-lite", + "quinn-proto", + "quinn-udp", + "rustc-hash", + "rustls", + "socket2 0.5.10", + "thiserror 2.0.18", + "tokio", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-proto" +version = "0.11.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1906b49b0c3bc04b5fe5d86a77925ae6524a19b816ae38ce1e426255f1d8a31" +dependencies = [ + "bytes", + "getrandom 0.3.4", + "lru-slab", + "rand 0.9.2", + "ring", + "rustc-hash", + "rustls", + "rustls-pki-types", + "slab", + "thiserror 2.0.18", + "tinyvec", + "tracing", + "web-time", +] + +[[package]] +name = "quinn-udp" +version = "0.5.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "addec6a0dcad8a8d96a771f815f0eaf55f9d1805756410b39f5fa81332574cbd" +dependencies = [ + "cfg_aliases", + "libc", + "once_cell", + "socket2 0.5.10", + "tracing", + "windows-sys 0.52.0", +] + [[package]] name = "quote" version = "1.0.44" @@ -2653,6 +2844,8 @@ dependencies = [ "native-tls", "percent-encoding", "pin-project-lite", + "quinn", + "rustls", "rustls-pki-types", "serde", "serde_json", @@ -2660,6 +2853,7 @@ dependencies = [ "sync_wrapper", "tokio", "tokio-native-tls", + "tokio-rustls", "tokio-util", "tower", "tower-http", @@ -2669,6 +2863,7 @@ dependencies = [ "wasm-bindgen-futures", "wasm-streams", "web-sys", + "webpki-roots 1.0.6", ] [[package]] @@ -2705,6 +2900,12 @@ dependencies = [ "zeroize", ] +[[package]] +name = "rustc-hash" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" + [[package]] name = "rustc_version" version = "0.4.1" @@ -2768,6 +2969,7 @@ version = "1.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be040f8b0a225e40375822a563fa9524378b9d63112f53e19ffff34df5d33fdd" dependencies = [ + "web-time", "zeroize", ] @@ -4181,6 +4383,16 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "web-time" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5a6580f308b1fad9207618087a65c04e7a10bc77e02c8e84e9b00dd4b12fa0bb" +dependencies = [ + "js-sys", + "wasm-bindgen", +] + [[package]] name = "webpki-roots" version = "0.26.11" @@ -4641,6 +4853,20 @@ name = "zeroize" version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b97154e67e32c85465826e8bcc1c59429aaaf107c1e4a9e53c8d8ccd5eff88d0" +dependencies = [ + "zeroize_derive", +] + +[[package]] +name = "zeroize_derive" +version = "1.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "85a5b4158499876c763cb03bc4e49185d3cccbabb15b33c627f7884f43db852e" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] [[package]] name = "zerotrie" @@ -4675,8 +4901,78 @@ dependencies = [ "syn", ] +[[package]] +name = "zip" +version = "2.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fabe6324e908f85a1c52063ce7aa26b68dcb7eb6dbc83a2d148403c9bc3eba50" +dependencies = [ + "aes", + "arbitrary", + "bzip2", + "constant_time_eq", + "crc32fast", + "crossbeam-utils", + "deflate64", + "displaydoc", + "flate2", + "getrandom 0.3.4", + "hmac", + "indexmap 2.13.0", + "lzma-rs", + "memchr", + "pbkdf2", + "sha1", + "thiserror 2.0.18", + "time", + "xz2", + "zeroize", + "zopfli", + "zstd", +] + [[package]] name = "zmij" version = "1.0.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3ff05f8caa9038894637571ae6b9e29466c1f4f829d26c9b28f869a29cbe3445" + +[[package]] +name = "zopfli" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f05cd8797d63865425ff89b5c4a48804f35ba0ce8d125800027ad6017d2b5249" +dependencies = [ + "bumpalo", + "crc32fast", + "log", + "simd-adler32", +] + +[[package]] +name = "zstd" +version = "0.13.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e91ee311a569c327171651566e07972200e76fcfe2242a4fa446149a3881c08a" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "7.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f49c4d5f0abb602a93fb8736af2a4f4dd9512e36f7f570d66e65ff867ed3b9d" +dependencies = [ + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.16+zstd.1.5.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "91e19ebc2adc8f83e43039e79776e3fda8ca919132d68a1fed6a5faca2683748" +dependencies = [ + "cc", + "pkg-config", +] diff --git a/crates/ahandctl/Cargo.toml b/crates/ahandctl/Cargo.toml index da9e3cd8..39fa4924 100644 --- a/crates/ahandctl/Cargo.toml +++ b/crates/ahandctl/Cargo.toml @@ -24,3 +24,5 @@ open = "5" toml = "0.8" rand = "0.8" dirs = "5" +reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json"] } +sha2 = "0.10" diff --git a/crates/ahandctl/src/admin.rs b/crates/ahandctl/src/admin.rs index df468634..01196784 100644 --- a/crates/ahandctl/src/admin.rs +++ b/crates/ahandctl/src/admin.rs @@ -3,7 +3,6 @@ use serde::Serialize; use std::convert::Infallible; use std::path::{Path, PathBuf}; use std::sync::Arc; -use tokio::io::{AsyncBufReadExt, BufReader}; use warp::http::StatusCode; use warp::{Filter, Rejection, Reply, reject}; @@ -356,83 +355,12 @@ fn browser_init_stream() -> impl futures_util::Stream> { let (tx, rx) = tokio::sync::mpsc::unbounded_channel::(); + // Instead of spawning bash, direct user to the Rust-native command tokio::spawn(async move { - let home = match dirs::home_dir() { - Some(h) => h, - None => { - let _ = tx.send( - warp::sse::Event::default() - .data(r#"{"line":"ERROR: Failed to find home directory","status":"error","exit_code":1}"#), - ); - return; - } - }; - - let script_path = home.join(".ahand").join("bin").join("setup-browser.sh"); - if !script_path.exists() { - let msg = format!( - r#"{{"line":"ERROR: setup-browser.sh not found at {}","status":"error","exit_code":1}}"#, - script_path.display() - ); - let _ = tx.send(warp::sse::Event::default().data(msg)); - return; - } - - let mut cmd = tokio::process::Command::new("bash"); - cmd.arg(&script_path); - cmd.arg("--from-release"); - cmd.stdout(std::process::Stdio::piped()); - cmd.stderr(std::process::Stdio::piped()); - - let mut child = match cmd.spawn() { - Ok(c) => c, - Err(e) => { - let msg = format!( - r#"{{"line":"ERROR: Failed to spawn setup-browser.sh: {}","status":"error","exit_code":1}}"#, - e - ); - let _ = tx.send(warp::sse::Event::default().data(msg)); - return; - } - }; - - let stdout = child.stdout.take().expect("stdout"); - let stderr = child.stderr.take().expect("stderr"); - - let tx_out = tx.clone(); - let stdout_task = tokio::spawn(async move { - let reader = BufReader::new(stdout); - let mut lines = reader.lines(); - while let Ok(Some(line)) = lines.next_line().await { - let escaped = line.replace('\\', "\\\\").replace('"', "\\\""); - let data = format!(r#"{{"line":"{}"}}"#, escaped); - if tx_out.send(warp::sse::Event::default().data(data)).is_err() { - break; - } - } - }); - - let tx_err = tx.clone(); - let stderr_task = tokio::spawn(async move { - let reader = BufReader::new(stderr); - let mut lines = reader.lines(); - while let Ok(Some(line)) = lines.next_line().await { - let escaped = line.replace('\\', "\\\\").replace('"', "\\\""); - let data = format!(r#"{{"line":"[stderr] {}"}}"#, escaped); - if tx_err.send(warp::sse::Event::default().data(data)).is_err() { - break; - } - } - }); - - let status = child.wait().await; - let _ = stdout_task.await; - let _ = stderr_task.await; - - let exit_code = status.map(|s| s.code().unwrap_or(1)).unwrap_or(1); - let status_str = if exit_code == 0 { "done" } else { "error" }; - let data = format!(r#"{{"status":"{}","exit_code":{}}}"#, status_str, exit_code); - let _ = tx.send(warp::sse::Event::default().data(data)); + let msg = r#"{"line":"Browser setup is now built into the daemon. Run: ahandd browser-init","status":"info"}"#; + let _ = tx.send(warp::sse::Event::default().data(msg)); + let msg = r#"{"line":"For force reinstall: ahandd browser-init --force","status":"complete","exit_code":0}"#; + let _ = tx.send(warp::sse::Event::default().data(msg)); }); futures_util::stream::unfold(rx, |mut rx| async move { @@ -732,7 +660,7 @@ fn is_process_running(pid: u32) -> bool { std::path::Path::new(&format!("/proc/{}", pid)).exists() } -#[cfg(not(target_os = "linux"))] +#[cfg(all(not(target_os = "linux"), unix))] fn is_process_running(pid: u32) -> bool { use std::process::Command; Command::new("ps") @@ -741,3 +669,17 @@ fn is_process_running(pid: u32) -> bool { .map(|output| output.status.success()) .unwrap_or(false) } + +#[cfg(windows)] +fn is_process_running(pid: u32) -> bool { + use std::process::Command; + Command::new("tasklist") + .args(["/FI", &format!("PID eq {}", pid), "/NH"]) + .output() + .map(|output| { + let stdout = String::from_utf8_lossy(&output.stdout); + output.status.success() + && stdout.split_whitespace().any(|w| w == pid.to_string().as_str()) + }) + .unwrap_or(false) +} diff --git a/crates/ahandctl/src/daemon.rs b/crates/ahandctl/src/daemon.rs index 767d8bfc..5b21d66d 100644 --- a/crates/ahandctl/src/daemon.rs +++ b/crates/ahandctl/src/daemon.rs @@ -16,9 +16,11 @@ fn get_log_path() -> Result { /// Find the ahandd binary: installed path → sibling of current exe → error. fn find_ahandd_binary() -> Result { + let binary_name = if cfg!(windows) { "ahandd.exe" } else { "ahandd" }; + // 1. Installed location: ~/.ahand/bin/ahandd if let Some(home) = dirs::home_dir() { - let installed = home.join(".ahand").join("bin").join("ahandd"); + let installed = home.join(".ahand").join("bin").join(binary_name); if installed.exists() { return Ok(installed); } @@ -27,7 +29,7 @@ fn find_ahandd_binary() -> Result { // 2. Sibling of current executable (dev builds: target/debug/) if let Ok(current_exe) = std::env::current_exe() { if let Some(dir) = current_exe.parent() { - let sibling = dir.join("ahandd"); + let sibling = dir.join(binary_name); if sibling.exists() { return Ok(sibling); } @@ -62,7 +64,7 @@ fn is_process_running(pid: u32) -> bool { std::path::Path::new(&format!("/proc/{}", pid)).exists() } -#[cfg(not(target_os = "linux"))] +#[cfg(all(not(target_os = "linux"), unix))] fn is_process_running(pid: u32) -> bool { std::process::Command::new("ps") .args(["-p", &pid.to_string()]) @@ -73,6 +75,21 @@ fn is_process_running(pid: u32) -> bool { .unwrap_or(false) } +#[cfg(windows)] +fn is_process_running(pid: u32) -> bool { + std::process::Command::new("tasklist") + .args(["/FI", &format!("PID eq {}", pid), "/NH"]) + .output() + .map(|output| { + let stdout = String::from_utf8_lossy(&output.stdout); + // Locale-independent: check if PID appears as a word in output + output.status.success() + && stdout.split_whitespace().any(|w| w == pid.to_string().as_str()) + }) + .unwrap_or(false) +} + +#[cfg(unix)] fn send_signal(pid: u32, sig: &str) -> Result<()> { let status = std::process::Command::new("kill") .args([sig, &pid.to_string()]) @@ -84,6 +101,18 @@ fn send_signal(pid: u32, sig: &str) -> Result<()> { Ok(()) } +#[cfg(windows)] +fn send_signal(pid: u32, _sig: &str) -> Result<()> { + let status = std::process::Command::new("taskkill") + .args(["/PID", &pid.to_string(), "/F"]) + .status() + .context("Failed to run taskkill command")?; + if !status.success() { + anyhow::bail!("taskkill /PID {} failed", pid); + } + Ok(()) +} + pub async fn start(config: Option) -> Result<()> { if let Some(pid) = read_running_pid()? { println!("Daemon is already running (PID {}).", pid); @@ -120,6 +149,12 @@ pub async fn start(config: Option) -> Result<()> { cmd.process_group(0); } + #[cfg(windows)] + { + use std::os::windows::process::CommandExt; + cmd.creation_flags(0x08000008); // CREATE_NO_WINDOW | DETACHED_PROCESS + } + let child = cmd .spawn() .with_context(|| format!("Failed to start daemon: {}", ahandd.display()))?; @@ -193,3 +228,75 @@ pub async fn status() -> Result<()> { } Ok(()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn get_data_dir_returns_ahand_data() { + let dir = get_data_dir().unwrap(); + assert!(dir.to_string_lossy().contains(".ahand")); + assert!(dir.to_string_lossy().ends_with("data")); + } + + #[test] + fn get_pid_path_under_data_dir() { + let pid = get_pid_path().unwrap(); + assert!(pid.to_string_lossy().ends_with("daemon.pid")); + } + + #[test] + fn get_log_path_under_data_dir() { + let log = get_log_path().unwrap(); + assert!(log.to_string_lossy().ends_with("daemon.log")); + } + + #[test] + fn is_process_running_with_zero_pid() { + // PID 0 should not be running (it's the system idle process / kernel) + assert!(!is_process_running(0)); + } + + #[test] + fn is_process_running_with_current_pid() { + let pid = std::process::id(); + assert!(is_process_running(pid)); + } + + #[test] + fn is_process_running_with_nonexistent_pid() { + // Very high PID unlikely to exist + assert!(!is_process_running(4_000_000)); + } + + #[test] + fn read_running_pid_no_pid_file() { + // This test depends on whether there's actually a PID file. + // If daemon is not running, should return Ok(None) or Ok(Some(pid)). + // At minimum, it should not panic. + let result = read_running_pid(); + assert!(result.is_ok()); + } + + #[test] + fn find_ahandd_binary_does_not_panic() { + // May succeed or fail depending on environment, + // but should never panic. + let _result = find_ahandd_binary(); + } + + #[cfg(unix)] + #[test] + fn send_signal_to_nonexistent_process() { + let result = send_signal(4_000_000, "-0"); + assert!(result.is_err()); + } + + #[cfg(windows)] + #[test] + fn send_signal_to_nonexistent_process_windows() { + let result = send_signal(4_000_000, "-TERM"); + assert!(result.is_err()); + } +} diff --git a/crates/ahandctl/src/github_release.rs b/crates/ahandctl/src/github_release.rs new file mode 100644 index 00000000..df3a8ef5 --- /dev/null +++ b/crates/ahandctl/src/github_release.rs @@ -0,0 +1,91 @@ +use anyhow::{Context, Result}; + +pub const GITHUB_REPO: &str = "team9ai/aHand"; + +/// Fetch the latest release version from GitHub (strips the `rust-v` prefix). +pub async fn fetch_latest_version() -> Result { + let url = format!("https://api.github.com/repos/{GITHUB_REPO}/releases/latest"); + let client = reqwest::Client::new(); + let resp = client + .get(&url) + .header("User-Agent", "ahandctl") + .send() + .await + .context("Failed to fetch latest release")? + .json::() + .await + .context("Failed to parse release response")?; + let tag = resp["tag_name"] + .as_str() + .context("no tag_name in release")?; + Ok(tag.strip_prefix("rust-v").unwrap_or(tag).to_string()) +} + +/// Returns `(platform_suffix, exe_extension)` for the current target. +pub fn platform_suffix() -> (&'static str, &'static str) { + if cfg!(target_os = "macos") && cfg!(target_arch = "aarch64") { + ("darwin-arm64", "") + } else if cfg!(target_os = "macos") && cfg!(target_arch = "x86_64") { + ("darwin-x64", "") + } else if cfg!(target_os = "linux") && cfg!(target_arch = "x86_64") { + ("linux-x64", "") + } else if cfg!(target_os = "linux") && cfg!(target_arch = "aarch64") { + ("linux-arm64", "") + } else if cfg!(target_os = "windows") && cfg!(target_arch = "x86_64") { + ("windows-x64", ".exe") + } else if cfg!(target_os = "windows") && cfg!(target_arch = "aarch64") { + ("windows-arm64", ".exe") + } else { + ("unknown", "") + } +} + +/// Download raw bytes from `url` with a User-Agent header. +pub async fn download_bytes(url: &str) -> Result> { + let client = reqwest::Client::new(); + let resp = client + .get(url) + .header("User-Agent", "ahandctl") + .send() + .await + .context("HTTP request failed")?; + if !resp.status().is_success() { + anyhow::bail!("HTTP {} for {}", resp.status(), url); + } + let bytes = resp.bytes().await.context("Failed to read response body")?; + Ok(bytes.to_vec()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn platform_suffix_returns_valid_tuple() { + let (suffix, ext) = platform_suffix(); + // On macOS CI: ("darwin-arm64", "") or ("darwin-x64", "") + // On Linux CI: ("linux-x64", "") or ("linux-arm64", "") + // On Windows CI: ("windows-x64", ".exe") + assert!(!suffix.is_empty()); + assert!( + suffix.starts_with("darwin-") + || suffix.starts_with("linux-") + || suffix.starts_with("windows-") + || suffix == "unknown", + "unexpected suffix: {suffix}" + ); + + if suffix.starts_with("windows-") { + assert_eq!(ext, ".exe"); + } else if suffix != "unknown" { + assert_eq!(ext, ""); + } + } + + #[test] + fn platform_suffix_is_deterministic() { + let a = platform_suffix(); + let b = platform_suffix(); + assert_eq!(a, b); + } +} diff --git a/crates/ahandctl/src/install_daemon.rs b/crates/ahandctl/src/install_daemon.rs new file mode 100644 index 00000000..2d9d7918 --- /dev/null +++ b/crates/ahandctl/src/install_daemon.rs @@ -0,0 +1,84 @@ +use anyhow::{Context, Result}; + +use crate::github_release::{self, GITHUB_REPO}; + +pub async fn run(target_version: Option) -> Result<()> { + let version = match target_version { + Some(v) => v, + None => github_release::fetch_latest_version().await?, + }; + + let (suffix, exe_ext) = github_release::platform_suffix(); + let bin_dir = dirs::home_dir() + .context("cannot determine home directory")? + .join(".ahand") + .join("bin"); + std::fs::create_dir_all(&bin_dir)?; + + let asset = format!("ahandd-{suffix}{exe_ext}"); + let url = format!( + "https://github.com/{GITHUB_REPO}/releases/download/rust-v{version}/{asset}" + ); + + println!("Downloading ahandd v{version} ({suffix})..."); + let bytes = github_release::download_bytes(&url) + .await + .with_context(|| format!("Failed to download {asset}"))?; + + // Verify checksum + let checksums_url = format!( + "https://github.com/{GITHUB_REPO}/releases/download/rust-v{version}/checksums-rust-{suffix}.txt" + ); + let checksums_bytes = github_release::download_bytes(&checksums_url) + .await + .context("Failed to download checksums — cannot verify binary integrity")?; + let checksums_str = String::from_utf8_lossy(&checksums_bytes); + if let Some(expected) = checksums_str + .lines() + .find(|line| line.ends_with(&asset)) + .and_then(|line| line.split_whitespace().next()) + { + use sha2::{Digest, Sha256}; + let actual = format!("{:x}", Sha256::digest(&bytes)); + if actual != expected { + anyhow::bail!("Checksum mismatch for {asset}: expected {expected}, got {actual}"); + } + println!(" Checksum OK"); + } else { + anyhow::bail!("Checksum entry missing for {asset} — cannot verify binary integrity"); + } + + let dest = bin_dir.join(format!("ahandd{exe_ext}")); + + // On Windows, rename existing binary before overwriting (can't overwrite running exe) + #[cfg(windows)] + { + let backup = bin_dir.join(format!("ahandd.old{exe_ext}")); + let _ = std::fs::remove_file(&backup); + if dest.exists() { + std::fs::rename(&dest, &backup) + .with_context(|| format!("Failed to backup {}", dest.display()))?; + } + } + + std::fs::write(&dest, &bytes) + .with_context(|| format!("Failed to write {}", dest.display()))?; + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(&dest, std::fs::Permissions::from_mode(0o755))?; + } + + println!("Installed: {}", dest.display()); + Ok(()) +} + +#[cfg(test)] +mod tests { + #[test] + fn install_daemon_module_compiles() { + // Smoke test — module is reachable and types resolve + assert!(true); + } +} diff --git a/crates/ahandctl/src/main.rs b/crates/ahandctl/src/main.rs index 4664ded6..c90ec5b8 100644 --- a/crates/ahandctl/src/main.rs +++ b/crates/ahandctl/src/main.rs @@ -12,6 +12,8 @@ use tracing::info; mod admin; mod browser_init; mod daemon; +mod github_release; +mod install_daemon; mod upgrade; #[derive(Parser)] @@ -84,6 +86,12 @@ enum Cmd { #[arg(long)] version: Option, }, + /// Install the ahandd daemon binary from GitHub Releases + InstallDaemon { + /// Specific version to install (default: latest) + #[arg(long)] + version: Option, + }, /// Start the ahandd daemon in the background Start { /// Path to config file (defaults to ~/.ahand/config.toml) @@ -187,6 +195,9 @@ async fn main() -> anyhow::Result<()> { Cmd::Upgrade { check, version } => { return upgrade::run(*check, version.clone()).await; } + Cmd::InstallDaemon { version } => { + return install_daemon::run(version.clone()).await; + } Cmd::Start { config } => { return daemon::start(config.clone()).await; } @@ -230,6 +241,7 @@ async fn main() -> anyhow::Result<()> { Cmd::Configure { .. } | Cmd::BrowserInit { .. } | Cmd::Upgrade { .. } + | Cmd::InstallDaemon { .. } | Cmd::Start { .. } | Cmd::Stop | Cmd::Restart { .. } @@ -265,6 +277,7 @@ async fn main() -> anyhow::Result<()> { Cmd::Configure { .. } | Cmd::BrowserInit { .. } | Cmd::Upgrade { .. } + | Cmd::InstallDaemon { .. } | Cmd::Start { .. } | Cmd::Stop | Cmd::Restart { .. } @@ -293,17 +306,48 @@ async fn read_frame(reader: &mut R) -> std::io::Result< } async fn write_frame(writer: &mut W, data: &[u8]) -> std::io::Result<()> { + if data.len() > 16 * 1024 * 1024 { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "outgoing frame too large", + )); + } writer.write_u32(data.len() as u32).await?; writer.write_all(data).await?; writer.flush().await?; Ok(()) } +// ── IPC connect (cross-platform) ───────────────────────────────────── + +#[cfg(unix)] +async fn ipc_connect( + path: &str, +) -> anyhow::Result<( + impl tokio::io::AsyncRead + Unpin, + impl tokio::io::AsyncWrite + Unpin, +)> { + let stream = tokio::net::UnixStream::connect(path).await?; + let (r, w) = stream.into_split(); + Ok((r, w)) +} + +#[cfg(windows)] +async fn ipc_connect( + path: &str, +) -> anyhow::Result<( + impl tokio::io::AsyncRead + Unpin, + impl tokio::io::AsyncWrite + Unpin, +)> { + let client = tokio::net::windows::named_pipe::ClientOptions::new().open(path)?; + let (r, w) = tokio::io::split(client); + Ok((r, w)) +} + // ── IPC exec ───────────────────────────────────────────────────────── async fn ipc_exec(socket_path: &str, tool: &str, args: &[String]) -> anyhow::Result<()> { - let stream = tokio::net::UnixStream::connect(socket_path).await?; - let (mut reader, mut writer) = stream.into_split(); + let (mut reader, mut writer) = ipc_connect(socket_path).await?; let mut reader = tokio::io::BufReader::new(&mut reader); let device_id = format!("ctl-{}", std::process::id()); @@ -396,8 +440,7 @@ async fn ipc_exec(socket_path: &str, tool: &str, args: &[String]) -> anyhow::Res // ── IPC cancel ─────────────────────────────────────────────────────── async fn ipc_cancel(socket_path: &str, job_id: &str) -> anyhow::Result<()> { - let stream = tokio::net::UnixStream::connect(socket_path).await?; - let (mut reader, mut writer) = stream.into_split(); + let (mut reader, mut writer) = ipc_connect(socket_path).await?; let mut reader = tokio::io::BufReader::new(&mut reader); let device_id = format!("ctl-{}", std::process::id()); @@ -618,8 +661,7 @@ async fn ws_ping(url: &str) -> anyhow::Result<()> { // ── IPC approve ────────────────────────────────────────────────────── async fn ipc_approve(socket_path: &str) -> anyhow::Result<()> { - let stream = tokio::net::UnixStream::connect(socket_path).await?; - let (mut reader, mut writer) = stream.into_split(); + let (mut reader, mut writer) = ipc_connect(socket_path).await?; let mut reader = tokio::io::BufReader::new(&mut reader); let device_id = format!("ctl-{}", std::process::id()); @@ -722,8 +764,7 @@ async fn ipc_approve(socket_path: &str) -> anyhow::Result<()> { // ── IPC policy ─────────────────────────────────────────────────────── async fn ipc_policy(socket_path: &str, action: PolicyAction) -> anyhow::Result<()> { - let stream = tokio::net::UnixStream::connect(socket_path).await?; - let (mut reader, mut writer) = stream.into_split(); + let (mut reader, mut writer) = ipc_connect(socket_path).await?; let mut reader = tokio::io::BufReader::new(&mut reader); let device_id = format!("ctl-{}", std::process::id()); @@ -824,8 +865,7 @@ async fn ws_policy(url: &str, action: PolicyAction) -> anyhow::Result<()> { // ── IPC session ───────────────────────────────────────────────────── async fn ipc_session(socket_path: &str, action: SessionAction) -> anyhow::Result<()> { - let stream = tokio::net::UnixStream::connect(socket_path).await?; - let (mut reader, mut writer) = stream.into_split(); + let (mut reader, mut writer) = ipc_connect(socket_path).await?; let mut reader = tokio::io::BufReader::new(&mut reader); let device_id = format!("ctl-{}", std::process::id()); diff --git a/crates/ahandctl/src/upgrade.rs b/crates/ahandctl/src/upgrade.rs index dc54b39c..8d86c403 100644 --- a/crates/ahandctl/src/upgrade.rs +++ b/crates/ahandctl/src/upgrade.rs @@ -1,65 +1,121 @@ use anyhow::{Context, Result}; -use std::path::PathBuf; -use tokio::io::{AsyncBufReadExt, BufReader}; -use tokio::process::Command; + +use crate::github_release::{self, GITHUB_REPO}; pub async fn run(check_only: bool, target_version: Option) -> Result<()> { - let script_path = resolve_script_path()?; + let current = current_version(); + let latest = match target_version { + Some(v) => v, + None => github_release::fetch_latest_version().await?, + }; - if !script_path.exists() { - anyhow::bail!( - "upgrade.sh not found at {}\nRun: bash scripts/deploy-admin.sh", - script_path.display() - ); - } + println!("Current: {current}"); + println!("Latest: {latest}"); - let mut cmd = Command::new("bash"); - cmd.arg(&script_path); + if current == latest { + println!("Already up to date."); + return Ok(()); + } if check_only { - cmd.arg("--check"); + println!("Update available: {current} → {latest}"); + return Ok(()); } - if let Some(version) = target_version { - cmd.arg("--version"); - cmd.arg(version); + + println!("Upgrading {current} → {latest}..."); + download_and_install(&latest).await?; + println!("Upgrade complete. Restart the daemon to use the new version."); + Ok(()) +} + +fn current_version() -> String { + env!("CARGO_PKG_VERSION").to_string() +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn current_version_returns_cargo_version() { + let version = current_version(); + assert!(!version.is_empty()); + // Should be semver-like + assert!(version.contains('.'), "expected semver: {version}"); } +} - cmd.stdout(std::process::Stdio::piped()); - cmd.stderr(std::process::Stdio::piped()); +async fn download_and_install(version: &str) -> Result<()> { + let (suffix, exe_ext) = github_release::platform_suffix(); + let bin_dir = dirs::home_dir() + .context("cannot determine home directory")? + .join(".ahand") + .join("bin"); + std::fs::create_dir_all(&bin_dir)?; - let mut child = cmd.spawn().context("Failed to spawn upgrade.sh")?; + // Stop daemon before replacing binaries + if let Err(e) = crate::daemon::stop().await { + eprintln!("Note: could not stop daemon: {e}"); + } - let stdout = child.stdout.take().expect("stdout"); - let stderr = child.stderr.take().expect("stderr"); + // Download checksums for verification + let checksums_url = format!( + "https://github.com/{GITHUB_REPO}/releases/download/rust-v{version}/checksums-rust-{suffix}.txt" + ); + let checksums_bytes = github_release::download_bytes(&checksums_url) + .await + .context("Failed to download checksums — cannot verify binary integrity")?; + let checksums_text = String::from_utf8_lossy(&checksums_bytes).to_string(); - let stdout_task = tokio::spawn(async move { - let reader = BufReader::new(stdout); - let mut lines = reader.lines(); - while let Ok(Some(line)) = lines.next_line().await { - println!("{}", line); + for binary in &["ahandd", "ahandctl"] { + let asset = format!("{binary}-{suffix}{exe_ext}"); + let url = format!( + "https://github.com/{GITHUB_REPO}/releases/download/rust-v{version}/{asset}" + ); + println!(" Downloading {asset}..."); + let bytes = github_release::download_bytes(&url) + .await + .with_context(|| format!("Failed to download {asset}"))?; + + // Verify checksum if available + if let Some(expected) = checksums_text + .lines() + .find(|line| line.ends_with(&asset)) + .and_then(|line| line.split_whitespace().next()) + { + use sha2::{Digest, Sha256}; + let actual = format!("{:x}", Sha256::digest(&bytes)); + if actual != expected { + anyhow::bail!("Checksum mismatch for {asset}: expected {expected}, got {actual}"); + } + println!(" Checksum OK: {asset}"); + } else { + anyhow::bail!("Checksum entry missing for {asset} — cannot verify binary integrity"); } - }); - let stderr_task = tokio::spawn(async move { - let reader = BufReader::new(stderr); - let mut lines = reader.lines(); - while let Ok(Some(line)) = lines.next_line().await { - eprintln!("{}", line); + let dest = bin_dir.join(format!("{binary}{exe_ext}")); + + // On Windows, rename current binary before overwriting (can't overwrite running exe) + #[cfg(windows)] + { + let backup = bin_dir.join(format!("{binary}.old{exe_ext}")); + let _ = std::fs::remove_file(&backup); + if dest.exists() { + std::fs::rename(&dest, &backup) + .with_context(|| format!("Failed to backup {}", dest.display()))?; + } } - }); - let status = child.wait().await?; - let _ = stdout_task.await; - let _ = stderr_task.await; + std::fs::write(&dest, &bytes) + .with_context(|| format!("Failed to write {}", dest.display()))?; - if !status.success() { - anyhow::bail!("Upgrade failed with exit code: {}", status); - } + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(&dest, std::fs::Permissions::from_mode(0o755))?; + } + println!(" Installed: {}", dest.display()); + } Ok(()) } - -fn resolve_script_path() -> Result { - let home = dirs::home_dir().context("Failed to find home directory")?; - Ok(home.join(".ahand").join("bin").join("upgrade.sh")) -} diff --git a/crates/ahandd/Cargo.toml b/crates/ahandd/Cargo.toml index d64a0a7e..4ff584bb 100644 --- a/crates/ahandd/Cargo.toml +++ b/crates/ahandd/Cargo.toml @@ -4,6 +4,14 @@ version.workspace = true edition.workspace = true license.workspace = true +[lib] +name = "ahandd" +path = "src/lib.rs" + +[[bin]] +name = "ahandd" +path = "src/main.rs" + [dependencies] ahand-protocol = { path = "../ahand-protocol" } tokio.workspace = true @@ -33,3 +41,19 @@ xz2 = { version = "0.1", features = ["static"] } tokio-util = { version = "0.7", features = ["io"] } semver = "1" hex = "0.4" + +[target.'cfg(windows)'.dependencies] +windows-sys = { version = "0.59", features = [ + "Win32_Foundation", + "Win32_Security", + "Win32_Security_Authorization", + "Win32_Security_Cryptography", + "Win32_System_Memory", + "Win32_System_Pipes", + "Win32_System_Threading", +] } +zip = "2" + +[dev-dependencies] +tempfile = "3" +libc = "0.2" diff --git a/crates/ahandd/src/browser_init.rs b/crates/ahandd/src/browser_init.rs index 202ae323..0698ec9f 100644 --- a/crates/ahandd/src/browser_init.rs +++ b/crates/ahandd/src/browser_init.rs @@ -40,7 +40,10 @@ pub async fn run(force: bool) -> Result<()> { println!(); println!("Setup complete!"); println!(" Node.js: {}", node_bin.display()); + #[cfg(unix)] let cli_path = dirs.node.join("bin").join("playwright-cli"); + #[cfg(windows)] + let cli_path = dirs.node.join("playwright-cli.cmd"); println!(" playwright-cli: {}", cli_path.display()); println!(); println!("playwright-cli will use the browser installed on your system (Chrome, Edge, etc.)."); @@ -48,24 +51,49 @@ pub async fn run(force: bool) -> Result<()> { } async fn clean(dirs: &Dirs) { - // Uninstall playwright-cli from our managed prefix - let npm = dirs.node.join("bin").join("npm"); - if npm.exists() { - let prefix = dirs.node.to_string_lossy().to_string(); - let _ = tokio::process::Command::new(&npm) - .args(["uninstall", "-g", "--prefix", &prefix, "@playwright/cli"]) - .stdout(Stdio::null()) - .stderr(Stdio::null()) - .status() - .await; + #[cfg(unix)] + { + // Uninstall playwright-cli from our managed prefix + let npm = dirs.node.join("bin").join("npm"); + if npm.exists() { + let prefix = dirs.node.to_string_lossy().to_string(); + let _ = tokio::process::Command::new(&npm) + .args(["uninstall", "-g", "--prefix", &prefix, "@playwright/cli"]) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status() + .await; + } + } + + #[cfg(windows)] + { + // Note: we intentionally do NOT kill all node.exe processes system-wide + // as that would affect unrelated Node.js applications running on the machine. + + // Uninstall playwright-cli from our managed prefix + let npm = dirs.node.join("npm.cmd"); + if npm.exists() { + let prefix = dirs.node.to_string_lossy().to_string(); + let _ = tokio::process::Command::new(&npm) + .args(["uninstall", "-g", "--prefix", &prefix, "@playwright/cli"]) + .stdout(Stdio::null()) + .stderr(Stdio::null()) + .status() + .await; + } } - println!(" Cleaned playwright-cli installation."); + + println!(" Cleaned browser installation."); } // Step 1: Node.js async fn ensure_node(dirs: &Dirs) -> Result { + #[cfg(unix)] let local_node = dirs.node.join("bin").join("node"); + #[cfg(windows)] + let local_node = dirs.node.join("node.exe"); // Check if we already have a suitable local node if local_node.exists() { @@ -133,40 +161,89 @@ async fn node_major_version(node_bin: &Path) -> Option { } async fn install_node(dirs: &Dirs) -> Result<()> { - let (os, arch) = platform_info(); - let tarball = format!("node-v{NODE_LTS_VERSION}-{os}-{arch}.tar.xz"); - let url = format!("https://nodejs.org/dist/v{NODE_LTS_VERSION}/{tarball}"); + #[cfg(windows)] + { + let (_os, arch) = platform_info(); + let zipfile = format!("node-v{NODE_LTS_VERSION}-win-{arch}.zip"); + let url = format!("https://nodejs.org/dist/v{NODE_LTS_VERSION}/{zipfile}"); + + let bytes = download_bytes(&url).await.context(format!( + "Failed to download Node.js from {url} — check your network connection" + ))?; - let bytes = download_bytes(&url).await.context(format!( - "Failed to download Node.js from {url} — check your network connection" - ))?; + std::fs::create_dir_all(&dirs.node).context(format!( + "Failed to create directory {}: permission denied or disk full", + dirs.node.display() + ))?; - std::fs::create_dir_all(&dirs.node).context(format!( - "Failed to create directory {}: permission denied or disk full", - dirs.node.display() - ))?; - let decoder = xz2::read::XzDecoder::new(std::io::Cursor::new(bytes)); - let mut archive = tar::Archive::new(decoder); - archive.set_preserve_permissions(true); - for entry in archive - .entries() - .context("Failed to read Node.js archive — download may be corrupted")? - { - let mut entry = entry.context("Corrupted entry in Node.js archive")?; - let path = entry.path()?.into_owned(); - // Strip first component (e.g. "node-v24.13.0-darwin-arm64/bin/node" -> "bin/node") - let stripped: PathBuf = path.components().skip(1).collect(); - if stripped.components().count() == 0 { - continue; - } - let dest = dirs.node.join(&stripped); - if let Some(parent) = dest.parent() { - std::fs::create_dir_all(parent)?; + let cursor = std::io::Cursor::new(bytes); + let mut archive = zip::ZipArchive::new(cursor) + .context("Failed to read Node.js zip archive")?; + for i in 0..archive.len() { + let mut file = archive.by_index(i)?; + let path = match file.enclosed_name() { + Some(p) => p.to_owned(), + None => continue, + }; + // Strip first component (e.g. "node-v24.13.0-win-x64/node.exe" -> "node.exe") + let stripped: PathBuf = path.components().skip(1).collect(); + if stripped.components().count() == 0 { + continue; + } + let dest = dirs.node.join(&stripped); + // Defense-in-depth: ensure destination stays inside extraction root + if !dest.starts_with(&dirs.node) { + continue; + } + if file.is_dir() { + std::fs::create_dir_all(&dest)?; + } else { + if let Some(parent) = dest.parent() { + std::fs::create_dir_all(parent)?; + } + let mut outfile = std::fs::File::create(&dest)?; + std::io::copy(&mut file, &mut outfile)?; + } } - entry.unpack(&dest).context(format!( - "Failed to extract {} — disk may be full", - dest.display() + } + + #[cfg(unix)] + { + let (os, arch) = platform_info(); + let tarball = format!("node-v{NODE_LTS_VERSION}-{os}-{arch}.tar.xz"); + let url = format!("https://nodejs.org/dist/v{NODE_LTS_VERSION}/{tarball}"); + + let bytes = download_bytes(&url).await.context(format!( + "Failed to download Node.js from {url} — check your network connection" ))?; + + std::fs::create_dir_all(&dirs.node).context(format!( + "Failed to create directory {}: permission denied or disk full", + dirs.node.display() + ))?; + let decoder = xz2::read::XzDecoder::new(std::io::Cursor::new(bytes)); + let mut archive = tar::Archive::new(decoder); + archive.set_preserve_permissions(true); + for entry in archive + .entries() + .context("Failed to read Node.js archive — download may be corrupted")? + { + let mut entry = entry.context("Corrupted entry in Node.js archive")?; + let path = entry.path()?.into_owned(); + // Strip first component (e.g. "node-v24.13.0-darwin-arm64/bin/node" -> "bin/node") + let stripped: PathBuf = path.components().skip(1).collect(); + if stripped.components().count() == 0 { + continue; + } + let dest = dirs.node.join(&stripped); + if let Some(parent) = dest.parent() { + std::fs::create_dir_all(parent)?; + } + entry.unpack(&dest).context(format!( + "Failed to extract {} — disk may be full", + dest.display() + ))?; + } } Ok(()) @@ -175,13 +252,22 @@ async fn install_node(dirs: &Dirs) -> Result<()> { // Step 2: playwright-cli via npm async fn install_playwright_cli(dirs: &Dirs, node_bin: &Path) -> Result<()> { + #[cfg(unix)] let npm = node_bin .parent() .map(|p| p.join("npm")) .unwrap_or_else(|| PathBuf::from("npm")); + #[cfg(windows)] + let npm = node_bin + .parent() + .map(|p| p.join("npm.cmd")) + .unwrap_or_else(|| PathBuf::from("npm.cmd")); // Check if already installed at the correct version + #[cfg(unix)] let cli_path = dirs.node.join("bin").join("playwright-cli"); + #[cfg(windows)] + let cli_path = dirs.node.join("playwright-cli.cmd"); if cli_path.exists() { // Verify version let output = tokio::process::Command::new(&cli_path) @@ -337,3 +423,36 @@ fn which(bin: &str) -> Result { } anyhow::bail!("{bin} not found in PATH") } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn platform_info_returns_valid() { + let (os, arch) = platform_info(); + assert!( + ["darwin", "linux", "win", "unknown"].contains(&os), + "unexpected os: {os}" + ); + assert!( + ["arm64", "x64", "unknown"].contains(&arch), + "unexpected arch: {arch}" + ); + } + + #[test] + fn dirs_new_succeeds() { + // Should not panic as long as home dir is available + let dirs = Dirs::new(); + assert!(dirs.is_ok()); + } + + #[cfg(unix)] + #[test] + fn dirs_node_path_is_under_ahand() { + let dirs = Dirs::new().unwrap(); + assert!(dirs.node.to_string_lossy().contains(".ahand")); + assert!(dirs.node.to_string_lossy().ends_with("node")); + } +} diff --git a/crates/ahandd/src/config.rs b/crates/ahandd/src/config.rs index 1116224c..7e6808c6 100644 --- a/crates/ahandd/src/config.rs +++ b/crates/ahandd/src/config.rs @@ -245,14 +245,25 @@ impl Config { self.device_id.clone().unwrap_or_else(uuid_v4) } - /// Resolve the IPC socket path. Default: ~/.ahand/ahandd.sock. + /// Resolve the IPC socket path. + /// Unix default: ~/.ahand/ahandd.sock + /// Windows default: \\.\pipe\ahandd pub fn ipc_socket_path(&self) -> PathBuf { match &self.ipc_socket_path { Some(p) => PathBuf::from(p), - None => dirs::home_dir() - .unwrap_or_else(|| PathBuf::from("/tmp")) - .join(".ahand") - .join("ahandd.sock"), + None => { + #[cfg(unix)] + { + dirs::home_dir() + .unwrap_or_else(|| PathBuf::from("/tmp")) + .join(".ahand") + .join("ahandd.sock") + } + #[cfg(windows)] + { + PathBuf::from(r"\\.\pipe\ahandd") + } + } } } @@ -283,3 +294,93 @@ fn uuid_v4() -> String { .as_nanos(); format!("{:032x}", ts) } + +#[cfg(test)] +mod tests { + use super::*; + + fn minimal_config() -> Config { + Config { + mode: None, + server_url: "ws://localhost:3000/ws".to_string(), + device_id: None, + max_concurrent_jobs: None, + data_dir: None, + debug_ipc: None, + ipc_socket_path: None, + ipc_socket_mode: None, + trust_timeout_mins: None, + default_session_mode: None, + policy: Default::default(), + openclaw: None, + browser: None, + hub: None, + } + } + + #[test] + fn ipc_socket_path_custom() { + let mut cfg = minimal_config(); + cfg.ipc_socket_path = Some("/custom/path.sock".to_string()); + assert_eq!(cfg.ipc_socket_path(), PathBuf::from("/custom/path.sock")); + } + + #[cfg(unix)] + #[test] + fn ipc_socket_path_unix_default() { + let cfg = minimal_config(); + let path = cfg.ipc_socket_path(); + assert!(path.to_string_lossy().ends_with("ahandd.sock")); + assert!(path.to_string_lossy().contains(".ahand")); + } + + #[cfg(windows)] + #[test] + fn ipc_socket_path_windows_default() { + let cfg = minimal_config(); + let path = cfg.ipc_socket_path(); + assert_eq!(path, PathBuf::from(r"\\.\pipe\ahandd")); + } + + #[test] + fn ipc_socket_mode_default() { + let cfg = minimal_config(); + assert_eq!(cfg.ipc_socket_mode(), 0o660); + } + + #[test] + fn ipc_socket_mode_custom() { + let mut cfg = minimal_config(); + cfg.ipc_socket_mode = Some(0o600); + assert_eq!(cfg.ipc_socket_mode(), 0o600); + } + + #[test] + fn device_id_auto_generated_when_none() { + let cfg = minimal_config(); + let id = cfg.device_id(); + assert!(!id.is_empty()); + } + + #[test] + fn device_id_custom() { + let mut cfg = minimal_config(); + cfg.device_id = Some("my-device".to_string()); + assert_eq!(cfg.device_id(), "my-device"); + } + + #[test] + fn data_dir_default() { + let cfg = minimal_config(); + let dir = cfg.data_dir(); + assert!(dir.is_some()); + assert!(dir.unwrap().to_string_lossy().contains(".ahand")); + } + + #[test] + fn data_dir_empty_disables() { + let mut cfg = minimal_config(); + cfg.data_dir = Some("".to_string()); + assert!(cfg.data_dir().is_none()); + } +} diff --git a/crates/ahandd/src/dpapi.rs b/crates/ahandd/src/dpapi.rs new file mode 100644 index 00000000..090b89de --- /dev/null +++ b/crates/ahandd/src/dpapi.rs @@ -0,0 +1,78 @@ +//! Windows DPAPI wrapper for encrypting sensitive data at rest. +//! +//! Data is bound to the current Windows user account. +//! Only the same user on the same machine can decrypt it. + +use std::io; +use windows_sys::Win32::Foundation::LocalFree; +use windows_sys::Win32::Security::Cryptography::{ + CryptProtectData, CryptUnprotectData, CRYPTPROTECT_UI_FORBIDDEN, CRYPT_INTEGER_BLOB, +}; + +/// Encrypt data using DPAPI, bound to current user. +pub fn protect(plaintext: &[u8]) -> io::Result> { + unsafe { + let input = CRYPT_INTEGER_BLOB { + cbData: plaintext.len() as u32, + pbData: plaintext.as_ptr() as *mut u8, + }; + let mut output = CRYPT_INTEGER_BLOB { + cbData: 0, + pbData: std::ptr::null_mut(), + }; + + let result = CryptProtectData( + &input, + std::ptr::null(), // description + std::ptr::null(), // entropy + std::ptr::null(), // reserved + std::ptr::null(), // prompt + CRYPTPROTECT_UI_FORBIDDEN, + &mut output, + ); + + if result == 0 { + return Err(io::Error::last_os_error()); + } + + let encrypted = + std::slice::from_raw_parts(output.pbData, output.cbData as usize).to_vec(); + LocalFree(output.pbData as *mut _); + Ok(encrypted) + } +} + +/// Decrypt DPAPI-protected data. Only works for the same user who encrypted it. +pub fn unprotect(ciphertext: &[u8]) -> io::Result> { + unsafe { + let input = CRYPT_INTEGER_BLOB { + cbData: ciphertext.len() as u32, + pbData: ciphertext.as_ptr() as *mut u8, + }; + let mut output = CRYPT_INTEGER_BLOB { + cbData: 0, + pbData: std::ptr::null_mut(), + }; + + let result = CryptUnprotectData( + &input, + std::ptr::null_mut(), // description + std::ptr::null(), // entropy + std::ptr::null(), // reserved + std::ptr::null(), // prompt + CRYPTPROTECT_UI_FORBIDDEN, + &mut output, + ); + + if result == 0 { + return Err(io::Error::last_os_error()); + } + + let decrypted = + std::slice::from_raw_parts(output.pbData, output.cbData as usize).to_vec(); + // Zero plaintext before freeing to prevent key material lingering in heap + std::ptr::write_bytes(output.pbData, 0, output.cbData as usize); + LocalFree(output.pbData as *mut _); + Ok(decrypted) + } +} diff --git a/crates/ahandd/src/fs_perms.rs b/crates/ahandd/src/fs_perms.rs new file mode 100644 index 00000000..e3a4a7ac --- /dev/null +++ b/crates/ahandd/src/fs_perms.rs @@ -0,0 +1,176 @@ +use std::io; +use std::path::Path; + +/// Restrict file to owner-only read/write (Unix 0o600 equivalent). +#[cfg(unix)] +pub fn restrict_owner_only(path: &Path) -> io::Result<()> { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o600)) +} + +/// Restrict file to owner + group read/write (Unix 0o660 equivalent). +#[cfg(unix)] +pub fn restrict_owner_and_group(path: &Path) -> io::Result<()> { + use std::os::unix::fs::PermissionsExt; + std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o660)) +} + +#[cfg(windows)] +pub fn restrict_owner_only(path: &Path) -> io::Result<()> { + win_acl::set_owner_only_acl(path) +} + +#[cfg(windows)] +pub fn restrict_owner_and_group(path: &Path) -> io::Result<()> { + win_acl::set_owner_and_users_acl(path) +} + +#[cfg(windows)] +mod win_acl { + use std::io; + use std::path::Path; + + pub fn set_owner_only_acl(path: &Path) -> io::Result<()> { + use windows_sys::Win32::Security::*; + use windows_sys::Win32::System::Threading::GetCurrentProcess; + + unsafe { + let mut token_handle = 0isize; + if OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY, &mut token_handle) == 0 { + return Err(io::Error::last_os_error()); + } + + let result = set_acl_with_token(path, token_handle, false); + windows_sys::Win32::Foundation::CloseHandle(token_handle); + result + } + } + + pub fn set_owner_and_users_acl(path: &Path) -> io::Result<()> { + use windows_sys::Win32::Security::*; + use windows_sys::Win32::System::Threading::GetCurrentProcess; + + unsafe { + let mut token_handle = 0isize; + if OpenProcessToken(GetCurrentProcess(), TOKEN_QUERY, &mut token_handle) == 0 { + return Err(io::Error::last_os_error()); + } + + let result = set_acl_with_token(path, token_handle, true); + windows_sys::Win32::Foundation::CloseHandle(token_handle); + result + } + } + + unsafe fn set_acl_with_token( + path: &Path, + token_handle: isize, + include_users_group: bool, + ) -> io::Result<()> { + use windows_sys::Win32::Security::Authorization::*; + use windows_sys::Win32::Security::*; + + // Get current user SID from token + let mut info_len = 0u32; + GetTokenInformation( + token_handle, + TokenUser, + std::ptr::null_mut(), + 0, + &mut info_len, + ); + let mut buffer = vec![0u8; info_len as usize]; + if GetTokenInformation( + token_handle, + TokenUser, + buffer.as_mut_ptr() as *mut _, + info_len, + &mut info_len, + ) == 0 + { + return Err(io::Error::last_os_error()); + } + + let token_user = &*(buffer.as_ptr() as *const TOKEN_USER); + let user_sid = token_user.User.Sid; + + // GENERIC_READ | GENERIC_WRITE + const GENERIC_RW: u32 = 0x10000000 | 0x20000000; + + // Build ACE for current user + let mut entries = Vec::with_capacity(2); + entries.push(EXPLICIT_ACCESS_W { + grfAccessPermissions: GENERIC_RW, + grfAccessMode: SET_ACCESS, + grfInheritance: NO_INHERITANCE, + Trustee: TRUSTEE_W { + pMultipleTrustee: std::ptr::null_mut(), + MultipleTrusteeOperation: NO_MULTIPLE_TRUSTEE, + TrusteeForm: TRUSTEE_IS_SID, + TrusteeType: TRUSTEE_IS_USER, + ptstrName: user_sid as *mut u16, + }, + }); + + // Optionally add BUILTIN\Users group + let mut users_sid_buf = [0u8; 68]; + if include_users_group { + let mut sid_size = users_sid_buf.len() as u32; + if CreateWellKnownSid( + WinBuiltinUsersSid, + std::ptr::null_mut(), + users_sid_buf.as_mut_ptr() as *mut _, + &mut sid_size, + ) == 0 + { + return Err(io::Error::last_os_error()); + } + entries.push(EXPLICIT_ACCESS_W { + grfAccessPermissions: GENERIC_RW, + grfAccessMode: SET_ACCESS, + grfInheritance: NO_INHERITANCE, + Trustee: TRUSTEE_W { + pMultipleTrustee: std::ptr::null_mut(), + MultipleTrusteeOperation: NO_MULTIPLE_TRUSTEE, + TrusteeForm: TRUSTEE_IS_SID, + TrusteeType: TRUSTEE_IS_WELL_KNOWN_GROUP, + ptstrName: users_sid_buf.as_mut_ptr() as *mut u16, + }, + }); + } + + let mut acl = std::ptr::null_mut(); + let result = SetEntriesInAclW( + entries.len() as u32, + entries.as_mut_ptr(), + std::ptr::null_mut(), + &mut acl, + ); + if result != 0 { + return Err(io::Error::from_raw_os_error(result as i32)); + } + + let path_wide: Vec = path + .to_string_lossy() + .encode_utf16() + .chain(std::iter::once(0)) + .collect(); + + let result = SetNamedSecurityInfoW( + path_wide.as_ptr(), + SE_FILE_OBJECT, + DACL_SECURITY_INFORMATION | PROTECTED_DACL_SECURITY_INFORMATION, + std::ptr::null_mut(), + std::ptr::null_mut(), + acl, + std::ptr::null_mut(), + ); + + windows_sys::Win32::System::Memory::LocalFree(acl as *mut _); + + if result != 0 { + return Err(io::Error::from_raw_os_error(result as i32)); + } + Ok(()) + } +} diff --git a/crates/ahandd/src/ipc.rs b/crates/ahandd/src/ipc.rs index e1c4c8b9..36c64c7e 100644 --- a/crates/ahandd/src/ipc.rs +++ b/crates/ahandd/src/ipc.rs @@ -1,10 +1,9 @@ -use std::path::{Path, PathBuf}; +use std::path::PathBuf; use std::sync::Arc; use ahand_protocol::{BrowserResponse, Envelope, JobFinished, JobRejected, SessionMode, envelope}; use prost::Message; -use tokio::io::{AsyncReadExt, AsyncWriteExt}; -use tokio::net::{UnixListener, UnixStream}; +use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt}; use tokio::sync::{broadcast, mpsc}; use tracing::{error, info, warn}; @@ -15,9 +14,38 @@ use crate::registry::{IsKnown, JobRegistry}; use crate::session::{SessionDecision, SessionManager}; use crate::store::RunStore; -/// Start the IPC server on the given Unix socket path. +/// Start the IPC server on the given socket path. #[allow(clippy::too_many_arguments)] pub async fn serve_ipc( + socket_path: PathBuf, + #[allow(unused_variables)] socket_mode: u32, + registry: Arc, + store: Option>, + session_mgr: Arc, + approval_mgr: Arc, + approval_broadcast_tx: broadcast::Sender, + device_id: String, + browser_mgr: Arc, +) -> anyhow::Result<()> { + #[cfg(unix)] + { + serve_ipc_unix( + socket_path, socket_mode, registry, store, session_mgr, + approval_mgr, approval_broadcast_tx, device_id, browser_mgr, + ).await + } + #[cfg(windows)] + { + serve_ipc_windows( + socket_path, registry, store, session_mgr, + approval_mgr, approval_broadcast_tx, device_id, browser_mgr, + ).await + } +} + +#[cfg(unix)] +#[allow(clippy::too_many_arguments)] +async fn serve_ipc_unix( socket_path: PathBuf, socket_mode: u32, registry: Arc, @@ -28,6 +56,8 @@ pub async fn serve_ipc( device_id: String, browser_mgr: Arc, ) -> anyhow::Result<()> { + use tokio::net::UnixListener; + // Remove stale socket file if it exists. let _ = std::fs::remove_file(&socket_path); @@ -39,14 +69,13 @@ pub async fn serve_ipc( let listener = UnixListener::bind(&socket_path)?; // Set socket permissions. - set_permissions(&socket_path, socket_mode)?; + crate::fs_perms::restrict_owner_and_group(&socket_path)?; info!(path = %socket_path.display(), mode = format!("{:04o}", socket_mode), "IPC server listening"); loop { match listener.accept().await { Ok((stream, _addr)) => { - // Get peer credentials before splitting the stream. let caller_uid = match stream.peer_cred() { Ok(cred) => format!("uid:{}", cred.uid()), Err(e) => { @@ -55,21 +84,14 @@ pub async fn serve_ipc( } }; - let reg = Arc::clone(®istry); - let st = store.clone(); - let smgr = Arc::clone(&session_mgr); - let amgr = Arc::clone(&approval_mgr); - let bcast = approval_broadcast_tx.clone(); - let did = device_id.clone(); - let bmgr = Arc::clone(&browser_mgr); - tokio::spawn(async move { - if let Err(e) = - handle_ipc_conn(stream, reg, st, smgr, amgr, bcast, did, caller_uid, bmgr) - .await - { - warn!(error = %e, "IPC connection error"); - } - }); + let (reader, writer) = stream.into_split(); + spawn_ipc_handler( + reader, writer, + Arc::clone(®istry), store.clone(), + Arc::clone(&session_mgr), Arc::clone(&approval_mgr), + approval_broadcast_tx.clone(), device_id.clone(), + caller_uid, Arc::clone(&browser_mgr), + ); } Err(e) => { error!(error = %e, "IPC accept error"); @@ -78,19 +100,186 @@ pub async fn serve_ipc( } } +#[cfg(windows)] #[allow(clippy::too_many_arguments)] -async fn handle_ipc_conn( - stream: UnixStream, +async fn serve_ipc_windows( + socket_path: PathBuf, registry: Arc, store: Option>, session_mgr: Arc, approval_mgr: Arc, approval_broadcast_tx: broadcast::Sender, device_id: String, - caller_uid: String, browser_mgr: Arc, ) -> anyhow::Result<()> { - let (reader, writer) = stream.into_split(); + use tokio::net::windows::named_pipe::ServerOptions; + + let pipe_name = socket_path.to_string_lossy().to_string(); + info!(path = %pipe_name, "IPC server listening (Named Pipe)"); + + let mut server = ServerOptions::new() + .first_pipe_instance(true) + .create(&pipe_name)?; + + loop { + if let Err(e) = server.connect().await { + error!(error = %e, "IPC pipe connect error"); + // Rebuild server instance and retry + match ServerOptions::new().create(&pipe_name) { + Ok(s) => server = s, + Err(e) => { + error!(error = %e, "IPC pipe recreate failed, retrying in 100ms"); + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + continue; + } + } + continue; + } + + let connected = server; + server = match ServerOptions::new().create(&pipe_name) { + Ok(s) => s, + Err(e) => { + error!(error = %e, "IPC pipe create failed"); + continue; + } + }; + + let caller_uid = get_pipe_caller_identity(&connected); + + // Security: reject connections from other users + if caller_uid == "user:unknown" { + warn!("IPC: rejecting connection — could not verify caller identity"); + drop(connected); + continue; + } + + let (reader, writer) = tokio::io::split(connected); + spawn_ipc_handler( + reader, writer, + Arc::clone(®istry), store.clone(), + Arc::clone(&session_mgr), Arc::clone(&approval_mgr), + approval_broadcast_tx.clone(), device_id.clone(), + caller_uid, Arc::clone(&browser_mgr), + ); + } +} + +#[cfg(windows)] +fn get_pipe_caller_identity(pipe: &tokio::net::windows::named_pipe::NamedPipeServer) -> String { + use std::os::windows::io::AsRawHandle; + use windows_sys::Win32::Foundation::CloseHandle; + use windows_sys::Win32::Security::*; + use windows_sys::Win32::System::Pipes::GetNamedPipeClientProcessId; + use windows_sys::Win32::System::Threading::{OpenProcess, PROCESS_QUERY_LIMITED_INFORMATION}; + + unsafe { + let handle = pipe.as_raw_handle() as isize; + let mut pid = 0u32; + if GetNamedPipeClientProcessId(handle, &mut pid) == 0 { + return "user:unknown".to_string(); + } + + let process = OpenProcess(PROCESS_QUERY_LIMITED_INFORMATION, 0, pid); + if process == 0 { + return format!("pid:{pid}"); + } + + let mut token = 0isize; + if OpenProcessToken(process, TOKEN_QUERY, &mut token) == 0 { + CloseHandle(process); + return format!("pid:{pid}"); + } + + let mut info_len = 0u32; + GetTokenInformation(token, TokenUser, std::ptr::null_mut(), 0, &mut info_len); + if info_len == 0 { + CloseHandle(token); + CloseHandle(process); + return format!("pid:{pid}"); + } + let mut buffer = vec![0u8; info_len as usize]; + if GetTokenInformation( + token, TokenUser, buffer.as_mut_ptr() as *mut _, + info_len, &mut info_len, + ) == 0 { + CloseHandle(token); + CloseHandle(process); + return format!("pid:{pid}"); + } + + let token_user = &*(buffer.as_ptr() as *const TOKEN_USER); + let sid = token_user.User.Sid; + + let mut name_buf = [0u16; 256]; + let mut name_len = 256u32; + let mut domain_buf = [0u16; 256]; + let mut domain_len = 256u32; + let mut sid_type = 0; + + if LookupAccountSidW( + std::ptr::null(), sid, + name_buf.as_mut_ptr(), &mut name_len, + domain_buf.as_mut_ptr(), &mut domain_len, + &mut sid_type, + ) == 0 { + CloseHandle(token); + CloseHandle(process); + return format!("pid:{pid}"); + } + + CloseHandle(token); + CloseHandle(process); + + let username = String::from_utf16_lossy(&name_buf[..name_len as usize]); + format!("user:{username}") + } +} + +#[allow(clippy::too_many_arguments)] +fn spawn_ipc_handler( + reader: R, + writer: W, + registry: Arc, + store: Option>, + session_mgr: Arc, + approval_mgr: Arc, + approval_broadcast_tx: broadcast::Sender, + device_id: String, + caller_uid: String, + browser_mgr: Arc, +) where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ + tokio::spawn(async move { + if let Err(e) = handle_ipc_conn( + reader, writer, registry, store, session_mgr, + approval_mgr, approval_broadcast_tx, device_id, + caller_uid, browser_mgr, + ).await { + warn!(error = %e, "IPC connection error"); + } + }); +} + +#[allow(clippy::too_many_arguments)] +async fn handle_ipc_conn( + reader: R, + writer: W, + registry: Arc, + store: Option>, + session_mgr: Arc, + approval_mgr: Arc, + approval_broadcast_tx: broadcast::Sender, + device_id: String, + caller_uid: String, + browser_mgr: Arc, +) -> anyhow::Result<()> +where + R: AsyncRead + Unpin + Send + 'static, + W: AsyncWrite + Unpin + Send + 'static, +{ let mut reader = tokio::io::BufReader::new(reader); info!(caller_uid = %caller_uid, "IPC: new connection"); @@ -470,18 +659,18 @@ async fn read_frame(reader: &mut R) -> std::io::Result< /// Write a length-prefixed frame. async fn write_frame(writer: &mut W, data: &[u8]) -> std::io::Result<()> { + if data.len() > 16 * 1024 * 1024 { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "outgoing frame too large", + )); + } writer.write_u32(data.len() as u32).await?; writer.write_all(data).await?; writer.flush().await?; Ok(()) } -fn set_permissions(path: &Path, mode: u32) -> std::io::Result<()> { - use std::os::unix::fs::PermissionsExt; - let perms = std::fs::Permissions::from_mode(mode); - std::fs::set_permissions(path, perms) -} - fn now_ms() -> u64 { std::time::SystemTime::now() .duration_since(std::time::UNIX_EPOCH) @@ -494,3 +683,76 @@ fn new_msg_id() -> String { static COUNTER: AtomicU64 = AtomicU64::new(0); format!("ipc-{}", COUNTER.fetch_add(1, Ordering::Relaxed)) } + +#[cfg(test)] +mod tests { + #[cfg(windows)] + #[tokio::test] + async fn test_named_pipe_roundtrip() { + use tokio::net::windows::named_pipe::{ClientOptions, ServerOptions}; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + let pipe_name = format!(r"\\.\pipe\ahand-test-{}", std::process::id()); + + let mut server = ServerOptions::new() + .first_pipe_instance(true) + .create(&pipe_name) + .unwrap(); + + let server_task = tokio::spawn(async move { + server.connect().await.unwrap(); + let (mut reader, mut writer) = tokio::io::split(server); + + // Read frame + let len = reader.read_u32().await.unwrap() as usize; + let mut buf = vec![0u8; len]; + reader.read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, b"hello pipe"); + + // Write frame back + writer.write_u32(5).await.unwrap(); + writer.write_all(b"world").await.unwrap(); + writer.flush().await.unwrap(); + }); + + // Brief delay for server to start + tokio::time::sleep(std::time::Duration::from_millis(50)).await; + + let client = ClientOptions::new().open(&pipe_name).unwrap(); + let (mut reader, mut writer) = tokio::io::split(client); + + // Send frame + writer.write_u32(10).await.unwrap(); + writer.write_all(b"hello pipe").await.unwrap(); + writer.flush().await.unwrap(); + + // Read response + let len = reader.read_u32().await.unwrap() as usize; + let mut buf = vec![0u8; len]; + reader.read_exact(&mut buf).await.unwrap(); + assert_eq!(buf, b"world"); + + server_task.await.unwrap(); + } + + #[cfg(windows)] + #[test] + fn test_get_pipe_caller_identity_format() { + // Will be tested when we have a connected pipe in the full integration test + // For now, verify the function exists and compiles + } + + #[test] + fn test_now_ms_returns_nonzero() { + let ts = super::now_ms(); + assert!(ts > 0); + } + + #[test] + fn test_new_msg_id_unique() { + let a = super::new_msg_id(); + let b = super::new_msg_id(); + assert_ne!(a, b); + assert!(a.starts_with("ipc-")); + } +} diff --git a/crates/ahandd/src/lib.rs b/crates/ahandd/src/lib.rs index 93eb5797..84161716 100644 --- a/crates/ahandd/src/lib.rs +++ b/crates/ahandd/src/lib.rs @@ -1,11 +1,17 @@ pub mod ahand_client; -mod approval; -mod browser; +pub mod approval; +pub mod browser; pub mod config; pub mod device_identity; pub mod executor; +pub mod fs_perms; +pub mod ipc; mod outbox; -mod registry; -mod session; +pub mod policy; +pub mod registry; +pub mod session; mod store; pub mod updater; + +#[cfg(windows)] +pub mod dpapi; diff --git a/crates/ahandd/src/main.rs b/crates/ahandd/src/main.rs index e2895613..045cce9a 100644 --- a/crates/ahandd/src/main.rs +++ b/crates/ahandd/src/main.rs @@ -5,6 +5,7 @@ mod browser_init; mod config; mod device_identity; mod executor; +mod fs_perms; mod ipc; mod openclaw; mod outbox; @@ -20,7 +21,6 @@ use std::sync::Arc; use ahand_protocol::Envelope; use clap::{Parser, Subcommand}; use config::ConnectionMode; -use tokio::signal::unix::{SignalKind, signal}; use tracing::info; #[derive(Parser)] @@ -290,8 +290,10 @@ async fn main() -> anyhow::Result<()> { let (approval_broadcast_tx, _) = tokio::sync::broadcast::channel::(64); // Set up signal handlers for graceful shutdown. - let mut sigterm = signal(SignalKind::terminate())?; - let mut sigint = signal(SignalKind::interrupt())?; + #[cfg(unix)] + let mut sigterm = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?; + #[cfg(unix)] + let mut sigint = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?; let main_future = async { match connection_mode { @@ -391,15 +393,30 @@ async fn main() -> anyhow::Result<()> { }; // Race main event loop against shutdown signals. - let result = tokio::select! { - r = main_future => r, - _ = sigterm.recv() => { - info!("received SIGTERM, shutting down"); - Ok(()) + let result = { + #[cfg(unix)] + { + tokio::select! { + r = main_future => r, + _ = sigterm.recv() => { + info!("received SIGTERM, shutting down"); + Ok(()) + } + _ = sigint.recv() => { + info!("received SIGINT, shutting down"); + Ok(()) + } + } } - _ = sigint.recv() => { - info!("received SIGINT, shutting down"); - Ok(()) + #[cfg(windows)] + { + tokio::select! { + r = main_future => r, + _ = tokio::signal::ctrl_c() => { + info!("received Ctrl+C, shutting down"); + Ok(()) + } + } } }; diff --git a/crates/ahandd/src/openclaw/device_identity.rs b/crates/ahandd/src/openclaw/device_identity.rs index 939add43..df69464a 100644 --- a/crates/ahandd/src/openclaw/device_identity.rs +++ b/crates/ahandd/src/openclaw/device_identity.rs @@ -65,6 +65,16 @@ impl DeviceIdentity { /// Load from file fn load(path: &PathBuf) -> Result { + #[cfg(windows)] + let content = { + let encrypted = std::fs::read(path) + .with_context(|| format!("failed to read {}", path.display()))?; + let decrypted = crate::dpapi::unprotect(&encrypted) + .map_err(|e| anyhow::anyhow!("DPAPI decrypt failed for {}: {}", path.display(), e))?; + String::from_utf8(decrypted) + .with_context(|| "decrypted identity is not valid UTF-8")? + }; + #[cfg(unix)] let content = std::fs::read_to_string(path) .with_context(|| format!("failed to read {}", path.display()))?; @@ -127,15 +137,21 @@ impl DeviceIdentity { let content = serde_json::to_string_pretty(&stored).context("failed to serialize identity")?; - std::fs::write(path, format!("{}\n", content)) - .with_context(|| format!("failed to write {}", path.display()))?; - - // Set file permissions to 0600 (user read/write only) + #[cfg(windows)] + { + let encrypted = crate::dpapi::protect(format!("{}\n", content).as_bytes()) + .map_err(|e| anyhow::anyhow!("DPAPI encrypt failed: {}", e))?; + std::fs::write(path, &encrypted) + .with_context(|| format!("failed to write {}", path.display()))?; + crate::fs_perms::restrict_owner_only(path) + .with_context(|| format!("failed to set permissions on {}", path.display()))?; + } #[cfg(unix)] { - use std::os::unix::fs::PermissionsExt; - let perms = std::fs::Permissions::from_mode(0o600); - let _ = std::fs::set_permissions(path, perms); + std::fs::write(path, format!("{}\n", content)) + .with_context(|| format!("failed to write {}", path.display()))?; + crate::fs_perms::restrict_owner_only(path) + .with_context(|| format!("failed to set permissions on {}", path.display()))?; } Ok(()) @@ -220,3 +236,108 @@ mod hex { .collect() } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn generate_identity_has_valid_fields() { + let id = DeviceIdentity::generate(); + assert!(!id.device_id.is_empty()); + assert_eq!(id.public_key_raw().len(), 32); + assert!(!id.public_key_base64url().is_empty()); + } + + #[test] + fn sign_produces_nonempty_signature() { + let id = DeviceIdentity::generate(); + let sig = id.sign("test payload"); + assert!(!sig.is_empty()); + } + + #[test] + fn save_and_load_roundtrip() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("device-identity.json"); + + let original = DeviceIdentity::generate(); + original.save(&path).unwrap(); + + let loaded = DeviceIdentity::load(&path).unwrap(); + assert_eq!(original.device_id, loaded.device_id); + assert_eq!(original.public_key_raw(), loaded.public_key_raw()); + } + + #[test] + fn load_or_create_generates_when_missing() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("new-identity.json"); + assert!(!path.exists()); + + let id = DeviceIdentity::load_or_create(&path).unwrap(); + assert!(path.exists()); + assert!(!id.device_id.is_empty()); + } + + #[test] + fn load_or_create_loads_existing() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("existing.json"); + + let original = DeviceIdentity::generate(); + original.save(&path).unwrap(); + + let loaded = DeviceIdentity::load_or_create(&path).unwrap(); + assert_eq!(original.device_id, loaded.device_id); + } + + #[cfg(windows)] + #[test] + fn save_and_load_roundtrip_with_dpapi() { + // On Windows, save() encrypts with DPAPI and load() decrypts. + // This tests the full integration. + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("dpapi-identity.json"); + + let original = DeviceIdentity::generate(); + original.save(&path).unwrap(); + + // File should be encrypted (not valid JSON) + let raw = std::fs::read(&path).unwrap(); + assert!(serde_json::from_slice::(&raw).is_err(), + "file should be DPAPI-encrypted, not plain JSON"); + + let loaded = DeviceIdentity::load(&path).unwrap(); + assert_eq!(original.device_id, loaded.device_id); + } + + #[test] + fn load_rejects_invalid_json() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("bad.json"); + std::fs::write(&path, "not json").unwrap(); + assert!(DeviceIdentity::load(&path).is_err()); + } + + #[test] + fn build_auth_payload_v1_format() { + let payload = build_auth_payload( + "dev1", "cli1", "tool", "user", &["read".to_string()], + 1234567890, None, None, + ); + assert!(payload.starts_with("v1|")); + assert!(payload.contains("dev1")); + assert!(payload.contains("1234567890")); + } + + #[test] + fn build_auth_payload_v2_with_nonce() { + let payload = build_auth_payload( + "dev1", "cli1", "tool", "user", &[], + 1234567890, Some("token"), Some("nonce123"), + ); + assert!(payload.starts_with("v2|")); + assert!(payload.contains("nonce123")); + } +} diff --git a/crates/ahandd/src/openclaw/exec_approvals.rs b/crates/ahandd/src/openclaw/exec_approvals.rs index 51a974c5..b9a966d0 100644 --- a/crates/ahandd/src/openclaw/exec_approvals.rs +++ b/crates/ahandd/src/openclaw/exec_approvals.rs @@ -62,13 +62,7 @@ pub fn save_exec_approvals(path: &Path, file: &ExecApprovalsFile) -> Result<()> std::fs::write(path, format!("{}\n", content)) .with_context(|| format!("failed to write {}", path.display()))?; - // Set file permissions to 0600 (user read/write only) - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - let perms = std::fs::Permissions::from_mode(0o600); - let _ = std::fs::set_permissions(path, perms); - } + let _ = crate::fs_perms::restrict_owner_only(path); Ok(()) } diff --git a/crates/ahandd/src/openclaw/pairing.rs b/crates/ahandd/src/openclaw/pairing.rs index c3936179..79bd0d87 100644 --- a/crates/ahandd/src/openclaw/pairing.rs +++ b/crates/ahandd/src/openclaw/pairing.rs @@ -81,13 +81,7 @@ pub fn save_pairing_state(path: &PathBuf, state: &PairingState) -> Result<()> { std::fs::write(path, format!("{}\n", content)) .with_context(|| format!("failed to write {}", path.display()))?; - // Set file permissions to 0600 (user read/write only) - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - let perms = std::fs::Permissions::from_mode(0o600); - let _ = std::fs::set_permissions(path, perms); - } + let _ = crate::fs_perms::restrict_owner_only(path); Ok(()) } diff --git a/crates/ahandd/tests/dpapi_test.rs b/crates/ahandd/tests/dpapi_test.rs new file mode 100644 index 00000000..25904b34 --- /dev/null +++ b/crates/ahandd/tests/dpapi_test.rs @@ -0,0 +1,52 @@ +//! DPAPI tests — only compile and run on Windows. + +#[cfg(windows)] +mod tests { + #[test] + fn dpapi_roundtrip() { + let plaintext = b"Ed25519-secret-key-bytes-here-32"; + let encrypted = ahandd::dpapi::protect(plaintext).unwrap(); + + assert_ne!( + &encrypted[..], + plaintext.as_slice(), + "ciphertext should differ from plaintext" + ); + assert!( + encrypted.len() > plaintext.len(), + "ciphertext should be larger than plaintext" + ); + + let decrypted = ahandd::dpapi::unprotect(&encrypted).unwrap(); + assert_eq!( + &decrypted[..], + plaintext.as_slice(), + "roundtrip decryption failed" + ); + } + + #[test] + fn dpapi_empty_input() { + let encrypted = ahandd::dpapi::protect(b"").unwrap(); + let decrypted = ahandd::dpapi::unprotect(&encrypted).unwrap(); + assert!(decrypted.is_empty(), "empty input should roundtrip to empty"); + } + + #[test] + fn dpapi_large_payload() { + let plaintext = vec![0xABu8; 4096]; + let encrypted = ahandd::dpapi::protect(&plaintext).unwrap(); + let decrypted = ahandd::dpapi::unprotect(&encrypted).unwrap(); + assert_eq!(decrypted, plaintext, "large payload roundtrip failed"); + } + + #[test] + fn dpapi_unprotect_invalid_data() { + let garbage = b"this is not valid DPAPI ciphertext"; + let result = ahandd::dpapi::unprotect(garbage); + assert!( + result.is_err(), + "unprotect should fail on invalid ciphertext" + ); + } +} diff --git a/crates/ahandd/tests/fs_perms_test.rs b/crates/ahandd/tests/fs_perms_test.rs new file mode 100644 index 00000000..5a094f7f --- /dev/null +++ b/crates/ahandd/tests/fs_perms_test.rs @@ -0,0 +1,125 @@ +use tempfile::NamedTempFile; + +// We test against the ahandd crate's public fs_perms module. +// On Unix we verify with std::os::unix::fs::PermissionsExt. +// On Windows the DACL verification would be tested in CI. + +#[test] +fn test_restrict_owner_only() { + let tmp = NamedTempFile::new().unwrap(); + let path = tmp.path(); + std::fs::write(path, "secret").unwrap(); + + ahandd::fs_perms::restrict_owner_only(path).unwrap(); + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mode = std::fs::metadata(path).unwrap().permissions().mode() & 0o777; + assert_eq!(mode, 0o600, "expected 0o600, got 0o{:03o}", mode); + } +} + +#[test] +fn test_restrict_owner_and_group() { + let tmp = NamedTempFile::new().unwrap(); + let path = tmp.path(); + std::fs::write(path, "shared").unwrap(); + + ahandd::fs_perms::restrict_owner_and_group(path).unwrap(); + + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + let mode = std::fs::metadata(path).unwrap().permissions().mode() & 0o777; + assert_eq!(mode, 0o660, "expected 0o660, got 0o{:03o}", mode); + } +} + +#[test] +fn test_restrict_owner_only_nonexistent_file() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("does_not_exist.txt"); + + let result = ahandd::fs_perms::restrict_owner_only(&path); + assert!(result.is_err(), "should fail for nonexistent file"); +} + +#[test] +fn test_restrict_owner_and_group_nonexistent_file() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("does_not_exist.txt"); + + let result = ahandd::fs_perms::restrict_owner_and_group(&path); + assert!(result.is_err(), "should fail for nonexistent file"); +} + +#[cfg(unix)] +#[test] +fn test_restrict_owner_only_idempotent() { + use std::os::unix::fs::PermissionsExt; + + let tmp = NamedTempFile::new().unwrap(); + let path = tmp.path(); + std::fs::write(path, "data").unwrap(); + + // Apply twice -- should succeed both times with same result. + ahandd::fs_perms::restrict_owner_only(path).unwrap(); + ahandd::fs_perms::restrict_owner_only(path).unwrap(); + + let mode = std::fs::metadata(path).unwrap().permissions().mode() & 0o777; + assert_eq!(mode, 0o600); +} + +#[cfg(unix)] +#[test] +fn test_restrict_owner_and_group_idempotent() { + use std::os::unix::fs::PermissionsExt; + + let tmp = NamedTempFile::new().unwrap(); + let path = tmp.path(); + std::fs::write(path, "data").unwrap(); + + // Apply twice -- should succeed both times with same result. + ahandd::fs_perms::restrict_owner_and_group(path).unwrap(); + ahandd::fs_perms::restrict_owner_and_group(path).unwrap(); + + let mode = std::fs::metadata(path).unwrap().permissions().mode() & 0o777; + assert_eq!(mode, 0o660); +} + +#[cfg(unix)] +#[test] +fn test_restrict_owner_only_from_permissive() { + use std::os::unix::fs::PermissionsExt; + + let tmp = NamedTempFile::new().unwrap(); + let path = tmp.path(); + std::fs::write(path, "data").unwrap(); + + // Start with wide-open permissions. + std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o777)).unwrap(); + + ahandd::fs_perms::restrict_owner_only(path).unwrap(); + + let mode = std::fs::metadata(path).unwrap().permissions().mode() & 0o777; + assert_eq!(mode, 0o600, "should restrict from 0o777 to 0o600"); +} + +#[cfg(unix)] +#[test] +fn test_restrict_owner_and_group_from_permissive() { + use std::os::unix::fs::PermissionsExt; + + let tmp = NamedTempFile::new().unwrap(); + let path = tmp.path(); + std::fs::write(path, "data").unwrap(); + + // Start with wide-open permissions. + std::fs::set_permissions(path, std::fs::Permissions::from_mode(0o777)).unwrap(); + + ahandd::fs_perms::restrict_owner_and_group(path).unwrap(); + + let mode = std::fs::metadata(path).unwrap().permissions().mode() & 0o777; + assert_eq!(mode, 0o660, "should restrict from 0o777 to 0o660"); +} diff --git a/crates/ahandd/tests/ipc_roundtrip.rs b/crates/ahandd/tests/ipc_roundtrip.rs new file mode 100644 index 00000000..7bc7e9e7 --- /dev/null +++ b/crates/ahandd/tests/ipc_roundtrip.rs @@ -0,0 +1,682 @@ +//! IPC frame protocol integration tests. +//! +//! Verifies the length-prefixed frame protocol works correctly, +//! which is the foundation for cross-platform IPC communication. +//! On Unix we test over Unix domain sockets; on Windows these tests +//! would exercise the same framing logic over named pipes (not yet +//! wired here — the protocol layer is platform-agnostic). + +use tokio::io::{AsyncReadExt, AsyncWriteExt}; + +/// Write a length-prefixed frame: `[4-byte big-endian length][payload]`. +/// +/// This mirrors the production `write_frame` in `ipc.rs`. +async fn write_frame(w: &mut W, data: &[u8]) -> std::io::Result<()> { + w.write_u32(data.len() as u32).await?; + w.write_all(data).await?; + w.flush().await +} + +/// Read a length-prefixed frame, rejecting payloads >16 MiB. +/// +/// This mirrors the production `read_frame` in `ipc.rs`. +async fn read_frame(r: &mut R) -> std::io::Result> { + let len = r.read_u32().await? as usize; + if len > 16 * 1024 * 1024 { + return Err(std::io::Error::new( + std::io::ErrorKind::InvalidData, + "frame too large", + )); + } + let mut buf = vec![0u8; len]; + r.read_exact(&mut buf).await?; + Ok(buf) +} + +// --------------------------------------------------------------------------- +// Unix-socket tests (cfg(unix)) +// --------------------------------------------------------------------------- + +#[cfg(unix)] +mod unix { + use super::*; + + #[tokio::test] + async fn frame_roundtrip() { + let dir = tempfile::tempdir().unwrap(); + let sock_path = dir.path().join("test.sock"); + + let listener = tokio::net::UnixListener::bind(&sock_path).unwrap(); + + let path = sock_path.clone(); + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (reader, mut writer) = stream.into_split(); + let mut reader = tokio::io::BufReader::new(reader); + + // Read a frame from the client. + let data = read_frame(&mut reader).await.unwrap(); + assert_eq!(data, b"hello from client"); + + // Send a frame back. + write_frame(&mut writer, b"hello from server").await.unwrap(); + }); + + let client = tokio::spawn(async move { + let stream = tokio::net::UnixStream::connect(&path).await.unwrap(); + let (reader, mut writer) = stream.into_split(); + let mut reader = tokio::io::BufReader::new(reader); + + write_frame(&mut writer, b"hello from client").await.unwrap(); + + let data = read_frame(&mut reader).await.unwrap(); + assert_eq!(data, b"hello from server"); + }); + + server.await.unwrap(); + client.await.unwrap(); + } + + #[tokio::test] + async fn frame_protocol_with_protobuf() { + use ahand_protocol::{envelope, Envelope, JobRequest}; + use prost::Message; + + let dir = tempfile::tempdir().unwrap(); + let sock_path = dir.path().join("proto.sock"); + + let listener = tokio::net::UnixListener::bind(&sock_path).unwrap(); + + let path = sock_path.clone(); + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (reader, _writer) = stream.into_split(); + let mut reader = tokio::io::BufReader::new(reader); + + let data = read_frame(&mut reader).await.unwrap(); + let env = Envelope::decode(data.as_slice()).unwrap(); + + assert_eq!(env.device_id, "test-device"); + if let Some(envelope::Payload::JobRequest(req)) = env.payload { + assert_eq!(req.job_id, "job-1"); + assert_eq!(req.tool, "echo"); + assert_eq!(req.args, vec!["hello".to_string()]); + } else { + panic!("expected JobRequest payload"); + } + }); + + let client = tokio::spawn(async move { + let stream = tokio::net::UnixStream::connect(&path).await.unwrap(); + let (_reader, mut writer) = stream.into_split(); + + let env = Envelope { + device_id: "test-device".to_string(), + msg_id: "msg-1".to_string(), + ts_ms: 12345, + payload: Some(envelope::Payload::JobRequest(JobRequest { + job_id: "job-1".to_string(), + tool: "echo".to_string(), + args: vec!["hello".to_string()], + ..Default::default() + })), + ..Default::default() + }; + + write_frame(&mut writer, &env.encode_to_vec()).await.unwrap(); + }); + + server.await.unwrap(); + client.await.unwrap(); + } + + #[tokio::test] + async fn frame_multiple_messages() { + let dir = tempfile::tempdir().unwrap(); + let sock_path = dir.path().join("multi.sock"); + + let listener = tokio::net::UnixListener::bind(&sock_path).unwrap(); + + let path = sock_path.clone(); + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (reader, _) = stream.into_split(); + let mut reader = tokio::io::BufReader::new(reader); + + for i in 0..5u32 { + let data = read_frame(&mut reader).await.unwrap(); + assert_eq!(data, format!("message-{i}").as_bytes()); + } + }); + + let client = tokio::spawn(async move { + let stream = tokio::net::UnixStream::connect(&path).await.unwrap(); + let (_, mut writer) = stream.into_split(); + + for i in 0..5u32 { + write_frame(&mut writer, format!("message-{i}").as_bytes()) + .await + .unwrap(); + } + }); + + server.await.unwrap(); + client.await.unwrap(); + } + + #[tokio::test] + async fn frame_empty_message() { + let dir = tempfile::tempdir().unwrap(); + let sock_path = dir.path().join("empty.sock"); + + let listener = tokio::net::UnixListener::bind(&sock_path).unwrap(); + + let path = sock_path.clone(); + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (reader, _) = stream.into_split(); + let mut reader = tokio::io::BufReader::new(reader); + + let data = read_frame(&mut reader).await.unwrap(); + assert!(data.is_empty()); + }); + + let client = tokio::spawn(async move { + let stream = tokio::net::UnixStream::connect(&path).await.unwrap(); + let (_, mut writer) = stream.into_split(); + write_frame(&mut writer, b"").await.unwrap(); + }); + + server.await.unwrap(); + client.await.unwrap(); + } + + #[tokio::test] + async fn frame_too_large_rejected() { + let dir = tempfile::tempdir().unwrap(); + let sock_path = dir.path().join("large.sock"); + + let listener = tokio::net::UnixListener::bind(&sock_path).unwrap(); + + let path = sock_path.clone(); + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (reader, _) = stream.into_split(); + let mut reader = tokio::io::BufReader::new(reader); + + let result = read_frame(&mut reader).await; + assert!(result.is_err()); + let err = result.unwrap_err(); + assert_eq!(err.kind(), std::io::ErrorKind::InvalidData); + assert!(err.to_string().contains("frame too large")); + }); + + let client = tokio::spawn(async move { + let stream = tokio::net::UnixStream::connect(&path).await.unwrap(); + let (_, mut writer) = stream.into_split(); + + // Write a length header claiming 17 MiB (exceeds 16 MiB limit), + // without actually sending that much data. + let fake_len: u32 = 17 * 1024 * 1024; + writer.write_u32(fake_len).await.unwrap(); + writer.flush().await.unwrap(); + }); + + server.await.unwrap(); + client.await.unwrap(); + } + + #[tokio::test] + async fn frame_max_allowed_size_accepted() { + let dir = tempfile::tempdir().unwrap(); + let sock_path = dir.path().join("maxsize.sock"); + + let listener = tokio::net::UnixListener::bind(&sock_path).unwrap(); + + let path = sock_path.clone(); + // Use a smaller payload to avoid allocating 16 MiB in tests. + // We verify the boundary by testing a frame of exactly 1024 bytes. + let payload = vec![0xABu8; 1024]; + let expected = payload.clone(); + + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (reader, _) = stream.into_split(); + let mut reader = tokio::io::BufReader::new(reader); + + let data = read_frame(&mut reader).await.unwrap(); + assert_eq!(data, expected); + }); + + let client = tokio::spawn(async move { + let stream = tokio::net::UnixStream::connect(&path).await.unwrap(); + let (_, mut writer) = stream.into_split(); + write_frame(&mut writer, &payload).await.unwrap(); + }); + + server.await.unwrap(); + client.await.unwrap(); + } + + #[tokio::test] + async fn frame_bidirectional_protobuf() { + use ahand_protocol::{envelope, Envelope, JobFinished, JobRequest}; + use prost::Message; + + let dir = tempfile::tempdir().unwrap(); + let sock_path = dir.path().join("bidir.sock"); + + let listener = tokio::net::UnixListener::bind(&sock_path).unwrap(); + + let path = sock_path.clone(); + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (reader, mut writer) = stream.into_split(); + let mut reader = tokio::io::BufReader::new(reader); + + // Receive a JobRequest. + let data = read_frame(&mut reader).await.unwrap(); + let env = Envelope::decode(data.as_slice()).unwrap(); + let job_id = match env.payload { + Some(envelope::Payload::JobRequest(ref req)) => req.job_id.clone(), + other => panic!("expected JobRequest, got {other:?}"), + }; + + // Respond with JobFinished. + let response = Envelope { + device_id: "daemon".to_string(), + msg_id: "resp-1".to_string(), + payload: Some(envelope::Payload::JobFinished(JobFinished { + job_id, + exit_code: 0, + error: String::new(), + })), + ..Default::default() + }; + write_frame(&mut writer, &response.encode_to_vec()) + .await + .unwrap(); + }); + + let client = tokio::spawn(async move { + let stream = tokio::net::UnixStream::connect(&path).await.unwrap(); + let (reader, mut writer) = stream.into_split(); + let mut reader = tokio::io::BufReader::new(reader); + + // Send a JobRequest. + let request = Envelope { + device_id: "client".to_string(), + msg_id: "req-1".to_string(), + payload: Some(envelope::Payload::JobRequest(JobRequest { + job_id: "job-42".to_string(), + tool: "ls".to_string(), + args: vec!["-la".to_string()], + cwd: "/tmp".to_string(), + ..Default::default() + })), + ..Default::default() + }; + write_frame(&mut writer, &request.encode_to_vec()) + .await + .unwrap(); + + // Read the response. + let data = read_frame(&mut reader).await.unwrap(); + let env = Envelope::decode(data.as_slice()).unwrap(); + match env.payload { + Some(envelope::Payload::JobFinished(finished)) => { + assert_eq!(finished.job_id, "job-42"); + assert_eq!(finished.exit_code, 0); + assert!(finished.error.is_empty()); + } + other => panic!("expected JobFinished, got {other:?}"), + } + }); + + server.await.unwrap(); + client.await.unwrap(); + } + + #[tokio::test] + async fn frame_truncated_payload_returns_error() { + let dir = tempfile::tempdir().unwrap(); + let sock_path = dir.path().join("trunc.sock"); + + let listener = tokio::net::UnixListener::bind(&sock_path).unwrap(); + + let path = sock_path.clone(); + let server = tokio::spawn(async move { + let (stream, _) = listener.accept().await.unwrap(); + let (reader, _) = stream.into_split(); + let mut reader = tokio::io::BufReader::new(reader); + + // The client claims 100 bytes but only sends 5, then closes. + // read_frame should return an error (UnexpectedEof). + let result = read_frame(&mut reader).await; + assert!(result.is_err()); + }); + + let client = tokio::spawn(async move { + let stream = tokio::net::UnixStream::connect(&path).await.unwrap(); + let (_, mut writer) = stream.into_split(); + + // Write a length header of 100 but only 5 bytes of payload, + // then drop the connection. + writer.write_u32(100).await.unwrap(); + writer.write_all(&[1, 2, 3, 4, 5]).await.unwrap(); + writer.flush().await.unwrap(); + drop(writer); + }); + + server.await.unwrap(); + client.await.unwrap(); + } +} + +// --------------------------------------------------------------------------- +// Full serve_ipc integration tests (cfg(unix)) +// --------------------------------------------------------------------------- + +#[cfg(unix)] +mod serve_ipc_integration { + use std::sync::Arc; + + use ahand_protocol::{envelope, Envelope, JobRequest, SessionMode, SessionQuery}; + use ahandd::{approval, browser, config, ipc, registry, session}; + use prost::Message; + use tokio::io::{AsyncReadExt, AsyncWriteExt}; + + /// Helper: write a length-prefixed protobuf frame. + async fn write_envelope(w: &mut W, env: &Envelope) { + let data = env.encode_to_vec(); + w.write_u32(data.len() as u32).await.unwrap(); + w.write_all(&data).await.unwrap(); + w.flush().await.unwrap(); + } + + /// Helper: read one length-prefixed protobuf frame with a timeout. + async fn read_envelope(r: &mut R) -> Option { + let len = tokio::time::timeout( + std::time::Duration::from_secs(5), + r.read_u32(), + ) + .await + .ok()? + .ok()? as usize; + + let mut buf = vec![0u8; len]; + r.read_exact(&mut buf).await.ok()?; + Envelope::decode(buf.as_slice()).ok() + } + + /// Spin up a real `serve_ipc` server bound to a temp Unix socket and return + /// the socket path and the server's `JoinHandle` (caller must abort it). + async fn start_server( + dir: &tempfile::TempDir, + session_mgr: Arc, + ) -> (std::path::PathBuf, tokio::task::JoinHandle>) { + let sock_path = dir.path().join("test.sock"); + let registry = Arc::new(registry::JobRegistry::new(4)); + let approval_mgr = Arc::new(approval::ApprovalManager::new(300)); + let browser_cfg = config::BrowserConfig::default(); + let browser_mgr = Arc::new(browser::BrowserManager::new(browser_cfg)); + let (broadcast_tx, _) = tokio::sync::broadcast::channel::(16); + + let path_clone = sock_path.clone(); + let handle = tokio::spawn(ipc::serve_ipc( + path_clone, + 0o660, + registry, + None, + session_mgr, + approval_mgr, + broadcast_tx, + "test-device".to_string(), + browser_mgr, + )); + + // Wait for socket to be ready (poll instead of fixed sleep). + let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(5); + while !sock_path.exists() { + if tokio::time::Instant::now() > deadline { + panic!("IPC server did not start within 5 seconds"); + } + tokio::time::sleep(std::time::Duration::from_millis(10)).await; + } + + (sock_path, handle) + } + + // ----------------------------------------------------------------------- + // Test: send a JobRequest and receive JobFinished or JobRejected. + // ----------------------------------------------------------------------- + + #[tokio::test] + async fn serve_ipc_job_roundtrip() { + let dir = tempfile::tempdir().unwrap(); + let session_mgr = Arc::new(session::SessionManager::new(60)); + + let (sock_path, server_handle) = start_server(&dir, Arc::clone(&session_mgr)).await; + + // Connect as a client. + let stream = tokio::net::UnixStream::connect(&sock_path).await.unwrap(); + let (reader, mut writer) = stream.into_split(); + let mut reader = tokio::io::BufReader::new(reader); + + // Send a JobRequest. + let job_id = format!("test-job-{}", std::process::id()); + let req = Envelope { + device_id: "test-client".to_string(), + msg_id: "msg-1".to_string(), + ts_ms: 0, + payload: Some(envelope::Payload::JobRequest(JobRequest { + job_id: job_id.clone(), + tool: "echo".to_string(), + args: vec!["hello".to_string()], + ..Default::default() + })), + ..Default::default() + }; + write_envelope(&mut writer, &req).await; + + // Read responses until we get a JobFinished or JobRejected for our job. + let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(10); + let mut got_response = false; + while tokio::time::Instant::now() < deadline { + let env = match read_envelope(&mut reader).await { + Some(e) => e, + None => break, + }; + match env.payload { + Some(envelope::Payload::JobFinished(ref fin)) if fin.job_id == job_id => { + got_response = true; + break; + } + Some(envelope::Payload::JobRejected(ref rej)) if rej.job_id == job_id => { + // Default session mode is Inactive, so rejection is expected. + // A rejection still proves the server processed our request. + got_response = true; + break; + } + _ => continue, + } + } + assert!( + got_response, + "did not receive JobFinished or JobRejected within timeout" + ); + + server_handle.abort(); + } + + // ----------------------------------------------------------------------- + // Test: send a JobRequest with AutoAccept mode and get JobFinished. + // ----------------------------------------------------------------------- + + #[tokio::test] + async fn serve_ipc_job_auto_accept() { + let dir = tempfile::tempdir().unwrap(); + let session_mgr = Arc::new(session::SessionManager::new(60)); + // Pre-set the default mode so new callers are auto-accepted. + session_mgr + .set_default_mode(SessionMode::AutoAccept) + .await; + + let (sock_path, server_handle) = start_server(&dir, Arc::clone(&session_mgr)).await; + + let stream = tokio::net::UnixStream::connect(&sock_path).await.unwrap(); + let (reader, mut writer) = stream.into_split(); + let mut reader = tokio::io::BufReader::new(reader); + + let job_id = format!("test-auto-{}", std::process::id()); + let req = Envelope { + device_id: "test-client".to_string(), + msg_id: "msg-auto".to_string(), + ts_ms: 0, + payload: Some(envelope::Payload::JobRequest(JobRequest { + job_id: job_id.clone(), + tool: "echo".to_string(), + args: vec!["hello".to_string()], + ..Default::default() + })), + ..Default::default() + }; + write_envelope(&mut writer, &req).await; + + // With AutoAccept the job should be accepted and run. + // We expect a JobFinished (echo should complete quickly). + let deadline = tokio::time::Instant::now() + std::time::Duration::from_secs(10); + let mut got_finished = false; + while tokio::time::Instant::now() < deadline { + let env = match read_envelope(&mut reader).await { + Some(e) => e, + None => break, + }; + if let Some(envelope::Payload::JobFinished(ref fin)) = env.payload { + if fin.job_id == job_id { + got_finished = true; + break; + } + } + } + assert!(got_finished, "did not receive JobFinished within timeout"); + + server_handle.abort(); + } + + // ----------------------------------------------------------------------- + // Test: SessionQuery -> SessionState roundtrip. + // ----------------------------------------------------------------------- + + #[tokio::test] + async fn serve_ipc_session_query_roundtrip() { + let dir = tempfile::tempdir().unwrap(); + let session_mgr = Arc::new(session::SessionManager::new(60)); + + let (sock_path, server_handle) = start_server(&dir, Arc::clone(&session_mgr)).await; + + let stream = tokio::net::UnixStream::connect(&sock_path).await.unwrap(); + let (reader, mut writer) = stream.into_split(); + let mut reader = tokio::io::BufReader::new(reader); + + // Give the server a moment to register our peer_cred via register_caller. + tokio::time::sleep(std::time::Duration::from_millis(200)).await; + + // Query our own session state using our UID. + // The server registers the caller_uid from peer_cred on connect, so + // querying with that UID should return a SessionState. + let our_uid = format!("uid:{}", unsafe { libc::getuid() }); + let query = Envelope { + device_id: "test-client".to_string(), + msg_id: "msg-session-query".to_string(), + ts_ms: 0, + payload: Some(envelope::Payload::SessionQuery(SessionQuery { + caller_uid: our_uid.clone(), + })), + ..Default::default() + }; + write_envelope(&mut writer, &query).await; + + // Read back a SessionState. + let env = read_envelope(&mut reader) + .await + .expect("expected SessionState response from server"); + + match env.payload { + Some(envelope::Payload::SessionState(state)) => { + assert_eq!(state.caller_uid, our_uid); + // Default mode is Inactive. + assert_eq!(state.mode, i32::from(SessionMode::Inactive)); + } + other => panic!("expected SessionState, got {other:?}"), + } + + server_handle.abort(); + } + + // ----------------------------------------------------------------------- + // Test: SessionQuery with empty caller_uid returns all sessions. + // ----------------------------------------------------------------------- + + #[tokio::test] + async fn serve_ipc_session_query_all() { + let dir = tempfile::tempdir().unwrap(); + let session_mgr = Arc::new(session::SessionManager::new(60)); + + let (sock_path, server_handle) = start_server(&dir, Arc::clone(&session_mgr)).await; + + let stream = tokio::net::UnixStream::connect(&sock_path).await.unwrap(); + let (reader, mut writer) = stream.into_split(); + let mut reader = tokio::io::BufReader::new(reader); + + // Let server register our peer_cred. + tokio::time::sleep(std::time::Duration::from_millis(100)).await; + + // Query all sessions (empty caller_uid). + let query = Envelope { + device_id: "test-client".to_string(), + msg_id: "msg-session-all".to_string(), + ts_ms: 0, + payload: Some(envelope::Payload::SessionQuery(SessionQuery { + caller_uid: String::new(), + })), + ..Default::default() + }; + write_envelope(&mut writer, &query).await; + + // We should get at least one SessionState (for our own connection). + let env = read_envelope(&mut reader) + .await + .expect("expected at least one SessionState response"); + + match env.payload { + Some(envelope::Payload::SessionState(state)) => { + // The caller_uid should be our uid. + let our_uid = format!("uid:{}", unsafe { libc::getuid() }); + assert_eq!(state.caller_uid, our_uid); + } + other => panic!("expected SessionState, got {other:?}"), + } + + server_handle.abort(); + } + + // ----------------------------------------------------------------------- + // Test: peer identity format on Unix. + // ----------------------------------------------------------------------- + + #[test] + fn peer_identity_format_unix() { + // On Unix, the serve_ipc_unix handler formats peer_cred as "uid:{number}". + let uid = unsafe { libc::getuid() }; + let identity = format!("uid:{uid}"); + assert!(identity.starts_with("uid:")); + // UID should be a valid non-negative integer. + let parsed: u32 = identity + .strip_prefix("uid:") + .unwrap() + .parse() + .expect("uid should be a valid u32"); + assert_eq!(parsed, uid); + } +} diff --git a/scripts/dist/install.ps1 b/scripts/dist/install.ps1 new file mode 100644 index 00000000..20d3c678 --- /dev/null +++ b/scripts/dist/install.ps1 @@ -0,0 +1,146 @@ +#Requires -Version 5.1 +<# +.SYNOPSIS + Install aHand (ahandd + ahandctl) on Windows. +.DESCRIPTION + Downloads prebuilt binaries from GitHub Releases and installs to ~/.ahand/bin/. + Verifies SHA-256 checksums and adds the install directory to the user PATH. +.EXAMPLE + irm https://raw.githubusercontent.com/team9ai/aHand/main/scripts/dist/install.ps1 | iex +.PARAMETER Version + Specific version to install (default: latest). +.PARAMETER InstallDir + Directory to install binaries (default: $env:USERPROFILE\.ahand\bin). +#> +param( + [string]$Version = "", + [string]$InstallDir = "$env:USERPROFILE\.ahand\bin" +) + +$ErrorActionPreference = "Stop" +$REPO = "team9ai/aHand" + +# ── Detect architecture ────────────────────────────────────────────── + +$arch = if ([Environment]::Is64BitOperatingSystem) { + if ($env:PROCESSOR_ARCHITECTURE -eq "ARM64") { "arm64" } else { "x64" } +} else { + Write-Error "32-bit Windows is not supported." + exit 1 +} +$suffix = "windows-$arch" + +# ── Determine version ──────────────────────────────────────────────── + +if (-not $Version) { + Write-Host "Fetching latest version..." + try { + $release = Invoke-RestMethod "https://api.github.com/repos/$REPO/releases/latest" ` + -Headers @{ "User-Agent" = "ahand-installer" } + $Version = $release.tag_name -replace '^rust-v', '' + } catch { + Write-Error "Failed to fetch latest release: $_" + exit 1 + } +} + +if (-not $Version) { + Write-Error "Could not determine version to install." + exit 1 +} + +Write-Host "" +Write-Host "Installing aHand v$Version ($suffix)..." + +# ── Create install directory ───────────────────────────────────────── + +New-Item -ItemType Directory -Force -Path $InstallDir | Out-Null + +# ── Download binaries ──────────────────────────────────────────────── + +$baseUrl = "https://github.com/$REPO/releases/download/rust-v$Version" + +foreach ($binary in @("ahandd", "ahandctl")) { + $asset = "$binary-$suffix.exe" + $url = "$baseUrl/$asset" + $dest = Join-Path $InstallDir "$binary.exe" + + Write-Host " Downloading $asset..." + try { + Invoke-WebRequest -Uri $url -OutFile $dest -UseBasicParsing + } catch { + Write-Error "Failed to download $asset from $url : $_" + exit 1 + } + Write-Host " Installed: $dest" +} + +# ── Verify checksums ──────────────────────────────────────────────── + +# Build a lookup from asset names (in checksum file) to installed paths. +$nameToPath = @{ + "ahandd-$suffix.exe" = Join-Path $InstallDir "ahandd.exe" + "ahandctl-$suffix.exe" = Join-Path $InstallDir "ahandctl.exe" +} + +$checksumUrl = "$baseUrl/checksums-rust-$suffix.txt" +try { + $checksums = Invoke-RestMethod -Uri $checksumUrl -Headers @{ "User-Agent" = "ahand-installer" } + $verified = @{} + foreach ($line in $checksums -split "`n") { + $line = $line.Trim() + if ($line -match "^([0-9a-f]+)\s+(.+)$") { + $expected = $Matches[1] + $assetName = $Matches[2].Trim() + $filePath = $nameToPath[$assetName] + if ($filePath -and (Test-Path $filePath)) { + $actual = (Get-FileHash $filePath -Algorithm SHA256).Hash.ToLower() + if ($actual -ne $expected) { + Remove-Item $filePath -Force -ErrorAction SilentlyContinue + Write-Error "Checksum mismatch for $assetName! Expected $expected, got $actual. File removed." + exit 1 + } else { + Write-Host " Checksum OK: $assetName" + $verified[$assetName] = $true + } + } + } + } + # Ensure both binaries were actually verified + foreach ($asset in $nameToPath.Keys) { + if (-not $verified.ContainsKey($asset)) { + Write-Error "Checksum entry missing for $asset — cannot verify binary integrity." + exit 1 + } + } +} catch { + Write-Error "Could not verify checksums: $_" + Write-Error "Installation aborted — cannot verify binary integrity." + exit 1 +} + +# ── Write version marker ──────────────────────────────────────────── + +$versionFile = Join-Path (Split-Path $InstallDir -Parent) "version" +$Version | Out-File -FilePath $versionFile -Encoding utf8 -NoNewline + +# ── Add to PATH ───────────────────────────────────────────────────── + +$userPath = [Environment]::GetEnvironmentVariable("PATH", "User") +if ($userPath -notlike "*$InstallDir*") { + [Environment]::SetEnvironmentVariable("PATH", "$userPath;$InstallDir", "User") + Write-Host "" + Write-Host "Added $InstallDir to user PATH." + Write-Host "Restart your terminal for PATH changes to take effect." +} + +# ── Done ───────────────────────────────────────────────────────────── + +Write-Host "" +Write-Host "aHand v$Version installed successfully!" +Write-Host " ahandd: $(Join-Path $InstallDir 'ahandd.exe')" +Write-Host " ahandctl: $(Join-Path $InstallDir 'ahandctl.exe')" +Write-Host "" +Write-Host "Get started:" +Write-Host " ahandctl configure" +Write-Host ""