Skip to content

Commit

Permalink
refactor(dslx): use model.MeasuringNetwork for testing
Browse files Browse the repository at this point in the history
This diff modifies dslx functions to always use the MeasuringNetwork
for testing rather than using specific func fields.

By doing this, we open up the possibility of simplifying the state of
each func, with the ultimate goal of making them pure functions.

By making them pure functions, we make the code more manageable and
easy to modify, which opens up for additional refactorings.
  • Loading branch information
bassosimone committed Oct 20, 2023
1 parent 12cf2d3 commit d1db35d
Show file tree
Hide file tree
Showing 8 changed files with 87 additions and 128 deletions.
36 changes: 14 additions & 22 deletions internal/dslx/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"time"

"github.com/ooni/probe-cli/v3/internal/logx"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)

Expand Down Expand Up @@ -73,13 +72,12 @@ type ResolvedAddresses struct {
// DNSLookupGetaddrinfo returns a function that resolves a domain name to
// IP addresses using libc's getaddrinfo function.
func DNSLookupGetaddrinfo(rt Runtime) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] {
return &dnsLookupGetaddrinfoFunc{nil, rt}
return &dnsLookupGetaddrinfoFunc{rt}
}

// dnsLookupGetaddrinfoFunc is the function returned by DNSLookupGetaddrinfo.
type dnsLookupGetaddrinfoFunc struct {
resolver model.Resolver // for testing
rt Runtime
rt Runtime
}

// Apply implements Func.
Expand All @@ -102,10 +100,8 @@ func (f *dnsLookupGetaddrinfoFunc) Apply(
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

resolver := f.resolver
if resolver == nil {
resolver = trace.NewStdlibResolver(f.rt.Logger())
}
// create the resolver
resolver := trace.NewStdlibResolver(f.rt.Logger())

// lookup
addrs, err := resolver.LookupHost(ctx, input.Domain)
Expand All @@ -131,18 +127,16 @@ func (f *dnsLookupGetaddrinfoFunc) Apply(
// IP addresses using the given DNS-over-UDP resolver.
func DNSLookupUDP(rt Runtime, resolver string) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] {
return &dnsLookupUDPFunc{
Resolver: resolver,
mockResolver: nil,
rt: rt,
Resolver: resolver,
rt: rt,
}
}

// dnsLookupUDPFunc is the function returned by DNSLookupUDP.
type dnsLookupUDPFunc struct {
// Resolver is the MANDATORY endpointed of the resolver to use.
Resolver string
mockResolver model.Resolver // for testing
rt Runtime
Resolver string
rt Runtime
}

// Apply implements Func.
Expand All @@ -166,14 +160,12 @@ func (f *dnsLookupUDPFunc) Apply(
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

resolver := f.mockResolver
if resolver == nil {
resolver = trace.NewParallelUDPResolver(
f.rt.Logger(),
trace.NewDialerWithoutResolver(f.rt.Logger()),
f.Resolver,
)
}
// create the resolver
resolver := trace.NewParallelUDPResolver(
f.rt.Logger(),
trace.NewDialerWithoutResolver(f.rt.Logger()),
f.Resolver,
)

// lookup
addrs, err := resolver.LookupHost(ctx, input.Domain)
Expand Down
67 changes: 51 additions & 16 deletions internal/dslx/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dslx
import (
"context"
"errors"
"net"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -84,10 +85,15 @@ func TestGetaddrinfo(t *testing.T) {
t.Run("with lookup error", func(t *testing.T) {
mockedErr := errors.New("mocked")
f := dnsLookupGetaddrinfoFunc{
resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, mockedErr
}},
rt: NewMinimalRuntime(model.DiscardLogger, time.Now()),
rt: NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{
MockNewStdlibResolver: func(logger model.DebugLogger) model.Resolver {
return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, mockedErr
},
}
},
})),
}
res := f.Apply(context.Background(), domain)
if res.Observations == nil || len(res.Observations) <= 0 {
Expand All @@ -106,10 +112,15 @@ func TestGetaddrinfo(t *testing.T) {

t.Run("with success", func(t *testing.T) {
f := dnsLookupGetaddrinfoFunc{
resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"93.184.216.34"}, nil
}},
rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()),
rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{
MockNewStdlibResolver: func(logger model.DebugLogger) model.Resolver {
return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"93.184.216.34"}, nil
},
}
},
})),
}
res := f.Apply(context.Background(), domain)
if res.Observations == nil || len(res.Observations) <= 0 {
Expand Down Expand Up @@ -171,10 +182,22 @@ func TestLookupUDP(t *testing.T) {
mockedErr := errors.New("mocked")
f := dnsLookupUDPFunc{
Resolver: "1.1.1.1:53",
mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, mockedErr
}},
rt: NewMinimalRuntime(model.DiscardLogger, time.Now()),
rt: NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{
MockNewParallelUDPResolver: func(logger model.DebugLogger, dialer model.Dialer, endpoint string) model.Resolver {
return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, mockedErr
},
}
},
MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer {
return &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
panic("should not be called")
},
}
},
})),
}
res := f.Apply(context.Background(), domain)
if res.Observations == nil || len(res.Observations) <= 0 {
Expand All @@ -194,10 +217,22 @@ func TestLookupUDP(t *testing.T) {
t.Run("with success", func(t *testing.T) {
f := dnsLookupUDPFunc{
Resolver: "1.1.1.1:53",
mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"93.184.216.34"}, nil
}},
rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()),
rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{
MockNewParallelUDPResolver: func(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver {
return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"93.184.216.34"}, nil
},
}
},
MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer {
return &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
panic("should not be called")
},
}
},
})),
}
res := f.Apply(context.Background(), domain)
if res.Observations == nil || len(res.Observations) <= 0 {
Expand Down
8 changes: 1 addition & 7 deletions internal/dslx/quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"time"

"github.com/ooni/probe-cli/v3/internal/logx"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/quic-go/quic-go"
)
Expand Down Expand Up @@ -73,8 +72,6 @@ type quicHandshakeFunc struct {

// ServerName is the ServerName to handshake for.
ServerName string

dialer model.QUICDialer // for testing
}

