Skip to content

Commit

Permalink
refactor(dslx): unify TLS and QUIC handshake options (#1378)
Browse files Browse the repository at this point in the history
  • Loading branch information
bassosimone authored Oct 25, 2023
1 parent 18cef86 commit b528666
Show file tree
Hide file tree
Showing 4 changed files with 124 additions and 276 deletions.
83 changes: 10 additions & 73 deletions internal/dslx/quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,71 +7,31 @@ package dslx
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"net"
"time"

"github.com/ooni/probe-cli/v3/internal/logx"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/quic-go/quic-go"
)

// QUICHandshakeOption is an option you can pass to QUICHandshake.
type QUICHandshakeOption func(*quicHandshakeFunc)

// QUICHandshakeOptionInsecureSkipVerify controls whether QUIC verification is enabled.
func QUICHandshakeOptionInsecureSkipVerify(value bool) QUICHandshakeOption {
return func(thf *quicHandshakeFunc) {
thf.InsecureSkipVerify = value
}
}

// QUICHandshakeOptionRootCAs allows to configure custom root CAs.
func QUICHandshakeOptionRootCAs(value *x509.CertPool) QUICHandshakeOption {
return func(thf *quicHandshakeFunc) {
thf.RootCAs = value
}
}

// QUICHandshakeOptionServerName allows to configure the SNI to use.
func QUICHandshakeOptionServerName(value string) QUICHandshakeOption {
return func(thf *quicHandshakeFunc) {
thf.ServerName = value
}
}

// QUICHandshake returns a function performing QUIC handshakes.
func QUICHandshake(rt Runtime, options ...QUICHandshakeOption) Func[
func QUICHandshake(rt Runtime, options ...TLSHandshakeOption) Func[
*Endpoint, *Maybe[*QUICConnection]] {
// See https://github.com/ooni/probe/issues/2413 to understand
// why we're using nil to force netxlite to use the cached
// default Mozilla cert pool.
f := &quicHandshakeFunc{
InsecureSkipVerify: false,
RootCAs: nil,
Rt: rt,
ServerName: "",
}
for _, option := range options {
option(f)
Options: options,
Rt: rt,
}
return f
}

// quicHandshakeFunc performs QUIC handshakes.
type quicHandshakeFunc struct {
// InsecureSkipVerify allows to skip TLS verification.
InsecureSkipVerify bool

// RootCAs contains the Root CAs to use.
RootCAs *x509.CertPool
// Options contains the options.
Options []TLSHandshakeOption

// Rt is the Runtime that owns us.
Rt Runtime

// ServerName is the ServerName to handshake for.
ServerName string
}

// Apply implements Func.
Expand All @@ -80,27 +40,22 @@ func (f *quicHandshakeFunc) Apply(
// create trace
trace := f.Rt.NewTrace(f.Rt.IDGenerator().Add(1), f.Rt.ZeroTime(), input.Tags...)

// use defaults or user-configured overrides
serverName := f.serverName(input)
// create a suitable TLS configuration
config := tlsNewConfig(input.Address, []string{"h3"}, input.Domain, f.Rt.Logger(), f.Options...)

// start the operation logger
ol := logx.NewOperationLogger(
f.Rt.Logger(),
"[#%d] QUICHandshake with %s SNI=%s",
"[#%d] QUICHandshake with %s SNI=%s ALPN=%v",
trace.Index(),
input.Address,
serverName,
config.ServerName,
config.NextProtos,
)

// setup
udpListener := netxlite.NewUDPListener()
quicDialer := trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger())
config := &tls.Config{
NextProtos: []string{"h3"},
InsecureSkipVerify: f.InsecureSkipVerify,
RootCAs: f.RootCAs,
ServerName: serverName,
}
const timeout = 10 * time.Second
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
Expand Down Expand Up @@ -139,24 +94,6 @@ func (f *quicHandshakeFunc) Apply(
}
}

func (f *quicHandshakeFunc) serverName(input *Endpoint) string {
if f.ServerName != "" {
return f.ServerName
}
if input.Domain != "" {
return input.Domain
}
addr, _, err := net.SplitHostPort(input.Address)
if err == nil {
return addr
}
// Note: golang requires a ServerName and fails if it's empty. If the provided
// ServerName is an IP address, however, golang WILL NOT emit any SNI extension
// in the ClientHello, consistently with RFC 6066 Section 3 requirements.
f.Rt.Logger().Warn("TLSHandshake: cannot determine which SNI to use")
return ""
}

// QUICConnection is an established QUIC connection. If you initialize
// manually, init at least the ones marked as MANDATORY.
type QUICConnection struct {
Expand Down
66 changes: 1 addition & 65 deletions internal/dslx/quic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,6 @@ func TestQUICHandshake(t *testing.T) {

f := QUICHandshake(
NewMinimalRuntime(model.DiscardLogger, time.Now()),
QUICHandshakeOptionInsecureSkipVerify(true),
QUICHandshakeOptionServerName("sni"),
QUICHandshakeOptionRootCAs(certpool),
)
if _, ok := f.(*quicHandshakeFunc); !ok {
t.Fatal("unexpected type. Expected: quicHandshakeFunc")
Expand Down Expand Up @@ -103,10 +100,7 @@ func TestQUICHandshake(t *testing.T) {
return tt.dialer
},
}))
quicHandshake := &quicHandshakeFunc{
Rt: rt,
ServerName: tt.sni,
}
quicHandshake := QUICHandshake(rt, TLSHandshakeOptionServerName(tt.sni))
endpoint := &Endpoint{
Address: "1.2.3.4:567",
Network: "udp",
Expand Down Expand Up @@ -136,61 +130,3 @@ func TestQUICHandshake(t *testing.T) {
}
})
}

/*
Test cases:
- With input SNI
- With input domain
- With input host address
- With input IP address
*/
func TestServerNameQUIC(t *testing.T) {
t.Run("With input SNI", func(t *testing.T) {
sni := "sni"
endpoint := &Endpoint{
Address: "example.com:123",
}
f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), ServerName: sni}
serverName := f.serverName(endpoint)
if serverName != sni {
t.Fatalf("unexpected server name: %s", serverName)
}
})

t.Run("With input domain", func(t *testing.T) {
domain := "domain"
endpoint := &Endpoint{
Address: "example.com:123",
Domain: domain,
}
f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())}
serverName := f.serverName(endpoint)
if serverName != domain {
t.Fatalf("unexpected server name: %s", serverName)
}
})

t.Run("With input host address", func(t *testing.T) {
hostaddr := "example.com"
endpoint := &Endpoint{
Address: hostaddr + ":123",
}
f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())}
serverName := f.serverName(endpoint)
if serverName != hostaddr {
t.Fatalf("unexpected server name: %s", serverName)
}
})

t.Run("With input IP address", func(t *testing.T) {
ip := "1.1.1.1"
endpoint := &Endpoint{
Address: ip,
}
f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())}
serverName := f.serverName(endpoint)
if serverName != "" {
t.Fatalf("unexpected server name: %s", serverName)
}
})
}
Loading

0 comments on commit b528666

Please sign in to comment.