diff --git a/ARCHITECTURE.md b/ARCHITECTURE.md index c1b1c996..0bdc759b 100644 --- a/ARCHITECTURE.md +++ b/ARCHITECTURE.md @@ -19,7 +19,7 @@ point. More detailed usage instructions and technical details can be found in ea ```mermaid block-beta columns 4 - + API["API/Language Bindings"]:4 RT["Runtime"]:4 CP["Control Plane"]:1 @@ -59,7 +59,7 @@ Implements the TCP, UDP, and raw socket abstractions for the overlay network (ta ### Control Plane -Crates that communicate with Tailscale's control plane (or Headscale) and provide configuration for the data plane. The control plane handles authentication/authorization, node registration, policy updates, network map distribution, and much more for the nodes in a tailnet. +Crates that communicate with Tailscale's control plane (or Headscale) and provide configuration for the data plane. The control plane handles authentication/authorization, node registration, policy updates, network map distribution, and much more for the nodes in a tailnet. - [`ts_control`](ts_control/src/lib.rs): control plane client that handles registration, authorization/authentication, configuration, and streaming updates. - [`ts_control_noise`](ts_control_noise/src/lib.rs): abstraction that wraps control plane communications in a Noise IK tunnel, transparently handling cryptography for the client. @@ -78,7 +78,7 @@ Types and (de)serialization code for control plane traffic "on the wire". `ts_co Crates that communicate with other Tailscale nodes on the tailnet. The data plane is responsible for actually exchanging packets between peers on the tailnet, including transport management (DERP, TUN, etc.), routing, packet filtering, and tunneling. -- [`ts_dataplane`](ts_dataplane/src/lib.rs): wires all the individual data plane functions together, flowing inbound and outbound packets through the components in the correct order. The various data plane components are described below. +- [`ts_dataplane`](ts_dataplane/src/lib.rs): wires all the individual data plane functions together, flowing inbound and outbound packets through the components in the correct order. The various data plane components are described below. #### Packet Filtering @@ -88,13 +88,13 @@ Crates that communicate with other Tailscale nodes on the tailnet. The data plan #### Routing - - [`ts_overlay_router`](ts_overlay_router/src/lib.rs): routing table implementation for overlay (tailnet) traffic; determines which peer to send outbound traffic to, and which overlay transport should receive inbound packets. + - [`ts_overlay_router`](ts_overlay_router/src/lib.rs): routing table implementation for overlay (tailnet) traffic; determines which peer to send outbound traffic to, and which overlay transport should receive inbound packets. - [`ts_underlay_router`](ts_underlay_router/src/lib.rs): routing table implementation for underlay traffic; determines which underlay transport an outbound packet should be sent from, if any. #### Transports - - [`ts_transport`](ts_transport/src/lib.rs): traits that define transports and how they move traffic in and out of the overlay/underlay network. - - [`ts_transport_derp`](ts_transport_derp/src/lib.rs): an underlay transport that exchanges packets between nodes via Designated Encrypted Relay for Packets (DERP) relay servers. + - [`ts_transport`](ts_transport/src/lib.rs): traits that define transports and how they move traffic in and out of the overlay/underlay network. + - [`ts_derp`](ts_derp/src/lib.rs): an underlay transport that exchanges packets between nodes via Designated Encrypted Relay for Packets (DERP) relay servers. - [`ts_transport_tun`](ts_transport_tun/src/lib.rs): an overlay transport that exposes a TUN device on the local machine to send/receive packets on the overlay network (tailnet). #### Tunneling @@ -106,9 +106,9 @@ Crates that communicate with other Tailscale nodes on the tailnet. The data plan Crates used throughout the codebase that provide generic algorithms, data structures, cross-cutting concerns, or development tooling. #### Algorithms and Data Structures - - [`ts_array256`](ts_array256/src/lib.rs): sparse array of 256 elements with configurable backing store, used with `ts_bart`. - - [`ts_bart`](ts_bart/README.md): BAlanced Routing Table (BART) data structure for fast IP address/prefix search in routing tables and packet filtering. - - [`ts_bitset`](ts_bitset/src/lib.rs): fixed-width bitset used to track presence of elements in `ts_array256`. + - [`ts_array256`](ts_array256/src/lib.rs): sparse array of 256 elements with configurable backing store, used with `ts_bart`. + - [`ts_bart`](ts_bart/README.md): BAlanced Routing Table (BART) data structure for fast IP address/prefix search in routing tables and packet filtering. + - [`ts_bitset`](ts_bitset/src/lib.rs): fixed-width bitset used to track presence of elements in `ts_array256`. - [`ts_dynbitset`](ts_dynbitset/src/lib.rs): growable bitset built on top of `ts_bitset`, used with `ts_bart_packetfilter`. - [`ts_keys`](ts_keys/src/lib.rs): data structures representing all of Tailscale's x25519 keys (disco, node, machine, etc.). - [`ts_packet`](ts_packet/src/lib.rs): base types representing network packets. @@ -121,12 +121,12 @@ Crates used throughout the codebase that provide generic algorithms, data struct #### Examples, Debugging, and Testing - [`ts_cli_util`](ts_cli_util/src/lib.rs): helpers for writing command line tools and initializing logging, used in examples. - - [`ts_test_util`](ts_test_util/src/lib.rs): common code used by our unit and integration tests, such as determining if the network is available. + - [`ts_test_util`](ts_test_util/src/lib.rs): common code used by our unit and integration tests, such as determining if the network is available. - [`ts_hexdump`](ts_hexdump/src/lib.rs): traits and functions to generate canonical hexdumps of buffers for debug logging. #### Protocols - [`ts_disco_protocol`](ts_disco_protocol/src/lib.rs): incomplete implementation of Tailscale's discovery protocol (disco). - - [`ts_http_util`](ts_http_util/src/lib.rs): HTTP/1 and HTTP/2 client utilities used in `ts_control` and `ts_transport_derp`. + - [`ts_http_util`](ts_http_util/src/lib.rs): HTTP/1 and HTTP/2 client utilities used in `ts_control` and `ts_derp`. - [`ts_tls_util`](ts_tls_util/src/lib.rs): Transport Layer Sockets (TLS) utilities to manage certificates and establish secure connections over HTTP. #### Time diff --git a/Cargo.lock b/Cargo.lock index dd834ca8..f152427d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4140,8 +4140,8 @@ dependencies = [ "tracing-tracy", "tracy-client", "ts_control", + "ts_derp", "ts_netcheck", - "ts_transport_derp", ] [[package]] @@ -4166,6 +4166,7 @@ dependencies = [ "ts_capabilityversion", "ts_control_noise", "ts_control_serde", + "ts_derp", "ts_dynbitset", "ts_http_util", "ts_keys", @@ -4173,7 +4174,6 @@ dependencies = [ "ts_packetfilter", "ts_packetfilter_state", "ts_tls_util", - "ts_transport_derp", "url", "zerocopy", ] @@ -4229,7 +4229,6 @@ dependencies = [ "tokio", "tracing", "ts_bart", - "ts_keys", "ts_overlay_router", "ts_packet", "ts_packetfilter", @@ -4239,6 +4238,34 @@ dependencies = [ "ts_underlay_router", ] +[[package]] +name = "ts_derp" +version = "0.2.0" +dependencies = [ + "bytes", + "crypto_box", + "futures", + "hex", + "reqwest", + "serde", + "serde_json", + "thiserror", + "tokio", + "tokio-util", + "tracing", + "ts_cli_util", + "ts_control_serde", + "ts_hexdump", + "ts_http_util", + "ts_keys", + "ts_packet", + "ts_tls_util", + "ts_transport", + "url", + "yoke", + "zerocopy", +] + [[package]] name = "ts_devtools" version = "0.2.0" @@ -4249,12 +4276,10 @@ dependencies = [ "tokio", "tracing", "ts_cli_util", + "ts_derp", "ts_keys", - "ts_packet", "ts_packetfilter", "ts_packetfilter_state", - "ts_transport", - "ts_transport_derp", ] [[package]] @@ -4356,9 +4381,9 @@ dependencies = [ "tracing-test", "ts_control", "ts_control_serde", + "ts_derp", "ts_http_util", "ts_test_util", - "ts_transport_derp", "url", ] @@ -4424,7 +4449,6 @@ dependencies = [ "itertools", "tracing", "ts_bart", - "ts_keys", "ts_packet", "ts_transport", ] @@ -4497,8 +4521,12 @@ version = "0.2.0" dependencies = [ "futures", "ipnet", + "itertools", "kameo", "kameo_actors", + "proptest", + "rand 0.10.1", + "smallvec", "thiserror", "tokio", "tracing", @@ -4506,6 +4534,7 @@ dependencies = [ "ts_bart_packetfilter", "ts_control", "ts_dataplane", + "ts_derp", "ts_keys", "ts_netcheck", "ts_netstack_smoltcp", @@ -4514,7 +4543,7 @@ dependencies = [ "ts_packetfilter", "ts_packetfilter_state", "ts_transport", - "ts_transport_derp", + "ts_tunnel", ] [[package]] @@ -4543,36 +4572,7 @@ dependencies = [ name = "ts_transport" version = "0.2.0" dependencies = [ - "ts_keys", - "ts_packet", -] - -[[package]] -name = "ts_transport_derp" -version = "0.2.0" -dependencies = [ - "bytes", - "crypto_box", - "futures", - "hex", - "reqwest", - "serde", - "serde_json", - "thiserror", - "tokio", - "tokio-util", - "tracing", - "ts_cli_util", - "ts_control_serde", - "ts_hexdump", - "ts_http_util", - "ts_keys", "ts_packet", - "ts_tls_util", - "ts_transport", - "url", - "yoke", - "zerocopy", ] [[package]] @@ -4618,7 +4618,6 @@ dependencies = [ name = "ts_underlay_router" version = "0.2.0" dependencies = [ - "ts_keys", "ts_packet", "ts_transport", ] diff --git a/Cargo.toml b/Cargo.toml index 0a65bc31..320097b3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -12,6 +12,7 @@ members = [ "ts_control_noise", "ts_control_serde", "ts_dataplane", + "ts_derp", "ts_devtools", "ts_disco_protocol", "ts_dynbitset", @@ -37,7 +38,6 @@ members = [ "ts_time", "ts_tls_util", "ts_transport", - "ts_transport_derp", "ts_transport_tun", "ts_underlay_router", "ts_tunnel", @@ -115,6 +115,7 @@ ts_control = { path = "ts_control", version = "0.2.0" } ts_control_noise = { path = "ts_control_noise", version = "0.2.0" } ts_control_serde = { path = "ts_control_serde", version = "0.2.0" } ts_dataplane = { path = "ts_dataplane", version = "0.2.0" } +ts_derp = { path = "ts_derp", version = "0.2.0" } ts_disco_protocol = { path = "ts_disco_protocol", version = "0.2.0" } ts_dynbitset = { path = "ts_dynbitset", version = "0.2.0" } ts_hexdump = { path = "ts_hexdump", version = "0.2.0" } @@ -136,7 +137,6 @@ ts_runtime = { path = "ts_runtime", version = "0.2.0" } ts_test_util = { path = "ts_test_util" } ts_time = { path = "ts_time", version = "0.2.0" } ts_transport = { path = "ts_transport", version = "0.2.0" } -ts_transport_derp = { path = "ts_transport_derp", version = "0.2.0" } ts_transport_tun = { path = "ts_transport_tun", version = "0.2.0" } ts_underlay_router = { path = "ts_underlay_router", version = "0.2.0" } ts_tunnel = { path = "ts_tunnel", version = "0.2.0" } diff --git a/ts_cli_util/Cargo.toml b/ts_cli_util/Cargo.toml index 46a30d74..9185e2b2 100644 --- a/ts_cli_util/Cargo.toml +++ b/ts_cli_util/Cargo.toml @@ -16,7 +16,7 @@ rust-version.workspace = true tailscale.workspace = true ts_control.workspace = true ts_netcheck.workspace = true -ts_transport_derp.workspace = true +ts_derp.workspace = true # Unconditionally required dependencies. cfg-if.workspace = true diff --git a/ts_cli_util/src/lib.rs b/ts_cli_util/src/lib.rs index 0860fed6..535b6e34 100644 --- a/ts_cli_util/src/lib.rs +++ b/ts_cli_util/src/lib.rs @@ -5,8 +5,8 @@ use std::sync::Arc; use futures_util::{Stream, StreamExt}; use tracing::level_filters::LevelFilter; use tracing_subscriber::{Layer, layer::SubscriberExt, util::SubscriberInitExt}; +use ts_derp::{RegionId, ServerConnInfo}; use ts_netcheck::RegionResult; -use ts_transport_derp::{RegionId, ServerConnInfo}; /// Result with a boxed [`core::error::Error`] trait object. pub type Result = core::result::Result>; diff --git a/ts_control/Cargo.toml b/ts_control/Cargo.toml index a99edb54..ebaa114b 100644 --- a/ts_control/Cargo.toml +++ b/ts_control/Cargo.toml @@ -24,7 +24,7 @@ ts_packet.workspace = true ts_packetfilter.workspace = true ts_packetfilter_state.workspace = true ts_tls_util.workspace = true -ts_transport_derp.workspace = true +ts_derp.workspace = true # Unconditionally required dependencies. bytes.workspace = true @@ -51,7 +51,7 @@ default = ["async_tokio"] async_tokio = ["dep:futures-util", "dep:tokio", "dep:tokio-stream"] # Allow derp connections to be made without verifying TLS certs. Only for use in tests. -insecure-derp = ["ts_transport_derp/insecure-for-tests"] +insecure-derp = ["ts_derp/insecure-for-tests"] # Allow control keys to be fetched over plain HTTP1 without TLS. Only for use in tests. insecure-keyfetch = [] diff --git a/ts_control/src/derp.rs b/ts_control/src/derp.rs index 6004d45d..036fcdb8 100644 --- a/ts_control/src/derp.rs +++ b/ts_control/src/derp.rs @@ -1,17 +1,17 @@ use alloc::collections::BTreeMap; -use ts_transport_derp::TlsValidationConfig; +use ts_derp::TlsValidationConfig; -/// The full derp state, a map of [`ts_transport_derp::RegionId`]s to [`Region`]s. -pub type Map = BTreeMap; +/// The full derp state, a map of [`ts_derp::RegionId`]s to [`Region`]s. +pub type Map = BTreeMap; -/// Convert a derp map from the [`ts_control_serde`] representation to the [`ts_transport_derp`] +/// Convert a derp map from the [`ts_control_serde`] representation to the [`ts_derp`] /// representation. pub fn convert_derp_map( derp_map: &ts_control_serde::DerpMap<'_>, -) -> impl Iterator { +) -> impl Iterator { derp_map.regions.iter().map(|(id, region)| { - let id = ts_transport_derp::RegionId((*id).into()); + let id = ts_derp::RegionId((*id).into()); let region: Region = region.into(); (id, region) @@ -22,10 +22,10 @@ pub fn convert_derp_map( #[derive(Debug, Clone, PartialEq, Eq)] pub struct Region { /// The info for this region. - pub info: ts_transport_derp::RegionInfo, + pub info: ts_derp::RegionInfo, /// Servers in this region. - pub servers: Vec, + pub servers: Vec, } impl From<&ts_control_serde::DerpRegion<'_>> for Region { @@ -37,15 +37,15 @@ impl From<&ts_control_serde::DerpRegion<'_>> for Region { } } -fn region_info(region: &ts_control_serde::DerpRegion) -> ts_transport_derp::RegionInfo { - ts_transport_derp::RegionInfo { +fn region_info(region: &ts_control_serde::DerpRegion) -> ts_derp::RegionInfo { + ts_derp::RegionInfo { name: region.name.to_string(), code: region.code.to_string(), no_measure_no_home: region.no_measure_no_home, } } -fn server(server: &ts_control_serde::DerpServer) -> ts_transport_derp::ServerConnInfo { +fn server(server: &ts_control_serde::DerpServer) -> ts_derp::ServerConnInfo { const DEFAULT_TLS_PORT: u16 = 443; let https_port = match server.derp_port { @@ -62,7 +62,7 @@ fn server(server: &ts_control_serde::DerpServer) -> ts_transport_derp::ServerCon tls_config = TlsValidationConfig::InsecureForTests; }; - ts_transport_derp::ServerConnInfo { + ts_derp::ServerConnInfo { hostname: server.hostname.to_string(), https_port, stun_port: server.stun_port.into(), @@ -77,13 +77,13 @@ fn server(server: &ts_control_serde::DerpServer) -> ts_transport_derp::ServerCon } } -fn convert_ip_usage(ip: ts_control_serde::DerpIpUsage) -> ts_transport_derp::IpUsage +fn convert_ip_usage(ip: ts_control_serde::DerpIpUsage) -> ts_derp::IpUsage where T: Copy, { match ip { - ts_control_serde::DerpIpUsage::Disable => ts_transport_derp::IpUsage::Disable, - ts_control_serde::DerpIpUsage::UseDns => ts_transport_derp::IpUsage::UseDns, - ts_control_serde::DerpIpUsage::FixedAddr(ip) => ts_transport_derp::IpUsage::FixedAddr(ip), + ts_control_serde::DerpIpUsage::Disable => ts_derp::IpUsage::Disable, + ts_control_serde::DerpIpUsage::UseDns => ts_derp::IpUsage::UseDns, + ts_control_serde::DerpIpUsage::FixedAddr(ip) => ts_derp::IpUsage::FixedAddr(ip), } } diff --git a/ts_control/src/map_request_builder.rs b/ts_control/src/map_request_builder.rs index f5ca7d6e..7c0ae494 100644 --- a/ts_control/src/map_request_builder.rs +++ b/ts_control/src/map_request_builder.rs @@ -67,7 +67,7 @@ impl<'a> MapRequestBuilder<'a> { /// Set the [`NetInfo::preferred_derp`] field (inside [`MapRequest::host_info`] -> /// [`HostInfo::net_info`]). - pub fn preferred_derp(mut self, value: ts_transport_derp::RegionId) -> Self { + pub fn preferred_derp(mut self, value: ts_derp::RegionId) -> Self { self.net_info_mut().preferred_derp = Some(value.0.into()); self } diff --git a/ts_control/src/node.rs b/ts_control/src/node.rs index 3f67bf5a..ad498c77 100644 --- a/ts_control/src/node.rs +++ b/ts_control/src/node.rs @@ -7,7 +7,7 @@ use ts_keys::{DiscoPublicKey, MachinePublicKey, NodePublicKey}; pub type Id = i64; /// The stable ID of a node. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct StableId(pub String); /// A node in a tailnet. @@ -46,7 +46,7 @@ pub struct Node { pub underlay_addresses: Vec, /// The DERP region for this node, if known. - pub derp_region: Option, + pub derp_region: Option, } impl Node { @@ -67,6 +67,17 @@ impl Node { } } + /// The fully-qualified domain name of the node, only returning `Some` if the tailnet + /// component is present. + /// + /// See [`Node::fqdn`]. + pub fn fqdn_opt(&self, trailing_dot: bool) -> Option { + let dot = if trailing_dot { "." } else { "" }; + let tailnet = self.tailnet.as_deref()?; + + Some(format!("{}.{tailnet}{dot}", self.hostname)) + } + /// Report whether this node matches the given `name`. /// /// `name` is checked for equality with both this node's bare hostname and its fqdn. A @@ -155,7 +166,7 @@ impl From<&ts_control_serde::Node<'_>> for Node { .home_derp .or(value.legacy_derp_string) .or_else(|| value.host_info.net_info.as_ref()?.preferred_derp) - .map(|x| ts_transport_derp::RegionId(x.into())), + .map(|x| ts_derp::RegionId(x.into())), } } } diff --git a/ts_control/src/tokio/client.rs b/ts_control/src/tokio/client.rs index 9aa664e5..a076b71e 100644 --- a/ts_control/src/tokio/client.rs +++ b/ts_control/src/tokio/client.rs @@ -125,7 +125,7 @@ impl AsyncControlClient { #[tracing::instrument(skip_all, fields(map_url = %self.map_url(), %region_id), level = "trace")] pub async fn set_home_region<'c>( &mut self, - region_id: ts_transport_derp::RegionId, + region_id: ts_derp::RegionId, latencies: impl IntoIterator, ) { tracing::trace!(region = %region_id, "reporting home derp to control server"); @@ -161,7 +161,7 @@ impl AsyncControlClient { #[derive(Debug)] pub enum Command { SetDerpHomeRegion { - id: ts_transport_derp::RegionId, + id: ts_derp::RegionId, latencies: BTreeMap, }, } diff --git a/ts_dataplane/Cargo.toml b/ts_dataplane/Cargo.toml index 0bf16d15..a683b2d2 100644 --- a/ts_dataplane/Cargo.toml +++ b/ts_dataplane/Cargo.toml @@ -16,7 +16,6 @@ async_tokio = ["dep:tokio"] [dependencies] # Our crates. -ts_keys.workspace = true ts_overlay_router.workspace = true ts_packet.workspace = true ts_packetfilter.workspace = true diff --git a/ts_dataplane/src/async_tokio.rs b/ts_dataplane/src/async_tokio.rs index 4cbd82ca..a0ad94c5 100644 --- a/ts_dataplane/src/async_tokio.rs +++ b/ts_dataplane/src/async_tokio.rs @@ -3,9 +3,8 @@ use std::{collections::HashMap, convert::Infallible, ops::DerefMut, sync::atomic::AtomicU32}; use tokio::sync::{Mutex, mpsc}; -use ts_keys::NodePublicKey; use ts_packet::PacketMut; -use ts_transport::{OverlayTransportId, UnderlayTransportId}; +use ts_transport::{OverlayTransportId, PeerId, UnderlayTransportId}; use ts_tunnel::NodeKeyPair; use crate::{EventResult, InboundResult, OutboundResult}; @@ -17,10 +16,10 @@ pub type DataplaneToOverlay = mpsc::UnboundedSender>; pub type DataplaneFromOverlay = mpsc::UnboundedReceiver>; /// Queue for packets leaving the data plane "down" into an underlay transport. -pub type DataplaneToUnderlay = mpsc::UnboundedSender<(NodePublicKey, Vec)>; +pub type DataplaneToUnderlay = mpsc::UnboundedSender<(PeerId, Vec)>; /// Queue for packets entering the data plane "up" from an underlay transport. -pub type DataplaneFromUnderlay = mpsc::UnboundedReceiver<(NodePublicKey, Vec)>; +pub type DataplaneFromUnderlay = mpsc::UnboundedReceiver<(PeerId, Vec)>; // TODO: wire in overlay/underlay transport traits @@ -147,7 +146,7 @@ impl DataPlane { pub async fn step(&self) { enum SelectResult { OverlayDown(Vec), - UnderlayUp(NodePublicKey, Vec), + UnderlayUp(PeerId, Vec), TransportsChanged, Event, } @@ -186,10 +185,10 @@ impl DataPlane { } underlay_pkts = underlay_up.recv() => { - let (node_key, underlay_pkts) = underlay_pkts.unwrap(); - tracing::trace!(%node_key, n_underlay_pkts = underlay_pkts.len()); + let (peer_id, underlay_pkts) = underlay_pkts.unwrap(); + tracing::trace!(%peer_id, n_underlay_pkts = underlay_pkts.len()); - SelectResult::UnderlayUp(node_key, underlay_pkts) + SelectResult::UnderlayUp(peer_id, underlay_pkts) } _ = self.transports_changed.notified() => { @@ -215,14 +214,7 @@ impl DataPlane { (Some(to_peers), Some(loopback)) } - SelectResult::UnderlayUp(node_key, underlay_up) => { - if core.sync.wireguard.peer_id(node_key).is_none() { - core.sync.wireguard.add_peer(ts_tunnel::PeerConfig { - key: node_key, - psk: [0u8; 32].into(), - }); - } - + SelectResult::UnderlayUp(_peer_id, underlay_up) => { let InboundResult { to_local, to_peers } = core.sync.process_inbound(underlay_up); (Some(to_peers), Some(to_local)) @@ -265,13 +257,13 @@ async fn write_to_overlay(slf: &CoreState, packets: HashMap)>, + packets: impl IntoIterator)>, ) { - for ((tid, node_key), packets) in packets { - tracing::trace!(underlay_id = ?tid, %node_key, n_packets = packets.len()); + for ((tid, peer_id), packets) in packets { + tracing::trace!(underlay_id = ?tid, %peer_id, n_packets = packets.len()); if let Some(queue) = slf.underlay_transports.get(&tid) { - queue.send((node_key, packets)).unwrap(); + queue.send((peer_id, packets)).unwrap(); } } } diff --git a/ts_dataplane/src/lib.rs b/ts_dataplane/src/lib.rs index e0b20509..5e8c6524 100644 --- a/ts_dataplane/src/lib.rs +++ b/ts_dataplane/src/lib.rs @@ -1,19 +1,18 @@ #![doc = include_str!("../README.md")] -pub mod async_tokio; - use std::{collections::HashMap, sync::Arc, time::Instant}; use ts_bart::RoutingTable; -use ts_keys::NodePublicKey; use ts_overlay_router as or; use ts_packet::PacketMut; use ts_packetfilter::{FilterExt, IpProto}; use ts_time::{Handle, Scheduler}; -use ts_transport::{OverlayTransportId, UnderlayTransportId}; -use ts_tunnel::{Endpoint, NodeKeyPair, PeerConfig}; +use ts_transport::{OverlayTransportId, PeerId, UnderlayTransportId}; +use ts_tunnel::{Endpoint, NodeKeyPair}; use ts_underlay_router as ur; +pub mod async_tokio; + /// A data plane subsystem that can be the subject of timer events. pub enum Subsystem { /// The wireguard component. @@ -31,7 +30,7 @@ pub struct DataPlane { pub ur_out: ur::outbound::Router, /// Inbound source filter. - pub src_filter_in: Arc>, + pub src_filter_in: Arc>, /// Inbound overlay router. pub or_in: or::inbound::Router, @@ -70,31 +69,16 @@ impl DataPlane { let to_wireguard = to_wireguard .into_iter() - .map(|(k, v)| { - let id = self - .wireguard - .peer_id(k) - .or_else(|| { - self.wireguard.add_peer(PeerConfig { - key: k, - psk: [0u8; 32].into(), - }) - }) - .unwrap(); - - (id, v) - }) + .map(|(k, v)| (ts_tunnel::PeerId(k.0), v)) .collect::>(); let ts_tunnel::SendResult { to_peers: encrypted, } = self.wireguard.send(to_wireguard); - let to_peers = self.ur_out.route( - encrypted - .into_iter() - .filter_map(|(k, v)| Some((self.wireguard.peer_key(k)?, v))), - ); + let to_peers = self + .ur_out + .route(encrypted.into_iter().map(|(k, v)| (PeerId(k.0), v))); if let Some(next) = self.wireguard.next_event() && let Some(prev) = self @@ -116,41 +100,34 @@ impl DataPlane { let to_local = to_local .into_iter() - .map(|(peer_id, mut packets)| { - let span = tracing::trace_span!("src_filter_inbound", peer_id = ?peer_id, n_packet = packets.len(), peer_key = tracing::field::Empty).entered(); - - let Some(key) = self.wireguard.peer_key(peer_id) else { - tracing::warn!("no nodekey for peer"); - return (peer_id, vec![]); - }; - - span.record("peer_key", tracing::field::display(key)); + .map(|(peer_id, mut packets)| -> Vec { + let _span = tracing::trace_span!( + "src_filter_inbound", + peer_id = ?peer_id, + n_packet = packets.len(), + ) + .entered(); packets.retain(|packet| { let Some(src) = packet.get_src_addr() else { tracing::trace!("does not look like ip packet"); return false; }; - let verdict = if let Some(allowed_key) = self.src_filter_in.lookup(src) { - *allowed_key == key + let verdict = if let Some(allowed_peer) = self.src_filter_in.lookup(src) { + *allowed_peer == PeerId(peer_id.0) } else { + tracing::trace!(remote_ip = %src, "unknown peer address"); false }; tracing::trace!(?src, verdict); verdict }); - (peer_id, packets) + packets }) - .map(|(k, mut v)| { - let span = tracing::trace_span!("packet_filter_inbound", peer_id = ?k, n_packet = v.len(), peer_key = tracing::field::Empty).entered(); - - let Some(key) = self.wireguard.peer_key(k) else { - tracing::warn!("no nodekey for peer"); - return (k, vec![]); - }; - - span.record("peer_key", tracing::field::display(key)); + .map(|mut v| { + let _span = + tracing::trace_span!("packet_filter_inbound", n_packet = v.len()).entered(); v.retain(|pkt| { let Ok(pkt) = etherparse::SlicedPacket::from_ip(pkt.as_ref()) else { @@ -159,12 +136,16 @@ impl DataPlane { }; let (proto, src, dst) = match pkt.net { - Some(etherparse::NetSlice::Ipv4(ipv4)) => { - (IpProto::new(ipv4.payload().ip_number.0 as _), ipv4.header().source_addr().into(), ipv4.header().destination_addr().into()) - } - Some(etherparse::NetSlice::Ipv6(ipv6)) => { - (IpProto::new(ipv6.payload().ip_number.0 as _), ipv6.header().source_addr().into(), ipv6.header().destination_addr().into()) - } + Some(etherparse::NetSlice::Ipv4(ipv4)) => ( + IpProto::new(ipv4.payload().ip_number.0 as _), + ipv4.header().source_addr().into(), + ipv4.header().destination_addr().into(), + ), + Some(etherparse::NetSlice::Ipv6(ipv6)) => ( + IpProto::new(ipv6.payload().ip_number.0 as _), + ipv6.header().source_addr().into(), + ipv6.header().destination_addr().into(), + ), _ => { unreachable!("unexpected packet kind"); } @@ -189,24 +170,21 @@ impl DataPlane { // TODO(npry): wire in nodecaps let caps = []; - let verdict = self.packet_filter - .can_access(&info, caps); + let verdict = self.packet_filter.can_access(&info, caps); tracing::trace!(?info, ?caps, verdict); verdict }); - (k, v) + v }); let to_peers = to_peers .into_iter() - .filter_map(|(k, v)| Some((self.wireguard.peer_key(k)?, v))); + .map(|(k, v)| (ts_transport::PeerId(k.0), v)); - let to_local = self - .or_in - .route(to_local.flat_map(|(_id, packets)| packets)); + let to_local = self.or_in.route(to_local.flatten()); let to_peers = self.ur_out.route(to_peers); if let Some(next) = self.wireguard.next_event() @@ -244,7 +222,7 @@ impl DataPlane { to_peers.extend( res.to_peers .into_iter() - .filter_map(|(id, pkts)| Some((self.wireguard.peer_key(id)?, pkts))), + .map(|(id, pkts)| (ts_transport::PeerId(id.0), pkts)), ); } } @@ -266,7 +244,7 @@ impl DataPlane { /// The result of processing outbound packets. pub struct OutboundResult { /// Packets to be sent into underlay transports for transmission. - pub to_peers: HashMap<(UnderlayTransportId, NodePublicKey), Vec>, + pub to_peers: HashMap<(UnderlayTransportId, PeerId), Vec>, /// Packets to be looped back and delivered to overlay transports. pub loopback: HashMap>, } @@ -276,12 +254,12 @@ pub struct InboundResult { /// Decrypted packets to be delivered to overlay transports. pub to_local: HashMap>, /// Encrypted packets to be sent to wireguard peers by the underlay. - pub to_peers: HashMap<(UnderlayTransportId, NodePublicKey), Vec>, + pub to_peers: HashMap<(UnderlayTransportId, PeerId), Vec>, } /// The result of processing an event. #[derive(Default)] pub struct EventResult { /// Encrypted packets to be sent to wireguard peers by the underlay. - pub to_peers: HashMap<(UnderlayTransportId, NodePublicKey), Vec>, + pub to_peers: HashMap<(UnderlayTransportId, PeerId), Vec>, } diff --git a/ts_transport_derp/Cargo.toml b/ts_derp/Cargo.toml similarity index 97% rename from ts_transport_derp/Cargo.toml rename to ts_derp/Cargo.toml index 9cead130..fe56268c 100644 --- a/ts_transport_derp/Cargo.toml +++ b/ts_derp/Cargo.toml @@ -1,5 +1,5 @@ [package] -name = "ts_transport_derp" +name = "ts_derp" version.workspace = true description = "tailscale derp client" categories = ["network-programming", "encoding", "asynchronous"] diff --git a/ts_transport_derp/README.md b/ts_derp/README.md similarity index 64% rename from ts_transport_derp/README.md rename to ts_derp/README.md index ab18af18..90db7697 100644 --- a/ts_transport_derp/README.md +++ b/ts_derp/README.md @@ -1,3 +1,3 @@ -# ts_transport_derp +# ts_derp Tailscale derp protocol and client. diff --git a/ts_transport_derp/examples/common/mod.rs b/ts_derp/examples/common/mod.rs similarity index 80% rename from ts_transport_derp/examples/common/mod.rs rename to ts_derp/examples/common/mod.rs index dcfa7577..f97a3277 100644 --- a/ts_transport_derp/examples/common/mod.rs +++ b/ts_derp/examples/common/mod.rs @@ -1,8 +1,8 @@ -//! Common code used by multiple `ts_transport_derp` examples. +//! Common code used by multiple `ts_derp` examples. use std::{collections::BTreeMap, num::NonZeroU32}; -use ts_transport_derp::{RegionId, ServerConnInfo, TlsValidationConfig}; +use ts_derp::{RegionId, ServerConnInfo, TlsValidationConfig}; /// ID of DERP Region #1, which is New York City. pub const REGION_1: RegionId = RegionId(NonZeroU32::new(1).unwrap()); @@ -49,14 +49,14 @@ pub async fn load_derp_map() -> BTreeMap> { .collect() } -fn convert_ip_usage(ip: ts_control_serde::DerpIpUsage) -> ts_transport_derp::IpUsage +fn convert_ip_usage(ip: ts_control_serde::DerpIpUsage) -> ts_derp::IpUsage where T: Copy, { match ip { - ts_control_serde::DerpIpUsage::Disable => ts_transport_derp::IpUsage::Disable, - ts_control_serde::DerpIpUsage::UseDns => ts_transport_derp::IpUsage::UseDns, - ts_control_serde::DerpIpUsage::FixedAddr(ip) => ts_transport_derp::IpUsage::FixedAddr(ip), + ts_control_serde::DerpIpUsage::Disable => ts_derp::IpUsage::Disable, + ts_control_serde::DerpIpUsage::UseDns => ts_derp::IpUsage::UseDns, + ts_control_serde::DerpIpUsage::FixedAddr(ip) => ts_derp::IpUsage::FixedAddr(ip), } } diff --git a/ts_transport_derp/examples/listen.rs b/ts_derp/examples/listen.rs similarity index 52% rename from ts_transport_derp/examples/listen.rs rename to ts_derp/examples/listen.rs index 19adc342..f88a0d40 100644 --- a/ts_transport_derp/examples/listen.rs +++ b/ts_derp/examples/listen.rs @@ -3,7 +3,6 @@ //! Intended to test ping/pong/keepalive. use ts_keys::NodeKeyPair; -use ts_transport::UnderlayTransport; mod common; @@ -16,19 +15,16 @@ async fn main() -> ts_cli_util::Result<()> { let keypair = NodeKeyPair::new(); - let client = ts_transport_derp::Client::connect(region, &keypair).await?; + let client = ts_derp::Client::connect(region, &keypair).await?; tracing::info!("derp handshake done"); loop { - for result in client.recv().await { - match result { - Ok((peer, pkts)) => { - let pkts = pkts.into_iter().collect::>(); - tracing::info!(?peer, ?pkts); - } - Err(e) => { - tracing::error!(err = %e, "recv"); - } + match client.recv_one().await { + Ok((peer, pkt)) => { + tracing::info!(?peer, ?pkt); + } + Err(e) => { + tracing::error!(err = %e, "recv"); } } } diff --git a/ts_transport_derp/examples/ping.rs b/ts_derp/examples/ping.rs similarity index 78% rename from ts_transport_derp/examples/ping.rs rename to ts_derp/examples/ping.rs index dbd9165c..a52be949 100644 --- a/ts_transport_derp/examples/ping.rs +++ b/ts_derp/examples/ping.rs @@ -4,7 +4,6 @@ use std::{sync::Arc, time::Duration}; use tokio::task::JoinSet; use ts_keys::NodeKeyPair; -use ts_transport::UnderlayTransport; mod common; @@ -17,7 +16,7 @@ async fn main() -> ts_cli_util::Result<()> { let keypair = NodeKeyPair::new(); - let client = ts_transport_derp::Client::connect(region, &keypair).await?; + let client = ts_derp::Client::connect(region, &keypair).await?; tracing::info!("derp handshake done"); let client = Arc::new(client); @@ -29,7 +28,7 @@ async fn main() -> ts_cli_util::Result<()> { let mut ticker = tokio::time::interval(Duration::from_secs(1)); loop { - if let Err(e) = pinger.send([(keypair.public, vec![vec![1].into()])]).await { + if let Err(e) = pinger.send_one(keypair.public, &[1]).await { tracing::error!(err = %e, "ping"); } else { tracing::info!("ping"); @@ -43,8 +42,8 @@ async fn main() -> ts_cli_util::Result<()> { js.spawn(async move { loop { match recv.recv_one().await { - Ok((pkt, peer)) => { - tracing::info!(?pkt, ?peer, "pong"); + Ok((peer_key, pkt)) => { + tracing::info!(?pkt, %peer_key, "pong"); } Err(e) => { tracing::error!(err = %e, "recv"); diff --git a/ts_transport_derp/src/async_tokio.rs b/ts_derp/src/client.rs similarity index 91% rename from ts_transport_derp/src/async_tokio.rs rename to ts_derp/src/client.rs index f0570271..1d03b040 100644 --- a/ts_transport_derp/src/async_tokio.rs +++ b/ts_derp/src/client.rs @@ -10,7 +10,7 @@ use tokio_util::codec::{FramedRead, FramedWrite}; use ts_http_util::Client as _; use ts_keys::{NodeKeyPair, NodePublicKey}; use ts_packet::PacketMut; -use ts_transport::UnderlayTransport; +use ts_transport::{BatchRecvIter, BatchSendIter, UnderlayTransport}; use url::Url; use crate::{ @@ -23,7 +23,7 @@ type DefaultIo = ts_http_util::Upgraded; /// Type alias for the default derp client over upgraded HTTP on a tokio executor. pub type DefaultClient = Client; -/// Asynchronous DERP transport for a single DERP region. +/// Single-region DERP client. pub struct Client { read_conn: Mutex, frame::Codec>>, write_conn: Mutex, frame::Codec>>, @@ -117,6 +117,12 @@ where }) } + /// Send a message to a nodekey on the derp server. + pub async fn send_one(&self, node_key: NodePublicKey, msg: &[u8]) -> Result<(), Error> { + self.send_frame_with_extra(&frame::SendPacket { dest: node_key }, msg) + .await + } + /// Send a frame to the derp server. pub async fn send_frame( &self, @@ -191,6 +197,7 @@ where } FrameType::RecvPacket => { let (recv, payload) = frame.as_type::().unwrap(); + return Ok((recv.src, payload.into())); } t => { @@ -278,32 +285,23 @@ impl UnderlayTransport for Client where Io: AsyncRead + AsyncWrite + Send, { + type PeerKey = NodePublicKey; type Error = Error; - #[tracing::instrument(fields(%self))] - async fn recv( + async fn send( &self, - ) -> impl IntoIterator< - Item = Result<(NodePublicKey, impl IntoIterator), Self::Error>, - > { - [self.recv_one().await.map(|(k, pkt)| (k, [pkt]))] - } - - /// Send a batch of packets to a peer via this DERP server. - async fn send(&self, peer_packets: BatchIter) -> Result<(), Self::Error> - where - BatchIter: IntoIterator + Send, - BatchIter::IntoIter: Send, - PacketIter: IntoIterator + Send, - PacketIter::IntoIter: Send, - { - for (peer, packets) in peer_packets { - for packet in packets { - self.send_frame_with_extra(&frame::SendPacket { dest: peer }, packet.as_ref()) - .await?; + packet_batch: impl BatchSendIter, + ) -> Result<(), Self::Error> { + for (key, pkt) in packet_batch.batch_iter() { + for pkt in pkt { + self.send_one(key, pkt.as_ref()).await?; } } Ok(()) } + + async fn recv(&self) -> impl BatchRecvIter { + [self.recv_one().await.map(|(k, pkt)| (k, [pkt]))] + } } diff --git a/ts_transport_derp/src/dial.rs b/ts_derp/src/dial.rs similarity index 100% rename from ts_transport_derp/src/dial.rs rename to ts_derp/src/dial.rs diff --git a/ts_transport_derp/src/error.rs b/ts_derp/src/error.rs similarity index 100% rename from ts_transport_derp/src/error.rs rename to ts_derp/src/error.rs diff --git a/ts_transport_derp/src/frame/body/client_info.rs b/ts_derp/src/frame/body/client_info.rs similarity index 100% rename from ts_transport_derp/src/frame/body/client_info.rs rename to ts_derp/src/frame/body/client_info.rs diff --git a/ts_transport_derp/src/frame/body/close_peer.rs b/ts_derp/src/frame/body/close_peer.rs similarity index 100% rename from ts_transport_derp/src/frame/body/close_peer.rs rename to ts_derp/src/frame/body/close_peer.rs diff --git a/ts_transport_derp/src/frame/body/forward_packet.rs b/ts_derp/src/frame/body/forward_packet.rs similarity index 100% rename from ts_transport_derp/src/frame/body/forward_packet.rs rename to ts_derp/src/frame/body/forward_packet.rs diff --git a/ts_transport_derp/src/frame/body/health.rs b/ts_derp/src/frame/body/health.rs similarity index 100% rename from ts_transport_derp/src/frame/body/health.rs rename to ts_derp/src/frame/body/health.rs diff --git a/ts_transport_derp/src/frame/body/keep_alive.rs b/ts_derp/src/frame/body/keep_alive.rs similarity index 100% rename from ts_transport_derp/src/frame/body/keep_alive.rs rename to ts_derp/src/frame/body/keep_alive.rs diff --git a/ts_transport_derp/src/frame/body/mod.rs b/ts_derp/src/frame/body/mod.rs similarity index 100% rename from ts_transport_derp/src/frame/body/mod.rs rename to ts_derp/src/frame/body/mod.rs diff --git a/ts_transport_derp/src/frame/body/note_preferred.rs b/ts_derp/src/frame/body/note_preferred.rs similarity index 100% rename from ts_transport_derp/src/frame/body/note_preferred.rs rename to ts_derp/src/frame/body/note_preferred.rs diff --git a/ts_transport_derp/src/frame/body/peer_gone.rs b/ts_derp/src/frame/body/peer_gone.rs similarity index 100% rename from ts_transport_derp/src/frame/body/peer_gone.rs rename to ts_derp/src/frame/body/peer_gone.rs diff --git a/ts_transport_derp/src/frame/body/peer_present.rs b/ts_derp/src/frame/body/peer_present.rs similarity index 100% rename from ts_transport_derp/src/frame/body/peer_present.rs rename to ts_derp/src/frame/body/peer_present.rs diff --git a/ts_transport_derp/src/frame/body/ping.rs b/ts_derp/src/frame/body/ping.rs similarity index 100% rename from ts_transport_derp/src/frame/body/ping.rs rename to ts_derp/src/frame/body/ping.rs diff --git a/ts_transport_derp/src/frame/body/pong.rs b/ts_derp/src/frame/body/pong.rs similarity index 100% rename from ts_transport_derp/src/frame/body/pong.rs rename to ts_derp/src/frame/body/pong.rs diff --git a/ts_transport_derp/src/frame/body/recv_packet.rs b/ts_derp/src/frame/body/recv_packet.rs similarity index 100% rename from ts_transport_derp/src/frame/body/recv_packet.rs rename to ts_derp/src/frame/body/recv_packet.rs diff --git a/ts_transport_derp/src/frame/body/restarting.rs b/ts_derp/src/frame/body/restarting.rs similarity index 100% rename from ts_transport_derp/src/frame/body/restarting.rs rename to ts_derp/src/frame/body/restarting.rs diff --git a/ts_transport_derp/src/frame/body/send_packet.rs b/ts_derp/src/frame/body/send_packet.rs similarity index 100% rename from ts_transport_derp/src/frame/body/send_packet.rs rename to ts_derp/src/frame/body/send_packet.rs diff --git a/ts_transport_derp/src/frame/body/server_info.rs b/ts_derp/src/frame/body/server_info.rs similarity index 100% rename from ts_transport_derp/src/frame/body/server_info.rs rename to ts_derp/src/frame/body/server_info.rs diff --git a/ts_transport_derp/src/frame/body/server_key.rs b/ts_derp/src/frame/body/server_key.rs similarity index 100% rename from ts_transport_derp/src/frame/body/server_key.rs rename to ts_derp/src/frame/body/server_key.rs diff --git a/ts_transport_derp/src/frame/body/watch_conns.rs b/ts_derp/src/frame/body/watch_conns.rs similarity index 100% rename from ts_transport_derp/src/frame/body/watch_conns.rs rename to ts_derp/src/frame/body/watch_conns.rs diff --git a/ts_transport_derp/src/frame/codec.rs b/ts_derp/src/frame/codec.rs similarity index 100% rename from ts_transport_derp/src/frame/codec.rs rename to ts_derp/src/frame/codec.rs diff --git a/ts_transport_derp/src/frame/error.rs b/ts_derp/src/frame/error.rs similarity index 100% rename from ts_transport_derp/src/frame/error.rs rename to ts_derp/src/frame/error.rs diff --git a/ts_transport_derp/src/frame/frame_type.rs b/ts_derp/src/frame/frame_type.rs similarity index 100% rename from ts_transport_derp/src/frame/frame_type.rs rename to ts_derp/src/frame/frame_type.rs diff --git a/ts_transport_derp/src/frame/header.rs b/ts_derp/src/frame/header.rs similarity index 100% rename from ts_transport_derp/src/frame/header.rs rename to ts_derp/src/frame/header.rs diff --git a/ts_transport_derp/src/frame/magic.rs b/ts_derp/src/frame/magic.rs similarity index 100% rename from ts_transport_derp/src/frame/magic.rs rename to ts_derp/src/frame/magic.rs diff --git a/ts_transport_derp/src/frame/mod.rs b/ts_derp/src/frame/mod.rs similarity index 100% rename from ts_transport_derp/src/frame/mod.rs rename to ts_derp/src/frame/mod.rs diff --git a/ts_transport_derp/src/frame/raw.rs b/ts_derp/src/frame/raw.rs similarity index 100% rename from ts_transport_derp/src/frame/raw.rs rename to ts_derp/src/frame/raw.rs diff --git a/ts_transport_derp/src/lib.rs b/ts_derp/src/lib.rs similarity index 99% rename from ts_transport_derp/src/lib.rs rename to ts_derp/src/lib.rs index d58fe004..654ca8ce 100644 --- a/ts_transport_derp/src/lib.rs +++ b/ts_derp/src/lib.rs @@ -8,12 +8,12 @@ use core::{ use zerocopy::{FromBytes, Immutable, IntoBytes, KnownLayout}; -mod async_tokio; +mod client; pub mod dial; mod error; pub mod frame; -pub use async_tokio::{Client, DefaultClient}; +pub use client::{Client, DefaultClient}; pub use error::Error; /// A 24-byte nonce for symmetric encryption with ChaCha20Poly1305. diff --git a/ts_devtools/Cargo.toml b/ts_devtools/Cargo.toml index 272cb3a1..a49ee220 100644 --- a/ts_devtools/Cargo.toml +++ b/ts_devtools/Cargo.toml @@ -20,11 +20,9 @@ tracing.workspace = true ts_cli_util.workspace = true ts_keys.workspace = true -ts_packet.workspace = true ts_packetfilter.workspace = true ts_packetfilter_state.workspace = true -ts_transport.workspace = true -ts_transport_derp.workspace = true +ts_derp.workspace = true [lints] workspace = true diff --git a/ts_devtools/src/bin/derp_ping.rs b/ts_devtools/src/bin/derp_ping.rs index 6ec3d683..3fd3a12c 100644 --- a/ts_devtools/src/bin/derp_ping.rs +++ b/ts_devtools/src/bin/derp_ping.rs @@ -6,8 +6,6 @@ use std::sync::Arc; use clap::Parser; use tokio::task::JoinSet; use ts_keys::NodePublicKey; -use ts_packet::PacketMut; -use ts_transport::UnderlayTransport; /// Authenticate with control, load the derp map, and attempt to exchange derp pings with /// a selected peer. @@ -38,17 +36,15 @@ async fn main() -> ts_cli_util::Result<()> { let mut tasks = JoinSet::new(); - tracing::info!(?region_id, "starting derp transport"); - - let derp = ts_transport_derp::Client::connect(&derp_servers, &config.key_state.node_key.into()) - .await?; - let derp = Arc::new(derp); - let peer = args .send_to_self .then_some(config.key_state.node_key.public_key()) .or(args.peer); + tracing::info!(?region_id, "starting derp transport"); + let derp = ts_derp::Client::connect(&derp_servers, &config.key_state.node_key.into()).await?; + let derp = Arc::new(derp); + if let Some(peer) = peer { tasks.spawn(derp_send_ping(peer, derp.clone())); } else { @@ -64,7 +60,7 @@ async fn main() -> ts_cli_util::Result<()> { static PING_MAX: AtomicU32 = AtomicU32::new(0); -async fn derp_receive_ping(derp: impl Borrow) { +async fn derp_receive_ping(derp: impl Borrow) { use bytes::Buf; let derp = derp.borrow(); @@ -84,19 +80,18 @@ async fn derp_receive_ping(derp: impl Borrow) } #[tracing::instrument(skip(derp), fields(%peer))] -async fn derp_send_ping(peer: NodePublicKey, derp: impl Borrow) { +async fn derp_send_ping(peer: NodePublicKey, derp: impl Borrow) { use bytes::BufMut; let mut ticker = tokio::time::interval(Duration::from_secs(1)); let derp = derp.borrow(); + let mut packet = [0u8; size_of::()]; loop { let val = PING_MAX.fetch_add(1, core::sync::atomic::Ordering::SeqCst); + (&mut packet[..]).put_u32(val); - let mut packet = PacketMut::with_capacity(size_of::()); - packet.put_u32(val); - - derp.send([(peer, [packet])]).await.unwrap(); + derp.send_one(peer, &packet).await.unwrap(); tracing::info!(value = val, "send ping"); ticker.tick().await; diff --git a/ts_netcheck/Cargo.toml b/ts_netcheck/Cargo.toml index 0ba4442c..824f56e5 100644 --- a/ts_netcheck/Cargo.toml +++ b/ts_netcheck/Cargo.toml @@ -14,7 +14,7 @@ rust-version.workspace = true # Our crates. ts_control.workspace = true ts_http_util.workspace = true -ts_transport_derp.workspace = true +ts_derp.workspace = true # Unconditionally required depdendencies. bytes.workspace = true diff --git a/ts_netcheck/src/derp_latency.rs b/ts_netcheck/src/derp_latency.rs index 6b7e7e5a..87240908 100644 --- a/ts_netcheck/src/derp_latency.rs +++ b/ts_netcheck/src/derp_latency.rs @@ -3,7 +3,7 @@ use core::{fmt::Debug, net::SocketAddr, time::Duration}; use ts_control::DerpMap; -use ts_transport_derp::RegionId; +use ts_derp::RegionId; /// Configuration for probing derp map latency. #[derive(Debug, Copy, Clone)] diff --git a/ts_netcheck/src/https.rs b/ts_netcheck/src/https.rs index d0fccc0d..82c64d49 100644 --- a/ts_netcheck/src/https.rs +++ b/ts_netcheck/src/https.rs @@ -4,8 +4,8 @@ use core::{net::SocketAddr, time::Duration}; use std::{io, time::Instant}; use tokio::io::{AsyncRead, AsyncWrite}; +use ts_derp::ServerConnInfo; use ts_http_util::{ClientExt, EmptyBody, Http1}; -use ts_transport_derp::ServerConnInfo; use url::Url; /// Errors that may occur while probing derp latency. @@ -44,9 +44,9 @@ impl From for Error { } } -impl From for Error { - fn from(value: ts_transport_derp::dial::Error) -> Self { - use ts_transport_derp::dial; +impl From for Error { + fn from(value: ts_derp::dial::Error) -> Self { + use ts_derp::dial; match value { dial::Error::Io => Error::Io, @@ -101,8 +101,8 @@ impl Default for Config { /// /// Returns `None` iff no servers could be successfully measured, either due to connectivity errors /// or because they were not configured to be reachable. See the notes on -/// [`dial_region_tls`][ts_transport_derp::dial::dial_region_tls] and -/// [`dial_region_tcp`][ts_transport_derp::dial::dial_region_tcp] for more details on when +/// [`dial_region_tls`][ts_derp::dial::dial_region_tls] and +/// [`dial_region_tcp`][ts_derp::dial::dial_region_tcp] for more details on when /// servers are treated as not configured for reachability. pub async fn measure_https_latency<'c>( servers: impl IntoIterator, @@ -120,18 +120,17 @@ pub async fn measure_https_latency<'c>( let mut servers = servers.into_iter(); loop { - let (conn, server, remote) = - match ts_transport_derp::dial::dial_region_tls(&mut servers).await { - Ok(Some(x)) => x, - Ok(None) => { - tracing::warn!("ran out of servers to dial"); - return None; - } - Err(e) => { - tracing::error!(error = %e, "dialing tls"); - continue; - } - }; + let (conn, server, remote) = match ts_derp::dial::dial_region_tls(&mut servers).await { + Ok(Some(x)) => x, + Ok(None) => { + tracing::warn!("ran out of servers to dial"); + return None; + } + Err(e) => { + tracing::error!(error = %e, "dialing tls"); + continue; + } + }; match measure_server_latency(conn, server, &config).await { Ok(dur) => return Some((dur, server, remote)), @@ -231,7 +230,7 @@ mod test { let info = info(); - let (conn, server, remote) = ts_transport_derp::dial::dial_region_tls([&info]) + let (conn, server, remote) = ts_derp::dial::dial_region_tls([&info]) .await .unwrap() .unwrap(); diff --git a/ts_overlay_router/Cargo.toml b/ts_overlay_router/Cargo.toml index 1ebcb89b..5a033e0a 100644 --- a/ts_overlay_router/Cargo.toml +++ b/ts_overlay_router/Cargo.toml @@ -12,7 +12,6 @@ rust-version.workspace = true [dependencies] ts_bart.workspace = true -ts_keys.workspace = true ts_packet.workspace = true ts_transport.workspace = true diff --git a/ts_overlay_router/src/outbound.rs b/ts_overlay_router/src/outbound.rs index fd2fc5a5..09d5f9c3 100644 --- a/ts_overlay_router/src/outbound.rs +++ b/ts_overlay_router/src/outbound.rs @@ -4,9 +4,8 @@ use std::collections::HashMap; use itertools::Itertools; use ts_bart::{RoutingTable, Table}; -use ts_keys::NodePublicKey; use ts_packet::PacketMut; -use ts_transport::OverlayTransportId; +use ts_transport::{OverlayTransportId, PeerId}; /// An outbound routing action. #[derive(Debug, Clone)] @@ -18,7 +17,7 @@ pub enum RouteAction { Drop, /// Send to a wireguard peer. - Wireguard(NodePublicKey), + Wireguard(PeerId), /// Loop the packet back to a local overlay transport. /// @@ -36,7 +35,7 @@ pub struct Router { #[derive(Debug, Default, Eq, PartialEq)] pub struct Result { /// Packets to send through wireguard. - pub to_wireguard: HashMap>, + pub to_wireguard: HashMap>, /// Packets to return to a local transport. pub loopback: HashMap>, } @@ -102,8 +101,8 @@ mod tests { #[test] fn test_outbound_overlay() { - let peer_a = NodePublicKey::from([1u8; 32]); - let peer_b = NodePublicKey::from([2u8; 32]); + let peer_a = PeerId(0); + let peer_b = PeerId(1); let magicdns = 42.into(); let mut routes = Table::default(); diff --git a/ts_runtime/Cargo.toml b/ts_runtime/Cargo.toml index f9929bbf..1b9313b5 100644 --- a/ts_runtime/Cargo.toml +++ b/ts_runtime/Cargo.toml @@ -24,7 +24,8 @@ ts_packet.workspace = true ts_packetfilter = { workspace = true, features = ["checking-filter"] } ts_packetfilter_state.workspace = true ts_transport.workspace = true -ts_transport_derp.workspace = true +ts_derp.workspace = true +ts_tunnel.workspace = true # Unconditionally required dependencies. futures.workspace = true @@ -34,6 +35,12 @@ kameo_actors = "0.4" thiserror.workspace = true tokio.workspace = true tracing.workspace = true +smallvec.workspace = true + +[dev-dependencies] +rand = "0.10" +proptest = "1.11" +itertools = "0.14" [lints] workspace = true diff --git a/ts_runtime/src/dataplane.rs b/ts_runtime/src/dataplane.rs index f36520dc..86ea08dc 100644 --- a/ts_runtime/src/dataplane.rs +++ b/ts_runtime/src/dataplane.rs @@ -5,14 +5,14 @@ use kameo::{ message::{Context, Message}, }; use tokio::sync::mpsc; -use ts_keys::NodePublicKey; use ts_packet::PacketMut; -use ts_transport::{OverlayTransportId, UnderlayTransportId}; +use ts_transport::{OverlayTransportId, PeerId, UnderlayTransportId}; use crate::{ Error, env::Env, packetfilter::PacketFilterState, + peer_tracker::PeerState, route_updater::{PeerRouteUpdate, SelfRouteUpdate}, src_filter::SourceFilterState, }; @@ -24,10 +24,10 @@ pub type OverlayToDataplane = mpsc::UnboundedSender>; pub type OverlayFromDataplane = mpsc::UnboundedReceiver>; /// Queue for packets leaving the underlay to the dataplane. -pub type UnderlayToDataplane = mpsc::UnboundedSender<(NodePublicKey, Vec)>; +pub type UnderlayToDataplane = mpsc::UnboundedSender<(PeerId, Vec)>; /// Queue for packets entering an underlay from the dataplane. -pub type UnderlayFromDataplane = mpsc::UnboundedReceiver<(NodePublicKey, Vec)>; +pub type UnderlayFromDataplane = mpsc::UnboundedReceiver<(PeerId, Vec)>; pub struct DataplaneActor { dataplane: Arc, @@ -74,6 +74,7 @@ impl kameo::Actor for DataplaneActor { env.subscribe::(&slf).await?; env.subscribe::(&slf).await?; env.subscribe::(&slf).await?; + env.subscribe::>(&slf).await?; let task_dataplane = dataplane.clone(); @@ -138,3 +139,32 @@ impl Message for DataplaneActor { tracing::trace!("applied new source filter"); } } + +impl Message> for DataplaneActor { + type Reply = (); + + async fn handle(&mut self, msg: Arc, _ctx: &mut Context) { + { + let mut dp = self.dataplane.inner().await; + let wg = &mut dp.wireguard; + + for &upsert in &msg.upserts { + let (_, node) = msg.peers.get(&upsert).unwrap(); + + wg.upsert_peer( + ts_tunnel::PeerId(upsert.0), + ts_tunnel::PeerConfig { + key: node.node_key, + psk: [0u8; 32].into(), + }, + ); + } + + for delete in &msg.deletions { + wg.remove_peer(ts_tunnel::PeerId(delete.0)); + } + } + + tracing::trace!("applied new peer state"); + } +} diff --git a/ts_runtime/src/lib.rs b/ts_runtime/src/lib.rs index f495ab1b..dbecf9c6 100644 --- a/ts_runtime/src/lib.rs +++ b/ts_runtime/src/lib.rs @@ -40,7 +40,7 @@ pub struct Runtime { pub control: ActorRef, dataplane: ActorRef, netstack: WeakActorRef, - /// Reference to the peer state tracker actor, used for lookup. + /// Reference to the peer tracker for peer lookups. pub peer_tracker: WeakActorRef, env: Env, shutdown: watch::Sender, diff --git a/ts_runtime/src/multiderp.rs b/ts_runtime/src/multiderp.rs index f7551247..c6a51df7 100644 --- a/ts_runtime/src/multiderp.rs +++ b/ts_runtime/src/multiderp.rs @@ -1,6 +1,6 @@ use std::{ collections::HashMap, - sync::Arc, + sync::{Arc, RwLock}, time::{Duration, Instant}, }; @@ -11,14 +11,17 @@ use kameo::{ }; use tokio::{sync::watch, task::JoinSet}; use ts_control::DerpRegion; -use ts_keys::NodeKeyPair; -use ts_transport::{UnderlayTransport, UnderlayTransportId}; -use ts_transport_derp::RegionId; +use ts_derp::RegionId; +use ts_keys::{NodeKeyPair, NodePublicKey}; +use ts_transport::{ + BatchRecvIter, PeerId, UnderlayTransport, UnderlayTransportExt, UnderlayTransportId, +}; use crate::{ Env, Error, dataplane::{DataplaneActor, NewUnderlayTransport, UnderlayFromDataplane, UnderlayToDataplane}, derp_latency::DerpLatencyMeasurement, + peer_tracker::{PeerDb, PeerState}, }; /// Consumes derp map updates and spawns a task per region that runs an underlay transport. @@ -33,6 +36,7 @@ pub struct Multiderp { dataplane: ActorRef, derps: HashMap, current_home_derp: Option, + peer_db: Arc>>>, tasks: JoinSet<()>, } @@ -72,6 +76,8 @@ impl Multiderp { }; let (home_derp_tx, mut home_derp_rx) = watch::channel(false); + let peer_db = self.peer_db.clone(); + self.tasks.spawn(async move { while !*shutdown.borrow() { tokio::select! { @@ -85,6 +91,7 @@ impl Multiderp { &down, &mut up, &mut home_derp_rx, + &peer_db, ) => if let Err(e) = ret { tracing::error!(error = %e, region_id = %id, "running derp client"); tokio::time::sleep(Duration::from_millis(500)).await; @@ -121,6 +128,29 @@ impl Multiderp { } } +struct PeerDbLookup<'a>(&'a RwLock>>); + +impl ts_transport::PeerLookup for PeerDbLookup<'_> { + fn lookup_key(&self, id: PeerId) -> Option { + let db = self.0.read().unwrap(); + let db = db.as_ref()?; + + let (_, node) = db.get(&id)?; + Some(node.node_key) + } +} + +impl ts_transport::PeerLookup for PeerDbLookup<'_> { + fn lookup_key(&self, key: NodePublicKey) -> Option { + let db = self.0.read().unwrap(); + let db = db.as_ref()?; + + let (id, _) = db.get(&key)?; + + Some(id) + } +} + #[tracing::instrument(skip_all, fields(region_id = %id), name = "derp packet transport")] async fn run_derp_once( id: RegionId, @@ -129,7 +159,8 @@ async fn run_derp_once( to_dataplane: &UnderlayToDataplane, from_dataplane: &mut UnderlayFromDataplane, home_derp_rx: &mut watch::Receiver, -) -> Result<(), ts_transport_derp::Error> { + peer_db: &RwLock>>, +) -> Result<(), ts_derp::Error> { const INACTIVITY_TIMEOUT: Duration = Duration::from_secs(10); loop { @@ -153,11 +184,12 @@ async fn run_derp_once( tracing::trace!("establishing derp connection"); - let client = ts_transport_derp::DefaultClient::connect(®ion.servers, &keys).await?; + let client = ts_derp::DefaultClient::connect(®ion.servers, &keys).await?; + let transport = client.with_key_lookup(PeerDbLookup(peer_db)); if let Some(pending) = pending { tracing::trace!("sending queued packet"); - client.send([pending]).await?; + transport.send([pending]).await?; } let mut last_activity = Instant::now(); @@ -169,16 +201,20 @@ async fn run_derp_once( (!*home_derp_rx.borrow()).then(|| last_activity + INACTIVITY_TIMEOUT); tokio::select! { - from_derp = client.recv_one() => { + from_derp = transport.recv() => { last_activity = Instant::now(); - let (peer, pkt) = from_derp?; - tracing::trace!(parent: &span, %peer, len = pkt.len(), "packet from derp server"); + for ret in from_derp.batch_iter() { + let (peer_id, pkts) = ret?; + let pkts = pkts.into_iter().collect::>(); - let Ok(()) = to_dataplane.send((peer, vec![pkt])) else { - tracing::error!(parent: &span, "underlay receive channel closed"); - break; - }; + tracing::trace!(parent: &span, %peer_id, len = pkts.len(), "packet from derp server"); + + let Ok(()) = to_dataplane.send((peer_id, pkts)) else { + tracing::error!(parent: &span, "underlay receive channel closed"); + break; + }; + } }, from_net = from_dataplane.recv() => { @@ -191,7 +227,7 @@ async fn run_derp_once( tracing::trace!(parent: &span, peer = %from_net.0, packets = from_net.1.len(), "packets to derp server"); - client.send([from_net]).await?; + transport.send([from_net]).await?; }, _ = option_timeout(inactivity_timeout) => { @@ -225,11 +261,13 @@ impl kameo::Actor for Multiderp { slf: ActorRef, ) -> Result { env.subscribe::>(&slf).await?; + env.subscribe::>(&slf).await?; env.subscribe::(&slf).await?; Ok(Self { env, dataplane, + peer_db: Default::default(), derps: Default::default(), tasks: JoinSet::new(), current_home_derp: None, @@ -269,6 +307,15 @@ impl Message> for Multiderp { } } +impl Message> for Multiderp { + type Reply = (); + + async fn handle(&mut self, msg: Arc, _ctx: &mut Context) { + let mut db = self.peer_db.write().unwrap(); + *db = Some(msg.peers.clone()); + } +} + impl Message for Multiderp { type Reply = (); diff --git a/ts_runtime/src/peer_tracker.rs b/ts_runtime/src/peer_tracker.rs deleted file mode 100644 index 8bc0b03d..00000000 --- a/ts_runtime/src/peer_tracker.rs +++ /dev/null @@ -1,436 +0,0 @@ -//! Peer delta update tracking. - -use std::{ - borrow::Borrow, - collections::{HashMap, HashSet}, - net::IpAddr, - sync::Arc, -}; - -use ipnet::IpNet; -use kameo::{ - actor::ActorRef, - message::{Context, Message}, - reply::ReplySender, -}; -use ts_control::{Node, NodeId}; -use ts_keys::NodePublicKey; - -use crate::{Error, env::Env}; - -/// Actor that tracks peer delta updates and emits new states. -pub struct PeerTracker { - peers: HashMap, - id_to_nodekey: HashMap, - seen_state_update: bool, - pending_requests: Vec, - env: Env, -} - -// TODO(npry): accelerate with indexed data structures, linear search won't be -// acceptable on large tailnets. -impl PeerTracker { - fn peer_by_name_opt(&self, name: &str) -> Option<&Node> { - self.peers.values().find(|&peer| peer.matches_name(name)) - } - - fn peer_by_tailnet_ip_opt(&self, ip: IpAddr) -> Option<&Node> { - self.peers.values().find(|&peer| { - peer.tailnet_address.ipv4.addr() == ip || peer.tailnet_address.ipv6.addr() == ip - }) - } -} - -impl kameo::Actor for PeerTracker { - type Args = Env; - type Error = Error; - - async fn on_start(env: Self::Args, slf: ActorRef) -> Result { - env.subscribe::>(&slf).await?; - - Ok(Self { - peers: Default::default(), - id_to_nodekey: Default::default(), - pending_requests: Default::default(), - seen_state_update: false, - env, - }) - } -} - -enum Pending { - PeerByName(PeerByName, ReplySender>), - AcceptedRoute(PeerByAcceptedRoute, ReplySender>), - TailnetIp(PeerByTailnetIp, ReplySender>), -} - -// For messages with arguments, a struct is generated with the args as fields. They aren't -// documented, and we can't apply attributes directly to the fields. Hence, wrap in a module where -// docs are turned off everywhere. -#[allow(missing_docs)] -mod msg_impl { - use std::net::IpAddr; - - use kameo::prelude::DelegatedReply; - - use super::*; - - #[kameo::messages] - impl PeerTracker { - /// Lookup a peer by name. - /// - /// Waits until we've received at least one peer update from control. - #[message(ctx)] - pub async fn peer_by_name( - &mut self, - ctx: &mut Context>>, - name: String, - ) -> DelegatedReply> { - let (deleg, sender) = ctx.reply_sender(); - let Some(sender) = sender else { return deleg }; - - if !self.seen_state_update { - tracing::debug!(query = name, "no peer state seen yet, queueing request"); - - self.pending_requests - .push(Pending::PeerByName(PeerByName { name }, sender)); - - return deleg; - } - - sender.send(self.peer_by_name_opt(&name).cloned()); - - deleg - } - - /// Lookup all peers that accept packets addressed to the given IP. - /// - /// This includes the peer's tailnet address and any subnet routes it provides. Only - /// the peers with the most specific subnet route match that covers `ip` will be - /// returned. - /// - /// E.g., suppose: - /// - /// - We're querying for `10.1.2.3` - /// - `PeerA` and `PeerB` have accepted routes for `10.1.2.0/24` - /// - `PeerC` has an accepted route for `10.1.0.0/16` - /// - /// Only `PeerA` and `PeerB` will be returned, since they have the most specific - /// prefix match. - #[message(ctx)] - pub fn peer_by_accepted_route( - &mut self, - ctx: &mut Context>>, - ip: IpAddr, - ) -> DelegatedReply> { - let (deleg, sender) = ctx.reply_sender(); - let Some(sender) = sender else { return deleg }; - - if !self.seen_state_update { - tracing::debug!(query = %ip, "no peer state seen yet, queueing request"); - - self.pending_requests - .push(Pending::AcceptedRoute(PeerByAcceptedRoute { ip }, sender)); - - return deleg; - } - - sender.send(best_route_match(ip, self.peers.values())); - - deleg - } - - /// Lookup the peer that has the given tailnet IP address. - #[message(ctx)] - pub fn peer_by_tailnet_ip( - &mut self, - ctx: &mut Context>>, - ip: IpAddr, - ) -> DelegatedReply> { - let (deleg, sender) = ctx.reply_sender(); - let Some(sender) = sender else { return deleg }; - - if !self.seen_state_update { - tracing::debug!(query = %ip, "no peer state seen yet, queueing request"); - - self.pending_requests - .push(Pending::TailnetIp(PeerByTailnetIp { ip }, sender)); - - return deleg; - } - - sender.send(self.peer_by_tailnet_ip_opt(ip).cloned()); - - deleg - } - } -} - -pub use msg_impl::*; - -#[derive(Debug, Clone)] -pub(crate) struct PeerState { - #[allow(unused)] - pub deletions: HashSet, - #[allow(unused)] - pub upserts: HashSet, - pub peers: Arc>, -} - -// TODO: rpds - -impl Message> for PeerTracker { - type Reply = (); - - async fn handle( - &mut self, - msg: Arc, - _ctx: &mut Context, - ) { - let Some(peer_update) = &msg.peer_update else { - return; - }; - - let mut upserts = HashSet::default(); - let mut deletions = HashSet::default(); - - match peer_update { - ts_control::PeerUpdate::Full(nodes) => { - tracing::trace!("full peer update"); - - deletions = self.peers.keys().copied().collect(); - - self.peers.clear(); - self.id_to_nodekey.clear(); - - for node in nodes { - upserts.insert(node.node_key); - deletions.remove(&node.node_key); - - self.id_to_nodekey.insert(node.id, node.node_key); - self.peers.insert(node.node_key, node.clone()); - } - } - - ts_control::PeerUpdate::Delta { remove, upsert } => { - tracing::trace!("delta peer update"); - - for peer in upsert { - self.id_to_nodekey.insert(peer.id, peer.node_key); - self.peers.insert(peer.node_key, peer.clone()); - - upserts.insert(peer.node_key); - } - - for peer in remove { - let node_key = self.id_to_nodekey.remove(peer); - - if let Some(node_key) = node_key { - self.peers.remove(&node_key); - deletions.insert(node_key); - } - } - } - } - - tracing::debug!( - n_upsert = upserts.len(), - n_delete = deletions.len(), - peer_count = self.peers.len(), - "new peer state" - ); - - if !self.seen_state_update { - self.seen_state_update = true; - - if !self.pending_requests.is_empty() { - tracing::debug!( - n_pending = self.pending_requests.len(), - "state update received, servicing pending requests" - ); - } - - for req in core::mem::take(&mut self.pending_requests) { - match req { - Pending::PeerByName(PeerByName { name }, reply) => { - reply.send(self.peer_by_name_opt(&name).cloned()); - } - Pending::TailnetIp(PeerByTailnetIp { ip }, reply) => { - reply.send(self.peer_by_tailnet_ip_opt(ip).cloned()); - } - Pending::AcceptedRoute(PeerByAcceptedRoute { ip }, reply) => { - reply.send(best_route_match(ip, self.peers.values())); - } - } - } - } - - if let Err(e) = self - .env - .publish(PeerState { - upserts, - deletions, - peers: Arc::new(self.peers.clone()), - }) - .await - { - tracing::error!(error = %e, "publishing peer state update"); - } - } -} - -/// Get the most-narrow set of peers that have routes for the given IP. -fn best_route_match<'n, N>(query_ip: IpAddr, it: impl IntoIterator) -> Vec -where - N: Borrow + 'n, -{ - // TODO(npry): accelerate with an indexed data structure, linear search won't be - // acceptable on large tailnets. - - let (_, matching_peers) = it.into_iter().fold( - (None, vec![]), - |(mut best_match, mut matching_peers), peer: N| { - let peer = peer.borrow(); - let mut peer_best = None; - - for &candidate in &peer.accepted_routes { - // Normalize all prefixes to truncated form (mask off the host bits). - let candidate = candidate.trunc(); - - if !candidate.contains(&query_ip) { - continue; - } - - if peer_best - .as_ref() - .is_none_or(|existing: &IpNet| existing.contains(&candidate)) - { - peer_best = Some(candidate); - } - } - - match (best_match.as_ref(), peer_best) { - // This peer doesn't match, skip - (_, None) => return (best_match, matching_peers), - - // No previous match, set unconditionally - (None, _) => best_match = peer_best, - - // Previous match (same prefix), don't update - (Some(x), Some(y)) if x == &y => {} - - // New best match, clear old state - (Some(existing), Some(candidate)) if existing.contains(&candidate) => { - matching_peers.clear(); - best_match = peer_best; - } - - // This peer doesn't have as good a match - _ => return (best_match, matching_peers), - } - - matching_peers.push(peer.clone()); - - (best_match, matching_peers) - }, - ); - - matching_peers -} - -#[cfg(test)] -mod test { - use std::net::Ipv4Addr; - - use ipnet::Ipv4Net; - use ts_control::{StableNodeId, TailnetAddress}; - - use super::*; - - fn dummy_node(routes: impl IntoIterator) -> Node { - Node { - accepted_routes: routes.into_iter().collect(), - - node_key: Default::default(), - id: 0, - stable_id: StableNodeId("".to_owned()), - disco_key: Default::default(), - machine_key: None, - tailnet: None, - hostname: "".to_owned(), - tailnet_address: TailnetAddress { - ipv4: Default::default(), - ipv6: Default::default(), - }, - underlay_addresses: vec![], - node_key_expiry: None, - derp_region: None, - tags: vec![], - } - } - - fn ipv4net(ip: impl Into, pfx_len: usize) -> IpNet { - Ipv4Net::new(ip.into(), pfx_len as _).unwrap().into() - } - - #[test] - fn route_match() { - // no peers, no match - let m = best_route_match::([1, 2, 3, 4].into(), []); - assert!(m.is_empty()); - - // peer with no routes, no match - let m = best_route_match::([1, 2, 3, 4].into(), [dummy_node([])]); - assert!(m.is_empty()); - - // single peer, single match -- typical case - let m = best_route_match::( - [1, 2, 3, 4].into(), - [dummy_node([ipv4net([1, 2, 3, 4], 32)])], - ); - assert_eq!(m.len(), 1); - - // two matches both succeed - let m = best_route_match::( - [1, 2, 3, 4].into(), - [ - dummy_node([ipv4net([1, 2, 3, 4], 32)]), - dummy_node([ipv4net([1, 2, 3, 4], 32)]), - ], - ); - assert_eq!(m.len(), 2); - - // more-specific match wins - let m = best_route_match::( - [1, 2, 3, 4].into(), - [ - dummy_node([ipv4net([1, 2, 3, 4], 31)]), - dummy_node([ipv4net([1, 2, 3, 4], 32)]), - ], - ); - assert_eq!(m.len(), 1); - assert_eq!(m[0].accepted_routes[0].prefix_len(), 32); - - // denormalized prefix - let m = best_route_match::( - [1, 2, 3, 4].into(), - [ - dummy_node([ipv4net([1, 2, 3, 0], 24)]), - dummy_node([ipv4net([1, 2, 3, 8], 24)]), - ], - ); - assert_eq!(m.len(), 2); - assert_eq!(m[0].accepted_routes[0].prefix_len(), 24); - - // overlapping routes - let m = best_route_match::( - [1, 2, 3, 4].into(), - [ - dummy_node([ipv4net([1, 2, 3, 0], 24), ipv4net([1, 2, 3, 123], 24)]), - dummy_node([ipv4net([1, 2, 3, 8], 24)]), - ], - ); - assert_eq!(m.len(), 2); - assert_eq!(m[0].accepted_routes[0].prefix_len(), 24); - } -} diff --git a/ts_runtime/src/peer_tracker/mod.rs b/ts_runtime/src/peer_tracker/mod.rs new file mode 100644 index 00000000..ce8584be --- /dev/null +++ b/ts_runtime/src/peer_tracker/mod.rs @@ -0,0 +1,295 @@ +//! Peer delta update tracking. + +use std::{collections::HashSet, net::IpAddr, sync::Arc}; + +use kameo::{ + actor::ActorRef, + message::{Context, Message}, + reply::ReplySender, +}; +use ts_control::Node; +use ts_transport::PeerId; + +use crate::{Error, env::Env}; + +mod peer_db; + +pub use peer_db::PeerDb; + +/// Actor that tracks peer delta updates and emits new states. +pub struct PeerTracker { + peer_db: PeerDb, + seen_state_update: bool, + pending_requests: Vec, + env: Env, +} + +impl PeerTracker { + fn peer_by_name_opt(&self, name: &str) -> Option<&Node> { + let name = name.trim_end_matches('.'); + self.peer_db.get(&name).map(|(_id, node)| node) + } + + fn peer_by_tailnet_ip_opt(&self, ip: IpAddr) -> Option<&Node> { + self.peer_db.get(&ip).map(|(_id, node)| node) + } +} + +impl kameo::Actor for PeerTracker { + type Args = Env; + type Error = Error; + + async fn on_start(env: Self::Args, slf: ActorRef) -> Result { + env.subscribe::>(&slf).await?; + + Ok(Self { + peer_db: PeerDb::default(), + pending_requests: Default::default(), + seen_state_update: false, + env, + }) + } +} + +enum Pending { + PeerByName(PeerByName, ReplySender>), + AcceptedRoute(PeerByAcceptedRoute, ReplySender>), + TailnetIp(PeerByTailnetIp, ReplySender>), +} + +// For messages with arguments, a struct is generated with the args as fields. They aren't +// documented, and we can't apply attributes directly to the fields. Hence, wrap in a module where +// docs are turned off everywhere. +#[allow(missing_docs)] +mod msg_impl { + use std::net::IpAddr; + + use kameo::prelude::DelegatedReply; + + use super::*; + + #[kameo::messages] + impl PeerTracker { + /// Lookup a peer by name. + /// + /// Waits until we've received at least one peer update from control. + #[message(ctx)] + pub async fn peer_by_name( + &mut self, + ctx: &mut Context>>, + name: String, + ) -> DelegatedReply> { + let (deleg, sender) = ctx.reply_sender(); + let Some(sender) = sender else { return deleg }; + + if !self.seen_state_update { + tracing::debug!(query = name, "no peer state seen yet, queueing request"); + + self.pending_requests + .push(Pending::PeerByName(PeerByName { name }, sender)); + + return deleg; + } + + sender.send(self.peer_by_name_opt(&name).cloned()); + + deleg + } + + /// Lookup all peers that accept packets addressed to the given IP. + /// + /// This includes the peer's tailnet address and any subnet routes it provides. Only + /// the peers with the most specific subnet route match that covers `ip` will be + /// returned. + /// + /// E.g., suppose: + /// + /// - We're querying for `10.1.2.3` + /// - `PeerA` and `PeerB` have accepted routes for `10.1.2.0/24` + /// - `PeerC` has an accepted route for `10.1.0.0/16` + /// + /// Only `PeerA` and `PeerB` will be returned, since they have the most specific + /// prefix match. + #[message(ctx)] + pub fn peer_by_accepted_route( + &mut self, + ctx: &mut Context>>, + ip: IpAddr, + ) -> DelegatedReply> { + let (deleg, sender) = ctx.reply_sender(); + let Some(sender) = sender else { return deleg }; + + if !self.seen_state_update { + tracing::debug!(query = %ip, "no peer state seen yet, queueing request"); + + self.pending_requests + .push(Pending::AcceptedRoute(PeerByAcceptedRoute { ip }, sender)); + + return deleg; + } + + sender.send( + self.peer_db + .get_route(ip.into()) + .map(|(_id, node)| node.clone()) + .collect(), + ); + + deleg + } + + /// Lookup the peer that has the given tailnet IP address. + #[message(ctx)] + pub fn peer_by_tailnet_ip( + &mut self, + ctx: &mut Context>>, + ip: IpAddr, + ) -> DelegatedReply> { + let (deleg, sender) = ctx.reply_sender(); + let Some(sender) = sender else { return deleg }; + + if !self.seen_state_update { + tracing::debug!(query = %ip, "no peer state seen yet, queueing request"); + + self.pending_requests + .push(Pending::TailnetIp(PeerByTailnetIp { ip }, sender)); + + return deleg; + } + + sender.send(self.peer_by_tailnet_ip_opt(ip).cloned()); + + deleg + } + } +} + +pub use msg_impl::*; + +#[derive(Debug, Clone)] +pub(crate) struct PeerState { + #[allow(unused)] + pub deletions: HashSet, + #[allow(unused)] + pub upserts: HashSet, + pub peers: Arc, +} + +impl Message> for PeerTracker { + type Reply = (); + + async fn handle( + &mut self, + msg: Arc, + _ctx: &mut Context, + ) { + let Some(peer_update) = &msg.peer_update else { + return; + }; + + let mut upserts = HashSet::default(); + let mut deletions = HashSet::default(); + + match peer_update { + ts_control::PeerUpdate::Full(new_nodes) => { + tracing::trace!("full peer update"); + + let new_ids = new_nodes + .iter() + .map(|x| &x.stable_id) + .collect::>(); + + self.peer_db.retain(|id, peer| { + let retain = new_ids.contains(&peer.stable_id); + + if !retain { + deletions.insert(id); + } + + retain + }); + + for node in new_nodes { + let peer_id = self.peer_db.upsert(node); + upserts.insert(peer_id); + } + } + + ts_control::PeerUpdate::Delta { remove, upsert } => { + tracing::trace!("delta peer update"); + + for peer in upsert { + let id = self.peer_db.upsert(peer); + + upserts.insert(id); + } + + for peer in remove { + let Some((id, _node)) = self.peer_db.remove(peer) else { + tracing::error!(control_node_id = peer, "removed peer was unknown"); + continue; + }; + + deletions.insert(id); + } + } + } + + tracing::debug!( + n_upsert = upserts.len(), + n_delete = deletions.len(), + peer_count = self.peer_db.peers().len(), + "new peer state" + ); + + self.service_pending_requests(); + + if let Err(e) = self + .env + .publish(Arc::new(PeerState { + upserts, + deletions, + peers: Arc::new(self.peer_db.clone()), + })) + .await + { + tracing::error!(error = %e, "publishing peer state update"); + } + } +} + +impl PeerTracker { + fn service_pending_requests(&mut self) { + if self.seen_state_update { + return; + } + + self.seen_state_update = true; + + if !self.pending_requests.is_empty() { + tracing::debug!( + n_pending = self.pending_requests.len(), + "state update received, servicing pending requests" + ); + } + + for req in core::mem::take(&mut self.pending_requests) { + match req { + Pending::PeerByName(PeerByName { name }, reply) => { + reply.send(self.peer_by_name_opt(&name).cloned()); + } + Pending::TailnetIp(PeerByTailnetIp { ip }, reply) => { + reply.send(self.peer_by_tailnet_ip_opt(ip).cloned()); + } + Pending::AcceptedRoute(PeerByAcceptedRoute { ip }, reply) => { + reply.send( + self.peer_db + .get_route(ip.into()) + .map(|(_id, node)| node.clone()) + .collect(), + ); + } + } + } + } +} diff --git a/ts_runtime/src/peer_tracker/peer_db.rs b/ts_runtime/src/peer_tracker/peer_db.rs new file mode 100644 index 00000000..076c5e4a --- /dev/null +++ b/ts_runtime/src/peer_tracker/peer_db.rs @@ -0,0 +1,830 @@ +use std::{ + collections::HashMap, + fmt::{Debug, Formatter}, + hash::Hash, + net::IpAddr, +}; + +use ts_bart::{RouteModification, RoutingTable, RoutingTableExt}; +use ts_control::{Node, StableNodeId}; +use ts_keys::{DiscoPublicKey, NodePublicKey}; +use ts_transport::PeerId; + +mod private { + use super::*; + + pub trait Sealed {} + + impl Sealed for PeerId {} + impl Sealed for NodePublicKey {} + impl Sealed for DiscoPublicKey {} + impl Sealed for StableNodeId {} + impl Sealed for ts_control::NodeId {} + impl Sealed for PeerName {} + impl Sealed for &str {} + impl Sealed for IpAddr {} + impl Sealed for ipnet::IpNet {} +} + +/// A [`Node`] field indexed by [`PeerDb`]. +pub trait IndexedField: Debug + private::Sealed { + /// Look up the peer id that has this field. + fn lookup(&self, db: &PeerDb) -> Option; +} + +type Index = HashMap; +type PeerName = String; + +/// A database that stores a map of peers by [`PeerId`] and multiple indices. +/// +/// Assumes that _all indexed fields_ are unique per-node, with a few notable exceptions: +/// +/// - Hostname may be duplicated, though the fqdn (including the tailnet component) may not +/// be. +/// - Accepted routes may overlap. +#[derive(Default, Clone)] +pub struct PeerDb { + peers: HashMap, + index_state: IndexState, + next_id: u32, +} + +impl Debug for PeerDb { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + self.peers.fmt(f) + } +} + +#[derive(Default, Clone)] +struct IndexState { + /// Index on the node's [`NodePublicKey`]. + nk_idx: Index, + /// Index on the [`DiscoPublicKey`], assuming it's known. + disco_idx: Index, + /// Index on the peer [`StableNodeId`]. + stableid_idx: Index, + /// Index for the [`ts_control::NodeId`]. + /// + /// This is a numeric ID assigned by control which could overlap across different + /// control regions (by contrast to [`StableNodeId`], which should not). We need this + /// field because control indicates node patches and deletions by this id rather than + /// the stable id. + control_idx: Index, + /// Index on the peer name and FQDN. + name_idx: Index, + /// Index on the node's tailnet IPv4 and IPv6. + ip_idx: ts_bart::Table, + /// Index on the node's accepted routes. + /// + /// These may overlap between nodes, hence this stores a vec of matching node ids for + /// each route. + route_idx: ts_bart::Table>, +} + +impl PeerDb { + /// Upsert a node into the peer db. + /// + /// The [`StableNodeId`] is used as the primary key to identify the node. + pub fn upsert(&mut self, new: &Node) -> PeerId { + let id = self + .index_state + .stableid_idx + .get(&new.stable_id) + .copied() + .unwrap_or_else(|| { + let id = self.next_id; + self.next_id += 1; + + PeerId(id) + }); + + let old = self.peers.get(&id); + + // no update: same node + if old.is_some_and(|x| x == new) { + return id; + } + + maybe_update_idx(new, old, |x| &x.node_key, &mut self.index_state.nk_idx, id); + maybe_update_idx( + new, + old, + |x| &x.stable_id, + &mut self.index_state.stableid_idx, + id, + ); + maybe_update_idx(new, old, |x| &x.id, &mut self.index_state.control_idx, id); + + maybe_update( + new, + old, + |x| &x.disco_key, + &mut self.index_state.disco_idx, + |old, idx| { + if let Some(key) = &old.disco_key { + let old_id = idx.remove(key); + assert!(old_id.is_some_and(|old_id| old_id == id)); + } + }, + |new, idx| { + if let Some(key) = &new.disco_key { + idx.insert(*key, id); + } + }, + ); + + // Store both `hostname` and fqdn (no trailing dot) in the `name_idx` index. This _does not_ + // preserve uniqueness for `hostname`; as documented on external API such as + // `tailscale::Device::peer_by_name`, there may be collisions in this field (typically when + // nodes are shared into the tailnet with the same name as an existing tailnet device). + // + // We don't resolve this conflict here and make it the caller's problem to include the fqdn + // if there is ambiguity; the index just stores the most recently updated node with a given + // hostname. + // + // Also, this index is overloaded to store both the fqdn and the hostname, but this is + // fine since the fqdn always includes `.`, while the hostname never does, so they're always + // distinguishable. + maybe_update( + new, + old, + |x| (&x.hostname, &x.tailnet), + &mut self.index_state.name_idx, + |old, idx| { + if idx.get(&old.hostname).is_some_and(|&x| x == id) { + idx.remove(&old.hostname); + } + + if let Some(fqdn) = old.fqdn_opt(false) { + let removed_id = idx.remove(&fqdn); + assert!(removed_id.is_some_and(|removed_id| removed_id == id)); + } + }, + |new, idx| { + idx.insert(new.hostname.clone(), id); + + if let Some(fqdn) = new.fqdn_opt(false) { + idx.insert(fqdn, id); + } + }, + ); + + maybe_update( + new, + old, + |x| &x.tailnet_address, + &mut self.index_state.ip_idx, + |old, idx| { + let id4 = idx.remove(old.tailnet_address.ipv4.into()); + let id6 = idx.remove(old.tailnet_address.ipv6.into()); + + assert!(id4.is_some_and(|old_id| old_id == id)); + assert!(id6.is_some_and(|old_id| old_id == id)); + }, + |new, idx| { + idx.insert(new.tailnet_address.ipv4.into(), id); + idx.insert(new.tailnet_address.ipv6.into(), id); + }, + ); + + maybe_update( + new, + old, + |x| &x.accepted_routes, + &mut self.index_state, + |old, idx| { + for &route in &old.accepted_routes { + idx.remove_route(route, id); + } + }, + |new, idx| { + for &route in &new.accepted_routes { + idx.route_idx.modify(route, |val| { + if let Some(val) = val { + val.push(id); + return RouteModification::Noop; + } + + RouteModification::Insert(smallvec::smallvec![id]) + }); + } + }, + ); + + self.peers.insert(id, new.clone()); + + id + } + + /// Remove a peer by a given indexed field. + pub fn remove(&mut self, field: &dyn IndexedField) -> Option<(PeerId, Node)> { + let id = field.lookup(self)?; + + let node = self.peers.remove(&id)?; + self.index_state.remove(id, &node); + + Some((id, node)) + } + + /// Get the node with the given field. + pub fn get(&self, field: &dyn IndexedField) -> Option<(PeerId, &Node)> { + let id = field.lookup(self)?; + let peer = self.peers.get(&id)?; + + Some((id, peer)) + } + + /// Get the nodes with the closest matching route. + pub fn get_route(&self, route: ipnet::IpNet) -> impl Iterator { + // this doesn't use IndexedField because more than one result can be returned + + self.index_state + .route_idx + .lookup_prefix(route) + .into_iter() + .flat_map(|x| x.iter()) + .map(|&id| (id, self.peers.get(&id).unwrap())) + } + + /// Check whether there is a peer with the given field in the db. + pub fn has(&self, field: &dyn IndexedField) -> Option { + field.lookup(self) + } + + /// Get a reference to the peer map. + pub const fn peers(&self) -> &HashMap { + &self.peers + } + + /// Remove the nodes in the db that don't satisfy the predicate function. + pub fn retain(&mut self, mut predicate: impl FnMut(PeerId, &Node) -> bool) { + self.peers.retain(|&id, node| { + let retain = predicate(id, node); + + if !retain { + self.index_state.remove(id, node); + } + + retain + }); + } +} + +impl IndexState { + fn remove(&mut self, id: PeerId, node: &Node) { + self.nk_idx.remove(&node.node_key); + self.stableid_idx.remove(&node.stable_id); + self.control_idx.remove(&node.id); + self.ip_idx.remove(node.tailnet_address.ipv4.into()); + self.ip_idx.remove(node.tailnet_address.ipv6.into()); + + if self.name_idx.get(&node.hostname).is_some_and(|&x| x == id) { + self.name_idx.remove(&node.hostname); + } + + if let Some(fqdn) = node.fqdn_opt(false) { + self.name_idx.remove(&fqdn); + } + + for route in &node.accepted_routes { + self.remove_route(*route, id); + } + + if let Some(disco) = &node.disco_key { + self.disco_idx.remove(disco); + } + } + + /// Remove `route` from the `route_idx`. + fn remove_route(&mut self, route: ipnet::IpNet, id: PeerId) { + self.route_idx.modify(route, |val| match val { + Some(val) => { + let mut some_matched = false; + + val.retain(|&mut x| { + let ids_match = x == id; + if ids_match { + some_matched = true; + } + + !ids_match + }); + + assert!(some_matched); + + if val.is_empty() { + RouteModification::Remove + } else { + RouteModification::Noop + } + } + None => RouteModification::Noop, + }); + } + + #[cfg(test)] + fn is_empty(&self) -> bool { + self.nk_idx.is_empty() + && self.stableid_idx.is_empty() + && self.control_idx.is_empty() + && self.ip_idx.size() == 0 + && self.name_idx.is_empty() + && self.route_idx.size() == 0 + && self.disco_idx.is_empty() + } +} + +/// Attempt to update the `idx` with the `new` node. +/// +/// The `accessor` selects a set of fields to check (by `PartialEq`) for whether the `new` +/// node has changed compared to the `old` one: +/// +/// - If the value returned by `accessor` is the same between `new` and `old`, nothing +/// happens. +/// - If the value has changed and `old` is `Some`, `remove(old, idx)` is called. +/// - If the value has changed, `insert(new, idx)` is called. +fn maybe_update<'n, T, Idx>( + new: &'n Node, + old: Option<&'n Node>, + accessor: impl Fn(&'n Node) -> T, + idx: &mut Idx, + mut remove: impl FnMut(&'n Node, &mut Idx), + mut insert: impl FnMut(&'n Node, &mut Idx), +) where + T: PartialEq + 'n, +{ + match old { + Some(old) if accessor(old) == accessor(new) => { + return; + } + Some(x) => { + remove(x, idx); + } + None => {} + } + + insert(new, idx) +} + +/// Specialization of [`maybe_update`] to work on [`Index`]. +fn maybe_update_idx( + new: &Node, + old: Option<&Node>, + accessor: impl Fn(&Node) -> &T, + idx: &mut Index, + new_id: PeerId, +) where + T: Eq + Hash + Clone, +{ + maybe_update( + new, + old, + &accessor, + idx, + |old, idx| { + let old_id = idx.remove(accessor(old)); + assert!(old_id.is_some_and(|old_id| old_id == new_id)); + }, + |new, idx| { + idx.insert(accessor(new).clone(), new_id); + }, + ) +} + +impl IndexedField for PeerId { + fn lookup(&self, db: &PeerDb) -> Option { + if db.peers.contains_key(self) { + Some(*self) + } else { + None + } + } +} + +impl IndexedField for NodePublicKey { + fn lookup(&self, db: &PeerDb) -> Option { + db.index_state.nk_idx.get(self).copied() + } +} + +impl IndexedField for DiscoPublicKey { + fn lookup(&self, db: &PeerDb) -> Option { + db.index_state.disco_idx.get(self).copied() + } +} + +impl IndexedField for StableNodeId { + fn lookup(&self, db: &PeerDb) -> Option { + db.index_state.stableid_idx.get(self).copied() + } +} + +impl IndexedField for ts_control::NodeId { + fn lookup(&self, db: &PeerDb) -> Option { + db.index_state.control_idx.get(self).copied() + } +} + +impl IndexedField for PeerName { + fn lookup(&self, db: &PeerDb) -> Option { + db.index_state.name_idx.get(self).copied() + } +} + +impl IndexedField for &str { + fn lookup(&self, db: &PeerDb) -> Option { + db.index_state.name_idx.get(*self).copied() + } +} + +impl IndexedField for IpAddr { + fn lookup(&self, db: &PeerDb) -> Option { + db.index_state.ip_idx.lookup(*self).copied() + } +} + +#[cfg(test)] +mod test { + use std::{ + collections::{HashMap, HashSet}, + net::{Ipv4Addr, Ipv6Addr, SocketAddr}, + num::NonZeroU32, + }; + + use proptest::{ + collection::{hash_set, vec}, + prelude::any, + strategy::Strategy, + }; + use rand::{ + RngExt, + distr::{Alphanumeric, SampleString}, + }; + use ts_control::TailnetAddress; + + use super::*; + + fn rand_string(rng: &mut dyn rand::Rng, max_len: usize) -> String { + let len = rng.random_range(1..max_len); + Alphanumeric.sample_string(rng, len) + } + + fn rand_route(rng: &mut dyn rand::Rng) -> ipnet::IpNet { + if rng.random::() { + let ip = rand_ipv4(rng); + ipnet::Ipv4Net::new(ip, rand::random_range(0..=32)) + .unwrap() + .trunc() + .into() + } else { + let ip = rand_ipv6(rng); + ipnet::Ipv6Net::new(ip, rand::random_range(0..=128)) + .unwrap() + .trunc() + .into() + } + } + + fn rand_ipv4(rng: &mut dyn rand::Rng) -> Ipv4Addr { + Ipv4Addr::from_octets(rng.random::<[u8; 4]>()) + } + + fn rand_ipv6(rng: &mut dyn rand::Rng) -> Ipv6Addr { + Ipv6Addr::from_segments(rng.random::<[u16; 8]>()) + } + + fn rand_node() -> Node { + let mut rng = rand::rng(); + + Node { + stable_id: StableNodeId(rand_string(&mut rng, 32)), + tailnet_address: TailnetAddress { + ipv4: rand_ipv4(&mut rng).into(), + ipv6: rand_ipv6(&mut rng).into(), + }, + node_key: rng.random::<[u8; 32]>().into(), + disco_key: rng + .random::() + .then_some(rng.random::<[u8; 32]>().into()), + machine_key: rng + .random::() + .then_some(rng.random::<[u8; 32]>().into()), + id: rng.random(), + accepted_routes: (0..rng.random_range(0..32)) + .map(|_| rand_route(&mut rng)) + .collect(), + + hostname: rand_string(&mut rng, 32), + tailnet: rng.random::().then_some(rand_string(&mut rng, 32)), + + node_key_expiry: None, + underlay_addresses: vec![], + derp_region: rng + .random::() + .then_some(ts_derp::RegionId(rng.random())), + + tags: (0..rng.random_range(0..8)) + .map(|_| rand_string(&mut rng, 32)) + .collect(), + } + } + + fn validate_indices(db: &PeerDb, node: &Node, id: PeerId) { + let ipv4 = IpAddr::from(node.tailnet_address.ipv4.addr()); + let ipv6 = IpAddr::from(node.tailnet_address.ipv6.addr()); + let fqdn = node.fqdn_opt(false); + + let mut keys: Vec<&dyn IndexedField> = + vec![&id, &node.node_key, &node.stable_id, &node.id, &ipv4, &ipv6]; + + if let Some(disco) = &node.disco_key { + keys.push(disco); + } + + if let Some(fqdn) = &fqdn { + keys.push(fqdn); + } + + for k in keys { + let lookup_id = k.lookup(db).unwrap(); + assert_eq!(lookup_id, id, "wrong id for key {k:?}"); + + let (lookup_id, lookup_node) = db.get(k).unwrap(); + assert_eq!(lookup_id, id, "wrong id for key {k:?}"); + assert_eq!(lookup_node, node, "wrong node for key {k:?}"); + } + + // We don't know if the hostname collides, but it should resolve to something + node.hostname.lookup(db).unwrap(); + + for &route in &node.accepted_routes { + // Generically we don't actually know if this node has the most specific match for this + // route, but there should at least be one match, and all matches should have at least + // one route that (inclusively) subsets our route. + + let routes = db.get_route(route).collect::>(); + assert!(!routes.is_empty()); + + for (found_id, found_node) in routes { + if found_id == id { + assert_eq!(found_node, node); + break; + } + + let has_subset = found_node + .accepted_routes + .iter() + .any(|found_route| route.contains(found_route)); + + assert!(has_subset); + } + } + } + + /// Assert that the node's routes are all present as the most specific routes in the + /// db. + fn assert_has_routes_exact(db: &PeerDb, node: &Node, id: PeerId) { + for &route in &node.accepted_routes { + let match_exists = db + .get_route(route) + .any(|(found_id, found_node)| found_id == id && found_node == node); + + assert!(match_exists); + } + } + + #[test] + fn test_indices() { + let mut db = PeerDb::default(); + let node = rand_node(); + let id = db.upsert(&node); + + validate_indices(&db, &node, id); + assert_has_routes_exact(&db, &node, id); + } + + #[test] + fn test_names() { + let mut db = PeerDb::default(); + + let node1 = Node { + hostname: "test".to_string(), + tailnet: Some("ts.net".to_string()), + ..rand_node() + }; + let node2 = Node { + hostname: "test".to_string(), + tailnet: Some("ts2.net".to_string()), + ..rand_node() + }; + let node3 = Node { + hostname: "test".to_string(), + tailnet: None, + ..rand_node() + }; + + let id1 = db.upsert(&node1); + let id2 = db.upsert(&node2); + let id3 = db.upsert(&node3); + + let nodes = [(id1, &node1), (id2, &node2), (id3, &node3)]; + + for (id, node) in &nodes { + validate_indices(&db, node, *id); + } + + let (id, node) = db.get(&"test").unwrap(); + assert!(nodes.iter().any(|(x, _node)| *x == id)); + + for &(x, curnode) in &nodes { + if x == id { + assert_eq!(node, curnode); + } else { + assert_ne!(node, curnode); + } + } + + let (id, node) = db.get(&"test.ts.net").unwrap(); + assert_eq!(id, id1); + assert_eq!(node, &node1); + + let (id, node) = db.get(&"test.ts2.net").unwrap(); + assert_eq!(id, id2); + assert_eq!(node, &node2); + } + + proptest::prop_compose! { + fn ipv4net()( + addr: Ipv4Addr, + pfx in 0u8..=32, + ) -> ipnet::Ipv4Net { + ipnet::Ipv4Net::new(addr, pfx).unwrap().trunc() + } + } + + proptest::prop_compose! { + fn ipv6net()( + addr: Ipv6Addr, + pfx in 0u8..=32, + ) -> ipnet::Ipv6Net { + ipnet::Ipv6Net::new(addr, pfx).unwrap().trunc() + } + } + + fn ipnet() -> impl Strategy { + proptest::prop_oneof![ + ipv4net().prop_map(ipnet::IpNet::from), + ipv6net().prop_map(ipnet::IpNet::from) + ] + } + + proptest::prop_compose! { + fn domain_segment()( + seg in "[[:alpha:]][[:alnum:]]*" + ) -> String { + seg + } + } + + proptest::prop_compose! { + fn domain(max_count: usize)( + segs in proptest::collection::vec(domain_segment(), 0..max_count) + ) -> String { + segs.join(".") + } + } + + type Key = [u8; 32]; + + proptest::prop_compose! { + // This is set up this way to ensure uniqueness among all the required-unique keys in a + // node. The `hash_set`s ensure that all ids AND stable ids AND node keys etc. are unique. + fn nodes(n: usize)( + id in hash_set(any::(), n), + stable_id in hash_set(".+", n), + tags in vec(hash_set(".+", 0..32), n), + accepted_routes in vec(hash_set(ipnet(), 0..32), n), + node_key in hash_set(any::(), n), + machine_key in vec(any::>(), n), + disco_key in vec(any::>(), n), + ipv4 in hash_set(any::(), n), + ipv6 in hash_set(any::(), n), + name in hash_set(domain_segment(), n), + tailnet in vec(domain(5), n), + has_tailnet in vec(any::(), n), + derp_region in vec(any::>(), n), + underlay_addrs in vec(any::>(), n), + ) -> Vec { + itertools::izip![ + id, + stable_id, + tags, + accepted_routes, + node_key, + machine_key, + disco_key, + ipv4, + ipv6, + name, + tailnet, + has_tailnet, + derp_region, + underlay_addrs, + ].map(|( + id, + stable_id, + tags, + mut accepted_routes, + node_key, + machine_key, + disco_key, + ipv4, + ipv6, + name, + tailnet, + has_tailnet, + derp_region, + underlay_addrs, + )| { + accepted_routes.insert(ipnet::Ipv4Net::from(ipv4).into()); + accepted_routes.insert(ipnet::Ipv6Net::from(ipv6).into()); + + Node { + id, + stable_id: StableNodeId(stable_id), + + hostname: name, + tailnet: has_tailnet.then_some(tailnet), + + node_key: node_key.into(), + disco_key: disco_key.map(Into::into), + machine_key: machine_key.map(Into::into), + + node_key_expiry: None, + + tailnet_address: TailnetAddress { + ipv4: ipv4.into(), + ipv6: ipv6.into(), + }, + tags: tags.into_iter().collect(), + + derp_region: derp_region.map(ts_derp::RegionId), + + accepted_routes: accepted_routes.into_iter().collect(), + underlay_addresses: underlay_addrs.into_iter().collect(), + } + }) + .collect() + } + } + + proptest::proptest! { + #[test] + fn prop_one_node_indices(mut nodes in nodes(1)) { + let node = nodes.pop().unwrap(); + + let mut db = PeerDb::default(); + let id = db.upsert(&node); + + validate_indices(&db, &node, id); + assert_has_routes_exact(&db, &node, id); + } + + #[test] + fn prop_many_nodes_indexed(nodes in nodes(16)) { + let mut db = PeerDb::default(); + + let mut nodes_by_id = HashMap::new(); + + for node in &nodes { + let id = db.upsert(node); + nodes_by_id.insert(id, node.clone()); + } + + for (id, node) in &nodes_by_id { + validate_indices(&db, node, *id); + } + } + + #[test] + fn prop_remove(nodes in nodes(16)) { + let mut db = PeerDb::default(); + + let mut ids = vec![]; + + for node in &nodes { + ids.push((db.upsert(node), node)); + } + + for (id, node) in ids { + let (removed_id, removed_node) = db.remove(&id).unwrap(); + + proptest::prop_assert_eq!(removed_id, id); + proptest::prop_assert_eq!(&removed_node, node); + } + + proptest::prop_assert!(db.peers.is_empty()); + proptest::prop_assert!(db.index_state.is_empty()); + } + } +} diff --git a/ts_runtime/src/route_updater.rs b/ts_runtime/src/route_updater.rs index d85d0a54..6fbc7511 100644 --- a/ts_runtime/src/route_updater.rs +++ b/ts_runtime/src/route_updater.rs @@ -5,11 +5,10 @@ use kameo::{ message::{Context, Message}, }; use ts_bart::RoutingTable; -use ts_keys::NodePublicKey; use ts_overlay_router::{ inbound::RouteAction as InboundRouteAction, outbound::RouteAction as OutboundRouteAction, }; -use ts_transport::{OverlayTransportId, UnderlayTransportId}; +use ts_transport::{OverlayTransportId, PeerId, UnderlayTransportId}; use crate::{Error, env::Env, multiderp, multiderp::Multiderp, peer_tracker::PeerState}; @@ -27,7 +26,7 @@ impl kameo::Actor for RouteUpdater { (multiderp, env, default_transport): Self::Args, actor_ref: ActorRef, ) -> Result { - env.subscribe::(&actor_ref).await?; + env.subscribe::>(&actor_ref).await?; env.subscribe::>(&actor_ref) .await?; @@ -50,35 +49,35 @@ pub struct PeerRouteUpdate { } pub struct PeerRoutesInner { - pub underlay_routes: HashMap, + pub underlay_routes: HashMap, pub overlay_out_routes: ts_bart::Table, } -impl Message for RouteUpdater { +impl Message> for RouteUpdater { type Reply = (); - async fn handle(&mut self, msg: PeerState, _ctx: &mut Context) { + async fn handle(&mut self, msg: Arc, _ctx: &mut Context) { tracing::trace!( - n_peers = msg.peers.len(), + n_peers = msg.peers.peers().len(), "reconstructing routes for peer update" ); let mut overlay_out = ts_bart::Table::default(); let mut underlay_out = HashMap::default(); - for peer in msg.peers.values() { + for (id, peer) in msg.peers.peers() { let span = tracing::trace_span!( "peer_update", - peer = %peer.node_key, - region = tracing::field::Empty, + peer_key = %peer.node_key, + region = ?peer.derp_region, underlay_transport = tracing::field::Empty, + peer_id = ?id, ); let Some(region) = peer.derp_region else { tracing::trace!(parent: &span, "peer has no derp region"); continue; }; - span.record("region", tracing::field::debug(region)); tracing::trace!(parent: &span, "ask multiderp for transport id"); @@ -89,7 +88,7 @@ impl Message for RouteUpdater { { Ok(Some(transport_id)) => { span.record("underlay_transport", tracing::field::debug(transport_id)); - underlay_out.insert(peer.node_key, transport_id); + underlay_out.insert(*id, transport_id); tracing::trace!(parent: &span, "set underlay route"); } Ok(None) => { @@ -103,7 +102,7 @@ impl Message for RouteUpdater { for route in &peer.accepted_routes { tracing::trace!(parent: &span, %route, "routes"); - overlay_out.insert(*route, OutboundRouteAction::Wireguard(peer.node_key)); + overlay_out.insert(*route, OutboundRouteAction::Wireguard(*id)); } } diff --git a/ts_runtime/src/src_filter.rs b/ts_runtime/src/src_filter.rs index cfa1aae6..29f00a9c 100644 --- a/ts_runtime/src/src_filter.rs +++ b/ts_runtime/src/src_filter.rs @@ -5,7 +5,7 @@ use kameo::{ message::{Context, Message}, }; use ts_bart::{RoutingTable, Table}; -use ts_keys::NodePublicKey; +use ts_transport::PeerId; use crate::{Error, env::Env, peer_tracker::PeerState}; @@ -18,23 +18,28 @@ impl kameo::Actor for SourceFilterUpdater { type Error = Error; async fn on_start(env: Self::Args, slf: ActorRef) -> Result { - env.subscribe::(&slf).await?; + env.subscribe::>(&slf).await?; Ok(Self { env }) } } #[derive(Clone)] -pub struct SourceFilterState(pub Arc>); +pub struct SourceFilterState(pub Arc>); -impl Message for SourceFilterUpdater { +impl Message> for SourceFilterUpdater { type Reply = (); - async fn handle(&mut self, state_update: PeerState, _ctx: &mut Context) { + async fn handle( + &mut self, + state_update: Arc, + _ctx: &mut Context, + ) { let mut src_filter = Table::default(); - for (nodekey, node) in state_update.peers.iter() { + + for (id, node) in state_update.peers.peers() { for route in node.accepted_routes.iter() { - src_filter.insert(route.to_owned(), *nodekey); + src_filter.insert(route.to_owned(), *id); } } diff --git a/ts_transport/Cargo.toml b/ts_transport/Cargo.toml index 358284fa..184de1fc 100644 --- a/ts_transport/Cargo.toml +++ b/ts_transport/Cargo.toml @@ -12,7 +12,6 @@ rust-version.workspace = true [dependencies] # Our crates. -ts_keys.workspace = true ts_packet.workspace = true [lints] diff --git a/ts_transport/src/batch_iter.rs b/ts_transport/src/batch_iter.rs new file mode 100644 index 00000000..597a9ea6 --- /dev/null +++ b/ts_transport/src/batch_iter.rs @@ -0,0 +1,88 @@ +use ts_packet::PacketMut; + +/// Wrapper around [`IntoIterator`] for a batch of packets keyed by `Key` which ensures +/// that it and all nested iterators are [`Send`]. +/// +/// Think of this as morally `HashMap>`, but with added flexibility +/// for the caller to convert source values on-the-fly without having to allocate an +/// intermediate collection. +pub trait BatchSendIter: Send { + /// Equivalent of the `IntoIter` type with the `Send` bound applied and `Item` + /// specified. + type BatchIt: Iterator + Send; + + /// Inner packet iterator (per-`Key`). + type PacketIt: PacketIter; + + /// Equivalent of [`IntoIterator::into_iter`], but with the bounds from `BatchIt` + /// enforced. + fn batch_iter(self) -> Self::BatchIt; +} + +/// Wrapper around [`IntoIterator`] for a batch of packets keyed by `Key` which ensures that +/// it and the nested iterators are [`Send`]. +/// +/// This is used to _return_ values from [`crate::UnderlayTransport::recv`], and so has a +/// slightly different shape than [`BatchSendIter`] (the items are `Result`s). +/// +/// Think of this as morally `HashMap>`, but with added flexibility +/// for the caller to convert source values on-the-fly without having to allocate an +/// intermediate collection. +pub trait BatchRecvIter: Send { + /// The error type this iterator may have. + type Error; + + /// Equivalent of the `IntoIter` type with the `Send` bound applied and `Item` + /// specified. + type BatchIt: Iterator> + Send; + + /// Inner packet iterator (per-`Key`). + type PacketIt: PacketIter; + + /// Equivalent of [`IntoIterator::into_iter`], but with the bounds from `BatchIt` + /// enforced. + fn batch_iter(self) -> Self::BatchIt; +} + +impl BatchSendIter for T +where + T: IntoIterator + Send, + ::IntoIter: Send, + P: PacketIter, +

