Skip to content

Commit

Permalink
quicreuse: make it possible to use an application-constructed quic.Tr…
Browse files Browse the repository at this point in the history
…ansport (#3122)

Co-authored-by: Marco Munizaga <[email protected]>
  • Loading branch information
marten-seemann and MarcoPolo authored Jan 10, 2025
1 parent a2993c1 commit 4651a0d
Show file tree
Hide file tree
Showing 5 changed files with 188 additions and 33 deletions.
65 changes: 59 additions & 6 deletions p2p/transport/quicreuse/connmgr.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"crypto/tls"
"errors"
"io"
"net"
"sync"

Expand All @@ -15,6 +16,22 @@ import (
quicmetrics "github.com/quic-go/quic-go/metrics"
)

type QUICListener interface {
Accept(ctx context.Context) (quic.Connection, error)
Close() error
Addr() net.Addr
}

var _ QUICListener = &quic.Listener{}

type QUICTransport interface {
Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error)
Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *quic.Config) (quic.Connection, error)
WriteTo(b []byte, addr net.Addr) (int, error)
ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error)
io.Closer
}

type ConnManager struct {
reuseUDP4 *reuse
reuseUDP6 *reuse
Expand Down Expand Up @@ -101,6 +118,32 @@ func (c *ConnManager) getReuse(network string) (*reuse, error) {
}
}

// LendTransport is an advanced method used to lend an existing QUICTransport
// to the ConnManager. The ConnManager will close the returned channel when it
// is done with the transport, so that the owner may safely close the transport.
func (c *ConnManager) LendTransport(network string, tr QUICTransport, conn net.PacketConn) (<-chan struct{}, error) {
c.quicListenersMu.Lock()
defer c.quicListenersMu.Unlock()

localAddr, ok := conn.LocalAddr().(*net.UDPAddr)
if !ok {
return nil, errors.New("expected a conn.LocalAddr() to return a *net.UDPAddr")
}

refCountedTr := &refcountedTransport{
QUICTransport: tr,
packetConn: conn,
borrowDoneSignal: make(chan struct{}),
}

var reuse *reuse
reuse, err := c.getReuse(network)
if err != nil {
return nil, err
}
return refCountedTr.borrowDoneSignal, reuse.AddTransport(refCountedTr, localAddr)
}

func (c *ConnManager) ListenQUIC(addr ma.Multiaddr, tlsConf *tls.Config, allowWindowIncrease func(conn quic.Connection, delta uint64) bool) (Listener, error) {
return c.ListenQUICAndAssociate(nil, addr, tlsConf, allowWindowIncrease)
}
Expand Down Expand Up @@ -175,7 +218,7 @@ func (c *ConnManager) SharedNonQUICPacketConn(network string, laddr *net.UDPAddr
ctx: ctx,
ctxCancel: cancel,
owningTransport: t,
tr: &t.Transport,
tr: t.QUICTransport,
}, nil
}
return nil, errors.New("expected to be able to share with a QUIC listener, but the QUIC listener is not using a refcountedTransport. `DisableReuseport` should not be set")
Expand All @@ -201,10 +244,12 @@ func (c *ConnManager) transportForListen(association any, network string, laddr
}
return &singleOwnerTransport{
packetConn: conn,
Transport: quic.Transport{
Conn: conn,
StatelessResetKey: &c.srk,
TokenGeneratorKey: &c.tokenKey,
Transport: &wrappedQUICTransport{
&quic.Transport{
Conn: conn,
StatelessResetKey: &c.srk,
TokenGeneratorKey: &c.tokenKey,
},
},
}, nil
}
Expand Down Expand Up @@ -279,7 +324,7 @@ func (c *ConnManager) TransportWithAssociationForDial(association any, network s
if err != nil {
return nil, err
}
return &singleOwnerTransport{Transport: quic.Transport{Conn: conn, StatelessResetKey: &c.srk}, packetConn: conn}, nil
return &singleOwnerTransport{Transport: &wrappedQUICTransport{&quic.Transport{Conn: conn, StatelessResetKey: &c.srk}}, packetConn: conn}, nil
}

