diff --git a/internal/dslx/dns.go b/internal/dslx/dns.go index 9da9c9ed0f..62b86293f6 100644 --- a/internal/dslx/dns.go +++ b/internal/dslx/dns.go @@ -9,7 +9,6 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -73,13 +72,12 @@ type ResolvedAddresses struct { // DNSLookupGetaddrinfo returns a function that resolves a domain name to // IP addresses using libc's getaddrinfo function. func DNSLookupGetaddrinfo(rt Runtime) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { - return &dnsLookupGetaddrinfoFunc{nil, rt} + return &dnsLookupGetaddrinfoFunc{rt} } // dnsLookupGetaddrinfoFunc is the function returned by DNSLookupGetaddrinfo. type dnsLookupGetaddrinfoFunc struct { - resolver model.Resolver // for testing - rt Runtime + rt Runtime } // Apply implements Func. @@ -102,10 +100,8 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - resolver := f.resolver - if resolver == nil { - resolver = trace.NewStdlibResolver(f.rt.Logger()) - } + // create the resolver + resolver := trace.NewStdlibResolver(f.rt.Logger()) // lookup addrs, err := resolver.LookupHost(ctx, input.Domain) @@ -131,18 +127,16 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( // IP addresses using the given DNS-over-UDP resolver. func DNSLookupUDP(rt Runtime, resolver string) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { return &dnsLookupUDPFunc{ - Resolver: resolver, - mockResolver: nil, - rt: rt, + Resolver: resolver, + rt: rt, } } // dnsLookupUDPFunc is the function returned by DNSLookupUDP. type dnsLookupUDPFunc struct { // Resolver is the MANDATORY endpointed of the resolver to use. - Resolver string - mockResolver model.Resolver // for testing - rt Runtime + Resolver string + rt Runtime } // Apply implements Func. @@ -166,14 +160,12 @@ func (f *dnsLookupUDPFunc) Apply( ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() - resolver := f.mockResolver - if resolver == nil { - resolver = trace.NewParallelUDPResolver( - f.rt.Logger(), - trace.NewDialerWithoutResolver(f.rt.Logger()), - f.Resolver, - ) - } + // create the resolver + resolver := trace.NewParallelUDPResolver( + f.rt.Logger(), + trace.NewDialerWithoutResolver(f.rt.Logger()), + f.Resolver, + ) // lookup addrs, err := resolver.LookupHost(ctx, input.Domain) diff --git a/internal/dslx/dns_test.go b/internal/dslx/dns_test.go index 15f08e155e..5dd6f79bd1 100644 --- a/internal/dslx/dns_test.go +++ b/internal/dslx/dns_test.go @@ -3,6 +3,7 @@ package dslx import ( "context" "errors" + "net" "sync/atomic" "testing" "time" @@ -84,10 +85,15 @@ func TestGetaddrinfo(t *testing.T) { t.Run("with lookup error", func(t *testing.T) { mockedErr := errors.New("mocked") f := dnsLookupGetaddrinfoFunc{ - resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, mockedErr - }}, - rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), + rt: NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewStdlibResolver: func(logger model.DebugLogger) model.Resolver { + return &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, mockedErr + }, + } + }, + })), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -106,10 +112,15 @@ func TestGetaddrinfo(t *testing.T) { t.Run("with success", func(t *testing.T) { f := dnsLookupGetaddrinfoFunc{ - resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{"93.184.216.34"}, nil - }}, - rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()), + rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewStdlibResolver: func(logger model.DebugLogger) model.Resolver { + return &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return []string{"93.184.216.34"}, nil + }, + } + }, + })), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -171,10 +182,22 @@ func TestLookupUDP(t *testing.T) { mockedErr := errors.New("mocked") f := dnsLookupUDPFunc{ Resolver: "1.1.1.1:53", - mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return nil, mockedErr - }}, - rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), + rt: NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewParallelUDPResolver: func(logger model.DebugLogger, dialer model.Dialer, endpoint string) model.Resolver { + return &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return nil, mockedErr + }, + } + }, + MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { + return &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + panic("should not be called") + }, + } + }, + })), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -194,10 +217,22 @@ func TestLookupUDP(t *testing.T) { t.Run("with success", func(t *testing.T) { f := dnsLookupUDPFunc{ Resolver: "1.1.1.1:53", - mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { - return []string{"93.184.216.34"}, nil - }}, - rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()), + rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewParallelUDPResolver: func(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver { + return &mocks.Resolver{ + MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { + return []string{"93.184.216.34"}, nil + }, + } + }, + MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { + return &mocks.Dialer{ + MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) { + panic("should not be called") + }, + } + }, + })), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { diff --git a/internal/dslx/quic.go b/internal/dslx/quic.go index c643584e4a..3acf675ac9 100644 --- a/internal/dslx/quic.go +++ b/internal/dslx/quic.go @@ -13,7 +13,6 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/quic-go/quic-go" ) @@ -73,8 +72,6 @@ type quicHandshakeFunc struct { // ServerName is the ServerName to handshake for. ServerName string - - dialer model.QUICDialer // for testing } // Apply implements Func. @@ -97,10 +94,7 @@ func (f *quicHandshakeFunc) Apply( // setup udpListener := netxlite.NewUDPListener() - quicDialer := f.dialer - if quicDialer == nil { - quicDialer = trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger()) - } + quicDialer := trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger()) config := &tls.Config{ NextProtos: []string{"h3"}, InsecureSkipVerify: f.InsecureSkipVerify, diff --git a/internal/dslx/quic_test.go b/internal/dslx/quic_test.go index 40c4923812..2d34954bae 100644 --- a/internal/dslx/quic_test.go +++ b/internal/dslx/quic_test.go @@ -98,10 +98,13 @@ func TestQUICHandshake(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()) + rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewQUICDialerWithoutResolver: func(listener model.UDPListener, logger model.DebugLogger, w ...model.QUICDialerWrapper) model.QUICDialer { + return tt.dialer + }, + })) quicHandshake := &quicHandshakeFunc{ Rt: rt, - dialer: tt.dialer, ServerName: tt.sni, } endpoint := &Endpoint{ @@ -131,24 +134,6 @@ func TestQUICHandshake(t *testing.T) { }) wasClosed = false } - - t.Run("with nil dialer", func(t *testing.T) { - quicHandshake := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), dialer: nil} - endpoint := &Endpoint{ - Address: "1.2.3.4:567", - Network: "udp", - } - ctx, cancel := context.WithCancel(context.Background()) - cancel() - res := quicHandshake.Apply(ctx, endpoint) - - if res.Error == nil { - t.Fatalf("expected an error here") - } - if res.State.QUICConn != nil { - t.Fatalf("unexpected conn: %s", res.State.QUICConn) - } - }) }) } diff --git a/internal/dslx/tcp.go b/internal/dslx/tcp.go index af5dbcff3c..eaa54c2d30 100644 --- a/internal/dslx/tcp.go +++ b/internal/dslx/tcp.go @@ -10,20 +10,18 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) // TCPConnect returns a function that establishes TCP connections. func TCPConnect(rt Runtime) Func[*Endpoint, *Maybe[*TCPConnection]] { - f := &tcpConnectFunc{nil, rt} + f := &tcpConnectFunc{rt} return f } // tcpConnectFunc is a function that establishes TCP connections. type tcpConnectFunc struct { - dialer model.Dialer // for testing - rt Runtime + rt Runtime } // Apply applies the function to its arguments. @@ -47,7 +45,7 @@ func (f *tcpConnectFunc) Apply( defer cancel() // obtain the dialer to use - dialer := f.dialerOrDefault(trace, f.rt.Logger()) + dialer := trace.NewDialerWithoutResolver(f.rt.Logger()) // connect conn, err := dialer.DialContext(ctx, "tcp", input.Address) @@ -74,15 +72,6 @@ func (f *tcpConnectFunc) Apply( } } -// dialerOrDefault is the function used to obtain a dialer -func (f *tcpConnectFunc) dialerOrDefault(trace Trace, logger model.Logger) model.Dialer { - dialer := f.dialer - if dialer == nil { - dialer = trace.NewDialerWithoutResolver(logger) - } - return dialer -} - // TCPConnection is an established TCP connection. If you initialize // manually, init at least the ones marked as MANDATORY. type TCPConnection struct { diff --git a/internal/dslx/tcp_test.go b/internal/dslx/tcp_test.go index 1ec42ef88a..8748b634bb 100644 --- a/internal/dslx/tcp_test.go +++ b/internal/dslx/tcp_test.go @@ -8,7 +8,6 @@ import ( "time" "github.com/google/go-cmp/cmp" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" ) @@ -68,8 +67,12 @@ func TestTCPConnect(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()) - tcpConnect := &tcpConnectFunc{tt.dialer, rt} + rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer { + return tt.dialer + }, + })) + tcpConnect := &tcpConnectFunc{rt} endpoint := &Endpoint{ Address: "1.2.3.4:567", Network: "tcp", @@ -99,15 +102,3 @@ func TestTCPConnect(t *testing.T) { } }) } - -// Make sure we get a valid dialer if no mocked dialer is configured -func TestDialerOrDefault(t *testing.T) { - f := &tcpConnectFunc{ - rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - dialer: nil, - } - dialer := f.dialerOrDefault(measurexlite.NewTrace(0, time.Now()), model.DiscardLogger) - if dialer == nil { - t.Fatal("expected non-nil dialer here") - } -} diff --git a/internal/dslx/tls.go b/internal/dslx/tls.go index 0c13215075..af67d59f61 100644 --- a/internal/dslx/tls.go +++ b/internal/dslx/tls.go @@ -12,7 +12,6 @@ import ( "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -82,9 +81,6 @@ type tlsHandshakeFunc struct { // ServerName is the ServerName to handshake for. ServerName string - - // for testing - handshaker model.TLSHandshaker } // Apply implements Func. @@ -108,7 +104,7 @@ func (f *tlsHandshakeFunc) Apply( ) // obtain the handshaker for use - handshaker := f.handshakerOrDefault(trace, f.Rt.Logger()) + handshaker := trace.NewTLSHandshakerStdlib(f.Rt.Logger()) // setup config := &tls.Config{ @@ -147,15 +143,6 @@ func (f *tlsHandshakeFunc) Apply( } } -// handshakerOrDefault is the function used to obtain an handshaker -func (f *tlsHandshakeFunc) handshakerOrDefault(trace Trace, logger model.Logger) model.TLSHandshaker { - handshaker := f.handshaker - if handshaker == nil { - handshaker = trace.NewTLSHandshakerStdlib(logger) - } - return handshaker -} - func (f *tlsHandshakeFunc) serverName(input *TCPConnection) string { if f.ServerName != "" { return f.ServerName diff --git a/internal/dslx/tls_test.go b/internal/dslx/tls_test.go index 2fd209661b..4dcce7e0e3 100644 --- a/internal/dslx/tls_test.go +++ b/internal/dslx/tls_test.go @@ -10,7 +10,6 @@ import ( "testing" "time" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/mocks" "github.com/ooni/probe-cli/v3/internal/model" ) @@ -133,16 +132,19 @@ func TestTLSHandshake(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + rt := NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{ + MockNewTLSHandshakerStdlib: func(logger model.DebugLogger) model.TLSHandshaker { + return tt.handshaker + }, + })) tlsHandshake := &tlsHandshakeFunc{ NextProto: tt.config.nextProtos, Rt: rt, ServerName: tt.config.sni, - handshaker: tt.handshaker, } idGen := &atomic.Int64{} zeroTime := time.Time{} - trace := measurexlite.NewTrace(idGen.Add(1), zeroTime) + trace := rt.NewTrace(idGen.Add(1), zeroTime) address := tt.config.address if address == "" { address = "1.2.3.4:567" @@ -233,19 +235,3 @@ func TestServerNameTLS(t *testing.T) { } }) } - -// Make sure we get a valid handshaker if no mocked handshaker is configured -func TestHandshakerOrDefault(t *testing.T) { - f := &tlsHandshakeFunc{ - InsecureSkipVerify: false, - NextProto: []string{}, - Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), - RootCAs: &x509.CertPool{}, - ServerName: "", - handshaker: nil, - } - handshaker := f.handshakerOrDefault(measurexlite.NewTrace(0, time.Now()), model.DiscardLogger) - if handshaker == nil { - t.Fatal("expected non-nil handshaker here") - } -}