::IntoIter: Send, +{ + type BatchIt = ::IntoIter; + type PacketIt = P; + + fn batch_iter(self) -> Self::BatchIt { + self.into_iter() + } +} + +impl BatchRecvIter for T +where + T: IntoIterator> + Send, + ::IntoIter: Send, + P: PacketIter, +

::IntoIter: Send, +{ + type Error = E; + type BatchIt = ::IntoIter; + type PacketIt = P; + + fn batch_iter(self) -> Self::BatchIt { + self.into_iter() + } +} + +pub trait PacketIter: IntoIterator + Send { + type PacketIt: Send + Iterator; +} + +impl

PacketIter for P +where + P: IntoIterator + Send, +

::IntoIter: Send, +{ + type PacketIt = P::IntoIter; +} diff --git a/ts_transport/src/lib.rs b/ts_transport/src/lib.rs index 396367c0..94ab1876 100644 --- a/ts_transport/src/lib.rs +++ b/ts_transport/src/lib.rs @@ -3,11 +3,19 @@ extern crate alloc; -use core::error::Error; +use core::{ + error::Error, + fmt::{Debug, Display, Formatter}, +}; -use ts_keys::NodePublicKey; use ts_packet::PacketMut; +mod batch_iter; +mod map_key; + +pub use batch_iter::{BatchRecvIter, BatchSendIter}; +pub use map_key::{MapPeerKey, PeerLookup}; + /// The unique id of an overlay transport. #[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] pub struct OverlayTransportId(pub u32); @@ -40,44 +48,71 @@ impl From for u32 { } } +/// The unique id of a peer. +#[derive(Debug, Copy, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub struct PeerId(pub u32); + +impl Display for PeerId { + fn fmt(&self, f: &mut Formatter<'_>) -> core::fmt::Result { + Debug::fmt(self, f) + } +} + /// An abstract transport that can carry packets to configurable destinations. pub trait UnderlayTransport { + /// The type of key this transport uses to identify peers. + /// + /// The runtime generally wants to use [`PeerId`] here, but transports will almost + /// always want to use a different key type for communication (however the peer is + /// known on the wire). + /// + /// To decouple, transport implementations can use their wire type here, while the + /// runtime wraps the implementation with [`UnderlayTransportExt::with_key_lookup`] + /// to provide functionality to convert the wire type to and from [`PeerId`]. + type PeerKey: Send + Sync + 'static; + /// The error type that this transport may produce. type Error: Error + Send + Sync + 'static; /// Send packets through the transport. /// /// The return type should be interpreted as meaning essentially - /// `HashMap>`. It is set up this way to enable the caller + /// `HashMap>`. It is set up this way to enable the caller /// to use iterators to transform a collection of a slightly different shape, or e.g. - /// look up `NodePublicKey`s on-the-fly, without having to `.collect()` into an + /// look up `PeerId`s on-the-fly, without having to `.collect()` into an /// intermediary collection. - fn send( + fn send( &self, - packet_batch: BatchIter, - ) -> impl Future> + Send - where - BatchIter: IntoIterator + Send, - BatchIter::IntoIter: Send, - PacketIter: IntoIterator + Send, - PacketIter::IntoIter: Send; + packet_batch: impl BatchSendIter, + ) -> impl Future> + Send; /// Receive packets from the transport. /// /// The return type should be interpreted as meaning essentially - /// `HashMap>`, but allows for the implementation to + /// `HashMap>`, but allows for the implementation to /// use iterators to map a collection of a slightly different shape, or e.g. look up - /// `NodePublicKey`s on-the-fly, without having to `.collect()` into an intermediary + /// `PeerId`s on-the-fly, without having to `.collect()` into an intermediary /// collection. fn recv( &self, - ) -> impl Future< - Output = impl IntoIterator< - Item = Result<(NodePublicKey, impl IntoIterator), Self::Error>, - >, - > + Send; + ) -> impl Future> + Send; } +/// Extension methods on [`UnderlayTransport`]. +pub trait UnderlayTransportExt: UnderlayTransport { + /// Map the keys used by this transport with the given [`PeerLookup`]. + fn with_key_lookup(self, lookup: Lookup) -> MapPeerKey + where + Self: Sized + Send + Sync, + Lookup: PeerLookup + PeerLookup + Send + Sync, + DstKey: Send + Sync + 'static, + { + MapPeerKey::new(self, lookup) + } +} + +impl UnderlayTransportExt for T where T: UnderlayTransport {} + /// A transport that can carry packets to and from the overlay network. pub trait OverlayTransport { /// The error type this transport may produce. diff --git a/ts_transport/src/map_key.rs b/ts_transport/src/map_key.rs new file mode 100644 index 00000000..b953e25c --- /dev/null +++ b/ts_transport/src/map_key.rs @@ -0,0 +1,60 @@ +use core::marker::PhantomData; + +use crate::{BatchRecvIter, BatchSendIter, UnderlayTransport}; + +/// Trait providing key lookup from one type to another. +pub trait PeerLookup { + /// Lookup the corresponding `To` key from this `From` key. + fn lookup_key(&self, from: From) -> Option; +} + +/// An [`UnderlayTransport`] that converts keys between two types using a [`PeerLookup`]. +pub struct MapPeerKey { + inner: Inner, + lookup: Lookup, + dst: PhantomData, +} + +impl MapPeerKey { + /// Construct a new [`MapPeerKey`] with the given lookup. + pub const fn new(t: T, lookup: Lookup) -> Self { + Self { + inner: t, + lookup, + dst: PhantomData, + } + } +} + +impl UnderlayTransport for MapPeerKey +where + Inner: UnderlayTransport + Send + Sync, + Lookup: PeerLookup + PeerLookup + Send + Sync, + DstKey: Send + Sync + 'static, +{ + type PeerKey = DstKey; + type Error = Inner::Error; + + async fn send(&self, packet_batch: impl BatchSendIter) -> Result<(), Self::Error> { + self.inner + .send(packet_batch.batch_iter().filter_map(|(key, packets)| { + let k = self.lookup.lookup_key(key)?; + Some((k, packets)) + })) + .await + } + + async fn recv(&self) -> impl BatchRecvIter { + self.inner + .recv() + .await + .batch_iter() + .filter_map(|result| match result { + Ok((key, pkts)) => { + let k = self.lookup.lookup_key(key)?; + Some(Ok((k, pkts))) + } + Err(e) => Some(Err(e)), + }) + } +} diff --git a/ts_tunnel/examples/handshake.rs b/ts_tunnel/examples/handshake.rs index 8412a812..2370bd84 100644 --- a/ts_tunnel/examples/handshake.rs +++ b/ts_tunnel/examples/handshake.rs @@ -84,12 +84,18 @@ async fn main() -> BoxResult<()> { let mut ep = Endpoint::new(privkey.into()); - let peer_id = ep - .add_peer(ts_tunnel::PeerConfig { - key: peer_key, - psk: [0; 32].into(), - }) - .ok_or("couldn't add peer")?; + let peer_id = ts_tunnel::PeerId(1); + + assert!( + ep.upsert_peer( + peer_id, + ts_tunnel::PeerConfig { + key: peer_key, + psk: [0; 32].into(), + } + ) + .is_none() + ); let sock = tokio::net::UdpSocket::bind("0.0.0.0:0").await?; eprintln!("socket bound to {}", sock.local_addr()?.port()); diff --git a/ts_tunnel/src/endpoint.rs b/ts_tunnel/src/endpoint.rs index f8bd141c..73cbce18 100644 --- a/ts_tunnel/src/endpoint.rs +++ b/ts_tunnel/src/endpoint.rs @@ -215,7 +215,6 @@ struct IdMap { // TODO: track recently abandoned session IDs, avoid reusing them for // one or two session lifetimes to avoid confusion with reordered packets. node_keys: HashMap, - next_peer_id: u32, } impl IdMap { @@ -229,17 +228,16 @@ impl IdMap { self.sessions.get(key) } - /// Allocate a new peer handle for communicating with the given peer pubkey. + /// Add a peer handle for communicating with the given peer pubkey. /// - /// Returns None if a peer already exists for the key. - fn allocate_peer(&mut self, key: &NodePublicKey) -> Option { + /// Returns `false` if a peer already exists for the key. + fn add_peer(&mut self, id: PeerId, key: &NodePublicKey) -> bool { if self.node_keys.contains_key(key) { - return None; + return false; } - self.next_peer_id += 1; - let ret = PeerId(self.next_peer_id); - self.node_keys.insert(*key, ret); - Some(ret) + + self.node_keys.insert(*key, id); + true } /// Allocate a new session ID for communication with the given peer. @@ -568,14 +566,35 @@ impl Endpoint { } } - /// Add a new peer. + /// Insert a peer if it doesn't exist, otherwise update the peer with the given `id` + /// with the given config. + /// + /// Returns the old [`PeerConfig`] if there was one. + /// + /// # Panics /// - /// Returns a handle to the newly configured peer, or None if a peer is already configured - /// with the given node key. - pub fn add_peer(&mut self, cfg: PeerConfig) -> Option { - let ret = self.state.ids.allocate_peer(&cfg.key)?; - self.peers.insert(ret, Peer::new(ret, cfg)); - Some(ret) + /// If the [`NodePublicKey`] in the new [`PeerConfig`] collides with an existing key + /// for a different [`PeerId`]. + pub fn upsert_peer(&mut self, id: PeerId, mut cfg: PeerConfig) -> Option { + match self.peers.get_mut(&id) { + Some(peer) => { + if peer.config.key != cfg.key { + self.state.ids.remove_peer(&peer.config.key); + self.state.ids.add_peer(id, &cfg.key); + } + + core::mem::swap(&mut peer.config, &mut cfg); + Some(cfg) + } + None => { + if !self.state.ids.add_peer(id, &cfg.key) { + panic!("nodekey collision"); + } + + self.peers.insert(id, Peer::new(id, cfg)); + None + } + } } /// Remove the given peer. @@ -803,19 +822,30 @@ mod tests { let (mut a_ep, mut b_ep) = (Endpoint::new(a_static), Endpoint::new(b_static)); - let a_peer = a_ep - .add_peer(PeerConfig { - key: b_static.public, - psk, - }) - .unwrap(); - - let b_peer = b_ep - .add_peer(PeerConfig { - key: a_static.public, - psk, - }) - .unwrap(); + let a_peer = PeerId(1); + let b_peer = PeerId(1); + + assert!( + a_ep.upsert_peer( + a_peer, + PeerConfig { + key: b_static.public, + psk, + }, + ) + .is_none() + ); + + assert!( + b_ep.upsert_peer( + b_peer, + PeerConfig { + key: a_static.public, + psk, + }, + ) + .is_none() + ); let a_to_b_packets = [ PacketMut::from(vec![1, 2, 3, 4]), diff --git a/ts_underlay_router/Cargo.toml b/ts_underlay_router/Cargo.toml index e176e017..d8d2b05c 100644 --- a/ts_underlay_router/Cargo.toml +++ b/ts_underlay_router/Cargo.toml @@ -11,7 +11,6 @@ license.workspace = true rust-version.workspace = true [dependencies] -ts_keys.workspace = true ts_packet.workspace = true ts_transport.workspace = true diff --git a/ts_underlay_router/src/outbound.rs b/ts_underlay_router/src/outbound.rs index 8519099d..7691a40b 100644 --- a/ts_underlay_router/src/outbound.rs +++ b/ts_underlay_router/src/outbound.rs @@ -2,28 +2,24 @@ use std::collections::HashMap; -use ts_keys::NodePublicKey; use ts_packet::PacketMut; -use ts_transport::UnderlayTransportId; +use ts_transport::{PeerId, UnderlayTransportId}; /// Routes packets that originate from the local device. #[derive(Default)] pub struct Router { /// The transport to use for sending to each wireguard peer. - pub table: HashMap, + pub table: HashMap, } /// The outcome of routing packets. -pub type Result = HashMap<(UnderlayTransportId, NodePublicKey), Vec>; +pub type Result = HashMap<(UnderlayTransportId, PeerId), Vec>; impl Router { /// Assigns a batch of packets to their next hop. /// /// Packets that don't match any routes are dropped. - pub fn route( - &self, - batches: impl IntoIterator)>, - ) -> Result { + pub fn route(&self, batches: impl IntoIterator)>) -> Result { let mut ret = Result::default(); for (peer_id, packets) in batches { @@ -44,11 +40,11 @@ mod tests { #[test] fn test_outbound_underlay() { - let peer_a = NodePublicKey::from([1u8; 32]); - let peer_b = NodePublicKey::from([2u8; 32]); - let peer_c = NodePublicKey::from([3u8; 32]); - let peer_d = NodePublicKey::from([4u8; 32]); - let peer_e = NodePublicKey::from([5u8; 32]); + let peer_a = PeerId(1); + let peer_b = PeerId(2); + let peer_c = PeerId(3); + let peer_d = PeerId(4); + let peer_e = PeerId(5); let transport_a = 5.into(); let transport_b = 6.into(); let transport_c = 7.into();