// Apply implements Func.
Expand All @@ -97,10 +94,7 @@ func (f *quicHandshakeFunc) Apply(

// setup
udpListener := netxlite.NewUDPListener()
quicDialer := f.dialer
if quicDialer == nil {
quicDialer = trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger())
}
quicDialer := trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger())
config := &tls.Config{
NextProtos: []string{"h3"},
InsecureSkipVerify: f.InsecureSkipVerify,
Expand Down
25 changes: 5 additions & 20 deletions internal/dslx/quic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,13 @@ func TestQUICHandshake(t *testing.T) {

for name, tt := range tests {
t.Run(name, func(t *testing.T) {
rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now())
rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{
MockNewQUICDialerWithoutResolver: func(listener model.UDPListener, logger model.DebugLogger, w ...model.QUICDialerWrapper) model.QUICDialer {
return tt.dialer
},
}))
quicHandshake := &quicHandshakeFunc{
Rt: rt,
dialer: tt.dialer,
ServerName: tt.sni,
}
endpoint := &Endpoint{
Expand Down Expand Up @@ -131,24 +134,6 @@ func TestQUICHandshake(t *testing.T) {
})
wasClosed = false
}

t.Run("with nil dialer", func(t *testing.T) {
quicHandshake := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), dialer: nil}
endpoint := &Endpoint{
Address: "1.2.3.4:567",
Network: "udp",
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
res := quicHandshake.Apply(ctx, endpoint)

if res.Error == nil {
t.Fatalf("expected an error here")
}
if res.State.QUICConn != nil {
t.Fatalf("unexpected conn: %s", res.State.QUICConn)
}
})
})
}

Expand Down
17 changes: 3 additions & 14 deletions internal/dslx/tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,20 +10,18 @@ import (
"time"

"github.com/ooni/probe-cli/v3/internal/logx"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)

