Skip to content

Commit

Permalink
refactor(dslx): reduce tlsHandshakeFunc state
Browse files Browse the repository at this point in the history
This diff reduces the state kept by the tlsHandshakeFunc struct
so that we can apply a transformation similar to the one we applied
for TCPConnect() and implement TLSHandshake() using a pure func.

The overall objective is that of factoring away completixity to
enable manipulating this code more easily.

While there, let's note that the changes applied here mean that we
can reuse this code for configuring tls.Config for the QUICHandshake.
  • Loading branch information
bassosimone committed Oct 18, 2023
1 parent 7c32b7e commit 57243bb
Show file tree
Hide file tree
Showing 2 changed files with 114 additions and 139 deletions.
110 changes: 53 additions & 57 deletions internal/dslx/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,75 +12,58 @@ import (
"time"

"github.com/ooni/probe-cli/v3/internal/logx"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)

// TLSHandshakeOption is an option you can pass to TLSHandshake.
type TLSHandshakeOption func(*tlsHandshakeFunc)
type TLSHandshakeOption func(config *tls.Config)

// TLSHandshakeOptionInsecureSkipVerify controls whether TLS verification is enabled.
func TLSHandshakeOptionInsecureSkipVerify(value bool) TLSHandshakeOption {
return func(thf *tlsHandshakeFunc) {
thf.InsecureSkipVerify = value
return func(config *tls.Config) {
config.InsecureSkipVerify = value
}
}

// TLSHandshakeOptionNextProto allows to configure the ALPN protocols.
func TLSHandshakeOptionNextProto(value []string) TLSHandshakeOption {
return func(thf *tlsHandshakeFunc) {
thf.NextProto = value
return func(config *tls.Config) {
config.NextProtos = value
}
}

// TLSHandshakeOptionRootCAs allows to configure custom root CAs.
func TLSHandshakeOptionRootCAs(value *x509.CertPool) TLSHandshakeOption {
return func(thf *tlsHandshakeFunc) {
thf.RootCAs = value
return func(config *tls.Config) {
config.RootCAs = value
}
}

// TLSHandshakeOptionServerName allows to configure the SNI to use.
func TLSHandshakeOptionServerName(value string) TLSHandshakeOption {
return func(thf *tlsHandshakeFunc) {
thf.ServerName = value
return func(config *tls.Config) {
config.ServerName = value
}
}

// TLSHandshake returns a function performing TSL handshakes.
func TLSHandshake(rt Runtime, options ...TLSHandshakeOption) Func[
*TCPConnection, *Maybe[*TLSConnection]] {
// See https://github.com/ooni/probe/issues/2413 to understand
// why we're using nil to force netxlite to use the cached
// default Mozilla cert pool.
f := &tlsHandshakeFunc{
InsecureSkipVerify: false,
NextProto: []string{},
RootCAs: nil,
Rt: rt,
ServerName: "",
}
for _, option := range options {
option(f)
Options: options,
Rt: rt,
}
return f
}

// tlsHandshakeFunc performs TLS handshakes.
type tlsHandshakeFunc struct {
// InsecureSkipVerify allows to skip TLS verification.
InsecureSkipVerify bool

// NextProto contains the ALPNs to negotiate.
NextProto []string
// Options contains the options.
Options []TLSHandshakeOption

// RootCAs contains the Root CAs to use.
RootCAs *x509.CertPool

// Pool is the Pool that owns us.
// Rt is the runtime that owns us.
Rt Runtime

// ServerName is the ServerName to handshake for.
ServerName string
}

// Apply implements Func.
Expand All @@ -89,30 +72,23 @@ func (f *tlsHandshakeFunc) Apply(
// keep using the same trace
trace := input.Trace

// use defaults or user-configured overrides
serverName := f.serverName(input)
nextProto := f.nextProto()
// create a suitable TLS configuration
config := tlsNewConfig(input.Address, []string{"h2", "http/1.1"}, input.Domain, f.Rt.Logger(), f.Options...)

// start the operation logger
ol := logx.NewOperationLogger(
f.Rt.Logger(),
"[#%d] TLSHandshake with %s SNI=%s ALPN=%v",
trace.Index(),
input.Address,
serverName,
nextProto,
config.ServerName,
config.NextProtos,
)

// obtain the handshaker for use
handshaker := trace.NewTLSHandshakerStdlib(f.Rt.Logger())

// setup
config := &tls.Config{
NextProtos: nextProto,
InsecureSkipVerify: f.InsecureSkipVerify,
RootCAs: f.RootCAs,
ServerName: serverName,
}
const timeout = 10 * time.Second
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
Expand Down Expand Up @@ -143,31 +119,51 @@ func (f *tlsHandshakeFunc) Apply(
}
}

func (f *tlsHandshakeFunc) serverName(input *TCPConnection) string {
if f.ServerName != "" {
return f.ServerName
// tlsNewConfig is an utility function to create a new TLS config.
//
// Arguments:
//
// - address is the endpoint address (e.g., 1.1.1.1:443);
//
// - defaultALPN contains the default to be used for configuring ALPN;
//
// - domain is the possibly empty domain to use;
//
// - logger is the logger to use;
//
// - options contains options to modify the TLS handshake defaults.
func tlsNewConfig(address string, defaultALPN []string, domain string, logger model.Logger, options ...TLSHandshakeOption) *tls.Config {
// See https://github.com/ooni/probe/issues/2413 to understand
// why we're using nil to force netxlite to use the cached
// default Mozilla cert pool.
config := &tls.Config{
NextProtos: append([]string{}, defaultALPN...),
InsecureSkipVerify: false,
RootCAs: nil,
ServerName: tlsServerName(address, domain, logger),
}
if input.Domain != "" {
return input.Domain
for _, option := range options {
option(config)
}
return config
}

// tlsServerName is an utility function to obtina the server name from a TCPConnection.
func tlsServerName(address, domain string, logger model.Logger) string {
if domain != "" {
return domain
}
addr, _, err := net.SplitHostPort(input.Address)
addr, _, err := net.SplitHostPort(address)
if err == nil {
return addr
}
// Note: golang requires a ServerName and fails if it's empty. If the provided
// ServerName is an IP address, however, golang WILL NOT emit any SNI extension
// in the ClientHello, consistently with RFC 6066 Section 3 requirements.
f.Rt.Logger().Warn("TLSHandshake: cannot determine which SNI to use")
logger.Warn("TLSHandshake: cannot determine which SNI to use")
return ""
}

func (f *tlsHandshakeFunc) nextProto() []string {
if len(f.NextProto) > 0 {
return f.NextProto
}
return []string{"h2", "http/1.1"}
}

// TLSConnection is an established TLS connection. If you initialize
// manually, init at least the ones marked as MANDATORY.
type TLSConnection struct {
Expand Down
143 changes: 61 additions & 82 deletions internal/dslx/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,51 +10,66 @@ import (
"testing"
"time"

"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/mocks"
"github.com/ooni/probe-cli/v3/internal/model"
)

/*
Test cases:
- Get tlsHandshakeFunc with options
- Apply tlsHandshakeFunc:
- with EOF
- with invalid address
- with success
- with sni
- with options
*/
func TestTLSHandshake(t *testing.T) {
t.Run("Get tlsHandshakeFunc with options", func(t *testing.T) {
func TestTLSNewConfig(t *testing.T) {
t.Run("without options", func(t *testing.T) {
config := tlsNewConfig("1.1.1.1:443", []string{"h2", "http/1.1"}, "sni", model.DiscardLogger)

if config.InsecureSkipVerify {
t.Fatalf("unexpected %s, expected %v, got %v", "InsecureSkipVerify", false, config.InsecureSkipVerify)
}
if diff := cmp.Diff([]string{"h2", "http/1.1"}, config.NextProtos); diff != "" {
t.Fatal(diff)
}
if config.ServerName != "sni" {
t.Fatalf("unexpected %s, expected %s, got %s", "ServerName", "sni", config.ServerName)
}
if !config.RootCAs.Equal(nil) {
t.Fatalf("unexpected %s, expected %v, got %v", "RootCAs", nil, config.RootCAs)
}
})

t.Run("with options", func(t *testing.T) {
certpool := x509.NewCertPool()
certpool.AddCert(&x509.Certificate{})

f := TLSHandshake(
NewMinimalRuntime(model.DiscardLogger, time.Now()),
config := tlsNewConfig(
"1.1.1.1:443", []string{"h2", "http/1.1"}, "sni", model.DiscardLogger,
TLSHandshakeOptionInsecureSkipVerify(true),
TLSHandshakeOptionNextProto([]string{"h2"}),
TLSHandshakeOptionServerName("sni"),
TLSHandshakeOptionServerName("example.domain"),
TLSHandshakeOptionRootCAs(certpool),
)
var handshakeFunc *tlsHandshakeFunc
var ok bool
if handshakeFunc, ok = f.(*tlsHandshakeFunc); !ok {
t.Fatal("unexpected type. Expected: tlsHandshakeFunc")
}
if !handshakeFunc.InsecureSkipVerify {
t.Fatalf("unexpected %s, expected %v, got %v", "InsecureSkipVerify", true, false)

if !config.InsecureSkipVerify {
t.Fatalf("unexpected %s, expected %v, got %v", "InsecureSkipVerify", true, config.InsecureSkipVerify)
}
if len(handshakeFunc.NextProto) != 1 || handshakeFunc.NextProto[0] != "h2" {
t.Fatalf("unexpected %s, expected %v, got %v", "NextProto", []string{"h2"}, handshakeFunc.NextProto)
if diff := cmp.Diff([]string{"h2"}, config.NextProtos); diff != "" {
t.Fatal(diff)
}
if handshakeFunc.ServerName != "sni" {
t.Fatalf("unexpected %s, expected %s, got %s", "ServerName", "sni", handshakeFunc.ServerName)
if config.ServerName != "example.domain" {
t.Fatalf("unexpected %s, expected %s, got %s", "ServerName", "example.domain", config.ServerName)
}
if !handshakeFunc.RootCAs.Equal(certpool) {
t.Fatalf("unexpected %s, expected %v, got %v", "RootCAs", certpool, handshakeFunc.RootCAs)
if !config.RootCAs.Equal(certpool) {
t.Fatalf("unexpected %s, expected %v, got %v", "RootCAs", nil, config.RootCAs)
}
})
}

/*
Test cases:
- Apply tlsHandshakeFunc:
- with EOF
- with invalid address
- with success
- with sni
- with options
*/
func TestTLSHandshake(t *testing.T) {
t.Run("Apply tlsHandshakeFunc", func(t *testing.T) {
wasClosed := false

Expand Down Expand Up @@ -137,11 +152,10 @@ func TestTLSHandshake(t *testing.T) {
return tt.handshaker
},
}))
tlsHandshake := &tlsHandshakeFunc{
NextProto: tt.config.nextProtos,
Rt: rt,
ServerName: tt.config.sni,
}
tlsHandshake := TLSHandshake(rt,
TLSHandshakeOptionNextProto(tt.config.nextProtos),
TLSHandshakeOptionServerName(tt.config.sni),
)
idGen := &atomic.Int64{}
zeroTime := time.Time{}
trace := rt.NewTrace(idGen.Add(1), zeroTime)
Expand Down Expand Up @@ -174,62 +188,27 @@ func TestTLSHandshake(t *testing.T) {

/*
Test cases:
- With input SNI
- With input domain
- With input host address
- With input IP address
- With domain
- With host address
- With IP address
*/
func TestServerNameTLS(t *testing.T) {
t.Run("With input SNI", func(t *testing.T) {
sni := "sni"
tcpConn := TCPConnection{
Address: "example.com:123",
}
f := &tlsHandshakeFunc{
Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()),
ServerName: sni,
}
serverName := f.serverName(&tcpConn)
if serverName != sni {
func TestTLSServerName(t *testing.T) {
t.Run("With domain", func(t *testing.T) {
serverName := tlsServerName("example.com:123", "domain", model.DiscardLogger)
if serverName != "domain" {
t.Fatalf("unexpected server name: %s", serverName)
}
})
t.Run("With input domain", func(t *testing.T) {
domain := "domain"
tcpConn := TCPConnection{
Address: "example.com:123",
Domain: domain,
}
f := &tlsHandshakeFunc{
Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()),
}
serverName := f.serverName(&tcpConn)
if serverName != domain {
t.Fatalf("unexpected server name: %s", serverName)
}
})
t.Run("With input host address", func(t *testing.T) {
hostaddr := "example.com"
tcpConn := TCPConnection{
Address: hostaddr + ":123",
}
f := &tlsHandshakeFunc{
Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()),
}
serverName := f.serverName(&tcpConn)
if serverName != hostaddr {

t.Run("With host address", func(t *testing.T) {
serverName := tlsServerName("1.1.1.1:443", "", model.DiscardLogger)
if serverName != "1.1.1.1" {
t.Fatalf("unexpected server name: %s", serverName)
}
})
t.Run("With input IP address", func(t *testing.T) {
ip := "1.1.1.1"
tcpConn := TCPConnection{
Address: ip,
}
f := &tlsHandshakeFunc{
Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()),
}
serverName := f.serverName(&tcpConn)

t.Run("With IP address", func(t *testing.T) {
serverName := tlsServerName("1.1.1.1", "", model.DiscardLogger)
if serverName != "" {
t.Fatalf("unexpected server name: %s", serverName)
}
Expand Down

0 comments on commit 57243bb

Please sign in to comment.