Skip to content

Commit

Permalink
cleanup(dslx): use model.UnderlyingNetwork for testing (#1377)
Browse files Browse the repository at this point in the history
  • Loading branch information
bassosimone authored Oct 20, 2023
1 parent 2da2e4d commit 18cef86
Show file tree
Hide file tree
Showing 12 changed files with 170 additions and 139 deletions.
36 changes: 14 additions & 22 deletions internal/dslx/dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"time"

"github.com/ooni/probe-cli/v3/internal/logx"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)

Expand Down Expand Up @@ -73,13 +72,12 @@ type ResolvedAddresses struct {
// DNSLookupGetaddrinfo returns a function that resolves a domain name to
// IP addresses using libc's getaddrinfo function.
func DNSLookupGetaddrinfo(rt Runtime) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] {
return &dnsLookupGetaddrinfoFunc{nil, rt}
return &dnsLookupGetaddrinfoFunc{rt}
}

// dnsLookupGetaddrinfoFunc is the function returned by DNSLookupGetaddrinfo.
type dnsLookupGetaddrinfoFunc struct {
resolver model.Resolver // for testing
rt Runtime
rt Runtime
}

// Apply implements Func.
Expand All @@ -102,10 +100,8 @@ func (f *dnsLookupGetaddrinfoFunc) Apply(
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

resolver := f.resolver
if resolver == nil {
resolver = trace.NewStdlibResolver(f.rt.Logger())
}
// create the resolver
resolver := trace.NewStdlibResolver(f.rt.Logger())

// lookup
addrs, err := resolver.LookupHost(ctx, input.Domain)
Expand All @@ -131,18 +127,16 @@ func (f *dnsLookupGetaddrinfoFunc) Apply(
// IP addresses using the given DNS-over-UDP resolver.
func DNSLookupUDP(rt Runtime, resolver string) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] {
return &dnsLookupUDPFunc{
Resolver: resolver,
mockResolver: nil,
rt: rt,
Resolver: resolver,
rt: rt,
}
}

// dnsLookupUDPFunc is the function returned by DNSLookupUDP.
type dnsLookupUDPFunc struct {
// Resolver is the MANDATORY endpointed of the resolver to use.
Resolver string
mockResolver model.Resolver // for testing
rt Runtime
Resolver string
rt Runtime
}

// Apply implements Func.
Expand All @@ -166,14 +160,12 @@ func (f *dnsLookupUDPFunc) Apply(
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()

resolver := f.mockResolver
if resolver == nil {
resolver = trace.NewParallelUDPResolver(
f.rt.Logger(),
trace.NewDialerWithoutResolver(f.rt.Logger()),
f.Resolver,
)
}
// create the resolver
resolver := trace.NewParallelUDPResolver(
f.rt.Logger(),
trace.NewDialerWithoutResolver(f.rt.Logger()),
f.Resolver,
)

// lookup
addrs, err := resolver.LookupHost(ctx, input.Domain)
Expand Down
67 changes: 51 additions & 16 deletions internal/dslx/dns_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package dslx
import (
"context"
"errors"
"net"
"sync/atomic"
"testing"
"time"
Expand Down Expand Up @@ -84,10 +85,15 @@ func TestGetaddrinfo(t *testing.T) {
t.Run("with lookup error", func(t *testing.T) {
mockedErr := errors.New("mocked")
f := dnsLookupGetaddrinfoFunc{
resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, mockedErr
}},
rt: NewMinimalRuntime(model.DiscardLogger, time.Now()),
rt: NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{
MockNewStdlibResolver: func(logger model.DebugLogger) model.Resolver {
return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, mockedErr
},
}
},
})),
}
res := f.Apply(context.Background(), domain)
if res.Observations == nil || len(res.Observations) <= 0 {
Expand All @@ -106,10 +112,15 @@ func TestGetaddrinfo(t *testing.T) {

t.Run("with success", func(t *testing.T) {
f := dnsLookupGetaddrinfoFunc{
resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"93.184.216.34"}, nil
}},
rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()),
rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{
MockNewStdlibResolver: func(logger model.DebugLogger) model.Resolver {
return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"93.184.216.34"}, nil
},
}
},
})),
}
res := f.Apply(context.Background(), domain)
if res.Observations == nil || len(res.Observations) <= 0 {
Expand Down Expand Up @@ -171,10 +182,22 @@ func TestLookupUDP(t *testing.T) {
mockedErr := errors.New("mocked")
f := dnsLookupUDPFunc{
Resolver: "1.1.1.1:53",
mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, mockedErr
}},
rt: NewMinimalRuntime(model.DiscardLogger, time.Now()),
rt: NewMinimalRuntime(model.DiscardLogger, time.Now(), MinimalRuntimeOptionMeasuringNetwork(&mocks.MeasuringNetwork{
MockNewParallelUDPResolver: func(logger model.DebugLogger, dialer model.Dialer, endpoint string) model.Resolver {
return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return nil, mockedErr
},
}
},
MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer {
return &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
panic("should not be called")
},
}
},
})),
}
res := f.Apply(context.Background(), domain)
if res.Observations == nil || len(res.Observations) <= 0 {
Expand All @@ -194,10 +217,22 @@ func TestLookupUDP(t *testing.T) {
t.Run("with success", func(t *testing.T) {
f := dnsLookupUDPFunc{
Resolver: "1.1.1.1:53",
mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"93.184.216.34"}, nil
}},
rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()),
rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{
MockNewParallelUDPResolver: func(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver {
return &mocks.Resolver{
MockLookupHost: func(ctx context.Context, domain string) ([]string, error) {
return []string{"93.184.216.34"}, nil
},
}
},
MockNewDialerWithoutResolver: func(dl model.DebugLogger, w ...model.DialerWrapper) model.Dialer {
return &mocks.Dialer{
MockDialContext: func(ctx context.Context, network, address string) (net.Conn, error) {
panic("should not be called")
},
}
},
})),
}
res := f.Apply(context.Background(), domain)
if res.Observations == nil || len(res.Observations) <= 0 {
Expand Down
8 changes: 1 addition & 7 deletions internal/dslx/quic.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"time"

"github.com/ooni/probe-cli/v3/internal/logx"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
"github.com/quic-go/quic-go"
)
Expand Down Expand Up @@ -73,8 +72,6 @@ type quicHandshakeFunc struct {

// ServerName is the ServerName to handshake for.
ServerName string

dialer model.QUICDialer // for testing
}

