From 57243bb101195247cd16c0d619c1eb673b6c9821 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Wed, 18 Oct 2023 11:41:31 +0200 Subject: [PATCH] refactor(dslx): reduce tlsHandshakeFunc state This diff reduces the state kept by the tlsHandshakeFunc struct so that we can apply a transformation similar to the one we applied for TCPConnect() and implement TLSHandshake() using a pure func. The overall objective is that of factoring away completixity to enable manipulating this code more easily. While there, let's note that the changes applied here mean that we can reuse this code for configuring tls.Config for the QUICHandshake. --- internal/dslx/tls.go | 110 ++++++++++++++--------------- internal/dslx/tls_test.go | 143 ++++++++++++++++---------------------- 2 files changed, 114 insertions(+), 139 deletions(-) diff --git a/internal/dslx/tls.go b/internal/dslx/tls.go index af67d59f61..ee75086b4d 100644 --- a/internal/dslx/tls.go +++ b/internal/dslx/tls.go @@ -12,75 +12,58 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) // TLSHandshakeOption is an option you can pass to TLSHandshake. -type TLSHandshakeOption func(*tlsHandshakeFunc) +type TLSHandshakeOption func(config *tls.Config) // TLSHandshakeOptionInsecureSkipVerify controls whether TLS verification is enabled. func TLSHandshakeOptionInsecureSkipVerify(value bool) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.InsecureSkipVerify = value + return func(config *tls.Config) { + config.InsecureSkipVerify = value } } // TLSHandshakeOptionNextProto allows to configure the ALPN protocols. func TLSHandshakeOptionNextProto(value []string) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.NextProto = value + return func(config *tls.Config) { + config.NextProtos = value } } // TLSHandshakeOptionRootCAs allows to configure custom root CAs. func TLSHandshakeOptionRootCAs(value *x509.CertPool) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.RootCAs = value + return func(config *tls.Config) { + config.RootCAs = value } } // TLSHandshakeOptionServerName allows to configure the SNI to use. func TLSHandshakeOptionServerName(value string) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.ServerName = value + return func(config *tls.Config) { + config.ServerName = value } } // TLSHandshake returns a function performing TSL handshakes. func TLSHandshake(rt Runtime, options ...TLSHandshakeOption) Func[ *TCPConnection, *Maybe[*TLSConnection]] { - // See https://github.com/ooni/probe/issues/2413 to understand - // why we're using nil to force netxlite to use the cached - // default Mozilla cert pool. f := &tlsHandshakeFunc{ - InsecureSkipVerify: false, - NextProto: []string{}, - RootCAs: nil, - Rt: rt, - ServerName: "", - } - for _, option := range options { - option(f) + Options: options, + Rt: rt, } return f } // tlsHandshakeFunc performs TLS handshakes. type tlsHandshakeFunc struct { - // InsecureSkipVerify allows to skip TLS verification. - InsecureSkipVerify bool - - // NextProto contains the ALPNs to negotiate. - NextProto []string + // Options contains the options. + Options []TLSHandshakeOption - // RootCAs contains the Root CAs to use. - RootCAs *x509.CertPool - - // Pool is the Pool that owns us. + // Rt is the runtime that owns us. Rt Runtime - - // ServerName is the ServerName to handshake for. - ServerName string } // Apply implements Func. @@ -89,9 +72,8 @@ func (f *tlsHandshakeFunc) Apply( // keep using the same trace trace := input.Trace - // use defaults or user-configured overrides - serverName := f.serverName(input) - nextProto := f.nextProto() + // create a suitable TLS configuration + config := tlsNewConfig(input.Address, []string{"h2", "http/1.1"}, input.Domain, f.Rt.Logger(), f.Options...) // start the operation logger ol := logx.NewOperationLogger( @@ -99,20 +81,14 @@ func (f *tlsHandshakeFunc) Apply( "[#%d] TLSHandshake with %s SNI=%s ALPN=%v", trace.Index(), input.Address, - serverName, - nextProto, + config.ServerName, + config.NextProtos, ) // obtain the handshaker for use handshaker := trace.NewTLSHandshakerStdlib(f.Rt.Logger()) // setup - config := &tls.Config{ - NextProtos: nextProto, - InsecureSkipVerify: f.InsecureSkipVerify, - RootCAs: f.RootCAs, - ServerName: serverName, - } const timeout = 10 * time.Second ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() @@ -143,31 +119,51 @@ func (f *tlsHandshakeFunc) Apply( } } -func (f *tlsHandshakeFunc) serverName(input *TCPConnection) string { - if f.ServerName != "" { - return f.ServerName +// tlsNewConfig is an utility function to create a new TLS config. +// +// Arguments: +// +// - address is the endpoint address (e.g., 1.1.1.1:443); +// +// - defaultALPN contains the default to be used for configuring ALPN; +// +// - domain is the possibly empty domain to use; +// +// - logger is the logger to use; +// +// - options contains options to modify the TLS handshake defaults. +func tlsNewConfig(address string, defaultALPN []string, domain string, logger model.Logger, options ...TLSHandshakeOption) *tls.Config { + // See https://github.com/ooni/probe/issues/2413 to understand + // why we're using nil to force netxlite to use the cached + // default Mozilla cert pool. + config := &tls.Config{ + NextProtos: append([]string{}, defaultALPN...), + InsecureSkipVerify: false, + RootCAs: nil, + ServerName: tlsServerName(address, domain, logger), } - if input.Domain != "" { - return input.Domain + for _, option := range options { + option(config) + } + return config +} + +// tlsServerName is an utility function to obtina the server name from a TCPConnection. +func tlsServerName(address, domain string, logger model.Logger) string { + if domain != "" { + return domain } - addr, _, err := net.SplitHostPort(input.Address) + addr, _, err := net.SplitHostPort(address) if err == nil { return addr } // Note: golang requires a ServerName and fails if it's empty. If the provided // ServerName is an IP address, however, golang WILL NOT emit any SNI extension // in the ClientHello, consistently with RFC 6066 Section 3 requirements. - f.Rt.Logger().Warn("TLSHandshake: cannot determine which SNI to use") + logger.Warn("TLSHandshake: cannot determine which SNI to use") return "" } -func (f *tlsHandshakeFunc) nextProto() []string { - if len(f.NextProto) > 0 { - return f.NextProto - } - return []string{"h2", "http/1.1"} -} - // TLSConnection is an established TLS connection. If you initialize // manually, init at least the ones marked as MANDATORY. type TLSConnection struct { diff --git a/internal/dslx/tls_test.go b/internal/dslx/tls_test.go index 4dcce7e0e3..36df4a79ca 100644 --- a/internal/dslx/tls_test.go +++ b/internal/dslx/tls_test.go @@ -10,51 +10,66 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" ) -/* -Test cases: -- Get tlsHandshakeFunc with options -- Apply tlsHandshakeFunc: - - with EOF - - with invalid address - - with success - - with sni - - with options -*/ -func TestTLSHandshake(t *testing.T) { - t.Run("Get tlsHandshakeFunc with options", func(t *testing.T) { +func TestTLSNewConfig(t *testing.T) { + t.Run("without options", func(t *testing.T) { + config := tlsNewConfig("1.1.1.1:443", []string{"h2", "http/1.1"}, "sni", model.DiscardLogger) + + if config.InsecureSkipVerify { + t.Fatalf("unexpected %s, expected %v, got %v", "InsecureSkipVerify", false, config.InsecureSkipVerify) + } + if diff := cmp.Diff([]string{"h2", "http/1.1"}, config.NextProtos); diff != "" { + t.Fatal(diff) + } + if config.ServerName != "sni" { + t.Fatalf("unexpected %s, expected %s, got %s", "ServerName", "sni", config.ServerName) + } + if !config.RootCAs.Equal(nil) { + t.Fatalf("unexpected %s, expected %v, got %v", "RootCAs", nil, config.RootCAs) + } + }) + + t.Run("with options", func(t *testing.T) { certpool := x509.NewCertPool() certpool.AddCert(&x509.Certificate{}) - f := TLSHandshake( - NewMinimalRuntime(model.DiscardLogger, time.Now()), + config := tlsNewConfig( + "1.1.1.1:443", []string{"h2", "http/1.1"}, "sni", model.DiscardLogger, TLSHandshakeOptionInsecureSkipVerify(true), TLSHandshakeOptionNextProto([]string{"h2"}), - TLSHandshakeOptionServerName("sni"), + TLSHandshakeOptionServerName("example.domain"), TLSHandshakeOptionRootCAs(certpool), ) - var handshakeFunc *tlsHandshakeFunc - var ok bool - if handshakeFunc, ok = f.(*tlsHandshakeFunc); !ok { - t.Fatal("unexpected type. Expected: tlsHandshakeFunc") - } - if !handshakeFunc.InsecureSkipVerify { - t.Fatalf("unexpected %s, expected %v, got %v", "InsecureSkipVerify", true, false) + + if !config.InsecureSkipVerify { + t.Fatalf("unexpected %s, expected %v, got %v", "InsecureSkipVerify", true, config.InsecureSkipVerify) } - if len(handshakeFunc.NextProto) != 1 || handshakeFunc.NextProto[0] != "h2" { - t.Fatalf("unexpected %s, expected %v, got %v", "NextProto", []string{"h2"}, handshakeFunc.NextProto) + if diff := cmp.Diff([]string{"h2"}, config.NextProtos); diff != "" { + t.Fatal(diff) } - if handshakeFunc.ServerName != "sni" { - t.Fatalf("unexpected %s, expected %s, got %s", "ServerName", "sni", handshakeFunc.ServerName) + if config.ServerName != "example.domain" { + t.Fatalf("unexpected %s, expected %s, got %s", "ServerName", "example.domain", config.ServerName) } - if !handshakeFunc.RootCAs.Equal(certpool) { - t.Fatalf("unexpected %s, expected %v, got %v", "RootCAs", certpool, handshakeFunc.RootCAs) + if !config.RootCAs.Equal(certpool) { + t.Fatalf("unexpected %s, expected %v, got %v", "RootCAs", nil, config.RootCAs) } }) +} +/* +Test cases: +- Apply tlsHandshakeFunc: + - with EOF + - with invalid address + - with success + - with sni + - with options +*/ +func TestTLSHandshake(t *testing.T) { t.Run("Apply tlsHandshakeFunc", func(t *testing.T) { wasClosed := false @@ -137,11 +152,10 @@ func TestTLSHandshake(t *testing.T) { return tt.handshaker }, })) - tlsHandshake := &tlsHandshakeFunc{ - NextProto: tt.config.nextProtos, - Rt: rt, - ServerName: tt.config.sni, - } + tlsHandshake := TLSHandshake(rt, + TLSHandshakeOptionNextProto(tt.config.nextProtos), + TLSHandshakeOptionServerName(tt.config.sni), + ) idGen := &atomic.Int64{} zeroTime := time.Time{} trace := rt.NewTrace(idGen.Add(1), zeroTime) @@ -174,62 +188,27 @@ func TestTLSHandshake(t *testing.T) { /* Test cases: -- With input SNI -- With input domain -- With input host address -- With input IP address +- With domain +- With host address +- With IP address */ -func TestServerNameTLS(t *testing.T) { - t.Run("With input SNI", func(t *testing.T) { - sni := "sni" - tcpConn := TCPConnection{ - Address: "example.com:123", - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - ServerName: sni, - } - serverName := f.serverName(&tcpConn) - if serverName != sni { +func TestTLSServerName(t *testing.T) { + t.Run("With domain", func(t *testing.T) { + serverName := tlsServerName("example.com:123", "domain", model.DiscardLogger) + if serverName != "domain" { t.Fatalf("unexpected server name: %s", serverName) } }) - t.Run("With input domain", func(t *testing.T) { - domain := "domain" - tcpConn := TCPConnection{ - Address: "example.com:123", - Domain: domain, - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - } - serverName := f.serverName(&tcpConn) - if serverName != domain { - t.Fatalf("unexpected server name: %s", serverName) - } - }) - t.Run("With input host address", func(t *testing.T) { - hostaddr := "example.com" - tcpConn := TCPConnection{ - Address: hostaddr + ":123", - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - } - serverName := f.serverName(&tcpConn) - if serverName != hostaddr { + + t.Run("With host address", func(t *testing.T) { + serverName := tlsServerName("1.1.1.1:443", "", model.DiscardLogger) + if serverName != "1.1.1.1" { t.Fatalf("unexpected server name: %s", serverName) } }) - t.Run("With input IP address", func(t *testing.T) { - ip := "1.1.1.1" - tcpConn := TCPConnection{ - Address: ip, - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - } - serverName := f.serverName(&tcpConn) + + t.Run("With IP address", func(t *testing.T) { + serverName := tlsServerName("1.1.1.1", "", model.DiscardLogger) if serverName != "" { t.Fatalf("unexpected server name: %s", serverName) }