Skip to content

Commit

Permalink
refactor(dslx): rewrite QUICHandshake using TLSHandshakeOption
Browse files Browse the repository at this point in the history
This diff takes advantage of the fact that now the TLSHandshakeOption
are independent of the tlsHandshakeFunc structure, so we can use the
same options for configuring the QUIC handshake.
  • Loading branch information
bassosimone committed Oct 18, 2023
1 parent 57243bb commit e1f5bc9
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 153 deletions.
85 changes: 11 additions & 74 deletions internal/dslx/quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,71 +7,31 @@ package dslx
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"net"
"time"

"github.com/ooni/probe-cli/v3/internal/logx"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/quic-go/quic-go"
)

// QUICHandshakeOption is an option you can pass to QUICHandshake.
type QUICHandshakeOption func(*quicHandshakeFunc)

// QUICHandshakeOptionInsecureSkipVerify controls whether QUIC verification is enabled.
func QUICHandshakeOptionInsecureSkipVerify(value bool) QUICHandshakeOption {
return func(thf *quicHandshakeFunc) {
thf.InsecureSkipVerify = value
}
}

// QUICHandshakeOptionRootCAs allows to configure custom root CAs.
func QUICHandshakeOptionRootCAs(value *x509.CertPool) QUICHandshakeOption {
return func(thf *quicHandshakeFunc) {
thf.RootCAs = value
}
}

// QUICHandshakeOptionServerName allows to configure the SNI to use.
func QUICHandshakeOptionServerName(value string) QUICHandshakeOption {
return func(thf *quicHandshakeFunc) {
thf.ServerName = value
}
}

// QUICHandshake returns a function performing QUIC handshakes.
func QUICHandshake(rt Runtime, options ...QUICHandshakeOption) Func[
func QUICHandshake(rt Runtime, options ...TLSHandshakeOption) Func[
*Endpoint, *Maybe[*QUICConnection]] {
// See https://github.com/ooni/probe/issues/2413 to understand
// why we're using nil to force netxlite to use the cached
// default Mozilla cert pool.
f := &quicHandshakeFunc{
InsecureSkipVerify: false,
RootCAs: nil,
Rt: rt,
ServerName: "",
}
for _, option := range options {
option(f)
Options: options,
Rt: rt,
}
return f
}

// quicHandshakeFunc performs QUIC handshakes.
type quicHandshakeFunc struct {
// InsecureSkipVerify allows to skip TLS verification.
InsecureSkipVerify bool

// RootCAs contains the Root CAs to use.
RootCAs *x509.CertPool
// Options contains the options.
Options []TLSHandshakeOption

// Rt is the Runtime that owns us.
// Rt is the runtime that owns us.
Rt Runtime

// ServerName is the ServerName to handshake for.
ServerName string
}

// Apply implements Func.
Expand All @@ -80,27 +40,22 @@ func (f *quicHandshakeFunc) Apply(
// create trace
trace := f.Rt.NewTrace(f.Rt.IDGenerator().Add(1), f.Rt.ZeroTime(), input.Tags...)

// use defaults or user-configured overrides
serverName := f.serverName(input)
// create a suitable TLS configuration
config := tlsNewConfig(input.Address, []string{"h3"}, input.Domain, f.Rt.Logger(), f.Options...)

// start the operation logger
ol := logx.NewOperationLogger(
f.Rt.Logger(),
"[#%d] QUICHandshake with %s SNI=%s",
"[#%d] QUICHandshake with %s SNI=%s ALPN=%v",
trace.Index(),
input.Address,
serverName,
config.ServerName,
config.NextProtos,
)

// setup
udpListener := netxlite.NewUDPListener()
quicDialer := trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger())
config := &tls.Config{
NextProtos: []string{"h3"},
InsecureSkipVerify: f.InsecureSkipVerify,
RootCAs: f.RootCAs,
ServerName: serverName,
}
const timeout = 10 * time.Second
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
Expand Down Expand Up @@ -139,24 +94,6 @@ func (f *quicHandshakeFunc) Apply(
}
}

