Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 10 additions & 9 deletions conn/bind_std.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ import (
"sync"
"syscall"

"github.com/tailscale/wireguard-go/iobuf"
"golang.org/x/net/ipv4"
"golang.org/x/net/ipv6"
)
Expand Down Expand Up @@ -233,13 +234,13 @@ func (s *StdNetBind) receiveIP(
br batchReader,
conn *net.UDPConn,
rxOffload bool,
bufs [][]byte,
sizes []int,
bufs []iobuf.View,
eps []Endpoint,
) (n int, err error) {
msgs := s.getMessages()
iobuf.EnsureAllocated(bufs)
for i := range bufs {
(*msgs)[i].Buffers[0] = bufs[i]
(*msgs)[i].Buffers[0] = bufs[i].Bytes
(*msgs)[i].OOB = (*msgs)[i].OOB[:cap((*msgs)[i].OOB)]
}
defer s.putMessages(msgs)
Expand Down Expand Up @@ -271,8 +272,8 @@ func (s *StdNetBind) receiveIP(
}
for i := 0; i < numMsgs; i++ {
msg := &(*msgs)[i]
sizes[i] = msg.N
if sizes[i] == 0 {
bufs[i].Bytes = bufs[i].Bytes[:msg.N]
if len(bufs[i].Bytes) == 0 {
continue
}
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
Expand All @@ -284,14 +285,14 @@ func (s *StdNetBind) receiveIP(
}

func (s *StdNetBind) makeReceiveIPv4(pc *ipv4.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
return func(bufs []iobuf.View, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, eps)
}
}

func (s *StdNetBind) makeReceiveIPv6(pc *ipv6.PacketConn, conn *net.UDPConn, rxOffload bool) ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, sizes, eps)
return func(bufs []iobuf.View, eps []Endpoint) (n int, err error) {
return s.receiveIP(pc, conn, rxOffload, bufs, eps)
}
}

Expand Down
8 changes: 4 additions & 4 deletions conn/bind_std_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net"
"testing"

"github.com/tailscale/wireguard-go/iobuf"
"golang.org/x/net/ipv6"
)

Expand All @@ -15,15 +16,14 @@ func TestStdNetBindReceiveFuncAfterClose(t *testing.T) {
t.Fatal(err)
}
bind.Close()
bufs := make([][]byte, 1)
bufs[0] = make([]byte, 1)
sizes := make([]int, 1)
bufs := make([]iobuf.View, 1)
bufs[0] = iobuf.View{Bytes: make([]byte, 1)}
eps := make([]Endpoint, 1)
for _, fn := range fns {
// The ReceiveFuncs must not access conn-related fields on StdNetBind
// unguarded. Close() nils the conn-related fields resulting in a panic
// if they violate the mutex.
fn(bufs, sizes, eps)
fn(bufs, eps)
}
}

Expand Down
15 changes: 9 additions & 6 deletions conn/bind_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"golang.org/x/sys/windows"

"github.com/tailscale/wireguard-go/conn/winrio"
"github.com/tailscale/wireguard-go/iobuf"
)

const (
Expand Down Expand Up @@ -416,20 +417,22 @@ retry:
return n, &ep, nil
}