func (c *ConnManager) Protocols() []int {
Expand All @@ -299,3 +344,11 @@ func (c *ConnManager) Close() error {
func (c *ConnManager) ClientConfig() *quic.Config {
return c.clientConfig
}

type wrappedQUICTransport struct {
*quic.Transport
}

func (t *wrappedQUICTransport) Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error) {
return t.Transport.Listen(tlsConf, conf)
}
62 changes: 60 additions & 2 deletions p2p/transport/quicreuse/connmgr_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ func TestConnectionPassedToQUICForListening(t *testing.T) {
quicTr, err := cm.transportForListen(nil, netw, naddr)
require.NoError(t, err)
defer quicTr.Close()
if _, ok := quicTr.(*singleOwnerTransport).Transport.Conn.(quic.OOBCapablePacketConn); !ok {
if _, ok := quicTr.(*singleOwnerTransport).packetConn.(quic.OOBCapablePacketConn); !ok {
t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn")
}
}
Expand Down Expand Up @@ -156,7 +156,7 @@ func TestConnectionPassedToQUICForDialing(t *testing.T) {

require.NoError(t, err, "dial error")
defer quicTr.Close()
if _, ok := quicTr.(*singleOwnerTransport).Transport.Conn.(quic.OOBCapablePacketConn); !ok {
if _, ok := quicTr.(*singleOwnerTransport).packetConn.(quic.OOBCapablePacketConn); !ok {
t.Fatal("connection passed to quic-go cannot be type asserted to a *net.UDPConn")
}
}
Expand Down Expand Up @@ -257,3 +257,61 @@ func testListener(t *testing.T, enableReuseport bool) {

checkClosed(t, cm)
}

func TestExternalTransport(t *testing.T) {
conn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero})
require.NoError(t, err)
defer conn.Close()
port := conn.LocalAddr().(*net.UDPAddr).Port
tr := &quic.Transport{Conn: conn}
defer tr.Close()

cm, err := NewConnManager(quic.StatelessResetKey{}, quic.TokenGeneratorKey{})
require.NoError(t, err)
doneWithTr, err := cm.LendTransport("udp4", &wrappedQUICTransport{tr}, conn)
require.NoError(t, err)

// make sure this transport is used when listening on the same port
ln, err := cm.ListenQUICAndAssociate(
"quic",
ma.StringCast(fmt.Sprintf("/ip4/0.0.0.0/udp/%d", port)),
&tls.Config{NextProtos: []string{"libp2p"}},
func(quic.Connection, uint64) bool { return false },
)
require.NoError(t, err)
defer ln.Close()
require.Equal(t, port, ln.Addr().(*net.UDPAddr).Port)

// make sure this transport is used when dialing out
udpLn, err := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4(127, 0, 0, 1)})
require.NoError(t, err)
defer udpLn.Close()
addrChan := make(chan net.Addr, 1)
go func() {
_, addr, _ := udpLn.ReadFrom(make([]byte, 2000))
addrChan <- addr
}()
ctx, cancel := context.WithTimeout(context.Background(), 20*time.Millisecond)
defer cancel()
_, err = cm.DialQUIC(
ctx,
ma.StringCast(fmt.Sprintf("/ip4/127.0.0.1/udp/%d/quic-v1", udpLn.LocalAddr().(*net.UDPAddr).Port)),
&tls.Config{NextProtos: []string{"libp2p"}},
func(quic.Connection, uint64) bool { return false },
)
require.ErrorIs(t, err, context.DeadlineExceeded)

select {
case addr := <-addrChan:
require.Equal(t, port, addr.(*net.UDPAddr).Port)
case <-time.After(time.Second):
t.Fatal("timeout")
}

cm.Close()
select {
case <-doneWithTr:
default:
t.Fatal("doneWithTr not closed")
}
}
2 changes: 1 addition & 1 deletion p2p/transport/quicreuse/listener.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ type protoConf struct {
}