// Apply implements Func.
Expand All @@ -97,10 +94,7 @@ func (f *quicHandshakeFunc) Apply(

// setup
udpListener := netxlite.NewUDPListener()
quicDialer := f.dialer
if quicDialer == nil {
quicDialer = trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger())
}
quicDialer := trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger())
config := &tls.Config{
NextProtos: []string{"h3"},
InsecureSkipVerify: f.InsecureSkipVerify,
Expand Down
25 changes: 5 additions & 20 deletions internal/dslx/quic_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,13 @@ func TestQUICHandshake(t *testing.T) {

for name, tt := range tests {
t.Run(name, func(t *testing.T) {
rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now())
rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(&mocks.MeasuringNetwork{
MockNewQUICDialerWithoutResolver: func(listener model.UDPListener, logger model.DebugLogger, w ...model.QUICDialerWrapper) model.QUICDialer {
return tt.dialer
},
}))
quicHandshake := &quicHandshakeFunc{
Rt: rt,
dialer: tt.dialer,
ServerName: tt.sni,
}
endpoint := &Endpoint{
Expand Down Expand Up @@ -131,24 +134,6 @@ func TestQUICHandshake(t *testing.T) {
})
wasClosed = false
}

t.Run("with nil dialer", func(t *testing.T) {
quicHandshake := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), dialer: nil}
endpoint := &Endpoint{
Address: "1.2.3.4:567",
Network: "udp",
}
ctx, cancel := context.WithCancel(context.Background())
cancel()
res := quicHandshake.Apply(ctx, endpoint)

if res.Error == nil {
t.Fatalf("expected an error here")
}
if res.State.QUICConn != nil {
t.Fatalf("unexpected conn: %s", res.State.QUICConn)
}
})
})
}

Expand Down
25 changes: 22 additions & 3 deletions internal/dslx/runtimemeasurex.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,42 @@ import (

"github.com/ooni/probe-cli/v3/internal/measurexlite"
"github.com/ooni/probe-cli/v3/internal/model"
"github.com/ooni/probe-cli/v3/internal/netxlite"
)

// RuntimeMeasurexLiteOption is an option for initializing a [*RuntimeMeasurexLite].
type RuntimeMeasurexLiteOption func(rt *RuntimeMeasurexLite)

// RuntimeMeasurexLiteOptionMeasuringNetwork allows to configure which [model.MeasuringNetwork] to use.
func RuntimeMeasurexLiteOptionMeasuringNetwork(netx model.MeasuringNetwork) RuntimeMeasurexLiteOption {
return func(rt *RuntimeMeasurexLite) {
rt.netx = netx
}
}

// NewRuntimeMeasurexLite creates a [Runtime] using [measurexlite] to collect [*Observations].
func NewRuntimeMeasurexLite(logger model.Logger, zeroTime time.Time) *RuntimeMeasurexLite {
return &RuntimeMeasurexLite{
func NewRuntimeMeasurexLite(logger model.Logger, zeroTime time.Time, options ...RuntimeMeasurexLiteOption) *RuntimeMeasurexLite {
rt := &RuntimeMeasurexLite{
MinimalRuntime: NewMinimalRuntime(logger, zeroTime),
netx: &netxlite.Netx{Underlying: nil}, // implies using the host's network
}
for _, option := range options {
option(rt)
}
return rt
}

// RuntimeMeasurexLite uses [measurexlite] to collect [*Observations.]
type RuntimeMeasurexLite struct {
*MinimalRuntime
netx model.MeasuringNetwork
}

// NewTrace implements Runtime.
func (p *RuntimeMeasurexLite) NewTrace(index int64, zeroTime time.Time, tags ...string) Trace {
return measurexlite.NewTrace(index, zeroTime, tags...)
trace := measurexlite.NewTrace(index, zeroTime, tags...)
trace.Netx = p.netx
return trace
}

var _ Runtime = &RuntimeMeasurexLite{}
24 changes: 24 additions & 0 deletions internal/dslx/runtimemeasurex_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
package dslx

import (
"testing"
"time"

"github.com/ooni/probe-cli/v3/internal/measurexlite"
"github.com/ooni/probe-cli/v3/internal/mocks"
"github.com/ooni/probe-cli/v3/internal/model"
)

func TestMeasurexLiteRuntime(t *testing.T) {
t.Run("we can configure a custom model.MeasuringNetwork", func(t *testing.T) {
netx := &mocks.MeasuringNetwork{}
rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now(), RuntimeMeasurexLiteOptionMeasuringNetwork(netx))
if rt.netx != netx {
t.Fatal("did not set the measuring network")
}
trace := rt.NewTrace(rt.IDGenerator().Add(1), rt.ZeroTime()).(*measurexlite.Trace)
if trace.Netx != netx {
t.Fatal("did not set the measuring network")
}
})
}
Loading

0 comments on commit 18cef86

Please sign in to comment.