func (bind *WinRingBind) receiveIPv4(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
func (bind *WinRingBind) receiveIPv4(bufs []iobuf.View, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
n, ep, err := bind.v4.Receive(bufs[0], &bind.isOpen)
sizes[0] = n
iobuf.EnsureAllocated(bufs[:1])
n, ep, err := bind.v4.Receive(bufs[0].Bytes, &bind.isOpen)
bufs[0].Bytes = bufs[0].Bytes[:n]
eps[0] = ep
return 1, err
}

func (bind *WinRingBind) receiveIPv6(bufs [][]byte, sizes []int, eps []Endpoint) (int, error) {
func (bind *WinRingBind) receiveIPv6(bufs []iobuf.View, eps []Endpoint) (int, error) {
bind.mu.RLock()
defer bind.mu.RUnlock()
n, ep, err := bind.v6.Receive(bufs[0], &bind.isOpen)
sizes[0] = n
iobuf.EnsureAllocated(bufs[:1])
n, ep, err := bind.v6.Receive(bufs[0].Bytes, &bind.isOpen)
bufs[0].Bytes = bufs[0].Bytes[:n]
eps[0] = ep
return 1, err
}
Expand Down
8 changes: 5 additions & 3 deletions conn/bindtest/bindtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ import (
"os"

"github.com/tailscale/wireguard-go/conn"
"github.com/tailscale/wireguard-go/iobuf"
)

type ChannelBind struct {
Expand Down Expand Up @@ -94,13 +95,14 @@ func (c *ChannelBind) BatchSize() int { return 1 }
func (c *ChannelBind) SetMark(mark uint32) error { return nil }

func (c *ChannelBind) makeReceiveFunc(ch chan []byte) conn.ReceiveFunc {
return func(bufs [][]byte, sizes []int, eps []conn.Endpoint) (n int, err error) {
return func(bufs []iobuf.View, eps []conn.Endpoint) (n int, err error) {
select {
case <-c.closeSignal:
return 0, net.ErrClosed
case rx := <-ch:
copied := copy(bufs[0], rx)
sizes[0] = copied
iobuf.EnsureAllocated(bufs[:1])
n := copy(bufs[0].Bytes, rx)
bufs[0].Bytes = bufs[0].Bytes[:n]
eps[0] = c.target6
return 1, nil
}
Expand Down
15 changes: 8 additions & 7 deletions conn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,19 +13,20 @@ import (
"reflect"
"runtime"
"strings"

"github.com/tailscale/wireguard-go/iobuf"
)

const (
IdealBatchSize = 128 // maximum number of packets handled per read and write
)

// A ReceiveFunc receives at least one packet from the network and writes them
// into packets. On a successful read it returns the number of elements of
// sizes, packets, and endpoints that should be evaluated. Some elements of
// sizes may be zero, and callers should ignore them. Callers must pass a sizes
// and eps slice with a length greater than or equal to the length of packets.
// These lengths must not exceed the length of the associated Bind.BatchSize().
type ReceiveFunc func(packets [][]byte, sizes []int, eps []Endpoint) (n int, err error)
// A ReceiveFunc receives at least one packet from the network into bufs.
// On a successful read it returns the number of elements of bufs and eps
// that should be evaluated. Callers must pass an eps slice with a length
// greater than or equal to the length of bufs. These lengths must not
// exceed the length of the associated Bind.BatchSize().
type ReceiveFunc func(bufs []iobuf.View, eps []Endpoint) (n int, err error)

// A Bind listens on a port for both IPv6 and IPv4 UDP traffic.
//
Expand Down
4 changes: 3 additions & 1 deletion conn/conn_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@ package conn

import (
"testing"

"github.com/tailscale/wireguard-go/iobuf"
)

func TestPrettyName(t *testing.T) {
var (
recvFunc ReceiveFunc = func(bufs [][]byte, sizes []int, eps []Endpoint) (n int, err error) { return }
recvFunc ReceiveFunc = func(bufs []iobuf.View, eps []Endpoint) (n int, err error) { return }
)

const want = "TestPrettyName"
Expand Down
4 changes: 2 additions & 2 deletions device/channels.go
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ func (device *Device) flushInboundQueue(q *autodrainingInboundQueue) {
case elemsContainer := <-q.c:
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
elem.buffer.Release()
device.PutInboundElement(elem)
}
device.PutInboundElementsContainer(elemsContainer)
Expand Down Expand Up @@ -126,7 +126,7 @@ func (device *Device) flushOutboundQueue(q *autodrainingOutboundQueue) {
case elemsContainer := <-q.c:
elemsContainer.Lock()
for _, elem := range elemsContainer.elems {
device.PutMessageBuffer(elem.buffer)
elem.buffer.Release()
device.PutOutboundElement(elem)
}
device.PutOutboundElementsContainer(elemsContainer)
Expand Down
8 changes: 5 additions & 3 deletions device/constants.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ package device

import (
"time"

"github.com/tailscale/wireguard-go/iobuf"
)

/* Specification constants */
Expand All @@ -27,9 +29,9 @@ const (
)

const (
MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive)
MaxMessageSize = MaxSegmentSize // maximum size of transport message
MaxContentSize = MaxSegmentSize - MessageTransportSize - MessageEncapsulatingTransportSize // maximum size of transport message content
MinMessageSize = MessageKeepaliveSize // minimum size of transport message (keepalive)
MaxMessageSize = iobuf.MaxReadSize // maximum size of transport message
MaxContentSize = iobuf.MaxReadSize - MessageTransportSize - MessageEncapsulatingTransportSize // maximum size of transport message content
)

/* Implementation constants */
Expand Down
10 changes: 5 additions & 5 deletions device/device.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/tailscale/wireguard-go/ratelimiter"
"github.com/tailscale/wireguard-go/rwcancel"
"github.com/tailscale/wireguard-go/tun"
"github.com/tailscale/wireguard-go/waitpool"
)

type Device struct {
Expand Down Expand Up @@ -71,11 +72,10 @@ type Device struct {
cookieChecker CookieChecker

pool struct {
inboundElementsContainer *WaitPool
outboundElementsContainer *WaitPool
messageBuffers *WaitPool
inboundElements *WaitPool
outboundElements *WaitPool
inboundElementsContainer *waitpool.WaitPool
outboundElementsContainer *waitpool.WaitPool
inboundElements *waitpool.WaitPool
outboundElements *waitpool.WaitPool
}

queue struct {
Expand Down
3 changes: 2 additions & 1 deletion device/device_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/tailscale/wireguard-go/conn"
"github.com/tailscale/wireguard-go/conn/bindtest"
"github.com/tailscale/wireguard-go/iobuf"
"github.com/tailscale/wireguard-go/tun"
"github.com/tailscale/wireguard-go/tun/tuntest"
)
Expand Down Expand Up @@ -437,7 +438,7 @@ type fakeTUNDeviceSized struct {
}

func (t *fakeTUNDeviceSized) File() *os.File { return nil }
func (t *fakeTUNDeviceSized) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
func (t *fakeTUNDeviceSized) Read(bufs []iobuf.View, offset int) (n int, err error) {
return 0, nil
}
func (t *fakeTUNDeviceSized) Write(bufs [][]byte, offset int) (int, error) { return 0, nil }
Expand Down
60 changes: 8 additions & 52 deletions device/pools.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,61 +7,24 @@ package device

import (
"sync"
)

type WaitPool struct {
pool sync.Pool
cond sync.Cond
lock sync.Mutex
count uint32 // Get calls not yet Put back
max uint32
}

func NewWaitPool(max uint32, new func() any) *WaitPool {
p := &WaitPool{pool: sync.Pool{New: new}, max: max}
p.cond = sync.Cond{L: &p.lock}
return p
}

func (p *WaitPool) Get() any {
if p.max != 0 {
p.lock.Lock()
for p.count >= p.max {
p.cond.Wait()
}
p.count++
p.lock.Unlock()
}
return p.pool.Get()
}

func (p *WaitPool) Put(x any) {
p.pool.Put(x)
if p.max == 0 {
return
}
p.lock.Lock()
defer p.lock.Unlock()
p.count--
p.cond.Signal()
}
"github.com/tailscale/wireguard-go/iobuf"
"github.com/tailscale/wireguard-go/waitpool"
)

func (device *Device) PopulatePools() {
device.pool.inboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
device.pool.inboundElementsContainer = waitpool.New(iobuf.MaxPooledBuffers, func() any {
s := make([]*QueueInboundElement, 0, device.BatchSize())
return &QueueInboundElementsContainer{elems: s}
})
device.pool.outboundElementsContainer = NewWaitPool(PreallocatedBuffersPerPool, func() any {
device.pool.outboundElementsContainer = waitpool.New(iobuf.MaxPooledBuffers, func() any {
s := make([]*QueueOutboundElement, 0, device.BatchSize())
return &QueueOutboundElementsContainer{elems: s}
})
device.pool.messageBuffers = NewWaitPool(PreallocatedBuffersPerPool, func() any {
return new([MaxMessageSize]byte)
})
device.pool.inboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
device.pool.inboundElements = waitpool.New(iobuf.MaxPooledBuffers, func() any {
return new(QueueInboundElement)
})
device.pool.outboundElements = NewWaitPool(PreallocatedBuffersPerPool, func() any {
device.pool.outboundElements = waitpool.New(iobuf.MaxPooledBuffers, func() any {
return new(QueueOutboundElement)
})
}
Expand Down Expand Up @@ -94,14 +57,6 @@ func (device *Device) PutOutboundElementsContainer(c *QueueOutboundElementsConta
device.pool.outboundElementsContainer.Put(c)
}

func (device *Device) GetMessageBuffer() *[MaxMessageSize]byte {
return device.pool.messageBuffers.Get().(*[MaxMessageSize]byte)
}

func (device *Device) PutMessageBuffer(msg *[MaxMessageSize]byte) {
device.pool.messageBuffers.Put(msg)
}

func (device *Device) GetInboundElement() *QueueInboundElement {
return device.pool.inboundElements.Get().(*QueueInboundElement)
}
Expand All @@ -117,5 +72,6 @@ func (device *Device) GetOutboundElement() *QueueOutboundElement {

func (device *Device) PutOutboundElement(elem *QueueOutboundElement) {
elem.clearPointers()
elem.nonce = 0
device.pool.outboundElements.Put(elem)
}
10 changes: 4 additions & 6 deletions device/queueconstants_android.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ import "github.com/tailscale/wireguard-go/conn"
/* Reduce memory consumption for Android */

const (
QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = 2200
PreallocatedBuffersPerPool = 4096
QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
)
10 changes: 4 additions & 6 deletions device/queueconstants_default.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,8 @@ package device
import "github.com/tailscale/wireguard-go/conn"

const (
QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
MaxSegmentSize = (1 << 16) - 1 // largest possible UDP datagram
PreallocatedBuffersPerPool = 0 // Disable and allow for infinite memory growth
QueueStagedSize = conn.IdealBatchSize
QueueOutboundSize = 1024
QueueInboundSize = 1024
QueueHandshakeSize = 1024
)
Loading
Loading