Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(dslx): unify TLS and QUIC handshake options #1378

Merged
merged 11 commits into from
Oct 25, 2023
85 changes: 11 additions & 74 deletions internal/dslx/quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,71 +7,31 @@ package dslx
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"net"
"time"

"github.com/ooni/probe-cli/v3/internal/logx"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/quic-go/quic-go"
)

// QUICHandshakeOption is an option you can pass to QUICHandshake.
type QUICHandshakeOption func(*quicHandshakeFunc)

// QUICHandshakeOptionInsecureSkipVerify controls whether QUIC verification is enabled.
func QUICHandshakeOptionInsecureSkipVerify(value bool) QUICHandshakeOption {
return func(thf *quicHandshakeFunc) {
thf.InsecureSkipVerify = value
}
}

// QUICHandshakeOptionRootCAs allows to configure custom root CAs.
func QUICHandshakeOptionRootCAs(value *x509.CertPool) QUICHandshakeOption {
return func(thf *quicHandshakeFunc) {
thf.RootCAs = value
}
}

// QUICHandshakeOptionServerName allows to configure the SNI to use.
func QUICHandshakeOptionServerName(value string) QUICHandshakeOption {
return func(thf *quicHandshakeFunc) {
thf.ServerName = value
}
}

// QUICHandshake returns a function performing QUIC handshakes.
func QUICHandshake(rt Runtime, options ...QUICHandshakeOption) Func[
func QUICHandshake(rt Runtime, options ...TLSHandshakeOption) Func[
*Endpoint, *Maybe[*QUICConnection]] {
// See https://github.com/ooni/probe/issues/2413 to understand
// why we're using nil to force netxlite to use the cached
// default Mozilla cert pool.
f := &quicHandshakeFunc{
InsecureSkipVerify: false,
RootCAs: nil,
Rt: rt,
ServerName: "",
}
for _, option := range options {
option(f)
Options: options,
Rt: rt,
}
return f
}

// quicHandshakeFunc performs QUIC handshakes.
type quicHandshakeFunc struct {
// InsecureSkipVerify allows to skip TLS verification.
InsecureSkipVerify bool

// RootCAs contains the Root CAs to use.
RootCAs *x509.CertPool
// Options contains the options.
Options []TLSHandshakeOption

// Rt is the Runtime that owns us.
// Rt is the runtime that owns us.
bassosimone marked this conversation as resolved.
Show resolved Hide resolved
Rt Runtime

// ServerName is the ServerName to handshake for.
ServerName string
}

// Apply implements Func.
Expand All @@ -80,27 +40,22 @@ func (f *quicHandshakeFunc) Apply(
// create trace
trace := f.Rt.NewTrace(f.Rt.IDGenerator().Add(1), f.Rt.ZeroTime(), input.Tags...)

// use defaults or user-configured overrides
serverName := f.serverName(input)
// create a suitable TLS configuration
config := tlsNewConfig(input.Address, []string{"h3"}, input.Domain, f.Rt.Logger(), f.Options...)

// start the operation logger
ol := logx.NewOperationLogger(
f.Rt.Logger(),
"[#%d] QUICHandshake with %s SNI=%s",
"[#%d] QUICHandshake with %s SNI=%s ALPN=%v",
trace.Index(),
input.Address,
serverName,
config.ServerName,
config.NextProtos,
)

// setup
udpListener := netxlite.NewUDPListener()
quicDialer := trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger())
config := &tls.Config{
NextProtos: []string{"h3"},
InsecureSkipVerify: f.InsecureSkipVerify,
RootCAs: f.RootCAs,
ServerName: serverName,
}
const timeout = 10 * time.Second
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
Expand Down Expand Up @@ -139,24 +94,6 @@ func (f *quicHandshakeFunc) Apply(
}
}

func (f *quicHandshakeFunc) serverName(input *Endpoint) string {
if f.ServerName != "" {
return f.ServerName
}
if input.Domain != "" {
return input.Domain
}
addr, _, err := net.SplitHostPort(input.Address)
if err == nil {
return addr
}
// Note: golang requires a ServerName and fails if it's empty. If the provided
// ServerName is an IP address, however, golang WILL NOT emit any SNI extension
// in the ClientHello, consistently with RFC 6066 Section 3 requirements.
f.Rt.Logger().Warn("TLSHandshake: cannot determine which SNI to use")
return ""
}

// QUICConnection is an established QUIC connection. If you initialize
// manually, init at least the ones marked as MANDATORY.
type QUICConnection struct {
Expand Down
80 changes: 1 addition & 79 deletions internal/dslx/quic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package dslx
import (
"context"
"crypto/tls"
"crypto/x509"
"io"
"testing"
"time"
Expand All @@ -16,28 +15,12 @@ import (

/*
Test cases:
- Get quicHandshakeFunc with options
- Apply quicHandshakeFunc:
- with EOF
- success
- with sni
*/
func TestQUICHandshake(t *testing.T) {
t.Run("Get quicHandshakeFunc with options", func(t *testing.T) {
certpool := x509.NewCertPool()
certpool.AddCert(&x509.Certificate{})

f := QUICHandshake(
NewMinimalRuntime(model.DiscardLogger, time.Now()),
QUICHandshakeOptionInsecureSkipVerify(true),
QUICHandshakeOptionServerName("sni"),
QUICHandshakeOptionRootCAs(certpool),
)
if _, ok := f.(*quicHandshakeFunc); !ok {
t.Fatal("unexpected type. Expected: quicHandshakeFunc")
}
})

t.Run("Apply quicHandshakeFunc", func(t *testing.T) {
wasClosed := false
plainConn := &mocks.QUICEarlyConnection{
Expand Down Expand Up @@ -103,10 +86,7 @@ func TestQUICHandshake(t *testing.T) {
return tt.dialer
},
}))
quicHandshake := &quicHandshakeFunc{
Rt: rt,
ServerName: tt.sni,
}
quicHandshake := QUICHandshake(rt, TLSHandshakeOptionServerName(tt.sni))
endpoint := &Endpoint{
Address: "1.2.3.4:567",
Network: "udp",
Expand Down Expand Up @@ -136,61 +116,3 @@ func TestQUICHandshake(t *testing.T) {
}
})
}

/*
Test cases:
- With input SNI
- With input domain
- With input host address
- With input IP address
*/
func TestServerNameQUIC(t *testing.T) {
t.Run("With input SNI", func(t *testing.T) {
sni := "sni"
endpoint := &Endpoint{
Address: "example.com:123",
}
f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), ServerName: sni}
serverName := f.serverName(endpoint)
if serverName != sni {
t.Fatalf("unexpected server name: %s", serverName)
}
})

t.Run("With input domain", func(t *testing.T) {
domain := "domain"
endpoint := &Endpoint{
Address: "example.com:123",
Domain: domain,
}
f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())}
serverName := f.serverName(endpoint)
if serverName != domain {
t.Fatalf("unexpected server name: %s", serverName)
}
})

t.Run("With input host address", func(t *testing.T) {
hostaddr := "example.com"
endpoint := &Endpoint{
Address: hostaddr + ":123",
}
f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())}
serverName := f.serverName(endpoint)
if serverName != hostaddr {
t.Fatalf("unexpected server name: %s", serverName)
}
})

t.Run("With input IP address", func(t *testing.T) {
ip := "1.1.1.1"
endpoint := &Endpoint{
Address: ip,
}
f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())}
serverName := f.serverName(endpoint)
if serverName != "" {
t.Fatalf("unexpected server name: %s", serverName)
}
})
}
Loading