diff --git a/internal/dslx/connpool.go b/internal/dslx/connpool.go deleted file mode 100644 index 0636147b02..0000000000 --- a/internal/dslx/connpool.go +++ /dev/null @@ -1,42 +0,0 @@ -package dslx - -// -// Connection pooling to streamline closing connections. -// - -import ( - "io" - "sync" -) - -// ConnPool tracks established connections. The zero value -// of this struct is ready to use. -type ConnPool struct { - mu sync.Mutex - v []io.Closer -} - -// MaybeTrack tracks the given connection if not nil. This -// method is safe for use by multiple goroutines. -func (p *ConnPool) MaybeTrack(c io.Closer) { - if c != nil { - defer p.mu.Unlock() - p.mu.Lock() - p.v = append(p.v, c) - } -} - -// Close closes all the tracked connections in reverse order. This -// method is safe for use by multiple goroutines. -func (p *ConnPool) Close() error { - // Implementation note: reverse order is such that we close TLS - // connections before we close the TCP connections they use. Hence - // we'll _gracefully_ close TLS connections. - defer p.mu.Unlock() - p.mu.Lock() - for idx := len(p.v) - 1; idx >= 0; idx-- { - _ = p.v[idx].Close() - } - p.v = nil // reset - return nil -} diff --git a/internal/dslx/connpool_test.go b/internal/dslx/connpool_test.go deleted file mode 100644 index daba7799f5..0000000000 --- a/internal/dslx/connpool_test.go +++ /dev/null @@ -1,100 +0,0 @@ -package dslx - -import ( - "errors" - "io" - "testing" - - "github.com/ooni/probe-cli/v3/internal/mocks" - "github.com/quic-go/quic-go" -) - -/* -Test cases: -- Maybe track connections: - - with nil - - with connection - - with quic connection - -- Close ConnPool: - - all Close() calls succeed - - one Close() call fails -*/ - -func closeableConnWithErr(err error) io.Closer { - return &mocks.Conn{ - MockClose: func() error { - return err - }, - } -} - -func closeableQUICConnWithErr(err error) io.Closer { - return &quicCloserConn{ - &mocks.QUICEarlyConnection{ - MockCloseWithError: func(code quic.ApplicationErrorCode, reason string) error { - return err - }, - }, - } -} - -func TestConnPool(t *testing.T) { - type connpoolTest struct { - mockConn io.Closer - want int // len of connpool.v - } - - t.Run("Maybe track connections", func(t *testing.T) { - tests := map[string]connpoolTest{ - "with nil": {mockConn: nil, want: 0}, - "with connection": {mockConn: closeableConnWithErr(nil), want: 1}, - "with quic connection": {mockConn: closeableQUICConnWithErr(nil), want: 1}, - } - for name, tt := range tests { - t.Run(name, func(t *testing.T) { - connpool := &ConnPool{} - connpool.MaybeTrack(tt.mockConn) - if len(connpool.v) != tt.want { - t.Fatalf("expected %d tracked connections, got: %d", tt.want, len(connpool.v)) - } - }) - } - }) - - t.Run("Close ConnPool", func(t *testing.T) { - mockErr := errors.New("mocked") - tests := map[string]struct { - pool *ConnPool - }{ - "all Close() calls succeed": { - pool: &ConnPool{ - v: []io.Closer{ - closeableConnWithErr(nil), - closeableQUICConnWithErr(nil), - }, - }, - }, - "one Close() call fails": { - pool: &ConnPool{ - v: []io.Closer{ - closeableConnWithErr(nil), - closeableConnWithErr(mockErr), - }, - }, - }, - } - - for name, tt := range tests { - t.Run(name, func(t *testing.T) { - err := tt.pool.Close() - if err != nil { // Close() should always return nil - t.Fatalf("unexpected error %s", err) - } - if tt.pool.v != nil { - t.Fatalf("v should be reset but is not") - } - }) - } - }) -} diff --git a/internal/dslx/dns.go b/internal/dslx/dns.go index 63759c544f..9da9c9ed0f 100644 --- a/internal/dslx/dns.go +++ b/internal/dslx/dns.go @@ -6,11 +6,9 @@ package dslx import ( "context" - "sync/atomic" "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -21,22 +19,6 @@ type DomainName string // DNSLookupOption is an option you can pass to NewDomainToResolve. type DNSLookupOption func(*DomainToResolve) -// DNSLookupOptionIDGenerator configures a specific ID generator. -// See DomainToResolve docs for more information. -func DNSLookupOptionIDGenerator(value *atomic.Int64) DNSLookupOption { - return func(dis *DomainToResolve) { - dis.IDGenerator = value - } -} - -// DNSLookupOptionLogger configures a specific logger. -// See DomainToResolve docs for more information. -func DNSLookupOptionLogger(value model.Logger) DNSLookupOption { - return func(dis *DomainToResolve) { - dis.Logger = value - } -} - // DNSLookupOptionTags allows to set tags to tag observations. func DNSLookupOptionTags(value ...string) DNSLookupOption { return func(dis *DomainToResolve) { @@ -44,24 +26,13 @@ func DNSLookupOptionTags(value ...string) DNSLookupOption { } } -// DNSLookupOptionZeroTime configures the measurement's zero time. -// See DomainToResolve docs for more information. -func DNSLookupOptionZeroTime(value time.Time) DNSLookupOption { - return func(dis *DomainToResolve) { - dis.ZeroTime = value - } -} - // NewDomainToResolve creates input for performing DNS lookups. The only mandatory // argument is the domain name to resolve. You can also supply optional // values by passing options to this function. func NewDomainToResolve(domain DomainName, options ...DNSLookupOption) *DomainToResolve { state := &DomainToResolve{ - Domain: string(domain), - IDGenerator: &atomic.Int64{}, - Logger: model.DiscardLogger, - Tags: []string{}, - ZeroTime: time.Now(), + Domain: string(domain), + Tags: []string{}, } for _, option := range options { option(state) @@ -79,25 +50,8 @@ type DomainToResolve struct { // Domain is the MANDATORY domain name to lookup. Domain string - // IDGenerator is the MANDATORY ID generator. We will use this field - // to assign unique IDs to distinct sub-measurements. The default - // construction implemented by NewDomainToResolve creates a new generator - // that starts counting from zero, leading to the first trace having - // one as its index. - IDGenerator *atomic.Int64 - - // Logger is the MANDATORY logger to use. The default construction - // implemented by NewDomainToResolve uses model.DiscardLogger. - Logger model.Logger - // Tags contains OPTIONAL tags to tag observations. Tags []string - - // ZeroTime is the MANDATORY zero time of the measurement. We will - // use this field as the zero value to compute relative elapsed times - // when generating measurements. The default construction by - // NewDomainToResolve initializes this field with the current time. - ZeroTime time.Time } // ResolvedAddresses contains the results of DNS lookups. To initialize @@ -110,33 +64,22 @@ type ResolvedAddresses struct { // from the value inside the DomainToResolve. Domain string - // IDGenerator is the ID generator. We inherit this field - // from the value inside the DomainToResolve. - IDGenerator *atomic.Int64 - - // Logger is the logger to use. We inherit this field - // from the value inside the DomainToResolve. - Logger model.Logger - // Trace is the trace we're currently using. This struct is // created by the various Apply functions using values inside // the DomainToResolve to initialize the Trace. - Trace *measurexlite.Trace - - // ZeroTime is the zero time of the measurement. We inherit this field - // from the value inside the DomainToResolve. - ZeroTime time.Time + Trace Trace } // DNSLookupGetaddrinfo returns a function that resolves a domain name to // IP addresses using libc's getaddrinfo function. -func DNSLookupGetaddrinfo() Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { - return &dnsLookupGetaddrinfoFunc{} +func DNSLookupGetaddrinfo(rt Runtime) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { + return &dnsLookupGetaddrinfoFunc{nil, rt} } // dnsLookupGetaddrinfoFunc is the function returned by DNSLookupGetaddrinfo. type dnsLookupGetaddrinfoFunc struct { resolver model.Resolver // for testing + rt Runtime } // Apply implements Func. @@ -144,13 +87,13 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] { // create trace - trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) + trace := f.rt.NewTrace(f.rt.IDGenerator().Add(1), f.rt.ZeroTime(), input.Tags...) // start the operation logger ol := logx.NewOperationLogger( - input.Logger, + f.rt.Logger(), "[#%d] DNSLookup[getaddrinfo] %s", - trace.Index, + trace.Index(), input.Domain, ) @@ -161,7 +104,7 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( resolver := f.resolver if resolver == nil { - resolver = trace.NewStdlibResolver(input.Logger) + resolver = trace.NewStdlibResolver(f.rt.Logger()) } // lookup @@ -171,12 +114,9 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( ol.Stop(err) state := &ResolvedAddresses{ - Addresses: addrs, // maybe empty - Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, - Trace: trace, - ZeroTime: input.ZeroTime, + Addresses: addrs, // maybe empty + Domain: input.Domain, + Trace: trace, } return &Maybe[*ResolvedAddresses]{ @@ -189,9 +129,11 @@ func (f *dnsLookupGetaddrinfoFunc) Apply( // DNSLookupUDP returns a function that resolves a domain name to // IP addresses using the given DNS-over-UDP resolver. -func DNSLookupUDP(resolver string) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { +func DNSLookupUDP(rt Runtime, resolver string) Func[*DomainToResolve, *Maybe[*ResolvedAddresses]] { return &dnsLookupUDPFunc{ - Resolver: resolver, + Resolver: resolver, + mockResolver: nil, + rt: rt, } } @@ -200,6 +142,7 @@ type dnsLookupUDPFunc struct { // Resolver is the MANDATORY endpointed of the resolver to use. Resolver string mockResolver model.Resolver // for testing + rt Runtime } // Apply implements Func. @@ -207,13 +150,13 @@ func (f *dnsLookupUDPFunc) Apply( ctx context.Context, input *DomainToResolve) *Maybe[*ResolvedAddresses] { // create trace - trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) + trace := f.rt.NewTrace(f.rt.IDGenerator().Add(1), f.rt.ZeroTime(), input.Tags...) // start the operation logger ol := logx.NewOperationLogger( - input.Logger, + f.rt.Logger(), "[#%d] DNSLookup[%s/udp] %s", - trace.Index, + trace.Index(), f.Resolver, input.Domain, ) @@ -226,8 +169,8 @@ func (f *dnsLookupUDPFunc) Apply( resolver := f.mockResolver if resolver == nil { resolver = trace.NewParallelUDPResolver( - input.Logger, - netxlite.NewDialerWithoutResolver(input.Logger), + f.rt.Logger(), + trace.NewDialerWithoutResolver(f.rt.Logger()), f.Resolver, ) } @@ -239,12 +182,9 @@ func (f *dnsLookupUDPFunc) Apply( ol.Stop(err) state := &ResolvedAddresses{ - Addresses: addrs, // maybe empty - Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, - Trace: trace, - ZeroTime: input.ZeroTime, + Addresses: addrs, // maybe empty + Domain: input.Domain, + Trace: trace, } return &Maybe[*ResolvedAddresses]{ diff --git a/internal/dslx/dns_test.go b/internal/dslx/dns_test.go index 2f3e292804..15f08e155e 100644 --- a/internal/dslx/dns_test.go +++ b/internal/dslx/dns_test.go @@ -30,26 +30,13 @@ func TestNewDomainToResolve(t *testing.T) { t.Run("with options", func(t *testing.T) { idGen := &atomic.Int64{} idGen.Add(42) - zt := time.Now() domainToResolve := NewDomainToResolve( DomainName("www.example.com"), - DNSLookupOptionIDGenerator(idGen), - DNSLookupOptionLogger(model.DiscardLogger), - DNSLookupOptionZeroTime(zt), DNSLookupOptionTags("antani"), ) if domainToResolve.Domain != "www.example.com" { t.Fatalf("unexpected domain") } - if domainToResolve.IDGenerator != idGen { - t.Fatalf("unexpected id generator") - } - if domainToResolve.Logger != model.DiscardLogger { - t.Fatalf("unexpected logger") - } - if domainToResolve.ZeroTime != zt { - t.Fatalf("unexpected zerotime") - } if diff := cmp.Diff([]string{"antani"}, domainToResolve.Tags); diff != "" { t.Fatal(diff) } @@ -67,7 +54,7 @@ Test cases: */ func TestGetaddrinfo(t *testing.T) { t.Run("Get dnsLookupGetaddrinfoFunc", func(t *testing.T) { - f := DNSLookupGetaddrinfo() + f := DNSLookupGetaddrinfo(NewMinimalRuntime(model.DiscardLogger, time.Now())) if _, ok := f.(*dnsLookupGetaddrinfoFunc); !ok { t.Fatal("unexpected type, want dnsLookupGetaddrinfoFunc") } @@ -75,15 +62,14 @@ func TestGetaddrinfo(t *testing.T) { t.Run("Apply dnsLookupGetaddrinfoFunc", func(t *testing.T) { domain := &DomainToResolve{ - Domain: "example.com", - Logger: model.DiscardLogger, - IDGenerator: &atomic.Int64{}, - Tags: []string{"antani"}, - ZeroTime: time.Time{}, + Domain: "example.com", + Tags: []string{"antani"}, } t.Run("with nil resolver", func(t *testing.T) { - f := dnsLookupGetaddrinfoFunc{} + f := dnsLookupGetaddrinfoFunc{ + rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), + } ctx, cancel := context.WithCancel(context.Background()) cancel() // immediately cancel the lookup res := f.Apply(ctx, domain) @@ -101,6 +87,7 @@ func TestGetaddrinfo(t *testing.T) { resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return nil, mockedErr }}, + rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -122,6 +109,7 @@ func TestGetaddrinfo(t *testing.T) { resolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return []string{"93.184.216.34"}, nil }}, + rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -153,7 +141,8 @@ Test cases: */ func TestLookupUDP(t *testing.T) { t.Run("Get dnsLookupUDPFunc", func(t *testing.T) { - f := DNSLookupUDP("1.1.1.1:53") + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + f := DNSLookupUDP(rt, "1.1.1.1:53") if _, ok := f.(*dnsLookupUDPFunc); !ok { t.Fatal("unexpected type, want dnsLookupUDPFunc") } @@ -161,15 +150,12 @@ func TestLookupUDP(t *testing.T) { t.Run("Apply dnsLookupGetaddrinfoFunc", func(t *testing.T) { domain := &DomainToResolve{ - Domain: "example.com", - Logger: model.DiscardLogger, - IDGenerator: &atomic.Int64{}, - Tags: []string{"antani"}, - ZeroTime: time.Time{}, + Domain: "example.com", + Tags: []string{"antani"}, } t.Run("with nil resolver", func(t *testing.T) { - f := dnsLookupUDPFunc{Resolver: "1.1.1.1:53"} + f := dnsLookupUDPFunc{Resolver: "1.1.1.1:53", rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} ctx, cancel := context.WithCancel(context.Background()) cancel() res := f.Apply(ctx, domain) @@ -188,6 +174,7 @@ func TestLookupUDP(t *testing.T) { mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return nil, mockedErr }}, + rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { @@ -210,6 +197,7 @@ func TestLookupUDP(t *testing.T) { mockResolver: &mocks.Resolver{MockLookupHost: func(ctx context.Context, domain string) ([]string, error) { return []string{"93.184.216.34"}, nil }}, + rt: NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()), } res := f.Apply(context.Background(), domain) if res.Observations == nil || len(res.Observations) <= 0 { diff --git a/internal/dslx/endpoint.go b/internal/dslx/endpoint.go index cd17d68c88..ea725f5d58 100644 --- a/internal/dslx/endpoint.go +++ b/internal/dslx/endpoint.go @@ -4,13 +4,6 @@ package dslx // Manipulate endpoints // -import ( - "sync/atomic" - "time" - - "github.com/ooni/probe-cli/v3/internal/model" -) - type ( // EndpointNetwork is the network of the endpoint EndpointNetwork string @@ -29,20 +22,11 @@ type Endpoint struct { // Domain is the OPTIONAL domain used to resolve the endpoints' IP address. Domain string - // IDGenerator is MANDATORY the ID generator to use. - IDGenerator *atomic.Int64 - - // Logger is the MANDATORY logger to use. - Logger model.Logger - // Network is the MANDATORY endpoint network. Network string // Tags contains OPTIONAL tags for tagging observations. Tags []string - - // ZeroTime is the MANDATORY zero time of the measurement. - ZeroTime time.Time } // EndpointOption is an option you can use to construct EndpointState. @@ -55,20 +39,6 @@ func EndpointOptionDomain(value string) EndpointOption { } } -// EndpointOptionIDGenerator allows to set the ID generator. -func EndpointOptionIDGenerator(value *atomic.Int64) EndpointOption { - return func(es *Endpoint) { - es.IDGenerator = value - } -} - -// EndpointOptionLogger allows to set the logger. -func EndpointOptionLogger(value model.Logger) EndpointOption { - return func(es *Endpoint) { - es.Logger = value - } -} - // EndpointOptionTags allows to set tags to tag observations. func EndpointOptionTags(value ...string) EndpointOption { return func(es *Endpoint) { @@ -76,13 +46,6 @@ func EndpointOptionTags(value ...string) EndpointOption { } } -// EndpointOptionZeroTime allows to set the zero time. -func EndpointOptionZeroTime(value time.Time) EndpointOption { - return func(es *Endpoint) { - es.ZeroTime = value - } -} - // NewEndpoint creates a new network endpoint (i.e., a three tuple composed // of a network protocol, an IP address, and a port). // @@ -97,13 +60,10 @@ func EndpointOptionZeroTime(value time.Time) EndpointOption { func NewEndpoint( network EndpointNetwork, address EndpointAddress, options ...EndpointOption) *Endpoint { epnt := &Endpoint{ - Address: string(address), - Domain: "", - IDGenerator: &atomic.Int64{}, - Logger: model.DiscardLogger, - Network: string(network), - Tags: []string{}, - ZeroTime: time.Now(), + Address: string(address), + Domain: "", + Network: string(network), + Tags: []string{}, } for _, option := range options { option(epnt) diff --git a/internal/dslx/endpoint_test.go b/internal/dslx/endpoint_test.go index 61170f1fc0..cc62ecedea 100644 --- a/internal/dslx/endpoint_test.go +++ b/internal/dslx/endpoint_test.go @@ -3,25 +3,19 @@ package dslx import ( "sync/atomic" "testing" - "time" "github.com/google/go-cmp/cmp" - "github.com/ooni/probe-cli/v3/internal/model" ) func TestEndpoint(t *testing.T) { idGen := &atomic.Int64{} idGen.Add(42) - zt := time.Now() t.Run("Create new endpoint", func(t *testing.T) { testEndpoint := NewEndpoint( "network", "10.9.8.76", EndpointOptionDomain("www.example.com"), - EndpointOptionIDGenerator(idGen), - EndpointOptionLogger(model.DiscardLogger), - EndpointOptionZeroTime(zt), EndpointOptionTags("antani"), ) if testEndpoint.Network != "network" { @@ -33,15 +27,6 @@ func TestEndpoint(t *testing.T) { if testEndpoint.Domain != "www.example.com" { t.Fatalf("unexpected domain") } - if testEndpoint.IDGenerator != idGen { - t.Fatalf("unexpected IDGenerator") - } - if testEndpoint.Logger != model.DiscardLogger { - t.Fatalf("unexpected logger") - } - if testEndpoint.ZeroTime != zt { - t.Fatalf("unexpected zero time") - } if diff := cmp.Diff([]string{"antani"}, testEndpoint.Tags); diff != "" { t.Fatal(diff) } diff --git a/internal/dslx/http_test.go b/internal/dslx/http_test.go index 17d2419de9..e95ca14e0f 100644 --- a/internal/dslx/http_test.go +++ b/internal/dslx/http_test.go @@ -30,7 +30,8 @@ Test cases: */ func TestHTTPRequest(t *testing.T) { t.Run("Get httpRequestFunc with options", func(t *testing.T) { - f := HTTPRequest( + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + f := HTTPRequest(rt, HTTPRequestOptionAccept("text/html"), HTTPRequestOptionAcceptLanguage("de"), HTTPRequestOptionHost("host"), @@ -96,16 +97,15 @@ func TestHTTPRequest(t *testing.T) { t.Run("with EOF", func(t *testing.T) { httpTransport := HTTPTransport{ - Address: "1.2.3.4:567", - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Scheme: "https", - Trace: trace, - Transport: eofTransport, - ZeroTime: zeroTime, + Address: "1.2.3.4:567", + Network: "tcp", + Scheme: "https", + Trace: trace, + Transport: eofTransport, + } + httpRequest := &httpRequestFunc{ + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } - httpRequest := &httpRequestFunc{} res := httpRequest.Apply(context.Background(), &httpTransport) if res.Error != io.EOF { t.Fatal("not the error we expected") @@ -117,14 +117,11 @@ func TestHTTPRequest(t *testing.T) { t.Run("with invalid method", func(t *testing.T) { httpTransport := HTTPTransport{ - Address: "1.2.3.4:567", - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Scheme: "https", - Trace: trace, - Transport: goodTransport, - ZeroTime: zeroTime, + Address: "1.2.3.4:567", + Network: "tcp", + Scheme: "https", + Trace: trace, + Transport: goodTransport, } httpRequest := &httpRequestFunc{ Method: "€", @@ -140,16 +137,15 @@ func TestHTTPRequest(t *testing.T) { t.Run("with port-less address", func(t *testing.T) { httpTransport := HTTPTransport{ - Address: "1.2.3.4", - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Scheme: "https", - Trace: trace, - Transport: goodTransport, - ZeroTime: zeroTime, + Address: "1.2.3.4", + Network: "tcp", + Scheme: "https", + Trace: trace, + Transport: goodTransport, + } + httpRequest := &httpRequestFunc{ + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } - httpRequest := &httpRequestFunc{} res := httpRequest.Apply(context.Background(), &httpTransport) if res.Error != nil { t.Fatal("expected error") @@ -193,16 +189,15 @@ func TestHTTPRequest(t *testing.T) { t.Run("with success (https)", func(t *testing.T) { httpTransport := HTTPTransport{ - Address: "1.2.3.4:443", - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Scheme: "https", - Trace: trace, - Transport: goodTransport, - ZeroTime: zeroTime, + Address: "1.2.3.4:443", + Network: "tcp", + Scheme: "https", + Trace: trace, + Transport: goodTransport, + } + httpRequest := &httpRequestFunc{ + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } - httpRequest := &httpRequestFunc{} res := httpRequest.Apply(context.Background(), &httpTransport) if res.Error != nil { t.Fatal("unexpected error") @@ -215,16 +210,15 @@ func TestHTTPRequest(t *testing.T) { t.Run("with success (http)", func(t *testing.T) { httpTransport := HTTPTransport{ - Address: "1.2.3.4:80", - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Scheme: "http", - Trace: trace, - Transport: goodTransport, - ZeroTime: zeroTime, + Address: "1.2.3.4:80", + Network: "tcp", + Scheme: "http", + Trace: trace, + Transport: goodTransport, + } + httpRequest := &httpRequestFunc{ + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } - httpRequest := &httpRequestFunc{} res := httpRequest.Apply(context.Background(), &httpTransport) if res.Error != nil { t.Fatal("unexpected error") @@ -237,21 +231,19 @@ func TestHTTPRequest(t *testing.T) { t.Run("with header options", func(t *testing.T) { httpTransport := HTTPTransport{ - Address: "1.2.3.4:567", - Domain: "domain.com", - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Scheme: "https", - Trace: trace, - Transport: goodTransport, - ZeroTime: zeroTime, + Address: "1.2.3.4:567", + Domain: "domain.com", + Network: "tcp", + Scheme: "https", + Trace: trace, + Transport: goodTransport, } httpRequest := &httpRequestFunc{ Accept: "text/html", AcceptLanguage: "de", Host: "host", Referer: "https://example.org", + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), URLPath: "/path/to/example", UserAgent: "Mozilla/5.0 Gecko/geckotrail Firefox/firefoxversion", } @@ -284,14 +276,16 @@ Test cases: */ func TestHTTPTCP(t *testing.T) { t.Run("Get httpTransportTCPFunc", func(t *testing.T) { - f := HTTPTransportTCP() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + f := HTTPTransportTCP(rt) if _, ok := f.(*httpTransportTCPFunc); !ok { t.Fatal("unexpected type") } }) t.Run("Get composed function: TCP with HTTP", func(t *testing.T) { - f := HTTPRequestOverTCP() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + f := HTTPRequestOverTCP(rt) if _, ok := f.(*compose2Func[*TCPConnection, *HTTPTransport, *HTTPResponse]); !ok { t.Fatal("unexpected type") } @@ -304,15 +298,14 @@ func TestHTTPTCP(t *testing.T) { trace := measurexlite.NewTrace(idGen.Add(1), zeroTime) address := "1.2.3.4:567" tcpConn := &TCPConnection{ - Address: address, - Conn: conn, - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Trace: trace, - ZeroTime: zeroTime, - } - f := httpTransportTCPFunc{} + Address: address, + Conn: conn, + Network: "tcp", + Trace: trace, + } + f := httpTransportTCPFunc{ + rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), + } res := f.Apply(context.Background(), tcpConn) if res.Error != nil { t.Fatalf("unexpected error: %s", res.Error) @@ -337,14 +330,16 @@ Test cases: */ func TestHTTPQUIC(t *testing.T) { t.Run("Get httpTransportQUICFunc", func(t *testing.T) { - f := HTTPTransportQUIC() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + f := HTTPTransportQUIC(rt) if _, ok := f.(*httpTransportQUICFunc); !ok { t.Fatal("unexpected type") } }) t.Run("Get composed function: QUIC with HTTP", func(t *testing.T) { - f := HTTPRequestOverQUIC() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + f := HTTPRequestOverQUIC(rt) if _, ok := f.(*compose2Func[*QUICConnection, *HTTPTransport, *HTTPResponse]); !ok { t.Fatal("unexpected type") } @@ -357,15 +352,14 @@ func TestHTTPQUIC(t *testing.T) { trace := measurexlite.NewTrace(idGen.Add(1), zeroTime) address := "1.2.3.4:567" quicConn := &QUICConnection{ - Address: address, - QUICConn: conn, - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "udp", - Trace: trace, - ZeroTime: zeroTime, - } - f := httpTransportQUICFunc{} + Address: address, + QUICConn: conn, + Network: "udp", + Trace: trace, + } + f := httpTransportQUICFunc{ + rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), + } res := f.Apply(context.Background(), quicConn) if res.Error != nil { t.Fatalf("unexpected error: %s", res.Error) @@ -390,14 +384,16 @@ Test cases: */ func TestHTTPTLS(t *testing.T) { t.Run("Get httpTransportTLSFunc", func(t *testing.T) { - f := HTTPTransportTLS() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + f := HTTPTransportTLS(rt) if _, ok := f.(*httpTransportTLSFunc); !ok { t.Fatal("unexpected type") } }) t.Run("Get composed function: TLS with HTTP", func(t *testing.T) { - f := HTTPRequestOverTLS() + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + f := HTTPRequestOverTLS(rt) if _, ok := f.(*compose2Func[*TLSConnection, *HTTPTransport, *HTTPResponse]); !ok { t.Fatal("unexpected type") } @@ -410,15 +406,14 @@ func TestHTTPTLS(t *testing.T) { trace := measurexlite.NewTrace(idGen.Add(1), zeroTime) address := "1.2.3.4:567" tlsConn := &TLSConnection{ - Address: address, - Conn: conn, - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Trace: trace, - ZeroTime: zeroTime, - } - f := httpTransportTLSFunc{} + Address: address, + Conn: conn, + Network: "tcp", + Trace: trace, + } + f := httpTransportTLSFunc{ + rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), + } res := f.Apply(context.Background(), tlsConn) if res.Error != nil { t.Fatalf("unexpected error: %s", res.Error) diff --git a/internal/dslx/httpcore.go b/internal/dslx/httpcore.go index b824f8f848..08ff04eb2b 100644 --- a/internal/dslx/httpcore.go +++ b/internal/dslx/httpcore.go @@ -10,7 +10,6 @@ import ( "net" "net/http" "net/url" - "sync/atomic" "time" "github.com/ooni/probe-cli/v3/internal/logx" @@ -32,12 +31,6 @@ type HTTPTransport struct { // Domain is the OPTIONAL domain from which the address was resolved. Domain string - // IDGenerator is the MANDATORY ID generator. - IDGenerator *atomic.Int64 - - // Logger is the MANDATORY logger to use. - Logger model.Logger - // Network is the MANDATORY network used by the underlying conn. Network string @@ -48,13 +41,10 @@ type HTTPTransport struct { TLSNegotiatedProtocol string // Trace is the MANDATORY trace we're using. - Trace *measurexlite.Trace + Trace Trace // Transport is the MANDATORY HTTP transport we're using. Transport model.HTTPTransport - - // ZeroTime is the MANDATORY zero time of the measurement. - ZeroTime time.Time } // HTTPRequestOption is an option you can pass to HTTPRequest. @@ -110,8 +100,8 @@ func HTTPRequestOptionUserAgent(value string) HTTPRequestOption { } // HTTPRequest issues an HTTP request using a transport and returns a response. -func HTTPRequest(options ...HTTPRequestOption) Func[*HTTPTransport, *Maybe[*HTTPResponse]] { - f := &httpRequestFunc{} +func HTTPRequest(rt Runtime, options ...HTTPRequestOption) Func[*HTTPTransport, *Maybe[*HTTPResponse]] { + f := &httpRequestFunc{Rt: rt} for _, option := range options { option(f) } @@ -135,6 +125,9 @@ type httpRequestFunc struct { // Referer is the OPTIONAL referer header. Referer string + // Rt is the MANDATORY runtime. + Rt Runtime + // URLPath is the OPTIONAL URL path. URLPath string @@ -162,9 +155,9 @@ func (f *httpRequestFunc) Apply( // start the operation logger ol := logx.NewOperationLogger( - input.Logger, + f.Rt.Logger(), "[#%d] HTTPRequest %s with %s/%s host=%s", - input.Trace.Index, + input.Trace.Index(), req.URL.String(), input.Address, input.Network, @@ -186,11 +179,8 @@ func (f *httpRequestFunc) Apply( HTTPRequest: req, // possibly nil HTTPResponse: resp, // possibly nil HTTPResponseBodySnapshot: body, // possibly nil - IDGenerator: input.IDGenerator, - Logger: input.Logger, Network: input.Network, Trace: input.Trace, - ZeroTime: input.ZeroTime, } return &Maybe[*HTTPResponse]{ @@ -262,7 +252,7 @@ func (f *httpRequestFunc) urlHost(input *HTTPTransport) string { } addr, port, err := net.SplitHostPort(input.Address) if err != nil { - input.Logger.Warnf("httpRequestFunc: cannot SplitHostPort for input.Address") + f.Rt.Logger().Warnf("httpRequestFunc: cannot SplitHostPort for input.Address") return input.Address } switch { @@ -288,7 +278,7 @@ func (f *httpRequestFunc) do( req *http.Request, ) (*http.Response, []byte, []*Observations, error) { const maxbody = 1 << 19 // TODO(bassosimone): allow to configure this value? - started := input.Trace.TimeSince(input.Trace.ZeroTime) + started := input.Trace.TimeSince(input.Trace.ZeroTime()) // manually create a single 1-length observations structure because // the trace cannot automatically capture HTTP events @@ -298,7 +288,7 @@ func (f *httpRequestFunc) do( observations[0].NetworkEvents = append(observations[0].NetworkEvents, measurexlite.NewAnnotationArchivalNetworkEvent( - input.Trace.Index, + input.Trace.Index(), started, "http_transaction_start", input.Trace.Tags()..., @@ -321,11 +311,11 @@ func (f *httpRequestFunc) do( samples := sampler.ExtractSamples() observations[0].NetworkEvents = append(observations[0].NetworkEvents, samples...) } - finished := input.Trace.TimeSince(input.Trace.ZeroTime) + finished := input.Trace.TimeSince(input.Trace.ZeroTime()) observations[0].NetworkEvents = append(observations[0].NetworkEvents, measurexlite.NewAnnotationArchivalNetworkEvent( - input.Trace.Index, + input.Trace.Index(), finished, "http_transaction_done", input.Trace.Tags()..., @@ -333,7 +323,7 @@ func (f *httpRequestFunc) do( observations[0].Requests = append(observations[0].Requests, measurexlite.NewArchivalHTTPRequestResult( - input.Trace.Index, + input.Trace.Index(), started, input.Network, input.Address, @@ -369,19 +359,10 @@ type HTTPResponse struct { // HTTPResponseBodySnapshot is the response body or nil if Err != nil. HTTPResponseBodySnapshot []byte - // IDGenerator is the MANDATORY ID generator. - IDGenerator *atomic.Int64 - - // Logger is the MANDATORY logger to use. - Logger model.Logger - // Network is the MANDATORY network we're connected to. Network string // Trace is the MANDATORY trace we're using. The trace is drained // when you call the Observations method. - Trace *measurexlite.Trace - - // ZeroTime is the MANDATORY zero time of the measurement. - ZeroTime time.Time + Trace Trace } diff --git a/internal/dslx/httpquic.go b/internal/dslx/httpquic.go index b2809bb2ca..6553f50d3f 100644 --- a/internal/dslx/httpquic.go +++ b/internal/dslx/httpquic.go @@ -11,24 +11,26 @@ import ( ) // HTTPRequestOverQUIC returns a Func that issues HTTP requests over QUIC. -func HTTPRequestOverQUIC(options ...HTTPRequestOption) Func[*QUICConnection, *Maybe[*HTTPResponse]] { - return Compose2(HTTPTransportQUIC(), HTTPRequest(options...)) +func HTTPRequestOverQUIC(rt Runtime, options ...HTTPRequestOption) Func[*QUICConnection, *Maybe[*HTTPResponse]] { + return Compose2(HTTPTransportQUIC(rt), HTTPRequest(rt, options...)) } // HTTPTransportQUIC converts a QUIC connection into an HTTP transport. -func HTTPTransportQUIC() Func[*QUICConnection, *Maybe[*HTTPTransport]] { - return &httpTransportQUICFunc{} +func HTTPTransportQUIC(rt Runtime) Func[*QUICConnection, *Maybe[*HTTPTransport]] { + return &httpTransportQUICFunc{rt} } // httpTransportQUICFunc is the function returned by HTTPTransportQUIC. -type httpTransportQUICFunc struct{} +type httpTransportQUICFunc struct { + rt Runtime +} // Apply implements Func. func (f *httpTransportQUICFunc) Apply( ctx context.Context, input *QUICConnection) *Maybe[*HTTPTransport] { // create transport httpTransport := netxlite.NewHTTP3Transport( - input.Logger, + f.rt.Logger(), netxlite.NewSingleUseQUICDialer(input.QUICConn), input.TLSConfig, ) @@ -36,14 +38,11 @@ func (f *httpTransportQUICFunc) Apply( state := &HTTPTransport{ Address: input.Address, Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, Network: input.Network, Scheme: "https", TLSNegotiatedProtocol: input.TLSState.NegotiatedProtocol, Trace: input.Trace, Transport: httpTransport, - ZeroTime: input.ZeroTime, } return &Maybe[*HTTPTransport]{ Error: nil, diff --git a/internal/dslx/httptcp.go b/internal/dslx/httptcp.go index 2571322130..b0d670cec7 100644 --- a/internal/dslx/httptcp.go +++ b/internal/dslx/httptcp.go @@ -11,17 +11,19 @@ import ( ) // HTTPRequestOverTCP returns a Func that issues HTTP requests over TCP. -func HTTPRequestOverTCP(options ...HTTPRequestOption) Func[*TCPConnection, *Maybe[*HTTPResponse]] { - return Compose2(HTTPTransportTCP(), HTTPRequest(options...)) +func HTTPRequestOverTCP(rt Runtime, options ...HTTPRequestOption) Func[*TCPConnection, *Maybe[*HTTPResponse]] { + return Compose2(HTTPTransportTCP(rt), HTTPRequest(rt, options...)) } // HTTPTransportTCP converts a TCP connection into an HTTP transport. -func HTTPTransportTCP() Func[*TCPConnection, *Maybe[*HTTPTransport]] { - return &httpTransportTCPFunc{} +func HTTPTransportTCP(rt Runtime) Func[*TCPConnection, *Maybe[*HTTPTransport]] { + return &httpTransportTCPFunc{rt} } // httpTransportTCPFunc is the function returned by HTTPTransportTCP -type httpTransportTCPFunc struct{} +type httpTransportTCPFunc struct { + rt Runtime +} // Apply implements Func func (f *httpTransportTCPFunc) Apply( @@ -30,21 +32,18 @@ func (f *httpTransportTCPFunc) Apply( // function, but we can probably avoid using it, given that this code is // not using tracing and does not care about those quirks. httpTransport := netxlite.NewHTTPTransport( - input.Logger, + f.rt.Logger(), netxlite.NewSingleUseDialer(input.Conn), netxlite.NewNullTLSDialer(), ) state := &HTTPTransport{ Address: input.Address, Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, Network: input.Network, Scheme: "http", TLSNegotiatedProtocol: "", Trace: input.Trace, Transport: httpTransport, - ZeroTime: input.ZeroTime, } return &Maybe[*HTTPTransport]{ Error: nil, diff --git a/internal/dslx/httptls.go b/internal/dslx/httptls.go index 9c1448ce0a..1c541afd1f 100644 --- a/internal/dslx/httptls.go +++ b/internal/dslx/httptls.go @@ -11,17 +11,19 @@ import ( ) // HTTPRequestOverTLS returns a Func that issues HTTP requests over TLS. -func HTTPRequestOverTLS(options ...HTTPRequestOption) Func[*TLSConnection, *Maybe[*HTTPResponse]] { - return Compose2(HTTPTransportTLS(), HTTPRequest(options...)) +func HTTPRequestOverTLS(rt Runtime, options ...HTTPRequestOption) Func[*TLSConnection, *Maybe[*HTTPResponse]] { + return Compose2(HTTPTransportTLS(rt), HTTPRequest(rt, options...)) } // HTTPTransportTLS converts a TLS connection into an HTTP transport. -func HTTPTransportTLS() Func[*TLSConnection, *Maybe[*HTTPTransport]] { - return &httpTransportTLSFunc{} +func HTTPTransportTLS(rt Runtime) Func[*TLSConnection, *Maybe[*HTTPTransport]] { + return &httpTransportTLSFunc{rt} } // httpTransportTLSFunc is the function returned by HTTPTransportTLS. -type httpTransportTLSFunc struct{} +type httpTransportTLSFunc struct { + rt Runtime +} // Apply implements Func. func (f *httpTransportTLSFunc) Apply( @@ -30,21 +32,18 @@ func (f *httpTransportTLSFunc) Apply( // function, but we can probably avoid using it, given that this code is // not using tracing and does not care about those quirks. httpTransport := netxlite.NewHTTPTransport( - input.Logger, + f.rt.Logger(), netxlite.NewNullDialer(), netxlite.NewSingleUseTLSDialer(input.Conn), ) state := &HTTPTransport{ Address: input.Address, Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, Network: input.Network, Scheme: "https", TLSNegotiatedProtocol: input.TLSState.NegotiatedProtocol, Trace: input.Trace, Transport: httpTransport, - ZeroTime: input.ZeroTime, } return &Maybe[*HTTPTransport]{ Error: nil, diff --git a/internal/dslx/integration_test.go b/internal/dslx/integration_test.go index df182e3af7..1eddf74457 100644 --- a/internal/dslx/integration_test.go +++ b/internal/dslx/integration_test.go @@ -4,7 +4,6 @@ import ( "context" "net/http" "net/http/httptest" - "sync/atomic" "testing" "time" @@ -31,26 +30,23 @@ func TestMakeSureWeCollectSpeedSamples(t *testing.T) { })) defer server.Close() - // instantiate a connection pool - pool := &ConnPool{} - defer pool.Close() + // instantiate a runtime + rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()) + defer rt.Close() // create a measuring function f0 := Compose3( - TCPConnect(pool), - HTTPTransportTCP(), - HTTPRequest(), + TCPConnect(rt), + HTTPTransportTCP(rt), + HTTPRequest(rt), ) // create the endpoint to measure epnt := &Endpoint{ - Address: server.Listener.Addr().String(), - Domain: "", - IDGenerator: &atomic.Int64{}, - Logger: model.DiscardLogger, - Network: "tcp", - Tags: []string{}, - ZeroTime: time.Now(), + Address: server.Listener.Addr().String(), + Domain: "", + Network: "tcp", + Tags: []string{}, } // measure the endpoint diff --git a/internal/dslx/observations.go b/internal/dslx/observations.go index e1be1b611d..269e999463 100644 --- a/internal/dslx/observations.go +++ b/internal/dslx/observations.go @@ -5,7 +5,6 @@ package dslx // import ( - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" ) @@ -54,7 +53,7 @@ func ExtractObservations[T any](rs ...*Maybe[T]) (out []*Observations) { // maybeTraceToObservations returns the observations inside the // trace taking into account the case where trace is nil. -func maybeTraceToObservations(trace *measurexlite.Trace) (out []*Observations) { +func maybeTraceToObservations(trace Trace) (out []*Observations) { if trace != nil { out = append(out, &Observations{ NetworkEvents: trace.NetworkEvents(), diff --git a/internal/dslx/quic.go b/internal/dslx/quic.go index d4e88573d7..c643584e4a 100644 --- a/internal/dslx/quic.go +++ b/internal/dslx/quic.go @@ -10,11 +10,9 @@ import ( "crypto/x509" "io" "net" - "sync/atomic" "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" "github.com/quic-go/quic-go" @@ -45,15 +43,15 @@ func QUICHandshakeOptionServerName(value string) QUICHandshakeOption { } // QUICHandshake returns a function performing QUIC handshakes. -func QUICHandshake(pool *ConnPool, options ...QUICHandshakeOption) Func[ +func QUICHandshake(rt Runtime, options ...QUICHandshakeOption) Func[ *Endpoint, *Maybe[*QUICConnection]] { // See https://github.com/ooni/probe/issues/2413 to understand // why we're using nil to force netxlite to use the cached // default Mozilla cert pool. f := &quicHandshakeFunc{ InsecureSkipVerify: false, - Pool: pool, RootCAs: nil, + Rt: rt, ServerName: "", } for _, option := range options { @@ -67,12 +65,12 @@ type quicHandshakeFunc struct { // InsecureSkipVerify allows to skip TLS verification. InsecureSkipVerify bool - // Pool is the ConnPool that owns us. - Pool *ConnPool - // RootCAs contains the Root CAs to use. RootCAs *x509.CertPool + // Rt is the Runtime that owns us. + Rt Runtime + // ServerName is the ServerName to handshake for. ServerName string @@ -83,16 +81,16 @@ type quicHandshakeFunc struct { func (f *quicHandshakeFunc) Apply( ctx context.Context, input *Endpoint) *Maybe[*QUICConnection] { // create trace - trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) + trace := f.Rt.NewTrace(f.Rt.IDGenerator().Add(1), f.Rt.ZeroTime(), input.Tags...) // use defaults or user-configured overrides serverName := f.serverName(input) // start the operation logger ol := logx.NewOperationLogger( - input.Logger, + f.Rt.Logger(), "[#%d] QUICHandshake with %s SNI=%s", - trace.Index, + trace.Index(), input.Address, serverName, ) @@ -101,7 +99,7 @@ func (f *quicHandshakeFunc) Apply( udpListener := netxlite.NewUDPListener() quicDialer := f.dialer if quicDialer == nil { - quicDialer = trace.NewQUICDialerWithoutResolver(udpListener, input.Logger) + quicDialer = trace.NewQUICDialerWithoutResolver(udpListener, f.Rt.Logger()) } config := &tls.Config{ NextProtos: []string{"h3"}, @@ -124,22 +122,19 @@ func (f *quicHandshakeFunc) Apply( } // possibly track established conn for late close - f.Pool.MaybeTrack(closerConn) + f.Rt.MaybeTrackConn(closerConn) // stop the operation logger ol.Stop(err) state := &QUICConnection{ - Address: input.Address, - QUICConn: quicConn, // possibly nil - Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, - Network: input.Network, - TLSConfig: config, - TLSState: tlsState, - Trace: trace, - ZeroTime: input.ZeroTime, + Address: input.Address, + QUICConn: quicConn, // possibly nil + Domain: input.Domain, + Network: input.Network, + TLSConfig: config, + TLSState: tlsState, + Trace: trace, } return &Maybe[*QUICConnection]{ @@ -164,7 +159,7 @@ func (f *quicHandshakeFunc) serverName(input *Endpoint) string { // Note: golang requires a ServerName and fails if it's empty. If the provided // ServerName is an IP address, however, golang WILL NOT emit any SNI extension // in the ClientHello, consistently with RFC 6066 Section 3 requirements. - input.Logger.Warn("TLSHandshake: cannot determine which SNI to use") + f.Rt.Logger().Warn("TLSHandshake: cannot determine which SNI to use") return "" } @@ -180,12 +175,6 @@ type QUICConnection struct { // Domain is the OPTIONAL domain we resolved. Domain string - // IDGenerator is the MANDATORY ID generator to use. - IDGenerator *atomic.Int64 - - // Logger is the MANDATORY logger to use. - Logger model.Logger - // Network is the MANDATORY network we tried to use when connecting. Network string @@ -197,10 +186,7 @@ type QUICConnection struct { TLSState tls.ConnectionState // Trace is the MANDATORY trace we're using. - Trace *measurexlite.Trace - - // ZeroTime is the MANDATORY zero time of the measurement. - ZeroTime time.Time + Trace Trace } type quicCloserConn struct { diff --git a/internal/dslx/quic_test.go b/internal/dslx/quic_test.go index 7e0a30cf9d..40c4923812 100644 --- a/internal/dslx/quic_test.go +++ b/internal/dslx/quic_test.go @@ -5,7 +5,6 @@ import ( "crypto/tls" "crypto/x509" "io" - "sync/atomic" "testing" "time" @@ -29,7 +28,7 @@ func TestQUICHandshake(t *testing.T) { certpool.AddCert(&x509.Certificate{}) f := QUICHandshake( - &ConnPool{}, + NewMinimalRuntime(model.DiscardLogger, time.Now()), QUICHandshakeOptionInsecureSkipVerify(true), QUICHandshakeOptionServerName("sni"), QUICHandshakeOptionRootCAs(certpool), @@ -99,19 +98,16 @@ func TestQUICHandshake(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - pool := &ConnPool{} + rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()) quicHandshake := &quicHandshakeFunc{ - Pool: pool, + Rt: rt, dialer: tt.dialer, ServerName: tt.sni, } endpoint := &Endpoint{ - Address: "1.2.3.4:567", - Network: "udp", - IDGenerator: &atomic.Int64{}, - Logger: model.DiscardLogger, - Tags: tt.tags, - ZeroTime: time.Time{}, + Address: "1.2.3.4:567", + Network: "udp", + Tags: tt.tags, } res := quicHandshake.Apply(context.Background(), endpoint) if res.Error != tt.expectErr { @@ -120,7 +116,7 @@ func TestQUICHandshake(t *testing.T) { if res.State == nil || res.State.QUICConn != tt.expectConn { t.Fatal("unexpected conn") } - pool.Close() + rt.Close() if wasClosed != tt.closed { t.Fatalf("unexpected connection closed state: %v", wasClosed) } @@ -137,13 +133,10 @@ func TestQUICHandshake(t *testing.T) { } t.Run("with nil dialer", func(t *testing.T) { - quicHandshake := &quicHandshakeFunc{Pool: &ConnPool{}, dialer: nil} + quicHandshake := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), dialer: nil} endpoint := &Endpoint{ - Address: "1.2.3.4:567", - Network: "udp", - IDGenerator: &atomic.Int64{}, - Logger: model.DiscardLogger, - ZeroTime: time.Time{}, + Address: "1.2.3.4:567", + Network: "udp", } ctx, cancel := context.WithCancel(context.Background()) cancel() @@ -171,9 +164,8 @@ func TestServerNameQUIC(t *testing.T) { sni := "sni" endpoint := &Endpoint{ Address: "example.com:123", - Logger: model.DiscardLogger, } - f := &quicHandshakeFunc{Pool: &ConnPool{}, ServerName: sni} + f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), ServerName: sni} serverName := f.serverName(endpoint) if serverName != sni { t.Fatalf("unexpected server name: %s", serverName) @@ -185,9 +177,8 @@ func TestServerNameQUIC(t *testing.T) { endpoint := &Endpoint{ Address: "example.com:123", Domain: domain, - Logger: model.DiscardLogger, } - f := &quicHandshakeFunc{Pool: &ConnPool{}} + f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} serverName := f.serverName(endpoint) if serverName != domain { t.Fatalf("unexpected server name: %s", serverName) @@ -198,9 +189,8 @@ func TestServerNameQUIC(t *testing.T) { hostaddr := "example.com" endpoint := &Endpoint{ Address: hostaddr + ":123", - Logger: model.DiscardLogger, } - f := &quicHandshakeFunc{Pool: &ConnPool{}} + f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} serverName := f.serverName(endpoint) if serverName != hostaddr { t.Fatalf("unexpected server name: %s", serverName) @@ -211,9 +201,8 @@ func TestServerNameQUIC(t *testing.T) { ip := "1.1.1.1" endpoint := &Endpoint{ Address: ip, - Logger: model.DiscardLogger, } - f := &quicHandshakeFunc{Pool: &ConnPool{}} + f := &quicHandshakeFunc{Rt: NewMinimalRuntime(model.DiscardLogger, time.Now())} serverName := f.serverName(endpoint) if serverName != "" { t.Fatalf("unexpected server name: %s", serverName) diff --git a/internal/dslx/runtimecore.go b/internal/dslx/runtimecore.go new file mode 100644 index 0000000000..4215431e54 --- /dev/null +++ b/internal/dslx/runtimecore.go @@ -0,0 +1,36 @@ +package dslx + +import ( + "io" + "sync/atomic" + "time" + + "github.com/ooni/probe-cli/v3/internal/model" +) + +// Runtime is the runtime in which we execute the DSL. +type Runtime interface { + // Close closes all the connection tracked using MaybeTrackConn. + Close() error + + // IDGenerator returns an atomic counter used to generate + // separate unique IDs for each trace. + IDGenerator() *atomic.Int64 + + // Logger returns the base logger to use. + Logger() model.Logger + + // MaybeTrackConn tracks a connection such that it is closed + // when you call the Runtime's Close method. + MaybeTrackConn(conn io.Closer) + + // NewTrace creates a [Trace] instance. Note that each [Runtime] + // creates its own [Trace] type. A [Trace] is not guaranteed to collect + // [*Observations]. For example, [NewMinimalRuntime] creates a [Runtime] + // that does not collect any [*Observations]. + NewTrace(index int64, zeroTime time.Time, tags ...string) Trace + + // ZeroTime returns the runtime's "zero" time, which is used as the + // starting point to generate observation's delta times. + ZeroTime() time.Time +} diff --git a/internal/dslx/runtimemeasurex.go b/internal/dslx/runtimemeasurex.go new file mode 100644 index 0000000000..a075d085b5 --- /dev/null +++ b/internal/dslx/runtimemeasurex.go @@ -0,0 +1,27 @@ +package dslx + +import ( + "time" + + "github.com/ooni/probe-cli/v3/internal/measurexlite" + "github.com/ooni/probe-cli/v3/internal/model" +) + +// NewRuntimeMeasurexLite creates a [Runtime] using [measurexlite] to collect [*Observations]. +func NewRuntimeMeasurexLite(logger model.Logger, zeroTime time.Time) *RuntimeMeasurexLite { + return &RuntimeMeasurexLite{ + MinimalRuntime: NewMinimalRuntime(logger, zeroTime), + } +} + +// RuntimeMeasurexLite uses [measurexlite] to collect [*Observations.] +type RuntimeMeasurexLite struct { + *MinimalRuntime +} + +// NewTrace implements Runtime. +func (p *RuntimeMeasurexLite) NewTrace(index int64, zeroTime time.Time, tags ...string) Trace { + return measurexlite.NewTrace(index, zeroTime, tags...) +} + +var _ Runtime = &RuntimeMeasurexLite{} diff --git a/internal/dslx/runtimeminimal.go b/internal/dslx/runtimeminimal.go new file mode 100644 index 0000000000..7003c2ceec --- /dev/null +++ b/internal/dslx/runtimeminimal.go @@ -0,0 +1,159 @@ +package dslx + +import ( + "io" + "sync" + "sync/atomic" + "time" + + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/ooni/probe-cli/v3/internal/netxlite" +) + +// NewMinimalRuntime creates a minimal [Runtime] implementation. +// +// This [Runtime] implementation does not collect any [*Observations]. +func NewMinimalRuntime(logger model.Logger, zeroTime time.Time) *MinimalRuntime { + return &MinimalRuntime{ + idg: &atomic.Int64{}, + logger: logger, + mu: sync.Mutex{}, + v: []io.Closer{}, + zeroT: zeroTime, + } +} + +var _ Runtime = &MinimalRuntime{} + +// MinimalRuntime is a minimal [Runtime] implementation. +type MinimalRuntime struct { + idg *atomic.Int64 + logger model.Logger + mu sync.Mutex + v []io.Closer + zeroT time.Time +} + +// IDGenerator implements Runtime. +func (p *MinimalRuntime) IDGenerator() *atomic.Int64 { + return p.idg +} + +// Logger implements Runtime. +func (p *MinimalRuntime) Logger() model.Logger { + return p.logger +} + +// ZeroTime implements Runtime. +func (p *MinimalRuntime) ZeroTime() time.Time { + return p.zeroT +} + +// MaybeTrackConn implements Runtime. +func (p *MinimalRuntime) MaybeTrackConn(conn io.Closer) { + if conn != nil { + defer p.mu.Unlock() + p.mu.Lock() + p.v = append(p.v, conn) + } +} + +// Close implements Runtime. +func (p *MinimalRuntime) Close() error { + // Implementation note: reverse order is such that we close TLS + // connections before we close the TCP connections they use. Hence + // we'll _gracefully_ close TLS connections. + defer p.mu.Unlock() + p.mu.Lock() + for idx := len(p.v) - 1; idx >= 0; idx-- { + _ = p.v[idx].Close() + } + p.v = nil // reset + return nil +} + +// NewTrace implements Runtime. +func (p *MinimalRuntime) NewTrace(index int64, zeroTime time.Time, tags ...string) Trace { + return &minimalTrace{idx: index, tags: tags, zt: zeroTime} +} + +type minimalTrace struct { + idx int64 + tags []string + zt time.Time +} + +// CloneBytesReceivedMap implements Trace. +func (tx *minimalTrace) CloneBytesReceivedMap() (out map[string]int64) { + return make(map[string]int64) +} + +// DNSLookupsFromRoundTrip implements Trace. +func (tx *minimalTrace) DNSLookupsFromRoundTrip() (out []*model.ArchivalDNSLookupResult) { + return []*model.ArchivalDNSLookupResult{} +} + +// Index implements Trace. +func (tx *minimalTrace) Index() int64 { + return tx.idx +} + +// NetworkEvents implements Trace. +func (tx *minimalTrace) NetworkEvents() (out []*model.ArchivalNetworkEvent) { + return []*model.ArchivalNetworkEvent{} +} + +// NewDialerWithoutResolver implements Trace. +func (tx *minimalTrace) NewDialerWithoutResolver(dl model.DebugLogger, wrappers ...model.DialerWrapper) model.Dialer { + return netxlite.NewDialerWithoutResolver(dl, wrappers...) +} + +// NewParallelUDPResolver implements Trace. +func (tx *minimalTrace) NewParallelUDPResolver(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver { + return netxlite.NewParallelUDPResolver(logger, dialer, address) +} + +// NewQUICDialerWithoutResolver implements Trace. +func (tx *minimalTrace) NewQUICDialerWithoutResolver(listener model.UDPListener, dl model.DebugLogger, wrappers ...model.QUICDialerWrapper) model.QUICDialer { + return netxlite.NewQUICDialerWithoutResolver(listener, dl, wrappers...) +} + +// NewStdlibResolver implements Trace. +func (tx *minimalTrace) NewStdlibResolver(logger model.DebugLogger) model.Resolver { + return netxlite.NewStdlibResolver(logger) +} + +// NewTLSHandshakerStdlib implements Trace. +func (tx *minimalTrace) NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker { + return netxlite.NewTLSHandshakerStdlib(dl) +} + +// QUICHandshakes implements Trace. +func (tx *minimalTrace) QUICHandshakes() (out []*model.ArchivalTLSOrQUICHandshakeResult) { + return []*model.ArchivalTLSOrQUICHandshakeResult{} +} + +// TCPConnects implements Trace. +func (tx *minimalTrace) TCPConnects() (out []*model.ArchivalTCPConnectResult) { + return []*model.ArchivalTCPConnectResult{} +} + +// TLSHandshakes implements Trace. +func (tx *minimalTrace) TLSHandshakes() (out []*model.ArchivalTLSOrQUICHandshakeResult) { + return []*model.ArchivalTLSOrQUICHandshakeResult{} +} + +// Tags implements Trace. +func (tx *minimalTrace) Tags() []string { + return tx.tags +} + +// TimeSince implements Trace. +func (tx *minimalTrace) TimeSince(t0 time.Time) time.Duration { + return time.Since(t0) +} + +// ZeroTime implements Trace. +func (tx *minimalTrace) ZeroTime() time.Time { + return tx.zt +} diff --git a/internal/dslx/runtimeminimal_test.go b/internal/dslx/runtimeminimal_test.go new file mode 100644 index 0000000000..4699787fb9 --- /dev/null +++ b/internal/dslx/runtimeminimal_test.go @@ -0,0 +1,240 @@ +package dslx + +import ( + "errors" + "io" + "testing" + "time" + + "github.com/google/go-cmp/cmp" + "github.com/ooni/probe-cli/v3/internal/mocks" + "github.com/ooni/probe-cli/v3/internal/model" + "github.com/quic-go/quic-go" +) + +/* +Test cases: +- Maybe track connections: + - with nil + - with connection + - with quic connection + +- Close MinimalRuntime: + - all Close() calls succeed + - one Close() call fails +*/ + +func closeableConnWithErr(err error) io.Closer { + return &mocks.Conn{ + MockClose: func() error { + return err + }, + } +} + +func closeableQUICConnWithErr(err error) io.Closer { + return &quicCloserConn{ + &mocks.QUICEarlyConnection{ + MockCloseWithError: func(code quic.ApplicationErrorCode, reason string) error { + return err + }, + }, + } +} + +func TestMinimalRuntime(t *testing.T) { + // testcase is a test case implemented by this function + type testcase struct { + mockConn io.Closer + want int // len of (*minimalRuntime).v + } + + t.Run("Maybe track connections", func(t *testing.T) { + tests := map[string]testcase{ + "with nil": {mockConn: nil, want: 0}, + "with connection": {mockConn: closeableConnWithErr(nil), want: 1}, + "with quic connection": {mockConn: closeableQUICConnWithErr(nil), want: 1}, + } + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + rt.MaybeTrackConn(tt.mockConn) + if len(rt.v) != tt.want { + t.Fatalf("expected %d tracked connections, got: %d", tt.want, len(rt.v)) + } + }) + } + }) + + t.Run("Close MinimalRuntime", func(t *testing.T) { + mockErr := errors.New("mocked") + tests := map[string]struct { + rt *MinimalRuntime + }{ + "all Close() calls succeed": { + rt: &MinimalRuntime{ + v: []io.Closer{ + closeableConnWithErr(nil), + closeableQUICConnWithErr(nil), + }, + }, + }, + "one Close() call fails": { + rt: &MinimalRuntime{ + v: []io.Closer{ + closeableConnWithErr(nil), + closeableConnWithErr(mockErr), + }, + }, + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + err := tt.rt.Close() + if err != nil { // Close() should always return nil + t.Fatalf("unexpected error %s", err) + } + if tt.rt.v != nil { + t.Fatalf("v should be reset but is not") + } + }) + } + }) + + t.Run("IDGenerator", func(t *testing.T) { + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + out := rt.IDGenerator() + if out == nil { + t.Fatal("expected non-nil pointer") + } + }) + + t.Run("Logger", func(t *testing.T) { + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + out := rt.Logger() + if out == nil { + t.Fatal("expected non-nil pointer") + } + }) + + t.Run("ZeroTime", func(t *testing.T) { + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + out := rt.ZeroTime() + if out.IsZero() { + t.Fatal("expected non-zero time") + } + }) + + t.Run("Trace", func(t *testing.T) { + tags := []string{"antani", "mascetti", "melandri"} + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) + now := time.Now() + trace := rt.NewTrace(10, now, tags...) + + t.Run("CloneBytesReceivedMap", func(t *testing.T) { + out := trace.CloneBytesReceivedMap() + if out == nil || len(out) != 0 { + t.Fatal("expected zero-length map") + } + }) + + t.Run("DNSLookupsFromRoundTrip", func(t *testing.T) { + out := trace.DNSLookupsFromRoundTrip() + if out == nil || len(out) != 0 { + t.Fatal("expected zero-length slice") + } + }) + + t.Run("Index", func(t *testing.T) { + out := trace.Index() + if out != 10 { + t.Fatal("expected 10, got", out) + } + }) + + t.Run("NetworkEvents", func(t *testing.T) { + out := trace.NetworkEvents() + if out == nil || len(out) != 0 { + t.Fatal("expected zero-length slice") + } + }) + + t.Run("NewDialerWithoutResolver", func(t *testing.T) { + out := trace.NewDialerWithoutResolver(model.DiscardLogger) + if out == nil { + t.Fatal("expected non-nil pointer") + } + }) + + t.Run("NewParallelUDPResolver", func(t *testing.T) { + out := trace.NewParallelUDPResolver(model.DiscardLogger, &mocks.Dialer{}, "8.8.8.8:53") + if out == nil { + t.Fatal("expected non-nil pointer") + } + }) + + t.Run("NewQUICDialerWithoutResolver", func(t *testing.T) { + out := trace.NewQUICDialerWithoutResolver(&mocks.UDPListener{}, model.DiscardLogger) + if out == nil { + t.Fatal("expected non-nil pointer") + } + }) + + t.Run("NewStdlibResolver", func(t *testing.T) { + out := trace.NewStdlibResolver(model.DiscardLogger) + if out == nil { + t.Fatal("expected non-nil pointer") + } + }) + + t.Run("NewTLSHandshakerStdlib", func(t *testing.T) { + out := trace.NewTLSHandshakerStdlib(model.DiscardLogger) + if out == nil { + t.Fatal("expected non-nil pointer") + } + }) + + t.Run("QUICHandshakes", func(t *testing.T) { + out := trace.QUICHandshakes() + if out == nil || len(out) != 0 { + t.Fatal("expected zero-length slice") + } + }) + + t.Run("TCPConnects", func(t *testing.T) { + out := trace.TCPConnects() + if out == nil || len(out) != 0 { + t.Fatal("expected zero-length slice") + } + }) + + t.Run("TLSHandshakes", func(t *testing.T) { + out := trace.TLSHandshakes() + if out == nil || len(out) != 0 { + t.Fatal("expected zero-length slice") + } + }) + + t.Run("Tags", func(t *testing.T) { + out := trace.Tags() + if diff := cmp.Diff(tags, out); diff != "" { + t.Fatal(diff) + } + }) + + t.Run("TimeSince", func(t *testing.T) { + out := trace.TimeSince(now.Add(-10 * time.Second)) + if out == 0 { + t.Fatal("expected non-zero time") + } + }) + + t.Run("ZeroTime", func(t *testing.T) { + out := trace.ZeroTime() + if out.IsZero() { + t.Fatal("expected non-zero time") + } + }) + }) +} diff --git a/internal/dslx/tcp.go b/internal/dslx/tcp.go index fe5d769000..af5dbcff3c 100644 --- a/internal/dslx/tcp.go +++ b/internal/dslx/tcp.go @@ -7,25 +7,23 @@ package dslx import ( "context" "net" - "sync/atomic" "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) // TCPConnect returns a function that establishes TCP connections. -func TCPConnect(pool *ConnPool) Func[*Endpoint, *Maybe[*TCPConnection]] { - f := &tcpConnectFunc{pool, nil} +func TCPConnect(rt Runtime) Func[*Endpoint, *Maybe[*TCPConnection]] { + f := &tcpConnectFunc{nil, rt} return f } // tcpConnectFunc is a function that establishes TCP connections. type tcpConnectFunc struct { - p *ConnPool dialer model.Dialer // for testing + rt Runtime } // Apply applies the function to its arguments. @@ -33,13 +31,13 @@ func (f *tcpConnectFunc) Apply( ctx context.Context, input *Endpoint) *Maybe[*TCPConnection] { // create trace - trace := measurexlite.NewTrace(input.IDGenerator.Add(1), input.ZeroTime, input.Tags...) + trace := f.rt.NewTrace(f.rt.IDGenerator().Add(1), f.rt.ZeroTime(), input.Tags...) // start the operation logger ol := logx.NewOperationLogger( - input.Logger, + f.rt.Logger(), "[#%d] TCPConnect %s", - trace.Index, + trace.Index(), input.Address, ) @@ -49,26 +47,23 @@ func (f *tcpConnectFunc) Apply( defer cancel() // obtain the dialer to use - dialer := f.dialerOrDefault(trace, input.Logger) + dialer := f.dialerOrDefault(trace, f.rt.Logger()) // connect conn, err := dialer.DialContext(ctx, "tcp", input.Address) // possibly register established conn for late close - f.p.MaybeTrack(conn) + f.rt.MaybeTrackConn(conn) // stop the operation logger ol.Stop(err) state := &TCPConnection{ - Address: input.Address, - Conn: conn, // possibly nil - Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, - Network: input.Network, - Trace: trace, - ZeroTime: input.ZeroTime, + Address: input.Address, + Conn: conn, // possibly nil + Domain: input.Domain, + Network: input.Network, + Trace: trace, } return &Maybe[*TCPConnection]{ @@ -80,7 +75,7 @@ func (f *tcpConnectFunc) Apply( } // dialerOrDefault is the function used to obtain a dialer -func (f *tcpConnectFunc) dialerOrDefault(trace *measurexlite.Trace, logger model.Logger) model.Dialer { +func (f *tcpConnectFunc) dialerOrDefault(trace Trace, logger model.Logger) model.Dialer { dialer := f.dialer if dialer == nil { dialer = trace.NewDialerWithoutResolver(logger) @@ -100,18 +95,9 @@ type TCPConnection struct { // Domain is the OPTIONAL domain from which we resolved the Address. Domain string - // IDGenerator is the MANDATORY ID generator. - IDGenerator *atomic.Int64 - - // Logger is the MANDATORY logger to use. - Logger model.Logger - // Network is the MANDATORY network we tried to use when connecting. Network string // Trace is the MANDATORY trace we're using. - Trace *measurexlite.Trace - - // ZeroTime is the MANDATORY zero time of the measurement. - ZeroTime time.Time + Trace Trace } diff --git a/internal/dslx/tcp_test.go b/internal/dslx/tcp_test.go index 6a94ea35c5..1ec42ef88a 100644 --- a/internal/dslx/tcp_test.go +++ b/internal/dslx/tcp_test.go @@ -4,7 +4,6 @@ import ( "context" "io" "net" - "sync/atomic" "testing" "time" @@ -17,7 +16,7 @@ import ( func TestTCPConnect(t *testing.T) { t.Run("Get tcpConnectFunc", func(t *testing.T) { f := TCPConnect( - &ConnPool{}, + NewMinimalRuntime(model.DiscardLogger, time.Now()), ) if _, ok := f.(*tcpConnectFunc); !ok { t.Fatal("unexpected type. Expected: tcpConnectFunc") @@ -69,15 +68,12 @@ func TestTCPConnect(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - pool := &ConnPool{} - tcpConnect := &tcpConnectFunc{pool, tt.dialer} + rt := NewRuntimeMeasurexLite(model.DiscardLogger, time.Now()) + tcpConnect := &tcpConnectFunc{tt.dialer, rt} endpoint := &Endpoint{ - Address: "1.2.3.4:567", - Network: "tcp", - IDGenerator: &atomic.Int64{}, - Logger: model.DiscardLogger, - Tags: tt.tags, - ZeroTime: time.Time{}, + Address: "1.2.3.4:567", + Network: "tcp", + Tags: tt.tags, } res := tcpConnect.Apply(context.Background(), endpoint) if res.Error != tt.expectErr { @@ -86,7 +82,7 @@ func TestTCPConnect(t *testing.T) { if res.State == nil || res.State.Conn != tt.expectConn { t.Fatal("unexpected conn") } - pool.Close() + rt.Close() if wasClosed != tt.closed { t.Fatalf("unexpected connection closed state: %v", wasClosed) } @@ -107,7 +103,7 @@ func TestTCPConnect(t *testing.T) { // Make sure we get a valid dialer if no mocked dialer is configured func TestDialerOrDefault(t *testing.T) { f := &tcpConnectFunc{ - p: &ConnPool{}, + rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), dialer: nil, } dialer := f.dialerOrDefault(measurexlite.NewTrace(0, time.Now()), model.DiscardLogger) diff --git a/internal/dslx/tls.go b/internal/dslx/tls.go index eac86bc923..59e508f681 100644 --- a/internal/dslx/tls.go +++ b/internal/dslx/tls.go @@ -9,11 +9,9 @@ import ( "crypto/tls" "crypto/x509" "net" - "sync/atomic" "time" "github.com/ooni/probe-cli/v3/internal/logx" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/netxlite" ) @@ -50,7 +48,7 @@ func TLSHandshakeOptionServerName(value string) TLSHandshakeOption { } // TLSHandshake returns a function performing TSL handshakes. -func TLSHandshake(pool *ConnPool, options ...TLSHandshakeOption) Func[ +func TLSHandshake(rt Runtime, options ...TLSHandshakeOption) Func[ *TCPConnection, *Maybe[*TLSConnection]] { // See https://github.com/ooni/probe/issues/2413 to understand // why we're using nil to force netxlite to use the cached @@ -58,8 +56,8 @@ func TLSHandshake(pool *ConnPool, options ...TLSHandshakeOption) Func[ f := &tlsHandshakeFunc{ InsecureSkipVerify: false, NextProto: []string{}, - Pool: pool, RootCAs: nil, + Rt: rt, ServerName: "", } for _, option := range options { @@ -76,12 +74,12 @@ type tlsHandshakeFunc struct { // NextProto contains the ALPNs to negotiate. NextProto []string - // Pool is the Pool that owns us. - Pool *ConnPool - // RootCAs contains the Root CAs to use. RootCAs *x509.CertPool + // Rt is the Runtime that owns us. + Rt Runtime + // ServerName is the ServerName to handshake for. ServerName string @@ -101,16 +99,16 @@ func (f *tlsHandshakeFunc) Apply( // start the operation logger ol := logx.NewOperationLogger( - input.Logger, + f.Rt.Logger(), "[#%d] TLSHandshake with %s SNI=%s ALPN=%v", - trace.Index, + trace.Index(), input.Address, serverName, nextProto, ) // obtain the handshaker for use - handshaker := f.handshakerOrDefault(trace, input.Logger) + handshaker := f.handshakerOrDefault(trace, f.Rt.Logger()) // setup config := &tls.Config{ @@ -127,21 +125,18 @@ func (f *tlsHandshakeFunc) Apply( conn, err := handshaker.Handshake(ctx, input.Conn, config) // possibly register established conn for late close - f.Pool.MaybeTrack(conn) + f.Rt.MaybeTrackConn(conn) // stop the operation logger ol.Stop(err) state := &TLSConnection{ - Address: input.Address, - Conn: conn, // possibly nil - Domain: input.Domain, - IDGenerator: input.IDGenerator, - Logger: input.Logger, - Network: input.Network, - TLSState: netxlite.MaybeTLSConnectionState(conn), - Trace: trace, - ZeroTime: input.ZeroTime, + Address: input.Address, + Conn: conn, // possibly nil + Domain: input.Domain, + Network: input.Network, + TLSState: netxlite.MaybeTLSConnectionState(conn), + Trace: trace, } return &Maybe[*TLSConnection]{ @@ -153,7 +148,7 @@ func (f *tlsHandshakeFunc) Apply( } // handshakerOrDefault is the function used to obtain an handshaker -func (f *tlsHandshakeFunc) handshakerOrDefault(trace *measurexlite.Trace, logger model.Logger) model.TLSHandshaker { +func (f *tlsHandshakeFunc) handshakerOrDefault(trace Trace, logger model.Logger) model.TLSHandshaker { handshaker := f.handshaker if handshaker == nil { handshaker = trace.NewTLSHandshakerStdlib(logger) @@ -175,7 +170,7 @@ func (f *tlsHandshakeFunc) serverName(input *TCPConnection) string { // Note: golang requires a ServerName and fails if it's empty. If the provided // ServerName is an IP address, however, golang WILL NOT emit any SNI extension // in the ClientHello, consistently with RFC 6066 Section 3 requirements. - input.Logger.Warn("TLSHandshake: cannot determine which SNI to use") + f.Rt.Logger().Warn("TLSHandshake: cannot determine which SNI to use") return "" } @@ -198,12 +193,6 @@ type TLSConnection struct { // Domain is the OPTIONAL domain we resolved. Domain string - // IDGenerator is the MANDATORY ID generator to use. - IDGenerator *atomic.Int64 - - // Logger is the MANDATORY logger to use. - Logger model.Logger - // Network is the MANDATORY network we tried to use when connecting. Network string @@ -211,8 +200,5 @@ type TLSConnection struct { TLSState tls.ConnectionState // Trace is the MANDATORY trace we're using. - Trace *measurexlite.Trace - - // ZeroTime is the MANDATORY zero time of the measurement. - ZeroTime time.Time + Trace Trace } diff --git a/internal/dslx/tls_test.go b/internal/dslx/tls_test.go index 3cba8f81d8..2fd209661b 100644 --- a/internal/dslx/tls_test.go +++ b/internal/dslx/tls_test.go @@ -31,7 +31,7 @@ func TestTLSHandshake(t *testing.T) { certpool.AddCert(&x509.Certificate{}) f := TLSHandshake( - &ConnPool{}, + NewMinimalRuntime(model.DiscardLogger, time.Now()), TLSHandshakeOptionInsecureSkipVerify(true), TLSHandshakeOptionNextProto([]string{"h2"}), TLSHandshakeOptionServerName("sni"), @@ -133,10 +133,10 @@ func TestTLSHandshake(t *testing.T) { for name, tt := range tests { t.Run(name, func(t *testing.T) { - pool := &ConnPool{} + rt := NewMinimalRuntime(model.DiscardLogger, time.Now()) tlsHandshake := &tlsHandshakeFunc{ NextProto: tt.config.nextProtos, - Pool: pool, + Rt: rt, ServerName: tt.config.sni, handshaker: tt.handshaker, } @@ -148,13 +148,10 @@ func TestTLSHandshake(t *testing.T) { address = "1.2.3.4:567" } tcpConn := TCPConnection{ - Address: address, - Conn: &tcpConn, - IDGenerator: idGen, - Logger: model.DiscardLogger, - Network: "tcp", - Trace: trace, - ZeroTime: zeroTime, + Address: address, + Conn: &tcpConn, + Network: "tcp", + Trace: trace, } res := tlsHandshake.Apply(context.Background(), &tcpConn) if res.Error != tt.expectErr { @@ -163,7 +160,7 @@ func TestTLSHandshake(t *testing.T) { if res.State.Conn != tt.expectConn { t.Fatalf("unexpected conn %v", res.State.Conn) } - pool.Close() + rt.Close() if wasClosed != tt.closed { t.Fatalf("unexpected connection closed state %v", wasClosed) } @@ -185,10 +182,9 @@ func TestServerNameTLS(t *testing.T) { sni := "sni" tcpConn := TCPConnection{ Address: "example.com:123", - Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ - Pool: &ConnPool{}, + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), ServerName: sni, } serverName := f.serverName(&tcpConn) @@ -201,10 +197,9 @@ func TestServerNameTLS(t *testing.T) { tcpConn := TCPConnection{ Address: "example.com:123", Domain: domain, - Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ - Pool: &ConnPool{}, + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } serverName := f.serverName(&tcpConn) if serverName != domain { @@ -215,10 +210,9 @@ func TestServerNameTLS(t *testing.T) { hostaddr := "example.com" tcpConn := TCPConnection{ Address: hostaddr + ":123", - Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ - Pool: &ConnPool{}, + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } serverName := f.serverName(&tcpConn) if serverName != hostaddr { @@ -229,10 +223,9 @@ func TestServerNameTLS(t *testing.T) { ip := "1.1.1.1" tcpConn := TCPConnection{ Address: ip, - Logger: model.DiscardLogger, } f := &tlsHandshakeFunc{ - Pool: &ConnPool{}, + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), } serverName := f.serverName(&tcpConn) if serverName != "" { @@ -246,7 +239,7 @@ func TestHandshakerOrDefault(t *testing.T) { f := &tlsHandshakeFunc{ InsecureSkipVerify: false, NextProto: []string{}, - Pool: &ConnPool{}, + Rt: NewMinimalRuntime(model.DiscardLogger, time.Now()), RootCAs: &x509.CertPool{}, ServerName: "", handshaker: nil, diff --git a/internal/dslx/trace.go b/internal/dslx/trace.go new file mode 100644 index 0000000000..09094712ac --- /dev/null +++ b/internal/dslx/trace.go @@ -0,0 +1,71 @@ +package dslx + +import ( + "time" + + "github.com/ooni/probe-cli/v3/internal/model" +) + +// Trace collects [*Observations] using tracing. Specific implementations +// of this interface may be engineered to collect no [*Observations] for +// efficiency (i.e., when you don't care about collecting [*Observations] +// but you still want to use this package). +type Trace interface { + // CloneBytesReceivedMap returns a clone of the internal bytes received map. The key of the + // map is a string following the "EPNT_ADDRESS PROTO" pattern where the "EPNT_ADDRESS" contains + // the endpoint address and "PROTO" is "tcp" or "udp". + CloneBytesReceivedMap() (out map[string]int64) + + // DNSLookupsFromRoundTrip returns all the DNS lookup results collected so far. + DNSLookupsFromRoundTrip() (out []*model.ArchivalDNSLookupResult) + + // Index returns the unique index used by this trace. + Index() int64 + + // NewDialerWithoutResolver is equivalent to netxlite.NewDialerWithoutResolver + // except that it returns a model.Dialer that uses this trace. + // + // Caveat: the dialer wrappers are there to implement the + // model.MeasuringNetwork interface, but they're not used by this function. + NewDialerWithoutResolver(dl model.DebugLogger, wrappers ...model.DialerWrapper) model.Dialer + + // NewParallelUDPResolver returns a possibly-trace-ware parallel UDP resolver + NewParallelUDPResolver(logger model.DebugLogger, dialer model.Dialer, address string) model.Resolver + + // NewQUICDialerWithoutResolver is equivalent to + // netxlite.NewQUICDialerWithoutResolver except that it returns a + // model.QUICDialer that uses this trace. + // + // Caveat: the dialer wrappers are there to implement the + // model.MeasuringNetwork interface, but they're not used by this function. + NewQUICDialerWithoutResolver(listener model.UDPListener, + dl model.DebugLogger, wrappers ...model.QUICDialerWrapper) model.QUICDialer + + // NewTLSHandshakerStdlib is equivalent to netxlite.NewTLSHandshakerStdlib + // except that it returns a model.TLSHandshaker that uses this trace. + NewTLSHandshakerStdlib(dl model.DebugLogger) model.TLSHandshaker + + // NetworkEvents returns all the network events collected so far. + NetworkEvents() (out []*model.ArchivalNetworkEvent) + + // NewStdlibResolver returns a possibly-trace-ware system resolver. + NewStdlibResolver(logger model.DebugLogger) model.Resolver + + // QUICHandshakes collects all the QUIC handshake results collected so far. + QUICHandshakes() (out []*model.ArchivalTLSOrQUICHandshakeResult) + + // TCPConnects collects all the TCP connect results collected so far. + TCPConnects() (out []*model.ArchivalTCPConnectResult) + + // TLSHandshakes collects all the TLS handshake results collected so far. + TLSHandshakes() (out []*model.ArchivalTLSOrQUICHandshakeResult) + + // Tags returns the trace tags. + Tags() []string + + // TimeSince is equivalent to Trace.TimeNow().Sub(t0). + TimeSince(t0 time.Time) time.Duration + + // ZeroTime returns the "zero" time of this trace. + ZeroTime() time.Time +} diff --git a/internal/experiment/webconnectivitylte/cleartextflow.go b/internal/experiment/webconnectivitylte/cleartextflow.go index fb51486bf4..8ec7cc663c 100644 --- a/internal/experiment/webconnectivitylte/cleartextflow.go +++ b/internal/experiment/webconnectivitylte/cleartextflow.go @@ -242,9 +242,9 @@ func (t *CleartextFlow) newHTTPRequest(ctx context.Context) (*http.Request, erro func (t *CleartextFlow) httpTransaction(ctx context.Context, network, address, alpn string, txp model.HTTPTransport, req *http.Request, trace *measurexlite.Trace) (*http.Response, []byte, error) { const maxbody = 1 << 19 - started := trace.TimeSince(trace.ZeroTime) + started := trace.TimeSince(trace.ZeroTime()) t.TestKeys.AppendNetworkEvents(measurexlite.NewAnnotationArchivalNetworkEvent( - trace.Index, started, "http_transaction_start", + trace.Index(), started, "http_transaction_start", )) resp, err := txp.RoundTrip(req) var body []byte @@ -256,12 +256,12 @@ func (t *CleartextFlow) httpTransaction(ctx context.Context, network, address, a reader := io.LimitReader(resp.Body, maxbody) body, err = StreamAllContext(ctx, reader) } - finished := trace.TimeSince(trace.ZeroTime) + finished := trace.TimeSince(trace.ZeroTime()) t.TestKeys.AppendNetworkEvents(measurexlite.NewAnnotationArchivalNetworkEvent( - trace.Index, finished, "http_transaction_done", + trace.Index(), finished, "http_transaction_done", )) ev := measurexlite.NewArchivalHTTPRequestResult( - trace.Index, + trace.Index(), started, network, address, diff --git a/internal/experiment/webconnectivitylte/secureflow.go b/internal/experiment/webconnectivitylte/secureflow.go index 1f63c17434..6284eca649 100644 --- a/internal/experiment/webconnectivitylte/secureflow.go +++ b/internal/experiment/webconnectivitylte/secureflow.go @@ -297,9 +297,9 @@ func (t *SecureFlow) newHTTPRequest(ctx context.Context) (*http.Request, error) func (t *SecureFlow) httpTransaction(ctx context.Context, network, address, alpn string, txp model.HTTPTransport, req *http.Request, trace *measurexlite.Trace) (*http.Response, []byte, error) { const maxbody = 1 << 19 - started := trace.TimeSince(trace.ZeroTime) + started := trace.TimeSince(trace.ZeroTime()) t.TestKeys.AppendNetworkEvents(measurexlite.NewAnnotationArchivalNetworkEvent( - trace.Index, started, "http_transaction_start", + trace.Index(), started, "http_transaction_start", )) resp, err := txp.RoundTrip(req) var body []byte @@ -311,12 +311,12 @@ func (t *SecureFlow) httpTransaction(ctx context.Context, network, address, alpn reader := io.LimitReader(resp.Body, maxbody) body, err = StreamAllContext(ctx, reader) } - finished := trace.TimeSince(trace.ZeroTime) + finished := trace.TimeSince(trace.ZeroTime()) t.TestKeys.AppendNetworkEvents(measurexlite.NewAnnotationArchivalNetworkEvent( - trace.Index, finished, "http_transaction_done", + trace.Index(), finished, "http_transaction_done", )) ev := measurexlite.NewArchivalHTTPRequestResult( - trace.Index, + trace.Index(), started, network, address, diff --git a/internal/measurexlite/conn.go b/internal/measurexlite/conn.go index deb14293cf..54beade596 100644 --- a/internal/measurexlite/conn.go +++ b/internal/measurexlite/conn.go @@ -44,16 +44,16 @@ func (c *connTrace) Read(b []byte) (int, error) { // collect preliminary stats when the connection is surely active network := c.RemoteAddr().Network() addr := c.RemoteAddr().String() - started := c.tx.TimeSince(c.tx.ZeroTime) + started := c.tx.TimeSince(c.tx.ZeroTime()) // perform the underlying network operation count, err := c.Conn.Read(b) // emit the network event - finished := c.tx.TimeSince(c.tx.ZeroTime) + finished := c.tx.TimeSince(c.tx.ZeroTime()) select { case c.tx.networkEvent <- NewArchivalNetworkEvent( - c.tx.Index, started, netxlite.ReadOperation, network, addr, count, + c.tx.Index(), started, netxlite.ReadOperation, network, addr, count, err, finished, c.tx.tags...): default: // buffer is full } @@ -101,14 +101,14 @@ func (tx *Trace) CloneBytesReceivedMap() (out map[string]int64) { func (c *connTrace) Write(b []byte) (int, error) { network := c.RemoteAddr().Network() addr := c.RemoteAddr().String() - started := c.tx.TimeSince(c.tx.ZeroTime) + started := c.tx.TimeSince(c.tx.ZeroTime()) count, err := c.Conn.Write(b) - finished := c.tx.TimeSince(c.tx.ZeroTime) + finished := c.tx.TimeSince(c.tx.ZeroTime()) select { case c.tx.networkEvent <- NewArchivalNetworkEvent( - c.tx.Index, started, netxlite.WriteOperation, network, addr, count, + c.tx.Index(), started, netxlite.WriteOperation, network, addr, count, err, finished, c.tx.tags...): default: // buffer is full } @@ -143,17 +143,17 @@ type udpLikeConnTrace struct { // Read implements model.UDPLikeConn.ReadFrom and saves network events. func (c *udpLikeConnTrace) ReadFrom(b []byte) (int, net.Addr, error) { // record when we started measuring - started := c.tx.TimeSince(c.tx.ZeroTime) + started := c.tx.TimeSince(c.tx.ZeroTime()) // perform the network operation count, addr, err := c.UDPLikeConn.ReadFrom(b) // emit the network event - finished := c.tx.TimeSince(c.tx.ZeroTime) + finished := c.tx.TimeSince(c.tx.ZeroTime()) address := addrStringIfNotNil(addr) select { case c.tx.networkEvent <- NewArchivalNetworkEvent( - c.tx.Index, started, netxlite.ReadFromOperation, "udp", address, count, + c.tx.Index(), started, netxlite.ReadFromOperation, "udp", address, count, err, finished, c.tx.tags...): default: // buffer is full } @@ -176,15 +176,15 @@ func (tx *Trace) maybeUpdateBytesReceivedMapUDPLikeConn(addr net.Addr, count int // Write implements model.UDPLikeConn.WriteTo and saves network events. func (c *udpLikeConnTrace) WriteTo(b []byte, addr net.Addr) (int, error) { - started := c.tx.TimeSince(c.tx.ZeroTime) + started := c.tx.TimeSince(c.tx.ZeroTime()) address := addr.String() count, err := c.UDPLikeConn.WriteTo(b, addr) - finished := c.tx.TimeSince(c.tx.ZeroTime) + finished := c.tx.TimeSince(c.tx.ZeroTime()) select { case c.tx.networkEvent <- NewArchivalNetworkEvent( - c.tx.Index, started, netxlite.WriteToOperation, "udp", address, count, + c.tx.Index(), started, netxlite.WriteToOperation, "udp", address, count, err, finished, c.tx.tags...): default: // buffer is full } diff --git a/internal/measurexlite/dialer.go b/internal/measurexlite/dialer.go index e0837a6b66..8284ec2558 100644 --- a/internal/measurexlite/dialer.go +++ b/internal/measurexlite/dialer.go @@ -55,11 +55,11 @@ func (tx *Trace) OnConnectDone( // insert into the tcpConnect buffer select { case tx.tcpConnect <- NewArchivalTCPConnectResult( - tx.Index, - started.Sub(tx.ZeroTime), + tx.Index(), + started.Sub(tx.ZeroTime()), remoteAddr, err, - finished.Sub(tx.ZeroTime), + finished.Sub(tx.ZeroTime()), tx.tags..., ): default: // buffer is full @@ -69,14 +69,14 @@ func (tx *Trace) OnConnectDone( // see https://github.com/ooni/probe/issues/2254 select { case tx.networkEvent <- NewArchivalNetworkEvent( - tx.Index, - started.Sub(tx.ZeroTime), + tx.Index(), + started.Sub(tx.ZeroTime()), netxlite.ConnectOperation, "tcp", remoteAddr, 0, err, - finished.Sub(tx.ZeroTime), + finished.Sub(tx.ZeroTime()), tx.tags..., ): default: // buffer is full diff --git a/internal/measurexlite/dns.go b/internal/measurexlite/dns.go index 8c72544426..f450f33673 100644 --- a/internal/measurexlite/dns.go +++ b/internal/measurexlite/dns.go @@ -52,7 +52,7 @@ func (r *resolverTrace) CloseIdleConnections() { func (r *resolverTrace) emitResolveStart() { select { case r.tx.networkEvent <- NewAnnotationArchivalNetworkEvent( - r.tx.Index, r.tx.TimeSince(r.tx.ZeroTime), "resolve_start", + r.tx.Index(), r.tx.TimeSince(r.tx.ZeroTime()), "resolve_start", r.tx.tags..., ): default: // buffer is full @@ -63,7 +63,7 @@ func (r *resolverTrace) emitResolveStart() { func (r *resolverTrace) emiteResolveDone() { select { case r.tx.networkEvent <- NewAnnotationArchivalNetworkEvent( - r.tx.Index, r.tx.TimeSince(r.tx.ZeroTime), "resolve_done", + r.tx.Index(), r.tx.TimeSince(r.tx.ZeroTime()), "resolve_done", r.tx.tags..., ): default: // buffer is full @@ -109,12 +109,12 @@ func (tx *Trace) NewParallelDNSOverHTTPSResolver(logger model.DebugLogger, URL s // OnDNSRoundTripForLookupHost implements model.Trace.OnDNSRoundTripForLookupHost func (tx *Trace) OnDNSRoundTripForLookupHost(started time.Time, reso model.Resolver, query model.DNSQuery, response model.DNSResponse, addrs []string, err error, finished time.Time) { - t := finished.Sub(tx.ZeroTime) + t := finished.Sub(tx.ZeroTime()) select { case tx.dnsLookup <- NewArchivalDNSLookupResultFromRoundTrip( - tx.Index, - started.Sub(tx.ZeroTime), + tx.Index(), + started.Sub(tx.ZeroTime()), reso, query, response, @@ -274,12 +274,12 @@ var ErrDelayedDNSResponseBufferFull = errors.New( // OnDelayedDNSResponse implements model.Trace.OnDelayedDNSResponse func (tx *Trace) OnDelayedDNSResponse(started time.Time, txp model.DNSTransport, query model.DNSQuery, response model.DNSResponse, addrs []string, err error, finished time.Time) error { - t := finished.Sub(tx.ZeroTime) + t := finished.Sub(tx.ZeroTime()) select { case tx.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip( - tx.Index, - started.Sub(tx.ZeroTime), + tx.Index(), + started.Sub(tx.ZeroTime()), txp, query, response, diff --git a/internal/measurexlite/dns_test.go b/internal/measurexlite/dns_test.go index aba71d1489..7b2cd48886 100644 --- a/internal/measurexlite/dns_test.go +++ b/internal/measurexlite/dns_test.go @@ -519,8 +519,8 @@ func TestDelayedDNSResponseWithTimeout(t *testing.T) { } for i := 0; i < events; i++ { // fill the trace - trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index, started.Sub(trace.ZeroTime), - txp, query, dnsResponse, addrs, nil, finished.Sub(trace.ZeroTime)) + trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index(), started.Sub(trace.ZeroTime()), + txp, query, dnsResponse, addrs, nil, finished.Sub(trace.ZeroTime())) } ctx, cancel := context.WithCancel(context.Background()) cancel() // we ensure that the context cancels before draining all the events @@ -566,8 +566,8 @@ func TestDelayedDNSResponseWithTimeout(t *testing.T) { return []byte{} }, } - trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index, started.Sub(trace.ZeroTime), - txp, query, dnsResponse, addrs, nil, finished.Sub(trace.ZeroTime)) + trace.delayedDNSResponse <- NewArchivalDNSLookupResultFromRoundTrip(trace.Index(), started.Sub(trace.ZeroTime()), + txp, query, dnsResponse, addrs, nil, finished.Sub(trace.ZeroTime())) got := trace.DelayedDNSResponseWithTimeout(context.Background(), time.Second) if len(got) != 1 { t.Fatal("unexpected output from trace") diff --git a/internal/measurexlite/quic.go b/internal/measurexlite/quic.go index 949068f505..6df2149c57 100644 --- a/internal/measurexlite/quic.go +++ b/internal/measurexlite/quic.go @@ -49,10 +49,10 @@ func (qdx *quicDialerTrace) CloseIdleConnections() { // OnQUICHandshakeStart implements model.Trace.OnQUICHandshakeStart func (tx *Trace) OnQUICHandshakeStart(now time.Time, remoteAddr string, config *quic.Config) { - t := now.Sub(tx.ZeroTime) + t := now.Sub(tx.ZeroTime()) select { case tx.networkEvent <- NewAnnotationArchivalNetworkEvent( - tx.Index, t, "quic_handshake_start", tx.tags...): + tx.Index(), t, "quic_handshake_start", tx.tags...): default: } } @@ -60,7 +60,7 @@ func (tx *Trace) OnQUICHandshakeStart(now time.Time, remoteAddr string, config * // OnQUICHandshakeDone implements model.Trace.OnQUICHandshakeDone func (tx *Trace) OnQUICHandshakeDone(started time.Time, remoteAddr string, qconn quic.EarlyConnection, config *tls.Config, err error, finished time.Time) { - t := finished.Sub(tx.ZeroTime) + t := finished.Sub(tx.ZeroTime()) state := tls.ConnectionState{} if qconn != nil { @@ -69,8 +69,8 @@ func (tx *Trace) OnQUICHandshakeDone(started time.Time, remoteAddr string, qconn select { case tx.quicHandshake <- NewArchivalTLSOrQUICHandshakeResult( - tx.Index, - started.Sub(tx.ZeroTime), + tx.Index(), + started.Sub(tx.ZeroTime()), "udp", remoteAddr, config, @@ -84,7 +84,7 @@ func (tx *Trace) OnQUICHandshakeDone(started time.Time, remoteAddr string, qconn select { case tx.networkEvent <- NewAnnotationArchivalNetworkEvent( - tx.Index, t, "quic_handshake_done", tx.tags...): + tx.Index(), t, "quic_handshake_done", tx.tags...): default: // buffer is full } } diff --git a/internal/measurexlite/tls.go b/internal/measurexlite/tls.go index af0850c2e3..636bdae3aa 100644 --- a/internal/measurexlite/tls.go +++ b/internal/measurexlite/tls.go @@ -41,10 +41,10 @@ func (thx *tlsHandshakerTrace) Handshake( // OnTLSHandshakeStart implements model.Trace.OnTLSHandshakeStart. func (tx *Trace) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *tls.Config) { - t := now.Sub(tx.ZeroTime) + t := now.Sub(tx.ZeroTime()) select { case tx.networkEvent <- NewAnnotationArchivalNetworkEvent( - tx.Index, t, "tls_handshake_start", tx.tags...): + tx.Index(), t, "tls_handshake_start", tx.tags...): default: // buffer is full } } @@ -52,12 +52,12 @@ func (tx *Trace) OnTLSHandshakeStart(now time.Time, remoteAddr string, config *t // OnTLSHandshakeDone implements model.Trace.OnTLSHandshakeDone. func (tx *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config *tls.Config, state tls.ConnectionState, err error, finished time.Time) { - t := finished.Sub(tx.ZeroTime) + t := finished.Sub(tx.ZeroTime()) select { case tx.tlsHandshake <- NewArchivalTLSOrQUICHandshakeResult( - tx.Index, - started.Sub(tx.ZeroTime), + tx.Index(), + started.Sub(tx.ZeroTime()), "tcp", remoteAddr, config, @@ -71,7 +71,7 @@ func (tx *Trace) OnTLSHandshakeDone(started time.Time, remoteAddr string, config select { case tx.networkEvent <- NewAnnotationArchivalNetworkEvent( - tx.Index, t, "tls_handshake_done", tx.tags...): + tx.Index(), t, "tls_handshake_done", tx.tags...): default: // buffer is full } } diff --git a/internal/measurexlite/trace.go b/internal/measurexlite/trace.go index 7a7e79a6f8..91187fc89c 100644 --- a/internal/measurexlite/trace.go +++ b/internal/measurexlite/trace.go @@ -25,10 +25,8 @@ import ( // // [step-by-step measurements]: https://github.com/ooni/probe-cli/blob/master/docs/design/dd-003-step-by-step.md type Trace struct { - // Index is the unique index of this trace within the - // current measurement. Note that this field MUST be read-only. Writing it - // once you have constructed a trace MAY lead to data races. - Index int64 + // index is the unique index of this trace within the current measurement. + index int64 // Netx is the network to use for measuring. The constructor inits this // field using a [*netxlite.Netx]. You MAY override this field for testing. Make @@ -69,10 +67,8 @@ type Trace struct { // to produce deterministic timing when testing. timeNowFn func() time.Time - // ZeroTime is the time when we started the current measurement. This field - // MUST be read-only. Writing it once you have constructed the trace will - // likely read to data races. - ZeroTime time.Time + // zeroTime is the time when we started the current measurement. + zeroTime time.Time } var _ model.MeasuringNetwork = &Trace{} @@ -111,7 +107,7 @@ const QUICHandshakeBufferSize = 8 // to identify that some traces belong to some submeasurements). func NewTrace(index int64, zeroTime time.Time, tags ...string) *Trace { return &Trace{ - Index: index, + index: index, Netx: &netxlite.Netx{Underlying: nil}, // use the host network bytesReceivedMap: make(map[string]int64), bytesReceivedMu: &sync.Mutex{}, @@ -141,10 +137,20 @@ func NewTrace(index int64, zeroTime time.Time, tags ...string) *Trace { ), tags: tags, timeNowFn: nil, // use default - ZeroTime: zeroTime, + zeroTime: zeroTime, } } +// Index returns the trace index. +func (tx *Trace) Index() int64 { + return tx.index +} + +// ZeroTime returns trace's zero time. +func (tx *Trace) ZeroTime() time.Time { + return tx.zeroTime +} + // TimeNow implements model.Trace.TimeNow. func (tx *Trace) TimeNow() time.Time { if tx.timeNowFn != nil { diff --git a/internal/measurexlite/trace_test.go b/internal/measurexlite/trace_test.go index 111a5274ea..686d6540d7 100644 --- a/internal/measurexlite/trace_test.go +++ b/internal/measurexlite/trace_test.go @@ -25,7 +25,7 @@ func TestNewTrace(t *testing.T) { trace := NewTrace(index, zeroTime) t.Run("Index", func(t *testing.T) { - if trace.Index != index { + if trace.Index() != index { t.Fatal("invalid index") } }) @@ -164,7 +164,7 @@ func TestNewTrace(t *testing.T) { }) t.Run("ZeroTime", func(t *testing.T) { - if !trace.ZeroTime.Equal(zeroTime) { + if !trace.ZeroTime().Equal(zeroTime) { t.Fatal("invalid zero time") } }) diff --git a/internal/throttling/throttling.go b/internal/throttling/throttling.go index 46252ed33f..a9de7e3800 100644 --- a/internal/throttling/throttling.go +++ b/internal/throttling/throttling.go @@ -7,13 +7,32 @@ import ( "sync" "time" - "github.com/ooni/probe-cli/v3/internal/measurexlite" "github.com/ooni/probe-cli/v3/internal/memoryless" "github.com/ooni/probe-cli/v3/internal/model" "github.com/ooni/probe-cli/v3/internal/runtimex" ) -// Sampler periodically samples the bytes sent and received by a [*measurexlite.Trace]. The zero +// Trace is the [*measurexlite.Trace] abstraction used by this package. +type Trace interface { + // CloneBytesReceivedMap returns a clone of the internal bytes received map. The key of the + // map is a string following the "EPNT_ADDRESS PROTO" pattern where the "EPNT_ADDRESS" contains + // the endpoint address and "PROTO" is "tcp" or "udp". + CloneBytesReceivedMap() (out map[string]int64) + + // Index returns the unique index used by this trace. + Index() int64 + + // Tags returns the trace tags. + Tags() []string + + // TimeSince is equivalent to Trace.TimeNow().Sub(t0). + TimeSince(t0 time.Time) time.Duration + + // ZeroTime returns the "zero" time of this trace. + ZeroTime() time.Time +} + +// Sampler periodically samples the bytes sent and received by a [Trace]. The zero // value of this structure is invalid; please, construct using [NewSampler]. type Sampler struct { // cancel tells the background goroutine to stop @@ -29,16 +48,16 @@ type Sampler struct { q []*model.ArchivalNetworkEvent // tx is the trace we are sampling from - tx *measurexlite.Trace + tx Trace // wg is the waitgroup to wait for the sampler to join wg *sync.WaitGroup } -// NewSampler attaches a [*Sampler] to a [*measurexlite.Trace], starts sampling in the +// NewSampler attaches a [*Sampler] to a [Trace], starts sampling in the // background and returns the [*Sampler]. Remember to call [*Sampler.Close] to stop // the background goroutine that performs the sampling. -func NewSampler(tx *measurexlite.Trace) *Sampler { +func NewSampler(tx Trace) *Sampler { ctx, cancel := context.WithCancel(context.Background()) smpl := &Sampler{ cancel: cancel, @@ -95,7 +114,7 @@ const BytesReceivedCumulativeOperation = "bytes_received_cumulative" func (smpl *Sampler) collectSnapshot(stats map[string]int64) { // compute just once the events sampling time - now := smpl.tx.TimeSince(smpl.tx.ZeroTime).Seconds() + now := smpl.tx.TimeSince(smpl.tx.ZeroTime()).Seconds() // process each entry for key, count := range stats { @@ -116,7 +135,7 @@ func (smpl *Sampler) collectSnapshot(stats map[string]int64) { Proto: network, T0: now, T: now, - TransactionID: smpl.tx.Index, + TransactionID: smpl.tx.Index(), Tags: smpl.tx.Tags(), } diff --git a/internal/tutorial/dslx/chapter02/README.md b/internal/tutorial/dslx/chapter02/README.md index 5722218143..359344e4c4 100644 --- a/internal/tutorial/dslx/chapter02/README.md +++ b/internal/tutorial/dslx/chapter02/README.md @@ -43,7 +43,6 @@ import ( "context" "errors" "net" - "sync/atomic" "github.com/ooni/probe-cli/v3/internal/dslx" "github.com/ooni/probe-cli/v3/internal/model" @@ -133,7 +132,6 @@ of dslx pipelines a unique identifier). ```Go type Measurer struct { config Config - idGen atomic.Int64 } var _ model.ExperimentMeasurer = &Measurer{} @@ -178,15 +176,6 @@ So, this is where we will use `dslx` to implement the SNI blocking experiment. func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { ``` -### Define measurement parameters - -`sess` is the session of this measurement run. - -```Go - sess := args.Session - -``` - `measurement` contains metadata, the (required) input in form of the target SNI, and the nettest results (`TestKeys`). @@ -249,18 +238,24 @@ experiment's start time. ```Go dnsInput := dslx.NewDomainToResolve( dslx.DomainName(thaddrHost), - dslx.DNSLookupOptionIDGenerator(&m.idGen), - dslx.DNSLookupOptionLogger(sess.Logger()), - dslx.DNSLookupOptionZeroTime(measurement.MeasurementStartTimeSaved), ) ``` +Next, we create a minimal runtime. This data structure helps us to manage +open connections and close them when `rt.Close` is invoked. + +```Go + rt := dslx.NewMinimalRuntime(args.Session.Logger(), args.Measurement.MeasurementStartTimeSaved) + defer rt.Close() + +``` + We construct the resolver dslx function which can be - like in this case - the system resolver, or a custom UDP resolver. ```Go - lookupFn := dslx.DNSLookupGetaddrinfo() + lookupFn := dslx.DNSLookupGetaddrinfo(rt) ``` @@ -322,24 +317,12 @@ the protocol, address, and port three-tuple.) dslx.EndpointNetwork("tcp"), dslx.EndpointPort(443), dslx.EndpointOptionDomain(m.config.TestHelperAddress), - dslx.EndpointOptionIDGenerator(&m.idGen), - dslx.EndpointOptionLogger(sess.Logger()), - dslx.EndpointOptionZeroTime(measurement.MeasurementStartTimeSaved), ) runtimex.Assert(len(endpoints) >= 1, "expected at least one endpoint here") endpoint := endpoints[0] ``` -Next, we create a connection pool. This data structure helps us to manage -open connections and close them when `connpool.Close` is invoked. - -```Go - connpool := &dslx.ConnPool{} - defer connpool.Close() - -``` - In the following we compose step-by-step measurement "pipelines", represented by `dslx` functions. @@ -350,9 +333,9 @@ target SNI to be used within the TLS Client Hello. ```Go pipelineTarget := dslx.Compose2( - dslx.TCPConnect(connpool), + dslx.TCPConnect(rt), dslx.TLSHandshake( - connpool, + rt, dslx.TLSHandshakeOptionServerName(targetSNI), ), ) @@ -364,9 +347,9 @@ specify the *control* SNI to be used within the TLS Client Hello. ```Go pipelineControl := dslx.Compose2( - dslx.TCPConnect(connpool), + dslx.TCPConnect(rt), dslx.TLSHandshake( - connpool, + rt, dslx.TLSHandshakeOptionServerName(m.config.ControlSNI), ), ) diff --git a/internal/tutorial/dslx/chapter02/main.go b/internal/tutorial/dslx/chapter02/main.go index e956c0cacf..ef0e3c62c4 100644 --- a/internal/tutorial/dslx/chapter02/main.go +++ b/internal/tutorial/dslx/chapter02/main.go @@ -44,7 +44,6 @@ import ( "context" "errors" "net" - "sync/atomic" "github.com/ooni/probe-cli/v3/internal/dslx" "github.com/ooni/probe-cli/v3/internal/model" @@ -134,7 +133,6 @@ func (tk *Subresult) mergeObservations(obs []*dslx.Observations) { // ```Go type Measurer struct { config Config - idGen atomic.Int64 } var _ model.ExperimentMeasurer = &Measurer{} @@ -177,15 +175,6 @@ func (m *Measurer) GetSummaryKeys(measurement *model.Measurement) (interface{}, // // ```Go func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { - // ``` - // - // ### Define measurement parameters - // - // `sess` is the session of this measurement run. - // - // ```Go - sess := args.Session - // ``` // // `measurement` contains metadata, the (required) input in form of @@ -250,18 +239,24 @@ func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { // ```Go dnsInput := dslx.NewDomainToResolve( dslx.DomainName(thaddrHost), - dslx.DNSLookupOptionIDGenerator(&m.idGen), - dslx.DNSLookupOptionLogger(sess.Logger()), - dslx.DNSLookupOptionZeroTime(measurement.MeasurementStartTimeSaved), ) + // ``` + // + // Next, we create a minimal runtime. This data structure helps us to manage + // open connections and close them when `rt.Close` is invoked. + // + // ```Go + rt := dslx.NewMinimalRuntime(args.Session.Logger(), args.Measurement.MeasurementStartTimeSaved) + defer rt.Close() + // ``` // // We construct the resolver dslx function which can be - like in this case - the // system resolver, or a custom UDP resolver. // // ```Go - lookupFn := dslx.DNSLookupGetaddrinfo() + lookupFn := dslx.DNSLookupGetaddrinfo(rt) // ``` // @@ -323,22 +318,10 @@ func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { dslx.EndpointNetwork("tcp"), dslx.EndpointPort(443), dslx.EndpointOptionDomain(m.config.TestHelperAddress), - dslx.EndpointOptionIDGenerator(&m.idGen), - dslx.EndpointOptionLogger(sess.Logger()), - dslx.EndpointOptionZeroTime(measurement.MeasurementStartTimeSaved), ) runtimex.Assert(len(endpoints) >= 1, "expected at least one endpoint here") endpoint := endpoints[0] - // ``` - // - // Next, we create a connection pool. This data structure helps us to manage - // open connections and close them when `connpool.Close` is invoked. - // - // ```Go - connpool := &dslx.ConnPool{} - defer connpool.Close() - // ``` // // In the following we compose step-by-step measurement "pipelines", @@ -351,9 +334,9 @@ func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { // // ```Go pipelineTarget := dslx.Compose2( - dslx.TCPConnect(connpool), + dslx.TCPConnect(rt), dslx.TLSHandshake( - connpool, + rt, dslx.TLSHandshakeOptionServerName(targetSNI), ), ) @@ -365,9 +348,9 @@ func (m *Measurer) Run(ctx context.Context, args *model.ExperimentArgs) error { // // ```Go pipelineControl := dslx.Compose2( - dslx.TCPConnect(connpool), + dslx.TCPConnect(rt), dslx.TLSHandshake( - connpool, + rt, dslx.TLSHandshakeOptionServerName(m.config.ControlSNI), ), )