func (f *quicHandshakeFunc) serverName(input *Endpoint) string {
if f.ServerName != "" {
return f.ServerName
}
if input.Domain != "" {
return input.Domain
}
addr, _, err := net.SplitHostPort(input.Address)
if err == nil {
return addr
}
// Note: golang requires a ServerName and fails if it's empty. If the provided
// ServerName is an IP address, however, golang WILL NOT emit any SNI extension
// in the ClientHello, consistently with RFC 6066 Section 3 requirements.
f.Rt.Logger().Warn("TLSHandshake: cannot determine which SNI to use")
return ""
}

// QUICConnection is an established QUIC connection. If you initialize
// manually, init at least the ones marked as MANDATORY.
type QUICConnection struct {
Expand Down
80 changes: 1 addition & 79 deletions internal/dslx/quic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package dslx
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"testing"
"time"
Expand All @@ -16,28 +15,12 @@ import (

/*
Test cases:
- Get quicHandshakeFunc with options
- Apply quicHandshakeFunc:
- with EOF
- success
- with sni
*/
func TestQUICHandshake(t *testing.T) {
t.Run("Get quicHandshakeFunc with options", func(t *testing.T) {
certpool := x509.NewCertPool()
certpool.AddCert(&x509.Certificate{})

f := QUICHandshake(
NewMinimalRuntime(model.DiscardLogger, time.Now()),
QUICHandshakeOptionInsecureSkipVerify(true),
QUICHandshakeOptionServerName("sni"),
QUICHandshakeOptionRootCAs(certpool),
)
if _, ok := f.(*quicHandshakeFunc); !ok {
t.Fatal("unexpected type. Expected: quicHandshakeFunc")
}
})

t.Run("Apply quicHandshakeFunc", func(t *testing.T) {
wasClosed := false
plainConn := &mocks.QUICEarlyConnection{
Expand Down Expand Up @@ -103,10 +86,7 @@ func TestQUICHandshake(t *testing.T) {
return tt.dialer
},
}))
quicHandshake := &quicHandshakeFunc{
Rt: rt,
ServerName: tt.sni,
}
quicHandshake := QUICHandshake(rt, TLSHandshakeOptionServerName(tt.sni))
endpoint := &Endpoint{
Address: "1.2.3.4:567",
Network: "udp",
Expand Down Expand Up @@ -136,61 +116,3 @@ func TestQUICHandshake(t *testing.T) {
}
})
}

/*
Test cases:
- With input SNI
- With input domain
- With input host address
- With input IP address
*/
func TestServerNameQUIC(t *testing.T) {
t.Run("With input SNI", func(t *testing.T) {
sni := "sni"
endpoint := &Endpoint{
Address: "example.com:123",
}
f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), ServerName: sni}
serverName := f.serverName(endpoint)
if serverName != sni {
t.Fatalf("unexpected server name: %s", serverName)
}
})

t.Run("With input domain", func(t *testing.T) {
domain := "domain"
endpoint := &Endpoint{
Address: "example.com:123",
Domain: domain,
}
f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())}
serverName := f.serverName(endpoint)
if serverName != domain {
t.Fatalf("unexpected server name: %s", serverName)
}
})

t.Run("With input host address", func(t *testing.T) {
hostaddr := "example.com"
endpoint := &Endpoint{
Address: hostaddr + ":123",
}
f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())}
serverName := f.serverName(endpoint)
if serverName != hostaddr {
t.Fatalf("unexpected server name: %s", serverName)
}
})

t.Run("With input IP address", func(t *testing.T) {
ip := "1.1.1.1"
endpoint := &Endpoint{
Address: ip,
}
f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())}
serverName := f.serverName(endpoint)
if serverName != "" {
t.Fatalf("unexpected server name: %s", serverName)
}
})
}

0 comments on commit e1f5bc9

Please sign in to comment.