From 6497f87829cf825c1fe9353a5bd34cfa3deb0382 Mon Sep 17 00:00:00 2001 From: Simone Basso Date: Fri, 20 Oct 2023 14:35:44 +0200 Subject: [PATCH] cleanup(dslx): use model.UnderlyingNetwork for testing (#1377) Closes https://github.com/ooni/probe/issues/2582 --- internal/dslx/dns.go | 36 ++++++-------- internal/dslx/dns_test.go | 67 ++++++++++++++++++++------- internal/dslx/quic.go | 8 +--- internal/dslx/quic_test.go | 25 ++-------- internal/dslx/runtimemeasurex.go | 25 ++++++++-- internal/dslx/runtimemeasurex_test.go | 24 ++++++++++ internal/dslx/runtimeminimal.go | 33 +++++++++---- internal/dslx/runtimeminimal_test.go | 12 +++++ internal/dslx/tcp.go | 17 ++----- internal/dslx/tcp_test.go | 21 +++------ internal/dslx/tls.go | 15 +----- internal/dslx/tls_test.go | 26 +++-------- 12 files changed, 170 insertions(+), 139 deletions(-) create mode 100644 internal/dslx/runtimemeasurex_test.go diff --git a/internal/dslx/dns.go b/internal/dslx/dns.go index 9da9c9ed0f..62b86293f6 100644 --- a/internal/dslx/dns.go +++ b/internal/dslx/dns.go @@ -9,7 +9,6 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -73,13 +72,12 @@ type ResolvedAddresses struct { // DNSLookupGetaddrinfo returns a function that resolves a domain name to // IP addresses using libc's getaddrinfo function. func DNSLookupGetaddrinfo(rt Runtime) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { - return &dnsLookupGetaddrinfoFunc{nil, rt} + return &dnsLookupGetaddrinfoFunc{rt} } // dnsLookupGetaddrinfoFunc is the function returned by DNSLookupGetaddrinfo. type dnsLookupGetaddrinfoFunc struct { - resolver model.Resolver // for testing - rt Runtime + rt Runtime } // Apply implements Func. @@ -102,10 +100,8 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - resolver := f.resolver - if resolver == nil { - resolver = trace.NewStdlibResolver(f.rt.Logger()) - } + // create the resolver + resolver := trace.NewStdlibResolver(f.rt.Logger()) // lookup addrs, err := resolver.LookupHost(ctx, input.Domain) @@ -131,18 +127,16 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( // IP addresses using the given DNS-over-UDP resolver. func DNSLookupUDP(rt Runtime, resolver string) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { return &dnsLookupUDPFunc{ - Resolver: resolver, - mockResolver: nil, - rt: rt, + Resolver: resolver, + rt: rt, } } // dnsLookupUDPFunc is the function returned by DNSLookupUDP. type dnsLookupUDPFunc struct { // Resolver is the MANDATORY endpointed of the resolver to use. - Resolver string - mockResolver model.Resolver // for testing - rt Runtime + Resolver string + rt Runtime } // Apply implements Func. @@ -166,14 +160,12 @@ func (f *dnsLookupUDPFunc) Apply( ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - resolver := f.mockResolver - if resolver == nil { - resolver = trace.NewParallelUDPResolver( - f.rt.Logger(), - trace.NewDialerWithoutResolver(f.rt.Logger()), - f.Resolver, - ) - } + // create the resolver + resolver := trace.NewParallelUDPResolver( + f.rt.Logger(), + trace.NewDialerWithoutResolver(f.rt.Logger()), + f.Resolver, + ) // lookup addrs, err := resolver.LookupHost(ctx, input.Domain) diff --git a/internal/dslx/dns_test.go b/internal/dslx/dns_test.go index 15f08e155e..5dd6f79bd1 100644 --- a/internal/dslx/dns_test.go +++ b/internal/dslx/dns_test.go @@ -3,6 +3,7 @@ package dslx import ( "context" "errors" + "net" "sync/atomic" "testing" "time" @@ -84,10 +85,15 @@ func TestGetaddrinfo(t *testing.T) { t.Run("with lookup error", func(t *testing.T) { mockedErr := errors.New("mocked") f := dnsLookupGetaddrinfoFunc{ - resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, mockedErr - }}, - rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), + rt: NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewStdlibResolver: func(logger model.DebugLogger) model.Resolver { + return &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, mockedErr + }, + } + }, + })), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -106,10 +112,15 @@ func TestGetaddrinfo(t *testing.T) { t.Run("with success", func(t *testing.T) { f := dnsLookupGetaddrinfoFunc{ - resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{"93.184.216.34"}, nil - }}, - rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()), + rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewStdlibResolver: func(logger model.DebugLogger) model.Resolver { + return &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return []string{"93.184.216.34"}, nil + }, + } + }, + })), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -171,10 +182,22 @@ func TestLookupUDP(t *testing.T) { mockedErr := errors.New("mocked") f := dnsLookupUDPFunc{ Resolver: "1.1.1.1:53", - mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, mockedErr - }}, - rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), + rt: NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewParallelUDPResolver: func(logger model.DebugLogger, dialer model.Dialer, endpoint string) model.Resolver { + return &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, mockedErr + }, + } + }, + MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { + return &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + panic("should not be called") + }, + } + }, + })), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -194,10 +217,22 @@ func TestLookupUDP(t *testing.T) { t.Run("with success", func(t *testing.T) { f := dnsLookupUDPFunc{ Resolver: "1.1.1.1:53", - mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{"93.184.216.34"}, nil - }}, - rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()), + rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewParallelUDPResolver: func(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver { + return &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return []string{"93.184.216.34"}, nil + }, + } + }, + MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { + return &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + panic("should not be called") + }, + } + }, + })), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { diff --git a/internal/dslx/quic.go b/internal/dslx/quic.go index c643584e4a..3acf675ac9 100644 --- a/internal/dslx/quic.go +++ b/internal/dslx/quic.go @@ -13,7 +13,6 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/quic-go/quic-go" ) @@ -73,8 +72,6 @@ type quicHandshakeFunc struct { // ServerName is the ServerName to handshake for. ServerName string - - dialer model.QUICDialer // for testing } // Apply implements Func. @@ -97,10 +94,7 @@ func (f *quicHandshakeFunc) Apply( // setup udpListener := netxlite.NewUDPListener() - quicDialer := f.dialer - if quicDialer == nil { - quicDialer = trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger()) - } + quicDialer := trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger()) config := &tls.Config{ NextProtos: []string{"h3"}, InsecureSkipVerify: f.InsecureSkipVerify, diff --git a/internal/dslx/quic_test.go b/internal/dslx/quic_test.go index 40c4923812..2d34954bae 100644 --- a/internal/dslx/quic_test.go +++ b/internal/dslx/quic_test.go @@ -98,10 +98,13 @@ func TestQUICHandshake(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()) + rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewQUICDialerWithoutResolver: func(listener model.UDPListener, logger model.DebugLogger, w ...model.QUICDialerWrapper) model.QUICDialer { + return tt.dialer + }, + })) quicHandshake := &quicHandshakeFunc{ Rt: rt, - dialer: tt.dialer, ServerName: tt.sni, } endpoint := &Endpoint{ @@ -131,24 +134,6 @@ func TestQUICHandshake(t *testing.T) { }) wasClosed = false } - - t.Run("with nil dialer", func(t *testing.T) { - quicHandshake := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), dialer: nil} - endpoint := &Endpoint{ - Address: "1.2.3.4:567", - Network: "udp", - } - ctx, cancel := context.WithCancel(context.Background()) - cancel() - res := quicHandshake.Apply(ctx, endpoint) - - if res.Error == nil { - t.Fatalf("expected an error here") - } - if res.State.QUICConn != nil { - t.Fatalf("unexpected conn: %s", res.State.QUICConn) - } - }) }) } diff --git a/internal/dslx/runtimemeasurex.go b/internal/dslx/runtimemeasurex.go index a075d085b5..a5b7fd31f9 100644 --- a/internal/dslx/runtimemeasurex.go +++ b/internal/dslx/runtimemeasurex.go @@ -5,23 +5,42 @@ import ( "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/netxlite" ) +// RuntimeMeasurexLiteOption is an option for initializing a [*RuntimeMeasurexLite]. +type RuntimeMeasurexLiteOption func(rt *RuntimeMeasurexLite) + +// RuntimeMeasurexLiteOptionMeasuringNetwork allows to configure which [model.MeasuringNetwork] to use. +func RuntimeMeasurexLiteOptionMeasuringNetwork(netx model.MeasuringNetwork) RuntimeMeasurexLiteOption { + return func(rt *RuntimeMeasurexLite) { + rt.netx = netx + } +} + // NewRuntimeMeasurexLite creates a [Runtime] using [measurexlite] to collect [*Observations]. -func NewRuntimeMeasurexLite(logger model.Logger, zeroTime time.Time) *RuntimeMeasurexLite { - return &RuntimeMeasurexLite{ +func NewRuntimeMeasurexLite(logger model.Logger, zeroTime time.Time, options ...RuntimeMeasurexLiteOption) *RuntimeMeasurexLite { + rt := &RuntimeMeasurexLite{ MinimalRuntime: NewMinimalRuntime(logger, zeroTime), + netx: &netxlite.Netx{Underlying: nil}, // implies using the host's network + } + for _, option := range options { + option(rt) } + return rt } // RuntimeMeasurexLite uses [measurexlite] to collect [*Observations.] type RuntimeMeasurexLite struct { *MinimalRuntime + netx model.MeasuringNetwork } // NewTrace implements Runtime. func (p *RuntimeMeasurexLite) NewTrace(index int64, zeroTime time.Time, tags ...string) Trace { - return measurexlite.NewTrace(index, zeroTime, tags...) + trace := measurexlite.NewTrace(index, zeroTime, tags...) + trace.Netx = p.netx + return trace } var _ Runtime = &RuntimeMeasurexLite{} diff --git a/internal/dslx/runtimemeasurex_test.go b/internal/dslx/runtimemeasurex_test.go new file mode 100644 index 0000000000..1deb5a4547 --- /dev/null +++ b/internal/dslx/runtimemeasurex_test.go @@ -0,0 +1,24 @@ +package dslx + +import ( + "testing" + "time" + + "github.com/ooni/probe-cli/v3/internal/measurexlite" + "github.com/ooni/probe-cli/v3/internal/mocks" + "github.com/ooni/probe-cli/v3/internal/model" +) + +func TestMeasurexLiteRuntime(t *testing.T) { + t.Run("we can configure a custom model.MeasuringNetwork", func(t *testing.T) { + netx := &mocks.MeasuringNetwork{} + rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(netx)) + if rt.netx != netx { + t.Fatal("did not set the measuring network") + } + trace := rt.NewTrace(rt.IDGenerator().Add(1), rt.ZeroTime()).(*measurexlite.Trace) + if trace.Netx != netx { + t.Fatal("did not set the measuring network") + } + }) +} diff --git a/internal/dslx/runtimeminimal.go b/internal/dslx/runtimeminimal.go index 7003c2ceec..522505ef24 100644 --- a/internal/dslx/runtimeminimal.go +++ b/internal/dslx/runtimeminimal.go @@ -10,17 +10,32 @@ import ( "github.com/ooni/probe-cli/v3/internal/netxlite" ) +// MinimalRuntimeOption is an option for configuring the [*MinimalRuntime]. +type MinimalRuntimeOption func(rt *MinimalRuntime) + +// MinimalRuntimeOptionMeasuringNetwork configures the [model.MeasuringNetwork] to use. +func MinimalRuntimeOptionMeasuringNetwork(netx model.MeasuringNetwork) MinimalRuntimeOption { + return func(rt *MinimalRuntime) { + rt.netx = netx + } +} + // NewMinimalRuntime creates a minimal [Runtime] implementation. // // This [Runtime] implementation does not collect any [*Observations]. -func NewMinimalRuntime(logger model.Logger, zeroTime time.Time) *MinimalRuntime { - return &MinimalRuntime{ +func NewMinimalRuntime(logger model.Logger, zeroTime time.Time, options ...MinimalRuntimeOption) *MinimalRuntime { + rt := &MinimalRuntime{ idg: &atomic.Int64{}, logger: logger, mu: sync.Mutex{}, + netx: &netxlite.Netx{Underlying: nil}, // implies using the host's network v: []io.Closer{}, zeroT: zeroTime, } + for _, option := range options { + option(rt) + } + return rt } var _ Runtime = &MinimalRuntime{} @@ -30,6 +45,7 @@ type MinimalRuntime struct { idg *atomic.Int64 logger model.Logger mu sync.Mutex + netx model.MeasuringNetwork v []io.Closer zeroT time.Time } @@ -74,11 +90,12 @@ func (p *MinimalRuntime) Close() error { // NewTrace implements Runtime. func (p *MinimalRuntime) NewTrace(index int64, zeroTime time.Time, tags ...string) Trace { - return &minimalTrace{idx: index, tags: tags, zt: zeroTime} + return &minimalTrace{idx: index, netx: p.netx, tags: tags, zt: zeroTime} } type minimalTrace struct { idx int64 + netx model.MeasuringNetwork tags []string zt time.Time } @@ -105,27 +122,27 @@ func (tx *minimalTrace) NetworkEvents() (out []*model.ArchivalNetworkEvent) { // NewDialerWithoutResolver implements Trace. func (tx *minimalTrace) NewDialerWithoutResolver(dl model.DebugLogger, wrappers ...model.DialerWrapper) model.Dialer { - return netxlite.NewDialerWithoutResolver(dl, wrappers...) + return tx.netx.NewDialerWithoutResolver(dl, wrappers...) } // NewParallelUDPResolver implements Trace. func (tx *minimalTrace) NewParallelUDPResolver(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver { - return netxlite.NewParallelUDPResolver(logger, dialer, address) + return tx.netx.NewParallelUDPResolver(logger, dialer, address) } // NewQUICDialerWithoutResolver implements Trace. func (tx *minimalTrace) NewQUICDialerWithoutResolver(listener model.UDPListener, dl model.DebugLogger, wrappers ...model.QUICDialerWrapper) model.QUICDialer { - return netxlite.NewQUICDialerWithoutResolver(listener, dl, wrappers...) + return tx.netx.NewQUICDialerWithoutResolver(listener, dl, wrappers...) } // NewStdlibResolver implements Trace. func (tx *minimalTrace) NewStdlibResolver(logger model.DebugLogger) model.Resolver { - return netxlite.NewStdlibResolver(logger) + return tx.netx.NewStdlibResolver(logger) } // NewTLSHandshakerStdlib implements Trace. func (tx *minimalTrace) NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker { - return netxlite.NewTLSHandshakerStdlib(dl) + return tx.netx.NewTLSHandshakerStdlib(dl) } // QUICHandshakes implements Trace. diff --git a/internal/dslx/runtimeminimal_test.go b/internal/dslx/runtimeminimal_test.go index 4699787fb9..f773ccf82d 100644 --- a/internal/dslx/runtimeminimal_test.go +++ b/internal/dslx/runtimeminimal_test.go @@ -237,4 +237,16 @@ func TestMinimalRuntime(t *testing.T) { } }) }) + + t.Run("we can use a custom model.MeasuringNetwork", func(t *testing.T) { + netx := &mocks.MeasuringNetwork{} + rt := NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(netx)) + if rt.netx != netx { + t.Fatal("did not set the measuring network") + } + trace := rt.NewTrace(rt.IDGenerator().Add(1), rt.ZeroTime()).(*minimalTrace) + if trace.netx != netx { + t.Fatal("did not set the measuring network") + } + }) } diff --git a/internal/dslx/tcp.go b/internal/dslx/tcp.go index af5dbcff3c..eaa54c2d30 100644 --- a/internal/dslx/tcp.go +++ b/internal/dslx/tcp.go @@ -10,20 +10,18 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) // TCPConnect returns a function that establishes TCP connections. func TCPConnect(rt Runtime) Func[*Endpoint, *Maybe[*TCPConnection]] { - f := &tcpConnectFunc{nil, rt} + f := &tcpConnectFunc{rt} return f } // tcpConnectFunc is a function that establishes TCP connections. type tcpConnectFunc struct { - dialer model.Dialer // for testing - rt Runtime + rt Runtime } // Apply applies the function to its arguments. @@ -47,7 +45,7 @@ func (f *tcpConnectFunc) Apply( defer cancel() // obtain the dialer to use - dialer := f.dialerOrDefault(trace, f.rt.Logger()) + dialer := trace.NewDialerWithoutResolver(f.rt.Logger()) // connect conn, err := dialer.DialContext(ctx, "tcp", input.Address) @@ -74,15 +72,6 @@ func (f *tcpConnectFunc) Apply( } } -// dialerOrDefault is the function used to obtain a dialer -func (f *tcpConnectFunc) dialerOrDefault(trace Trace, logger model.Logger) model.Dialer { - dialer := f.dialer - if dialer == nil { - dialer = trace.NewDialerWithoutResolver(logger) - } - return dialer -} - // TCPConnection is an established TCP connection. If you initialize // manually, init at least the ones marked as MANDATORY. type TCPConnection struct { diff --git a/internal/dslx/tcp_test.go b/internal/dslx/tcp_test.go index 1ec42ef88a..8748b634bb 100644 --- a/internal/dslx/tcp_test.go +++ b/internal/dslx/tcp_test.go @@ -8,7 +8,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" ) @@ -68,8 +67,12 @@ func TestTCPConnect(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()) - tcpConnect := &tcpConnectFunc{tt.dialer, rt} + rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { + return tt.dialer + }, + })) + tcpConnect := &tcpConnectFunc{rt} endpoint := &Endpoint{ Address: "1.2.3.4:567", Network: "tcp", @@ -99,15 +102,3 @@ func TestTCPConnect(t *testing.T) { } }) } - -// Make sure we get a valid dialer if no mocked dialer is configured -func TestDialerOrDefault(t *testing.T) { - f := &tcpConnectFunc{ - rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - dialer: nil, - } - dialer := f.dialerOrDefault(measurexlite.NewTrace(0, time.Now()), model.DiscardLogger) - if dialer == nil { - t.Fatal("expected non-nil dialer here") - } -} diff --git a/internal/dslx/tls.go b/internal/dslx/tls.go index 59e508f681..5a37685dba 100644 --- a/internal/dslx/tls.go +++ b/internal/dslx/tls.go @@ -12,7 +12,6 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -82,9 +81,6 @@ type tlsHandshakeFunc struct { // ServerName is the ServerName to handshake for. ServerName string - - // for testing - handshaker model.TLSHandshaker } // Apply implements Func. @@ -108,7 +104,7 @@ func (f *tlsHandshakeFunc) Apply( ) // obtain the handshaker for use - handshaker := f.handshakerOrDefault(trace, f.Rt.Logger()) + handshaker := trace.NewTLSHandshakerStdlib(f.Rt.Logger()) // setup config := &tls.Config{ @@ -147,15 +143,6 @@ func (f *tlsHandshakeFunc) Apply( } } -// handshakerOrDefault is the function used to obtain an handshaker -func (f *tlsHandshakeFunc) handshakerOrDefault(trace Trace, logger model.Logger) model.TLSHandshaker { - handshaker := f.handshaker - if handshaker == nil { - handshaker = trace.NewTLSHandshakerStdlib(logger) - } - return handshaker -} - func (f *tlsHandshakeFunc) serverName(input *TCPConnection) string { if f.ServerName != "" { return f.ServerName diff --git a/internal/dslx/tls_test.go b/internal/dslx/tls_test.go index 2fd209661b..4dcce7e0e3 100644 --- a/internal/dslx/tls_test.go +++ b/internal/dslx/tls_test.go @@ -10,7 +10,6 @@ import ( "testing" "time" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" ) @@ -133,16 +132,19 @@ func TestTLSHandshake(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + rt := NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewTLSHandshakerStdlib: func(logger model.DebugLogger) model.TLSHandshaker { + return tt.handshaker + }, + })) tlsHandshake := &tlsHandshakeFunc{ NextProto: tt.config.nextProtos, Rt: rt, ServerName: tt.config.sni, - handshaker: tt.handshaker, } idGen := &atomic.Int64{} zeroTime := time.Time{} - trace := measurexlite.NewTrace(idGen.Add(1), zeroTime) + trace := rt.NewTrace(idGen.Add(1), zeroTime) address := tt.config.address if address == "" { address = "1.2.3.4:567" @@ -233,19 +235,3 @@ func TestServerNameTLS(t *testing.T) { } }) } - -// Make sure we get a valid handshaker if no mocked handshaker is configured -func TestHandshakerOrDefault(t *testing.T) { - f := &tlsHandshakeFunc{ - InsecureSkipVerify: false, - NextProto: []string{}, - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - RootCAs: &x509.CertPool{}, - ServerName: "", - handshaker: nil, - } - handshaker := f.handshakerOrDefault(measurexlite.NewTrace(0, time.Now()), model.DiscardLogger) - if handshaker == nil { - t.Fatal("expected non-nil handshaker here") - } -}