// TCPConnect returns a function that establishes TCP connections.
func TCPConnect(rt Runtime) Func[*Endpoint, *Maybe[*TCPConnection]] {
f := &tcpConnectFunc{nil, rt}
f := &tcpConnectFunc{rt}
return f
}

// tcpConnectFunc is a function that establishes TCP connections.
type tcpConnectFunc struct {
dialer model.Dialer // for testing
rt Runtime
rt Runtime
}

// Apply applies the function to its arguments.
Expand All @@ -47,7 +45,7 @@ func (f *tcpConnectFunc) Apply(
defer cancel()

// obtain the dialer to use
dialer := f.dialerOrDefault(trace, f.rt.Logger())
dialer := trace.NewDialerWithoutResolver(f.rt.Logger())

// connect
conn, err := dialer.DialContext(ctx, "tcp", input.Address)
Expand All @@ -74,15 +72,6 @@ func (f *tcpConnectFunc) Apply(
}
}

// dialerOrDefault is the function used to obtain a dialer
func (f *tcpConnectFunc) dialerOrDefault(trace Trace, logger model.Logger) model.Dialer {
dialer := f.dialer
if dialer == nil {
dialer = trace.NewDialerWithoutResolver(logger)
}
return dialer
}

// TCPConnection is an established TCP connection. If you initialize
// manually, init at least the ones marked as MANDATORY.
type TCPConnection struct {
Expand Down
21 changes: 6 additions & 15 deletions internal/dslx/tcp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"time"

"github.com/google/go-cmp/cmp"
"github.com/ooni/probe-cli/v3/internal/measurexlite"
"github.com/ooni/probe-cli/v3/internal/mocks"
"github.com/ooni/probe-cli/v3/internal/model"
)
Expand Down Expand Up @@ -68,8 +67,12 @@ func TestTCPConnect(t *testing.T) {

for name, tt := range tests {
t.Run(name, func(t *testing.T) {
rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now())
tcpConnect := &tcpConnectFunc{tt.dialer, rt}
rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{
MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer {
return tt.dialer
},
}))
tcpConnect := &tcpConnectFunc{rt}
endpoint := &Endpoint{
Address: "1.2.3.4:567",
Network: "tcp",
Expand Down Expand Up @@ -99,15 +102,3 @@ func TestTCPConnect(t *testing.T) {
}
})
}

// Make sure we get a valid dialer if no mocked dialer is configured
func TestDialerOrDefault(t *testing.T) {
f := &tcpConnectFunc{
rt: NewMinimalRuntime(model.DiscardLogger, time.Now()),
dialer: nil,
}
dialer := f.dialerOrDefault(measurexlite.NewTrace(0, time.Now()), model.DiscardLogger)
if dialer == nil {
t.Fatal("expected non-nil dialer here")
}
}
15 changes: 1 addition & 14 deletions internal/dslx/tls.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
"time"

"github.com/ooni/probe-cli/v3/internal/logx"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)

Expand Down Expand Up @@ -82,9 +81,6 @@ type tlsHandshakeFunc struct {

// ServerName is the ServerName to handshake for.
ServerName string

// for testing
handshaker model.TLSHandshaker
}

// Apply implements Func.
Expand All @@ -108,7 +104,7 @@ func (f *tlsHandshakeFunc) Apply(
)

// obtain the handshaker for use
handshaker := f.handshakerOrDefault(trace, f.Rt.Logger())
handshaker := trace.NewTLSHandshakerStdlib(f.Rt.Logger())

// setup
config := &tls.Config{
Expand Down Expand Up @@ -147,15 +143,6 @@ func (f *tlsHandshakeFunc) Apply(
}
}

// handshakerOrDefault is the function used to obtain an handshaker
func (f *tlsHandshakeFunc) handshakerOrDefault(trace Trace, logger model.Logger) model.TLSHandshaker {
handshaker := f.handshaker
if handshaker == nil {
handshaker = trace.NewTLSHandshakerStdlib(logger)
}
return handshaker
}

func (f *tlsHandshakeFunc) serverName(input *TCPConnection) string {
if f.ServerName != "" {
return f.ServerName
Expand Down
Loading

0 comments on commit d1db35d

Please sign in to comment.