diff --git a/conn/bind_std.go b/conn/bind_std.go index c13891e67..a4774dea8 100644 --- a/conn/bind_std.go +++ b/conn/bind_std.go @@ -16,6 +16,7 @@ import ( "sync" "syscall" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" ) @@ -233,13 +234,13 @@ func (s *StdNetBind) receiveIP( br batchReader, conn *net.UDPConn, rxOffload bool, - bufs [][]byte, - sizes []int, + bufs []iobuf.View, eps []Endpoint, ) (n int, err error) { msgs := s.getMessages() + iobuf.EnsureAllocated(bufs) for i := range bufs { - (*msgs)[i].Buffers[0] = bufs[i] + (*msgs)[i].Buffers[0] = bufs[i].Bytes (*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)] } defer s.putMessages(msgs) @@ -271,8 +272,8 @@ func (s *StdNetBind) receiveIP( } for i := 0; i < numMsgs; i++ { msg := &(*msgs)[i] - sizes[i] = msg.N - if sizes[i] == 0 { + bufs[i].Bytes = bufs[i].Bytes[:msg.N] + if len(bufs[i].Bytes) == 0 { continue } addrPort := msg.Addr.(*net.UDPAddr).AddrPort() @@ -284,14 +285,14 @@ func (s *StdNetBind) receiveIP( } func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + return func(bufs []iobuf.View, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, eps) } } func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { - return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps) + return func(bufs []iobuf.View, eps []Endpoint) (n int, err error) { + return s.receiveIP(pc, conn, rxOffload, bufs, eps) } } diff --git a/conn/bind_std_test.go b/conn/bind_std_test.go index 254952f0a..88dcc6c2f 100644 --- a/conn/bind_std_test.go +++ b/conn/bind_std_test.go @@ -5,6 +5,7 @@ import ( "net" "testing" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/net/ipv6" ) @@ -15,15 +16,14 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) { t.Fatal(err) } bind.Close() - bufs := make([][]byte, 1) - bufs[0] = make([]byte, 1) - sizes := make([]int, 1) + bufs := make([]iobuf.View, 1) + bufs[0] = iobuf.View{Bytes: make([]byte, 1)} eps := make([]Endpoint, 1) for _, fn := range fns { // The ReceiveFuncs must not access conn-related fields on StdNetBind // unguarded. Close() nils the conn-related fields resulting in a panic // if they violate the mutex. - fn(bufs, sizes, eps) + fn(bufs, eps) } } diff --git a/conn/bind_windows.go b/conn/bind_windows.go index 737b475e1..ba2b45bd4 100644 --- a/conn/bind_windows.go +++ b/conn/bind_windows.go @@ -18,6 +18,7 @@ import ( "golang.org/x/sys/windows" "github.com/tailscale/wireguard-go/conn/winrio" + "github.com/tailscale/wireguard-go/iobuf" ) const ( @@ -416,20 +417,22 @@ retry: return n, &ep, nil } -func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { +func (bind *WinRingBind) receiveIPv4(bufs []iobuf.View, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() - n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen) - sizes[0] = n + iobuf.EnsureAllocated(bufs[:1]) + n, ep, err := bind.v4.Receive(bufs[0].Bytes, &bind.isOpen) + bufs[0].Bytes = bufs[0].Bytes[:n] eps[0] = ep return 1, err } -func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) { +func (bind *WinRingBind) receiveIPv6(bufs []iobuf.View, eps []Endpoint) (int, error) { bind.mu.RLock() defer bind.mu.RUnlock() - n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen) - sizes[0] = n + iobuf.EnsureAllocated(bufs[:1]) + n, ep, err := bind.v6.Receive(bufs[0].Bytes, &bind.isOpen) + bufs[0].Bytes = bufs[0].Bytes[:n] eps[0] = ep return 1, err } diff --git a/conn/bindtest/bindtest.go b/conn/bindtest/bindtest.go index 741b776c4..e75749982 100644 --- a/conn/bindtest/bindtest.go +++ b/conn/bindtest/bindtest.go @@ -13,6 +13,7 @@ import ( "os" "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/iobuf" ) type ChannelBind struct { @@ -94,13 +95,14 @@ func (c *ChannelBind) BatchSize() int { return 1 } func (c *ChannelBind) SetMark(mark uint32) error { return nil } func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc { - return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) { + return func(bufs []iobuf.View, eps []conn.Endpoint) (n int, err error) { select { case <-c.closeSignal: return 0, net.ErrClosed case rx := <-ch: - copied := copy(bufs[0], rx) - sizes[0] = copied + iobuf.EnsureAllocated(bufs[:1]) + n := copy(bufs[0].Bytes, rx) + bufs[0].Bytes = bufs[0].Bytes[:n] eps[0] = c.target6 return 1, nil } diff --git a/conn/conn.go b/conn/conn.go index f1781614d..a0a8ffa36 100644 --- a/conn/conn.go +++ b/conn/conn.go @@ -13,19 +13,20 @@ import ( "reflect" "runtime" "strings" + + "github.com/tailscale/wireguard-go/iobuf" ) const ( IdealBatchSize = 128 // maximum number of packets handled per read and write ) -// A ReceiveFunc receives at least one packet from the network and writes them -// into packets. On a successful read it returns the number of elements of -// sizes, packets, and endpoints that should be evaluated. Some elements of -// sizes may be zero, and callers should ignore them. Callers must pass a sizes -// and eps slice with a length greater than or equal to the length of packets. -// These lengths must not exceed the length of the associated Bind.BatchSize(). -type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error) +// A ReceiveFunc receives at least one packet from the network into bufs. +// On a successful read it returns the number of elements of bufs and eps +// that should be evaluated. Callers must pass an eps slice with a length +// greater than or equal to the length of bufs. These lengths must not +// exceed the length of the associated Bind.BatchSize(). +type ReceiveFunc func(bufs []iobuf.View, eps []Endpoint) (n int, err error) // A Bind listens on a port for both IPv6 and IPv4 UDP traffic. // diff --git a/conn/conn_test.go b/conn/conn_test.go index c6194ee0c..a39ac8fe5 100644 --- a/conn/conn_test.go +++ b/conn/conn_test.go @@ -7,11 +7,13 @@ package conn import ( "testing" + + "github.com/tailscale/wireguard-go/iobuf" ) func TestPrettyName(t *testing.T) { var ( - recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return } + recvFunc ReceiveFunc = func(bufs []iobuf.View, eps []Endpoint) (n int, err error) { return } ) const want = "TestPrettyName" diff --git a/device/channels.go b/device/channels.go index e526f6bb1..cfa0d0587 100644 --- a/device/channels.go +++ b/device/channels.go @@ -93,7 +93,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) { case elemsContainer := <-q.c: elemsContainer.Lock() for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() device.PutInboundElement(elem) } device.PutInboundElementsContainer(elemsContainer) @@ -126,7 +126,7 @@ func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) { case elemsContainer := <-q.c: elemsContainer.Lock() for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsContainer) diff --git a/device/constants.go b/device/constants.go index 92c3bdea8..1fc691c88 100644 --- a/device/constants.go +++ b/device/constants.go @@ -7,6 +7,8 @@ package device import ( "time" + + "github.com/tailscale/wireguard-go/iobuf" ) /* Specification constants */ @@ -27,9 +29,9 @@ const ( ) const ( - MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) - MaxMessageSize = MaxSegmentSize // maximum size of transport message - MaxContentSize = MaxSegmentSize - MessageTransportSize - MessageEncapsulatingTransportSize // maximum size of transport message content + MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive) + MaxMessageSize = iobuf.MaxReadSize // maximum size of transport message + MaxContentSize = iobuf.MaxReadSize - MessageTransportSize - MessageEncapsulatingTransportSize // maximum size of transport message content ) /* Implementation constants */ diff --git a/device/device.go b/device/device.go index 0e720f251..ba25e1136 100644 --- a/device/device.go +++ b/device/device.go @@ -17,6 +17,7 @@ import ( "github.com/tailscale/wireguard-go/ratelimiter" "github.com/tailscale/wireguard-go/rwcancel" "github.com/tailscale/wireguard-go/tun" + "github.com/tailscale/wireguard-go/waitpool" ) type Device struct { @@ -71,11 +72,10 @@ type Device struct { cookieChecker CookieChecker pool struct { - inboundElementsContainer *WaitPool - outboundElementsContainer *WaitPool - messageBuffers *WaitPool - inboundElements *WaitPool - outboundElements *WaitPool + inboundElementsContainer *waitpool.WaitPool + outboundElementsContainer *waitpool.WaitPool + inboundElements *waitpool.WaitPool + outboundElements *waitpool.WaitPool } queue struct { diff --git a/device/device_test.go b/device/device_test.go index e44342170..0c07a42b0 100644 --- a/device/device_test.go +++ b/device/device_test.go @@ -22,6 +22,7 @@ import ( "github.com/tailscale/wireguard-go/conn" "github.com/tailscale/wireguard-go/conn/bindtest" + "github.com/tailscale/wireguard-go/iobuf" "github.com/tailscale/wireguard-go/tun" "github.com/tailscale/wireguard-go/tun/tuntest" ) @@ -437,7 +438,7 @@ type fakeTUNDeviceSized struct { } func (t *fakeTUNDeviceSized) File() *os.File { return nil } -func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) { +func (t *fakeTUNDeviceSized) Read(bufs []iobuf.View, offset int) (n int, err error) { return 0, nil } func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil } diff --git a/device/pools.go b/device/pools.go index 55d2be7df..d18aa6fb8 100644 --- a/device/pools.go +++ b/device/pools.go @@ -7,61 +7,24 @@ package device import ( "sync" -) - -type WaitPool struct { - pool sync.Pool - cond sync.Cond - lock sync.Mutex - count uint32 // Get calls not yet Put back - max uint32 -} - -func NewWaitPool(max uint32, new func() any) *WaitPool { - p := &WaitPool{pool: sync.Pool{New: new}, max: max} - p.cond = sync.Cond{L: &p.lock} - return p -} - -func (p *WaitPool) Get() any { - if p.max != 0 { - p.lock.Lock() - for p.count >= p.max { - p.cond.Wait() - } - p.count++ - p.lock.Unlock() - } - return p.pool.Get() -} -func (p *WaitPool) Put(x any) { - p.pool.Put(x) - if p.max == 0 { - return - } - p.lock.Lock() - defer p.lock.Unlock() - p.count-- - p.cond.Signal() -} + "github.com/tailscale/wireguard-go/iobuf" + "github.com/tailscale/wireguard-go/waitpool" +) func (device *Device) PopulatePools() { - device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + device.pool.inboundElementsContainer = waitpool.New(iobuf.MaxPooledBuffers, func() any { s := make([]*QueueInboundElement, 0, device.BatchSize()) return &QueueInboundElementsContainer{elems: s} }) - device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any { + device.pool.outboundElementsContainer = waitpool.New(iobuf.MaxPooledBuffers, func() any { s := make([]*QueueOutboundElement, 0, device.BatchSize()) return &QueueOutboundElementsContainer{elems: s} }) - device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any { - return new([MaxMessageSize]byte) - }) - device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { + device.pool.inboundElements = waitpool.New(iobuf.MaxPooledBuffers, func() any { return new(QueueInboundElement) }) - device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any { + device.pool.outboundElements = waitpool.New(iobuf.MaxPooledBuffers, func() any { return new(QueueOutboundElement) }) } @@ -94,14 +57,6 @@ func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsConta device.pool.outboundElementsContainer.Put(c) } -func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte { - return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte) -} - -func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) { - device.pool.messageBuffers.Put(msg) -} - func (device *Device) GetInboundElement() *QueueInboundElement { return device.pool.inboundElements.Get().(*QueueInboundElement) } @@ -117,5 +72,6 @@ func (device *Device) GetOutboundElement() *QueueOutboundElement { func (device *Device) PutOutboundElement(elem *QueueOutboundElement) { elem.clearPointers() + elem.nonce = 0 device.pool.outboundElements.Put(elem) } diff --git a/device/queueconstants_android.go b/device/queueconstants_android.go index bab9625c4..8f564db55 100644 --- a/device/queueconstants_android.go +++ b/device/queueconstants_android.go @@ -10,10 +10,8 @@ import "github.com/tailscale/wireguard-go/conn" /* Reduce memory consumption for Android */ const ( - QueueStagedSize = conn.IdealBatchSize - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - MaxSegmentSize = 2200 - PreallocatedBuffersPerPool = 4096 + QueueStagedSize = conn.IdealBatchSize + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 ) diff --git a/device/queueconstants_default.go b/device/queueconstants_default.go index 9749cb789..56b1ab995 100644 --- a/device/queueconstants_default.go +++ b/device/queueconstants_default.go @@ -10,10 +10,8 @@ package device import "github.com/tailscale/wireguard-go/conn" const ( - QueueStagedSize = conn.IdealBatchSize - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram - PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth + QueueStagedSize = conn.IdealBatchSize + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 ) diff --git a/device/queueconstants_ios.go b/device/queueconstants_ios.go index acd3cec13..258e1eed3 100644 --- a/device/queueconstants_ios.go +++ b/device/queueconstants_ios.go @@ -11,11 +11,8 @@ package device // These are vars instead of consts, because heavier network extensions might want to reduce // them further. var ( - QueueStagedSize = 128 - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - PreallocatedBuffersPerPool uint32 = 1024 + QueueStagedSize = 128 + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 ) - -const MaxSegmentSize = 1700 diff --git a/device/queueconstants_windows.go b/device/queueconstants_windows.go index 1eee32ba1..169c439de 100644 --- a/device/queueconstants_windows.go +++ b/device/queueconstants_windows.go @@ -6,10 +6,8 @@ package device const ( - QueueStagedSize = 128 - QueueOutboundSize = 1024 - QueueInboundSize = 1024 - QueueHandshakeSize = 1024 - MaxSegmentSize = 2048 - 32 // largest possible UDP datagram - PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth + QueueStagedSize = 128 + QueueOutboundSize = 1024 + QueueInboundSize = 1024 + QueueHandshakeSize = 1024 ) diff --git a/device/receive.go b/device/receive.go index 56cde1047..b15ec7bfb 100644 --- a/device/receive.go +++ b/device/receive.go @@ -14,6 +14,7 @@ import ( "time" "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" "golang.org/x/net/ipv6" @@ -23,11 +24,11 @@ type QueueHandshakeElement struct { msgType uint32 packet []byte endpoint conn.Endpoint - buffer *[MaxMessageSize]byte + buffer iobuf.View } type QueueInboundElement struct { - buffer *[MaxMessageSize]byte + buffer iobuf.View packet []byte counter uint64 keypair *Keypair @@ -44,7 +45,7 @@ type QueueInboundElementsContainer struct { // avoids accidentally keeping other objects around unnecessarily. // It also reduces the possible collateral damage from use-after-free bugs. func (elem *QueueInboundElement) clearPointers() { - elem.buffer = nil + elem.buffer = iobuf.View{} elem.packet = nil elem.keypair = nil elem.endpoint = nil @@ -84,31 +85,20 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive // receive datagrams until conn is closed var ( - bufsArrs = make([]*[MaxMessageSize]byte, maxBatchSize) - bufs = make([][]byte, maxBatchSize) + bufs = make([]iobuf.View, maxBatchSize) // nil entries; recv allocates err error - sizes = make([]int, maxBatchSize) count int endpoints = make([]conn.Endpoint, maxBatchSize) deathSpiral int elemsByPeer = make(map[*Peer]*QueueInboundElementsContainer, maxBatchSize) ) - for i := range bufsArrs { - bufsArrs[i] = device.GetMessageBuffer() - bufs[i] = bufsArrs[i][:] - } - defer func() { - for i := 0; i < maxBatchSize; i++ { - if bufsArrs[i] != nil { - device.PutMessageBuffer(bufsArrs[i]) - } - } + iobuf.ReleaseAll(bufs) }() for { - count, err = recv(bufs, sizes, endpoints) + count, err = recv(bufs, endpoints) if err != nil { if errors.Is(err, net.ErrClosed) { return @@ -127,14 +117,14 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive deathSpiral = 0 // handle each packet in the batch - for i, size := range sizes[:count] { - if size < MinMessageSize { + for i := 0; i < count; i++ { + if len(bufs[i].Bytes) < MinMessageSize { continue } // check size of packet - packet := bufsArrs[i][:size] + packet := bufs[i].Bytes msgType := binary.LittleEndian.Uint32(packet[:4]) switch msgType { @@ -170,7 +160,7 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive peer := value.peer elem := device.GetInboundElement() elem.packet = packet - elem.buffer = bufsArrs[i] + elem.buffer = bufs[i].Claim() elem.keypair = keypair elem.endpoint = endpoints[i] elem.counter = 0 @@ -182,8 +172,6 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive elemsByPeer[peer] = elemsForPeer } elemsForPeer.elems = append(elemsForPeer.elems, elem) - bufsArrs[i] = device.GetMessageBuffer() - bufs[i] = bufsArrs[i][:] continue // otherwise it is a fixed size & handshake related packet @@ -211,22 +199,21 @@ func (device *Device) RoutineReceiveIncoming(maxBatchSize int, recv conn.Receive select { case device.queue.handshake.c <- QueueHandshakeElement{ msgType: msgType, - buffer: bufsArrs[i], + buffer: bufs[i].Claim(), packet: packet, endpoint: endpoints[i], }: - bufsArrs[i] = device.GetMessageBuffer() - bufs[i] = bufsArrs[i][:] default: } } + iobuf.ReleaseAll(bufs[:count]) // release unclaimed for peer, elemsContainer := range elemsByPeer { if peer.isRunning.Load() { peer.queue.inbound.c <- elemsContainer device.queue.decryption.c <- elemsContainer } else { for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() device.PutInboundElement(elem) } device.PutInboundElementsContainer(elemsContainer) @@ -423,7 +410,7 @@ func (device *Device) RoutineHandshake(id int) { peer.SendKeepalive() } skip: - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() } } @@ -435,7 +422,8 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { }() device.log.Verbosef("%v - Routine: sequential receiver - started", peer) - bufs := make([][]byte, 0, maxBatchSize) + toWrite := make([]iobuf.View, 0, maxBatchSize) // reference to transferred buffers, released after batch write + bufs := make([][]byte, 0, maxBatchSize) // slices of the above buffers, passed to TUN device for elemsContainer := range peer.queue.inbound.c { if elemsContainer == nil { @@ -513,7 +501,8 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { continue } - bufs = append(bufs, elem.buffer[:MessageTransportOffsetContent+len(elem.packet)]) + bufs = append(bufs, elem.buffer.Bytes[:MessageTransportOffsetContent+len(elem.packet)]) + toWrite = append(toWrite, elem.buffer.Claim()) } peer.rxBytes.Add(rxBytesLen) @@ -531,12 +520,15 @@ func (peer *Peer) RoutineSequentialReceiver(maxBatchSize int) { if err != nil && !device.isClosed() { device.log.Errorf("Failed to write packets to TUN device: %v", err) } + } for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() device.PutInboundElement(elem) } bufs = bufs[:0] + iobuf.ReleaseAll(toWrite) //release unclaimed + toWrite = toWrite[:0] device.PutInboundElementsContainer(elemsContainer) } } diff --git a/device/send.go b/device/send.go index 89269fc07..4812bc6c7 100644 --- a/device/send.go +++ b/device/send.go @@ -15,6 +15,7 @@ import ( "time" "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/iobuf" "github.com/tailscale/wireguard-go/tun" "golang.org/x/crypto/chacha20poly1305" "golang.org/x/net/ipv4" @@ -46,8 +47,8 @@ import ( */ type QueueOutboundElement struct { - buffer *[MaxMessageSize]byte // slice holding the packet data - // packet is always a slice of "buffer". The starting offset in buffer + buffer iobuf.View + // packet is always a slice of buf. The starting offset in buf // is either: // a) MessageEncapsulatingTransportSize+MessageTransportHeaderSize (plaintext) // b) 0 (post-encryption) @@ -62,20 +63,12 @@ type QueueOutboundElementsContainer struct { elems []*QueueOutboundElement } -func (device *Device) NewOutboundElement() *QueueOutboundElement { - elem := device.GetOutboundElement() - elem.buffer = device.GetMessageBuffer() - elem.nonce = 0 - // keypair and peer were cleared (if necessary) by clearPointers. - return elem -} - // clearPointers clears elem fields that contain pointers. // This makes the garbage collector's life easier and // avoids accidentally keeping other objects around unnecessarily. // It also reduces the possible collateral damage from use-after-free bugs. func (elem *QueueOutboundElement) clearPointers() { - elem.buffer = nil + elem.buffer = iobuf.View{} elem.packet = nil elem.keypair = nil elem.peer = nil @@ -85,14 +78,15 @@ func (elem *QueueOutboundElement) clearPointers() { */ func (peer *Peer) SendKeepalive() { if len(peer.queue.staged) == 0 && peer.isRunning.Load() { - elem := peer.device.NewOutboundElement() + elem := peer.device.GetOutboundElement() + elem.buffer = iobuf.View{Bytes: make([]byte, MessageEncapsulatingTransportSize+MessageTransportSize)} elemsContainer := peer.device.GetOutboundElementsContainer() elemsContainer.elems = append(elemsContainer.elems, elem) select { case peer.queue.staged <- elemsContainer: peer.device.log.Verbosef("%v - Sending keepalive packet", peer) default: - peer.device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() peer.device.PutOutboundElement(elem) peer.device.PutOutboundElementsContainer(elemsContainer) } @@ -128,15 +122,16 @@ func (peer *Peer) SendHandshakeInitiation(isRetry bool) error { return err } - buf := make([]byte, MessageEncapsulatingTransportSize+MessageInitiationSize) - packet := buf[MessageEncapsulatingTransportSize:] + buf := iobuf.View{Bytes: make([]byte, MessageEncapsulatingTransportSize+MessageInitiationSize)} + defer buf.Release() + packet := buf.Bytes[MessageEncapsulatingTransportSize:] _ = msg.marshal(packet) peer.cookieGenerator.AddMacs(packet) peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() - err = peer.SendBuffers([][]byte{buf}) + err = peer.SendBuffers([][]byte{buf.Bytes}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake initiation: %v", peer, err) } @@ -158,8 +153,9 @@ func (peer *Peer) SendHandshakeResponse() error { return err } - buf := make([]byte, MessageEncapsulatingTransportSize+MessageResponseSize) - packet := buf[MessageEncapsulatingTransportSize:] + buf := iobuf.View{Bytes: make([]byte, MessageEncapsulatingTransportSize+MessageResponseSize)} + defer buf.Release() + packet := buf.Bytes[MessageEncapsulatingTransportSize:] _ = response.marshal(packet) peer.cookieGenerator.AddMacs(packet) @@ -174,7 +170,7 @@ func (peer *Peer) SendHandshakeResponse() error { peer.timersAnyAuthenticatedPacketSent() // TODO: allocation could be avoided - err = peer.SendBuffers([][]byte{buf}) + err = peer.SendBuffers([][]byte{buf.Bytes}) if err != nil { peer.device.log.Errorf("%v - Failed to send handshake response: %v", peer, err) } @@ -191,11 +187,12 @@ func (device *Device) SendHandshakeCookie(initiatingElem *QueueHandshakeElement) return err } - buf := make([]byte, MessageEncapsulatingTransportSize+MessageCookieReplySize) - packet := buf[MessageEncapsulatingTransportSize:] + buf := iobuf.View{Bytes: make([]byte, MessageEncapsulatingTransportSize+MessageCookieReplySize)} + defer buf.Release() + packet := buf.Bytes[MessageEncapsulatingTransportSize:] _ = reply.marshal(packet) - // TODO: allocation could be avoided - device.net.bind.Send([][]byte{buf}, initiatingElem.endpoint, MessageEncapsulatingTransportSize) + + device.net.bind.Send([][]byte{buf.Bytes}, initiatingElem.endpoint, MessageEncapsulatingTransportSize) return nil } @@ -223,57 +220,43 @@ func (device *Device) RoutineReadFromTUN() { var ( batchSize = device.BatchSize() readErr error - elems = make([]*QueueOutboundElement, batchSize) - bufs = make([][]byte, batchSize) + bufs = make([]iobuf.View, batchSize) elemsByPeer = make(map[*Peer]*QueueOutboundElementsContainer, batchSize) count = 0 - sizes = make([]int, batchSize) offset = MessageEncapsulatingTransportSize + MessageTransportHeaderSize ) - for i := range elems { - elems[i] = device.NewOutboundElement() - bufs[i] = elems[i].buffer[:] - } - defer func() { - for _, elem := range elems { - if elem != nil { - device.PutMessageBuffer(elem.buffer) - device.PutOutboundElement(elem) - } - } + iobuf.ReleaseAll(bufs) }() for { - // read packets - count, readErr = device.tun.device.Read(bufs, sizes, offset) + count, readErr = device.tun.device.Read(bufs, offset) + for i := 0; i < count; i++ { - if sizes[i] < 1 { + packet := bufs[i].Bytes[offset:] + if len(packet) < 1 { continue } - elem := elems[i] - elem.packet = bufs[i][offset : offset+sizes[i]] - // lookup peer var peer *Peer - switch elem.packet[0] >> 4 { + switch packet[0] >> 4 { case 4: - if len(elem.packet) < ipv4.HeaderLen { + if len(packet) < ipv4.HeaderLen { continue } - src := netip.AddrFrom4([4]byte(elem.packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len])) - dst := netip.AddrFrom4([4]byte(elem.packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len])) - peer = device.allowedips.LookupFromPacket(src, dst, elem.packet) + src := netip.AddrFrom4([4]byte(packet[IPv4offsetSrc : IPv4offsetSrc+net.IPv4len])) + dst := netip.AddrFrom4([4]byte(packet[IPv4offsetDst : IPv4offsetDst+net.IPv4len])) + peer = device.allowedips.LookupFromPacket(src, dst, packet) case 6: - if len(elem.packet) < ipv6.HeaderLen { + if len(packet) < ipv6.HeaderLen { continue } - src := netip.AddrFrom16([16]byte(elem.packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len])) - dst := netip.AddrFrom16([16]byte(elem.packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len])) - peer = device.allowedips.LookupFromPacket(src, dst, elem.packet) + src := netip.AddrFrom16([16]byte(packet[IPv6offsetSrc : IPv6offsetSrc+net.IPv6len])) + dst := netip.AddrFrom16([16]byte(packet[IPv6offsetDst : IPv6offsetDst+net.IPv6len])) + peer = device.allowedips.LookupFromPacket(src, dst, packet) default: device.log.Verbosef("Received packet with unknown IP version") @@ -282,15 +265,19 @@ func (device *Device) RoutineReadFromTUN() { if peer == nil { continue } + + elem := device.GetOutboundElement() + elem.packet = packet + elem.buffer = bufs[i].Claim() + elemsForPeer, ok := elemsByPeer[peer] if !ok { elemsForPeer = device.GetOutboundElementsContainer() elemsByPeer[peer] = elemsForPeer } elemsForPeer.elems = append(elemsForPeer.elems, elem) - elems[i] = device.NewOutboundElement() - bufs[i] = elems[i].buffer[:] } + iobuf.ReleaseAll(bufs[:count]) // release unclaimed for peer, elemsForPeer := range elemsByPeer { if peer.isRunning.Load() { @@ -298,7 +285,7 @@ func (device *Device) RoutineReadFromTUN() { peer.SendStagedPackets() } else { for _, elem := range elemsForPeer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsForPeer) @@ -335,7 +322,7 @@ func (peer *Peer) StagePackets(elems *QueueOutboundElementsContainer) { select { case tooOld := <-peer.queue.staged: for _, elem := range tooOld.elems { - peer.device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(tooOld) @@ -396,7 +383,7 @@ top: peer.device.queue.encryption.c <- elemsContainer } else { for _, elem := range elemsContainer.elems { - peer.device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(elemsContainer) @@ -416,7 +403,7 @@ func (peer *Peer) FlushStagedPackets() { select { case elemsContainer := <-peer.queue.staged: for _, elem := range elemsContainer.elems { - peer.device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() peer.device.PutOutboundElement(elem) } peer.device.PutOutboundElementsContainer(elemsContainer) @@ -456,7 +443,7 @@ func (device *Device) RoutineEncryption(id int) { for elemsContainer := range device.queue.encryption.c { for _, elem := range elemsContainer.elems { // populate header fields - header := elem.buffer[MessageEncapsulatingTransportSize : MessageEncapsulatingTransportSize+MessageTransportHeaderSize] + header := elem.buffer.Bytes[MessageEncapsulatingTransportSize : MessageEncapsulatingTransportSize+MessageTransportHeaderSize] fieldType := header[0:4] fieldReceiver := header[4:8] @@ -481,7 +468,8 @@ func (device *Device) RoutineEncryption(id int) { ) // re-slice packet to include encapsulating transport space - elem.packet = elem.buffer[:MessageEncapsulatingTransportSize+len(elem.packet)] + elem.buffer.Bytes = elem.buffer.Bytes[:MessageEncapsulatingTransportSize+len(elem.packet)] + elem.packet = elem.buffer.Bytes } elemsContainer.Unlock() } @@ -495,10 +483,12 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { }() device.log.Verbosef("%v - Routine: sequential sender - started", peer) - bufs := make([][]byte, 0, maxBatchSize) + toWrite := make([]iobuf.View, 0, maxBatchSize) // reference to transferred buffers, released after batch write + bufs := make([][]byte, 0, maxBatchSize) // slices of the above buffers, passed to Bind for elemsContainer := range peer.queue.outbound.c { bufs = bufs[:0] + toWrite = toWrite[:0] if elemsContainer == nil { return } @@ -511,7 +501,7 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { // that we never accidentally keep timers alive longer than necessary. elemsContainer.Lock() for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsContainer) @@ -524,17 +514,19 @@ func (peer *Peer) RoutineSequentialSender(maxBatchSize int) { dataSent = true } bufs = append(bufs, elem.packet) + toWrite = append(toWrite, elem.buffer.Claim()) } peer.timersAnyAuthenticatedPacketTraversal() peer.timersAnyAuthenticatedPacketSent() err := peer.SendBuffers(bufs) + iobuf.ReleaseAll(toWrite) if dataSent { peer.timersDataSent() } for _, elem := range elemsContainer.elems { - device.PutMessageBuffer(elem.buffer) + elem.buffer.Release() device.PutOutboundElement(elem) } device.PutOutboundElementsContainer(elemsContainer) diff --git a/iobuf/constants.go b/iobuf/constants.go new file mode 100644 index 000000000..99a2ec005 --- /dev/null +++ b/iobuf/constants.go @@ -0,0 +1,10 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package iobuf + +const ( + MaxBufferSize = MaxReadSize // the largest buffer that I/O may attempt to read or write. +) diff --git a/iobuf/constants_android.go b/iobuf/constants_android.go new file mode 100644 index 000000000..bad0d014e --- /dev/null +++ b/iobuf/constants_android.go @@ -0,0 +1,13 @@ +//go:build android + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package iobuf + +const ( + MaxReadSize = 2200 + MaxPooledBuffers = 4096 +) diff --git a/iobuf/constants_default.go b/iobuf/constants_default.go new file mode 100644 index 000000000..ff656532f --- /dev/null +++ b/iobuf/constants_default.go @@ -0,0 +1,13 @@ +//go:build !android && !ios && !windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package iobuf + +const ( + MaxReadSize = (1 << 16) - 1 + MaxPooledBuffers = 0 // Disable and allow for infinite memory growth +) diff --git a/iobuf/constants_ios.go b/iobuf/constants_ios.go new file mode 100644 index 000000000..96a0d39a3 --- /dev/null +++ b/iobuf/constants_ios.go @@ -0,0 +1,14 @@ +//go:build ios + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package iobuf + +var ( + MaxPooledBuffers = 1024 // Var to allow further reduction. Recreate [DefaultPool] if changed. +) + +const MaxReadSize = 1700 diff --git a/iobuf/constants_windows.go b/iobuf/constants_windows.go new file mode 100644 index 000000000..d1000f023 --- /dev/null +++ b/iobuf/constants_windows.go @@ -0,0 +1,13 @@ +//go:build windows + +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package iobuf + +const ( + MaxReadSize = 2048 - 32 + MaxPooledBuffers = 0 // Disable and allow for infinite memory growth +) diff --git a/iobuf/raw.go b/iobuf/raw.go new file mode 100644 index 000000000..c4ab50a86 --- /dev/null +++ b/iobuf/raw.go @@ -0,0 +1,62 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +package iobuf + +import ( + "unsafe" + + "github.com/tailscale/wireguard-go/waitpool" +) + +var _ Recycler = (*RawPool)(nil) + +// Raw is the fundamental byte array. +type Raw [MaxBufferSize]byte + +// RawPool wraps [waitpool.WaitPool] of [Raw] buffers +// to configure their return via [Raw.Recycle]. +type RawPool struct { + *waitpool.WaitPool +} + +func (p *RawPool) Get() *Raw { + return p.WaitPool.Get().(*Raw) +} + +// Recycle returns the buffer to the pool. +// +//go:nocheckptr +func (p *RawPool) Recycle(ptr RecycleHandle) { + arr := (*Raw)(unsafe.Pointer(ptr)) //nolint:govet + p.Put(arr) +} + +func NewRawPool(size int) *RawPool { + return &RawPool{waitpool.New(size, func() any { + return new(Raw) + })} +} + +// DefaultRawPool is used for package-level [Get] and [EnsureAllocated]. +var DefaultRawPool = NewRawPool(MaxBufferSize) + +// EnsureAllocated fills zero-valued Views from the [DefaultRawPool]. +func EnsureAllocated(bufs []View) { + for i := range bufs { + if bufs[i].Bytes == nil { + Init(&bufs[i]) + } + } +} + +// Init initializes a [View] in-place with a fresh backing from the pool. +// Sets Bytes to the full backing array. +func Init(b *View) { + arr := DefaultRawPool.Get() + b.Recycler = DefaultRawPool + b.Handle = RecycleHandle(unsafe.Pointer(arr)) + b.Bytes = arr[:] +} diff --git a/iobuf/view.go b/iobuf/view.go new file mode 100644 index 000000000..a6a81ba06 --- /dev/null +++ b/iobuf/view.go @@ -0,0 +1,57 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +// Package iobuf provides pooled packet buffers for the I/O pipeline. +// Each [View] carries one packet and a recycle function that returns +// its backing storage to the originating pool on [Release]. +package iobuf + +// RecycleHandle is the opaque reference to the [View] backing arrays. +type RecycleHandle uintptr + +// Recycler returns a backing array for reuse. +// The argument is the Backing of the [View] being released. +type Recycler interface { + Recycle(RecycleHandle) +} + +// RecycleFunc is a function adapter for Recycler. +type RecycleFunc func(RecycleHandle) + +func (f RecycleFunc) Recycle(ptr RecycleHandle) { f(ptr) } + +// View is the packet envelope. Meant to be a value type, +// allocated once per goroutine and reused across read cycles. +type View struct { + Recycler Recycler // nil for external/unmanaged Views. + Handle RecycleHandle // zero for external/unmanaged Views. + + // Bytes holds the bounded packet data. Cut from the backing array, + // it may be re-sliced by the caller. Do not append() on this slice. + // Nil for uninitialized Views. + Bytes []byte +} + +// Release returns the backing data to its source and zeros the View. +func (b *View) Release() { + if b.Recycler != nil { + b.Recycler.Recycle(b.Handle) + } + *b = View{} +} + +// Claim transfers ownership: returns a copy of the View and zeros the source. +func (b *View) Claim() View { + c := *b + *b = View{} + return c +} + +// ReleaseAll releases each View in the slice. +func ReleaseAll(bufs []View) { + for i := range bufs { + bufs[i].Release() + } +} diff --git a/tun/netstack/tun.go b/tun/netstack/tun.go index d8e70bb03..a0f125a4e 100644 --- a/tun/netstack/tun.go +++ b/tun/netstack/tun.go @@ -22,6 +22,7 @@ import ( "syscall" "time" + "github.com/tailscale/wireguard-go/iobuf" "github.com/tailscale/wireguard-go/tun" "golang.org/x/net/dns/dnsmessage" @@ -119,17 +120,19 @@ func (tun *netTun) Events() <-chan tun.Event { return tun.events } -func (tun *netTun) Read(buf [][]byte, sizes []int, offset int) (int, error) { +func (tun *netTun) Read(bufs []iobuf.View, offset int) (int, error) { view, ok := <-tun.incomingPacket if !ok { return 0, os.ErrClosed } - - n, err := view.Read(buf[0][offset:]) + // TODO: If not the offset, could use view.AsSlice() and wrap view.Release() in a [buffer.Recycler]. + // TODO: Allocate view.Size() buffer. + iobuf.EnsureAllocated(bufs[:1]) + n, err := view.Read(bufs[0].Bytes[offset:]) if err != nil { return 0, err } - sizes[0] = n + bufs[0].Bytes = bufs[0].Bytes[:offset+n] return 1, nil } diff --git a/tun/offload.go b/tun/offload.go index 6db437c34..a2cb15e27 100644 --- a/tun/offload.go +++ b/tun/offload.go @@ -3,6 +3,8 @@ package tun import ( "encoding/binary" "fmt" + + "github.com/tailscale/wireguard-go/iobuf" ) // GSOType represents the type of segmentation offload. @@ -73,15 +75,15 @@ const ( ipProtoUDP = 17 ) -// GSOSplit splits packets from 'in' into outBufs[][outOffset:], writing -// the size of each element into sizes. It returns the number of buffers +// GSOSplit splits packets from 'in' into one or more entries in outBufs, writing +// each output packet to outBufs[i].Data starting at outOffset. It returns the number of buffers // populated, and/or an error. Callers may pass an 'in' slice that overlaps with // the first element of outBuffers, i.e. &in[0] may be equal to -// &outBufs[0][outOffset]. GSONone is a valid options.GSOType regardless of the +// &outBufs[0].Data[outOffset]. GSONone is a valid options.GSOType regardless of the // value of options.NeedsCsum. Length of each outBufs element must be greater // than or equal to the length of 'in', otherwise output may be silently // truncated. -func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outOffset int) (int, error) { +func GSOSplit(in []byte, options GSOOptions, outBufs []iobuf.View, outOffset int) (int, error) { cSumAt := int(options.CsumStart) + int(options.CsumOffset) if cSumAt+1 >= len(in) { return 0, fmt.Errorf("end of checksum offset (%d) exceeds packet length (%d)", cSumAt+1, len(in)) @@ -94,8 +96,8 @@ func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outO // Handle the conditions where we are copying a single element to outBuffs. payloadLen := len(in) - int(options.HdrLen) if options.GSOType == GSONone || payloadLen < int(options.GSOSize) { - if len(in) > len(outBufs[0][outOffset:]) { - return 0, fmt.Errorf("length of packet (%d) exceeds output element length (%d)", len(in), len(outBufs[0][outOffset:])) + if len(in) > len(outBufs[0].Bytes[outOffset:]) { + return 0, fmt.Errorf("length of packet (%d) exceeds output element length (%d)", len(in), len(outBufs[0].Bytes[outOffset:])) } if options.NeedsCsum { // The initial value at the checksum offset should be summed with @@ -104,7 +106,8 @@ func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outO in[cSumAt], in[cSumAt+1] = 0, 0 binary.BigEndian.PutUint16(in[cSumAt:], ^Checksum(in[options.CsumStart:], initial)) } - sizes[0] = copy(outBufs[0][outOffset:], in) + n := copy(outBufs[0].Bytes[outOffset:], in) + outBufs[0].Bytes = outBufs[0].Bytes[:outOffset+n] return 1, nil } @@ -164,8 +167,8 @@ func GSOSplit(in []byte, options GSOOptions, outBufs [][]byte, sizes []int, outO } segmentDataLen := nextSegmentEnd - nextSegmentDataAt totalLen := int(options.HdrLen) + segmentDataLen - sizes[i] = totalLen - out := outBufs[i][outOffset:] + outBufs[i].Bytes = outBufs[i].Bytes[:outOffset+totalLen] + out := outBufs[i].Bytes[outOffset:] copy(out, in[:iphLen]) if ipVersion == 4 { diff --git a/tun/offload_linux.go b/tun/offload_linux.go index fb6ac5b94..9d0c9e707 100644 --- a/tun/offload_linux.go +++ b/tun/offload_linux.go @@ -11,6 +11,7 @@ import ( "errors" "fmt" "io" + "slices" "unsafe" "github.com/tailscale/wireguard-go/conn" @@ -72,8 +73,63 @@ const ( // virtioNetHdrLen is the length in bytes of virtioNetHdr. This matches the // shape of the C ABI for its kernel counterpart -- sizeof(virtio_net_hdr). virtioNetHdrLen = int(unsafe.Sizeof(virtioNetHdr{})) + + // Vector layout: [virtioHdr | headPacket | coalescedPayloads...] + virtioNetHdrIdx = 0 + emptyVectorLen = 1 + headPacketIdx = 1 + singlePacketLen = 2 + coalescedPacketsIdx = 2 + + maxScatterGatherFragments = 1024 // Limited by UIO_MAXIOV of 1024. ) +// groToWrite holds the write-ordered scatter-gather IO vectors for +// writev. +type groToWrite struct { + // Each iov is a [][]byte: + // - iovs[i][0] is the pre-allocated virtio header + // - iovs[i][1] is the head packet with transport headers + // - iovs[i][2:] are coalesced payload fragments + iovs [][][]byte + allocated int +} + +func newGROToWrite() groToWrite { + wi := groToWrite{ + iovs: make([][][]byte, 0, conn.IdealBatchSize), + } + for range cap(wi.iovs) { + _, _ = wi.next() // pre-allocate virtio headers + } + wi.iovs = wi.iovs[:0] + return wi +} + +// next extends iovs by one, reusing the pre-allocated backing. +// Returns the index and a pointer to the new item's [][]byte. +// +// Do not use the returned pointer after the next call to next. +func (w *groToWrite) next() (int, *[][]byte) { + n := len(w.iovs) + if n < w.allocated { + w.iovs = w.iovs[:n+1] + } else { + iov := make([][]byte, 1, conn.IdealBatchSize) + iov[virtioNetHdrIdx] = make([]byte, virtioNetHdrLen) + w.iovs = append(w.iovs, iov) + w.allocated++ + } + return n, &w.iovs[n] +} + +func (w *groToWrite) reset() { + for i := range w.iovs { + w.iovs[i] = w.iovs[i][:emptyVectorLen] // keep virtio header alloc + } + w.iovs = w.iovs[:0] +} + // tcpFlowKey represents the key for a TCP flow. type tcpFlowKey struct { srcAddr, dstAddr [16]byte @@ -114,28 +170,31 @@ func newTCPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset int) tcp // lookupOrInsert looks up a flow for the provided packet and metadata, // returning the packets found for the flow, or inserting a new one if none // is found. -func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) ([]tcpGROItem, bool) { +func (t *tcpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen int, wi *groToWrite) ([]tcpGROItem, bool) { key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) items, ok := t.itemsByFlow[key] if ok { return items, ok } // TODO: insert() performs another map lookup. This could be rearranged to avoid. - t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex) + t.insert(pkt, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, wi) return nil, false } // insert an item in the table for the provided packet and packet metadata. -func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen, bufsIndex int) { +func (t *tcpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, tcphOffset, tcphLen int, wi *groToWrite) { key := newTCPFlowKey(pkt, srcAddrOffset, dstAddrOffset, tcphOffset) + idx, iov := wi.next() + *iov = append(*iov, pkt) item := tcpGROItem{ - key: key, - bufsIndex: uint16(bufsIndex), - gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), - iphLen: uint8(tcphOffset), - tcphLen: uint8(tcphLen), - sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), - pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, + key: key, + outputIdx: uint16(idx), + gsoSize: uint16(len(pkt[tcphOffset+tcphLen:])), + iphLen: uint8(tcphOffset), + tcphLen: uint8(tcphLen), + sentSeq: binary.BigEndian.Uint32(pkt[tcphOffset+4:]), + pshSet: pkt[tcphOffset+tcpFlagsOffset]&tcpFlagPSH != 0, + payloadLen: uint16(len(pkt[tcphOffset+tcphLen:])), } items, ok := t.itemsByFlow[key] if !ok { @@ -159,14 +218,14 @@ func (t *tcpGROTable) deleteAt(key tcpFlowKey, i int) { // tcpGROItem represents bookkeeping data for a TCP packet during the lifetime // of a GRO evaluation across a vector of packets. type tcpGROItem struct { - key tcpFlowKey - sentSeq uint32 // the sequence number - bufsIndex uint16 // the index into the original bufs slice - numMerged uint16 // the number of packets merged into this item - gsoSize uint16 // payload size - iphLen uint8 // ip header len - tcphLen uint8 // tcp header len - pshSet bool // psh flag is set + key tcpFlowKey + sentSeq uint32 // the sequence number + outputIdx uint16 // index into groToWrite + payloadLen uint16 // accumulated payload bytes + gsoSize uint16 // payload size + iphLen uint8 // ip header len + tcphLen uint8 // tcp header len + pshSet bool // psh flag is set } func (t *tcpGROTable) newItems() []tcpGROItem { @@ -221,26 +280,29 @@ func newUDPFlowKey(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int) udp // lookupOrInsert looks up a flow for the provided packet and metadata, // returning the packets found for the flow, or inserting a new one if none // is found. -func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int) ([]udpGROItem, bool) { +func (u *udpGROTable) lookupOrInsert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int, wi *groToWrite) ([]udpGROItem, bool) { key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) items, ok := u.itemsByFlow[key] if ok { return items, ok } // TODO: insert() performs another map lookup. This could be rearranged to avoid. - u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex, false) + u.insert(pkt, srcAddrOffset, dstAddrOffset, udphOffset, wi, false) return nil, false } // insert an item in the table for the provided packet and packet metadata. -func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset, bufsIndex int, cSumKnownInvalid bool) { +func (u *udpGROTable) insert(pkt []byte, srcAddrOffset, dstAddrOffset, udphOffset int, wi *groToWrite, cSumKnownInvalid bool) { key := newUDPFlowKey(pkt, srcAddrOffset, dstAddrOffset, udphOffset) + idx, iov := wi.next() + *iov = append(*iov, pkt) item := udpGROItem{ key: key, - bufsIndex: uint16(bufsIndex), + outputIdx: uint16(idx), gsoSize: uint16(len(pkt[udphOffset+udphLen:])), iphLen: uint8(udphOffset), cSumKnownInvalid: cSumKnownInvalid, + payloadLen: uint16(len(pkt[udphOffset+udphLen:])), } items, ok := u.itemsByFlow[key] if !ok { @@ -259,8 +321,8 @@ func (u *udpGROTable) updateAt(item udpGROItem, i int) { // of a GRO evaluation across a vector of packets. type udpGROItem struct { key udpFlowKey - bufsIndex uint16 // the index into the original bufs slice - numMerged uint16 // the number of packets merged into this item + outputIdx uint16 // index into groToWrite + payloadLen uint16 // accumulated payload bytes gsoSize uint16 // payload size iphLen uint8 // ip header len cSumKnownInvalid bool // UDP header checksum validity; a false value DOES NOT imply valid, just unknown. @@ -325,15 +387,16 @@ func ipHeadersCanCoalesce(pktA, pktB []byte) bool { } // udpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet -// described by item. iphLen and gsoSize describe pkt. bufs is the vector of -// packets involved in the current GRO evaluation. bufsOffset is the offset at -// which packet data begins within bufs. -func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { - pktTarget := bufs[item.bufsIndex][bufsOffset:] +// described by item. iphLen and gsoSize describe pkt. +func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGROItem, wi *groToWrite) canCoalesce { + if len(wi.iovs[item.outputIdx]) >= maxScatterGatherFragments { + return coalesceUnavailable + } + pktTarget := wi.iovs[item.outputIdx][headPacketIdx] if !ipHeadersCanCoalesce(pkt, pktTarget) { return coalesceUnavailable } - if len(pktTarget[iphLen+udphLen:])%int(item.gsoSize) != 0 { + if item.payloadLen%item.gsoSize != 0 { // A smaller than gsoSize packet has been appended previously. // Nothing can come after a smaller packet on the end. return coalesceUnavailable @@ -342,14 +405,20 @@ func udpPacketsCanCoalesce(pkt []byte, iphLen uint8, gsoSize uint16, item udpGRO // We cannot have a larger packet following a smaller one. return coalesceUnavailable } + if int(item.iphLen)+udphLen+int(item.payloadLen)+int(gsoSize) > maxUint16 { + return coalesceUnavailable + } return coalesceAppend } // tcpPacketsCanCoalesce evaluates if pkt can be coalesced with the packet // described by item. This function makes considerations that match the kernel's // GRO self tests, which can be found in tools/testing/selftests/net/gro.c. -func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, bufs [][]byte, bufsOffset int) canCoalesce { - pktTarget := bufs[item.bufsIndex][bufsOffset:] +func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet bool, gsoSize uint16, item tcpGROItem, wi *groToWrite) canCoalesce { + if len(wi.iovs[item.outputIdx]) >= maxScatterGatherFragments { + return coalesceUnavailable + } + pktTarget := wi.iovs[item.outputIdx][headPacketIdx] if tcphLen != item.tcphLen { // cannot coalesce with unequal tcp options len return coalesceUnavailable @@ -363,16 +432,17 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet if !ipHeadersCanCoalesce(pkt, pktTarget) { return coalesceUnavailable } + if int(item.iphLen)+int(item.tcphLen)+int(item.payloadLen)+int(gsoSize) > maxUint16 { + return coalesceUnavailable + } // seq adjacency - lhsLen := item.gsoSize - lhsLen += item.numMerged * item.gsoSize - if seq == item.sentSeq+uint32(lhsLen) { // pkt aligns following item from a seq num perspective + if seq == item.sentSeq+uint32(item.payloadLen) { // pkt aligns following item from a seq num perspective if item.pshSet { // We cannot append to a segment that has the PSH flag set, PSH // can only be set on the final segment in a reassembled group. return coalesceUnavailable } - if len(pktTarget[iphLen+tcphLen:])%int(item.gsoSize) != 0 { + if item.payloadLen%item.gsoSize != 0 { // A smaller than gsoSize packet has been appended previously. // Nothing can come after a smaller packet on the end. return coalesceUnavailable @@ -392,7 +462,7 @@ func tcpPacketsCanCoalesce(pkt []byte, iphLen, tcphLen uint8, seq uint32, pshSet // We cannot have a larger packet following a smaller one. return coalesceUnavailable } - if gsoSize > item.gsoSize && item.numMerged > 0 { + if gsoSize > item.gsoSize && len(wi.iovs[item.outputIdx]) > singlePacketLen { // There's at least one previous merge, and we're larger than all // previous. This would put multiple smaller packets on the end. return coalesceUnavailable @@ -414,13 +484,11 @@ func checksumValid(pkt []byte, iphLen, proto uint8, isV6 bool) bool { return ^Checksum(pkt[iphLen:], cSum) == 0 } -// coalesceResult represents the result of attempting to coalesce two TCP -// packets. +// coalesceResult represents the result of attempting to coalesce two packets. type coalesceResult int const ( - coalesceInsufficientCap coalesceResult = iota - coalescePSHEnding + coalescePSHEnding coalesceResult = iota coalesceItemInvalidCSum coalescePktInvalidCSum coalesceSuccess @@ -428,54 +496,33 @@ const ( // coalesceUDPPackets attempts to coalesce pkt with the packet described by // item, and returns the outcome. -func coalesceUDPPackets(pkt []byte, item *udpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { - pktHead := bufs[item.bufsIndex][bufsOffset:] // the packet that will end up at the front - headersLen := item.iphLen + udphLen - coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) - - if cap(pktHead)-bufsOffset < coalescedLen { - // We don't want to allocate a new underlying array if capacity is - // too small. - return coalesceInsufficientCap - } - if item.numMerged == 0 { - if item.cSumKnownInvalid || !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_UDP, isV6) { +func coalesceUDPPackets(pkt []byte, item *udpGROItem, wi *groToWrite, isV6 bool) coalesceResult { + headersLen := int(item.iphLen) + udphLen + iov := &wi.iovs[item.outputIdx] + if len(*iov) == singlePacketLen { + if item.cSumKnownInvalid || !checksumValid((*iov)[headPacketIdx], item.iphLen, unix.IPPROTO_UDP, isV6) { return coalesceItemInvalidCSum } } if !checksumValid(pkt, item.iphLen, unix.IPPROTO_UDP, isV6) { return coalescePktInvalidCSum } - extendBy := len(pkt) - int(headersLen) - bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) - copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) - - item.numMerged++ + *iov = append(*iov, pkt[headersLen:]) + item.payloadLen += uint16(len(pkt) - headersLen) return coalesceSuccess } // coalesceTCPPackets attempts to coalesce pkt with the packet described by -// item, and returns the outcome. This function may swap bufs elements in the -// event of a prepend as item's bufs index is already being tracked for writing -// to a Device. -func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, bufs [][]byte, bufsOffset int, isV6 bool) coalesceResult { - var pktHead []byte // the packet that will end up at the front - headersLen := item.iphLen + item.tcphLen - coalescedLen := len(bufs[item.bufsIndex][bufsOffset:]) + len(pkt) - int(headersLen) - - // Copy data +// item, and returns the outcome. +func coalesceTCPPackets(mode canCoalesce, pkt []byte, gsoSize uint16, seq uint32, pshSet bool, item *tcpGROItem, wi *groToWrite, isV6 bool) coalesceResult { + headersLen := int(item.iphLen) + int(item.tcphLen) + iov := &wi.iovs[item.outputIdx] if mode == coalescePrepend { - pktHead = pkt - if cap(pkt)-bufsOffset < coalescedLen { - // We don't want to allocate a new underlying array if capacity is - // too small. - return coalesceInsufficientCap - } if pshSet { return coalescePSHEnding } - if item.numMerged == 0 { - if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + if len(*iov) == singlePacketLen { + if !checksumValid((*iov)[headPacketIdx], item.iphLen, unix.IPPROTO_TCP, isV6) { return coalesceItemInvalidCSum } } @@ -483,21 +530,15 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize return coalescePktInvalidCSum } item.sentSeq = seq - extendBy := coalescedLen - len(pktHead) - bufs[pktBuffsIndex] = append(bufs[pktBuffsIndex], make([]byte, extendBy)...) - copy(bufs[pktBuffsIndex][bufsOffset+len(pkt):], bufs[item.bufsIndex][bufsOffset+int(headersLen):]) - // Flip the slice headers in bufs as part of prepend. The index of item - // is already being tracked for writing. - bufs[item.bufsIndex], bufs[pktBuffsIndex] = bufs[pktBuffsIndex], bufs[item.bufsIndex] + oldHead := (*iov)[headPacketIdx] + (*iov)[headPacketIdx] = pkt + oldHeadPayload := oldHead[headersLen:] + if len(oldHeadPayload) > 0 { + *iov = slices.Insert(*iov, coalescedPacketsIdx, oldHeadPayload) + } } else { - pktHead = bufs[item.bufsIndex][bufsOffset:] - if cap(pktHead)-bufsOffset < coalescedLen { - // We don't want to allocate a new underlying array if capacity is - // too small. - return coalesceInsufficientCap - } - if item.numMerged == 0 { - if !checksumValid(bufs[item.bufsIndex][bufsOffset:], item.iphLen, unix.IPPROTO_TCP, isV6) { + if len(*iov) == singlePacketLen { + if !checksumValid((*iov)[headPacketIdx], item.iphLen, unix.IPPROTO_TCP, isV6) { return coalesceItemInvalidCSum } } @@ -507,18 +548,16 @@ func coalesceTCPPackets(mode canCoalesce, pkt []byte, pktBuffsIndex int, gsoSize if pshSet { // We are appending a segment with PSH set. item.pshSet = pshSet - pktHead[item.iphLen+tcpFlagsOffset] |= tcpFlagPSH + (*iov)[headPacketIdx][item.iphLen+tcpFlagsOffset] |= tcpFlagPSH } - extendBy := len(pkt) - int(headersLen) - bufs[item.bufsIndex] = append(bufs[item.bufsIndex], make([]byte, extendBy)...) - copy(bufs[item.bufsIndex][bufsOffset+len(pktHead):], pkt[headersLen:]) + *iov = append(*iov, pkt[headersLen:]) } if gsoSize > item.gsoSize { item.gsoSize = gsoSize } - item.numMerged++ + item.payloadLen += uint16(len(pkt) - headersLen) return coalesceSuccess } @@ -538,13 +577,12 @@ const ( groResultCoalesced ) -// tcpGRO evaluates the TCP packet at pktI in bufs for coalescing with +// tcpGRO evaluates the TCP packet for coalescing with // existing packets tracked in table. It returns a groResultNoop when no // action was taken, groResultTableInsert when the evaluated packet was // inserted into table, and groResultCoalesced when the evaluated packet was // coalesced with another packet in table. -func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) groResult { - pkt := bufs[pktI][offset:] +func tcpGRO(pkt []byte, table *tcpGROTable, wi *groToWrite, isV6 bool) groResult { if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. return groResultNoop @@ -599,7 +637,7 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) srcAddrOffset = ipv6SrcAddrOffset addrLen = 16 } - items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, wi) if !existing { return groResultTableInsert } @@ -613,9 +651,9 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) // sequence number perspective, however once an item is inserted into // the table it is never compared across other items later. item := items[i] - can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, bufs, offset) + can := tcpPacketsCanCoalesce(pkt, uint8(iphLen), uint8(tcphLen), seq, pshSet, gsoSize, item, wi) if can != coalesceUnavailable { - result := coalesceTCPPackets(can, pkt, pktI, gsoSize, seq, pshSet, &item, bufs, offset, isV6) + result := coalesceTCPPackets(can, pkt, gsoSize, seq, pshSet, &item, wi, isV6) switch result { case coalesceSuccess: table.updateAt(item, i) @@ -631,16 +669,19 @@ func tcpGRO(bufs [][]byte, offset int, pktI int, table *tcpGROTable, isV6 bool) } } // failed to coalesce with any other packets; store the item in the flow - table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, pktI) + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, tcphLen, wi) return groResultTableInsert } -// applyTCPCoalesceAccounting updates bufs to account for coalescing based on the +// applyTCPCoalesceAccounting updates headers to account for coalescing based on the // metadata found in table. -func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) error { +func applyTCPCoalesceAccounting(wi *groToWrite, table *tcpGROTable) error { for _, items := range table.itemsByFlow { for _, item := range items { - if item.numMerged > 0 { + iov := wi.iovs[item.outputIdx] + pkt := iov[headPacketIdx] + if len(iov) > singlePacketLen { + totalLen := uint16(item.iphLen) + uint16(item.tcphLen) + item.payloadLen hdr := virtioNetHdr{ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb hdrLen: uint16(item.iphLen + item.tcphLen), @@ -648,21 +689,20 @@ func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) e csumStart: uint16(item.iphLen), csumOffset: 16, } - pkt := bufs[item.bufsIndex][offset:] // Recalculate the total len (IPv4) or payload len (IPv6). // Recalculate the (IPv4) header checksum. if item.key.isV6 { hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV6 - binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + binary.BigEndian.PutUint16(pkt[4:], totalLen-uint16(item.iphLen)) // set new IPv6 header payload len } else { hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_TCPV4 pkt[10], pkt[11] = 0, 0 - binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length - iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum - binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + binary.BigEndian.PutUint16(pkt[2:], totalLen) // set new total length + iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field } - err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + err := hdr.encode(iov[virtioNetHdrIdx]) if err != nil { return err } @@ -676,56 +716,53 @@ func applyTCPCoalesceAccounting(bufs [][]byte, offset int, table *tcpGROTable) e addrLen = 16 addrOffset = ipv6SrcAddrOffset } - srcAddrAt := offset + addrOffset - srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] - dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := PseudoHeaderChecksum(unix.IPPROTO_TCP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + srcAddr := pkt[addrOffset : addrOffset+addrLen] + dstAddr := pkt[addrOffset+addrLen : addrOffset+addrLen*2] + psum := PseudoHeaderChecksum(unix.IPPROTO_TCP, srcAddr, dstAddr, totalLen-uint16(item.iphLen)) binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], Checksum([]byte{}, psum)) } else { - hdr := virtioNetHdr{} - err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) - if err != nil { - return err - } + clear(iov[virtioNetHdrIdx]) } } } return nil } -// applyUDPCoalesceAccounting updates bufs to account for coalescing based on the +// applyUDPCoalesceAccounting updates headers to account for coalescing based on the // metadata found in table. -func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) error { +func applyUDPCoalesceAccounting(wi *groToWrite, table *udpGROTable) error { for _, items := range table.itemsByFlow { for _, item := range items { - if item.numMerged > 0 { + iov := wi.iovs[item.outputIdx] + pkt := iov[headPacketIdx] + if len(iov) > singlePacketLen { + totalLen := uint16(item.iphLen) + udphLen + item.payloadLen hdr := virtioNetHdr{ flags: unix.VIRTIO_NET_HDR_F_NEEDS_CSUM, // this turns into CHECKSUM_PARTIAL in the skb - hdrLen: uint16(item.iphLen + udphLen), + hdrLen: uint16(item.iphLen) + udphLen, gsoSize: item.gsoSize, csumStart: uint16(item.iphLen), csumOffset: 6, } - pkt := bufs[item.bufsIndex][offset:] // Recalculate the total len (IPv4) or payload len (IPv6). // Recalculate the (IPv4) header checksum. hdr.gsoType = unix.VIRTIO_NET_HDR_GSO_UDP_L4 if item.key.isV6 { - binary.BigEndian.PutUint16(pkt[4:], uint16(len(pkt))-uint16(item.iphLen)) // set new IPv6 header payload len + binary.BigEndian.PutUint16(pkt[4:], totalLen-uint16(item.iphLen)) // set new IPv6 header payload len } else { pkt[10], pkt[11] = 0, 0 - binary.BigEndian.PutUint16(pkt[2:], uint16(len(pkt))) // set new total length - iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum - binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field + binary.BigEndian.PutUint16(pkt[2:], totalLen) // set new total length + iphCSum := ^Checksum(pkt[:item.iphLen], 0) // compute IPv4 header checksum + binary.BigEndian.PutUint16(pkt[10:], iphCSum) // set IPv4 header checksum field } - err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) + err := hdr.encode(iov[virtioNetHdrIdx]) if err != nil { return err } // Recalculate the UDP len field value - binary.BigEndian.PutUint16(pkt[item.iphLen+4:], uint16(len(pkt[item.iphLen:]))) + binary.BigEndian.PutUint16(pkt[item.iphLen+4:], udphLen+item.payloadLen) // Calculate the pseudo header checksum and place it at the UDP // checksum offset. Downstream checksum offloading will combine @@ -736,17 +773,12 @@ func applyUDPCoalesceAccounting(bufs [][]byte, offset int, table *udpGROTable) e addrLen = 16 addrOffset = ipv6SrcAddrOffset } - srcAddrAt := offset + addrOffset - srcAddr := bufs[item.bufsIndex][srcAddrAt : srcAddrAt+addrLen] - dstAddr := bufs[item.bufsIndex][srcAddrAt+addrLen : srcAddrAt+addrLen*2] - psum := PseudoHeaderChecksum(unix.IPPROTO_UDP, srcAddr, dstAddr, uint16(len(pkt)-int(item.iphLen))) + srcAddr := pkt[addrOffset : addrOffset+addrLen] + dstAddr := pkt[addrOffset+addrLen : addrOffset+addrLen*2] + psum := PseudoHeaderChecksum(unix.IPPROTO_UDP, srcAddr, dstAddr, totalLen-uint16(item.iphLen)) binary.BigEndian.PutUint16(pkt[hdr.csumStart+hdr.csumOffset:], Checksum([]byte{}, psum)) } else { - hdr := virtioNetHdr{} - err := hdr.encode(bufs[item.bufsIndex][offset-virtioNetHdrLen:]) - if err != nil { - return err - } + clear(iov[virtioNetHdrIdx]) } } } @@ -793,13 +825,12 @@ const ( udphLen = 8 ) -// udpGRO evaluates the UDP packet at pktI in bufs for coalescing with +// udpGRO evaluates the UDP packet for coalescing with // existing packets tracked in table. It returns a groResultNoop when no // action was taken, groResultTableInsert when the evaluated packet was // inserted into table, and groResultCoalesced when the evaluated packet was // coalesced with another packet in table. -func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) groResult { - pkt := bufs[pktI][offset:] +func udpGRO(pkt []byte, table *udpGROTable, wi *groToWrite, isV6 bool) groResult { if len(pkt) > maxUint16 { // A valid IPv4 or IPv6 packet will never exceed this. return groResultNoop @@ -840,7 +871,7 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) srcAddrOffset = ipv6SrcAddrOffset addrLen = 16 } - items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI) + items, existing := table.lookupOrInsert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, wi) if !existing { return groResultTableInsert } @@ -848,10 +879,10 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) // for a given flow. We must also always insert a new item, or successfully // coalesce with an existing item, for the same reason. item := items[len(items)-1] - can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, bufs, offset) + can := udpPacketsCanCoalesce(pkt, uint8(iphLen), gsoSize, item, wi) var pktCSumKnownInvalid bool if can == coalesceAppend { - result := coalesceUDPPackets(pkt, &item, bufs, offset, isV6) + result := coalesceUDPPackets(pkt, &item, wi, isV6) switch result { case coalesceSuccess: table.updateAt(item, len(items)-1) @@ -868,44 +899,42 @@ func udpGRO(bufs [][]byte, offset int, pktI int, table *udpGROTable, isV6 bool) } } // failed to coalesce with any other packets; store the item in the flow - table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, pktI, pktCSumKnownInvalid) + table.insert(pkt, srcAddrOffset, srcAddrOffset+addrLen, iphLen, wi, pktCSumKnownInvalid) return groResultTableInsert } -// handleGRO evaluates bufs for GRO, and writes the indices of the resulting -// packets into toWrite. toWrite, tcpTable, and udpTable should initially be +// handleGRO evaluates bufs for GRO, and populates wi with the resulting +// io vectors for writev. wi, tcpTable, and udpTable should initially be // empty (but non-nil), and are passed in to save allocs as the caller may reset // and recycle them across vectors of packets. gro indicates if TCP and UDP GRO // are supported/enabled. -func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, toWrite *[]int) error { +func handleGRO(bufs [][]byte, offset int, tcpTable *tcpGROTable, udpTable *udpGROTable, gro groDisablementFlags, wi *groToWrite) error { for i := range bufs { if offset < virtioNetHdrLen || offset > len(bufs[i])-1 { return errors.New("invalid offset") } + pkt := bufs[i][offset:] var result groResult - switch packetIsGROCandidate(bufs[i][offset:], gro) { + switch packetIsGROCandidate(pkt, gro) { case tcp4GROCandidate: - result = tcpGRO(bufs, offset, i, tcpTable, false) + result = tcpGRO(pkt, tcpTable, wi, false) case tcp6GROCandidate: - result = tcpGRO(bufs, offset, i, tcpTable, true) + result = tcpGRO(pkt, tcpTable, wi, true) case udp4GROCandidate: - result = udpGRO(bufs, offset, i, udpTable, false) + result = udpGRO(pkt, udpTable, wi, false) case udp6GROCandidate: - result = udpGRO(bufs, offset, i, udpTable, true) + result = udpGRO(pkt, udpTable, wi, true) } switch result { case groResultNoop: - hdr := virtioNetHdr{} - err := hdr.encode(bufs[i][offset-virtioNetHdrLen:]) - if err != nil { - return err - } - fallthrough + _, iov := wi.next() + clear((*iov)[virtioNetHdrIdx]) + *iov = append(*iov, pkt) case groResultTableInsert: - *toWrite = append(*toWrite, i) + // already in wi via table insert } } - errTCP := applyTCPCoalesceAccounting(bufs, offset, tcpTable) - errUDP := applyUDPCoalesceAccounting(bufs, offset, udpTable) + errTCP := applyTCPCoalesceAccounting(wi, tcpTable) + errUDP := applyUDPCoalesceAccounting(wi, udpTable) return errors.Join(errTCP, errUDP) } diff --git a/tun/offload_linux_test.go b/tun/offload_linux_test.go index 407037863..71d5bef3d 100644 --- a/tun/offload_linux_test.go +++ b/tun/offload_linux_test.go @@ -7,9 +7,11 @@ package tun import ( "net/netip" + "slices" "testing" "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -234,13 +236,12 @@ func Test_handleVirtioRead(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - out := make([][]byte, conn.IdealBatchSize) - sizes := make([]int, conn.IdealBatchSize) + out := make([]iobuf.View, conn.IdealBatchSize) for i := range out { - out[i] = make([]byte, 65535) + out[i] = iobuf.View{Bytes: make([]byte, 65535)} } tt.hdr.encode(tt.pktIn) - n, err := handleVirtioRead(tt.pktIn, out, sizes, offset) + n, err := handleVirtioRead(tt.pktIn, out, offset) if err != nil { if tt.wantErr { return @@ -251,8 +252,8 @@ func Test_handleVirtioRead(t *testing.T) { t.Fatalf("got %d packets, wanted %d", n, len(tt.wantLens)) } for i := range tt.wantLens { - if tt.wantLens[i] != sizes[i] { - t.Fatalf("wantLens[%d]: %d != outSizes: %d", i, tt.wantLens[i], sizes[i]) + if size := len(out[i].Bytes) - offset; tt.wantLens[i] != size { + t.Fatalf("wantLens[%d]: %d != size: %d", i, tt.wantLens[i], size) } } }) @@ -289,32 +290,21 @@ func Fuzz_handleGRO(f *testing.F) { f.Add(pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11, 0, offset) f.Fuzz(func(t *testing.T, pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11 []byte, gro int, offset int) { pkts := [][]byte{pkt0, pkt1, pkt2, pkt3, pkt4, pkt5, pkt6, pkt7, pkt8, pkt9, pkt10, pkt11} - toWrite := make([]int, 0, len(pkts)) - handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), groDisablementFlags(gro), &toWrite) - if len(toWrite) > len(pkts) { - t.Errorf("len(toWrite): %d > len(pkts): %d", len(toWrite), len(pkts)) - } - seenWriteI := make(map[int]bool) - for _, writeI := range toWrite { - if writeI < 0 || writeI > len(pkts)-1 { - t.Errorf("toWrite value (%d) outside bounds of len(pkts): %d", writeI, len(pkts)) - } - if seenWriteI[writeI] { - t.Errorf("duplicate toWrite value: %d", writeI) - } - seenWriteI[writeI] = true + wi := newGROToWrite() + handleGRO(pkts, offset, newTCPGROTable(), newUDPGROTable(), groDisablementFlags(gro), &wi) + if len(wi.iovs) > len(pkts) { + t.Errorf("len(wi.iovs): %d > len(pkts): %d", len(wi.iovs), len(pkts)) } }) } func Test_handleGRO(t *testing.T) { tests := []struct { - name string - pktsIn [][]byte - gro groDisablementFlags - wantToWrite []int - wantLens []int - wantErr bool + name string + pktsIn [][]byte + gro groDisablementFlags + wantLens [][]int + wantErr bool }{ { "multiple protocols and flows", @@ -332,8 +322,15 @@ func Test_handleGRO(t *testing.T) { udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 }, 0, - []int{0, 1, 2, 4, 5, 7, 9}, - []int{240, 228, 128, 140, 260, 160, 248}, + [][]int{ + {virtioNetHdrLen, 140, 100}, // tcp4 A->B merged + {virtioNetHdrLen, 128, 100}, // udp4 A->B merged + {virtioNetHdrLen, 128}, // udp4 A->C + {virtioNetHdrLen, 140}, // tcp4 A->C + {virtioNetHdrLen, 160, 100}, // tcp6 A->B merged + {virtioNetHdrLen, 160}, // tcp6 A->C + {virtioNetHdrLen, 148, 100}, // udp6 A->B merged + }, false, }, { @@ -352,8 +349,17 @@ func Test_handleGRO(t *testing.T) { udp6Packet(ip6PortA, ip6PortB, 100), // udp6 flow 1 }, udpGRODisabled, - []int{0, 1, 2, 4, 5, 7, 8, 9, 10}, - []int{240, 128, 128, 140, 260, 160, 128, 148, 148}, + [][]int{ + {virtioNetHdrLen, 140, 100}, // tcp4 A->B merged + {virtioNetHdrLen, 128}, // udp4 A->B noop + {virtioNetHdrLen, 128}, // udp4 A->C noop + {virtioNetHdrLen, 140}, // tcp4 A->C + {virtioNetHdrLen, 160, 100}, // tcp6 A->B merged + {virtioNetHdrLen, 160}, // tcp6 A->C + {virtioNetHdrLen, 128}, // udp4 A->B noop + {virtioNetHdrLen, 148}, // udp6 A->B noop + {virtioNetHdrLen, 148}, // udp6 A->B noop + }, false, }, { @@ -369,8 +375,12 @@ func Test_handleGRO(t *testing.T) { tcp6Packet(ip6PortA, ip6PortB, header.TCPFlagAck, 100, 301), // v6 flow 1 }, 0, - []int{0, 2, 4, 6}, - []int{240, 240, 260, 260}, + [][]int{ + {virtioNetHdrLen, 140, 100}, // v4 merged (seq 1+101 PSH) + {virtioNetHdrLen, 140, 100}, // v4 merged (seq 201+301) + {virtioNetHdrLen, 160, 100}, // v6 merged (seq 1+101 PSH) + {virtioNetHdrLen, 160, 100}, // v6 merged (seq 201+301) + }, false, }, { @@ -384,8 +394,12 @@ func Test_handleGRO(t *testing.T) { udp4Packet(ip4PortA, ip4PortB, 100), }, 0, - []int{0, 1, 3, 4}, - []int{140, 240, 128, 228}, + [][]int{ + {virtioNetHdrLen, 140}, // tcp4 bad csum, unmerged + {virtioNetHdrLen, 140, 100}, // tcp4 merged (seq 101+201) + {virtioNetHdrLen, 128}, // udp4 bad csum, unmerged + {virtioNetHdrLen, 128, 100}, // udp4 merged + }, false, }, { @@ -396,8 +410,9 @@ func Test_handleGRO(t *testing.T) { tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 201), // v4 flow 1 seq 201 len 100 }, 0, - []int{0}, - []int{340}, + [][]int{ + {virtioNetHdrLen, 140, 100, 100}, // prepend seq 1, original seq 101 payload, append seq 201 + }, false, }, { @@ -413,8 +428,12 @@ func Test_handleGRO(t *testing.T) { }), }, 0, - []int{0, 1, 2, 3}, - []int{140, 140, 128, 128}, + [][]int{ + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 128}, + {virtioNetHdrLen, 128}, + }, false, }, { @@ -430,8 +449,12 @@ func Test_handleGRO(t *testing.T) { }), }, 0, - []int{0, 1, 2, 3}, - []int{140, 140, 128, 128}, + [][]int{ + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 128}, + {virtioNetHdrLen, 128}, + }, false, }, { @@ -447,8 +470,12 @@ func Test_handleGRO(t *testing.T) { }), }, 0, - []int{0, 1, 2, 3}, - []int{140, 140, 128, 128}, + [][]int{ + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 128}, + {virtioNetHdrLen, 128}, + }, false, }, { @@ -464,8 +491,12 @@ func Test_handleGRO(t *testing.T) { }), }, 0, - []int{0, 1, 2, 3}, - []int{140, 140, 128, 128}, + [][]int{ + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 140}, + {virtioNetHdrLen, 128}, + {virtioNetHdrLen, 128}, + }, false, }, { @@ -481,8 +512,12 @@ func Test_handleGRO(t *testing.T) { }), }, 0, - []int{0, 1, 2, 3}, - []int{160, 160, 148, 148}, + [][]int{ + {virtioNetHdrLen, 160}, + {virtioNetHdrLen, 160}, + {virtioNetHdrLen, 148}, + {virtioNetHdrLen, 148}, + }, false, }, { @@ -498,31 +533,46 @@ func Test_handleGRO(t *testing.T) { }), }, 0, - []int{0, 1, 2, 3}, - []int{160, 160, 148, 148}, + [][]int{ + {virtioNetHdrLen, 160}, + {virtioNetHdrLen, 160}, + {virtioNetHdrLen, 148}, + {virtioNetHdrLen, 148}, + }, false, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - toWrite := make([]int, 0, len(tt.pktsIn)) - err := handleGRO(tt.pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.gro, &toWrite) - if err != nil { - if tt.wantErr { - return + wi := newGROToWrite() + for range 2 { // validating reset() correctness + wi.reset() + // Deep copy pktsIn since coalesce accounting mutates head packet headers. + pktsIn := make([][]byte, len(tt.pktsIn)) + for k, p := range tt.pktsIn { + pktsIn[k] = slices.Clone(p) } - t.Fatalf("got err: %v", err) - } - if len(toWrite) != len(tt.wantToWrite) { - t.Fatalf("got %d packets, wanted %d", len(toWrite), len(tt.wantToWrite)) - } - for i, pktI := range tt.wantToWrite { - if tt.wantToWrite[i] != toWrite[i] { - t.Fatalf("wantToWrite[%d]: %d != toWrite: %d", i, tt.wantToWrite[i], toWrite[i]) + err := handleGRO(pktsIn, offset, newTCPGROTable(), newUDPGROTable(), tt.gro, &wi) + if err != nil { + if tt.wantErr { + return + } + t.Fatalf("got err: %v", err) + } + if len(wi.iovs) != len(tt.wantLens) { + t.Fatalf("got %d packets, wanted %d", len(wi.iovs), len(tt.wantLens)) } - if tt.wantLens[i] != len(tt.pktsIn[pktI][offset:]) { - t.Errorf("wanted len %d packet at %d, got: %d", tt.wantLens[i], i, len(tt.pktsIn[pktI][offset:])) + for i, wantFragLens := range tt.wantLens { + iov := wi.iovs[i] + if len(iov) != len(wantFragLens) { + t.Fatalf("items[%d]: got %d fragments, wanted %d", i, len(iov), len(wantFragLens)) + } + for j, wantLen := range wantFragLens { + if len(iov[j]) != wantLen { + t.Errorf("items[%d][%d]: got len %d, want %d", i, j, len(iov[j]), wantLen) + } + } } } }) @@ -669,12 +719,11 @@ func Test_udpPacketsCanCoalesce(t *testing.T) { udp4c := udp4Packet(ip4PortA, ip4PortB, 110) type args struct { - pkt []byte - iphLen uint8 - gsoSize uint16 - item udpGROItem - bufs [][]byte - bufsOffset int + pkt []byte + iphLen uint8 + gsoSize uint16 + item udpGROItem + wi groToWrite } tests := []struct { name string @@ -688,14 +737,12 @@ func Test_udpPacketsCanCoalesce(t *testing.T) { iphLen: 20, gsoSize: 100, item: udpGROItem{ - gsoSize: 100, - iphLen: 20, + gsoSize: 100, + iphLen: 20, + outputIdx: 0, + payloadLen: 100, }, - bufs: [][]byte{ - udp4a, - udp4b, - }, - bufsOffset: offset, + wi: groToWrite{iovs: [][][]byte{{make([]byte, virtioNetHdrLen), udp4b[offset:]}}}, }, coalesceAppend, }, @@ -706,14 +753,12 @@ func Test_udpPacketsCanCoalesce(t *testing.T) { iphLen: 20, gsoSize: 10, item: udpGROItem{ - gsoSize: 100, - iphLen: 20, - }, - bufs: [][]byte{ - udp4a, - udp4b, + gsoSize: 100, + iphLen: 20, + outputIdx: 0, + payloadLen: 100, }, - bufsOffset: offset, + wi: groToWrite{iovs: [][][]byte{{make([]byte, virtioNetHdrLen), udp4b[offset:]}}}, }, coalesceAppend, }, @@ -724,14 +769,12 @@ func Test_udpPacketsCanCoalesce(t *testing.T) { iphLen: 20, gsoSize: 100, item: udpGROItem{ - gsoSize: 100, - iphLen: 20, + gsoSize: 100, + iphLen: 20, + outputIdx: 0, + payloadLen: 110, }, - bufs: [][]byte{ - udp4c, - udp4b, - }, - bufsOffset: offset, + wi: groToWrite{iovs: [][][]byte{{make([]byte, virtioNetHdrLen), udp4c[offset:]}}}, }, coalesceUnavailable, }, @@ -742,23 +785,123 @@ func Test_udpPacketsCanCoalesce(t *testing.T) { iphLen: 20, gsoSize: 110, item: udpGROItem{ - gsoSize: 100, - iphLen: 20, + gsoSize: 100, + iphLen: 20, + outputIdx: 0, + payloadLen: 100, }, - bufs: [][]byte{ - udp4a, - udp4c, + wi: groToWrite{iovs: [][][]byte{{make([]byte, virtioNetHdrLen), udp4a[offset:]}}}, + }, + coalesceUnavailable, + }, + { + "coalesceUnavailable too many fragments", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + outputIdx: 0, + payloadLen: 100, }, - bufsOffset: offset, + wi: groToWrite{iovs: [][][]byte{make([][]byte, maxScatterGatherFragments)}}, + }, + coalesceUnavailable, + }, + { + "coalesceUnavailable payload overflow", + args{ + pkt: udp4a[offset:], + iphLen: 20, + gsoSize: 100, + item: udpGROItem{ + gsoSize: 100, + iphLen: 20, + outputIdx: 0, + payloadLen: maxUint16 - 20 - 8, + }, + wi: groToWrite{iovs: [][][]byte{{make([]byte, virtioNetHdrLen), udp4a[offset:]}}}, }, coalesceUnavailable, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, tt.args.bufs, tt.args.bufsOffset); got != tt.want { + if got := udpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.gsoSize, tt.args.item, &tt.args.wi); got != tt.want { t.Errorf("udpPacketsCanCoalesce() = %v, want %v", got, tt.want) } }) } } + +func Test_tcpPacketsCanCoalesce(t *testing.T) { + tcp4a := tcp4Packet(ip4PortA, ip4PortB, header.TCPFlagAck, 100, 100) + + type args struct { + pkt []byte + iphLen uint8 + tcphLen uint8 + seq uint32 + pshSet bool + gsoSize uint16 + item tcpGROItem + wi groToWrite + } + tests := []struct { + name string + args args + want canCoalesce + }{ + { + "coalesceUnavailable too many fragments", + args{ + pkt: tcp4a[offset:], + iphLen: 20, + tcphLen: 20, + seq: 200, + pshSet: false, + gsoSize: 100, + item: tcpGROItem{ + gsoSize: 100, + iphLen: 20, + tcphLen: 20, + sentSeq: 100, + outputIdx: 0, + payloadLen: 100, + }, + wi: groToWrite{iovs: [][][]byte{make([][]byte, maxScatterGatherFragments)}}, + }, + coalesceUnavailable, + }, + { + "coalesceUnavailable payload overflow", + args{ + pkt: tcp4a[offset:], + iphLen: 20, + tcphLen: 20, + seq: 200, + pshSet: false, + gsoSize: 100, + item: tcpGROItem{ + gsoSize: 100, + iphLen: 20, + tcphLen: 20, + sentSeq: 100, + outputIdx: 0, + payloadLen: maxUint16 - 20 - 20, + }, + wi: groToWrite{iovs: [][][]byte{{make([]byte, virtioNetHdrLen), tcp4a[offset:]}}}, + }, + coalesceUnavailable, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tcpPacketsCanCoalesce(tt.args.pkt, tt.args.iphLen, tt.args.tcphLen, tt.args.seq, tt.args.pshSet, tt.args.gsoSize, tt.args.item, &tt.args.wi); got != tt.want { + t.Errorf("tcpPacketsCanCoalesce() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/tun/offload_test.go b/tun/offload_test.go index 82a37b9cc..33109b72c 100644 --- a/tun/offload_test.go +++ b/tun/offload_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/iobuf" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" ) @@ -67,11 +68,10 @@ func Fuzz_GSOSplit(f *testing.F) { }) header.UDP(gsoUDPv6[20:]).Encode(udpFields) - out := make([][]byte, conn.IdealBatchSize) + out := make([]iobuf.View, conn.IdealBatchSize) for i := range out { - out[i] = make([]byte, 65535) + out[i] = iobuf.View{Bytes: make([]byte, 65535)} } - sizes := make([]int, conn.IdealBatchSize) f.Add(gsoTCPv4, int(GSOTCPv4), uint16(40), uint16(20), uint16(16), uint16(100), false) f.Add(gsoUDPv4, int(GSOUDPL4), uint16(28), uint16(20), uint16(6), uint16(100), false) @@ -87,9 +87,9 @@ func Fuzz_GSOSplit(f *testing.F) { GSOSize: gsoSize, NeedsCsum: needsCsum, } - n, _ := GSOSplit(pkt, options, out, sizes, 0) - if n > len(sizes) { - t.Errorf("n (%d) > len(sizes): %d", n, len(sizes)) + n, _ := GSOSplit(pkt, options, out, 0) + if n > len(out) { + t.Errorf("n (%d) > len(out): %d", n, len(out)) } }) } diff --git a/tun/tun.go b/tun/tun.go index 719a60631..5b33b92df 100644 --- a/tun/tun.go +++ b/tun/tun.go @@ -7,6 +7,8 @@ package tun import ( "os" + + "github.com/tailscale/wireguard-go/iobuf" ) type Event int @@ -23,10 +25,11 @@ type Device interface { // Read one or more packets from the Device (without any additional headers). // On a successful read it returns the number of packets read, and sets - // packet lengths within the sizes slice. len(sizes) must be >= len(bufs). + // each buf's length to include the read data. + // Zero-valued entries in bufs are allocated by the implementation. // A nonzero offset can be used to instruct the Device on where to begin // reading into each element of the bufs slice. - Read(bufs [][]byte, sizes []int, offset int) (n int, err error) + Read(bufs []iobuf.View, offset int) (n int, err error) // Write one or more packets to the device (without any additional headers). // On a successful write it returns the number of packets written. A nonzero diff --git a/tun/tun_darwin.go b/tun/tun_darwin.go index c9a6c0bc4..f91b11c17 100644 --- a/tun/tun_darwin.go +++ b/tun/tun_darwin.go @@ -16,6 +16,7 @@ import ( "time" "unsafe" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/sys/unix" ) @@ -217,7 +218,7 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []iobuf.View, offset int) (n int, err error) { // TODO: the BSDs look very similar in Read() and Write(). They should be // collapsed, with platform-specific files containing the varying parts of // their implementations. @@ -225,12 +226,12 @@ func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) case err := <-tun.errors: return 0, err default: - buf := bufs[0][offset-4:] - n, err := tun.tunFile.Read(buf[:]) + iobuf.EnsureAllocated(bufs[:1]) + n, err := tun.tunFile.Read(bufs[0].Bytes[offset-4:]) if n < 4 { return 0, err } - sizes[0] = n - 4 + bufs[0].Bytes = bufs[0].Bytes[:offset+n-4] return 1, err } } diff --git a/tun/tun_freebsd.go b/tun/tun_freebsd.go index 7c65fd999..2df631b54 100644 --- a/tun/tun_freebsd.go +++ b/tun/tun_freebsd.go @@ -15,6 +15,7 @@ import ( "syscall" "unsafe" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/sys/unix" ) @@ -38,9 +39,9 @@ type ifreqName struct { // Iface requests with a pointer type ifreqPtr struct { - Name [unix.IFNAMSIZ]byte - Data uintptr - _ [16 - unsafe.Sizeof(uintptr(0))]byte + Name [unix.IFNAMSIZ]byte + Bytes uintptr + _ [16 - unsafe.Sizeof(uintptr(0))]byte } // Iface requests with MTU @@ -249,7 +250,7 @@ func CreateTUN(name string, mtu int) (Device, error) { copy(newnp[:], name) var ifr ifreqPtr copy(ifr.Name[:], assignedName) - ifr.Data = uintptr(unsafe.Pointer(&newnp[0])) + ifr.Bytes = uintptr(unsafe.Pointer(&newnp[0])) _, _, errno = unix.Syscall(unix.SYS_IOCTL, uintptr(confd), uintptr(unix.SIOCSIFNAME), uintptr(unsafe.Pointer(&ifr))) if errno != 0 { tunFile.Close() @@ -333,17 +334,17 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []iobuf.View, offset int) (n int, err error) { select { case err := <-tun.errors: return 0, err default: - buf := bufs[0][offset-4:] - n, err := tun.tunFile.Read(buf[:]) + iobuf.EnsureAllocated(bufs[:1]) + n, err := tun.tunFile.Read(bufs[0].Bytes[offset-4:]) if n < 4 { return 0, err } - sizes[0] = n - 4 + bufs[0].Bytes = bufs[0].Bytes[:offset+n-4] return 1, err } } diff --git a/tun/tun_linux.go b/tun/tun_linux.go index 7cdbf8825..3ef1dc617 100644 --- a/tun/tun_linux.go +++ b/tun/tun_linux.go @@ -18,6 +18,7 @@ import ( "unsafe" "github.com/tailscale/wireguard-go/conn" + "github.com/tailscale/wireguard-go/iobuf" "github.com/tailscale/wireguard-go/rwcancel" "golang.org/x/sys/unix" ) @@ -29,6 +30,7 @@ const ( type NativeTun struct { tunFile *os.File + tunRawConn syscall.RawConn index int32 // if index errors chan error // async error handling events chan Event // device related events @@ -49,7 +51,7 @@ type NativeTun struct { readBuff [virtioNetHdrLen + 65535]byte // if vnetHdr every read() is prefixed by virtioNetHdr writeOpMu sync.Mutex // writeOpMu guards the following fields - toWrite []int + toWrite groToWrite tcpGROTable *tcpGROTable udpGROTable *udpGROTable gro groDisablementFlags @@ -354,31 +356,53 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { defer func() { tun.tcpGROTable.reset() tun.udpGROTable.reset() + tun.toWrite.reset() tun.writeOpMu.Unlock() }() var ( errs error total int ) - tun.toWrite = tun.toWrite[:0] - if tun.vnetHdr { - err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.gro, &tun.toWrite) - if err != nil { - return 0, err - } - offset -= virtioNetHdrLen - } else { + if !tun.vnetHdr { for i := range bufs { - tun.toWrite = append(tun.toWrite, i) + n, err := tun.tunFile.Write(bufs[i][offset:]) + if errors.Is(err, syscall.EBADFD) { + return total, os.ErrClosed + } + if err != nil { + errs = errors.Join(errs, err) + } else { + total += n + } } + return total, errs } - for _, bufsI := range tun.toWrite { - n, err := tun.tunFile.Write(bufs[bufsI][offset:]) - if errors.Is(err, syscall.EBADFD) { + err := handleGRO(bufs, offset, tun.tcpGROTable, tun.udpGROTable, tun.gro, &tun.toWrite) + if err != nil { + return 0, err + } + for _, nb := range tun.toWrite.iovs { + var werr error + var n int + err := tun.tunRawConn.Write(func(fd uintptr) bool { + for { + n, werr = unix.Writev(int(fd), nb) + if werr == syscall.EINTR { + continue // quick retry on interrupt, EINTR is never returned with partial writes + } + return werr != syscall.EAGAIN // poller retry on "would block" + } + }) + // err is a poller error (e.g. fd closed before the syscall) + // werr is the Writev syscall error itself. + if err != nil { + return total, err + } + if errors.Is(werr, syscall.EBADFD) { return total, os.ErrClosed } - if err != nil { - errs = errors.Join(errs, err) + if werr != nil { + errs = errors.Join(errs, werr) } else { total += n } @@ -387,9 +411,9 @@ func (tun *NativeTun) Write(bufs [][]byte, offset int) (int, error) { } // handleVirtioRead splits in into bufs, leaving offset bytes at the front of -// each buffer. It mutates sizes to reflect the size of each element of bufs, -// and returns the number of packets read. -func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, error) { +// each buffer. It sets each buffer's Bytes length to reflect the size of each +// element of bufs, and returns the number of packets read. +func handleVirtioRead(in []byte, bufs []iobuf.View, offset int) (int, error) { var hdr virtioNetHdr err := hdr.decode(in) if err != nil { @@ -421,17 +445,18 @@ func handleVirtioRead(in []byte, bufs [][]byte, sizes []int, offset int) (int, e options.HdrLen = options.CsumStart + tcpHLen } - return GSOSplit(in, options, bufs, sizes, offset) + return GSOSplit(in, options, bufs, offset) } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []iobuf.View, offset int) (int, error) { tun.readOpMu.Lock() defer tun.readOpMu.Unlock() select { case err := <-tun.errors: return 0, err default: - readInto := bufs[0][offset:] + iobuf.EnsureAllocated(bufs) + readInto := bufs[0].Bytes[offset:] if tun.vnetHdr { readInto = tun.readBuff[:] } @@ -443,9 +468,9 @@ func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) return 0, err } if tun.vnetHdr { - return handleVirtioRead(readInto[:n], bufs, sizes, offset) + return handleVirtioRead(readInto[:n], bufs, offset) } else { - sizes[0] = n + bufs[0].Bytes = bufs[0].Bytes[:n+offset] return 1, nil } } @@ -577,6 +602,7 @@ func CreateTUN(name string, mtu int) (Device, error) { // CreateTUNFromFile creates a Device from an os.File with the provided MTU. func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { + var err error tun := &NativeTun{ tunFile: file, events: make(chan Event, 5), @@ -584,7 +610,12 @@ func CreateTUNFromFile(file *os.File, mtu int) (Device, error) { statusListenersShutdown: make(chan struct{}), tcpGROTable: newTCPGROTable(), udpGROTable: newUDPGROTable(), - toWrite: make([]int, 0, conn.IdealBatchSize), + toWrite: newGROToWrite(), + } + + tun.tunRawConn, err = tun.tunFile.SyscallConn() + if err != nil { + return nil, err } name, err := tun.Name() @@ -640,7 +671,11 @@ func CreateUnmonitoredTUNFromFD(fd int) (Device, string, error) { errors: make(chan error, 5), tcpGROTable: newTCPGROTable(), udpGROTable: newUDPGROTable(), - toWrite: make([]int, 0, conn.IdealBatchSize), + toWrite: newGROToWrite(), + } + tun.tunRawConn, err = tun.tunFile.SyscallConn() + if err != nil { + return nil, "", err } name, err := tun.Name() if err != nil { diff --git a/tun/tun_openbsd.go b/tun/tun_openbsd.go index ae571b90c..aa25529f4 100644 --- a/tun/tun_openbsd.go +++ b/tun/tun_openbsd.go @@ -15,6 +15,7 @@ import ( "syscall" "unsafe" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/sys/unix" ) @@ -204,17 +205,17 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []iobuf.View, offset int) (n int, err error) { select { case err := <-tun.errors: return 0, err default: - buf := bufs[0][offset-4:] - n, err := tun.tunFile.Read(buf[:]) + iobuf.EnsureAllocated(bufs[:1]) + n, err := tun.tunFile.Read(bufs[0].Bytes[offset-4:]) if n < 4 { return 0, err } - sizes[0] = n - 4 + bufs[0].Bytes = bufs[0].Bytes[:offset+n-4] return 1, err } } diff --git a/tun/tun_plan9.go b/tun/tun_plan9.go index 7b66eadf6..180b3ada4 100644 --- a/tun/tun_plan9.go +++ b/tun/tun_plan9.go @@ -12,6 +12,8 @@ import ( "strconv" "strings" "sync" + + "github.com/tailscale/wireguard-go/iobuf" ) type NativeTun struct { @@ -81,18 +83,19 @@ func (tun *NativeTun) Events() <-chan Event { return tun.events } -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []iobuf.View, offset int) (n int, err error) { select { case err := <-tun.errors: return 0, err default: - n, err := tun.dataFile.Read(bufs[0][offset:]) - if n == 1 && bufs[0][offset] == 0 { + iobuf.EnsureAllocated(bufs[:1]) + n, err := tun.dataFile.Read(bufs[0].Bytes[offset:]) + if n == 1 && bufs[0].Bytes[offset] == 0 { // EOF err = io.EOF n = 0 } - sizes[0] = n + bufs[0].Bytes = bufs[0].Bytes[:offset+n] return 1, err } } diff --git a/tun/tun_windows.go b/tun/tun_windows.go index 34f29805d..e9096f165 100644 --- a/tun/tun_windows.go +++ b/tun/tun_windows.go @@ -14,6 +14,7 @@ import ( "time" _ "unsafe" + "github.com/tailscale/wireguard-go/iobuf" "golang.org/x/sys/windows" "golang.zx2c4.com/wintun" ) @@ -144,7 +145,7 @@ func (tun *NativeTun) BatchSize() int { // Note: Read() and Write() assume the caller comes only from a single thread; there's no locking. -func (tun *NativeTun) Read(bufs [][]byte, sizes []int, offset int) (int, error) { +func (tun *NativeTun) Read(bufs []iobuf.View, offset int) (n int, err error) { tun.running.Add(1) defer tun.running.Done() retry: @@ -161,8 +162,9 @@ retry: switch err { case nil: packetSize := len(packet) - copy(bufs[0][offset:], packet) - sizes[0] = packetSize + iobuf.EnsureAllocated(bufs[:1]) + n := copy(bufs[0].Bytes[offset:], packet) + bufs[0].Bytes = bufs[0].Bytes[:offset+n] tun.session.ReleaseReceivePacket(packet) tun.rate.update(uint64(packetSize)) return 1, nil diff --git a/tun/tuntest/tuntest.go b/tun/tuntest/tuntest.go index e7507c26c..9e1c924e7 100644 --- a/tun/tuntest/tuntest.go +++ b/tun/tuntest/tuntest.go @@ -11,6 +11,7 @@ import ( "net/netip" "os" + "github.com/tailscale/wireguard-go/iobuf" "github.com/tailscale/wireguard-go/tun" ) @@ -110,13 +111,15 @@ type chTun struct { func (t *chTun) File() *os.File { return nil } -func (t *chTun) Read(packets [][]byte, sizes []int, offset int) (int, error) { +func (t *chTun) Read(bufs []iobuf.View, offset int) (int, error) { select { case <-t.c.closed: return 0, os.ErrClosed case msg := <-t.c.Outbound: - n := copy(packets[0][offset:], msg) - sizes[0] = n + // TODO: Allocate len(msg) buffer. + iobuf.EnsureAllocated(bufs[:1]) + n := copy(bufs[0].Bytes[offset:], msg) + bufs[0].Bytes = bufs[0].Bytes[:offset+n] return 1, nil } } diff --git a/device/race_disabled_test.go b/waitpool/race_disabled_test.go similarity index 89% rename from device/race_disabled_test.go rename to waitpool/race_disabled_test.go index bb5c45032..24d3d22d3 100644 --- a/device/race_disabled_test.go +++ b/waitpool/race_disabled_test.go @@ -5,6 +5,6 @@ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ -package device +package waitpool const raceEnabled = false diff --git a/device/race_enabled_test.go b/waitpool/race_enabled_test.go similarity index 89% rename from device/race_enabled_test.go rename to waitpool/race_enabled_test.go index 4e9daea79..1193d02be 100644 --- a/device/race_enabled_test.go +++ b/waitpool/race_enabled_test.go @@ -5,6 +5,6 @@ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ -package device +package waitpool const raceEnabled = true diff --git a/waitpool/waitpool.go b/waitpool/waitpool.go new file mode 100644 index 000000000..686407758 --- /dev/null +++ b/waitpool/waitpool.go @@ -0,0 +1,59 @@ +/* SPDX-License-Identifier: MIT + * + * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. + */ + +// Package waitpool provides a sync.Pool wrapper that caps the number of +// concurrently checked-out elements, blocking Get when the cap is reached. +package waitpool + +import ( + "sync" +) + +// WaitPool is a sync.Pool with an optional concurrency cap. When max > 0, +// Get blocks once max elements are checked out until a corresponding Put +// returns one. When max == 0 there is no cap and Get never blocks. +type WaitPool struct { + pool sync.Pool + cond sync.Cond + lock sync.Mutex + count int // Get calls not yet Put back + max int +} + +// New returns a WaitPool with the given concurrency cap and constructor. +// A max of 0 (or negative) disables the cap. +func New(max int, newFn func() any) *WaitPool { + if max < 0 { + max = 0 + } + p := &WaitPool{pool: sync.Pool{New: newFn}, max: max} + p.cond = sync.Cond{L: &p.lock} + return p +} + +// Get returns an element from the pool, blocking if the concurrency cap is reached. +func (p *WaitPool) Get() any { + if p.max != 0 { + p.lock.Lock() + for p.count >= p.max { + p.cond.Wait() + } + p.count++ + p.lock.Unlock() + } + return p.pool.Get() +} + +// Put returns an element to the pool and unblocks one waiting Get if any. +func (p *WaitPool) Put(x any) { + p.pool.Put(x) + if p.max == 0 { + return + } + p.lock.Lock() + defer p.lock.Unlock() + p.count-- + p.cond.Signal() +} diff --git a/device/pools_test.go b/waitpool/waitpool_test.go similarity index 88% rename from device/pools_test.go rename to waitpool/waitpool_test.go index 9c3e8d733..1d672bbb0 100644 --- a/device/pools_test.go +++ b/waitpool/waitpool_test.go @@ -3,7 +3,7 @@ * Copyright (C) 2017-2023 WireGuard LLC. All Rights Reserved. */ -package device +package waitpool import ( "math/rand" @@ -30,9 +30,9 @@ func TestWaitPool(t *testing.T) { if workers-4 <= 0 { t.Skip("Not enough cores") } - p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) + p := New(workers-4, func() any { return make([]byte, 16) }) wg.Add(workers) - var max atomic.Uint32 + var max atomic.Int64 updateMax := func() { p.lock.Lock() count := p.count @@ -42,10 +42,10 @@ func TestWaitPool(t *testing.T) { } for { old := max.Load() - if count <= old { + if int64(count) <= old { break } - if max.CompareAndSwap(old, count) { + if max.CompareAndSwap(old, int64(count)) { break } } @@ -65,7 +65,7 @@ func TestWaitPool(t *testing.T) { }() } wg.Wait() - if max.Load() != p.max { + if max.Load() != int64(p.max) { t.Errorf("Actual maximum count (%d) != ideal maximum count (%d)", max.Load(), p.max) } } @@ -78,7 +78,7 @@ func BenchmarkWaitPool(b *testing.B) { if workers-4 <= 0 { b.Skip("Not enough cores") } - p := NewWaitPool(uint32(workers-4), func() any { return make([]byte, 16) }) + p := New(workers-4, func() any { return make([]byte, 16) }) wg.Add(workers) b.ResetTimer() for i := 0; i < workers; i++ { @@ -102,7 +102,7 @@ func BenchmarkWaitPoolEmpty(b *testing.B) { if workers-4 <= 0 { b.Skip("Not enough cores") } - p := NewWaitPool(0, func() any { return make([]byte, 16) }) + p := New(0, func() any { return make([]byte, 16) }) wg.Add(workers) b.ResetTimer() for i := 0; i < workers; i++ {