diff --git a/internal/dslx/tls.go b/internal/dslx/tls.go index af67d59f61..ee75086b4d 100644 --- a/internal/dslx/tls.go +++ b/internal/dslx/tls.go @@ -12,75 +12,58 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" + "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) // TLSHandshakeOption is an option you can pass to TLSHandshake. -type TLSHandshakeOption func(*tlsHandshakeFunc) +type TLSHandshakeOption func(config *tls.Config) // TLSHandshakeOptionInsecureSkipVerify controls whether TLS verification is enabled. func TLSHandshakeOptionInsecureSkipVerify(value bool) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.InsecureSkipVerify = value + return func(config *tls.Config) { + config.InsecureSkipVerify = value } } // TLSHandshakeOptionNextProto allows to configure the ALPN protocols. func TLSHandshakeOptionNextProto(value []string) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.NextProto = value + return func(config *tls.Config) { + config.NextProtos = value } } // TLSHandshakeOptionRootCAs allows to configure custom root CAs. func TLSHandshakeOptionRootCAs(value *x509.CertPool) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.RootCAs = value + return func(config *tls.Config) { + config.RootCAs = value } } // TLSHandshakeOptionServerName allows to configure the SNI to use. func TLSHandshakeOptionServerName(value string) TLSHandshakeOption { - return func(thf *tlsHandshakeFunc) { - thf.ServerName = value + return func(config *tls.Config) { + config.ServerName = value } } // TLSHandshake returns a function performing TSL handshakes. func TLSHandshake(rt Runtime, options ...TLSHandshakeOption) Func[ *TCPConnection, *Maybe[*TLSConnection]] { - // See https://github.com/ooni/probe/issues/2413 to understand - // why we're using nil to force netxlite to use the cached - // default Mozilla cert pool. f := &tlsHandshakeFunc{ - InsecureSkipVerify: false, - NextProto: []string{}, - RootCAs: nil, - Rt: rt, - ServerName: "", - } - for _, option := range options { - option(f) + Options: options, + Rt: rt, } return f } // tlsHandshakeFunc performs TLS handshakes. type tlsHandshakeFunc struct { - // InsecureSkipVerify allows to skip TLS verification. - InsecureSkipVerify bool - - // NextProto contains the ALPNs to negotiate. - NextProto []string + // Options contains the options. + Options []TLSHandshakeOption - // RootCAs contains the Root CAs to use. - RootCAs *x509.CertPool - - // Pool is the Pool that owns us. + // Rt is the runtime that owns us. Rt Runtime - - // ServerName is the ServerName to handshake for. - ServerName string } // Apply implements Func. @@ -89,9 +72,8 @@ func (f *tlsHandshakeFunc) Apply( // keep using the same trace trace := input.Trace - // use defaults or user-configured overrides - serverName := f.serverName(input) - nextProto := f.nextProto() + // create a suitable TLS configuration + config := tlsNewConfig(input.Address, []string{"h2", "http/1.1"}, input.Domain, f.Rt.Logger(), f.Options...) // start the operation logger ol := logx.NewOperationLogger( @@ -99,20 +81,14 @@ func (f *tlsHandshakeFunc) Apply( "[#%d] TLSHandshake with %s SNI=%s ALPN=%v", trace.Index(), input.Address, - serverName, - nextProto, + config.ServerName, + config.NextProtos, ) // obtain the handshaker for use handshaker := trace.NewTLSHandshakerStdlib(f.Rt.Logger()) // setup - config := &tls.Config{ - NextProtos: nextProto, - InsecureSkipVerify: f.InsecureSkipVerify, - RootCAs: f.RootCAs, - ServerName: serverName, - } const timeout = 10 * time.Second ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() @@ -143,31 +119,51 @@ func (f *tlsHandshakeFunc) Apply( } } -func (f *tlsHandshakeFunc) serverName(input *TCPConnection) string { - if f.ServerName != "" { - return f.ServerName +// tlsNewConfig is an utility function to create a new TLS config. +// +// Arguments: +// +// - address is the endpoint address (e.g., 1.1.1.1:443); +// +// - defaultALPN contains the default to be used for configuring ALPN; +// +// - domain is the possibly empty domain to use; +// +// - logger is the logger to use; +// +// - options contains options to modify the TLS handshake defaults. +func tlsNewConfig(address string, defaultALPN []string, domain string, logger model.Logger, options ...TLSHandshakeOption) *tls.Config { + // See https://github.com/ooni/probe/issues/2413 to understand + // why we're using nil to force netxlite to use the cached + // default Mozilla cert pool. + config := &tls.Config{ + NextProtos: append([]string{}, defaultALPN...), + InsecureSkipVerify: false, + RootCAs: nil, + ServerName: tlsServerName(address, domain, logger), } - if input.Domain != "" { - return input.Domain + for _, option := range options { + option(config) + } + return config +} + +// tlsServerName is an utility function to obtina the server name from a TCPConnection. +func tlsServerName(address, domain string, logger model.Logger) string { + if domain != "" { + return domain } - addr, _, err := net.SplitHostPort(input.Address) + addr, _, err := net.SplitHostPort(address) if err == nil { return addr } // Note: golang requires a ServerName and fails if it's empty. If the provided // ServerName is an IP address, however, golang WILL NOT emit any SNI extension // in the ClientHello, consistently with RFC 6066 Section 3 requirements. - f.Rt.Logger().Warn("TLSHandshake: cannot determine which SNI to use") + logger.Warn("TLSHandshake: cannot determine which SNI to use") return "" } -func (f *tlsHandshakeFunc) nextProto() []string { - if len(f.NextProto) > 0 { - return f.NextProto - } - return []string{"h2", "http/1.1"} -} - // TLSConnection is an established TLS connection. If you initialize // manually, init at least the ones marked as MANDATORY. type TLSConnection struct { diff --git a/internal/dslx/tls_test.go b/internal/dslx/tls_test.go index 4dcce7e0e3..36df4a79ca 100644 --- a/internal/dslx/tls_test.go +++ b/internal/dslx/tls_test.go @@ -10,51 +10,66 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" ) -/* -Test cases: -- Get tlsHandshakeFunc with options -- Apply tlsHandshakeFunc: - - with EOF - - with invalid address - - with success - - with sni - - with options -*/ -func TestTLSHandshake(t *testing.T) { - t.Run("Get tlsHandshakeFunc with options", func(t *testing.T) { +func TestTLSNewConfig(t *testing.T) { + t.Run("without options", func(t *testing.T) { + config := tlsNewConfig("1.1.1.1:443", []string{"h2", "http/1.1"}, "sni", model.DiscardLogger) + + if config.InsecureSkipVerify { + t.Fatalf("unexpected %s, expected %v, got %v", "InsecureSkipVerify", false, config.InsecureSkipVerify) + } + if diff := cmp.Diff([]string{"h2", "http/1.1"}, config.NextProtos); diff != "" { + t.Fatal(diff) + } + if config.ServerName != "sni" { + t.Fatalf("unexpected %s, expected %s, got %s", "ServerName", "sni", config.ServerName) + } + if !config.RootCAs.Equal(nil) { + t.Fatalf("unexpected %s, expected %v, got %v", "RootCAs", nil, config.RootCAs) + } + }) + + t.Run("with options", func(t *testing.T) { certpool := x509.NewCertPool() certpool.AddCert(&x509.Certificate{}) - f := TLSHandshake( - NewMinimalRuntime(model.DiscardLogger, time.Now()), + config := tlsNewConfig( + "1.1.1.1:443", []string{"h2", "http/1.1"}, "sni", model.DiscardLogger, TLSHandshakeOptionInsecureSkipVerify(true), TLSHandshakeOptionNextProto([]string{"h2"}), - TLSHandshakeOptionServerName("sni"), + TLSHandshakeOptionServerName("example.domain"), TLSHandshakeOptionRootCAs(certpool), ) - var handshakeFunc *tlsHandshakeFunc - var ok bool - if handshakeFunc, ok = f.(*tlsHandshakeFunc); !ok { - t.Fatal("unexpected type. Expected: tlsHandshakeFunc") - } - if !handshakeFunc.InsecureSkipVerify { - t.Fatalf("unexpected %s, expected %v, got %v", "InsecureSkipVerify", true, false) + + if !config.InsecureSkipVerify { + t.Fatalf("unexpected %s, expected %v, got %v", "InsecureSkipVerify", true, config.InsecureSkipVerify) } - if len(handshakeFunc.NextProto) != 1 || handshakeFunc.NextProto[0] != "h2" { - t.Fatalf("unexpected %s, expected %v, got %v", "NextProto", []string{"h2"}, handshakeFunc.NextProto) + if diff := cmp.Diff([]string{"h2"}, config.NextProtos); diff != "" { + t.Fatal(diff) } - if handshakeFunc.ServerName != "sni" { - t.Fatalf("unexpected %s, expected %s, got %s", "ServerName", "sni", handshakeFunc.ServerName) + if config.ServerName != "example.domain" { + t.Fatalf("unexpected %s, expected %s, got %s", "ServerName", "example.domain", config.ServerName) } - if !handshakeFunc.RootCAs.Equal(certpool) { - t.Fatalf("unexpected %s, expected %v, got %v", "RootCAs", certpool, handshakeFunc.RootCAs) + if !config.RootCAs.Equal(certpool) { + t.Fatalf("unexpected %s, expected %v, got %v", "RootCAs", nil, config.RootCAs) } }) +} +/* +Test cases: +- Apply tlsHandshakeFunc: + - with EOF + - with invalid address + - with success + - with sni + - with options +*/ +func TestTLSHandshake(t *testing.T) { t.Run("Apply tlsHandshakeFunc", func(t *testing.T) { wasClosed := false @@ -137,11 +152,10 @@ func TestTLSHandshake(t *testing.T) { return tt.handshaker }, })) - tlsHandshake := &tlsHandshakeFunc{ - NextProto: tt.config.nextProtos, - Rt: rt, - ServerName: tt.config.sni, - } + tlsHandshake := TLSHandshake(rt, + TLSHandshakeOptionNextProto(tt.config.nextProtos), + TLSHandshakeOptionServerName(tt.config.sni), + ) idGen := &atomic.Int64{} zeroTime := time.Time{} trace := rt.NewTrace(idGen.Add(1), zeroTime) @@ -174,62 +188,27 @@ func TestTLSHandshake(t *testing.T) { /* Test cases: -- With input SNI -- With input domain -- With input host address -- With input IP address +- With domain +- With host address +- With IP address */ -func TestServerNameTLS(t *testing.T) { - t.Run("With input SNI", func(t *testing.T) { - sni := "sni" - tcpConn := TCPConnection{ - Address: "example.com:123", - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - ServerName: sni, - } - serverName := f.serverName(&tcpConn) - if serverName != sni { +func TestTLSServerName(t *testing.T) { + t.Run("With domain", func(t *testing.T) { + serverName := tlsServerName("example.com:123", "domain", model.DiscardLogger) + if serverName != "domain" { t.Fatalf("unexpected server name: %s", serverName) } }) - t.Run("With input domain", func(t *testing.T) { - domain := "domain" - tcpConn := TCPConnection{ - Address: "example.com:123", - Domain: domain, - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - } - serverName := f.serverName(&tcpConn) - if serverName != domain { - t.Fatalf("unexpected server name: %s", serverName) - } - }) - t.Run("With input host address", func(t *testing.T) { - hostaddr := "example.com" - tcpConn := TCPConnection{ - Address: hostaddr + ":123", - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - } - serverName := f.serverName(&tcpConn) - if serverName != hostaddr { + + t.Run("With host address", func(t *testing.T) { + serverName := tlsServerName("1.1.1.1:443", "", model.DiscardLogger) + if serverName != "1.1.1.1" { t.Fatalf("unexpected server name: %s", serverName) } }) - t.Run("With input IP address", func(t *testing.T) { - ip := "1.1.1.1" - tcpConn := TCPConnection{ - Address: ip, - } - f := &tlsHandshakeFunc{ - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - } - serverName := f.serverName(&tcpConn) + + t.Run("With IP address", func(t *testing.T) { + serverName := tlsServerName("1.1.1.1", "", model.DiscardLogger) if serverName != "" { t.Fatalf("unexpected server name: %s", serverName) }