type quicListener struct {
l *quic.Listener
l QUICListener
transport refCountedQuicTransport
running chan struct{}
addrs []ma.Multiaddr
Expand Down
6 changes: 2 additions & 4 deletions p2p/transport/quicreuse/nonquic_packetconn.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,14 @@ import (
"context"
"net"
"time"

"github.com/quic-go/quic-go"
)

// nonQUICPacketConn is a net.PacketConn that can be used to read and write
// non-QUIC packets on a quic.Transport. This lets us reuse this UDP port for
// other transports like WebRTC.
type nonQUICPacketConn struct {
owningTransport refCountedQuicTransport
tr *quic.Transport
tr QUICTransport
ctx context.Context
ctxCancel context.CancelFunc
readCtx context.Context
Expand All @@ -32,7 +30,7 @@ func (n *nonQUICPacketConn) Close() error {

// LocalAddr implements net.PacketConn.
func (n *nonQUICPacketConn) LocalAddr() net.Addr {
return n.tr.Conn.LocalAddr()
return n.owningTransport.LocalAddr()
}

// ReadFrom implements net.PacketConn.
Expand Down
86 changes: 66 additions & 20 deletions p2p/transport/quicreuse/reuse.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ package quicreuse
import (
"context"
"crypto/tls"
"errors"
"fmt"
"net"
"sync"
"time"
Expand All @@ -25,23 +27,30 @@ type refCountedQuicTransport interface {
IncreaseCount()

Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *quic.Config) (quic.Connection, error)
Listen(tlsConf *tls.Config, conf *quic.Config) (*quic.Listener, error)
Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error)
}

type singleOwnerTransport struct {
quic.Transport
Transport QUICTransport

// Used to write packets directly around QUIC.
packetConn net.PacketConn
}

var _ QUICTransport = &singleOwnerTransport{}

func (c *singleOwnerTransport) IncreaseCount() {}
func (c *singleOwnerTransport) DecreaseCount() {
c.Transport.Close()
func (c *singleOwnerTransport) DecreaseCount() { c.Transport.Close() }
func (c *singleOwnerTransport) LocalAddr() net.Addr {
return c.packetConn.LocalAddr()
}

func (c *singleOwnerTransport) LocalAddr() net.Addr {
return c.Transport.Conn.LocalAddr()
func (c *singleOwnerTransport) Dial(ctx context.Context, addr net.Addr, tlsConf *tls.Config, conf *quic.Config) (quic.Connection, error) {
return c.Transport.Dial(ctx, addr, tlsConf, conf)
}

func (c *singleOwnerTransport) ReadNonQUICPacket(ctx context.Context, b []byte) (int, net.Addr, error) {
return c.Transport.ReadNonQUICPacket(ctx, b)
}

func (c *singleOwnerTransport) Close() error {
Expand All @@ -54,14 +63,18 @@ func (c *singleOwnerTransport) WriteTo(b []byte, addr net.Addr) (int, error) {
return c.Transport.WriteTo(b, addr)
}

func (c *singleOwnerTransport) Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error) {
return c.Transport.Listen(tlsConf, conf)
}

// Constant. Defined as variables to simplify testing.
var (
garbageCollectInterval = 30 * time.Second
maxUnusedDuration = 10 * time.Second
)

type refcountedTransport struct {
quic.Transport
QUICTransport

// Used to write packets directly around QUIC.
packetConn net.PacketConn
Expand All @@ -70,6 +83,11 @@ type refcountedTransport struct {
refCount int
unusedSince time.Time

// Only set for transports we are borrowing.
// If set, we will _never_ close the underlying transport. We only close this
// channel to signal to the owner that we are done with it.
borrowDoneSignal chan struct{}

assocations map[any]struct{}
}

Expand Down Expand Up @@ -109,17 +127,24 @@ func (c *refcountedTransport) IncreaseCount() {
}

func (c *refcountedTransport) Close() error {
// TODO(when we drop support for go 1.19) use errors.Join
c.Transport.Close()
return c.packetConn.Close()
if c.borrowDoneSignal != nil {
close(c.borrowDoneSignal)
return nil
}

return errors.Join(c.QUICTransport.Close(), c.packetConn.Close())
}

func (c *refcountedTransport) WriteTo(b []byte, addr net.Addr) (int, error) {
return c.Transport.WriteTo(b, addr)
return c.QUICTransport.WriteTo(b, addr)
}

func (c *refcountedTransport) LocalAddr() net.Addr {
return c.Transport.Conn.LocalAddr()
return c.packetConn.LocalAddr()
}

func (c *refcountedTransport) Listen(tlsConf *tls.Config, conf *quic.Config) (QUICListener, error) {
return c.QUICTransport.Listen(tlsConf, conf)
}

func (c *refcountedTransport) DecreaseCount() {
Expand Down Expand Up @@ -302,15 +327,34 @@ func (r *reuse) transportForDialLocked(association any, network string, source *
if err != nil {
return nil, err
}
tr := &refcountedTransport{Transport: quic.Transport{
Conn: conn,
StatelessResetKey: r.statelessResetKey,
TokenGeneratorKey: r.tokenGeneratorKey,
}, packetConn: conn}
tr := &refcountedTransport{
QUICTransport: &wrappedQUICTransport{
Transport: &quic.Transport{
Conn: conn,
StatelessResetKey: r.statelessResetKey,
TokenGeneratorKey: r.tokenGeneratorKey,
},
},
packetConn: conn,
}
r.globalDialers[conn.LocalAddr().(*net.UDPAddr).Port] = tr
return tr, nil
}

func (r *reuse) AddTransport(tr *refcountedTransport, laddr *net.UDPAddr) error {
r.mutex.Lock()
defer r.mutex.Unlock()

if !laddr.IP.IsUnspecified() {
return errors.New("adding transport for specific IP not supported")
}
if _, ok := r.globalDialers[laddr.Port]; ok {
return fmt.Errorf("already have global dialer for port %d", laddr.Port)
}
r.globalDialers[laddr.Port] = tr
return nil
}

func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcountedTransport, error) {
r.mutex.Lock()
defer r.mutex.Unlock()
Expand Down Expand Up @@ -351,9 +395,11 @@ func (r *reuse) TransportForListen(network string, laddr *net.UDPAddr) (*refcoun
}
localAddr := conn.LocalAddr().(*net.UDPAddr)
tr := &refcountedTransport{
Transport: quic.Transport{
Conn: conn,
StatelessResetKey: r.statelessResetKey,
QUICTransport: &wrappedQUICTransport{
Transport: &quic.Transport{
Conn: conn,
StatelessResetKey: r.statelessResetKey,
},
},
packetConn: conn,
}
Expand Down

0 comments on commit 4651a0d

Please sign in to comment.