diff --git a/device/deadlock_test.go b/device/deadlock_test.go new file mode 100644 index 000000000..2ced77649 --- /dev/null +++ b/device/deadlock_test.go @@ -0,0 +1,109 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package device + +import ( + "testing" + "time" + + "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/tun/tuntest" +) + +// TestSetPrivateKeyConsumeInitiationDeadlock verifies that SetPrivateKey and +// ConsumeMessageInitiation do not deadlock. +// +// ConsumeMessageInitiation holds staticIdentity.RLock (noise-protocol.go:352) +// for the duration of the call. When processing an unknown peer, it calls +// LookupPeer -> PeerLookupFunc -> NewPeer, and NewPeer attempts a reentrant +// staticIdentity.RLock (peer.go:82). If SetPrivateKey has called +// staticIdentity.Lock (device.go:235) in between, the pending writer blocks +// the reentrant reader, while the writer waits for the original reader to +// release, a classic sync.RWMutex reentrant-reader deadlock. +// +// We would like to use testing/synctest to eliminate the wall-clock sleep +// and 5-second timeout, but sync.RWMutex is not bubble-aware as of Go 1.26, +// so a goroutine blocked in RWMutex.Lock is not classified as "durably +// blocked" and synctest cannot detect this deadlock. Revisit once the sync +// package participates in synctest bubbles. +func TestSetPrivateKeyConsumeInitiationDeadlock(t *testing.T) { + // Create the receiver device. We intentionally avoid t.Cleanup(Close) + // because in the deadlock case the locks are permanently held and + // Close would block. + recvSK, err := newPrivateKey() + if err != nil { + t.Fatal(err) + } + receiver := NewDevice( + tuntest.NewChannelTUN().TUN(), + conn.NewDefaultBind(), + NewLogger(LogLevelError, ""), + ) + receiver.SetPrivateKey(recvSK) + + // Create the initiator device and add receiver as a peer so we can + // produce a valid MessageInitiation. + initiator := randDevice(t) + defer initiator.Close() + + peer, err := initiator.NewPeer(recvSK.publicKey()) + if err != nil { + t.Fatal(err) + } + peer.Start() + + msg, err := initiator.CreateMessageInitiation(peer) + if err != nil { + t.Fatal(err) + } + + newSK, err := newPrivateKey() + if err != nil { + t.Fatal(err) + } + + // PeerLookupFunc is the synchronization point: when it fires, + // ConsumeMessageInitiation is holding staticIdentity.RLock and we are + // between the two RLock acquisitions (the second one is inside NewPeer). + inLookup := make(chan struct{}) + proceed := make(chan struct{}) + receiver.SetPeerLookupFunc(func(pk NoisePublicKey) (_ *NewPeerConfig, ok bool) { + close(inLookup) + <-proceed + return &NewPeerConfig{}, true + }) + + // Goroutine A: ConsumeMessageInitiation holds staticIdentity.RLock. + go receiver.ConsumeMessageInitiation(msg, nil) + <-inLookup + + // Goroutine B: SetPrivateKey calls staticIdentity.Lock (write); + // blocks because A holds the read lock. + setKeyDone := make(chan error, 1) + go func() { + setKeyDone <- receiver.SetPrivateKey(newSK) + }() + + // Give B time to reach staticIdentity.Lock (the very first operation + // in SetPrivateKey), so it becomes a pending writer. + time.Sleep(20 * time.Millisecond) + + // A proceeds: LookupPeer → NewPeer → staticIdentity.RLock (reentrant). + // With B's write pending, the reentrant read lock blocks. Deadlock. + close(proceed) + + select { + case err := <-setKeyDone: + if err != nil { + t.Fatal(err) + } + receiver.Close() + case <-time.After(5 * time.Second): + t.Fatal("deadlock: ConsumeMessageInitiation holds staticIdentity.RLock, " + + "SetPrivateKey is pending staticIdentity.Lock (write), and NewPeer's " + + "reentrant staticIdentity.RLock is blocked by the pending writer") + } +} diff --git a/device/device.go b/device/device.go index 0e720f251..da668d6f8 100644 --- a/device/device.go +++ b/device/device.go @@ -182,14 +182,22 @@ func (device *Device) upLocked() error { device.ipcMutex.Lock() defer device.ipcMutex.Unlock() + // Collect peers under RLock and then release before calling into them, + // because SendKeepalive can reach CreateMessageInitiation which acquires + // staticIdentity.RLock; holding peers.RLock across that path would + // invert the staticIdentity < peers hierarchy (see lock-ordering.md). device.peers.RLock() + peers := make([]*Peer, 0, len(device.peers.keyMap)) for _, peer := range device.peers.keyMap { + peers = append(peers, peer) + } + device.peers.RUnlock() + for _, peer := range peers { peer.Start() if peer.persistentKeepaliveInterval.Load() > 0 { peer.SendKeepalive() } } - device.peers.RUnlock() return nil } @@ -533,16 +541,25 @@ func (device *Device) SendKeepalivesToPeersWithCurrentKeypair() { return } + // Collect the set of peers to keepalive under peers.RLock, then release + // before invoking SendKeepalive. SendKeepalive can reach + // CreateMessageInitiation which acquires staticIdentity.RLock; holding + // peers.RLock across that path would invert the + // staticIdentity < peers hierarchy (see lock-ordering.md). + var peers []*Peer device.peers.RLock() for _, peer := range device.peers.keyMap { peer.keypairs.RLock() sendKeepalive := peer.keypairs.current != nil && !peer.keypairs.current.created.Add(RejectAfterTime).Before(time.Now()) peer.keypairs.RUnlock() if sendKeepalive { - peer.SendKeepalive() + peers = append(peers, peer) } } device.peers.RUnlock() + for _, peer := range peers { + peer.SendKeepalive() + } } // closeBindLocked closes the device's net.bind. diff --git a/device/lock-ordering.md b/device/lock-ordering.md new file mode 100644 index 000000000..55a15c0b7 --- /dev/null +++ b/device/lock-ordering.md @@ -0,0 +1,27 @@ +# Lock Ordering in wireguard-go/device + +## Lock hierarchy + +Locks must be acquired in the order listed below. A goroutine holding a +lock with a higher number must never attempt to acquire a lock with a +lower number. + +``` +Level 0 device.state.Mutex +Level 1 device.ipcMutex (sync.RWMutex) +Level 2 device.net.RWMutex +Level 3 device.staticIdentity.RWMutex +Level 4 device.peers.RWMutex +Level 5 peer.state.Mutex +Level 6 peer.handshake.mutex (sync.RWMutex) +Level 7 peer.keypairs.RWMutex +Level 8 device.allowedips.mu (sync.RWMutex) +Level 9 device.indexTable.RWMutex +Level 10 peer.endpoint.Mutex +Level 11 device.cookieChecker.RWMutex +Level 12 peer.cookieGenerator.RWMutex +Level 13 Timer.modifyingLock / Timer.runningLock +``` + +Not every pair of locks appears in practice; the ordering above is the +transitive closure of the pairs that do. diff --git a/device/noise-protocol.go b/device/noise-protocol.go index ad5838e1d..18602f1dd 100644 --- a/device/noise-protocol.go +++ b/device/noise-protocol.go @@ -349,17 +349,22 @@ func (device *Device) ConsumeMessageInitiation(msg *MessageInitiation, endpoint return nil } + // Snapshot staticIdentity so we don't hold the RLock across LookupPeer, + // which may call NewPeer (reentrant RLock deadlocks against a pending + // SetPrivateKey writer; see lock-ordering.md). device.staticIdentity.RLock() - defer device.staticIdentity.RUnlock() + publicKey := device.staticIdentity.publicKey + privateKey := device.staticIdentity.privateKey + device.staticIdentity.RUnlock() - mixHash(&hash, &InitialHash, device.staticIdentity.publicKey[:]) + mixHash(&hash, &InitialHash, publicKey[:]) mixHash(&hash, &hash, msg.Ephemeral[:]) mixKey(&chainKey, &InitialChainKey, msg.Ephemeral[:]) // decrypt static key var peerPK NoisePublicKey var key [chacha20poly1305.KeySize]byte - ss, err := device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + ss, err := privateKey.sharedSecret(msg.Ephemeral) if err != nil { return nil } @@ -534,6 +539,14 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { chainKey [blake2s.Size]byte ) + // Snapshot the static private key before acquiring handshake.mutex so + // that handshake.mutex is never held while acquiring staticIdentity + // (which would invert the staticIdentity < handshake.mutex hierarchy; + // see lock-ordering.md). + device.staticIdentity.RLock() + privateKey := device.staticIdentity.privateKey + device.staticIdentity.RUnlock() + ok := func() bool { // lock handshake state @@ -544,11 +557,6 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { return false } - // lock private key for reading - - device.staticIdentity.RLock() - defer device.staticIdentity.RUnlock() - // finish 3-way DH mixHash(&hash, &handshake.hash, msg.Ephemeral[:]) @@ -561,7 +569,7 @@ func (device *Device) ConsumeMessageResponse(msg *MessageResponse) *Peer { mixKey(&chainKey, &chainKey, ss[:]) setZero(ss[:]) - ss, err = device.staticIdentity.privateKey.sharedSecret(msg.Ephemeral) + ss, err = privateKey.sharedSecret(msg.Ephemeral) if err != nil { return false }